aizanlabs commited on
Commit
89e6c61
·
verified ·
1 Parent(s): e67420d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -6
app.py CHANGED
@@ -71,9 +71,14 @@ class DocumentRetrievalAndGeneration:
71
  return generate_text
72
  def initialize_llm2(self,model_id):
73
 
74
- self.tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-3b")#, local_files_only=True)
75
- self.model = BloomForCausalLM.from_pretrained("bigscience/bloom-3b")#, local_files_only=True)
76
- self.result_length = 2048
 
 
 
 
 
77
 
78
  # return generate_text
79
 
@@ -150,9 +155,44 @@ class DocumentRetrievalAndGeneration:
150
 
151
  # decoded = self.llm.tokenizer.batch_decode(generated_ids)
152
  # generated_response = decoded[0]
153
- inputs = self.tokenizer(prompt, return_tensors="pt")
154
- generated_response=self.tokenizer.decode(self.model.generate(inputs["input_ids"], max_length=self.result_length,no_repeat_ngram_size=2)[0])
155
- print(generated_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
158
 
 
71
  return generate_text
72
  def initialize_llm2(self,model_id):
73
 
74
+ model_name = "mistralai/Mistral-7B-Instruct-v0.2"
75
+ pipeline = transformers.pipeline(
76
+ "text-generation",
77
+ model=model_name,
78
+ model_kwargs={"torch_dtype": torch.bfloat16},
79
+ device="cpu",
80
+ )
81
+
82
 
83
  # return generate_text
84
 
 
155
 
156
  # decoded = self.llm.tokenizer.batch_decode(generated_ids)
157
  # generated_response = decoded[0]
158
+ messages = []
159
+ # Check if history is None or empty and handle accordingly
160
+ if history:
161
+ for user_msg, assistant_msg in history:
162
+ messages.append({"role": "user", "content": user_msg})
163
+ messages.append({"role": "assistant", "content": assistant_msg})
164
+
165
+ # Always add the current user message
166
+ messages.append({"role": "user", "content": message})
167
+
168
+ # Construct the prompt using the pipeline's tokenizer
169
+ prompt = pipeline.tokenizer.apply_chat_template(
170
+ messages,
171
+ tokenize=False,
172
+ add_generation_prompt=True
173
+ )
174
+
175
+ # Generate the response
176
+ terminators = [
177
+ pipeline.tokenizer.eos_token_id,
178
+ pipeline.tokenizer.convert_tokens_to_ids("")
179
+ ]
180
+
181
+ # Adjust the temperature slightly above given to ensure variety
182
+ adjusted_temp = temperature + 0.1
183
+
184
+ # Generate outputs with adjusted parameters
185
+ outputs = pipeline(
186
+ prompt,
187
+ max_new_tokens=max_new_tokens,
188
+ do_sample=True,
189
+ temperature=adjusted_temp,
190
+ top_p=0.9
191
+ )
192
+
193
+ # Extract the generated text, skipping the length of the prompt
194
+ generated_text = outputs[0]["generated_text"]
195
+ generated_response = generated_text[len(prompt):]
196
 
197
  match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
198