hysts HF staff commited on
Commit
141ba59
·
1 Parent(s): 89f9579
Files changed (3) hide show
  1. app.py +83 -220
  2. model.py +0 -64
  3. style.css +1 -1
app.py CHANGED
@@ -1,20 +1,16 @@
 
1
  from typing import Iterator
2
 
3
  import gradio as gr
4
  import torch
 
5
 
6
- from model import get_input_token_length, run
7
-
8
- DEFAULT_SYSTEM_PROMPT = """\
9
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
10
-
11
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
12
- """
13
  MAX_MAX_NEW_TOKENS = 2048
14
  DEFAULT_MAX_NEW_TOKENS = 1024
15
- MAX_INPUT_TOKEN_LENGTH = 4000
16
 
17
- DESCRIPTION = """
18
  # Llama-2 7B Chat
19
 
20
  This Space demonstrates model [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, a Llama 2 model with 7B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
@@ -36,245 +32,112 @@ if not torch.cuda.is_available():
36
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
37
 
38
 
39
- def clear_and_save_textbox(message: str) -> tuple[str, str]:
40
- return "", message
41
-
42
-
43
- def display_input(message: str, history: list[tuple[str, str]]) -> list[tuple[str, str]]:
44
- history.append((message, ""))
45
- return history
46
-
47
-
48
- def delete_prev_fn(history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
49
- try:
50
- message, _ = history.pop()
51
- except IndexError:
52
- message = ""
53
- return history, message or ""
54
 
55
 
56
  def generate(
57
  message: str,
58
- history_with_input: list[tuple[str, str]],
59
  system_prompt: str,
60
- max_new_tokens: int,
61
- temperature: float,
62
- top_p: float,
63
- top_k: int,
64
- ) -> Iterator[list[tuple[str, str]]]:
65
- if max_new_tokens > MAX_MAX_NEW_TOKENS:
66
- raise ValueError
67
-
68
- history = history_with_input[:-1]
69
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
70
- try:
71
- first_response = next(generator)
72
- yield history + [(message, first_response)]
73
- except StopIteration:
74
- yield history + [(message, "")]
75
- for response in generator:
76
- yield history + [(message, response)]
77
-
78
-
79
- def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
80
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
81
- for x in generator:
82
- pass
83
- return "", x
84
-
85
-
86
- def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
87
- input_token_length = get_input_token_length(message, chat_history, system_prompt)
88
- if input_token_length > MAX_INPUT_TOKEN_LENGTH:
89
- raise gr.Error(
90
- f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
91
- )
92
-
93
-
94
- with gr.Blocks(css="style.css") as demo:
95
- gr.Markdown(DESCRIPTION)
96
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
97
 
98
- with gr.Group():
99
- chatbot = gr.Chatbot(label="Chatbot")
100
- with gr.Row():
101
- textbox = gr.Textbox(
102
- container=False,
103
- show_label=False,
104
- placeholder="Type a message...",
105
- scale=10,
106
- )
107
- submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
108
- with gr.Row():
109
- retry_button = gr.Button("🔄 Retry", variant="secondary")
110
- undo_button = gr.Button("↩️ Undo", variant="secondary")
111
- clear_button = gr.Button("🗑️ Clear", variant="secondary")
112
 
113
- saved_input = gr.State()
114
 
115
- with gr.Accordion(label="Advanced options", open=False):
116
- system_prompt = gr.Textbox(label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
117
- max_new_tokens = gr.Slider(
 
 
118
  label="Max new tokens",
119
  minimum=1,
120
  maximum=MAX_MAX_NEW_TOKENS,
121
  step=1,
122
  value=DEFAULT_MAX_NEW_TOKENS,
123
- )
124
- temperature = gr.Slider(
125
  label="Temperature",
126
  minimum=0.1,
127
  maximum=4.0,
128
  step=0.1,
129
- value=1.0,
130
- )
131
- top_p = gr.Slider(
132
  label="Top-p (nucleus sampling)",
133
  minimum=0.05,
134
  maximum=1.0,
135
  step=0.05,
136
- value=0.95,
137
- )
138
- top_k = gr.Slider(
139
  label="Top-k",
140
  minimum=1,
141
  maximum=1000,
142
  step=1,
143
  value=50,
144
- )
145
-
146
- gr.Examples(
147
- examples=[
148
- "Hello there! How are you doing?",
149
- "Can you explain briefly to me what is the Python programming language?",
150
- "Explain the plot of Cinderella in a sentence.",
151
- "How many hours does it take a man to eat a Helicopter?",
152
- "Write a 100-word article on 'Benefits of Open-Source in AI research'",
153
- ],
154
- inputs=textbox,
155
- outputs=[textbox, chatbot],
156
- fn=process_example,
157
- cache_examples=True,
158
- )
 
 
 
159
 
 
 
 
 
160
  gr.Markdown(LICENSE)
161
 
162
- textbox.submit(
163
- fn=clear_and_save_textbox,
164
- inputs=textbox,
165
- outputs=[textbox, saved_input],
166
- api_name=False,
167
- queue=False,
168
- ).then(
169
- fn=display_input,
170
- inputs=[saved_input, chatbot],
171
- outputs=chatbot,
172
- api_name=False,
173
- queue=False,
174
- ).then(
175
- fn=check_input_token_length,
176
- inputs=[saved_input, chatbot, system_prompt],
177
- api_name=False,
178
- queue=False,
179
- ).success(
180
- fn=generate,
181
- inputs=[
182
- saved_input,
183
- chatbot,
184
- system_prompt,
185
- max_new_tokens,
186
- temperature,
187
- top_p,
188
- top_k,
189
- ],
190
- outputs=chatbot,
191
- api_name=False,
192
- )
193
-
194
- button_event_preprocess = (
195
- submit_button.click(
196
- fn=clear_and_save_textbox,
197
- inputs=textbox,
198
- outputs=[textbox, saved_input],
199
- api_name=False,
200
- queue=False,
201
- )
202
- .then(
203
- fn=display_input,
204
- inputs=[saved_input, chatbot],
205
- outputs=chatbot,
206
- api_name=False,
207
- queue=False,
208
- )
209
- .then(
210
- fn=check_input_token_length,
211
- inputs=[saved_input, chatbot, system_prompt],
212
- api_name=False,
213
- queue=False,
214
- )
215
- .success(
216
- fn=generate,
217
- inputs=[
218
- saved_input,
219
- chatbot,
220
- system_prompt,
221
- max_new_tokens,
222
- temperature,
223
- top_p,
224
- top_k,
225
- ],
226
- outputs=chatbot,
227
- api_name=False,
228
- )
229
- )
230
-
231
- retry_button.click(
232
- fn=delete_prev_fn,
233
- inputs=chatbot,
234
- outputs=[chatbot, saved_input],
235
- api_name=False,
236
- queue=False,
237
- ).then(
238
- fn=display_input,
239
- inputs=[saved_input, chatbot],
240
- outputs=chatbot,
241
- api_name=False,
242
- queue=False,
243
- ).then(
244
- fn=generate,
245
- inputs=[
246
- saved_input,
247
- chatbot,
248
- system_prompt,
249
- max_new_tokens,
250
- temperature,
251
- top_p,
252
- top_k,
253
- ],
254
- outputs=chatbot,
255
- api_name=False,
256
- )
257
-
258
- undo_button.click(
259
- fn=delete_prev_fn,
260
- inputs=chatbot,
261
- outputs=[chatbot, saved_input],
262
- api_name=False,
263
- queue=False,
264
- ).then(
265
- fn=lambda x: x,
266
- inputs=[saved_input],
267
- outputs=textbox,
268
- api_name=False,
269
- queue=False,
270
- )
271
-
272
- clear_button.click(
273
- fn=lambda: ([], ""),
274
- outputs=[chatbot, saved_input],
275
- queue=False,
276
- api_name=False,
277
- )
278
-
279
  if __name__ == "__main__":
280
  demo.queue(max_size=20).launch()
 
1
+ from threading import Thread
2
  from typing import Iterator
3
 
4
  import gradio as gr
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
 
 
 
 
 
 
9
  MAX_MAX_NEW_TOKENS = 2048
10
  DEFAULT_MAX_NEW_TOKENS = 1024
11
+ MAX_INPUT_TOKEN_LENGTH = 4096
12
 
13
+ DESCRIPTION = """\
14
  # Llama-2 7B Chat
