lightmate commited on
Commit
9f09252
·
verified ·
1 Parent(s): 844af01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -105
app.py CHANGED
@@ -55,7 +55,7 @@ ov_model = OVModelForCausalLM.from_pretrained(
55
  trust_remote_code=True,
56
  )
57
 
58
- # Stopping criteria for token generation
59
  class StopOnTokens(StoppingCriteria):
60
  def __init__(self, token_ids):
61
  self.token_ids = token_ids
@@ -63,29 +63,26 @@ class StopOnTokens(StoppingCriteria):
63
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
64
  return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)
65
 
66
- # Functions for chatbot logic
 
 
 
 
 
 
 
 
 
 
 
67
  def convert_history_to_token(history: List[Tuple[str, str]]):
68
- """
69
- function for conversion history stored as list pairs of user and assistant messages to tokens according to model expected conversation template
70
- Params:
71
- history: dialogue history
72
- Returns:
73
- history in token format
74
- """
75
  if pt_model_name == "baichuan2":
76
  system_tokens = tok.encode(start_message)
77
  history_tokens = []
78
  for old_query, response in history[:-1]:
79
- round_tokens = []
80
- round_tokens.append(195)
81
- round_tokens.extend(tok.encode(old_query))
82
- round_tokens.append(196)
83
- round_tokens.extend(tok.encode(response))
84
  history_tokens = round_tokens + history_tokens
85
- input_tokens = system_tokens + history_tokens
86
- input_tokens.append(195)
87
- input_tokens.extend(tok.encode(history[-1][0]))
88
- input_tokens.append(196)
89
  input_token = torch.LongTensor([input_tokens])
90
  elif history_template is None or has_chat_template:
91
  messages = [{"role": "system", "content": start_message}]
@@ -97,130 +94,90 @@ def convert_history_to_token(history: List[Tuple[str, str]]):
97
  messages.append({"role": "user", "content": user_msg})
98
  if model_msg:
99
  messages.append({"role": "assistant", "content": model_msg})
100
-
101
  input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
102
  else:
103
  text = start_message + "".join(
104
- ["".join([history_template.format(num=round, user=item[0], assistant=item[1])]) for round, item in enumerate(history[:-1])]
105
- )
106
- text += "".join(
107
- [
108
- "".join(
109
- [
110
- current_message_template.format(
111
- num=len(history) + 1,
112
- user=history[-1][0],
113
- assistant=history[-1][1],
114
- )
115
- ]
116
- )
117
- ]
118
  )
 
119
  input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
120
  return input_token
121
 
122
- # Initialize the search tool
123
  search = DuckDuckGoSearchRun()
124
 
125
- # Function to retrieve and format search results based on user input
126
- def fetch_search_results(query: str) -> str:
127
- search_results = search.invoke(query)
128
- # Displaying search results for debugging
129
- print("Search results: ", search_results)
130
- return f"Relevant and recent information:\n{search_results}"
131
-
132
- # Function to decide if a search is needed based on the user query
133
  def should_use_search(query: str) -> bool:
134
- # Simple heuristic, can be extended with more advanced intent analysis
135
- search_keywords = ["latest", "news", "update", "which" "who", "what", "when", "why","how", "recent", "result", "tell", "explain",
136
- "announcement", "bulletin", "report", "brief", "insight", "disclosure", "update",
137
  "release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate",
138
  "recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate",
139
  "explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define",
140
  "illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate",
141
  "break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion",
142
  "product", "performance", "resolution"
143
- ]
144
  return any(keyword in query.lower() for keyword in search_keywords)
145
 
146
- # Generate prompt for model with optional search context
147
  def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
