zouhsab commited on
Commit
c79b360
Β·
verified Β·
1 Parent(s): feaa9a9

Upload gradio_web_server.py

Browse files
Files changed (1) hide show
  1. gradio_web_server.py +472 -0
gradio_web_server.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ from tinyllava.conversation import (default_conversation, conv_templates,
11
+ SeparatorStyle)
12
+ from tinyllava.constants import LOGDIR
13
+ from tinyllava.utils import (build_logger, server_error_msg,
14
+ violates_moderation, moderation_msg)
15
+ import hashlib
16
+
17
+
18
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
+
20
+ headers = {"User-Agent": "LLaVA Client"}
21
+
22
+ no_change_btn = gr.Button.update()
23
+ enable_btn = gr.Button.update(interactive=True)
24
+ disable_btn = gr.Button.update(interactive=False)
25
+
26
+ priority = {
27
+ "vicuna-13b": "aaaaaaa",
28
+ "koala-13b": "aaaaaab",
29
+ }
30
+
31
+
32
+ def get_conv_log_filename():
33
+ t = datetime.datetime.now()
34
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
+ return name
36
+
37
+
38
+ def get_model_list():
39
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
40
+ assert ret.status_code == 200
41
+ ret = requests.post(args.controller_url + "/list_models")
42
+ models = ret.json()["models"]
43
+ models.sort(key=lambda x: priority.get(x, x))
44
+ logger.info(f"Models: {models}")
45
+ return models
46
+
47
+
48
+ get_window_url_params = """
49
+ function() {
50
+ const params = new URLSearchParams(window.location.search);
51
+ url_params = Object.fromEntries(params);
52
+ console.log(url_params);
53
+ return url_params;
54
+ }
55
+ """
56
+
57
+
58
+ def load_demo(url_params, request: gr.Request):
59
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
+
61
+ dropdown_update = gr.Dropdown.update(visible=True)
62
+ if "model" in url_params:
63
+ model = url_params["model"]
64
+ if model in models:
65
+ dropdown_update = gr.Dropdown.update(
66
+ value=model, visible=True)
67
+
68
+ state = default_conversation.copy()
69
+ return state, dropdown_update
70
+
71
+
72
+ def load_demo_refresh_model_list(request: gr.Request):
73
+ logger.info(f"load_demo. ip: {request.client.host}")
74
+ models = get_model_list()
75
+ state = default_conversation.copy()
76
+ dropdown_update = gr.Dropdown.update(
77
+ choices=models,
78
+ value=models[0] if len(models) > 0 else ""
79
+ )
80
+ return state, dropdown_update
81
+
82
+
83
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
+ with open(get_conv_log_filename(), "a") as fout:
85
+ data = {
86
+ "tstamp": round(time.time(), 4),
87
+ "type": vote_type,
88
+ "model": model_selector,
89
+ "state": state.dict(),
90
+ "ip": request.client.host,
91
+ }
92
+ fout.write(json.dumps(data) + "\n")
93
+
94
+
95
+ def upvote_last_response(state, model_selector, request: gr.Request):
96
+ logger.info(f"upvote. ip: {request.client.host}")
97
+ vote_last_response(state, "upvote", model_selector, request)
98
+ return ("",) + (disable_btn,) * 3
99
+
100
+
101
+ def downvote_last_response(state, model_selector, request: gr.Request):
102
+ logger.info(f"downvote. ip: {request.client.host}")
103
+ vote_last_response(state, "downvote", model_selector, request)
104
+ return ("",) + (disable_btn,) * 3
105
+
106
+
107
+ def flag_last_response(state, model_selector, request: gr.Request):
108
+ logger.info(f"flag. ip: {request.client.host}")
109
+ vote_last_response(state, "flag", model_selector, request)
110
+ return ("",) + (disable_btn,) * 3
111
+
112
+
113
+ def regenerate(state, image_process_mode, request: gr.Request):
114
+ logger.info(f"regenerate. ip: {request.client.host}")
115
+ state.messages[-1][-1] = None
116
+ prev_human_msg = state.messages[-2]
117
+ if type(prev_human_msg[1]) in (tuple, list):
118
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119
+ state.skip_next = False
120
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
121
+
122
+
123
+ def clear_history(request: gr.Request):
124
+ logger.info(f"clear_history. ip: {request.client.host}")
125
+ state = default_conversation.copy()
126
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
+
128
+
129
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
130
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
131
+ if len(text) <= 0 and image is None:
132
+ state.skip_next = True
133
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
134
+ if args.moderate:
135
+ flagged = violates_moderation(text)
136
+ if flagged:
137
+ state.skip_next = True
138
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
139
+ no_change_btn,) * 5
140
+
141
+ text = text[:1536] # Hard cut-off
142
+ if image is not None:
143
+ text = text[:1200] # Hard cut-off for images
144
+ if '<image>' not in text:
145
+ # text = '<Image><image></Image>' + text
146
+ text = text + '\n<image>'
147
+ text = (text, image, image_process_mode)
148
+ if len(state.get_images(return_pil=True)) > 0:
149
+ state = default_conversation.copy()
150
+ state.append_message(state.roles[0], text)
151
+ state.append_message(state.roles[1], None)
152
+ state.skip_next = False
153
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
+
155
+
156
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
+ logger.info(f"http_bot. ip: {request.client.host}")
158
+ start_tstamp = time.time()
159
+ model_name = model_selector
160
+
161
+ if state.skip_next:
162
+ # This generate call is skipped due to invalid inputs
163
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
+ return
165
+
166
+ if len(state.messages) == state.offset + 2:
167
+ # First round of conversation
168
+ if "tinyllava" in model_name.lower():
169
+ if 'llama-2' in model_name.lower():
170
+ template_name = "llava_llama_2"
171
+ elif "v1" in model_name.lower():
172
+ if 'mmtag' in model_name.lower():
173
+ template_name = "v1_mmtag"
174
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
175
+ template_name = "v1_mmtag"
176
+ else:
177
+ template_name = "llava_v1"
178
+ elif 'phi' in model_name.lower():
179
+ template_name = "phi"
180
+ elif "mpt" in model_name.lower():
181
+ template_name = "mpt"
182
+ else:
183
+ if 'mmtag' in model_name.lower():
184
+ template_name = "v0_mmtag"
185
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
186
+ template_name = "v0_mmtag"
187
+ else:
188
+ template_name = "llava_v0"
189
+ elif "mpt" in model_name:
190
+ template_name = "mpt_text"
191
+ elif "llama-2" in model_name:
192
+ template_name = "llama_2"
193
+ else:
194
+ template_name = "vicuna_v1"
195
+ new_state = conv_templates[template_name].copy()
196
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
197
+ new_state.append_message(new_state.roles[1], None)
198
+ state = new_state
199
+
200
+ # Query worker address
201
+ controller_url = args.controller_url
202
+ ret = requests.post(controller_url + "/get_worker_address",
203
+ json={"model": model_name})
204
+ worker_addr = ret.json()["address"]
205
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
206
+
207
+ # No available worker
208
+ if worker_addr == "":
209
+ state.messages[-1][-1] = server_error_msg
210
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
211
+ return
212
+
213
+ # Construct prompt
214
+ prompt = state.get_prompt()
215
+
216
+ all_images = state.get_images(return_pil=True)
217
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
218
+ for image, hash in zip(all_images, all_image_hash):
219
+ t = datetime.datetime.now()
220
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
221
+ if not os.path.isfile(filename):
222
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
223
+ image.save(filename)
224
+
225
+ # Make requests
226
+ pload = {
227
+ "model": model_name,
228
+ "prompt": prompt,
229
+ "temperature": float(temperature),
230
+ "top_p": float(top_p),
231
+ "max_new_tokens": min(int(max_new_tokens), 1536),
232
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
233
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
234
+ }
235
+ logger.info(f"==== request ====\n{pload}")
236
+
237
+ pload['images'] = state.get_images()
238
+
239
+ state.messages[-1][-1] = "β–Œ"
240
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
241
+
242
+ try:
243
+ # Stream output
244
+ response = requests.post(worker_addr + "/worker_generate_stream",
245
+ headers=headers, json=pload, stream=True, timeout=10)
246
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
247
+ if chunk:
248
+ data = json.loads(chunk.decode())
249
+ if data["error_code"] == 0:
250
+ output = data["text"][len(prompt):].strip()
251
+ state.messages[-1][-1] = output + "β–Œ"
252
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
253
+ else:
254
+ output = data["text"] + f" (error_code: {data['error_code']})"
255
+ state.messages[-1][-1] = output
256
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
257
+ return
258
+ time.sleep(0.03)
259
+ except requests.exceptions.RequestException as e:
260
+ state.messages[-1][-1] = server_error_msg
261
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
262
+ return
263
+
264
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
265
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
266
+
267
+ finish_tstamp = time.time()
268
+ logger.info(f"{output}")
269
+
270
+ with open(get_conv_log_filename(), "a") as fout:
271
+ data = {
272
+ "tstamp": round(finish_tstamp, 4),
273
+ "type": "chat",
274
+ "model": model_name,
275
+ "start": round(start_tstamp, 4),
276
+ "finish": round(finish_tstamp, 4),
277
+ "state": state.dict(),
278
+ "images": all_image_hash,
279
+ "ip": request.client.host,
280
+ }
281
+ fout.write(json.dumps(data) + "\n")
282
+
283
+ title_markdown = ("""
284
+ # πŸŒ‹ LLaVA: Large Language and Vision Assistant
285
+ [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | πŸ“š [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
286
+ """)
287
+
288
+ tos_markdown = ("""
289
+ ### Terms of use
290
+ By using this service, users are required to agree to the following terms:
291
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
292
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
293
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
294
+ """)
295
+
296
+
297
+ learn_more_markdown = ("""
298
+ ### License
299
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
300
+ """)
301
+
302
+ block_css = """
303
+
304
+ #buttons button {
305
+ min-width: min(120px,100%);
306
+ }
307
+
308
+ """
309
+
310
+ def build_demo(embed_mode):
311
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
312
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
313
+ state = gr.State()
314
+
315
+ if not embed_mode:
316
+ gr.Markdown(title_markdown)
317
+
318
+ with gr.Row():
319
+ with gr.Column(scale=3):
320
+ with gr.Row(elem_id="model_selector_row"):
321
+ model_selector = gr.Dropdown(
322
+ choices=models,
323
+ value=models[0] if len(models) > 0 else "",
324
+ interactive=True,
325
+ show_label=False,
326
+ container=False)
327
+
328
+ imagebox = gr.Image(type="pil")
329
+ image_process_mode = gr.Radio(
330
+ ["Crop", "Resize", "Pad", "Default"],
331
+ value="Default",
332
+ label="Preprocess for non-square image", visible=False)
333
+
334
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
335
+ gr.Examples(examples=[
336
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
337
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
338
+ ], inputs=[imagebox, textbox])
339
+
340
+ with gr.Accordion("Parameters", open=False) as parameter_row:
341
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
342
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
343
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
344
+
345
+ with gr.Column(scale=8):
346
+ chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
347
+ with gr.Row():
348
+ with gr.Column(scale=8):
349
+ textbox.render()
350
+ with gr.Column(scale=1, min_width=50):
351
+ submit_btn = gr.Button(value="Send", variant="primary")
352
+ with gr.Row(elem_id="buttons") as button_row:
353
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
354
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
355
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
356
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
357
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
358
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
359
+
360
+ if not embed_mode:
361
+ gr.Markdown(tos_markdown)
362
+ gr.Markdown(learn_more_markdown)
363
+ url_params = gr.JSON(visible=False)
364
+
365
+ # Register listeners
366
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
367
+ upvote_btn.click(
368
+ upvote_last_response,
369
+ [state, model_selector],
370
+ [textbox, upvote_btn, downvote_btn, flag_btn],
371
+ queue=False
372
+ )
373
+ downvote_btn.click(
374
+ downvote_last_response,
375
+ [state, model_selector],
376
+ [textbox, upvote_btn, downvote_btn, flag_btn],
377
+ queue=False
378
+ )
379
+ flag_btn.click(
380
+ flag_last_response,
381
+ [state, model_selector],
382
+ [textbox, upvote_btn, downvote_btn, flag_btn],
383
+ queue=False
384
+ )
385
+
386
+ regenerate_btn.click(
387
+ regenerate,
388
+ [state, image_process_mode],
389
+ [state, chatbot, textbox, imagebox] + btn_list,
390
+ queue=False
391
+ ).then(
392
+ http_bot,
393
+ [state, model_selector, temperature, top_p, max_output_tokens],
394
+ [state, chatbot] + btn_list
395
+ )
396
+
397
+ clear_btn.click(
398
+ clear_history,
399
+ None,
400
+ [state, chatbot, textbox, imagebox] + btn_list,
401
+ queue=False
402
+ )
403
+
404
+ textbox.submit(
405
+ add_text,
406
+ [state, textbox, imagebox, image_process_mode],
407
+ [state, chatbot, textbox, imagebox] + btn_list,
408
+ queue=False
409
+ ).then(
410
+ http_bot,
411
+ [state, model_selector, temperature, top_p, max_output_tokens],
412
+ [state, chatbot] + btn_list
413
+ )
414
+
415
+ submit_btn.click(
416
+ add_text,
417
+ [state, textbox, imagebox, image_process_mode],
418
+ [state, chatbot, textbox, imagebox] + btn_list,
419
+ queue=False
420
+ ).then(
421
+ http_bot,
422
+ [state, model_selector, temperature, top_p, max_output_tokens],
423
+ [state, chatbot] + btn_list
424
+ )
425
+
426
+ if args.model_list_mode == "once":
427
+ demo.load(
428
+ load_demo,
429
+ [url_params],
430
+ [state, model_selector],
431
+ _js=get_window_url_params,
432
+ queue=False
433
+ )
434
+ elif args.model_list_mode == "reload":
435
+ demo.load(
436
+ load_demo_refresh_model_list,
437
+ None,
438
+ [state, model_selector],
439
+ queue=False
440
+ )
441
+ else:
442
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
443
+
444
+ return demo
445
+
446
+
447
+ if __name__ == "__main__":
448
+ parser = argparse.ArgumentParser()
449
+ parser.add_argument("--host", type=str, default="127.0.0.1")
450
+ parser.add_argument("--port", type=int, default=6006)
451
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
452
+ parser.add_argument("--concurrency-count", type=int, default=10)
453
+ parser.add_argument("--model-list-mode", type=str, default="once",
454
+ choices=["once", "reload"])
455
+ parser.add_argument("--share", action="store_true")
456
+ parser.add_argument("--moderate", action="store_true")
457
+ parser.add_argument("--embed", action="store_true")
458
+ args = parser.parse_args()
459
+ logger.info(f"args: {args}")
460
+
461
+ models = get_model_list()
462
+
463
+ logger.info(args)
464
+ demo = build_demo(args.embed)
465
+ demo.queue(
466
+ concurrency_count=args.concurrency_count,
467
+ api_open=False
468
+ ).launch(
469
+ server_name=args.host,
470
+ server_port=args.port,
471
+ share=args.share
472
+ )