aizanlabs commited on
Commit
6bac8e1
·
verified ·
1 Parent(s): 8338839

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -18
app.py CHANGED
@@ -13,7 +13,7 @@ from datetime import datetime
13
  import json
14
  import gradio as gr
15
  import re
16
-
17
  class DocumentRetrievalAndGeneration:
18
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
19
  # hf_token = os.getenv('HF_TOKEN')
@@ -66,15 +66,9 @@ class DocumentRetrievalAndGeneration:
66
 
67
  tokenizer = AutoTokenizer.from_pretrained(model_id)
68
  model = AutoModelForCausalLM.from_pretrained(model_id)
69
- generate_text = pipeline(
70
- model=model,
71
- tokenizer=tokenizer,
72
- return_full_text=True,
73
- task='text-generation',
74
- temperature=0.6,
75
- max_new_tokens=256,
76
- )
77
- return generate_text
78
 
79
  def generate_response_with_timeout(self, model_inputs):
80
  try:
@@ -127,16 +121,28 @@ class DocumentRetrievalAndGeneration:
127
  </s>
128
  """
129
 
130
- messages = [{"role": "user", "content": prompt}]
131
- encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
132
- model_inputs = encodeds.to(self.llm.device)
133
 
134
- start_time = datetime.now()
135
- generated_ids = self.generate_response_with_timeout(model_inputs)
136
- elapsed_time = datetime.now() - start_time
 
 
 
137
 
138
- decoded = self.llm.tokenizer.batch_decode(generated_ids)
139
- generated_response = decoded[0]
 
 
 
 
 
 
 
 
 
140
  match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
141
 
142
  match2 = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE)
 
13
  import json
14
  import gradio as gr
15
  import re
16
+ from unsloth import FastLanguageModel
17
  class DocumentRetrievalAndGeneration:
18
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
19
  # hf_token = os.getenv('HF_TOKEN')
 
66
 
67
  tokenizer = AutoTokenizer.from_pretrained(model_id)
68
  model = AutoModelForCausalLM.from_pretrained(model_id)
69
+ FastLanguageModel.for_inference(model) # Enable native 2x faster inference
70
+
71
+ # return generate_text
 
 
 
 
 
 
72
 
73
  def generate_response_with_timeout(self, model_inputs):
74
  try:
 
121
  </s>
122
  """
123
 
124
+ # messages = [{"role": "user", "content": prompt}]
125
+ # encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
126
+ # model_inputs = encodeds.to(self.llm.device)
127
 
128
+ # start_time = datetime.now()
129
+ # generated_ids = self.generate_response_with_timeout(model_inputs)
130
+ # elapsed_time = datetime.now() - start_time
131
+
132
+ # decoded = self.llm.tokenizer.batch_decode(generated_ids)
133
+ # generated_response = decoded[0]
134
 
135
+ inputs = tokenizer(
136
+ [
137
+ alpaca_prompt.format(
138
+ "", # instruction
139
+ prompt, # input
140
+ "", # output - leave this blank for generation!
141
+ )
142
+ ], return_tensors = "pt")#.to("cuda")
143
+
144
+ outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
145
+ tokenizer.batch_decode(outputs)
146
  match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
147
 
148
  match2 = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE)