Ubuntu commited on
Commit
7935b1e
Β·
1 Parent(s): 8f9a049

corrected requiremet

Browse files
.gitignore CHANGED
@@ -1 +1,2 @@
1
  nohup.out
 
 
1
  nohup.out
2
+ __pycache__/**
TestApp/gradio-app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 MosaicML spaces authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # and
4
+ # the https://huggingface.co/spaces/HuggingFaceH4/databricks-dolly authors
5
+ import datetime
6
+ import os
7
+ from threading import Event, Thread
8
+ from uuid import uuid4
9
+
10
+ import gradio as gr
11
+ import requests
12
+ import torch
13
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
14
+
15
+ from quick_pipeline import InstructionTextGenerationPipeline as pipeline
16
+
17
+
18
+ # Configuration
19
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
20
+
21
+ examples = [
22
+ # to do: add coupled hparams so e.g. poem has higher temp
23
+ "Write a travel blog about a 3-day trip to Thailand.",
24
+ "Write a short story about a robot that has a nice day.",
25
+ "Convert the following to a single line of JSON:\n\n```name: John\nage: 30\naddress:\n street:123 Main St.\n city: San Francisco\n state: CA\n zip: 94101\n```",
26
+ "Write a quick email to congratulate MosaicML about the launch of their inference offering.",
27
+ "Explain how a candle works to a 6 year old in a few sentences.",
28
+ "What are some of the most common misconceptions about birds?",
29
+ ]
30
+
31
+ # Initialize the model and tokenizer
32
+ generate = pipeline(
33
+ "mosaicml/mpt-7b-instruct",
34
+ torch_dtype=torch.bfloat16,
35
+ trust_remote_code=True,
36
+ use_auth_token=HF_TOKEN,
37
+ )
38
+ stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
39
+
40
+
41
+ # Define a custom stopping criteria
42
+ class StopOnTokens(StoppingCriteria):
43
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
44
+ for stop_id in stop_token_ids:
45
+ if input_ids[0][-1] == stop_id:
46
+ return True
47
+ return False
48
+
49
+
50
+ def log_conversation(session_id, instruction, response, generate_kwargs):
51
+ logging_url = os.getenv("LOGGING_URL", None)
52
+ if logging_url is None:
53
+ return
54
+
55
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
56
+
57
+ data = {
58
+ "session_id": session_id,
59
+ "timestamp": timestamp,
60
+ "instruction": instruction,
61
+ "response": response,
62
+ "generate_kwargs": generate_kwargs,
63
+ }
64
+
65
+ try:
66
+ requests.post(logging_url, json=data)
67
+ except requests.exceptions.RequestException as e:
68
+ print(f"Error logging conversation: {e}")
69
+
70
+
71
+ def process_stream(instruction, temperature, top_p, top_k, max_new_tokens, session_id):
72
+ # Tokenize the input
73
+ input_ids = generate.tokenizer(
74
+ generate.format_instruction(instruction), return_tensors="pt"
75
+ ).input_ids
76
+ input_ids = input_ids.to(generate.model.device)
77
+
78
+ # Initialize the streamer and stopping criteria
79
+ streamer = TextIteratorStreamer(
80
+ generate.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
81
+ )
82
+ stop = StopOnTokens()
83
+
84
+ if temperature < 0.1:
85
+ temperature = 0.0
86
+ do_sample = False
87
+ else:
88
+ do_sample = True
89
+
90
+ gkw = {
91
+ **generate.generate_kwargs,
92
+ **{
93
+ "input_ids": input_ids,
94
+ "max_new_tokens": max_new_tokens,
95
+ "temperature": temperature,
96
+ "do_sample": do_sample,
97
+ "top_p": top_p,
98
+ "top_k": top_k,
99
+ "streamer": streamer,
100
+ "stopping_criteria": StoppingCriteriaList([stop]),
101
+ },
102
+ }
103
+
104
+ response = ""
105
+ stream_complete = Event()
106
+
107
+ def generate_and_signal_complete():
108
+ generate.model.generate(**gkw)
109
+ stream_complete.set()
110
+
111
+ def log_after_stream_complete():
112
+ stream_complete.wait()
113
+ log_conversation(
114
+ session_id,
115
+ instruction,
116
+ response,
117
+ {
118
+ "top_k": top_k,
119
+ "top_p": top_p,
120
+ "temperature": temperature,
121
+ },
122
+ )
123
+
124
+ t1 = Thread(target=generate_and_signal_complete)
125
+ t1.start()
126
+
127
+ t2 = Thread(target=log_after_stream_complete)
128
+ t2.start()
129
+
130
+ for new_text in streamer:
131
+ response += new_text
132
+ yield response
133
+
134
+
135
+ with gr.Blocks(
136
+ theme=gr.themes.Soft(),
137
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
138
+ ) as demo:
139
+ session_id = gr.State(lambda: str(uuid4()))
140
+ gr.Markdown(
141
+ """<h1><center>MosaicML MPT-7B-Instruct</center></h1>
142
+ This demo is of [MPT-7B-Instruct](https://huggingface.co/mosaicml/mpt-7b-instruct). It is based on [MPT-7B](https://huggingface.co/mosaicml/mpt-7b) fine-tuned with approximately [60,000 instruction demonstrations](https://huggingface.co/datasets/sam-mosaic/dolly_hhrlhf)
143
+ If you're interested in [training](https://www.mosaicml.com/training) and [deploying](https://www.mosaicml.com/inference) your own MPT or LLMs, [sign up](https://forms.mosaicml.com/demo?utm_source=huggingface&utm_medium=referral&utm_campaign=mpt-7b) for MosaicML platform.
144
+ This is running on a smaller, shared GPU, so it may take a few seconds to respond. If you want to run it on your own GPU, you can [download the model from HuggingFace](https://huggingface.co/mosaicml/mpt-7b-instruct) and run it locally. Or [Duplicate the Space](https://huggingface.co/spaces/mosaicml/mpt-7b-instruct?duplicate=true) to skip the queue and run in a private space."""
145
+ )
146
+ with gr.Row():
147
+ with gr.Column():
148
+ with gr.Row():
149
+ instruction = gr.Textbox(
150
+ placeholder="Enter your question here",
151
+ label="Question/Instruction",
152
+ elem_id="q-input",
153
+ )
154
+ with gr.Accordion("Advanced Options:", open=False):
155
+ with gr.Row():
156
+ with gr.Column():
157
+ with gr.Row():
158
+ temperature = gr.Slider(
159
+ label="Temperature",
160
+ value=0.1,
161
+ minimum=0.0,
162
+ maximum=1.0,
163
+ step=0.1,
164
+ interactive=True,
165
+ info="Higher values produce more diverse outputs",
166
+ )
167
+ with gr.Column():
168
+ with gr.Row():
169
+ top_p = gr.Slider(
170
+ label="Top-p (nucleus sampling)",
171
+ value=1.0,
172
+ minimum=0.0,
173
+ maximum=1,
174
+ step=0.01,
175
+ interactive=True,
176
+ info=(
177
+ "Sample from the smallest possible set of tokens whose cumulative probability "
178
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
179
+ ),
180
+ )
181
+ with gr.Column():
182
+ with gr.Row():
183
+ top_k = gr.Slider(
184
+ label="Top-k",
185
+ value=0,
186
+ minimum=0.0,
187
+ maximum=200,
188
+ step=1,
189
+ interactive=True,
190
+ info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.",
191
+ )
192
+ with gr.Column():
193
+ with gr.Row():
194
+ max_new_tokens = gr.Slider(
195
+ label="Maximum new tokens",
196
+ value=256,
197
+ minimum=0,
198
+ maximum=1664,
199
+ step=5,
200
+ interactive=True,
201
+ info="The maximum number of new tokens to generate",
202
+ )
203
+ with gr.Row():
204
+ submit = gr.Button("Submit")
205
+ with gr.Row():
206
+ with gr.Box():
207
+ gr.Markdown("**MPT-7B-Instruct**")
208
+ output_7b = gr.Markdown()
209
+
210
+ with gr.Row():
211
+ gr.Examples(
212
+ examples=examples,
213
+ inputs=[instruction],
214
+ cache_examples=False,
215
+ fn=process_stream,
216
+ outputs=output_7b,
217
+ )
218
+ with gr.Row():
219
+ gr.Markdown(
220
+ "Disclaimer: MPT-7B can produce factually incorrect output, and should not be relied on to produce "
221
+ "factually accurate information. MPT-7B was trained on various public datasets; while great efforts "
222
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
223
+ "biased, or otherwise offensive outputs.",
224
+ elem_classes=["disclaimer"],
225
+ )
226
+ with gr.Row():
227
+ gr.Markdown(
228
+ "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)",
229
+ elem_classes=["disclaimer"],
230
+ )
231
+
232
+ submit.click(
233
+ process_stream,
234
+ inputs=[instruction, temperature, top_p, top_k, max_new_tokens, session_id],
235
+ outputs=output_7b,
236
+ )
237
+ instruction.submit(
238
+ process_stream,
239
+ inputs=[instruction, temperature, top_p, top_k, max_new_tokens, session_id],
240
+ outputs=output_7b,
241
+ )
242
+
243
+ demo.queue(max_size=32, concurrency_count=4).launch(debug=True)
TestApp/quick_pipeline.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Tuple
2
+ import warnings
3
+
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+
8
+ INSTRUCTION_KEY = "### Instruction:"
9
+ RESPONSE_KEY = "### Response:"
10
+ END_KEY = "### End"
11
+ INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
12
+ PROMPT_FOR_GENERATION_FORMAT = """{intro}
13
+
14
+ {instruction_key}
15
+ {instruction}
16
+
17
+ {response_key}
18
+ """.format(
19
+ intro=INTRO_BLURB,
20
+ instruction_key=INSTRUCTION_KEY,
21
+ instruction="{instruction}",
22
+ response_key=RESPONSE_KEY,
23
+ )
24
+
25
+
26
+ class InstructionTextGenerationPipeline:
27
+ def __init__(
28
+ self,
29
+ model_name,
30
+ torch_dtype=torch.bfloat16,
31
+ trust_remote_code=True,
32
+ use_auth_token=None,
33
+ ) -> None:
34
+ self.model = AutoModelForCausalLM.from_pretrained(
35
+ model_name,
36
+ torch_dtype=torch_dtype,
37
+ trust_remote_code=trust_remote_code,
38
+ use_auth_token=use_auth_token,
39
+ )
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ model_name,
43
+ trust_remote_code=trust_remote_code,
44
+ use_auth_token=use_auth_token,
45
+ )
46
+ if tokenizer.pad_token_id is None:
47
+ warnings.warn(
48
+ "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
49
+ )
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+ tokenizer.padding_side = "left"
52
+ self.tokenizer = tokenizer
53
+
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ self.model.eval()
56
+ self.model.to(device=device, dtype=torch_dtype)
57
+
58
+ self.generate_kwargs = {
59
+ "temperature": 0.5,
60
+ "top_p": 0.92,
61
+ "top_k": 0,
62
+ "max_new_tokens": 512,
63
+ "use_cache": True,
64
+ "do_sample": True,
65
+ "eos_token_id": self.tokenizer.eos_token_id,
66
+ "pad_token_id": self.tokenizer.pad_token_id,
67
+ "repetition_penalty": 1.1, # 1.0 means no penalty, > 1.0 means penalty, 1.2 from CTRL paper
68
+ }
69
+
70
+ def format_instruction(self, instruction):
71
+ return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
72
+
73
+ def __call__(
74
+ self, instruction: str, **generate_kwargs: Dict[str, Any]
75
+ ) -> Tuple[str, str, float]:
76
+ s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
77
+ input_ids = self.tokenizer(s, return_tensors="pt").input_ids
78
+ input_ids = input_ids.to(self.model.device)
79
+ gkw = {**self.generate_kwargs, **generate_kwargs}
80
+ with torch.no_grad():
81
+ output_ids = self.model.generate(input_ids, **gkw)
82
+ # Slice the output_ids tensor to get only new tokens
83
+ new_tokens = output_ids[0, len(input_ids[0]) :]
84
+ output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
85
+ return output_text
requirements.txt CHANGED
@@ -1,6 +1,7 @@
 
1
  urllib3==1.26.6
2
  gradio
3
- transformers
4
  einops
5
  torch
6
  config
 
1
+ -e git+https://github.com/samhavens/just-triton-flash.git#egg=flash_attn
2
  urllib3==1.26.6
3
  gradio
4
+ transformers==4.29.2
5
  einops
6
  torch
7
  config