Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -55,7 +55,7 @@ ov_model = OVModelForCausalLM.from_pretrained(
|
|
55 |
trust_remote_code=True,
|
56 |
)
|
57 |
|
58 |
-
#
|
59 |
class StopOnTokens(StoppingCriteria):
|
60 |
def __init__(self, token_ids):
|
61 |
self.token_ids = token_ids
|
@@ -63,29 +63,26 @@ class StopOnTokens(StoppingCriteria):
|
|
63 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
64 |
return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)
|
65 |
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def convert_history_to_token(history: List[Tuple[str, str]]):
|
68 |
-
"""
|
69 |
-
function for conversion history stored as list pairs of user and assistant messages to tokens according to model expected conversation template
|
70 |
-
Params:
|
71 |
-
history: dialogue history
|
72 |
-
Returns:
|
73 |
-
history in token format
|
74 |
-
"""
|
75 |
if pt_model_name == "baichuan2":
|
76 |
system_tokens = tok.encode(start_message)
|
77 |
history_tokens = []
|
78 |
for old_query, response in history[:-1]:
|
79 |
-
round_tokens = []
|
80 |
-
round_tokens.append(195)
|
81 |
-
round_tokens.extend(tok.encode(old_query))
|
82 |
-
round_tokens.append(196)
|
83 |
-
round_tokens.extend(tok.encode(response))
|
84 |
history_tokens = round_tokens + history_tokens
|
85 |
-
input_tokens = system_tokens + history_tokens
|
86 |
-
input_tokens.append(195)
|
87 |
-
input_tokens.extend(tok.encode(history[-1][0]))
|
88 |
-
input_tokens.append(196)
|
89 |
input_token = torch.LongTensor([input_tokens])
|
90 |
elif history_template is None or has_chat_template:
|
91 |
messages = [{"role": "system", "content": start_message}]
|
@@ -97,130 +94,90 @@ def convert_history_to_token(history: List[Tuple[str, str]]):
|
|
97 |
messages.append({"role": "user", "content": user_msg})
|
98 |
if model_msg:
|
99 |
messages.append({"role": "assistant", "content": model_msg})
|
100 |
-
|
101 |
input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
|
102 |
else:
|
103 |
text = start_message + "".join(
|
104 |
-
[
|
105 |
-
)
|
106 |
-
text += "".join(
|
107 |
-
[
|
108 |
-
"".join(
|
109 |
-
[
|
110 |
-
current_message_template.format(
|
111 |
-
num=len(history) + 1,
|
112 |
-
user=history[-1][0],
|
113 |
-
assistant=history[-1][1],
|
114 |
-
)
|
115 |
-
]
|
116 |
-
)
|
117 |
-
]
|
118 |
)
|
|
|
119 |
input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
|
120 |
return input_token
|
121 |
|
122 |
-
# Initialize
|
123 |
search = DuckDuckGoSearchRun()
|
124 |
|
125 |
-
#
|
126 |
-
def fetch_search_results(query: str) -> str:
|
127 |
-
search_results = search.invoke(query)
|
128 |
-
# Displaying search results for debugging
|
129 |
-
print("Search results: ", search_results)
|
130 |
-
return f"Relevant and recent information:\n{search_results}"
|
131 |
-
|
132 |
-
# Function to decide if a search is needed based on the user query
|
133 |
def should_use_search(query: str) -> bool:
|
134 |
-
|
135 |
-
|
136 |
-
"announcement", "bulletin", "report", "brief", "insight", "disclosure", "update",
|
137 |
"release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate",
|
138 |
"recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate",
|
139 |
"explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define",
|
140 |
"illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate",
|
141 |
"break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion",
|
142 |
"product", "performance", "resolution"
|
143 |
-
|
144 |
return any(keyword in query.lower() for keyword in search_keywords)
|
145 |
|
146 |
-
#
|
147 |
def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
|
148 |
-
|
149 |
-
|
150 |
-
"If relevant information is provided below, use it to give an accurate and concise answer. If there is no relevant information available, please rely on your general knowledge and indicate that no recent or specific information is available to answer."
|
151 |
-
)
|
152 |
-
|
153 |
-
# Build the prompt with instructions, search context, and user query
|
154 |
-
prompt = f"{instructions}\n\n"
|
155 |
-
if search_context:
|
156 |
-
prompt += f"{search_context}\n\n" # Include search context prominently at the top
|
157 |
-
|
158 |
-
# Add the user's query
|
159 |
-
prompt += f"{user_query} ?\n\n"
|
160 |
-
|
161 |
-
# Optionally add recent history for context, without labels
|
162 |
-
# if history:
|
163 |
-
# prompt += "Recent conversation:\n"
|
164 |
-
# for user_msg, assistant_msg in history[:-1]: # Exclude the last message to prevent duplication
|
165 |
-
# prompt += f"{user_msg}\n{assistant_msg}\n"
|
166 |
-
|
167 |
return prompt
|
168 |
|
|
|
|
|
|
|
|
|
|
|
169 |
|
|
|
170 |
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
|
171 |
-
"""
|
172 |
-
Main callback function for running chatbot on submit button click.
|
173 |
-
"""
|
174 |
user_query = history[-1][0]
|
175 |
-
search_context = ""
|
|
|
|
|
176 |
|
177 |
-
#
|
178 |
-
if should_use_search(user_query):
|
179 |
-
search_context = fetch_search_results(user_query)
|
180 |
-
prompt = construct_model_prompt(user_query, search_context, history)
|
181 |
-
input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids
|
182 |
-
else:
|
183 |
-
# If no search context, use the original logic with tokenization
|
184 |
-
prompt = construct_model_prompt(user_query, "", history)
|
185 |
-
input_ids = convert_history_to_token(history)
|
186 |
-
|
187 |
-
# Ensure input length does not exceed a threshold (e.g., 2000 tokens)
|
188 |
if input_ids.shape[1] > 2000:
|
189 |
-
# If input exceeds the limit, only use the most recent conversation
|
190 |
history = [history[-1]]
|
191 |
|
192 |
-
#
|
193 |
streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
temperature
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
generate_kwargs["stopping_criteria"] = StoppingCriteriaList(stop_tokens)
|
208 |
-
|
209 |
-
# Event to signal when streaming is complete
|
210 |
stream_complete = Event()
|
211 |
-
|
212 |
def generate_and_signal_complete():
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
t1 = Thread(target=generate_and_signal_complete)
|
217 |
t1.start()
|
218 |
|
219 |
-
# Initialize an empty string to store the generated text
|
220 |
partial_text = ""
|
221 |
for new_text in streamer:
|
222 |
partial_text = text_processor(partial_text, new_text)
|
223 |
-
# Update the last entry in the original history with the response
|
224 |
history[-1] = (user_query, partial_text)
|
225 |
yield history
|
226 |
|
@@ -228,6 +185,6 @@ def request_cancel():
|
|
228 |
ov_model.request.cancel()
|
229 |
|
230 |
# Gradio setup and launch
|
231 |
-
demo = make_demo(run_fn=bot,
|
232 |
if __name__ == "__main__":
|
233 |
demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)
|
|
|
55 |
trust_remote_code=True,
|
56 |
)
|
57 |
|
58 |
+
# Define stopping criteria for specific token sequences
|
59 |
class StopOnTokens(StoppingCriteria):
|
60 |
def __init__(self, token_ids):
|
61 |
self.token_ids = token_ids
|
|
|
63 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
64 |
return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)
|
65 |
|
66 |
+
if stop_tokens is not None:
|
67 |
+
if isinstance(stop_tokens[0], str):
|
68 |
+
stop_tokens = tok.convert_tokens_to_ids(stop_tokens)
|
69 |
+
stop_tokens = [StopOnTokens(stop_tokens)]
|
70 |
+
|
71 |
+
# Helper function for partial text update
|
72 |
+
def default_partial_text_processor(partial_text: str, new_text: str) -> str:
|
73 |
+
return partial_text + new_text
|
74 |
+
|
75 |
+
text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor)
|
76 |
+
|
77 |
+
# Convert conversation history to tokens based on model template
|
78 |
def convert_history_to_token(history: List[Tuple[str, str]]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
if pt_model_name == "baichuan2":
|
80 |
system_tokens = tok.encode(start_message)
|
81 |
history_tokens = []
|
82 |
for old_query, response in history[:-1]:
|
83 |
+
round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response)
|
|
|
|
|
|
|
|
|
84 |
history_tokens = round_tokens + history_tokens
|
85 |
+
input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196]
|
|
|
|
|
|
|
86 |
input_token = torch.LongTensor([input_tokens])
|
87 |
elif history_template is None or has_chat_template:
|
88 |
messages = [{"role": "system", "content": start_message}]
|
|
|
94 |
messages.append({"role": "user", "content": user_msg})
|
95 |
if model_msg:
|
96 |
messages.append({"role": "assistant", "content": model_msg})
|
|
|
97 |
input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
|
98 |
else:
|
99 |
text = start_message + "".join(
|
100 |
+
[history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
)
|
102 |
+
text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1])
|
103 |
input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
|
104 |
return input_token
|
105 |
|
106 |
+
# Initialize search tool
|
107 |
search = DuckDuckGoSearchRun()
|
108 |
|
109 |
+
# Determine if a search is needed based on the query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def should_use_search(query: str) -> bool:
|
111 |
+
search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current",
|
112 |
+
"announcement", "bulletin", "report", "brief", "insight", "disclosure", "update",
|
|
|
113 |
"release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate",
|
114 |
"recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate",
|
115 |
"explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define",
|
116 |
"illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate",
|
117 |
"break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion",
|
118 |
"product", "performance", "resolution"
|
119 |
+
]
|
120 |
return any(keyword in query.lower() for keyword in search_keywords)
|
121 |
|
122 |
+
# Construct the prompt with optional search context
|
123 |
def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
|
124 |
+
instructions = "Use the information below if relevant to provide an accurate and concise answer. If no information is available, rely on your general knowledge."
|
125 |
+
prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
return prompt
|
127 |
|
128 |
+
# Fetch search results for a query
|
129 |
+
def fetch_search_results(query: str) -> str:
|
130 |
+
search_results = search.invoke(query)
|
131 |
+
print("Search results:", search_results) # Optional: Debugging output
|
132 |
+
return f"Relevant and recent information:\n{search_results}"
|
133 |
|
134 |
+
# Main chatbot function
|
135 |
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
|
|
|
|
|
|
|
136 |
user_query = history[-1][0]
|
137 |
+
search_context = fetch_search_results(user_query) if should_use_search(user_query) else ""
|
138 |
+
prompt = construct_model_prompt(user_query, search_context, history)
|
139 |
+
input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history)
|
140 |
|
141 |
+
# Limit input length to avoid exceeding token limit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
if input_ids.shape[1] > 2000:
|
|
|
143 |
history = [history[-1]]
|
144 |
|
145 |
+
# Configure response streaming
|
146 |
streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
|
147 |
+
generate_kwargs = {
|
148 |
+
"input_ids": input_ids,
|
149 |
+
"max_new_tokens": max_new_tokens,
|
150 |
+
"temperature": temperature,
|
151 |
+
"do_sample": temperature > 0.0,
|
152 |
+
"top_p": top_p,
|
153 |
+
"top_k": top_k,
|
154 |
+
"repetition_penalty": repetition_penalty,
|
155 |
+
"streamer": streamer,
|
156 |
+
"stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None,
|
157 |
+
}
|
158 |
+
|
159 |
+
# Signal completion
|
|
|
|
|
|
|
160 |
stream_complete = Event()
|
|
|
161 |
def generate_and_signal_complete():
|
162 |
+
try:
|
163 |
+
ov_model.generate(**generate_kwargs)
|
164 |
+
except RuntimeError as e:
|
165 |
+
# Check if the error message indicates the request was canceled
|
166 |
+
if "Infer Request was canceled" in str(e):
|
167 |
+
print("Generation request was canceled.")
|
168 |
+
else:
|
169 |
+
# If it's a different RuntimeError, re-raise it
|
170 |
+
raise e
|
171 |
+
finally:
|
172 |
+
# Signal completion of the stream
|
173 |
+
stream_complete.set()
|
174 |
|
175 |
t1 = Thread(target=generate_and_signal_complete)
|
176 |
t1.start()
|
177 |
|
|
|
178 |
partial_text = ""
|
179 |
for new_text in streamer:
|
180 |
partial_text = text_processor(partial_text, new_text)
|
|
|
181 |
history[-1] = (user_query, partial_text)
|
182 |
yield history
|
183 |
|
|
|
185 |
ov_model.request.cancel()
|
186 |
|
187 |
# Gradio setup and launch
|
188 |
+
demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value)
|
189 |
if __name__ == "__main__":
|
190 |
demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)
|