Update app.py
Browse files
app.py
CHANGED
@@ -71,9 +71,14 @@ class DocumentRetrievalAndGeneration:
|
|
71 |
return generate_text
|
72 |
def initialize_llm2(self,model_id):
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# return generate_text
|
79 |
|
@@ -150,9 +155,44 @@ class DocumentRetrievalAndGeneration:
|
|
150 |
|
151 |
# decoded = self.llm.tokenizer.batch_decode(generated_ids)
|
152 |
# generated_response = decoded[0]
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
|
158 |
|
|
|
71 |
return generate_text
|
72 |
def initialize_llm2(self,model_id):
|
73 |
|
74 |
+
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
75 |
+
pipeline = transformers.pipeline(
|
76 |
+
"text-generation",
|
77 |
+
model=model_name,
|
78 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
79 |
+
device="cpu",
|
80 |
+
)
|
81 |
+
|
82 |
|
83 |
# return generate_text
|
84 |
|
|
|
155 |
|
156 |
# decoded = self.llm.tokenizer.batch_decode(generated_ids)
|
157 |
# generated_response = decoded[0]
|
158 |
+
messages = []
|
159 |
+
# Check if history is None or empty and handle accordingly
|
160 |
+
if history:
|
161 |
+
for user_msg, assistant_msg in history:
|
162 |
+
messages.append({"role": "user", "content": user_msg})
|
163 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
164 |
+
|
165 |
+
# Always add the current user message
|
166 |
+
messages.append({"role": "user", "content": message})
|
167 |
+
|
168 |
+
# Construct the prompt using the pipeline's tokenizer
|
169 |
+
prompt = pipeline.tokenizer.apply_chat_template(
|
170 |
+
messages,
|
171 |
+
tokenize=False,
|
172 |
+
add_generation_prompt=True
|
173 |
+
)
|
174 |
+
|
175 |
+
# Generate the response
|
176 |
+
terminators = [
|
177 |
+
pipeline.tokenizer.eos_token_id,
|
178 |
+
pipeline.tokenizer.convert_tokens_to_ids("")
|
179 |
+
]
|
180 |
+
|
181 |
+
# Adjust the temperature slightly above given to ensure variety
|
182 |
+
adjusted_temp = temperature + 0.1
|
183 |
+
|
184 |
+
# Generate outputs with adjusted parameters
|
185 |
+
outputs = pipeline(
|
186 |
+
prompt,
|
187 |
+
max_new_tokens=max_new_tokens,
|
188 |
+
do_sample=True,
|
189 |
+
temperature=adjusted_temp,
|
190 |
+
top_p=0.9
|
191 |
+
)
|
192 |
+
|
193 |
+
# Extract the generated text, skipping the length of the prompt
|
194 |
+
generated_text = outputs[0]["generated_text"]
|
195 |
+
generated_response = generated_text[len(prompt):]
|
196 |
|
197 |
match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
|
198 |
|