Spaces:
Sleeping
Sleeping
add different llms
Browse files- app.py +7 -7
- backend/query_llm.py +13 -5
app.py
CHANGED
@@ -63,25 +63,25 @@ def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encode
|
|
63 |
prompt_html = template_html.render(documents=documents, query=query)
|
64 |
|
65 |
if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
|
66 |
-
|
67 |
if llm_model == "mistralai/Mistral-7B-v0.1":
|
68 |
-
|
69 |
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
|
70 |
-
|
71 |
if llm_model == "gpt-3.5-turbo":
|
72 |
-
|
73 |
if llm_model == "gpt-4-turbo-preview":
|
74 |
-
|
75 |
|
76 |
#if api_kind == "HuggingFace":
|
77 |
# generate_fn = generate_hf
|
78 |
#elif api_kind == "OpenAI":
|
79 |
# generate_fn = generate_openai
|
80 |
#else:
|
81 |
-
|
82 |
|
83 |
history[-1][1] = ""
|
84 |
-
for character in generate_fn(prompt, history[:-1]):
|
85 |
history[-1][1] = character
|
86 |
yield history, prompt_html
|
87 |
|
|
|
63 |
prompt_html = template_html.render(documents=documents, query=query)
|
64 |
|
65 |
if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
|
66 |
+
generate_fn = generate_hf
|
67 |
if llm_model == "mistralai/Mistral-7B-v0.1":
|
68 |
+
generate_fn = generate_hf
|
69 |
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
|
70 |
+
generate_fn = generate_hf
|
71 |
if llm_model == "gpt-3.5-turbo":
|
72 |
+
generate_fn = generate_openai
|
73 |
if llm_model == "gpt-4-turbo-preview":
|
74 |
+
generate_fn = generate_openai
|
75 |
|
76 |
#if api_kind == "HuggingFace":
|
77 |
# generate_fn = generate_hf
|
78 |
#elif api_kind == "OpenAI":
|
79 |
# generate_fn = generate_openai
|
80 |
#else:
|
81 |
+
raise gr.Error(f"API {api_kind} is not supported")
|
82 |
|
83 |
history[-1][1] = ""
|
84 |
+
for character in generate_fn(prompt, history[:-1], llm_model):
|
85 |
history[-1][1] = character
|
86 |
yield history, prompt_html
|
87 |
|
backend/query_llm.py
CHANGED
@@ -34,7 +34,7 @@ OAI_GENERATE_KWARGS = {
|
|
34 |
}
|
35 |
|
36 |
|
37 |
-
def format_prompt(message: str, api_kind: str):
|
38 |
"""
|
39 |
Formats the given message using a chat template.
|
40 |
|
@@ -51,12 +51,13 @@ def format_prompt(message: str, api_kind: str):
|
|
51 |
if api_kind == "openai":
|
52 |
return messages
|
53 |
elif api_kind == "hf":
|
|
|
54 |
return TOKENIZER.apply_chat_template(messages, tokenize=False)
|
55 |
elif api_kind:
|
56 |
raise ValueError("API is not supported")
|
57 |
|
58 |
|
59 |
-
def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
60 |
"""
|
61 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
62 |
|
@@ -67,8 +68,14 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
|
67 |
Generator[str, None, str]: A generator yielding chunks of generated text.
|
68 |
Returns a final string if an error occurs.
|
69 |
"""
|
|
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
|
73 |
|
74 |
try:
|
@@ -93,7 +100,7 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
|
93 |
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
94 |
|
95 |
|
96 |
-
def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
|
97 |
"""
|
98 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
99 |
|
@@ -108,7 +115,8 @@ def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
|
|
108 |
|
109 |
try:
|
110 |
stream = OAI_CLIENT.chat.completions.create(
|
111 |
-
model=os.getenv("OPENAI_MODEL"),
|
|
|
112 |
messages=formatted_prompt,
|
113 |
**OAI_GENERATE_KWARGS,
|
114 |
stream=True
|
|
|
34 |
}
|
35 |
|
36 |
|
37 |
+
def format_prompt(message: str, api_kind: str, tokenizer_name = None):
|
38 |
"""
|
39 |
Formats the given message using a chat template.
|
40 |
|
|
|
51 |
if api_kind == "openai":
|
52 |
return messages
|
53 |
elif api_kind == "hf":
|
54 |
+
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name)
|
55 |
return TOKENIZER.apply_chat_template(messages, tokenize=False)
|
56 |
elif api_kind:
|
57 |
raise ValueError("API is not supported")
|
58 |
|
59 |
|
60 |
+
def generate_hf(prompt: str, history: str, hf_model_name: str) -> Generator[str, None, str]:
|
61 |
"""
|
62 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
63 |
|
|
|
68 |
Generator[str, None, str]: A generator yielding chunks of generated text.
|
69 |
Returns a final string if an error occurs.
|
70 |
"""
|
71 |
+
|
72 |
|
73 |
+
HF_CLIENT = InferenceClient(
|
74 |
+
hf_model_name,
|
75 |
+
token=os.getenv("HF_TOKEN")
|
76 |
+
)
|
77 |
+
|
78 |
+
formatted_prompt = format_prompt(prompt, "hf", hf_model_name)
|
79 |
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
|
80 |
|
81 |
try:
|
|
|
100 |
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
101 |
|
102 |
|
103 |
+
def generate_openai(prompt: str, history: str, model_name: str) -> Generator[str, None, str]:
|
104 |
"""
|
105 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
106 |
|
|
|
115 |
|
116 |
try:
|
117 |
stream = OAI_CLIENT.chat.completions.create(
|
118 |
+
#model=os.getenv("OPENAI_MODEL"),
|
119 |
+
model = model_name,
|
120 |
messages=formatted_prompt,
|
121 |
**OAI_GENERATE_KWARGS,
|
122 |
stream=True
|