charanhu commited on
Commit
5f5a729
·
1 Parent(s): 9ad03b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -6,7 +6,6 @@ from threading import Thread
6
  # Load model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
8
  model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
9
- model = model.to('cuda:0')
10
 
11
  class StopOnTokens(StoppingCriteria):
12
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
@@ -24,7 +23,7 @@ def predict(message, history):
24
  messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
25
  for item in history_transformer_format])
26
 
27
- model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
28
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
29
  max_new_tokens = gr.Slider(minimum=0, maximum=2048, value=10, label="Temperature"),
30
  min_new_tokens = gr.Slider(minimum=0, maximum=2048, value=1, label="Temperature"),
@@ -39,14 +38,10 @@ def predict(message, history):
39
  num_beams=1,
40
  stopping_criteria=StoppingCriteriaList([stop])
41
  )
42
- t = Thread(target=model.generate, kwargs=generate_kwargs)
43
- t.start()
44
-
45
- partial_message = ""
46
- for new_token in streamer:
47
- if new_token != '<':
48
- partial_message += new_token
49
- yield partial_message
50
 
51
 
52
  gr.ChatInterface(predict).queue().launch()
 
6
  # Load model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
8
  model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
 
9
 
10
  class StopOnTokens(StoppingCriteria):
11
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
23
  messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
24
  for item in history_transformer_format])
25
 
26
+ model_inputs = tokenizer([messages], return_tensors="pt")
27
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
28
  max_new_tokens = gr.Slider(minimum=0, maximum=2048, value=10, label="Temperature"),
29
  min_new_tokens = gr.Slider(minimum=0, maximum=2048, value=1, label="Temperature"),
 
38
  num_beams=1,
39
  stopping_criteria=StoppingCriteriaList([stop])
40
  )
41
+ generated_sequence = model.generate(**generate_kwargs)[0]
42
+ generated_text = tokenizer.decode(generated_sequence, skip_special_tokens=True)
43
+
44
+ yield generated_text
 
 
 
 
45
 
46
 
47
  gr.ChatInterface(predict).queue().launch()