148
- # Simple instruction for the model to prioritize search information if available
149
- instructions = (
150
- "If relevant information is provided below, use it to give an accurate and concise answer. If there is no relevant information available, please rely on your general knowledge and indicate that no recent or specific information is available to answer."
151
- )
152
-
153
- # Build the prompt with instructions, search context, and user query
154
- prompt = f"{instructions}\n\n"
155
- if search_context:
156
- prompt += f"{search_context}\n\n" # Include search context prominently at the top
157
-
158
- # Add the user's query
159
- prompt += f"{user_query} ?\n\n"
160
-
161
- # Optionally add recent history for context, without labels
162
- # if history:
163
- # prompt += "Recent conversation:\n"
164
- # for user_msg, assistant_msg in history[:-1]: # Exclude the last message to prevent duplication
165
- # prompt += f"{user_msg}\n{assistant_msg}\n"
166
-
167
  return prompt
168
 
 
 
 
 
 
169
 
 
170
  def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
171
- """
172
- Main callback function for running chatbot on submit button click.
173
- """
174
  user_query = history[-1][0]
175
- search_context = ""
 
 
176
 
177
- # Decide if search is required based on the user query
178
- if should_use_search(user_query):
179
- search_context = fetch_search_results(user_query)
180
- prompt = construct_model_prompt(user_query, search_context, history)
181
- input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids
182
- else:
183
- # If no search context, use the original logic with tokenization
184
- prompt = construct_model_prompt(user_query, "", history)
185
- input_ids = convert_history_to_token(history)
186
-
187
- # Ensure input length does not exceed a threshold (e.g., 2000 tokens)
188
  if input_ids.shape[1] > 2000:
189
- # If input exceeds the limit, only use the most recent conversation
190
  history = [history[-1]]
191
 
192
- # Streamer for model response generation
193
  streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
194
-
195
- generate_kwargs = dict(
196
- input_ids=input_ids,
197
- max_new_tokens=256, # Adjust this as needed
198
- temperature=temperature,
199
- do_sample=temperature > 0.0,
200
- top_p=top_p,
201
- top_k=top_k,
202
- repetition_penalty=repetition_penalty,
203
- streamer=streamer,
204
- )
205
-
206
- if stop_tokens is not None:
207
- generate_kwargs["stopping_criteria"] = StoppingCriteriaList(stop_tokens)
208
-
209
- # Event to signal when streaming is complete
210
  stream_complete = Event()
211
-
212
  def generate_and_signal_complete():
213
- ov_model.generate(**generate_kwargs)
214
- stream_complete.set()
 
 
 
 
 
 
 
 
 
 
215
 
216
  t1 = Thread(target=generate_and_signal_complete)
217
  t1.start()
218
 
219
- # Initialize an empty string to store the generated text
220
  partial_text = ""
221
  for new_text in streamer:
222
  partial_text = text_processor(partial_text, new_text)
223
- # Update the last entry in the original history with the response
224
  history[-1] = (user_query, partial_text)
225
  yield history
226
 
@@ -228,6 +185,6 @@ def request_cancel():
228
  ov_model.request.cancel()
229
 
230
  # Gradio setup and launch
231
- demo = make_demo(run_fn=bot, stop_fn=request_cancel, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value)
232
  if __name__ == "__main__":
233
  demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)
 
55
  trust_remote_code=True,
56
  )
57
 
58
+ # Define stopping criteria for specific token sequences
59
  class StopOnTokens(StoppingCriteria):
60
  def __init__(self, token_ids):
61
  self.token_ids = token_ids
 
63
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
64
  return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)
65
 
66
+ if stop_tokens is not None:
67
+ if isinstance(stop_tokens[0], str):
68
+ stop_tokens = tok.convert_tokens_to_ids(stop_tokens)
69
+ stop_tokens = [StopOnTokens(stop_tokens)]
70
+
71
+ # Helper function for partial text update
72
+ def default_partial_text_processor(partial_text: str, new_text: str) -> str:
73
+ return partial_text + new_text
74
+
75
+ text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor)
76
+
77
+ # Convert conversation history to tokens based on model template
78
  def convert_history_to_token(history: List[Tuple[str, str]]):
 
 
 
 
 
 
 
79
  if pt_model_name == "baichuan2":
80
  system_tokens = tok.encode(start_message)
81
  history_tokens = []
82
  for old_query, response in history[:-1]:
83
+ round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response)
 
 
 
 
84
  history_tokens = round_tokens + history_tokens
85
+ input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196]
 
 
 
86
  input_token = torch.LongTensor([input_tokens])
87
  elif history_template is None or has_chat_template:
88
  messages = [{"role": "system", "content": start_message}]
 
94
  messages.append({"role": "user", "content": user_msg})
95
  if model_msg:
96
  messages.append({"role": "assistant", "content": model_msg})
 
97
  input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
98
  else:
99
  text = start_message + "".join(
100
+ [history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])]
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
+ text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1])
103
  input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
104
  return input_token
105
 
106
+ # Initialize search tool
107
  search = DuckDuckGoSearchRun()
108
 
109
+ # Determine if a search is needed based on the query
 
 
 
 
 
 
 
110
  def should_use_search(query: str) -> bool:
111
+ search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current",
112
+ "announcement", "bulletin", "report", "brief", "insight", "disclosure", "update",
 
113
  "release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate",
114
  "recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate",
115
  "explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define",
116
  "illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate",
117
  "break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion",
118
  "product", "performance", "resolution"
119
+ ]
120
  return any(keyword in query.lower() for keyword in search_keywords)
121
 
122
+ # Construct the prompt with optional search context
123
  def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
124
+ instructions = "Use the information below if relevant to provide an accurate and concise answer. If no information is available, rely on your general knowledge."
125
+ prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  return prompt
127
 
128
+ # Fetch search results for a query
129
+ def fetch_search_results(query: str) -> str:
130
+ search_results = search.invoke(query)
131
+ print("Search results:", search_results) # Optional: Debugging output
132
+ return f"Relevant and recent information:\n{search_results}"
133
 
134
+ # Main chatbot function
135
  def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
 
 
 
136
  user_query = history[-1][0]
137
+ search_context = fetch_search_results(user_query) if should_use_search(user_query) else ""
138
+ prompt = construct_model_prompt(user_query, search_context, history)
139
+ input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history)
140
 
141
+ # Limit input length to avoid exceeding token limit
 
 
 
 
 
 
 
 
 
 
142
  if input_ids.shape[1] > 2000:
 
143
  history = [history[-1]]
144
 
145
+ # Configure response streaming
146
  streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
147
+ generate_kwargs = {
148
+ "input_ids": input_ids,
149
+ "max_new_tokens": max_new_tokens,
150
+ "temperature": temperature,
151
+ "do_sample": temperature > 0.0,
152
+ "top_p": top_p,
153
+ "top_k": top_k,
154
+ "repetition_penalty": repetition_penalty,
155
+ "streamer": streamer,
156
+ "stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None,
157
+ }
158
+
159
+ # Signal completion
 
 
 
160
  stream_complete = Event()
 
161
  def generate_and_signal_complete():
162
+ try:
163
+ ov_model.generate(**generate_kwargs)
164
+ except RuntimeError as e:
165
+ # Check if the error message indicates the request was canceled
166
+ if "Infer Request was canceled" in str(e):
167
+ print("Generation request was canceled.")
168
+ else:
169
+ # If it's a different RuntimeError, re-raise it
170
+ raise e
171
+ finally:
172
+ # Signal completion of the stream
173
+ stream_complete.set()
174
 
175
  t1 = Thread(target=generate_and_signal_complete)
176
  t1.start()
177
 
 
178
  partial_text = ""
179
  for new_text in streamer:
180
  partial_text = text_processor(partial_text, new_text)
 
181
  history[-1] = (user_query, partial_text)
182
  yield history
183
 
 
185
  ov_model.request.cancel()
186
 
187
  # Gradio setup and launch
188
+ demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value)
189
  if __name__ == "__main__":
190
  demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)