vwxyzjn commited on
Commit
c9776de
·
1 Parent(s): 73437fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -60
app.py CHANGED
@@ -1,22 +1,17 @@
1
  import os
2
  import re
3
  import copy
 
4
 
5
  import gradio as gr
6
  from text_generation import Client
7
  from transformers import load_tool
8
-
9
-
10
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
11
 
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
  print(HF_TOKEN)
15
 
16
- API_URL = "https://api-inference.huggingface.co/models/vwxyzjn/starcoderbase-triviaqa"
17
- API_URL_BASE ="https://api-inference.huggingface.co/models/bigcode/starcoderbase"
18
- API_URL_PLUS = "https://api-inference.huggingface.co/models/bigcode/starcoderplus"
19
-
20
  FIM_PREFIX = "<fim_prefix>"
21
  FIM_MIDDLE = "<fim_middle>"
22
  FIM_SUFFIX = "<fim_suffix>"
@@ -77,13 +72,31 @@ theme = gr.themes.Monochrome(
77
  ],
78
  )
79
 
80
- client = Client(
81
- API_URL,
82
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
83
- )
84
  tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
85
- tool_fn = lambda x: tool(x).split("\n")[1][:600] # limit the amount if tokens
86
- tools = {"Wiki": tool_fn}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def parse_tool_call(text, request_token="<request>", call_token="<call>"):
89
  """
@@ -113,9 +126,9 @@ def parse_tool_call(text, request_token="<request>", call_token="<call>"):
113
 
114
 
115
  def generate(
116
- prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, version="StarCoderBase TriviaQA",
117
  ):
118
-
119
  temperature = float(temperature)
120
  if temperature < 1e-2:
121
  temperature = 1e-2
@@ -132,48 +145,55 @@ def generate(
132
  seed=42,
133
  stop_sequences=["<call>"]
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- if version == "StarCoderBase TriviaQA":
137
- stream = client.generate_stream(prompt, **generate_kwargs)
138
-
139
-
140
- # call env phase
141
- output = prompt
142
- previous_token = ""
143
- for response in stream:
144
- if response.token.text == "<|endoftext|>":
145
- return output
146
- else:
147
- output += response.token.text
148
- previous_token = response.token.text
149
- # text env logic:
150
- tool, query = parse_tool_call(output[len(prompt):])
151
- if tool is not None and query is not None:
152
- if tool not in tools:
153
- response = f"Unknown tool {tool}."
154
- try:
155
- response = tools[tool](query)
156
- output += response + "<response>"
157
- except Exception as error:
158
- response = f"Tool error: {str(error)}"
159
- yield output[len(prompt):]
160
-
161
- call_output = copy.deepcopy(output)
162
- # response phase
163
- generate_kwargs["stop_sequences"] = ["<submit>"]
164
- stream = client.generate_stream(output, **generate_kwargs)
165
- previous_token = ""
166
- for response in stream:
167
- if response.token.text == "<|endoftext|>":
168
  return output
169
- else:
170
- output += response.token.text
171
- previous_token = response.token.text
172
- yield output[len(prompt):]
173
-
174
-
175
-
176
- return output
177
 
178
 
179
  examples = [
@@ -224,15 +244,21 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
224
  gr.Markdown(description)
225
  with gr.Row():
226
  version = gr.Dropdown(
227
- ["StarCoderBase TriviaQA"],
228
- value="StarCoderBase TriviaQA",
229
  label="Model",
230
  info="Choose a model from the list",
231
  )
 
 
 
 
 
232
  with gr.Row():
233
  with gr.Column():
234
  instruction = gr.Textbox(
235
- placeholder="Enter your code here",
 
236
  lines=5,
237
  label="Input",
238
  elem_id="q-input",
@@ -297,15 +323,15 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
297
  fn=process_example,
298
  outputs=[output],
299
  )
300
- gr.Markdown(FORMATS)
301
 
302
  submit.click(
303
  generate,
304
- inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version],
305
  outputs=[output],
306
  )
307
  share_button.click(None, [], [], _js=share_js)
308
- demo.queue(concurrency_count=16).launch(debug=True)
309
 
310
 
311
  """
 
1
  import os
2
  import re
3
  import copy
4
+ import time
5
 
6
  import gradio as gr
7
  from text_generation import Client
8
  from transformers import load_tool
 
 
9
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
10
 
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
  print(HF_TOKEN)
14
 
 
 
 
 
15
  FIM_PREFIX = "<fim_prefix>"
16
  FIM_MIDDLE = "<fim_middle>"
17
  FIM_SUFFIX = "<fim_suffix>"
 
72
  ],
73
  )
74
 
 
 
 
 
75
  tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
76
+ tool_fn = lambda x: tool(x).split("\n")[1][:600] # limit the amount if token, system_prompts
77
+
78
+ clients = {
79
+ "StarCoderBase TriviaQA": [
80
+ Client(
81
+ "https://api-inference.huggingface.co/models/vwxyzjn/starcoderbase-triviaqa",
82
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
83
+ ),
84
+ {"Wiki": tool_fn},
85
+ """\
86
+ Answer the following question:
87
+
88
+ Q: In which branch of the arts is Patricia Neary famous?
89
+ A: Ballets
90
+ A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
91
+ Result=Ballets<submit>
92
+
93
+ Q: Who won Super Bowl XX?
94
+ A: Chicago Bears
95
+ A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
96
+ Result=Chicago Bears<submit>
97
+ """
98
+ ],
99
+ }
100
 
101
  def parse_tool_call(text, request_token="<request>", call_token="<call>"):
102
  """
 
126
 
127
 
128
  def generate(
129
+ prompt, system_prompt, version, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
130
  ):
131
+ client, tools, _ = clients[version]
132
  temperature = float(temperature)
133
  if temperature < 1e-2:
134
  temperature = 1e-2
 
145
  seed=42,
146
  stop_sequences=["<call>"]
147
  )
148
+ generation_still_running = True
149
+ while generation_still_running:
150
+ try:
151
+ stream = client.generate_stream(system_prompt + prompt, **generate_kwargs)
152
+
153
+
154
+ # call env phase
155
+ output = system_prompt + prompt
156
+ previous_token = ""
157
+ for response in stream:
158
+ if response.token.text == "<|endoftext|>":
159
+ return output
160
+ else:
161
+ output += response.token.text
162
+ previous_token = response.token.text
163
+ # text env logic:
164
+ tool, query = parse_tool_call(output[len(system_prompt + prompt):])
165
+ print("tool", tool, query)
166
+ if tool is not None and query is not None:
167
+ if tool not in tools:
168
+ response = f"Unknown tool {tool}."
169
+ try:
170
+ response = tools[tool](query)
171
+ output += response + "<response>"
172
+ except Exception as error:
173
+ response = f"Tool error: {str(error)}"
174
+ yield output[len(system_prompt + prompt):]
175
+
176
+ call_output = copy.deepcopy(output)
177
+ # response phase
178
+ generate_kwargs["stop_sequences"] = ["<submit>"]
179
+ stream = client.generate_stream(output, **generate_kwargs)
180
+ previous_token = ""
181
+ for response in stream:
182
+ if response.token.text == "<|endoftext|>":
183
+ return output
184
+ else:
185
+ output += response.token.text
186
+ previous_token = response.token.text
187
+ yield output[len(system_prompt + prompt):]
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  return output
190
+ except Exception as e:
191
+ if "loading" in str(e):
192
+ gr.Warning("waiting for model to load... (this could take up to 20 minutes, after which things are much faster)")
193
+ time.sleep(7)
194
+ continue
195
+ else:
196
+ raise gr.Error(str(e))
 
197
 
198
 
199
  examples = [
 
244
  gr.Markdown(description)
245
  with gr.Row():
246
  version = gr.Dropdown(
247
+ list(clients.keys()),
248
+ value=list(clients.keys())[0],
249
  label="Model",
250
  info="Choose a model from the list",
251
  )
252
+ system_prompt = gr.Textbox(
253
+ value=clients[list(clients.keys())[0]][2],
254
+ label="System prompt",
255
+ )
256
+
257
  with gr.Row():
258
  with gr.Column():
259
  instruction = gr.Textbox(
260
+ value="Q: In which country is Oberhofen situated?",
261
+ # placeholder="Enter your question here. E.g., Q: In which country is Oberhofen situated?",
262
  lines=5,
263
  label="Input",
264
  elem_id="q-input",
 
323
  fn=process_example,
324
  outputs=[output],
325
  )
326
+ # gr.Markdown(FORMATS)
327
 
328
  submit.click(
329
  generate,
330
+ inputs=[instruction, system_prompt, version, temperature, max_new_tokens, top_p, repetition_penalty],
331
  outputs=[output],
332
  )
333
  share_button.click(None, [], [], _js=share_js)
334
+ demo.queue(concurrency_count=16).launch(share=True)
335
 
336
 
337
  """