littlebird13 commited on
Commit
00cfd74
·
verified ·
1 Parent(s): ea1cd92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -35
app.py CHANGED
@@ -85,41 +85,41 @@ def generate(
85
  top_k: int = 50,
86
  repetition_penalty: float = 1.2,
87
  ) -> Iterator[str]:
88
- # print_gpu()
89
-
90
- # conversation = []
91
- # if system_prompt:
92
- # conversation.append({"role": "system", "content": system_prompt})
93
- # for user, assistant in chat_history:
94
- # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
95
- # conversation.append({"role": "user", "content": message})
96
-
97
- # input_ids = tokenizer.apply_chat_template(conversation, tokenize=False,add_generation_prompt=True)
98
- # input_ids = tokenizer([input_ids],return_tensors="pt").to(model.device)
99
-
100
- # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
101
- # generate_kwargs = dict(
102
- # input_ids=input_ids.input_ids,
103
- # streamer=streamer,
104
- # max_new_tokens=max_new_tokens,
105
- # do_sample=True,
106
- # top_p=top_p,
107
- # top_k=top_k,
108
- # temperature=temperature,
109
- # repetition_penalty=repetition_penalty,
110
- # )
111
- # t = Thread(target=model.generate, kwargs=generate_kwargs)
112
- # t.start()
113
- # #dictionary update sequence element #0 has length 19; 2 is required
114
-
115
- # outputs = []
116
- # for text in streamer:
117
- # outputs.append(text)
118
- # yield "".join(outputs)
119
-
120
- # #outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
121
- # print(outputs)
122
- #yield outputs
123
 
124
 
125
  chat_interface = gr.ChatInterface(
 
85
  top_k: int = 50,
86
  repetition_penalty: float = 1.2,
87
  ) -> Iterator[str]:
88
+ print_gpu()
89
+
90
+ conversation = []
91
+ if system_prompt:
92
+ conversation.append({"role": "system", "content": system_prompt})
93
+ for user, assistant in chat_history:
94
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
95
+ conversation.append({"role": "user", "content": message})
96
+
97
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=False,add_generation_prompt=True)
98
+ input_ids = tokenizer([input_ids],return_tensors="pt").to(model.device)
99
+
100
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
101
+ generate_kwargs = dict(
102
+ input_ids=input_ids.input_ids,
103
+ streamer=streamer,
104
+ max_new_tokens=max_new_tokens,
105
+ do_sample=True,
106
+ top_p=top_p,
107
+ top_k=top_k,
108
+ temperature=temperature,
109
+ repetition_penalty=repetition_penalty,
110
+ )
111
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
112
+ t.start()
113
+ #dictionary update sequence element #0 has length 19; 2 is required
114
+
115
+ outputs = []
116
+ for text in streamer:
117
+ outputs.append(text)
118
+ yield "".join(outputs)
119
+
120
+ #outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
121
+ print(outputs)
122
+ yield outputs
123
 
124
 
125
  chat_interface = gr.ChatInterface(