15
 
16
  This Space demonstrates model [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, a Llama 2 model with 7B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
 
32
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
33
 
34
 
35
+ if torch.cuda.is_available():
36
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
37
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
38
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
39
+ tokenizer.use_default_system_prompt = False
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def generate(
43
  message: str,
44
+ chat_history: list[tuple[str, str]],
45
  system_prompt: str,
46
+ max_new_tokens: int = 1024,
47
+ temperature: float = 0.6,
48
+ top_p: float = 0.9,
49
+ top_k: int = 50,
50
+ repetition_penalty: float = 1.2,
51
+ ) -> Iterator[str]:
52
+ conversation = []
53
+ if system_prompt:
54
+ conversation.append({"role": "system", "content": system_prompt})
55
+ for user, assistant in chat_history:
56
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
57
+ conversation.append({"role": "user", "content": message})
58
+
59
+ chat = tokenizer.apply_chat_template(conversation, tokenize=False)
60
+ inputs = tokenizer(chat, return_tensors="pt", add_special_tokens=False).to("cuda")
61
+ if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
62
+ inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
63
+ gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
64
+
65
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
66
+ generate_kwargs = dict(
67
+ inputs,
68
+ streamer=streamer,
69
+ max_new_tokens=max_new_tokens,
70
+ do_sample=True,
71
+ top_p=top_p,
72
+ top_k=top_k,
73
+ temperature=temperature,
74
+ num_beams=1,
75
+ repetition_penalty=repetition_penalty,
76
+ )
77
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
78
+ t.start()
 
 
 
 
79
 
80
+ outputs = []
81
+ for text in streamer:
82
+ outputs.append(text)
83
+ yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
84
 
 
85
 
86
+ chat_interface = gr.ChatInterface(
87
+ fn=generate,
88
+ additional_inputs=[
89
+ gr.Textbox(label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6),
90
+ gr.Slider(
91
  label="Max new tokens",
92
  minimum=1,
93
  maximum=MAX_MAX_NEW_TOKENS,
94
  step=1,
95
  value=DEFAULT_MAX_NEW_TOKENS,
96
+ ),
97
+ gr.Slider(
98
  label="Temperature",
99
  minimum=0.1,
100
  maximum=4.0,
101
  step=0.1,
102
+ value=0.6,
103
+ ),
104
+ gr.Slider(
105
  label="Top-p (nucleus sampling)",
106
  minimum=0.05,
107
  maximum=1.0,
108
  step=0.05,
109
+ value=0.9,
110
+ ),
111
+ gr.Slider(
112
  label="Top-k",
113
  minimum=1,
114
  maximum=1000,
115
  step=1,
116
  value=50,
117
+ ),
118
+ gr.Slider(
119
+ label="Repetition penalty",
120
+ minimum=1.0,
121
+ maximum=2.0,
122
+ step=0.05,
123
+ value=1.2,
124
+ ),
125
+ ],
126
+ stop_btn=None,
127
+ examples=[
128
+ ["Hello there! How are you doing?"],
129
+ ["Can you explain briefly to me what is the Python programming language?"],
130
+ ["Explain the plot of Cinderella in a sentence."],
131
+ ["How many hours does it take a man to eat a Helicopter?"],
132
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
133
+ ],
134
+ )
135
 
136
+ with gr.Blocks(css="style.css") as demo:
137
+ gr.Markdown(DESCRIPTION)
138
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
139
+ chat_interface.render()
140
  gr.Markdown(LICENSE)
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  if __name__ == "__main__":
143
  demo.queue(max_size=20).launch()
model.py DELETED
@@ -1,64 +0,0 @@
1
- from threading import Thread
2
- from typing import Iterator
3
-
4
- import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
-
7
- model_id = "meta-llama/Llama-2-7b-chat-hf"
8
-
9
- if torch.cuda.is_available():
10
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
11
- else:
12
- model = None
13
- tokenizer = AutoTokenizer.from_pretrained(model_id)
14
-
15
-
16
- def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
17
- texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
18
- # The first user input is _not_ stripped
19
- do_strip = False
20
- for user_input, response in chat_history:
21
- user_input = user_input.strip() if do_strip else user_input
22
- do_strip = True
23
- texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
24
- message = message.strip() if do_strip else message
25
- texts.append(f"{message} [/INST]")
26
- return "".join(texts)
27
-
28
-
29
- def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
30
- prompt = get_prompt(message, chat_history, system_prompt)
31
- input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"]
32
- return input_ids.shape[-1]
33
-
34
-
35
- def run(
36
- message: str,
37
- chat_history: list[tuple[str, str]],
38
- system_prompt: str,
39
- max_new_tokens: int = 1024,
40
- temperature: float = 0.8,
41
- top_p: float = 0.95,
42
- top_k: int = 50,
43
- ) -> Iterator[str]:
44
- prompt = get_prompt(message, chat_history, system_prompt)
45
- inputs = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to("cuda")
46
-
47
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
48
- generate_kwargs = dict(
49
- inputs,
50
- streamer=streamer,
51
- max_new_tokens=max_new_tokens,
52
- do_sample=True,
53
- top_p=top_p,
54
- top_k=top_k,
55
- temperature=temperature,
56
- num_beams=1,
57
- )
58
- t = Thread(target=model.generate, kwargs=generate_kwargs)
59
- t.start()
60
-
61
- outputs = []
62
- for text in streamer:
63
- outputs.append(text)
64
- yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
style.css CHANGED
@@ -9,7 +9,7 @@ h1 {
9
  border-radius: 100vh;
10
  }
11
 
12
- #component-0 {
13
  max-width: 900px;
14
  margin: auto;
15
  padding-top: 1.5rem;
 
9
  border-radius: 100vh;
10
  }
11
 
12
+ .contain {
13
  max-width: 900px;
14
  margin: auto;
15
  padding-top: 1.5rem;