yatima-k commited on
Commit
b818699
·
verified ·
1 Parent(s): beff924

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -8
  2. __init__.py +0 -0
  3. __pycache__/__init__.cpython-39.pyc +0 -0
  4. __pycache__/api_provider.cpython-39.pyc +0 -0
  5. __pycache__/base_model_worker.cpython-39.pyc +0 -0
  6. __pycache__/call_monitor.cpython-39.pyc +0 -0
  7. __pycache__/cli.cpython-39.pyc +0 -0
  8. __pycache__/controller.cpython-39.pyc +0 -0
  9. __pycache__/gradio_block_arena_anony.cpython-39.pyc +0 -0
  10. __pycache__/gradio_block_arena_named.cpython-39.pyc +0 -0
  11. __pycache__/gradio_block_arena_vision.cpython-39.pyc +0 -0
  12. __pycache__/gradio_web_server.cpython-39.pyc +0 -0
  13. __pycache__/gradio_web_server_multi.cpython-39.pyc +0 -0
  14. __pycache__/huggingface_api.cpython-39.pyc +0 -0
  15. __pycache__/huggingface_api_worker.cpython-39.pyc +0 -0
  16. __pycache__/inference.cpython-39.pyc +0 -0
  17. __pycache__/launch_all_serve.cpython-39.pyc +0 -0
  18. __pycache__/lightllm_worker.cpython-39.pyc +0 -0
  19. __pycache__/mlx_worker.cpython-39.pyc +0 -0
  20. __pycache__/model_worker.cpython-39.pyc +0 -0
  21. __pycache__/multi_model_worker.cpython-39.pyc +0 -0
  22. __pycache__/openai_api_server.cpython-39.pyc +0 -0
  23. __pycache__/register_worker.cpython-39.pyc +0 -0
  24. __pycache__/sglang_worker.cpython-39.pyc +0 -0
  25. __pycache__/shutdown_serve.cpython-39.pyc +0 -0
  26. __pycache__/test_message.cpython-39.pyc +0 -0
  27. __pycache__/test_throughput.cpython-39.pyc +0 -0
  28. __pycache__/vllm_worker.cpython-39.pyc +0 -0
  29. api_provider.py +454 -0
  30. base_model_worker.py +241 -0
  31. call_monitor.py +219 -0
  32. cli.py +304 -0
  33. controller.py +389 -0
  34. gradio_block_arena_anony.py +811 -0
  35. gradio_block_arena_named.py +469 -0
  36. gradio_block_arena_vision.py +187 -0
  37. gradio_web_server.py +887 -0
  38. gradio_web_server_multi.py +277 -0
  39. huggingface_api.py +73 -0
  40. huggingface_api_worker.py +415 -0
  41. inference.py +555 -0
  42. launch_all_serve.py +284 -0
  43. lightllm_worker.py +512 -0
  44. mlx_worker.py +288 -0
  45. model_worker.py +425 -0
  46. monitor/__pycache__/basic_stats.cpython-39.pyc +0 -0
  47. monitor/__pycache__/clean_battle_data.cpython-39.pyc +0 -0
  48. monitor/__pycache__/clean_chat_data.cpython-39.pyc +0 -0
  49. monitor/__pycache__/elo_analysis.cpython-39.pyc +0 -0
  50. monitor/__pycache__/inspect_conv.cpython-39.pyc +0 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Test Fastchat
3
- emoji: 👁
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.27.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: test_fastchat
3
+ app_file: gradio_web_server.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.27.0
 
 
6
  ---
 
 
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-39.pyc ADDED
Binary file (190 Bytes). View file
 
__pycache__/api_provider.cpython-39.pyc ADDED
Binary file (7.84 kB). View file
 
__pycache__/base_model_worker.cpython-39.pyc ADDED
Binary file (7.09 kB). View file
 
__pycache__/call_monitor.cpython-39.pyc ADDED
Binary file (6.5 kB). View file
 
__pycache__/cli.cpython-39.pyc ADDED
Binary file (8.98 kB). View file
 
__pycache__/controller.cpython-39.pyc ADDED
Binary file (10.3 kB). View file
 
__pycache__/gradio_block_arena_anony.cpython-39.pyc ADDED
Binary file (15.5 kB). View file
 
__pycache__/gradio_block_arena_named.cpython-39.pyc ADDED
Binary file (11.1 kB). View file
 
__pycache__/gradio_block_arena_vision.cpython-39.pyc ADDED
Binary file (4.25 kB). View file
 
__pycache__/gradio_web_server.cpython-39.pyc ADDED
Binary file (21.9 kB). View file
 
__pycache__/gradio_web_server_multi.cpython-39.pyc ADDED
Binary file (6.25 kB). View file
 
__pycache__/huggingface_api.cpython-39.pyc ADDED
Binary file (1.98 kB). View file
 
__pycache__/huggingface_api_worker.cpython-39.pyc ADDED
Binary file (11 kB). View file
 
__pycache__/inference.cpython-39.pyc ADDED
Binary file (10.7 kB). View file
 
__pycache__/launch_all_serve.cpython-39.pyc ADDED
Binary file (6.34 kB). View file
 
__pycache__/lightllm_worker.cpython-39.pyc ADDED
Binary file (13 kB). View file
 
__pycache__/mlx_worker.cpython-39.pyc ADDED
Binary file (7.55 kB). View file
 
__pycache__/model_worker.cpython-39.pyc ADDED
Binary file (10 kB). View file
 
__pycache__/multi_model_worker.cpython-39.pyc ADDED
Binary file (8.78 kB). View file
 
__pycache__/openai_api_server.cpython-39.pyc ADDED
Binary file (21.6 kB). View file
 
__pycache__/register_worker.cpython-39.pyc ADDED
Binary file (914 Bytes). View file
 
__pycache__/sglang_worker.cpython-39.pyc ADDED
Binary file (8.6 kB). View file
 
__pycache__/shutdown_serve.cpython-39.pyc ADDED
Binary file (923 Bytes). View file
 
__pycache__/test_message.cpython-39.pyc ADDED
Binary file (2.11 kB). View file
 
__pycache__/test_throughput.cpython-39.pyc ADDED
Binary file (3.11 kB). View file
 
__pycache__/vllm_worker.cpython-39.pyc ADDED
Binary file (8.58 kB). View file
 
api_provider.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Call API providers."""
2
+
3
+ import json
4
+ import os
5
+ import random
6
+ import time
7
+
8
+ import requests
9
+
10
+ from fastchat.utils import build_logger
11
+
12
+
13
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
14
+
15
+
16
+ def get_api_provider_stream_iter(
17
+ conv,
18
+ model_name,
19
+ model_api_dict,
20
+ temperature,
21
+ top_p,
22
+ max_new_tokens,
23
+ ):
24
+ if model_api_dict["api_type"] == "openai":
25
+ prompt = conv.to_openai_api_messages()
26
+ stream_iter = openai_api_stream_iter(
27
+ model_api_dict["model_name"],
28
+ prompt,
29
+ temperature,
30
+ top_p,
31
+ max_new_tokens,
32
+ api_base=model_api_dict["api_base"],
33
+ api_key=model_api_dict["api_key"],
34
+ )
35
+ elif model_api_dict["api_type"] == "anthropic":
36
+ prompt = conv.get_prompt()
37
+ stream_iter = anthropic_api_stream_iter(
38
+ model_name, prompt, temperature, top_p, max_new_tokens
39
+ )
40
+ elif model_api_dict["api_type"] == "gemini":
41
+ stream_iter = gemini_api_stream_iter(
42
+ model_api_dict["model_name"],
43
+ conv,
44
+ temperature,
45
+ top_p,
46
+ max_new_tokens,
47
+ api_key=model_api_dict["api_key"],
48
+ )
49
+ elif model_api_dict["api_type"] == "bard":
50
+ prompt = conv.to_openai_api_messages()
51
+ stream_iter = bard_api_stream_iter(
52
+ model_api_dict["model_name"],
53
+ prompt,
54
+ temperature,
55
+ top_p,
56
+ api_key=model_api_dict["api_key"],
57
+ )
58
+ elif model_api_dict["api_type"] == "mistral":
59
+ prompt = conv.to_openai_api_messages()
60
+ stream_iter = mistral_api_stream_iter(
61
+ model_name, prompt, temperature, top_p, max_new_tokens
62
+ )
63
+ elif model_api_dict["api_type"] == "nvidia":
64
+ prompt = conv.to_openai_api_messages()
65
+ stream_iter = nvidia_api_stream_iter(
66
+ model_name,
67
+ prompt,
68
+ temperature,
69
+ top_p,
70
+ max_new_tokens,
71
+ model_api_dict["api_base"],
72
+ )
73
+ elif model_api_dict["api_type"] == "ai2":
74
+ prompt = conv.to_openai_api_messages()
75
+ stream_iter = ai2_api_stream_iter(
76
+ model_name,
77
+ model_api_dict["model_name"],
78
+ prompt,
79
+ temperature,
80
+ top_p,
81
+ max_new_tokens,
82
+ api_base=model_api_dict["api_base"],
83
+ api_key=model_api_dict["api_key"],
84
+ )
85
+ else:
86
+ raise NotImplementedError()
87
+
88
+ return stream_iter
89
+
90
+
91
+ def openai_api_stream_iter(
92
+ model_name,
93
+ messages,
94
+ temperature,
95
+ top_p,
96
+ max_new_tokens,
97
+ api_base=None,
98
+ api_key=None,
99
+ ):
100
+ import openai
101
+
102
+ api_key = api_key or os.environ["OPENAI_API_KEY"]
103
+
104
+ if "azure" in model_name:
105
+ client = openai.AzureOpenAI(
106
+ api_version="2023-07-01-preview",
107
+ azure_endpoint=api_base or "https://api.openai.com/v1",
108
+ api_key=api_key,
109
+ )
110
+ else:
111
+ client = openai.OpenAI(
112
+ base_url=api_base or "https://api.openai.com/v1", api_key=api_key
113
+ )
114
+
115
+ if model_name == "gpt-4-turbo":
116
+ model_name = "gpt-4-1106-preview"
117
+
118
+ # Make requests
119
+ gen_params = {
120
+ "model": model_name,
121
+ "prompt": messages,
122
+ "temperature": temperature,
123
+ "top_p": top_p,
124
+ "max_new_tokens": max_new_tokens,
125
+ }
126
+ logger.info(f"==== request ====\n{gen_params}")
127
+
128
+ res = client.chat.completions.create(
129
+ model=model_name,
130
+ messages=messages,
131
+ temperature=temperature,
132
+ max_tokens=max_new_tokens,
133
+ stream=True,
134
+ )
135
+ text = ""
136
+ for chunk in res:
137
+ if len(chunk.choices) > 0:
138
+ text += chunk.choices[0].delta.content or ""
139
+ data = {
140
+ "text": text,
141
+ "error_code": 0,
142
+ }
143
+ yield data
144
+
145
+
146
+ def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
147
+ import anthropic
148
+
149
+ c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
150
+
151
+ # Make requests
152
+ gen_params = {
153
+ "model": model_name,
154
+ "prompt": prompt,
155
+ "temperature": temperature,
156
+ "top_p": top_p,
157
+ "max_new_tokens": max_new_tokens,
158
+ }
159
+ logger.info(f"==== request ====\n{gen_params}")
160
+
161
+ res = c.completions.create(
162
+ prompt=prompt,
163
+ stop_sequences=[anthropic.HUMAN_PROMPT],
164
+ max_tokens_to_sample=max_new_tokens,
165
+ temperature=temperature,
166
+ top_p=top_p,
167
+ model=model_name,
168
+ stream=True,
169
+ )
170
+ text = ""
171
+ for chunk in res:
172
+ text += chunk.completion
173
+ data = {
174
+ "text": text,
175
+ "error_code": 0,
176
+ }
177
+ yield data
178
+
179
+
180
+ def gemini_api_stream_iter(
181
+ model_name, conv, temperature, top_p, max_new_tokens, api_key=None
182
+ ):
183
+ import google.generativeai as genai # pip install google-generativeai
184
+
185
+ if api_key is None:
186
+ api_key = os.environ["GEMINI_API_KEY"]
187
+ genai.configure(api_key=api_key)
188
+
189
+ generation_config = {
190
+ "temperature": temperature,
191
+ "max_output_tokens": max_new_tokens,
192
+ "top_p": top_p,
193
+ }
194
+ params = {
195
+ "model": model_name,
196
+ "prompt": conv,
197
+ }
198
+ params.update(generation_config)
199
+ logger.info(f"==== request ====\n{params}")
200
+
201
+ safety_settings = [
202
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
203
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
204
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
205
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
206
+ ]
207
+ model = genai.GenerativeModel(
208
+ model_name=model_name,
209
+ generation_config=generation_config,
210
+ safety_settings=safety_settings,
211
+ )
212
+ history = []
213
+ for role, message in conv.messages[:-2]:
214
+ history.append({"role": role, "parts": message})
215
+ convo = model.start_chat(history=history)
216
+ response = convo.send_message(conv.messages[-2][1], stream=True)
217
+
218
+ try:
219
+ text = ""
220
+ for chunk in response:
221
+ text += chunk.text
222
+ data = {
223
+ "text": text,
224
+ "error_code": 0,
225
+ }
226
+ yield data
227
+ except Exception as e:
228
+ logger.error(f"==== error ====\n{e}")
229
+ reason = chunk.candidates
230
+ yield {
231
+ "text": f"**API REQUEST ERROR** Reason: {reason}.",
232
+ "error_code": 1,
233
+ }
234
+
235
+
236
+ def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None):
237
+ del top_p # not supported
238
+ del temperature # not supported
239
+
240
+ if api_key is None:
241
+ api_key = os.environ["BARD_API_KEY"]
242
+
243
+ # convert conv to conv_bard
244
+ conv_bard = []
245
+ for turn in conv:
246
+ if turn["role"] == "user":
247
+ conv_bard.append({"author": "0", "content": turn["content"]})
248
+ elif turn["role"] == "assistant":
249
+ conv_bard.append({"author": "1", "content": turn["content"]})
250
+ else:
251
+ raise ValueError(f"Unsupported role: {turn['role']}")
252
+
253
+ params = {
254
+ "model": model_name,
255
+ "prompt": conv_bard,
256
+ }
257
+ logger.info(f"==== request ====\n{params}")
258
+
259
+ try:
260
+ res = requests.post(
261
+ f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}",
262
+ json={
263
+ "prompt": {
264
+ "messages": conv_bard,
265
+ },
266
+ },
267
+ timeout=30,
268
+ )
269
+ except Exception as e:
270
+ logger.error(f"==== error ====\n{e}")
271
+ yield {
272
+ "text": f"**API REQUEST ERROR** Reason: {e}.",
273
+ "error_code": 1,
274
+ }
275
+
276
+ if res.status_code != 200:
277
+ logger.error(f"==== error ==== ({res.status_code}): {res.text}")
278
+ yield {
279
+ "text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.",
280
+ "error_code": 1,
281
+ }
282
+
283
+ response_json = res.json()
284
+ if "candidates" not in response_json:
285
+ logger.error(f"==== error ==== response blocked: {response_json}")
286
+ reason = response_json["filters"][0]["reason"]
287
+ yield {
288
+ "text": f"**API REQUEST ERROR** Reason: {reason}.",
289
+ "error_code": 1,
290
+ }
291
+
292
+ response = response_json["candidates"][0]["content"]
293
+ pos = 0
294
+ while pos < len(response):
295
+ # simulate token streaming
296
+ pos += random.randint(3, 6)
297
+ time.sleep(0.002)
298
+ data = {
299
+ "text": response[:pos],
300
+ "error_code": 0,
301
+ }
302
+ yield data
303
+
304
+
305
+ def ai2_api_stream_iter(
306
+ model_name,
307
+ model_id,
308
+ messages,
309
+ temperature,
310
+ top_p,
311
+ max_new_tokens,
312
+ api_key=None,
313
+ api_base=None,
314
+ ):
315
+ # get keys and needed values
316
+ ai2_key = api_key or os.environ.get("AI2_API_KEY")
317
+ api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
318
+
319
+ # Make requests
320
+ gen_params = {
321
+ "model": model_name,
322
+ "prompt": messages,
323
+ "temperature": temperature,
324
+ "top_p": top_p,
325
+ "max_new_tokens": max_new_tokens,
326
+ }
327
+ logger.info(f"==== request ====\n{gen_params}")
328
+
329
+ # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
330
+ # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
331
+ if temperature == 0.0 and top_p < 1.0:
332
+ raise ValueError("top_p must be 1 when temperature is 0.0")
333
+
334
+ res = requests.post(
335
+ api_base,
336
+ stream=True,
337
+ headers={"Authorization": f"Bearer {ai2_key}"},
338
+ json={
339
+ "model_id": model_id,
340
+ # This input format is specific to the Tulu2 model. Other models
341
+ # may require different input formats. See the model's schema
342
+ # documentation on InferD for more information.
343
+ "input": {
344
+ "messages": messages,
345
+ "opts": {
346
+ "max_tokens": max_new_tokens,
347
+ "temperature": temperature,
348
+ "top_p": top_p,
349
+ "logprobs": 1, # increase for more choices
350
+ },
351
+ },
352
+ },
353
+ timeout=5,
354
+ )
355
+
356
+ if res.status_code != 200:
357
+ logger.error(f"unexpected response ({res.status_code}): {res.text}")
358
+ raise ValueError("unexpected response from InferD", res)
359
+
360
+ text = ""
361
+ for line in res.iter_lines():
362
+ if line:
363
+ part = json.loads(line)
364
+ if "result" in part and "output" in part["result"]:
365
+ for t in part["result"]["output"]["text"]:
366
+ text += t
367
+ else:
368
+ logger.error(f"unexpected part: {part}")
369
+ raise ValueError("empty result in InferD response")
370
+
371
+ data = {
372
+ "text": text,
373
+ "error_code": 0,
374
+ }
375
+ yield data
376
+
377
+
378
+ def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
379
+ from mistralai.client import MistralClient
380
+ from mistralai.models.chat_completion import ChatMessage
381
+
382
+ api_key = os.environ["MISTRAL_API_KEY"]
383
+
384
+ client = MistralClient(api_key=api_key)
385
+
386
+ # Make requests
387
+ gen_params = {
388
+ "model": model_name,
389
+ "prompt": messages,
390
+ "temperature": temperature,
391
+ "top_p": top_p,
392
+ "max_new_tokens": max_new_tokens,
393
+ }
394
+ logger.info(f"==== request ====\n{gen_params}")
395
+
396
+ new_messages = [
397
+ ChatMessage(role=message["role"], content=message["content"])
398
+ for message in messages
399
+ ]
400
+
401
+ res = client.chat_stream(
402
+ model=model_name,
403
+ temperature=temperature,
404
+ messages=new_messages,
405
+ max_tokens=max_new_tokens,
406
+ top_p=top_p,
407
+ )
408
+
409
+ text = ""
410
+ for chunk in res:
411
+ if chunk.choices[0].delta.content is not None:
412
+ text += chunk.choices[0].delta.content
413
+ data = {
414
+ "text": text,
415
+ "error_code": 0,
416
+ }
417
+ yield data
418
+
419
+
420
+ def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base):
421
+ assert model_name in ["llama2-70b-steerlm-chat", "yi-34b-chat"]
422
+
423
+ api_key = os.environ["NVIDIA_API_KEY"]
424
+ headers = {
425
+ "Authorization": f"Bearer {api_key}",
426
+ "accept": "text/event-stream",
427
+ "content-type": "application/json",
428
+ }
429
+ # nvidia api does not accept 0 temperature
430
+ if temp == 0.0:
431
+ temp = 0.0001
432
+
433
+ payload = {
434
+ "messages": messages,
435
+ "temperature": temp,
436
+ "top_p": top_p,
437
+ "max_tokens": max_tokens,
438
+ "seed": 42,
439
+ "stream": True,
440
+ }
441
+ logger.info(f"==== request ====\n{payload}")
442
+
443
+ response = requests.post(
444
+ api_base, headers=headers, json=payload, stream=True, timeout=1
445
+ )
446
+ text = ""
447
+ for line in response.iter_lines():
448
+ if line:
449
+ data = line.decode("utf-8")
450
+ if data.endswith("[DONE]"):
451
+ break
452
+ data = json.loads(data[6:])["choices"][0]["delta"]["content"]
453
+ text += data
454
+ yield {"text": text, "error_code": 0}
base_model_worker.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import threading
3
+ import time
4
+ from typing import List
5
+
6
+ from fastapi import FastAPI, Request, BackgroundTasks
7
+ from fastapi.responses import StreamingResponse, JSONResponse
8
+ import requests
9
+
10
+ from fastchat.constants import WORKER_HEART_BEAT_INTERVAL
11
+ from fastchat.conversation import Conversation
12
+ from fastchat.utils import pretty_print_semaphore, build_logger
13
+
14
+
15
+ worker = None
16
+ logger = None
17
+
18
+ app = FastAPI()
19
+
20
+
21
+ def heart_beat_worker(obj):
22
+ while True:
23
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
24
+ obj.send_heart_beat()
25
+
26
+
27
+ class BaseModelWorker:
28
+ def __init__(
29
+ self,
30
+ controller_addr: str,
31
+ worker_addr: str,
32
+ worker_id: str,
33
+ model_path: str,
34
+ model_names: List[str],
35
+ limit_worker_concurrency: int,
36
+ conv_template: str = None,
37
+ multimodal: bool = False,
38
+ ):
39
+ global logger, worker
40
+
41
+ self.controller_addr = controller_addr
42
+ self.worker_addr = worker_addr
43
+ self.worker_id = worker_id
44
+ if model_path.endswith("/"):
45
+ model_path = model_path[:-1]
46
+ self.model_names = model_names or [model_path.split("/")[-1]]
47
+ self.limit_worker_concurrency = limit_worker_concurrency
48
+ self.conv = self.make_conv_template(conv_template, model_path)
49
+ self.conv.sep_style = int(self.conv.sep_style)
50
+ self.multimodal = multimodal
51
+ self.tokenizer = None
52
+ self.context_len = None
53
+ self.call_ct = 0
54
+ self.semaphore = None
55
+
56
+ self.heart_beat_thread = None
57
+
58
+ if logger is None:
59
+ logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log")
60
+ if worker is None:
61
+ worker = self
62
+
63
+ def make_conv_template(
64
+ self,
65
+ conv_template: str = None,
66
+ model_path: str = None,
67
+ ) -> Conversation:
68
+ """
69
+ can be overrided to costomize the conversation template for different model workers.
70
+ """
71
+ from fastchat.conversation import get_conv_template
72
+ from fastchat.model.model_adapter import get_conversation_template
73
+
74
+ if conv_template:
75
+ conv = get_conv_template(conv_template)
76
+ else:
77
+ conv = get_conversation_template(model_path)
78
+ return conv
79
+
80
+ def init_heart_beat(self):
81
+ self.register_to_controller()
82
+ self.heart_beat_thread = threading.Thread(
83
+ target=heart_beat_worker,
84
+ args=(self,),
85
+ daemon=True,
86
+ )
87
+ self.heart_beat_thread.start()
88
+
89
+ def register_to_controller(self):
90
+ logger.info("Register to controller")
91
+
92
+ url = self.controller_addr + "/register_worker"
93
+ data = {
94
+ "worker_name": self.worker_addr,
95
+ "check_heart_beat": True,
96
+ "worker_status": self.get_status(),
97
+ "multimodal": self.multimodal,
98
+ }
99
+ r = requests.post(url, json=data)
100
+ assert r.status_code == 200
101
+
102
+ def send_heart_beat(self):
103
+ logger.info(
104
+ f"Send heart beat. Models: {self.model_names}. "
105
+ f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
106
+ f"call_ct: {self.call_ct}. "
107
+ f"worker_id: {self.worker_id}. "
108
+ )
109
+
110
+ url = self.controller_addr + "/receive_heart_beat"
111
+
112
+ while True:
113
+ try:
114
+ ret = requests.post(
115
+ url,
116
+ json={
117
+ "worker_name": self.worker_addr,
118
+ "queue_length": self.get_queue_length(),
119
+ },
120
+ timeout=5,
121
+ )
122
+ exist = ret.json()["exist"]
123
+ break
124
+ except (requests.exceptions.RequestException, KeyError) as e:
125
+ logger.error(f"heart beat error: {e}")
126
+ time.sleep(5)
127
+
128
+ if not exist:
129
+ self.register_to_controller()
130
+
131
+ def get_queue_length(self):
132
+ if self.semaphore is None:
133
+ return 0
134
+ else:
135
+ sempahore_value = (
136
+ self.semaphore._value
137
+ if self.semaphore._value is not None
138
+ else self.limit_worker_concurrency
139
+ )
140
+ waiter_count = (
141
+ 0 if self.semaphore._waiters is None else len(self.semaphore._waiters)
142
+ )
143
+ return self.limit_worker_concurrency - sempahore_value + waiter_count
144
+
145
+ def get_status(self):
146
+ return {
147
+ "model_names": self.model_names,
148
+ "speed": 1,
149
+ "queue_length": self.get_queue_length(),
150
+ }
151
+
152
+ def count_token(self, params):
153
+ prompt = params["prompt"]
154
+
155
+ try:
156
+ input_ids = self.tokenizer(prompt).input_ids
157
+ input_echo_len = len(input_ids)
158
+ except TypeError:
159
+ input_echo_len = self.tokenizer.num_tokens(prompt)
160
+
161
+ ret = {
162
+ "count": input_echo_len,
163
+ "error_code": 0,
164
+ }
165
+ return ret
166
+
167
+ def get_conv_template(self):
168
+ return {"conv": self.conv}
169
+
170
+ def generate_stream_gate(self, params):
171
+ raise NotImplementedError
172
+
173
+ def generate_gate(self, params):
174
+ raise NotImplementedError
175
+
176
+ def get_embeddings(self, params):
177
+ raise NotImplementedError
178
+
179
+
180
+ def release_worker_semaphore():
181
+ worker.semaphore.release()
182
+
183
+
184
+ def acquire_worker_semaphore():
185
+ if worker.semaphore is None:
186
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
187
+ return worker.semaphore.acquire()
188
+
189
+
190
+ def create_background_tasks():
191
+ background_tasks = BackgroundTasks()
192
+ background_tasks.add_task(release_worker_semaphore)
193
+ return background_tasks
194
+
195
+
196
+ @app.post("/worker_generate_stream")
197
+ async def api_generate_stream(request: Request):
198
+ params = await request.json()
199
+ await acquire_worker_semaphore()
200
+ generator = worker.generate_stream_gate(params)
201
+ background_tasks = create_background_tasks()
202
+ return StreamingResponse(generator, background=background_tasks)
203
+
204
+
205
+ @app.post("/worker_generate")
206
+ async def api_generate(request: Request):
207
+ params = await request.json()
208
+ await acquire_worker_semaphore()
209
+ output = await asyncio.to_thread(worker.generate_gate, params)
210
+ release_worker_semaphore()
211
+ return JSONResponse(output)
212
+
213
+
214
+ @app.post("/worker_get_embeddings")
215
+ async def api_get_embeddings(request: Request):
216
+ params = await request.json()
217
+ await acquire_worker_semaphore()
218
+ embedding = worker.get_embeddings(params)
219
+ release_worker_semaphore()
220
+ return JSONResponse(content=embedding)
221
+
222
+
223
+ @app.post("/worker_get_status")
224
+ async def api_get_status(request: Request):
225
+ return worker.get_status()
226
+
227
+
228
+ @app.post("/count_token")
229
+ async def api_count_token(request: Request):
230
+ params = await request.json()
231
+ return worker.count_token(params)
232
+
233
+
234
+ @app.post("/worker_get_conv_template")
235
+ async def api_get_conv(request: Request):
236
+ return worker.get_conv_template()
237
+
238
+
239
+ @app.post("/model_details")
240
+ async def api_model_details(request: Request):
241
+ return {"context_length": worker.context_len}
call_monitor.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import glob
4
+ import time
5
+
6
+ from fastapi import FastAPI
7
+ import hashlib
8
+ import asyncio
9
+
10
+ REFRESH_INTERVAL_SEC = 60
11
+ LOG_DIR = "/home/vicuna/fastchat_logs/server0"
12
+ # LOG_DIR = "/home/vicuna/tmp/test_env"
13
+
14
+
15
+ class Monitor:
16
+ """Monitor the number of calls to each model."""
17
+
18
+ def __init__(self, log_dir: str):
19
+ self.log_dir = log_dir
20
+ self.model_call = {}
21
+ self.user_call = {}
22
+ self.model_call_limit_global = {
23
+ "gpt-4-1106-preview": 300,
24
+ "gpt-4-0125-preview": 300,
25
+ }
26
+ self.model_call_day_limit_per_user = {"gpt-4-1106-preview": 10}
27
+
28
+ async def update_stats(self, num_file=1) -> None:
29
+ while True:
30
+ # find the latest num_file log under log_dir
31
+ json_files = glob.glob(os.path.join(self.log_dir, "*.json"))
32
+ json_files.sort(key=os.path.getctime, reverse=True)
33
+ json_files = json_files[:num_file]
34
+
35
+ model_call = {}
36
+ user_call = {}
37
+ for json_file in json_files:
38
+ for line in open(json_file, "r", encoding="utf-8"):
39
+ obj = json.loads(line)
40
+ if obj["type"] != "chat":
41
+ continue
42
+ if obj["model"] not in model_call:
43
+ model_call[obj["model"]] = []
44
+ model_call[obj["model"]].append(
45
+ {"tstamp": obj["tstamp"], "user_id": obj["ip"]}
46
+ )
47
+ if obj["ip"] not in user_call:
48
+ user_call[obj["ip"]] = []
49
+ user_call[obj["ip"]].append(
50
+ {"tstamp": obj["tstamp"], "model": obj["model"]}
51
+ )
52
+
53
+ self.model_call = model_call
54
+ self.model_call_stats_hour = self.get_model_call_stats(top_k=None)
55
+ self.model_call_stats_day = self.get_model_call_stats(
56
+ top_k=None, most_recent_min=24 * 60
57
+ )
58
+
59
+ self.user_call = user_call
60
+ self.user_call_stats_hour = self.get_user_call_stats(top_k=None)
61
+ self.user_call_stats_day = self.get_user_call_stats(
62
+ top_k=None, most_recent_min=24 * 60
63
+ )
64
+ await asyncio.sleep(REFRESH_INTERVAL_SEC)
65
+
66
+ def get_model_call_limit(self, model: str) -> int:
67
+ if model not in self.model_call_limit_global:
68
+ return -1
69
+ return self.model_call_limit_global[model]
70
+
71
+ def update_model_call_limit(self, model: str, limit: int) -> bool:
72
+ if model not in self.model_call_limit_global:
73
+ return False
74
+ self.model_call_limit_global[model] = limit
75
+ return True
76
+
77
+ def is_model_limit_reached(self, model: str) -> bool:
78
+ if model not in self.model_call_limit_global:
79
+ return False
80
+ if model not in self.model_call_stats_hour:
81
+ return False
82
+ # check if the model call limit is reached
83
+ if self.model_call_stats_hour[model] >= self.model_call_limit_global[model]:
84
+ return True
85
+ return False
86
+
87
+ def is_user_limit_reached(self, model: str, user_id: str) -> bool:
88
+ if model not in self.model_call_day_limit_per_user:
89
+ return False
90
+ if user_id not in self.user_call_stats_day:
91
+ return False
92
+ if model not in self.user_call_stats_day[user_id]["call_dict"]:
93
+ return False
94
+ # check if the user call limit is reached
95
+ if (
96
+ self.user_call_stats_day[user_id]["call_dict"][model]
97
+ >= self.model_call_day_limit_per_user[model]
98
+ ):
99
+ return True
100
+ return False
101
+
102
+ def get_model_call_stats(
103
+ self, target_model=None, most_recent_min: int = 60, top_k: int = 20
104
+ ) -> dict:
105
+ model_call_stats = {}
106
+ for model, reqs in self.model_call.items():
107
+ if target_model is not None and model != target_model:
108
+ continue
109
+ model_call = []
110
+ for req in reqs:
111
+ if req["tstamp"] < time.time() - most_recent_min * 60:
112
+ continue
113
+ model_call.append(req["tstamp"])
114
+ model_call_stats[model] = len(model_call)
115
+ if top_k is not None:
116
+ top_k_model = sorted(
117
+ model_call_stats, key=lambda x: model_call_stats[x], reverse=True
118
+ )[:top_k]
119
+ model_call_stats = {model: model_call_stats[model] for model in top_k_model}
120
+ return model_call_stats
121
+
122
+ def get_user_call_stats(
123
+ self, target_model=None, most_recent_min: int = 60, top_k: int = 20
124
+ ) -> dict:
125
+ user_call_stats = {}
126
+ for user_id, reqs in self.user_call.items():
127
+ user_model_call = {"call_dict": {}}
128
+ for req in reqs:
129
+ if req["tstamp"] < time.time() - most_recent_min * 60:
130
+ continue
131
+ if target_model is not None and req["model"] != target_model:
132
+ continue
133
+ if req["model"] not in user_model_call["call_dict"]:
134
+ user_model_call["call_dict"][req["model"]] = 0
135
+ user_model_call["call_dict"][req["model"]] += 1
136
+
137
+ user_model_call["total_calls"] = sum(user_model_call["call_dict"].values())
138
+ if user_model_call["total_calls"] > 0:
139
+ user_call_stats[user_id] = user_model_call
140
+ if top_k is not None:
141
+ top_k_user = sorted(
142
+ user_call_stats,
143
+ key=lambda x: user_call_stats[x]["total_calls"],
144
+ reverse=True,
145
+ )[:top_k]
146
+ user_call_stats = {
147
+ user_id: user_call_stats[user_id] for user_id in top_k_user
148
+ }
149
+ return user_call_stats
150
+
151
+ def get_num_users(self, most_recent_min: int = 60) -> int:
152
+ user_call_stats = self.get_user_call_stats(
153
+ most_recent_min=most_recent_min, top_k=None
154
+ )
155
+ return len(user_call_stats)
156
+
157
+
158
+ monitor = Monitor(log_dir=LOG_DIR)
159
+ app = FastAPI()
160
+
161
+
162
+ @app.on_event("startup")
163
+ async def app_startup():
164
+ asyncio.create_task(monitor.update_stats(2))
165
+
166
+
167
+ @app.get("/get_model_call_limit/{model}")
168
+ async def get_model_call_limit(model: str):
169
+ return {"model_call_limit": {model: monitor.get_model_call_limit(model)}}
170
+
171
+
172
+ @app.get("/update_model_call_limit/{model}/{limit}")
173
+ async def update_model_call_limit(model: str, limit: int):
174
+ if not monitor.update_model_call_limit(model, limit):
175
+ return {"success": False}
176
+ return {"success": True}
177
+
178
+
179
+ @app.get("/is_limit_reached")
180
+ async def is_limit_reached(model: str, user_id: str):
181
+ if monitor.is_model_limit_reached(model):
182
+ return {
183
+ "is_limit_reached": True,
184
+ "reason": f"MODEL_HOURLY_LIMIT ({model}): {monitor.get_model_call_limit(model)}",
185
+ }
186
+ if monitor.is_user_limit_reached(model, user_id):
187
+ return {
188
+ "is_limit_reached": True,
189
+ "reason": f"USER_DAILY_LIMIT ({model}): {monitor.model_call_day_limit_per_user[model]}",
190
+ }
191
+ return {"is_limit_reached": False}
192
+
193
+
194
+ @app.get("/get_num_users_hr")
195
+ async def get_num_users():
196
+ return {"num_users": len(monitor.user_call_stats_hour)}
197
+
198
+
199
+ @app.get("/get_num_users_day")
200
+ async def get_num_users_day():
201
+ return {"num_users": len(monitor.user_call_stats_day)}
202
+
203
+
204
+ @app.get("/get_user_call_stats")
205
+ async def get_user_call_stats(
206
+ model: str = None, most_recent_min: int = 60, top_k: int = None
207
+ ):
208
+ return {
209
+ "user_call_stats": monitor.get_user_call_stats(model, most_recent_min, top_k)
210
+ }
211
+
212
+
213
+ @app.get("/get_model_call_stats")
214
+ async def get_model_call_stats(
215
+ model: str = None, most_recent_min: int = 60, top_k: int = None
216
+ ):
217
+ return {
218
+ "model_call_stats": monitor.get_model_call_stats(model, most_recent_min, top_k)
219
+ }
cli.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat with a model with command line interface.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5
6
+ python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0
7
+
8
+ Other commands:
9
+ - Type "!!exit" or an empty line to exit.
10
+ - Type "!!reset" to start a new conversation.
11
+ - Type "!!remove" to remove the last prompt.
12
+ - Type "!!regen" to regenerate the last message.
13
+ - Type "!!save <filename>" to save the conversation history to a json file.
14
+ - Type "!!load <filename>" to load a conversation history from a json file.
15
+ """
16
+ import argparse
17
+ import os
18
+ import re
19
+ import sys
20
+
21
+ from prompt_toolkit import PromptSession
22
+ from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
23
+ from prompt_toolkit.completion import WordCompleter
24
+ from prompt_toolkit.history import InMemoryHistory
25
+ from prompt_toolkit.key_binding import KeyBindings
26
+ from rich.console import Console
27
+ from rich.live import Live
28
+ from rich.markdown import Markdown
29
+ import torch
30
+
31
+ from fastchat.model.model_adapter import add_model_args
32
+ from fastchat.modules.awq import AWQConfig
33
+ from fastchat.modules.exllama import ExllamaConfig
34
+ from fastchat.modules.xfastertransformer import XftConfig
35
+ from fastchat.modules.gptq import GptqConfig
36
+ from fastchat.serve.inference import ChatIO, chat_loop
37
+ from fastchat.utils import str_to_torch_dtype
38
+
39
+
40
+ class SimpleChatIO(ChatIO):
41
+ def __init__(self, multiline: bool = False):
42
+ self._multiline = multiline
43
+
44
+ def prompt_for_input(self, role) -> str:
45
+ if not self._multiline:
46
+ return input(f"{role}: ")
47
+
48
+ prompt_data = []
49
+ line = input(f"{role} [ctrl-d/z on empty line to end]: ")
50
+ while True:
51
+ prompt_data.append(line.strip())
52
+ try:
53
+ line = input()
54
+ except EOFError as e:
55
+ break
56
+ return "\n".join(prompt_data)
57
+
58
+ def prompt_for_output(self, role: str):
59
+ print(f"{role}: ", end="", flush=True)
60
+
61
+ def stream_output(self, output_stream):
62
+ pre = 0
63
+ for outputs in output_stream:
64
+ output_text = outputs["text"]
65
+ output_text = output_text.strip().split(" ")
66
+ now = len(output_text) - 1
67
+ if now > pre:
68
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
69
+ pre = now
70
+ print(" ".join(output_text[pre:]), flush=True)
71
+ return " ".join(output_text)
72
+
73
+ def print_output(self, text: str):
74
+ print(text)
75
+
76
+
77
+ class RichChatIO(ChatIO):
78
+ bindings = KeyBindings()
79
+
80
+ @bindings.add("escape", "enter")
81
+ def _(event):
82
+ event.app.current_buffer.newline()
83
+
84
+ def __init__(self, multiline: bool = False, mouse: bool = False):
85
+ self._prompt_session = PromptSession(history=InMemoryHistory())
86
+ self._completer = WordCompleter(
87
+ words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"],
88
+ pattern=re.compile("$"),
89
+ )
90
+ self._console = Console()
91
+ self._multiline = multiline
92
+ self._mouse = mouse
93
+
94
+ def prompt_for_input(self, role) -> str:
95
+ self._console.print(f"[bold]{role}:")
96
+ # TODO(suquark): multiline input has some issues. fix it later.
97
+ prompt_input = self._prompt_session.prompt(
98
+ completer=self._completer,
99
+ multiline=False,
100
+ mouse_support=self._mouse,
101
+ auto_suggest=AutoSuggestFromHistory(),
102
+ key_bindings=self.bindings if self._multiline else None,
103
+ )
104
+ self._console.print()
105
+ return prompt_input
106
+
107
+ def prompt_for_output(self, role: str):
108
+ self._console.print(f"[bold]{role.replace('/', '|')}:")
109
+
110
+ def stream_output(self, output_stream):
111
+ """Stream output from a role."""
112
+ # TODO(suquark): the console flickers when there is a code block
113
+ # above it. We need to cut off "live" when a code block is done.
114
+
115
+ # Create a Live context for updating the console output
116
+ with Live(console=self._console, refresh_per_second=4) as live:
117
+ # Read lines from the stream
118
+ for outputs in output_stream:
119
+ if not outputs:
120
+ continue
121
+ text = outputs["text"]
122
+ # Render the accumulated text as Markdown
123
+ # NOTE: this is a workaround for the rendering "unstandard markdown"
124
+ # in rich. The chatbots output treat "\n" as a new line for
125
+ # better compatibility with real-world text. However, rendering
126
+ # in markdown would break the format. It is because standard markdown
127
+ # treat a single "\n" in normal text as a space.
128
+ # Our workaround is adding two spaces at the end of each line.
129
+ # This is not a perfect solution, as it would
130
+ # introduce trailing spaces (only) in code block, but it works well
131
+ # especially for console output, because in general the console does not
132
+ # care about trailing spaces.
133
+ lines = []
134
+ for line in text.splitlines():
135
+ lines.append(line)
136
+ if line.startswith("```"):
137
+ # Code block marker - do not add trailing spaces, as it would
138
+ # break the syntax highlighting
139
+ lines.append("\n")
140
+ else:
141
+ lines.append(" \n")
142
+ markdown = Markdown("".join(lines))
143
+ # Update the Live console output
144
+ live.update(markdown)
145
+ self._console.print()
146
+ return text
147
+
148
+ def print_output(self, text: str):
149
+ self.stream_output([{"text": text}])
150
+
151
+
152
+ class ProgrammaticChatIO(ChatIO):
153
+ def prompt_for_input(self, role) -> str:
154
+ contents = ""
155
+ # `end_sequence` signals the end of a message. It is unlikely to occur in
156
+ # message content.
157
+ end_sequence = " __END_OF_A_MESSAGE_47582648__\n"
158
+ len_end = len(end_sequence)
159
+ while True:
160
+ if len(contents) >= len_end:
161
+ last_chars = contents[-len_end:]
162
+ if last_chars == end_sequence:
163
+ break
164
+ try:
165
+ char = sys.stdin.read(1)
166
+ contents = contents + char
167
+ except EOFError:
168
+ continue
169
+ contents = contents[:-len_end]
170
+ print(f"[!OP:{role}]: {contents}", flush=True)
171
+ return contents
172
+
173
+ def prompt_for_output(self, role: str):
174
+ print(f"[!OP:{role}]: ", end="", flush=True)
175
+
176
+ def stream_output(self, output_stream):
177
+ pre = 0
178
+ for outputs in output_stream:
179
+ output_text = outputs["text"]
180
+ output_text = output_text.strip().split(" ")
181
+ now = len(output_text) - 1
182
+ if now > pre:
183
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
184
+ pre = now
185
+ print(" ".join(output_text[pre:]), flush=True)
186
+ return " ".join(output_text)
187
+
188
+ def print_output(self, text: str):
189
+ print(text)
190
+
191
+
192
+ def main(args):
193
+ if args.gpus:
194
+ if len(args.gpus.split(",")) < args.num_gpus:
195
+ raise ValueError(
196
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
197
+ )
198
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
199
+ os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
200
+ if args.enable_exllama:
201
+ exllama_config = ExllamaConfig(
202
+ max_seq_len=args.exllama_max_seq_len,
203
+ gpu_split=args.exllama_gpu_split,
204
+ cache_8bit=args.exllama_cache_8bit,
205
+ )
206
+ else:
207
+ exllama_config = None
208
+ if args.enable_xft:
209
+ xft_config = XftConfig(
210
+ max_seq_len=args.xft_max_seq_len,
211
+ data_type=args.xft_dtype,
212
+ )
213
+ if args.device != "cpu":
214
+ print("xFasterTransformer now is only support CPUs. Reset device to CPU")
215
+ args.device = "cpu"
216
+ else:
217
+ xft_config = None
218
+ if args.style == "simple":
219
+ chatio = SimpleChatIO(args.multiline)
220
+ elif args.style == "rich":
221
+ chatio = RichChatIO(args.multiline, args.mouse)
222
+ elif args.style == "programmatic":
223
+ chatio = ProgrammaticChatIO()
224
+ else:
225
+ raise ValueError(f"Invalid style for console: {args.style}")
226
+ try:
227
+ chat_loop(
228
+ args.model_path,
229
+ args.device,
230
+ args.num_gpus,
231
+ args.max_gpu_memory,
232
+ str_to_torch_dtype(args.dtype),
233
+ args.load_8bit,
234
+ args.cpu_offloading,
235
+ args.conv_template,
236
+ args.conv_system_msg,
237
+ args.temperature,
238
+ args.repetition_penalty,
239
+ args.max_new_tokens,
240
+ chatio,
241
+ gptq_config=GptqConfig(
242
+ ckpt=args.gptq_ckpt or args.model_path,
243
+ wbits=args.gptq_wbits,
244
+ groupsize=args.gptq_groupsize,
245
+ act_order=args.gptq_act_order,
246
+ ),
247
+ awq_config=AWQConfig(
248
+ ckpt=args.awq_ckpt or args.model_path,
249
+ wbits=args.awq_wbits,
250
+ groupsize=args.awq_groupsize,
251
+ ),
252
+ exllama_config=exllama_config,
253
+ xft_config=xft_config,
254
+ revision=args.revision,
255
+ judge_sent_end=args.judge_sent_end,
256
+ debug=args.debug,
257
+ history=not args.no_history,
258
+ )
259
+ except KeyboardInterrupt:
260
+ print("exit...")
261
+
262
+
263
+ if __name__ == "__main__":
264
+ parser = argparse.ArgumentParser()
265
+ add_model_args(parser)
266
+ parser.add_argument(
267
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
268
+ )
269
+ parser.add_argument(
270
+ "--conv-system-msg", type=str, default=None, help="Conversation system message."
271
+ )
272
+ parser.add_argument("--temperature", type=float, default=0.7)
273
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
274
+ parser.add_argument("--max-new-tokens", type=int, default=512)
275
+ parser.add_argument("--no-history", action="store_true")
276
+ parser.add_argument(
277
+ "--style",
278
+ type=str,
279
+ default="simple",
280
+ choices=["simple", "rich", "programmatic"],
281
+ help="Display style.",
282
+ )
283
+ parser.add_argument(
284
+ "--multiline",
285
+ action="store_true",
286
+ help="Enable multiline input. Use ESC+Enter for newline.",
287
+ )
288
+ parser.add_argument(
289
+ "--mouse",
290
+ action="store_true",
291
+ help="[Rich Style]: Enable mouse support for cursor positioning.",
292
+ )
293
+ parser.add_argument(
294
+ "--judge-sent-end",
295
+ action="store_true",
296
+ help="Whether enable the correction logic that interrupts the output of sentences due to EOS.",
297
+ )
298
+ parser.add_argument(
299
+ "--debug",
300
+ action="store_true",
301
+ help="Print useful debug information (e.g., prompts)",
302
+ )
303
+ args = parser.parse_args()
304
+ main(args)
controller.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import os
12
+ import time
13
+ from typing import List, Union
14
+ import threading
15
+
16
+ from fastapi import FastAPI, Request
17
+ from fastapi.responses import StreamingResponse
18
+ import numpy as np
19
+ import requests
20
+ import uvicorn
21
+
22
+ from fastchat.constants import (
23
+ CONTROLLER_HEART_BEAT_EXPIRATION,
24
+ WORKER_API_TIMEOUT,
25
+ ErrorCode,
26
+ SERVER_ERROR_MSG,
27
+ )
28
+ from fastchat.utils import build_logger
29
+
30
+
31
+ logger = build_logger("controller", "controller.log")
32
+
33
+
34
+ class DispatchMethod(Enum):
35
+ LOTTERY = auto()
36
+ SHORTEST_QUEUE = auto()
37
+
38
+ @classmethod
39
+ def from_str(cls, name):
40
+ if name == "lottery":
41
+ return cls.LOTTERY
42
+ elif name == "shortest_queue":
43
+ return cls.SHORTEST_QUEUE
44
+ else:
45
+ raise ValueError(f"Invalid dispatch method")
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class WorkerInfo:
50
+ model_names: List[str]
51
+ speed: int
52
+ queue_length: int
53
+ check_heart_beat: bool
54
+ last_heart_beat: str
55
+ multimodal: bool
56
+
57
+
58
+ def heart_beat_controller(controller):
59
+ while True:
60
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
61
+ controller.remove_stale_workers_by_expiration()
62
+
63
+
64
+ class Controller:
65
+ def __init__(self, dispatch_method: str):
66
+ # Dict[str -> WorkerInfo]
67
+ self.worker_info = {}
68
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
69
+
70
+ self.heart_beat_thread = threading.Thread(
71
+ target=heart_beat_controller, args=(self,)
72
+ )
73
+ self.heart_beat_thread.start()
74
+
75
+ def register_worker(
76
+ self,
77
+ worker_name: str,
78
+ check_heart_beat: bool,
79
+ worker_status: dict,
80
+ multimodal: bool,
81
+ ):
82
+ if worker_name not in self.worker_info:
83
+ logger.info(f"Register a new worker: {worker_name}")
84
+ else:
85
+ logger.info(f"Register an existing worker: {worker_name}")
86
+
87
+ if not worker_status:
88
+ worker_status = self.get_worker_status(worker_name)
89
+ if not worker_status:
90
+ return False
91
+
92
+ self.worker_info[worker_name] = WorkerInfo(
93
+ worker_status["model_names"],
94
+ worker_status["speed"],
95
+ worker_status["queue_length"],
96
+ check_heart_beat,
97
+ time.time(),
98
+ multimodal,
99
+ )
100
+
101
+ logger.info(f"Register done: {worker_name}, {worker_status}")
102
+ return True
103
+
104
+ def get_worker_status(self, worker_name: str):
105
+ try:
106
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
107
+ except requests.exceptions.RequestException as e:
108
+ logger.error(f"Get status fails: {worker_name}, {e}")
109
+ return None
110
+
111
+ if r.status_code != 200:
112
+ logger.error(f"Get status fails: {worker_name}, {r}")
113
+ return None
114
+
115
+ return r.json()
116
+
117
+ def remove_worker(self, worker_name: str):
118
+ del self.worker_info[worker_name]
119
+
120
+ def refresh_all_workers(self):
121
+ old_info = dict(self.worker_info)
122
+ self.worker_info = {}
123
+
124
+ for w_name, w_info in old_info.items():
125
+ if not self.register_worker(
126
+ w_name, w_info.check_heart_beat, None, w_info.multimodal
127
+ ):
128
+ logger.info(f"Remove stale worker: {w_name}")
129
+
130
+ def list_models(self):
131
+ model_names = set()
132
+
133
+ for w_name, w_info in self.worker_info.items():
134
+ model_names.update(w_info.model_names)
135
+
136
+ return list(model_names)
137
+
138
+ def list_multimodal_models(self):
139
+ model_names = set()
140
+
141
+ for w_name, w_info in self.worker_info.items():
142
+ if w_info.multimodal:
143
+ model_names.update(w_info.model_names)
144
+
145
+ return list(model_names)
146
+
147
+ def list_language_models(self):
148
+ model_names = set()
149
+
150
+ for w_name, w_info in self.worker_info.items():
151
+ if not w_info.multimodal:
152
+ model_names.update(w_info.model_names)
153
+
154
+ return list(model_names)
155
+
156
+ def get_worker_address(self, model_name: str):
157
+ if self.dispatch_method == DispatchMethod.LOTTERY:
158
+ worker_names = []
159
+ worker_speeds = []
160
+ for w_name, w_info in self.worker_info.items():
161
+ if model_name in w_info.model_names:
162
+ worker_names.append(w_name)
163
+ worker_speeds.append(w_info.speed)
164
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
165
+ norm = np.sum(worker_speeds)
166
+ if norm < 1e-4:
167
+ return ""
168
+ worker_speeds = worker_speeds / norm
169
+ if True: # Directly return address
170
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
171
+ worker_name = worker_names[pt]
172
+ return worker_name
173
+
174
+ # Check status before returning
175
+ while True:
176
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
177
+ worker_name = worker_names[pt]
178
+
179
+ if self.get_worker_status(worker_name):
180
+ break
181
+ else:
182
+ self.remove_worker(worker_name)
183
+ worker_speeds[pt] = 0
184
+ norm = np.sum(worker_speeds)
185
+ if norm < 1e-4:
186
+ return ""
187
+ worker_speeds = worker_speeds / norm
188
+ continue
189
+ return worker_name
190
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
191
+ worker_names = []
192
+ worker_qlen = []
193
+ for w_name, w_info in self.worker_info.items():
194
+ if model_name in w_info.model_names:
195
+ worker_names.append(w_name)
196
+ worker_qlen.append(w_info.queue_length / w_info.speed)
197
+ if len(worker_names) == 0:
198
+ return ""
199
+ min_index = np.argmin(worker_qlen)
200
+ w_name = worker_names[min_index]
201
+ self.worker_info[w_name].queue_length += 1
202
+ logger.info(
203
+ f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
204
+ )
205
+ return w_name
206
+ else:
207
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
208
+
209
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
210
+ if worker_name not in self.worker_info:
211
+ logger.info(f"Receive unknown heart beat. {worker_name}")
212
+ return False
213
+
214
+ self.worker_info[worker_name].queue_length = queue_length
215
+ self.worker_info[worker_name].last_heart_beat = time.time()
216
+ logger.info(f"Receive heart beat. {worker_name}")
217
+ return True
218
+
219
+ def remove_stale_workers_by_expiration(self):
220
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
221
+ to_delete = []
222
+ for worker_name, w_info in self.worker_info.items():
223
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
224
+ to_delete.append(worker_name)
225
+
226
+ for worker_name in to_delete:
227
+ self.remove_worker(worker_name)
228
+
229
+ def handle_no_worker(self, params):
230
+ logger.info(f"no worker: {params['model']}")
231
+ ret = {
232
+ "text": SERVER_ERROR_MSG,
233
+ "error_code": ErrorCode.CONTROLLER_NO_WORKER,
234
+ }
235
+ return json.dumps(ret).encode() + b"\0"
236
+
237
+ def handle_worker_timeout(self, worker_address):
238
+ logger.info(f"worker timeout: {worker_address}")
239
+ ret = {
240
+ "text": SERVER_ERROR_MSG,
241
+ "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
242
+ }
243
+ return json.dumps(ret).encode() + b"\0"
244
+
245
+ # Let the controller act as a worker to achieve hierarchical
246
+ # management. This can be used to connect isolated sub networks.
247
+ def worker_api_get_status(self):
248
+ model_names = set()
249
+ speed = 0
250
+ queue_length = 0
251
+
252
+ for w_name in self.worker_info:
253
+ worker_status = self.get_worker_status(w_name)
254
+ if worker_status is not None:
255
+ model_names.update(worker_status["model_names"])
256
+ speed += worker_status["speed"]
257
+ queue_length += worker_status["queue_length"]
258
+
259
+ model_names = sorted(list(model_names))
260
+ return {
261
+ "model_names": model_names,
262
+ "speed": speed,
263
+ "queue_length": queue_length,
264
+ }
265
+
266
+ def worker_api_generate_stream(self, params):
267
+ worker_addr = self.get_worker_address(params["model"])
268
+ if not worker_addr:
269
+ yield self.handle_no_worker(params)
270
+
271
+ try:
272
+ response = requests.post(
273
+ worker_addr + "/worker_generate_stream",
274
+ json=params,
275
+ stream=True,
276
+ timeout=WORKER_API_TIMEOUT,
277
+ )
278
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
279
+ if chunk:
280
+ yield chunk + b"\0"
281
+ except requests.exceptions.RequestException as e:
282
+ yield self.handle_worker_timeout(worker_addr)
283
+
284
+
285
+ app = FastAPI()
286
+
287
+
288
+ @app.post("/register_worker")
289
+ async def register_worker(request: Request):
290
+ data = await request.json()
291
+ controller.register_worker(
292
+ data["worker_name"],
293
+ data["check_heart_beat"],
294
+ data.get("worker_status", None),
295
+ data.get("multimodal", False),
296
+ )
297
+
298
+
299
+ @app.post("/refresh_all_workers")
300
+ async def refresh_all_workers():
301
+ models = controller.refresh_all_workers()
302
+
303
+
304
+ @app.post("/list_models")
305
+ async def list_models():
306
+ models = controller.list_models()
307
+ return {"models": models}
308
+
309
+
310
+ @app.post("/list_multimodal_models")
311
+ async def list_multimodal_models():
312
+ models = controller.list_multimodal_models()
313
+ return {"models": models}
314
+
315
+
316
+ @app.post("/list_language_models")
317
+ async def list_language_models():
318
+ models = controller.list_language_models()
319
+ return {"models": models}
320
+
321
+
322
+ @app.post("/get_worker_address")
323
+ async def get_worker_address(request: Request):
324
+ data = await request.json()
325
+ addr = controller.get_worker_address(data["model"])
326
+ return {"address": addr}
327
+
328
+
329
+ @app.post("/receive_heart_beat")
330
+ async def receive_heart_beat(request: Request):
331
+ data = await request.json()
332
+ exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
333
+ return {"exist": exist}
334
+
335
+
336
+ @app.post("/worker_generate_stream")
337
+ async def worker_api_generate_stream(request: Request):
338
+ params = await request.json()
339
+ generator = controller.worker_api_generate_stream(params)
340
+ return StreamingResponse(generator)
341
+
342
+
343
+ @app.post("/worker_get_status")
344
+ async def worker_api_get_status(request: Request):
345
+ return controller.worker_api_get_status()
346
+
347
+
348
+ @app.get("/test_connection")
349
+ async def worker_api_get_status(request: Request):
350
+ return "success"
351
+
352
+
353
+ def create_controller():
354
+ parser = argparse.ArgumentParser()
355
+ parser.add_argument("--host", type=str, default="localhost")
356
+ parser.add_argument("--port", type=int, default=21001)
357
+ parser.add_argument(
358
+ "--dispatch-method",
359
+ type=str,
360
+ choices=["lottery", "shortest_queue"],
361
+ default="shortest_queue",
362
+ )
363
+ parser.add_argument(
364
+ "--ssl",
365
+ action="store_true",
366
+ required=False,
367
+ default=False,
368
+ help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
369
+ )
370
+ args = parser.parse_args()
371
+ logger.info(f"args: {args}")
372
+
373
+ controller = Controller(args.dispatch_method)
374
+ return args, controller
375
+
376
+
377
+ if __name__ == "__main__":
378
+ args, controller = create_controller()
379
+ if args.ssl:
380
+ uvicorn.run(
381
+ app,
382
+ host=args.host,
383
+ port=args.port,
384
+ log_level="info",
385
+ ssl_keyfile=os.environ["SSL_KEYFILE"],
386
+ ssl_certfile=os.environ["SSL_CERTFILE"],
387
+ )
388
+ else:
389
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
gradio_block_arena_anony.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot Arena (battle) tab.
3
+ Users chat with two anonymous models.
4
+ """
5
+
6
+ import json
7
+ import time
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+
12
+ from fastchat.constants import (
13
+ MODERATION_MSG,
14
+ CONVERSATION_LIMIT_MSG,
15
+ SLOW_MODEL_MSG,
16
+ INPUT_CHAR_LEN_LIMIT,
17
+ CONVERSATION_TURN_LIMIT,
18
+ )
19
+ from fastchat.model.model_adapter import get_conversation_template
20
+ from fastchat.serve.gradio_block_arena_named import flash_buttons
21
+ from fastchat.serve.gradio_web_server import (
22
+ State,
23
+ bot_response,
24
+ get_conv_log_filename,
25
+ no_change_btn,
26
+ enable_btn,
27
+ disable_btn,
28
+ invisible_btn,
29
+ acknowledgment_md,
30
+ get_ip,
31
+ get_model_description_md,
32
+ )
33
+ from fastchat.utils import (
34
+ build_logger,
35
+ moderation_filter,
36
+ )
37
+
38
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
39
+
40
+ num_sides = 2
41
+ enable_moderation = False
42
+ anony_names = ["", ""]
43
+ models = []
44
+
45
+
46
+ def set_global_vars_anony(enable_moderation_):
47
+ global enable_moderation
48
+ enable_moderation = enable_moderation_
49
+
50
+
51
+ def load_demo_side_by_side_anony(models_, url_params):
52
+ global models
53
+ models = models_
54
+
55
+ states = (None,) * num_sides
56
+ selector_updates = (
57
+ gr.Markdown(visible=True),
58
+ gr.Markdown(visible=True),
59
+ )
60
+
61
+ return states + selector_updates
62
+
63
+
64
+ def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
65
+ with open(get_conv_log_filename(), "a") as fout:
66
+ data = {
67
+ "tstamp": round(time.time(), 4),
68
+ "type": vote_type,
69
+ "models": [x for x in model_selectors],
70
+ "states": [x.dict() for x in states],
71
+ "ip": get_ip(request),
72
+ }
73
+ fout.write(json.dumps(data) + "\n")
74
+
75
+ if ":" not in model_selectors[0]:
76
+ for i in range(5):
77
+ names = (
78
+ "### Model A: " + states[0].model_name,
79
+ "### Model B: " + states[1].model_name,
80
+ )
81
+ yield names + ("",) + (disable_btn,) * 4
82
+ time.sleep(0.1)
83
+ else:
84
+ names = (
85
+ "### Model A: " + states[0].model_name,
86
+ "### Model B: " + states[1].model_name,
87
+ )
88
+ yield names + ("",) + (disable_btn,) * 4
89
+
90
+
91
+ def leftvote_last_response(
92
+ state0, state1, model_selector0, model_selector1, request: gr.Request
93
+ ):
94
+ logger.info(f"leftvote (anony). ip: {get_ip(request)}")
95
+ for x in vote_last_response(
96
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
97
+ ):
98
+ yield x
99
+
100
+
101
+ def rightvote_last_response(
102
+ state0, state1, model_selector0, model_selector1, request: gr.Request
103
+ ):
104
+ logger.info(f"rightvote (anony). ip: {get_ip(request)}")
105
+ for x in vote_last_response(
106
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
107
+ ):
108
+ yield x
109
+
110
+
111
+ def tievote_last_response(
112
+ state0, state1, model_selector0, model_selector1, request: gr.Request
113
+ ):
114
+ logger.info(f"tievote (anony). ip: {get_ip(request)}")
115
+ for x in vote_last_response(
116
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
117
+ ):
118
+ yield x
119
+
120
+
121
+ def bothbad_vote_last_response(
122
+ state0, state1, model_selector0, model_selector1, request: gr.Request
123
+ ):
124
+ logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
125
+ for x in vote_last_response(
126
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
127
+ ):
128
+ yield x
129
+
130
+
131
+ def regenerate(state0, state1, request: gr.Request):
132
+ logger.info(f"regenerate (anony). ip: {get_ip(request)}")
133
+ states = [state0, state1]
134
+ for i in range(num_sides):
135
+ states[i].conv.update_last_message(None)
136
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
137
+
138
+
139
+ def clear_history(request: gr.Request):
140
+ logger.info(f"clear_history (anony). ip: {get_ip(request)}")
141
+ return (
142
+ [None] * num_sides
143
+ + [None] * num_sides
144
+ + anony_names
145
+ + [""]
146
+ + [invisible_btn] * 4
147
+ + [disable_btn] * 2
148
+ + [""]
149
+ )
150
+
151
+
152
+ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
153
+ logger.info(f"share (anony). ip: {get_ip(request)}")
154
+ if state0 is not None and state1 is not None:
155
+ vote_last_response(
156
+ [state0, state1], "share", [model_selector0, model_selector1], request
157
+ )
158
+
159
+
160
+ SAMPLING_WEIGHTS = {
161
+ # tier 0
162
+ "gpt-4": 4,
163
+ "gpt-4-0314": 4,
164
+ "gpt-4-0613": 4,
165
+ "gpt-4-turbo": 4,
166
+ "gpt-4-1106-preview": 4,
167
+ "gpt-4-0125-preview": 4,
168
+ "gpt-3.5-turbo-0613": 2,
169
+ "gpt-3.5-turbo-1106": 2,
170
+ "gpt-3.5-turbo-0125": 4,
171
+ "claude-2.1": 4,
172
+ "claude-2.0": 2,
173
+ "claude-1": 2,
174
+ "claude-instant-1": 2,
175
+ "gemini-pro": 4,
176
+ "gemini-pro-dev-api": 4,
177
+ "bard-jan-24-gemini-pro": 4,
178
+ "bard-feb-2024": 4,
179
+ "mixtral-8x7b-instruct-v0.1": 4,
180
+ "mistral-medium": 4,
181
+ "qwen1.5-72b-chat": 4,
182
+ "qwen1.5-7b-chat": 2,
183
+ "qwen1.5-4b-chat": 2,
184
+ "nous-hermes-2-mixtral-8x7b-dpo": 2,
185
+ "deepseek-llm-67b-chat": 2,
186
+ "stripedhyena-nous-7b": 2,
187
+ "openchat-3.5-0106": 2,
188
+ "mistral-7b-instruct-v0.2": 2,
189
+ "solar-10.7b-instruct-v1.0": 2,
190
+ "dolphin-2.2.1-mistral-7b": 2,
191
+ "starling-lm-7b-alpha": 2,
192
+ "tulu-2-dpo-70b": 2,
193
+ "yi-34b-chat": 2,
194
+ "zephyr-7b-beta": 2,
195
+ # tier 1
196
+ "deluxe-chat-v1.2": 4,
197
+ "llama-2-70b-chat": 4,
198
+ "llama-2-13b-chat": 2,
199
+ "llama-2-7b-chat": 2,
200
+ "mistral-7b-instruct": 2,
201
+ "codellama-34b-instruct": 1.5,
202
+ "vicuna-33b": 2,
203
+ "vicuna-13b": 1.5,
204
+ "wizardlm-13b": 1.5,
205
+ "qwen-14b-chat": 1.5,
206
+ # tier 2
207
+ "pplx-7b-online": 1,
208
+ "pplx-70b-online": 1,
209
+ "openhermes-2.5-mistral-7b": 1.0,
210
+ "llama2-70b-steerlm-chat": 1.0,
211
+ "chatglm3-6b": 1.0,
212
+ "openchat-3.5": 1.0,
213
+ "wizardlm-70b": 1.0,
214
+ "vicuna-7b": 1.0,
215
+ "chatglm2-6b": 1.0,
216
+ # deprecated
217
+ "zephyr-7b-alpha": 1.5,
218
+ "codellama-13b-instruct": 1.0,
219
+ "mpt-30b-chat": 1.5,
220
+ "guanaco-33b": 1.0,
221
+ "fastchat-t5-3b": 0.5,
222
+ "alpaca-13b": 0.5,
223
+ "mpt-7b-chat": 0.1,
224
+ "oasst-pythia-12b": 0.1,
225
+ "RWKV-4-Raven-14B": 0.1,
226
+ "gpt4all-13b-snoozy": 0.1,
227
+ "koala-13b": 0.1,
228
+ "stablelm-tuned-alpha-7b": 0.1,
229
+ "dolly-v2-12b": 0.1,
230
+ "llama-13b": 0.1,
231
+ "chatglm-6b": 0.5,
232
+ "deluxe-chat-v1": 4,
233
+ "palm-2": 1.5,
234
+ }
235
+
236
+ # target model sampling weights will be boosted.
237
+ BATTLE_TARGETS = {
238
+ "gpt-4": {"gpt-4-0314", "claude-2.1", "gpt-4-1106-preview"},
239
+ "gpt-4-0613": {"gpt-4-0314", "claude-2.1", "gpt-4-1106-preview"},
240
+ "gpt-4-0314": {
241
+ "gpt-4-1106-preview",
242
+ "gpt-4-0613",
243
+ "claude-2.1",
244
+ "gpt-3.5-turbo-0613",
245
+ },
246
+ "gpt-4-1106-preview": {
247
+ "gpt-4-0613",
248
+ "gpt-3.5-turbo-0613",
249
+ "gpt-3.5-turbo-1106",
250
+ "claude-2.1",
251
+ "bard-feb-2024",
252
+ },
253
+ "gpt-4-0125-preview": {
254
+ "gpt-4-1106-preview",
255
+ "gpt-4-0613",
256
+ "gpt-3.5-turbo-0613",
257
+ "claude-2.1",
258
+ "mistral-medium",
259
+ "bard-feb-2024",
260
+ },
261
+ "gpt-3.5-turbo-0613": {"claude-instant-1", "gpt-4-0613", "claude-2.1"},
262
+ "gpt-3.5-turbo-1106": {"gpt-4-0613", "claude-instant-1", "gpt-3.5-turbo-0613"},
263
+ "gpt-3.5-turbo-0125": {
264
+ "gpt-4-0613",
265
+ "gpt-4-1106-preview",
266
+ "gpt-3.5-turbo-0613",
267
+ "gpt-3.5-turbo-1106",
268
+ "mixtral-8x7b-instruct-v0.1",
269
+ },
270
+ "qwen1.5-72b-chat": {
271
+ "gpt-3.5-turbo-0125",
272
+ "gpt-4-0613",
273
+ "gpt-4-1106-preview",
274
+ "llama-2-70b-chat",
275
+ "mixtral-8x7b-instruct-v0.1",
276
+ "mistral-medium",
277
+ "yi-34b-chat",
278
+ },
279
+ "qwen1.5-7b-chat": {
280
+ "gpt-3.5-turbo-0125",
281
+ "starling-lm-7b-alpha",
282
+ "llama-2-70b-chat",
283
+ "openchat-3.5",
284
+ "mixtral-8x7b-instruct-v0.1",
285
+ },
286
+ "qwen1.5-4b-chat": {
287
+ "llama-2-70b-chat",
288
+ "llama-2-13b-chat",
289
+ "llama-2-7b-chat",
290
+ "openchat-3.5",
291
+ },
292
+ "openchat-3.5-0106": {
293
+ "gpt-3.5-turbo-0125",
294
+ "gpt-3.5-turbo-0613",
295
+ "llama-2-70b-chat",
296
+ "openchat-3.5",
297
+ "mixtral-8x7b-instruct-v0.1",
298
+ },
299
+ "nous-hermes-2-mixtral-8x7b-dpo": {
300
+ "gpt-4-1106-preview",
301
+ "claude-2.1",
302
+ "mistral-medium",
303
+ "gpt-3.5-turbo-0613",
304
+ "mixtral-8x7b-instruct-v0.1",
305
+ },
306
+ "mistral-7b-instruct-v0.2": {
307
+ "llama-2-70b-chat",
308
+ "mixtral-8x7b-instruct-v0.1",
309
+ "starling-lm-7b-alpha",
310
+ "openhermes-2.5-mistral-7b",
311
+ },
312
+ "solar-10.7b-instruct-v1.0": {
313
+ "mixtral-8x7b-instruct-v0.1",
314
+ "gpt-3.5-turbo-0613",
315
+ "llama-2-70b-chat",
316
+ },
317
+ "mistral-medium": {
318
+ "gpt-3.5-turbo-0125",
319
+ "gpt-3.5-turbo-0613",
320
+ "gpt-4-1106-preview",
321
+ "mixtral-8x7b-instruct-v0.1",
322
+ "bard-feb-2024",
323
+ },
324
+ "mixtral-8x7b-instruct-v0.1": {
325
+ "gpt-3.5-turbo-0125",
326
+ "gpt-3.5-turbo-0613",
327
+ "gpt-4-1106-preview",
328
+ "llama-2-70b-chat",
329
+ },
330
+ "claude-2.1": {"gpt-4-1106-preview", "gpt-4-0613", "claude-1"},
331
+ "claude-2.0": {"gpt-4-1106-preview", "gpt-4-0613", "claude-1"},
332
+ "claude-1": {"claude-2.1", "gpt-4-0613", "gpt-3.5-turbo-0613"},
333
+ "claude-instant-1": {"gpt-3.5-turbo-0125", "claude-2.1"},
334
+ "gemini-pro": {"gpt-4-1106-preview", "gpt-4-0613", "gpt-3.5-turbo-0613"},
335
+ "gemini-pro-dev-api": {
336
+ "gpt-4-1106-preview",
337
+ "gpt-4-0613",
338
+ "gpt-3.5-turbo-0613",
339
+ "bard-feb-2024",
340
+ },
341
+ "bard-jan-24-gemini-pro": {
342
+ "gpt-4-1106-preview",
343
+ "gpt-4-0613",
344
+ "gpt-3.5-turbo-0613",
345
+ "gemini-pro-dev-api",
346
+ },
347
+ "bard-feb-2024": {
348
+ "gpt-4-1106-preview",
349
+ "gpt-4-0613",
350
+ "gpt-3.5-turbo-0613",
351
+ "bard-jan-24-gemini-pro",
352
+ },
353
+ "deepseek-llm-67b-chat": {
354
+ "gpt-4-1106-preview",
355
+ "gpt-4-turbo",
356
+ "gpt-3.5-turbo-0613",
357
+ },
358
+ "llama2-70b-steerlm-chat": {
359
+ "llama-2-70b-chat",
360
+ "tulu-2-dpo-70b",
361
+ "yi-34b-chat",
362
+ },
363
+ "stripedhyena-nous-7b": {
364
+ "starling-lm-7b-alpha",
365
+ "openhermes-2.5-mistral-7b",
366
+ "mistral-7b-instruct",
367
+ "llama-2-7b-chat",
368
+ },
369
+ "deluxe-chat-v1.1": {"gpt-4-0613", "gpt-4-1106-preview"},
370
+ "deluxe-chat-v1.2": {"gpt-4-0613", "gpt-4-1106-preview"},
371
+ "pplx-7b-online": {"gpt-3.5-turbo-0125", "llama-2-70b-chat"},
372
+ "pplx-70b-online": {"gpt-3.5-turbo-0125", "llama-2-70b-chat"},
373
+ "openhermes-2.5-mistral-7b": {
374
+ "gpt-3.5-turbo-0613",
375
+ "openchat-3.5",
376
+ "zephyr-7b-beta",
377
+ },
378
+ "dolphin-2.2.1-mistral-7b": {
379
+ "gpt-3.5-turbo-0613",
380
+ "vicuna-33b",
381
+ "starling-lm-7b-alpha",
382
+ "openhermes-2.5-mistral-7b",
383
+ },
384
+ "starling-lm-7b-alpha": {"gpt-3.5-turbo-0613", "openchat-3.5", "zephyr-7b-beta"},
385
+ "tulu-2-dpo-70b": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"},
386
+ "yi-34b-chat": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"},
387
+ "openchat-3.5": {"gpt-3.5-turbo-0613", "llama-2-70b-chat", "zephyr-7b-beta"},
388
+ "chatglm3-6b": {"yi-34b-chat", "qwen-14b-chat"},
389
+ "qwen-14b-chat": {"vicuna-13b", "llama-2-13b-chat", "llama-2-70b-chat"},
390
+ "zephyr-7b-alpha": {"mistral-7b-instruct", "llama-2-13b-chat"},
391
+ "zephyr-7b-beta": {
392
+ "mistral-7b-instruct",
393
+ "llama-2-13b-chat",
394
+ "llama-2-7b-chat",
395
+ "wizardlm-13b",
396
+ },
397
+ "llama-2-70b-chat": {"gpt-3.5-turbo-0125", "claude-instant-1"},
398
+ "llama-2-13b-chat": {"mistral-7b-instruct", "vicuna-13b", "llama-2-70b-chat"},
399
+ "llama-2-7b-chat": {"mistral-7b-instruct", "vicuna-7b", "llama-2-13b-chat"},
400
+ "mistral-7b-instruct": {
401
+ "llama-2-7b-chat",
402
+ "llama-2-13b-chat",
403
+ "llama-2-70b-chat",
404
+ },
405
+ "vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo-0613", "claude-instant-1"},
406
+ "vicuna-13b": {"llama-2-13b-chat", "llama-2-70b-chat"},
407
+ "vicuna-7b": {"llama-2-7b-chat", "mistral-7b-instruct", "llama-2-13b-chat"},
408
+ "wizardlm-70b": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"},
409
+ }
410
+
411
+ SAMPLING_BOOST_MODELS = [
412
+ # "claude-2.1",
413
+ # "gpt-4-0613",
414
+ # "gpt-4-0314",
415
+ # "gpt-4-1106-preview",
416
+ # "gpt-4-0125-preview",
417
+ "gpt-3.5-turbo-0125",
418
+ # "mistral-medium",
419
+ "nous-hermes-2-mixtral-8x7b-dpo",
420
+ "openchat-3.5-0106",
421
+ "qwen1.5-72b-chat",
422
+ "qwen1.5-7b-chat",
423
+ "qwen1.5-4b-chat",
424
+ # "mistral-7b-instruct-v0.2",
425
+ ]
426
+
427
+ # outage models won't be sampled.
428
+ OUTAGE_MODELS = []
429
+
430
+
431
+ def get_sample_weight(model):
432
+ if model in OUTAGE_MODELS:
433
+ return 0
434
+ weight = SAMPLING_WEIGHTS.get(model, 1.0)
435
+ if model in SAMPLING_BOOST_MODELS:
436
+ weight *= 5
437
+ return weight
438
+
439
+
440
+ def get_battle_pair():
441
+ if len(models) == 1:
442
+ return models[0], models[0]
443
+
444
+ model_weights = []
445
+ for model in models:
446
+ weight = get_sample_weight(model)
447
+ model_weights.append(weight)
448
+ total_weight = np.sum(model_weights)
449
+ model_weights = model_weights / total_weight
450
+ chosen_idx = np.random.choice(len(models), p=model_weights)
451
+ chosen_model = models[chosen_idx]
452
+ # for p, w in zip(models, model_weights):
453
+ # print(p, w)
454
+
455
+ rival_models = []
456
+ rival_weights = []
457
+ for model in models:
458
+ if model == chosen_model:
459
+ continue
460
+ weight = get_sample_weight(model)
461
+ if (
462
+ weight != 0
463
+ and chosen_model in BATTLE_TARGETS
464
+ and model in BATTLE_TARGETS[chosen_model]
465
+ ):
466
+ # boost to 50% chance
467
+ weight = total_weight / len(BATTLE_TARGETS[chosen_model])
468
+ rival_models.append(model)
469
+ rival_weights.append(weight)
470
+ # for p, w in zip(rival_models, rival_weights):
471
+ # print(p, w)
472
+ rival_weights = rival_weights / np.sum(rival_weights)
473
+ rival_idx = np.random.choice(len(rival_models), p=rival_weights)
474
+ rival_model = rival_models[rival_idx]
475
+
476
+ swap = np.random.randint(2)
477
+ if swap == 0:
478
+ return chosen_model, rival_model
479
+ else:
480
+ return rival_model, chosen_model
481
+
482
+
483
+ def add_text(
484
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
485
+ ):
486
+ ip = get_ip(request)
487
+ logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
488
+ states = [state0, state1]
489
+ model_selectors = [model_selector0, model_selector1]
490
+
491
+ # Init states if necessary
492
+ if states[0] is None:
493
+ assert states[1] is None
494
+
495
+ model_left, model_right = get_battle_pair()
496
+ states = [
497
+ State(model_left),
498
+ State(model_right),
499
+ ]
500
+
501
+ if len(text) <= 0:
502
+ for i in range(num_sides):
503
+ states[i].skip_next = True
504
+ return (
505
+ states
506
+ + [x.to_gradio_chatbot() for x in states]
507
+ + [""]
508
+ + [
509
+ no_change_btn,
510
+ ]
511
+ * 6
512
+ + [""]
513
+ )
514
+
515
+ model_list = [states[i].model_name for i in range(num_sides)]
516
+ flagged = moderation_filter(text, model_list)
517
+ if flagged:
518
+ logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
519
+ # overwrite the original text
520
+ text = MODERATION_MSG
521
+
522
+ conv = states[0].conv
523
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
524
+ logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
525
+ for i in range(num_sides):
526
+ states[i].skip_next = True
527
+ return (
528
+ states
529
+ + [x.to_gradio_chatbot() for x in states]
530
+ + [CONVERSATION_LIMIT_MSG]
531
+ + [
532
+ no_change_btn,
533
+ ]
534
+ * 6
535
+ + [""]
536
+ )
537
+
538
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
539
+ for i in range(num_sides):
540
+ states[i].conv.append_message(states[i].conv.roles[0], text)
541
+ states[i].conv.append_message(states[i].conv.roles[1], None)
542
+ states[i].skip_next = False
543
+
544
+ hint_msg = ""
545
+ for i in range(num_sides):
546
+ if "deluxe" in states[i].model_name:
547
+ hint_msg = SLOW_MODEL_MSG
548
+ return (
549
+ states
550
+ + [x.to_gradio_chatbot() for x in states]
551
+ + [""]
552
+ + [
553
+ disable_btn,
554
+ ]
555
+ * 6
556
+ + [hint_msg]
557
+ )
558
+
559
+
560
+ def bot_response_multi(
561
+ state0,
562
+ state1,
563
+ temperature,
564
+ top_p,
565
+ max_new_tokens,
566
+ request: gr.Request,
567
+ ):
568
+ logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
569
+
570
+ if state0 is None or state0.skip_next:
571
+ # This generate call is skipped due to invalid inputs
572
+ yield (
573
+ state0,
574
+ state1,
575
+ state0.to_gradio_chatbot(),
576
+ state1.to_gradio_chatbot(),
577
+ ) + (no_change_btn,) * 6
578
+ return
579
+
580
+ states = [state0, state1]
581
+ gen = []
582
+ for i in range(num_sides):
583
+ gen.append(
584
+ bot_response(
585
+ states[i],
586
+ temperature,
587
+ top_p,
588
+ max_new_tokens,
589
+ request,
590
+ apply_rate_limit=False,
591
+ )
592
+ )
593
+
594
+ is_gemini = []
595
+ for i in range(num_sides):
596
+ is_gemini.append(states[i].model_name in ["gemini-pro", "gemini-pro-dev-api"])
597
+ chatbots = [None] * num_sides
598
+ iters = 0
599
+ while True:
600
+ stop = True
601
+ iters += 1
602
+ for i in range(num_sides):
603
+ try:
604
+ # yield gemini fewer times as its chunk size is larger
605
+ # otherwise, gemini will stream too fast
606
+ if not is_gemini[i] or (iters % 30 == 1 or iters < 3):
607
+ ret = next(gen[i])
608
+ states[i], chatbots[i] = ret[0], ret[1]
609
+ stop = False
610
+ except StopIteration:
611
+ pass
612
+ yield states + chatbots + [disable_btn] * 6
613
+ if stop:
614
+ break
615
+
616
+
617
+ def build_side_by_side_ui_anony(models):
618
+ notice_markdown = """
619
+ # ⚔️ Chatbot Arena: Benchmarking LLMs in the Wild
620
+ | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
621
+
622
+ ## 📜 Rules
623
+ - Ask any question to two anonymous models (e.g., ChatGPT, Claude, Llama) and vote for the better one!
624
+ - You can continue chatting until you identify a winner.
625
+ - Vote won't be counted if model identity is revealed during conversation.
626
+
627
+ ## 🏆 Arena Elo&nbsp;[Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard)
628
+ We collect **200K+** human votes to compute an Elo-based LLM leaderboard.
629
+ Find out who is the 🥇LLM Champion!
630
+
631
+ ## 👇 Chat now!
632
+ """
633
+
634
+ states = [gr.State() for _ in range(num_sides)]
635
+ model_selectors = [None] * num_sides
636
+ chatbots = [None] * num_sides
637
+
638
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
639
+
640
+ with gr.Group(elem_id="share-region-anony"):
641
+ with gr.Accordion(
642
+ f"🔍 Expand to see the descriptions of {len(models)} models", open=False
643
+ ):
644
+ model_description_md = get_model_description_md(models)
645
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
646
+ with gr.Row():
647
+ for i in range(num_sides):
648
+ label = "Model A" if i == 0 else "Model B"
649
+ with gr.Column():
650
+ chatbots[i] = gr.Chatbot(
651
+ label=label,
652
+ elem_id="chatbot",
653
+ height=550,
654
+ show_copy_button=True,
655
+ )
656
+
657
+ with gr.Row():
658
+ for i in range(num_sides):
659
+ with gr.Column():
660
+ model_selectors[i] = gr.Markdown(
661
+ anony_names[i], elem_id="model_selector_md"
662
+ )
663
+ with gr.Row():
664
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
665
+
666
+ with gr.Row():
667
+ leftvote_btn = gr.Button(
668
+ value="👈 A is better", visible=False, interactive=False
669
+ )
670
+ rightvote_btn = gr.Button(
671
+ value="👉 B is better", visible=False, interactive=False
672
+ )
673
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
674
+ bothbad_btn = gr.Button(
675
+ value="👎 Both are bad", visible=False, interactive=False
676
+ )
677
+
678
+ with gr.Row():
679
+ textbox = gr.Textbox(
680
+ show_label=False,
681
+ placeholder="👉 Enter your prompt and press ENTER",
682
+ elem_id="input_box",
683
+ )
684
+ send_btn = gr.Button(value="Send", variant="primary", scale=0)
685
+
686
+ with gr.Row() as button_row:
687
+ clear_btn = gr.Button(value="🎲 New Round", interactive=False)
688
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
689
+ share_btn = gr.Button(value="📷 Share")
690
+
691
+ with gr.Accordion("Parameters", open=False) as parameter_row:
692
+ temperature = gr.Slider(
693
+ minimum=0.0,
694
+ maximum=1.0,
695
+ value=0.7,
696
+ step=0.1,
697
+ interactive=True,
698
+ label="Temperature",
699
+ )
700
+ top_p = gr.Slider(
701
+ minimum=0.0,
702
+ maximum=1.0,
703
+ value=1.0,
704
+ step=0.1,
705
+ interactive=True,
706
+ label="Top P",
707
+ )
708
+ max_output_tokens = gr.Slider(
709
+ minimum=16,
710
+ maximum=2048,
711
+ value=1024,
712
+ step=64,
713
+ interactive=True,
714
+ label="Max output tokens",
715
+ )
716
+
717
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
718
+
719
+ # Register listeners
720
+ btn_list = [
721
+ leftvote_btn,
722
+ rightvote_btn,
723
+ tie_btn,
724
+ bothbad_btn,
725
+ regenerate_btn,
726
+ clear_btn,
727
+ ]
728
+ leftvote_btn.click(
729
+ leftvote_last_response,
730
+ states + model_selectors,
731
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
732
+ )
733
+ rightvote_btn.click(
734
+ rightvote_last_response,
735
+ states + model_selectors,
736
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
737
+ )
738
+ tie_btn.click(
739
+ tievote_last_response,
740
+ states + model_selectors,
741
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
742
+ )
743
+ bothbad_btn.click(
744
+ bothbad_vote_last_response,
745
+ states + model_selectors,
746
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
747
+ )
748
+ regenerate_btn.click(
749
+ regenerate, states, states + chatbots + [textbox] + btn_list
750
+ ).then(
751
+ bot_response_multi,
752
+ states + [temperature, top_p, max_output_tokens],
753
+ states + chatbots + btn_list,
754
+ ).then(
755
+ flash_buttons, [], btn_list
756
+ )
757
+ clear_btn.click(
758
+ clear_history,
759
+ None,
760
+ states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning],
761
+ )
762
+
763
+ share_js = """
764
+ function (a, b, c, d) {
765
+ const captureElement = document.querySelector('#share-region-anony');
766
+ html2canvas(captureElement)
767
+ .then(canvas => {
768
+ canvas.style.display = 'none'
769
+ document.body.appendChild(canvas)
770
+ return canvas
771
+ })
772
+ .then(canvas => {
773
+ const image = canvas.toDataURL('image/png')
774
+ const a = document.createElement('a')
775
+ a.setAttribute('download', 'chatbot-arena.png')
776
+ a.setAttribute('href', image)
777
+ a.click()
778
+ canvas.remove()
779
+ });
780
+ return [a, b, c, d];
781
+ }
782
+ """
783
+ share_btn.click(share_click, states + model_selectors, [], js=share_js)
784
+
785
+ textbox.submit(
786
+ add_text,
787
+ states + model_selectors + [textbox],
788
+ states + chatbots + [textbox] + btn_list + [slow_warning],
789
+ ).then(
790
+ bot_response_multi,
791
+ states + [temperature, top_p, max_output_tokens],
792
+ states + chatbots + btn_list,
793
+ ).then(
794
+ flash_buttons,
795
+ [],
796
+ btn_list,
797
+ )
798
+
799
+ send_btn.click(
800
+ add_text,
801
+ states + model_selectors + [textbox],
802
+ states + chatbots + [textbox] + btn_list,
803
+ ).then(
804
+ bot_response_multi,
805
+ states + [temperature, top_p, max_output_tokens],
806
+ states + chatbots + btn_list,
807
+ ).then(
808
+ flash_buttons, [], btn_list
809
+ )
810
+
811
+ return states + model_selectors
gradio_block_arena_named.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatbot Arena (side-by-side) tab.
3
+ Users chat with two chosen models.
4
+ """
5
+
6
+ import json
7
+ import time
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+
12
+ from fastchat.constants import (
13
+ MODERATION_MSG,
14
+ CONVERSATION_LIMIT_MSG,
15
+ INPUT_CHAR_LEN_LIMIT,
16
+ CONVERSATION_TURN_LIMIT,
17
+ )
18
+ from fastchat.model.model_adapter import get_conversation_template
19
+ from fastchat.serve.gradio_web_server import (
20
+ State,
21
+ bot_response,
22
+ get_conv_log_filename,
23
+ no_change_btn,
24
+ enable_btn,
25
+ disable_btn,
26
+ invisible_btn,
27
+ acknowledgment_md,
28
+ get_ip,
29
+ get_model_description_md,
30
+ )
31
+ from fastchat.utils import (
32
+ build_logger,
33
+ moderation_filter,
34
+ )
35
+
36
+
37
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
38
+
39
+ num_sides = 2
40
+ enable_moderation = False
41
+
42
+
43
+ def set_global_vars_named(enable_moderation_):
44
+ global enable_moderation
45
+ enable_moderation = enable_moderation_
46
+
47
+
48
+ def load_demo_side_by_side_named(models, url_params):
49
+ states = (None,) * num_sides
50
+
51
+ model_left = models[0] if len(models) > 0 else ""
52
+ if len(models) > 1:
53
+ weights = ([8] * 4 + [4] * 8 + [1] * 32)[: len(models) - 1]
54
+ weights = weights / np.sum(weights)
55
+ model_right = np.random.choice(models[1:], p=weights)
56
+ else:
57
+ model_right = model_left
58
+
59
+ selector_updates = (
60
+ gr.Dropdown(choices=models, value=model_left, visible=True),
61
+ gr.Dropdown(choices=models, value=model_right, visible=True),
62
+ )
63
+
64
+ return states + selector_updates
65
+
66
+
67
+ def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
68
+ with open(get_conv_log_filename(), "a") as fout:
69
+ data = {
70
+ "tstamp": round(time.time(), 4),
71
+ "type": vote_type,
72
+ "models": [x for x in model_selectors],
73
+ "states": [x.dict() for x in states],
74
+ "ip": get_ip(request),
75
+ }
76
+ fout.write(json.dumps(data) + "\n")
77
+
78
+
79
+ def leftvote_last_response(
80
+ state0, state1, model_selector0, model_selector1, request: gr.Request
81
+ ):
82
+ logger.info(f"leftvote (named). ip: {get_ip(request)}")
83
+ vote_last_response(
84
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
85
+ )
86
+ return ("",) + (disable_btn,) * 4
87
+
88
+
89
+ def rightvote_last_response(
90
+ state0, state1, model_selector0, model_selector1, request: gr.Request
91
+ ):
92
+ logger.info(f"rightvote (named). ip: {get_ip(request)}")
93
+ vote_last_response(
94
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
95
+ )
96
+ return ("",) + (disable_btn,) * 4
97
+
98
+
99
+ def tievote_last_response(
100
+ state0, state1, model_selector0, model_selector1, request: gr.Request
101
+ ):
102
+ logger.info(f"tievote (named). ip: {get_ip(request)}")
103
+ vote_last_response(
104
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
105
+ )
106
+ return ("",) + (disable_btn,) * 4
107
+
108
+
109
+ def bothbad_vote_last_response(
110
+ state0, state1, model_selector0, model_selector1, request: gr.Request
111
+ ):
112
+ logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
113
+ vote_last_response(
114
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
115
+ )
116
+ return ("",) + (disable_btn,) * 4
117
+
118
+
119
+ def regenerate(state0, state1, request: gr.Request):
120
+ logger.info(f"regenerate (named). ip: {get_ip(request)}")
121
+ states = [state0, state1]
122
+ for i in range(num_sides):
123
+ states[i].conv.update_last_message(None)
124
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
125
+
126
+
127
+ def clear_history(request: gr.Request):
128
+ logger.info(f"clear_history (named). ip: {get_ip(request)}")
129
+ return (
130
+ [None] * num_sides
131
+ + [None] * num_sides
132
+ + [""]
133
+ + [invisible_btn] * 4
134
+ + [disable_btn] * 2
135
+ )
136
+
137
+
138
+ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
139
+ logger.info(f"share (named). ip: {get_ip(request)}")
140
+ if state0 is not None and state1 is not None:
141
+ vote_last_response(
142
+ [state0, state1], "share", [model_selector0, model_selector1], request
143
+ )
144
+
145
+
146
+ def add_text(
147
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
148
+ ):
149
+ ip = get_ip(request)
150
+ logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
151
+ states = [state0, state1]
152
+ model_selectors = [model_selector0, model_selector1]
153
+
154
+ # Init states if necessary
155
+ for i in range(num_sides):
156
+ if states[i] is None:
157
+ states[i] = State(model_selectors[i])
158
+
159
+ if len(text) <= 0:
160
+ for i in range(num_sides):
161
+ states[i].skip_next = True
162
+ return (
163
+ states
164
+ + [x.to_gradio_chatbot() for x in states]
165
+ + [""]
166
+ + [
167
+ no_change_btn,
168
+ ]
169
+ * 6
170
+ )
171
+
172
+ model_list = [states[i].model_name for i in range(num_sides)]
173
+ flagged = moderation_filter(text, model_list)
174
+ if flagged:
175
+ logger.info(f"violate moderation (named). ip: {ip}. text: {text}")
176
+ # overwrite the original text
177
+ text = MODERATION_MSG
178
+
179
+ conv = states[0].conv
180
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
181
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
182
+ for i in range(num_sides):
183
+ states[i].skip_next = True
184
+ return (
185
+ states
186
+ + [x.to_gradio_chatbot() for x in states]
187
+ + [CONVERSATION_LIMIT_MSG]
188
+ + [
189
+ no_change_btn,
190
+ ]
191
+ * 6
192
+ )
193
+
194
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
195
+ for i in range(num_sides):
196
+ states[i].conv.append_message(states[i].conv.roles[0], text)
197
+ states[i].conv.append_message(states[i].conv.roles[1], None)
198
+ states[i].skip_next = False
199
+
200
+ return (
201
+ states
202
+ + [x.to_gradio_chatbot() for x in states]
203
+ + [""]
204
+ + [
205
+ disable_btn,
206
+ ]
207
+ * 6
208
+ )
209
+
210
+
211
+ def bot_response_multi(
212
+ state0,
213
+ state1,
214
+ temperature,
215
+ top_p,
216
+ max_new_tokens,
217
+ request: gr.Request,
218
+ ):
219
+ logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
220
+
221
+ if state0.skip_next:
222
+ # This generate call is skipped due to invalid inputs
223
+ yield (
224
+ state0,
225
+ state1,
226
+ state0.to_gradio_chatbot(),
227
+ state1.to_gradio_chatbot(),
228
+ ) + (no_change_btn,) * 6
229
+ return
230
+
231
+ states = [state0, state1]
232
+ gen = []
233
+ for i in range(num_sides):
234
+ gen.append(
235
+ bot_response(
236
+ states[i],
237
+ temperature,
238
+ top_p,
239
+ max_new_tokens,
240
+ request,
241
+ )
242
+ )
243
+
244
+ is_gemini = []
245
+ for i in range(num_sides):
246
+ is_gemini.append(states[i].model_name in ["gemini-pro", "gemini-pro-dev-api"])
247
+
248
+ chatbots = [None] * num_sides
249
+ iters = 0
250
+ while True:
251
+ stop = True
252
+ iters += 1
253
+ for i in range(num_sides):
254
+ try:
255
+ # yield gemini fewer times as its chunk size is larger
256
+ # otherwise, gemini will stream too fast
257
+ if not is_gemini[i] or (iters % 30 == 1 or iters < 3):
258
+ ret = next(gen[i])
259
+ states[i], chatbots[i] = ret[0], ret[1]
260
+ stop = False
261
+ except StopIteration:
262
+ pass
263
+ yield states + chatbots + [disable_btn] * 6
264
+ if stop:
265
+ break
266
+
267
+
268
+ def flash_buttons():
269
+ btn_updates = [
270
+ [disable_btn] * 4 + [enable_btn] * 2,
271
+ [enable_btn] * 6,
272
+ ]
273
+ for i in range(4):
274
+ yield btn_updates[i % 2]
275
+ time.sleep(0.3)
276
+
277
+
278
+ def build_side_by_side_ui_named(models):
279
+ notice_markdown = """
280
+ # ⚔️ Chatbot Arena: Benchmarking LLMs in the Wild
281
+ | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
282
+
283
+ ## 📜 Rules
284
+ - Chat with any two models side-by-side and vote!
285
+ - You can continue chatting for multiple rounds.
286
+ - Click "Clear history" to start a new round.
287
+
288
+ ## 🤖 Choose two models to compare
289
+ """
290
+
291
+ states = [gr.State() for _ in range(num_sides)]
292
+ model_selectors = [None] * num_sides
293
+ chatbots = [None] * num_sides
294
+
295
+ notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
296
+
297
+ with gr.Group(elem_id="share-region-named"):
298
+ with gr.Row():
299
+ for i in range(num_sides):
300
+ with gr.Column():
301
+ model_selectors[i] = gr.Dropdown(
302
+ choices=models,
303
+ value=models[i] if len(models) > i else "",
304
+ interactive=True,
305
+ show_label=False,
306
+ container=False,
307
+ )
308
+ with gr.Row():
309
+ with gr.Accordion(
310
+ f"🔍 Expand to see the descriptions of {len(models)} models", open=False
311
+ ):
312
+ model_description_md = get_model_description_md(models)
313
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
314
+
315
+ with gr.Row():
316
+ for i in range(num_sides):
317
+ label = "Model A" if i == 0 else "Model B"
318
+ with gr.Column():
319
+ chatbots[i] = gr.Chatbot(
320
+ label=label,
321
+ elem_id=f"chatbot",
322
+ height=550,
323
+ show_copy_button=True,
324
+ )
325
+
326
+ with gr.Row():
327
+ leftvote_btn = gr.Button(
328
+ value="👈 A is better", visible=False, interactive=False
329
+ )
330
+ rightvote_btn = gr.Button(
331
+ value="👉 B is better", visible=False, interactive=False
332
+ )
333
+ tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
334
+ bothbad_btn = gr.Button(
335
+ value="👎 Both are bad", visible=False, interactive=False
336
+ )
337
+
338
+ with gr.Row():
339
+ textbox = gr.Textbox(
340
+ show_label=False,
341
+ placeholder="👉 Enter your prompt and press ENTER",
342
+ elem_id="input_box",
343
+ )
344
+ send_btn = gr.Button(value="Send", variant="primary", scale=0)
345
+
346
+ with gr.Row() as button_row:
347
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
348
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
349
+ share_btn = gr.Button(value="📷 Share")
350
+
351
+ with gr.Accordion("Parameters", open=False) as parameter_row:
352
+ temperature = gr.Slider(
353
+ minimum=0.0,
354
+ maximum=1.0,
355
+ value=0.7,
356
+ step=0.1,
357
+ interactive=True,
358
+ label="Temperature",
359
+ )
360
+ top_p = gr.Slider(
361
+ minimum=0.0,
362
+ maximum=1.0,
363
+ value=1.0,
364
+ step=0.1,
365
+ interactive=True,
366
+ label="Top P",
367
+ )
368
+ max_output_tokens = gr.Slider(
369
+ minimum=16,
370
+ maximum=2048,
371
+ value=1024,
372
+ step=64,
373
+ interactive=True,
374
+ label="Max output tokens",
375
+ )
376
+
377
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
378
+
379
+ # Register listeners
380
+ btn_list = [
381
+ leftvote_btn,
382
+ rightvote_btn,
383
+ tie_btn,
384
+ bothbad_btn,
385
+ regenerate_btn,
386
+ clear_btn,
387
+ ]
388
+ leftvote_btn.click(
389
+ leftvote_last_response,
390
+ states + model_selectors,
391
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
392
+ )
393
+ rightvote_btn.click(
394
+ rightvote_last_response,
395
+ states + model_selectors,
396
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
397
+ )
398
+ tie_btn.click(
399
+ tievote_last_response,
400
+ states + model_selectors,
401
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
402
+ )
403
+ bothbad_btn.click(
404
+ bothbad_vote_last_response,
405
+ states + model_selectors,
406
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
407
+ )
408
+ regenerate_btn.click(
409
+ regenerate, states, states + chatbots + [textbox] + btn_list
410
+ ).then(
411
+ bot_response_multi,
412
+ states + [temperature, top_p, max_output_tokens],
413
+ states + chatbots + btn_list,
414
+ ).then(
415
+ flash_buttons, [], btn_list
416
+ )
417
+ clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list)
418
+
419
+ share_js = """
420
+ function (a, b, c, d) {
421
+ const captureElement = document.querySelector('#share-region-named');
422
+ html2canvas(captureElement)
423
+ .then(canvas => {
424
+ canvas.style.display = 'none'
425
+ document.body.appendChild(canvas)
426
+ return canvas
427
+ })
428
+ .then(canvas => {
429
+ const image = canvas.toDataURL('image/png')
430
+ const a = document.createElement('a')
431
+ a.setAttribute('download', 'chatbot-arena.png')
432
+ a.setAttribute('href', image)
433
+ a.click()
434
+ canvas.remove()
435
+ });
436
+ return [a, b, c, d];
437
+ }
438
+ """
439
+ share_btn.click(share_click, states + model_selectors, [], js=share_js)
440
+
441
+ for i in range(num_sides):
442
+ model_selectors[i].change(
443
+ clear_history, None, states + chatbots + [textbox] + btn_list
444
+ )
445
+
446
+ textbox.submit(
447
+ add_text,
448
+ states + model_selectors + [textbox],
449
+ states + chatbots + [textbox] + btn_list,
450
+ ).then(
451
+ bot_response_multi,
452
+ states + [temperature, top_p, max_output_tokens],
453
+ states + chatbots + btn_list,
454
+ ).then(
455
+ flash_buttons, [], btn_list
456
+ )
457
+ send_btn.click(
458
+ add_text,
459
+ states + model_selectors + [textbox],
460
+ states + chatbots + [textbox] + btn_list,
461
+ ).then(
462
+ bot_response_multi,
463
+ states + [temperature, top_p, max_output_tokens],
464
+ states + chatbots + btn_list,
465
+ ).then(
466
+ flash_buttons, [], btn_list
467
+ )
468
+
469
+ return states + model_selectors
gradio_block_arena_vision.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The gradio demo server for chatting with a large multimodal model.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.controller
6
+ python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf
7
+ python3 -m fastchat.serve.gradio_web_server_multi --share --multimodal
8
+ """
9
+
10
+ import os
11
+
12
+ import gradio as gr
13
+
14
+ from fastchat.serve.gradio_web_server import (
15
+ upvote_last_response,
16
+ downvote_last_response,
17
+ flag_last_response,
18
+ get_model_description_md,
19
+ acknowledgment_md,
20
+ bot_response,
21
+ add_text,
22
+ clear_history,
23
+ regenerate,
24
+ )
25
+ from fastchat.utils import (
26
+ build_logger,
27
+ )
28
+
29
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
30
+
31
+
32
+ def build_single_vision_language_model_ui(models, add_promotion_links=False):
33
+ promotion = (
34
+ """
35
+ | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
36
+ """
37
+ if add_promotion_links
38
+ else ""
39
+ )
40
+
41
+ notice_markdown = f"""
42
+ # 🏔️ Chat with Open Large Vision-Language Models
43
+ {promotion}
44
+ """
45
+
46
+ state = gr.State()
47
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
48
+
49
+ with gr.Group():
50
+ with gr.Row(elem_id="model_selector_row"):
51
+ model_selector = gr.Dropdown(
52
+ choices=models,
53
+ value=models[0] if len(models) > 0 else "",
54
+ interactive=True,
55
+ show_label=False,
56
+ container=False,
57
+ )
58
+
59
+ with gr.Accordion(
60
+ f"🔍 Expand to see the descriptions of {len(models)} models", open=False
61
+ ):
62
+ model_description_md = get_model_description_md(models)
63
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
64
+
65
+ with gr.Row():
66
+ with gr.Column(scale=3):
67
+ textbox = gr.Textbox(
68
+ show_label=False,
69
+ placeholder="👉 Enter your prompt and press ENTER",
70
+ container=False,
71
+ render=False,
72
+ elem_id="input_box",
73
+ )
74
+ imagebox = gr.Image(type="pil")
75
+
76
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
77
+
78
+ with gr.Accordion("Parameters", open=False) as parameter_row:
79
+ temperature = gr.Slider(
80
+ minimum=0.0,
81
+ maximum=1.0,
82
+ value=0.2,
83
+ step=0.1,
84
+ interactive=True,
85
+ label="Temperature",
86
+ )
87
+ top_p = gr.Slider(
88
+ minimum=0.0,
89
+ maximum=1.0,
90
+ value=0.7,
91
+ step=0.1,
92
+ interactive=True,
93
+ label="Top P",
94
+ )
95
+ max_output_tokens = gr.Slider(
96
+ minimum=0,
97
+ maximum=1024,
98
+ value=512,
99
+ step=64,
100
+ interactive=True,
101
+ label="Max output tokens",
102
+ )
103
+
104
+ gr.Examples(
105
+ examples=[
106
+ [
107
+ f"{cur_dir}/example_images/city.jpeg",
108
+ "What is unusual about this image?",
109
+ ],
110
+ [
111
+ f"{cur_dir}/example_images/fridge.jpeg",
112
+ "What is in this fridge?",
113
+ ],
114
+ ],
115
+ inputs=[imagebox, textbox],
116
+ )
117
+
118
+ with gr.Column(scale=8):
119
+ chatbot = gr.Chatbot(
120
+ elem_id="chatbot", label="Scroll down and start chatting", height=550
121
+ )
122
+
123
+ with gr.Row():
124
+ with gr.Column(scale=8):
125
+ textbox.render()
126
+ with gr.Column(scale=1, min_width=50):
127
+ send_btn = gr.Button(value="Send", variant="primary")
128
+ with gr.Row(elem_id="buttons"):
129
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
130
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
131
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
132
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
133
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
134
+
135
+ if add_promotion_links:
136
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
137
+
138
+ # Register listeners
139
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
140
+ upvote_btn.click(
141
+ upvote_last_response,
142
+ [state, model_selector],
143
+ [textbox, upvote_btn, downvote_btn, flag_btn],
144
+ )
145
+ downvote_btn.click(
146
+ downvote_last_response,
147
+ [state, model_selector],
148
+ [textbox, upvote_btn, downvote_btn, flag_btn],
149
+ )
150
+ flag_btn.click(
151
+ flag_last_response,
152
+ [state, model_selector],
153
+ [textbox, upvote_btn, downvote_btn, flag_btn],
154
+ )
155
+ regenerate_btn.click(
156
+ regenerate, state, [state, chatbot, textbox, imagebox] + btn_list
157
+ ).then(
158
+ bot_response,
159
+ [state, temperature, top_p, max_output_tokens],
160
+ [state, chatbot] + btn_list,
161
+ )
162
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
163
+
164
+ model_selector.change(
165
+ clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
166
+ )
167
+
168
+ textbox.submit(
169
+ add_text,
170
+ [state, model_selector, textbox, imagebox],
171
+ [state, chatbot, textbox, imagebox] + btn_list,
172
+ ).then(
173
+ bot_response,
174
+ [state, temperature, top_p, max_output_tokens],
175
+ [state, chatbot] + btn_list,
176
+ )
177
+ send_btn.click(
178
+ add_text,
179
+ [state, model_selector, textbox, imagebox],
180
+ [state, chatbot, textbox, imagebox] + btn_list,
181
+ ).then(
182
+ bot_response,
183
+ [state, temperature, top_p, max_output_tokens],
184
+ [state, chatbot] + btn_list,
185
+ )
186
+
187
+ return [state, model_selector]
gradio_web_server.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The gradio demo server for chatting with a single model.
3
+ """
4
+
5
+ import argparse
6
+ from collections import defaultdict
7
+ import datetime
8
+ import hashlib
9
+ import json
10
+ import os
11
+ import random
12
+ import time
13
+ import uuid
14
+
15
+ import gradio as gr
16
+ import requests
17
+
18
+ from fastchat.constants import (
19
+ LOGDIR,
20
+ WORKER_API_TIMEOUT,
21
+ ErrorCode,
22
+ MODERATION_MSG,
23
+ CONVERSATION_LIMIT_MSG,
24
+ RATE_LIMIT_MSG,
25
+ SERVER_ERROR_MSG,
26
+ INPUT_CHAR_LEN_LIMIT,
27
+ CONVERSATION_TURN_LIMIT,
28
+ SESSION_EXPIRATION_TIME,
29
+ )
30
+ from fastchat.model.model_adapter import (
31
+ get_conversation_template,
32
+ )
33
+ from fastchat.model.model_registry import get_model_info, model_info
34
+ from fastchat.serve.api_provider import get_api_provider_stream_iter
35
+ from fastchat.utils import (
36
+ build_logger,
37
+ get_window_url_params_js,
38
+ get_window_url_params_with_tos_js,
39
+ moderation_filter,
40
+ parse_gradio_auth_creds,
41
+ load_image,
42
+ )
43
+
44
+
45
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
46
+
47
+ headers = {"User-Agent": "FastChat Client"}
48
+
49
+ no_change_btn = gr.Button()
50
+ enable_btn = gr.Button(interactive=True, visible=True)
51
+ disable_btn = gr.Button(interactive=False)
52
+ invisible_btn = gr.Button(interactive=False, visible=False)
53
+
54
+ controller_url = None
55
+ enable_moderation = False
56
+
57
+ acknowledgment_md = """
58
+ ### Terms of Service
59
+
60
+ Users are required to agree to the following terms before using the service:
61
+
62
+ The service is a research preview. It only provides limited safety measures and may generate offensive content.
63
+ It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
64
+ The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license.
65
+ Additionally, Bard is offered on LMSys for research purposes only. To access the Bard product, please visit its [website](http://bard.google.com).
66
+
67
+ ### Acknowledgment
68
+ We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous [sponsorship](https://lmsys.org/donations/).
69
+
70
+ <div class="sponsor-image-about">
71
+ <img src="https://storage.googleapis.com/public-arena-asset/kaggle.png" alt="Kaggle">
72
+ <img src="https://storage.googleapis.com/public-arena-asset/mbzuai.jpeg" alt="MBZUAI">
73
+ <img src="https://storage.googleapis.com/public-arena-asset/a16z.jpeg" alt="a16z">
74
+ <img src="https://storage.googleapis.com/public-arena-asset/together.png" alt="Together AI">
75
+ <img src="https://storage.googleapis.com/public-arena-asset/anyscale.png" alt="AnyScale">
76
+ <img src="https://storage.googleapis.com/public-arena-asset/huggingface.png" alt="HuggingFace">
77
+ </div>
78
+ """
79
+
80
+ # JSON file format of API-based models:
81
+ # {
82
+ # "gpt-3.5-turbo-0613": {
83
+ # "model_name": "gpt-3.5-turbo-0613",
84
+ # "api_type": "openai",
85
+ # "api_base": "https://api.openai.com/v1",
86
+ # "api_key": "sk-******",
87
+ # "anony_only": false
88
+ # }
89
+ # }
90
+ # "api_type" can be one of the following: openai, anthropic, gemini, mistral.
91
+ # "anony_only" means whether to show this model in anonymous mode only.
92
+ api_endpoint_info = {}
93
+
94
+
95
+ class State:
96
+ def __init__(self, model_name):
97
+ self.conv = get_conversation_template(model_name)
98
+ self.conv_id = uuid.uuid4().hex
99
+ self.skip_next = False
100
+ self.model_name = model_name
101
+
102
+ def to_gradio_chatbot(self):
103
+ return self.conv.to_gradio_chatbot()
104
+
105
+ def dict(self):
106
+ base = self.conv.dict()
107
+ base.update(
108
+ {
109
+ "conv_id": self.conv_id,
110
+ "model_name": self.model_name,
111
+ }
112
+ )
113
+ return base
114
+
115
+
116
+ def set_global_vars(controller_url_, enable_moderation_):
117
+ global controller_url, enable_moderation
118
+ controller_url = controller_url_
119
+ enable_moderation = enable_moderation_
120
+
121
+
122
+ def get_conv_log_filename():
123
+ t = datetime.datetime.now()
124
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
125
+ return name
126
+
127
+
128
+ def get_model_list(controller_url, register_api_endpoint_file, multimodal):
129
+ global api_endpoint_info
130
+
131
+ # Add models from the controller
132
+ if controller_url:
133
+ ret = requests.post(controller_url + "/refresh_all_workers")
134
+ assert ret.status_code == 200
135
+
136
+ if multimodal:
137
+ ret = requests.post(controller_url + "/list_multimodal_models")
138
+ models = ret.json()["models"]
139
+ else:
140
+ ret = requests.post(controller_url + "/list_language_models")
141
+ models = ret.json()["models"]
142
+ else:
143
+ models = []
144
+
145
+ # Add models from the API providers
146
+ if register_api_endpoint_file:
147
+ api_endpoint_info = json.load(open(register_api_endpoint_file))
148
+ for mdl, mdl_dict in api_endpoint_info.items():
149
+ mdl_multimodal = mdl_dict.get("multimodal", False)
150
+ if multimodal and mdl_multimodal:
151
+ models += [mdl]
152
+ elif not multimodal and not mdl_multimodal:
153
+ models += [mdl]
154
+
155
+ # Remove anonymous models
156
+ models = list(set(models))
157
+ visible_models = models.copy()
158
+ for mdl in visible_models:
159
+ if mdl not in api_endpoint_info:
160
+ continue
161
+ mdl_dict = api_endpoint_info[mdl]
162
+ if mdl_dict["anony_only"]:
163
+ visible_models.remove(mdl)
164
+
165
+ # Sort models and add descriptions
166
+ priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)}
167
+ models.sort(key=lambda x: priority.get(x, x))
168
+ visible_models.sort(key=lambda x: priority.get(x, x))
169
+ logger.info(f"All models: {models}")
170
+ logger.info(f"Visible models: {visible_models}")
171
+ return visible_models, models
172
+
173
+
174
+ def load_demo_single(models, url_params):
175
+ selected_model = models[0] if len(models) > 0 else ""
176
+ if "model" in url_params:
177
+ model = url_params["model"]
178
+ if model in models:
179
+ selected_model = model
180
+
181
+ dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True)
182
+ state = None
183
+ return state, dropdown_update
184
+
185
+
186
+ def load_demo(url_params, request: gr.Request):
187
+ global models
188
+
189
+ ip = get_ip(request)
190
+ logger.info(f"load_demo. ip: {ip}. params: {url_params}")
191
+
192
+ if args.model_list_mode == "reload":
193
+ models, all_models = get_model_list(
194
+ controller_url, args.register_api_endpoint_file, False
195
+ )
196
+
197
+ return load_demo_single(models, url_params)
198
+
199
+
200
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
201
+ with open(get_conv_log_filename(), "a") as fout:
202
+ data = {
203
+ "tstamp": round(time.time(), 4),
204
+ "type": vote_type,
205
+ "model": model_selector,
206
+ "state": state.dict(),
207
+ "ip": get_ip(request),
208
+ }
209
+ fout.write(json.dumps(data) + "\n")
210
+
211
+
212
+ def upvote_last_response(state, model_selector, request: gr.Request):
213
+ ip = get_ip(request)
214
+ logger.info(f"upvote. ip: {ip}")
215
+ vote_last_response(state, "upvote", model_selector, request)
216
+ return ("",) + (disable_btn,) * 3
217
+
218
+
219
+ def downvote_last_response(state, model_selector, request: gr.Request):
220
+ ip = get_ip(request)
221
+ logger.info(f"downvote. ip: {ip}")
222
+ vote_last_response(state, "downvote", model_selector, request)
223
+ return ("",) + (disable_btn,) * 3
224
+
225
+
226
+ def flag_last_response(state, model_selector, request: gr.Request):
227
+ ip = get_ip(request)
228
+ logger.info(f"flag. ip: {ip}")
229
+ vote_last_response(state, "flag", model_selector, request)
230
+ return ("",) + (disable_btn,) * 3
231
+
232
+
233
+ def regenerate(state, request: gr.Request):
234
+ ip = get_ip(request)
235
+ logger.info(f"regenerate. ip: {ip}")
236
+ state.conv.update_last_message(None)
237
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
238
+
239
+
240
+ def clear_history(request: gr.Request):
241
+ ip = get_ip(request)
242
+ logger.info(f"clear_history. ip: {ip}")
243
+ state = None
244
+ return (state, [], "", None) + (disable_btn,) * 5
245
+
246
+
247
+ def get_ip(request: gr.Request):
248
+ if "cf-connecting-ip" in request.headers:
249
+ ip = request.headers["cf-connecting-ip"]
250
+ else:
251
+ ip = request.client.host
252
+ return ip
253
+
254
+
255
+ def _prepare_text_with_image(state, text, image):
256
+ if image is not None:
257
+ if len(state.conv.get_images()) > 0:
258
+ # reset convo with new image
259
+ state.conv = get_conversation_template(state.model_name)
260
+
261
+ image = state.conv.convert_image_to_base64(
262
+ image
263
+ ) # PIL type is not JSON serializable
264
+
265
+ text = text, [image]
266
+
267
+ return text
268
+
269
+
270
+ def add_text(state, model_selector, text, image, request: gr.Request):
271
+ ip = get_ip(request)
272
+ logger.info(f"add_text. ip: {ip}. len: {len(text)}")
273
+
274
+ if state is None:
275
+ state = State(model_selector)
276
+
277
+ if len(text) <= 0:
278
+ state.skip_next = True
279
+ return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
280
+
281
+ flagged = moderation_filter(text, [state.model_name])
282
+ if flagged:
283
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
284
+ # overwrite the original text
285
+ text = MODERATION_MSG
286
+
287
+ if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
288
+ logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
289
+ state.skip_next = True
290
+ return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
291
+ no_change_btn,
292
+ ) * 5
293
+
294
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
295
+ text = _prepare_text_with_image(state, text, image)
296
+ state.conv.append_message(state.conv.roles[0], text)
297
+ state.conv.append_message(state.conv.roles[1], None)
298
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
299
+
300
+
301
+ def model_worker_stream_iter(
302
+ conv,
303
+ model_name,
304
+ worker_addr,
305
+ prompt,
306
+ temperature,
307
+ repetition_penalty,
308
+ top_p,
309
+ max_new_tokens,
310
+ images,
311
+ ):
312
+ # Make requests
313
+ gen_params = {
314
+ "model": model_name,
315
+ "prompt": prompt,
316
+ "temperature": temperature,
317
+ "repetition_penalty": repetition_penalty,
318
+ "top_p": top_p,
319
+ "max_new_tokens": max_new_tokens,
320
+ "stop": conv.stop_str,
321
+ "stop_token_ids": conv.stop_token_ids,
322
+ "echo": False,
323
+ }
324
+
325
+ logger.info(f"==== request ====\n{gen_params}")
326
+
327
+ if len(images) > 0:
328
+ gen_params["images"] = images
329
+
330
+ # Stream output
331
+ response = requests.post(
332
+ worker_addr + "/worker_generate_stream",
333
+ headers=headers,
334
+ json=gen_params,
335
+ stream=True,
336
+ timeout=WORKER_API_TIMEOUT,
337
+ )
338
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
339
+ if chunk:
340
+ data = json.loads(chunk.decode())
341
+ yield data
342
+
343
+
344
+ def is_limit_reached(model_name, ip):
345
+ monitor_url = "http://localhost:9090"
346
+ try:
347
+ ret = requests.get(
348
+ f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1
349
+ )
350
+ obj = ret.json()
351
+ return obj
352
+ except Exception as e:
353
+ logger.info(f"monitor error: {e}")
354
+ return None
355
+
356
+
357
+ def bot_response(
358
+ state,
359
+ temperature,
360
+ top_p,
361
+ max_new_tokens,
362
+ request: gr.Request,
363
+ apply_rate_limit=True,
364
+ ):
365
+ ip = get_ip(request)
366
+ logger.info(f"bot_response. ip: {ip}")
367
+ start_tstamp = time.time()
368
+ temperature = float(temperature)
369
+ top_p = float(top_p)
370
+ max_new_tokens = int(max_new_tokens)
371
+
372
+ if state.skip_next:
373
+ # This generate call is skipped due to invalid inputs
374
+ state.skip_next = False
375
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
376
+ return
377
+
378
+ if apply_rate_limit:
379
+ ret = is_limit_reached(state.model_name, ip)
380
+ if ret is not None and ret["is_limit_reached"]:
381
+ error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"]
382
+ logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}")
383
+ state.conv.update_last_message(error_msg)
384
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
385
+ return
386
+
387
+ conv, model_name = state.conv, state.model_name
388
+ model_api_dict = (
389
+ api_endpoint_info[model_name] if model_name in api_endpoint_info else None
390
+ )
391
+ images = conv.get_images()
392
+
393
+ if model_api_dict is None:
394
+ # Query worker address
395
+ ret = requests.post(
396
+ controller_url + "/get_worker_address", json={"model": model_name}
397
+ )
398
+ worker_addr = ret.json()["address"]
399
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
400
+
401
+ # No available worker
402
+ if worker_addr == "":
403
+ conv.update_last_message(SERVER_ERROR_MSG)
404
+ yield (
405
+ state,
406
+ state.to_gradio_chatbot(),
407
+ disable_btn,
408
+ disable_btn,
409
+ disable_btn,
410
+ enable_btn,
411
+ enable_btn,
412
+ )
413
+ return
414
+
415
+ # Construct prompt.
416
+ # We need to call it here, so it will not be affected by "▌".
417
+ prompt = conv.get_prompt()
418
+
419
+ # Set repetition_penalty
420
+ if "t5" in model_name:
421
+ repetition_penalty = 1.2
422
+ else:
423
+ repetition_penalty = 1.0
424
+
425
+ stream_iter = model_worker_stream_iter(
426
+ conv,
427
+ model_name,
428
+ worker_addr,
429
+ prompt,
430
+ temperature,
431
+ repetition_penalty,
432
+ top_p,
433
+ max_new_tokens,
434
+ images,
435
+ )
436
+ else:
437
+ stream_iter = get_api_provider_stream_iter(
438
+ conv,
439
+ model_name,
440
+ model_api_dict,
441
+ temperature,
442
+ top_p,
443
+ max_new_tokens,
444
+ )
445
+
446
+ conv.update_last_message("▌")
447
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
448
+
449
+ try:
450
+ for i, data in enumerate(stream_iter):
451
+ if data["error_code"] == 0:
452
+ output = data["text"].strip()
453
+ conv.update_last_message(output + "▌")
454
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
455
+ else:
456
+ output = data["text"] + f"\n\n(error_code: {data['error_code']})"
457
+ conv.update_last_message(output)
458
+ yield (state, state.to_gradio_chatbot()) + (
459
+ disable_btn,
460
+ disable_btn,
461
+ disable_btn,
462
+ enable_btn,
463
+ enable_btn,
464
+ )
465
+ return
466
+ output = data["text"].strip()
467
+ conv.update_last_message(output)
468
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
469
+ except requests.exceptions.RequestException as e:
470
+ conv.update_last_message(
471
+ f"{SERVER_ERROR_MSG}\n\n"
472
+ f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
473
+ )
474
+ yield (state, state.to_gradio_chatbot()) + (
475
+ disable_btn,
476
+ disable_btn,
477
+ disable_btn,
478
+ enable_btn,
479
+ enable_btn,
480
+ )
481
+ return
482
+ except Exception as e:
483
+ conv.update_last_message(
484
+ f"{SERVER_ERROR_MSG}\n\n"
485
+ f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
486
+ )
487
+ yield (state, state.to_gradio_chatbot()) + (
488
+ disable_btn,
489
+ disable_btn,
490
+ disable_btn,
491
+ enable_btn,
492
+ enable_btn,
493
+ )
494
+ return
495
+
496
+ finish_tstamp = time.time()
497
+ logger.info(f"{output}")
498
+
499
+ # We load the image because gradio accepts base64 but that increases file size by ~1.33x
500
+ loaded_images = [load_image(image) for image in images]
501
+ images_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images]
502
+ for image, hash_str in zip(loaded_images, images_hash):
503
+ t = datetime.datetime.now()
504
+ filename = os.path.join(
505
+ LOGDIR,
506
+ "serve_images",
507
+ f"{hash_str}.jpg",
508
+ )
509
+ if not os.path.isfile(filename):
510
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
511
+ image.save(filename)
512
+
513
+ with open(get_conv_log_filename(), "a") as fout:
514
+ data = {
515
+ "tstamp": round(finish_tstamp, 4),
516
+ "type": "chat",
517
+ "model": model_name,
518
+ "gen_params": {
519
+ "temperature": temperature,
520
+ "top_p": top_p,
521
+ "max_new_tokens": max_new_tokens,
522
+ },
523
+ "start": round(start_tstamp, 4),
524
+ "finish": round(finish_tstamp, 4),
525
+ "state": state.dict(),
526
+ "ip": get_ip(request),
527
+ "images": images_hash,
528
+ }
529
+ fout.write(json.dumps(data) + "\n")
530
+
531
+
532
+ block_css = """
533
+ #notice_markdown .prose {
534
+ font-size: 120% !important;
535
+ }
536
+ #notice_markdown th {
537
+ display: none;
538
+ }
539
+ #notice_markdown td {
540
+ padding-top: 6px;
541
+ padding-bottom: 6px;
542
+ }
543
+ #model_description_markdown {
544
+ font-size: 120% !important;
545
+ }
546
+ #leaderboard_markdown .prose {
547
+ font-size: 120% !important;
548
+ }
549
+ #leaderboard_markdown td {
550
+ padding-top: 6px;
551
+ padding-bottom: 6px;
552
+ }
553
+ #leaderboard_dataframe td {
554
+ line-height: 0.1em;
555
+ }
556
+ #about_markdown .prose {
557
+ font-size: 120% !important;
558
+ }
559
+ #ack_markdown .prose {
560
+ font-size: 120% !important;
561
+ }
562
+ footer {
563
+ display:none !important;
564
+ }
565
+ .sponsor-image-about img {
566
+ margin: 0 20px;
567
+ margin-top: 20px;
568
+ height: 40px;
569
+ max-height: 100%;
570
+ width: auto;
571
+ float: left;
572
+ }
573
+ """
574
+
575
+
576
+ def get_model_description_md(models):
577
+ model_description_md = """
578
+ | | | |
579
+ | ---- | ---- | ---- |
580
+ """
581
+ ct = 0
582
+ visited = set()
583
+ for i, name in enumerate(models):
584
+ minfo = get_model_info(name)
585
+ if minfo.simple_name in visited:
586
+ continue
587
+ visited.add(minfo.simple_name)
588
+ one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
589
+
590
+ if ct % 3 == 0:
591
+ model_description_md += "|"
592
+ model_description_md += f" {one_model_md} |"
593
+ if ct % 3 == 2:
594
+ model_description_md += "\n"
595
+ ct += 1
596
+ return model_description_md
597
+
598
+
599
+ def build_about():
600
+ about_markdown = """
601
+ # About Us
602
+ Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our [FastChat](https://github.com/lm-sys/FastChat) project at GitHub and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey!
603
+
604
+ ## Read More
605
+ - Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/)
606
+ - LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998)
607
+
608
+ ## Core Members
609
+ [Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ)
610
+
611
+ ## Advisors
612
+ [Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/)
613
+
614
+ ## Contact Us
615
+ - Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at [email protected]
616
+ - File issues on [GitHub](https://github.com/lm-sys/FastChat)
617
+ - Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys)
618
+
619
+ ## Acknowledgment
620
+ We thank [SkyPilot](https://github.com/skypilot-org/skypilot) and [Gradio](https://github.com/gradio-app/gradio) team for their system support.
621
+ We also thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. Learn more about partnership [here](https://lmsys.org/donations/).
622
+
623
+ <div class="sponsor-image-about">
624
+ <img src="https://storage.googleapis.com/public-arena-asset/kaggle.png" alt="Kaggle">
625
+ <img src="https://storage.googleapis.com/public-arena-asset/mbzuai.jpeg" alt="MBZUAI">
626
+ <img src="https://storage.googleapis.com/public-arena-asset/a16z.jpeg" alt="a16z">
627
+ <img src="https://storage.googleapis.com/public-arena-asset/together.png" alt="Together AI">
628
+ <img src="https://storage.googleapis.com/public-arena-asset/anyscale.png" alt="AnyScale">
629
+ <img src="https://storage.googleapis.com/public-arena-asset/huggingface.png" alt="HuggingFace">
630
+ </div>
631
+ """
632
+ gr.Markdown(about_markdown, elem_id="about_markdown")
633
+
634
+
635
+ def build_single_model_ui(models, add_promotion_links=False):
636
+ promotion = (
637
+ """
638
+ - | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
639
+ - Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/)
640
+ - Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/)
641
+
642
+ ## 🤖 Choose any model to chat
643
+ """
644
+ if add_promotion_links
645
+ else ""
646
+ )
647
+
648
+ notice_markdown = f"""
649
+ # 🏔️ Chat with Open Large Language Models
650
+ {promotion}
651
+ """
652
+
653
+ state = gr.State()
654
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
655
+
656
+ with gr.Group(elem_id="share-region-named"):
657
+ with gr.Row(elem_id="model_selector_row"):
658
+ model_selector = gr.Dropdown(
659
+ choices=models,
660
+ value=models[0] if len(models) > 0 else "",
661
+ interactive=True,
662
+ show_label=False,
663
+ container=False,
664
+ )
665
+ with gr.Row():
666
+ with gr.Accordion(
667
+ f"🔍 Expand to see the descriptions of {len(models)} models",
668
+ open=False,
669
+ ):
670
+ model_description_md = get_model_description_md(models)
671
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
672
+
673
+ chatbot = gr.Chatbot(
674
+ elem_id="chatbot",
675
+ label="Scroll down and start chatting",
676
+ height=550,
677
+ show_copy_button=True,
678
+ )
679
+ with gr.Row():
680
+ textbox = gr.Textbox(
681
+ show_label=False,
682
+ placeholder="👉 Enter your prompt and press ENTER",
683
+ elem_id="input_box",
684
+ )
685
+ send_btn = gr.Button(value="Send", variant="primary", scale=0)
686
+
687
+ with gr.Row() as button_row:
688
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
689
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
690
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
691
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
692
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
693
+
694
+ with gr.Accordion("Parameters", open=False) as parameter_row:
695
+ temperature = gr.Slider(
696
+ minimum=0.0,
697
+ maximum=1.0,
698
+ value=0.7,
699
+ step=0.1,
700
+ interactive=True,
701
+ label="Temperature",
702
+ )
703
+ top_p = gr.Slider(
704
+ minimum=0.0,
705
+ maximum=1.0,
706
+ value=1.0,
707
+ step=0.1,
708
+ interactive=True,
709
+ label="Top P",
710
+ )
711
+ max_output_tokens = gr.Slider(
712
+ minimum=16,
713
+ maximum=2048,
714
+ value=1024,
715
+ step=64,
716
+ interactive=True,
717
+ label="Max output tokens",
718
+ )
719
+
720
+ if add_promotion_links:
721
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
722
+
723
+ # Register listeners
724
+ imagebox = gr.State(None)
725
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
726
+ upvote_btn.click(
727
+ upvote_last_response,
728
+ [state, model_selector],
729
+ [textbox, upvote_btn, downvote_btn, flag_btn],
730
+ )
731
+ downvote_btn.click(
732
+ downvote_last_response,
733
+ [state, model_selector],
734
+ [textbox, upvote_btn, downvote_btn, flag_btn],
735
+ )
736
+ flag_btn.click(
737
+ flag_last_response,
738
+ [state, model_selector],
739
+ [textbox, upvote_btn, downvote_btn, flag_btn],
740
+ )
741
+ regenerate_btn.click(
742
+ regenerate, state, [state, chatbot, textbox, imagebox] + btn_list
743
+ ).then(
744
+ bot_response,
745
+ [state, temperature, top_p, max_output_tokens],
746
+ [state, chatbot] + btn_list,
747
+ )
748
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
749
+
750
+ model_selector.change(
751
+ clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
752
+ )
753
+
754
+ textbox.submit(
755
+ add_text,
756
+ [state, model_selector, textbox, imagebox],
757
+ [state, chatbot, textbox, imagebox] + btn_list,
758
+ ).then(
759
+ bot_response,
760
+ [state, temperature, top_p, max_output_tokens],
761
+ [state, chatbot] + btn_list,
762
+ )
763
+ send_btn.click(
764
+ add_text,
765
+ [state, model_selector, textbox, imagebox],
766
+ [state, chatbot, textbox, imagebox] + btn_list,
767
+ ).then(
768
+ bot_response,
769
+ [state, temperature, top_p, max_output_tokens],
770
+ [state, chatbot] + btn_list,
771
+ )
772
+
773
+ return [state, model_selector]
774
+
775
+
776
+ def build_demo(models):
777
+ with gr.Blocks(
778
+ title="Chat with Open Large Language Models",
779
+ theme=gr.themes.Default(),
780
+ css=block_css,
781
+ ) as demo:
782
+ url_params = gr.JSON(visible=False)
783
+
784
+ state, model_selector = build_single_model_ui(models)
785
+
786
+ if args.model_list_mode not in ["once", "reload"]:
787
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
788
+
789
+ if args.show_terms_of_use:
790
+ load_js = get_window_url_params_with_tos_js
791
+ else:
792
+ load_js = get_window_url_params_js
793
+
794
+ demo.load(
795
+ load_demo,
796
+ [url_params],
797
+ [
798
+ state,
799
+ model_selector,
800
+ ],
801
+ js=load_js,
802
+ )
803
+
804
+ return demo
805
+
806
+
807
+ if __name__ == "__main__":
808
+ parser = argparse.ArgumentParser()
809
+ parser.add_argument("--host", type=str, default="0.0.0.0")
810
+ parser.add_argument("--port", type=int)
811
+ parser.add_argument(
812
+ "--share",
813
+ action="store_true",
814
+ help="Whether to generate a public, shareable link",
815
+ )
816
+ parser.add_argument(
817
+ "--controller-url",
818
+ type=str,
819
+ default="http://localhost:21001",
820
+ help="The address of the controller",
821
+ )
822
+ parser.add_argument(
823
+ "--concurrency-count",
824
+ type=int,
825
+ default=10,
826
+ help="The concurrency count of the gradio queue",
827
+ )
828
+ parser.add_argument(
829
+ "--model-list-mode",
830
+ type=str,
831
+ default="once",
832
+ choices=["once", "reload"],
833
+ help="Whether to load the model list once or reload the model list every time",
834
+ )
835
+ parser.add_argument(
836
+ "--moderate",
837
+ action="store_true",
838
+ help="Enable content moderation to block unsafe inputs",
839
+ )
840
+ parser.add_argument(
841
+ "--show-terms-of-use",
842
+ action="store_true",
843
+ help="Shows term of use before loading the demo",
844
+ )
845
+ parser.add_argument(
846
+ "--register-api-endpoint-file",
847
+ type=str,
848
+ help="Register API-based model endpoints from a JSON file",
849
+ )
850
+ parser.add_argument(
851
+ "--gradio-auth-path",
852
+ type=str,
853
+ help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
854
+ )
855
+ parser.add_argument(
856
+ "--gradio-root-path",
857
+ type=str,
858
+ help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
859
+ )
860
+ args = parser.parse_args()
861
+ logger.info(f"args: {args}")
862
+
863
+ # Set global variables
864
+ set_global_vars(args.controller_url, args.moderate)
865
+ models, all_models = get_model_list(
866
+ args.controller_url, args.register_api_endpoint_file, False
867
+ )
868
+
869
+ # Set authorization credentials
870
+ auth = None
871
+ if args.gradio_auth_path is not None:
872
+ auth = parse_gradio_auth_creds(args.gradio_auth_path)
873
+
874
+ # Launch the demo
875
+ demo = build_demo(models)
876
+ demo.queue(
877
+ default_concurrency_limit=args.concurrency_count,
878
+ status_update_rate=10,
879
+ api_open=False,
880
+ ).launch(
881
+ server_name=args.host,
882
+ server_port=args.port,
883
+ share=args.share,
884
+ max_threads=200,
885
+ auth=auth,
886
+ root_path=args.gradio_root_path,
887
+ )
gradio_web_server_multi.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The gradio demo server with multiple tabs.
3
+ It supports chatting with a single model or chatting with two models side-by-side.
4
+ """
5
+
6
+ import argparse
7
+ import pickle
8
+ import time
9
+
10
+ import gradio as gr
11
+
12
+ from fastchat.serve.gradio_block_arena_anony import (
13
+ build_side_by_side_ui_anony,
14
+ load_demo_side_by_side_anony,
15
+ set_global_vars_anony,
16
+ )
17
+ from fastchat.serve.gradio_block_arena_named import (
18
+ build_side_by_side_ui_named,
19
+ load_demo_side_by_side_named,
20
+ set_global_vars_named,
21
+ )
22
+ from fastchat.serve.gradio_block_arena_vision import (
23
+ build_single_vision_language_model_ui,
24
+ )
25
+ from fastchat.serve.gradio_web_server import (
26
+ set_global_vars,
27
+ block_css,
28
+ build_single_model_ui,
29
+ build_about,
30
+ get_model_list,
31
+ load_demo_single,
32
+ get_ip,
33
+ )
34
+ from fastchat.serve.monitor.monitor import build_leaderboard_tab
35
+ from fastchat.utils import (
36
+ build_logger,
37
+ get_window_url_params_js,
38
+ get_window_url_params_with_tos_js,
39
+ parse_gradio_auth_creds,
40
+ )
41
+
42
+ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
43
+
44
+
45
+ def load_demo(url_params, request: gr.Request):
46
+ global models, all_models, vl_models
47
+
48
+ ip = get_ip(request)
49
+ logger.info(f"load_demo. ip: {ip}. params: {url_params}")
50
+
51
+ selected = 0
52
+ if "arena" in url_params:
53
+ selected = 0
54
+ elif "compare" in url_params:
55
+ selected = 1
56
+ elif "direct" in url_params or "model" in url_params:
57
+ selected = 2
58
+ elif "vision" in url_params:
59
+ selected = 3
60
+ elif "leaderboard" in url_params:
61
+ selected = 4
62
+
63
+ if args.model_list_mode == "reload":
64
+ models, all_models = get_model_list(
65
+ args.controller_url,
66
+ args.register_api_endpoint_file,
67
+ False,
68
+ )
69
+
70
+ vl_models, all_vl_models = get_model_list(
71
+ args.controller_url,
72
+ args.register_api_endpoint_file,
73
+ True,
74
+ )
75
+
76
+ single_updates = load_demo_single(models, url_params)
77
+ side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params)
78
+ side_by_side_named_updates = load_demo_side_by_side_named(models, url_params)
79
+ vision_language_updates = load_demo_single(vl_models, url_params)
80
+
81
+ return (
82
+ (gr.Tabs(selected=selected),)
83
+ + single_updates
84
+ + side_by_side_anony_updates
85
+ + side_by_side_named_updates
86
+ + vision_language_updates
87
+ )
88
+
89
+
90
+ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
91
+ text_size = gr.themes.sizes.text_md
92
+ if args.show_terms_of_use:
93
+ load_js = get_window_url_params_with_tos_js
94
+ else:
95
+ load_js = get_window_url_params_js
96
+
97
+ head_js = """
98
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.min.js"></script>
99
+ """
100
+ if args.ga_id is not None:
101
+ head_js += f"""
102
+ <script async src="https://www.googletagmanager.com/gtag/js?id={args.ga_id}"></script>
103
+ <script>
104
+ window.dataLayer = window.dataLayer || [];
105
+ function gtag(){{dataLayer.push(arguments);}}
106
+ gtag('js', new Date());
107
+
108
+ gtag('config', '{args.ga_id}');
109
+ window.__gradio_mode__ = "app";
110
+ </script>
111
+ """
112
+
113
+ with gr.Blocks(
114
+ title="Chat with Open Large Language Models",
115
+ theme=gr.themes.Default(text_size=text_size),
116
+ css=block_css,
117
+ head=head_js,
118
+ ) as demo:
119
+ with gr.Tabs() as tabs:
120
+ with gr.Tab("Arena (battle)", id=0):
121
+ side_by_side_anony_list = build_side_by_side_ui_anony(models)
122
+
123
+ with gr.Tab("Arena (side-by-side)", id=1):
124
+ side_by_side_named_list = build_side_by_side_ui_named(models)
125
+
126
+ with gr.Tab("Direct Chat", id=2):
127
+ single_model_list = build_single_model_ui(
128
+ models, add_promotion_links=True
129
+ )
130
+
131
+ with gr.Tab(
132
+ "Vision-Language Model Direct Chat", id=3, visible=args.multimodal
133
+ ):
134
+ single_vision_language_model_list = (
135
+ build_single_vision_language_model_ui(
136
+ vl_models, add_promotion_links=True
137
+ )
138
+ )
139
+
140
+ if elo_results_file:
141
+ with gr.Tab("Leaderboard", id=4):
142
+ build_leaderboard_tab(elo_results_file, leaderboard_table_file)
143
+
144
+ with gr.Tab("About Us", id=5):
145
+ about = build_about()
146
+
147
+ url_params = gr.JSON(visible=False)
148
+
149
+ if args.model_list_mode not in ["once", "reload"]:
150
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
151
+
152
+ demo.load(
153
+ load_demo,
154
+ [url_params],
155
+ [tabs]
156
+ + single_model_list
157
+ + side_by_side_anony_list
158
+ + side_by_side_named_list
159
+ + single_vision_language_model_list,
160
+ js=load_js,
161
+ )
162
+
163
+ return demo
164
+
165
+
166
+ if __name__ == "__main__":
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument("--host", type=str, default="0.0.0.0")
169
+ parser.add_argument("--port", type=int)
170
+ parser.add_argument(
171
+ "--share",
172
+ action="store_true",
173
+ help="Whether to generate a public, shareable link",
174
+ )
175
+ parser.add_argument(
176
+ "--controller-url",
177
+ type=str,
178
+ default="http://localhost:21001",
179
+ help="The address of the controller",
180
+ )
181
+ parser.add_argument(
182
+ "--concurrency-count",
183
+ type=int,
184
+ default=10,
185
+ help="The concurrency count of the gradio queue",
186
+ )
187
+ parser.add_argument(
188
+ "--model-list-mode",
189
+ type=str,
190
+ default="once",
191
+ choices=["once", "reload"],
192
+ help="Whether to load the model list once or reload the model list every time.",
193
+ )
194
+ parser.add_argument(
195
+ "--moderate",
196
+ action="store_true",
197
+ help="Enable content moderation to block unsafe inputs",
198
+ )
199
+ parser.add_argument(
200
+ "--show-terms-of-use",
201
+ action="store_true",
202
+ help="Shows term of use before loading the demo",
203
+ )
204
+ parser.add_argument(
205
+ "--multimodal", action="store_true", help="Show multi modal tabs."
206
+ )
207
+ parser.add_argument(
208
+ "--register-api-endpoint-file",
209
+ type=str,
210
+ help="Register API-based model endpoints from a JSON file",
211
+ )
212
+ parser.add_argument(
213
+ "--gradio-auth-path",
214
+ type=str,
215
+ help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
216
+ default=None,
217
+ )
218
+ parser.add_argument(
219
+ "--elo-results-file", type=str, help="Load leaderboard results and plots"
220
+ )
221
+ parser.add_argument(
222
+ "--leaderboard-table-file", type=str, help="Load leaderboard results and plots"
223
+ )
224
+ parser.add_argument(
225
+ "--gradio-root-path",
226
+ type=str,
227
+ help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
228
+ )
229
+ parser.add_argument(
230
+ "--ga-id",
231
+ type=str,
232
+ help="the Google Analytics ID",
233
+ default=None,
234
+ )
235
+ args = parser.parse_args()
236
+ logger.info(f"args: {args}")
237
+
238
+ # Set global variables
239
+ set_global_vars(args.controller_url, args.moderate)
240
+ set_global_vars_named(args.moderate)
241
+ set_global_vars_anony(args.moderate)
242
+ models, all_models = get_model_list(
243
+ args.controller_url,
244
+ args.register_api_endpoint_file,
245
+ False,
246
+ )
247
+
248
+ vl_models, all_vl_models = get_model_list(
249
+ args.controller_url,
250
+ args.register_api_endpoint_file,
251
+ True,
252
+ )
253
+
254
+ # Set authorization credentials
255
+ auth = None
256
+ if args.gradio_auth_path is not None:
257
+ auth = parse_gradio_auth_creds(args.gradio_auth_path)
258
+
259
+ # Launch the demo
260
+ demo = build_demo(
261
+ models,
262
+ vl_models,
263
+ args.elo_results_file,
264
+ args.leaderboard_table_file,
265
+ )
266
+ demo.queue(
267
+ default_concurrency_limit=args.concurrency_count,
268
+ status_update_rate=10,
269
+ api_open=False,
270
+ ).launch(
271
+ server_name=args.host,
272
+ server_port=args.port,
273
+ share=args.share,
274
+ max_threads=200,
275
+ auth=auth,
276
+ root_path=args.gradio_root_path,
277
+ )
huggingface_api.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Use FastChat with Hugging Face generation APIs.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5
6
+ python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0
7
+ """
8
+ import argparse
9
+
10
+ import torch
11
+
12
+ from fastchat.model import load_model, get_conversation_template, add_model_args
13
+
14
+
15
+ @torch.inference_mode()
16
+ def main(args):
17
+ # Load model
18
+ model, tokenizer = load_model(
19
+ args.model_path,
20
+ device=args.device,
21
+ num_gpus=args.num_gpus,
22
+ max_gpu_memory=args.max_gpu_memory,
23
+ load_8bit=args.load_8bit,
24
+ cpu_offloading=args.cpu_offloading,
25
+ revision=args.revision,
26
+ debug=args.debug,
27
+ )
28
+
29
+ # Build the prompt with a conversation template
30
+ msg = args.message
31
+ conv = get_conversation_template(args.model_path)
32
+ conv.append_message(conv.roles[0], msg)
33
+ conv.append_message(conv.roles[1], None)
34
+ prompt = conv.get_prompt()
35
+
36
+ # Run inference
37
+ inputs = tokenizer([prompt], return_tensors="pt").to(args.device)
38
+ output_ids = model.generate(
39
+ **inputs,
40
+ do_sample=True if args.temperature > 1e-5 else False,
41
+ temperature=args.temperature,
42
+ repetition_penalty=args.repetition_penalty,
43
+ max_new_tokens=args.max_new_tokens,
44
+ )
45
+
46
+ if model.config.is_encoder_decoder:
47
+ output_ids = output_ids[0]
48
+ else:
49
+ output_ids = output_ids[0][len(inputs["input_ids"][0]) :]
50
+ outputs = tokenizer.decode(
51
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
52
+ )
53
+
54
+ # Print results
55
+ print(f"{conv.roles[0]}: {msg}")
56
+ print(f"{conv.roles[1]}: {outputs}")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ parser = argparse.ArgumentParser()
61
+ add_model_args(parser)
62
+ parser.add_argument("--temperature", type=float, default=0.7)
63
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
64
+ parser.add_argument("--max-new-tokens", type=int, default=1024)
65
+ parser.add_argument("--debug", action="store_true")
66
+ parser.add_argument("--message", type=str, default="Hello! Who are you?")
67
+ args = parser.parse_args()
68
+
69
+ # Reset default repetition penalty for T5 models.
70
+ if "t5" in args.model_path and args.repetition_penalty == 1.0:
71
+ args.repetition_penalty = 1.2
72
+
73
+ main(args)
huggingface_api_worker.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker that calls huggingface inference endpoint.
3
+
4
+ Register models in a JSON file with the following format:
5
+ {
6
+ "falcon-180b-chat": {
7
+ "model_name": "falcon-180B-chat",
8
+ "api_base": "https://api-inference.huggingface.co/models",
9
+ "model_path": "tiiuae/falcon-180B-chat",
10
+ "token": "hf_XXX",
11
+ "context_length": 2048
12
+ },
13
+ "zephyr-7b-beta": {
14
+ "model_name": "zephyr-7b-beta",
15
+ "model_path": "",
16
+ "api_base": "xxx",
17
+ "token": "hf_XXX",
18
+ "context_length": 4096
19
+ }
20
+ }
21
+
22
+ "model_path", "api_base", "token", and "context_length" are necessary, while others are optional.
23
+ """
24
+ import argparse
25
+ import asyncio
26
+ import json
27
+ import uuid
28
+ import os
29
+ from typing import List, Optional
30
+
31
+ import requests
32
+ import uvicorn
33
+ from fastapi import BackgroundTasks, FastAPI, Request
34
+ from fastapi.responses import JSONResponse, StreamingResponse
35
+ from huggingface_hub import InferenceClient
36
+
37
+ from fastchat.constants import SERVER_ERROR_MSG, ErrorCode
38
+ from fastchat.serve.base_model_worker import BaseModelWorker
39
+ from fastchat.utils import build_logger
40
+
41
+ worker_id = str(uuid.uuid4())[:8]
42
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
43
+
44
+ workers = []
45
+ worker_map = {}
46
+ app = FastAPI()
47
+
48
+
49
+ # reference to
50
+ # https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392
51
+ def get_gen_kwargs(
52
+ params,
53
+ seed: Optional[int] = None,
54
+ ):
55
+ stop = params.get("stop", None)
56
+ if isinstance(stop, list):
57
+ stop_sequences = stop
58
+ elif isinstance(stop, str):
59
+ stop_sequences = [stop]
60
+ else:
61
+ stop_sequences = []
62
+ gen_kwargs = {
63
+ "do_sample": True,
64
+ "return_full_text": bool(params.get("echo", False)),
65
+ "max_new_tokens": int(params.get("max_new_tokens", 256)),
66
+ "top_p": float(params.get("top_p", 1.0)),
67
+ "temperature": float(params.get("temperature", 1.0)),
68
+ "stop_sequences": stop_sequences,
69
+ "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
70
+ "top_k": params.get("top_k", None),
71
+ "seed": seed,
72
+ }
73
+ if gen_kwargs["top_p"] == 1:
74
+ gen_kwargs["top_p"] = 0.9999999
75
+ if gen_kwargs["top_p"] == 0:
76
+ gen_kwargs.pop("top_p")
77
+ if gen_kwargs["temperature"] == 0:
78
+ gen_kwargs.pop("temperature")
79
+ gen_kwargs["do_sample"] = False
80
+ return gen_kwargs
81
+
82
+
83
+ def could_be_stop(text, stop):
84
+ for s in stop:
85
+ if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)):
86
+ return True
87
+ return False
88
+
89
+
90
+ class HuggingfaceApiWorker(BaseModelWorker):
91
+ def __init__(
92
+ self,
93
+ controller_addr: str,
94
+ worker_addr: str,
95
+ worker_id: str,
96
+ model_path: str,
97
+ api_base: str,
98
+ token: str,
99
+ context_length: int,
100
+ model_names: List[str],
101
+ limit_worker_concurrency: int,
102
+ no_register: bool,
103
+ conv_template: Optional[str] = None,
104
+ seed: Optional[int] = None,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(
108
+ controller_addr,
109
+ worker_addr,
110
+ worker_id,
111
+ model_path,
112
+ model_names,
113
+ limit_worker_concurrency,
114
+ conv_template=conv_template,
115
+ )
116
+
117
+ self.model_path = model_path
118
+ self.api_base = api_base
119
+ self.token = token
120
+ self.context_len = context_length
121
+ self.seed = seed
122
+
123
+ logger.info(
124
+ f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..."
125
+ )
126
+
127
+ if not no_register:
128
+ self.init_heart_beat()
129
+
130
+ def count_token(self, params):
131
+ # No tokenizer here
132
+ ret = {
133
+ "count": 0,
134
+ "error_code": 0,
135
+ }
136
+ return ret
137
+
138
+ def generate_stream_gate(self, params):
139
+ self.call_ct += 1
140
+
141
+ prompt = params["prompt"]
142
+ gen_kwargs = get_gen_kwargs(params, seed=self.seed)
143
+ stop = gen_kwargs["stop_sequences"]
144
+ if "falcon" in self.model_path and "chat" in self.model_path:
145
+ stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"])
146
+ stop = list(set(stop))
147
+ gen_kwargs["stop_sequences"] = stop
148
+
149
+ logger.info(f"prompt: {prompt}")
150
+ logger.info(f"gen_kwargs: {gen_kwargs}")
151
+
152
+ try:
153
+ if self.model_path == "":
154
+ url = f"{self.api_base}"
155
+ else:
156
+ url = f"{self.api_base}/{self.model_path}"
157
+ client = InferenceClient(url, token=self.token)
158
+ res = client.text_generation(
159
+ prompt, stream=True, details=True, **gen_kwargs
160
+ )
161
+
162
+ reason = None
163
+ text = ""
164
+ for chunk in res:
165
+ if chunk.token.special:
166
+ continue
167
+ text += chunk.token.text
168
+
169
+ s = next((x for x in stop if text.endswith(x)), None)
170
+ if s is not None:
171
+ text = text[: -len(s)]
172
+ reason = "stop"
173
+ break
174
+ if could_be_stop(text, stop):
175
+ continue
176
+ if (
177
+ chunk.details is not None
178
+ and chunk.details.finish_reason is not None
179
+ ):
180
+ reason = chunk.details.finish_reason
181
+ if reason not in ["stop", "length"]:
182
+ reason = None
183
+ ret = {
184
+ "text": text,
185
+ "error_code": 0,
186
+ "finish_reason": reason,
187
+ }
188
+ yield json.dumps(ret).encode() + b"\0"
189
+ except Exception as e:
190
+ ret = {
191
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
192
+ "error_code": ErrorCode.INTERNAL_ERROR,
193
+ }
194
+ yield json.dumps(ret).encode() + b"\0"
195
+
196
+ def generate_gate(self, params):
197
+ for x in self.generate_stream_gate(params):
198
+ pass
199
+ return json.loads(x[:-1].decode())
200
+
201
+ def get_embeddings(self, params):
202
+ raise NotImplementedError()
203
+
204
+
205
+ def release_worker_semaphore(worker):
206
+ worker.semaphore.release()
207
+
208
+
209
+ def acquire_worker_semaphore(worker):
210
+ if worker.semaphore is None:
211
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
212
+ return worker.semaphore.acquire()
213
+
214
+
215
+ def create_background_tasks(worker):
216
+ background_tasks = BackgroundTasks()
217
+ background_tasks.add_task(lambda: release_worker_semaphore(worker))
218
+ return background_tasks
219
+
220
+
221
+ @app.post("/worker_generate_stream")
222
+ async def api_generate_stream(request: Request):
223
+ params = await request.json()
224
+ worker = worker_map[params["model"]]
225
+ await acquire_worker_semaphore(worker)
226
+ generator = worker.generate_stream_gate(params)
227
+ background_tasks = create_background_tasks(worker)
228
+ return StreamingResponse(generator, background=background_tasks)
229
+
230
+
231
+ @app.post("/worker_generate")
232
+ async def api_generate(request: Request):
233
+ params = await request.json()
234
+ worker = worker_map[params["model"]]
235
+ await acquire_worker_semaphore(worker)
236
+ output = worker.generate_gate(params)
237
+ release_worker_semaphore(worker)
238
+ return JSONResponse(output)
239
+
240
+
241
+ @app.post("/worker_get_embeddings")
242
+ async def api_get_embeddings(request: Request):
243
+ params = await request.json()
244
+ worker = worker_map[params["model"]]
245
+ await acquire_worker_semaphore(worker)
246
+ embedding = worker.get_embeddings(params)
247
+ release_worker_semaphore(worker)
248
+ return JSONResponse(content=embedding)
249
+
250
+
251
+ @app.post("/worker_get_status")
252
+ async def api_get_status(request: Request):
253
+ return {
254
+ "model_names": [m for w in workers for m in w.model_names],
255
+ "speed": 1,
256
+ "queue_length": sum([w.get_queue_length() for w in workers]),
257
+ }
258
+
259
+
260
+ @app.post("/count_token")
261
+ async def api_count_token(request: Request):
262
+ params = await request.json()
263
+ worker = worker_map[params["model"]]
264
+ return worker.count_token(params)
265
+
266
+
267
+ @app.post("/worker_get_conv_template")
268
+ async def api_get_conv(request: Request):
269
+ params = await request.json()
270
+ worker = worker_map[params["model"]]
271
+ return worker.get_conv_template()
272
+
273
+
274
+ @app.post("/model_details")
275
+ async def api_model_details(request: Request):
276
+ params = await request.json()
277
+ worker = worker_map[params["model"]]
278
+ return {"context_length": worker.context_len}
279
+
280
+
281
+ def create_huggingface_api_worker():
282
+ parser = argparse.ArgumentParser()
283
+ parser.add_argument("--host", type=str, default="localhost")
284
+ parser.add_argument("--port", type=int, default=21002)
285
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
286
+ parser.add_argument(
287
+ "--controller-address", type=str, default="http://localhost:21001"
288
+ )
289
+ # all model-related parameters are listed in --model-info-file
290
+ parser.add_argument(
291
+ "--model-info-file",
292
+ type=str,
293
+ required=True,
294
+ help="Huggingface API model's info file path",
295
+ )
296
+
297
+ parser.add_argument(
298
+ "--limit-worker-concurrency",
299
+ type=int,
300
+ default=5,
301
+ help="Limit the model concurrency to prevent OOM.",
302
+ )
303
+ parser.add_argument("--no-register", action="store_true")
304
+ parser.add_argument(
305
+ "--seed",
306
+ type=int,
307
+ default=None,
308
+ help="Overwrite the random seed for each generation.",
309
+ )
310
+ parser.add_argument(
311
+ "--ssl",
312
+ action="store_true",
313
+ required=False,
314
+ default=False,
315
+ help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
316
+ )
317
+ args = parser.parse_args()
318
+
319
+ with open(args.model_info_file, "r", encoding="UTF-8") as f:
320
+ model_info = json.load(f)
321
+
322
+ logger.info(f"args: {args}")
323
+
324
+ model_path_list = []
325
+ api_base_list = []
326
+ token_list = []
327
+ context_length_list = []
328
+ model_names_list = []
329
+ conv_template_list = []
330
+
331
+ for m in model_info:
332
+ model_path_list.append(model_info[m]["model_path"])
333
+ api_base_list.append(model_info[m]["api_base"])
334
+ token_list.append(model_info[m]["token"])
335
+
336
+ context_length = model_info[m]["context_length"]
337
+ model_names = model_info[m].get("model_names", [m.split("/")[-1]])
338
+ if isinstance(model_names, str):
339
+ model_names = [model_names]
340
+ conv_template = model_info[m].get("conv_template", None)
341
+
342
+ context_length_list.append(context_length)
343
+ model_names_list.append(model_names)
344
+ conv_template_list.append(conv_template)
345
+
346
+ logger.info(f"Model paths: {model_path_list}")
347
+ logger.info(f"API bases: {api_base_list}")
348
+ logger.info(f"Tokens: {token_list}")
349
+ logger.info(f"Context lengths: {context_length_list}")
350
+ logger.info(f"Model names: {model_names_list}")
351
+ logger.info(f"Conv templates: {conv_template_list}")
352
+
353
+ for (
354
+ model_names,
355
+ conv_template,
356
+ model_path,
357
+ api_base,
358
+ token,
359
+ context_length,
360
+ ) in zip(
361
+ model_names_list,
362
+ conv_template_list,
363
+ model_path_list,
364
+ api_base_list,
365
+ token_list,
366
+ context_length_list,
367
+ ):
368
+ m = HuggingfaceApiWorker(
369
+ args.controller_address,
370
+ args.worker_address,
371
+ worker_id,
372
+ model_path,
373
+ api_base,
374
+ token,
375
+ context_length,
376
+ model_names,
377
+ args.limit_worker_concurrency,
378
+ no_register=args.no_register,
379
+ conv_template=conv_template,
380
+ seed=args.seed,
381
+ )
382
+ workers.append(m)
383
+ for name in model_names:
384
+ worker_map[name] = m
385
+
386
+ # register all the models
387
+ url = args.controller_address + "/register_worker"
388
+ data = {
389
+ "worker_name": workers[0].worker_addr,
390
+ "check_heart_beat": not args.no_register,
391
+ "worker_status": {
392
+ "model_names": [m for w in workers for m in w.model_names],
393
+ "speed": 1,
394
+ "queue_length": sum([w.get_queue_length() for w in workers]),
395
+ },
396
+ }
397
+ r = requests.post(url, json=data)
398
+ assert r.status_code == 200
399
+
400
+ return args, workers
401
+
402
+
403
+ if __name__ == "__main__":
404
+ args, workers = create_huggingface_api_worker()
405
+ if args.ssl:
406
+ uvicorn.run(
407
+ app,
408
+ host=args.host,
409
+ port=args.port,
410
+ log_level="info",
411
+ ssl_keyfile=os.environ["SSL_KEYFILE"],
412
+ ssl_certfile=os.environ["SSL_CERTFILE"],
413
+ )
414
+ else:
415
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
inference.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference for FastChat models."""
2
+ import abc
3
+ import gc
4
+ import json
5
+ import math
6
+ import os
7
+ import sys
8
+ import time
9
+ from typing import Iterable, Optional, Dict
10
+ import warnings
11
+
12
+ import psutil
13
+ import torch
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ AutoModelForCausalLM,
17
+ LlamaTokenizer,
18
+ LlamaForCausalLM,
19
+ AutoModel,
20
+ AutoModelForSeq2SeqLM,
21
+ T5Tokenizer,
22
+ AutoConfig,
23
+ )
24
+ from transformers.generation.logits_process import (
25
+ LogitsProcessorList,
26
+ RepetitionPenaltyLogitsProcessor,
27
+ TemperatureLogitsWarper,
28
+ TopKLogitsWarper,
29
+ TopPLogitsWarper,
30
+ )
31
+
32
+ from fastchat.conversation import get_conv_template, SeparatorStyle
33
+ from fastchat.model.model_adapter import (
34
+ load_model,
35
+ get_conversation_template,
36
+ get_generate_stream_function,
37
+ )
38
+ from fastchat.modules.awq import AWQConfig
39
+ from fastchat.modules.gptq import GptqConfig
40
+ from fastchat.modules.exllama import ExllamaConfig
41
+ from fastchat.modules.xfastertransformer import XftConfig
42
+ from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
43
+
44
+
45
+ def prepare_logits_processor(
46
+ temperature: float, repetition_penalty: float, top_p: float, top_k: int
47
+ ) -> LogitsProcessorList:
48
+ processor_list = LogitsProcessorList()
49
+ # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
50
+ if temperature >= 1e-5 and temperature != 1.0:
51
+ processor_list.append(TemperatureLogitsWarper(temperature))
52
+ if repetition_penalty > 1.0:
53
+ processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
54
+ if 1e-8 <= top_p < 1.0:
55
+ processor_list.append(TopPLogitsWarper(top_p))
56
+ if top_k > 0:
57
+ processor_list.append(TopKLogitsWarper(top_k))
58
+ return processor_list
59
+
60
+
61
+ @torch.inference_mode()
62
+ def generate_stream(
63
+ model,
64
+ tokenizer,
65
+ params: Dict,
66
+ device: str,
67
+ context_len: int,
68
+ stream_interval: int = 2,
69
+ judge_sent_end: bool = False,
70
+ ):
71
+ if hasattr(model, "device"):
72
+ device = model.device
73
+
74
+ # Read parameters
75
+ prompt = params["prompt"]
76
+ len_prompt = len(prompt)
77
+ temperature = float(params.get("temperature", 1.0))
78
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
79
+ top_p = float(params.get("top_p", 1.0))
80
+ top_k = int(params.get("top_k", -1)) # -1 means disable
81
+ max_new_tokens = int(params.get("max_new_tokens", 256))
82
+ logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
83
+ echo = bool(params.get("echo", True))
84
+ stop_str = params.get("stop", None)
85
+ stop_token_ids = params.get("stop_token_ids", None) or []
86
+ if tokenizer.eos_token_id not in stop_token_ids:
87
+ stop_token_ids.append(tokenizer.eos_token_id)
88
+
89
+ logits_processor = prepare_logits_processor(
90
+ temperature, repetition_penalty, top_p, top_k
91
+ )
92
+ input_ids = tokenizer(prompt).input_ids
93
+
94
+ if model.config.is_encoder_decoder:
95
+ max_src_len = context_len
96
+ else: # truncate
97
+ max_src_len = context_len - max_new_tokens - 1
98
+
99
+ input_ids = input_ids[-max_src_len:]
100
+ output_ids = list(input_ids)
101
+ input_echo_len = len(input_ids)
102
+
103
+ if model.config.is_encoder_decoder:
104
+ if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
105
+ raise NotImplementedError
106
+ encoder_output = model.encoder(
107
+ input_ids=torch.as_tensor([input_ids], device=device)
108
+ )[0]
109
+ start_ids = torch.as_tensor(
110
+ [[model.generation_config.decoder_start_token_id]],
111
+ dtype=torch.int64,
112
+ device=device,
113
+ )
114
+ else:
115
+ start_ids = torch.as_tensor([input_ids], device=device)
116
+
117
+ past_key_values = out = None
118
+ token_logprobs = [None] # The first token has no logprobs.
119
+ sent_interrupt = False
120
+ finish_reason = None
121
+ stopped = False
122
+ for i in range(max_new_tokens):
123
+ if i == 0: # prefill
124
+ if model.config.is_encoder_decoder:
125
+ out = model.decoder(
126
+ input_ids=start_ids,
127
+ encoder_hidden_states=encoder_output,
128
+ use_cache=True,
129
+ )
130
+ logits = model.lm_head(out[0])
131
+ else:
132
+ out = model(input_ids=start_ids, use_cache=True)
133
+ logits = out.logits
134
+ past_key_values = out.past_key_values
135
+
136
+ if logprobs is not None:
137
+ # Prefull logprobs for the prompt.
138
+ shift_input_ids = start_ids[..., 1:].contiguous()
139
+ shift_logits = logits[..., :-1, :].contiguous()
140
+ shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
141
+ for label_id, logit in zip(
142
+ shift_input_ids[0].tolist(), shift_logits[0]
143
+ ):
144
+ token_logprobs.append(logit[label_id])
145
+ else: # decoding
146
+ if model.config.is_encoder_decoder:
147
+ out = model.decoder(
148
+ input_ids=torch.as_tensor(
149
+ [[token] if not sent_interrupt else output_ids],
150
+ device=device,
151
+ ),
152
+ encoder_hidden_states=encoder_output,
153
+ use_cache=True,
154
+ past_key_values=past_key_values if not sent_interrupt else None,
155
+ )
156
+ sent_interrupt = False
157
+
158
+ logits = model.lm_head(out[0])
159
+ else:
160
+ out = model(
161
+ input_ids=torch.as_tensor(
162
+ [[token] if not sent_interrupt else output_ids],
163
+ device=device,
164
+ ),
165
+ use_cache=True,
166
+ past_key_values=past_key_values if not sent_interrupt else None,
167
+ )
168
+ sent_interrupt = False
169
+ logits = out.logits
170
+ past_key_values = out.past_key_values
171
+
172
+ if logits_processor:
173
+ if repetition_penalty > 1.0:
174
+ tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
175
+ else:
176
+ tmp_output_ids = None
177
+ last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
178
+ else:
179
+ last_token_logits = logits[0, -1, :]
180
+
181
+ if device == "mps":
182
+ # Switch to CPU by avoiding some bugs in mps backend.
183
+ last_token_logits = last_token_logits.float().to("cpu")
184
+
185
+ if temperature < 1e-5 or top_p < 1e-8: # greedy
186
+ _, indices = torch.topk(last_token_logits, 2)
187
+ tokens = [int(index) for index in indices.tolist()]
188
+ else:
189
+ probs = torch.softmax(last_token_logits, dim=-1)
190
+ indices = torch.multinomial(probs, num_samples=2)
191
+ tokens = [int(token) for token in indices.tolist()]
192
+ token = tokens[0]
193
+ output_ids.append(token)
194
+ if logprobs is not None:
195
+ # Cannot use last_token_logits because logprobs is based on raw logits.
196
+ token_logprobs.append(
197
+ torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
198
+ )
199
+
200
+ if token in stop_token_ids:
201
+ stopped = True
202
+ else:
203
+ stopped = False
204
+
205
+ # Yield the output tokens
206
+ if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
207
+ if echo:
208
+ tmp_output_ids = output_ids
209
+ rfind_start = len_prompt
210
+ else:
211
+ tmp_output_ids = output_ids[input_echo_len:]
212
+ rfind_start = 0
213
+
214
+ output = tokenizer.decode(
215
+ tmp_output_ids,
216
+ skip_special_tokens=True,
217
+ spaces_between_special_tokens=False,
218
+ clean_up_tokenization_spaces=True,
219
+ )
220
+ ret_logprobs = None
221
+ if logprobs is not None:
222
+ ret_logprobs = {
223
+ "text_offset": [],
224
+ "tokens": [
225
+ tokenizer.decode(token)
226
+ for token in (
227
+ output_ids if echo else output_ids[input_echo_len:]
228
+ )
229
+ ],
230
+ "token_logprobs": token_logprobs
231
+ if echo
232
+ else token_logprobs[input_echo_len:],
233
+ "top_logprobs": [{}]
234
+ * len(token_logprobs if echo else token_logprobs[input_echo_len:]),
235
+ }
236
+ # Compute text_offset
237
+ curr_pos = 0
238
+ for text in ret_logprobs["tokens"]:
239
+ ret_logprobs["text_offset"].append(curr_pos)
240
+ curr_pos += len(text)
241
+
242
+ # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
243
+ if judge_sent_end and stopped and not is_sentence_complete(output):
244
+ if len(tokens) > 1:
245
+ token = tokens[1]
246
+ output_ids[-1] = token
247
+ else:
248
+ output_ids.pop()
249
+ stopped = False
250
+ sent_interrupt = True
251
+
252
+ partially_stopped = False
253
+ if stop_str:
254
+ if isinstance(stop_str, str):
255
+ pos = output.rfind(stop_str, rfind_start)
256
+ if pos != -1:
257
+ output = output[:pos]
258
+ stopped = True
259
+ else:
260
+ partially_stopped = is_partial_stop(output, stop_str)
261
+ elif isinstance(stop_str, Iterable):
262
+ for each_stop in stop_str:
263
+ pos = output.rfind(each_stop, rfind_start)
264
+ if pos != -1:
265
+ output = output[:pos]
266
+ stopped = True
267
+ break
268
+ else:
269
+ partially_stopped = is_partial_stop(output, each_stop)
270
+ if partially_stopped:
271
+ break
272
+ else:
273
+ raise ValueError("Invalid stop field type.")
274
+
275
+ # Prevent yielding partial stop sequence
276
+ if not partially_stopped:
277
+ yield {
278
+ "text": output,
279
+ "logprobs": ret_logprobs,
280
+ "usage": {
281
+ "prompt_tokens": input_echo_len,
282
+ "completion_tokens": i,
283
+ "total_tokens": input_echo_len + i,
284
+ },
285
+ "finish_reason": None,
286
+ }
287
+
288
+ if stopped:
289
+ break
290
+
291
+ # Finish stream event, which contains finish reason
292
+ else:
293
+ finish_reason = "length"
294
+
295
+ if stopped:
296
+ finish_reason = "stop"
297
+
298
+ yield {
299
+ "text": output,
300
+ "logprobs": ret_logprobs,
301
+ "usage": {
302
+ "prompt_tokens": input_echo_len,
303
+ "completion_tokens": i,
304
+ "total_tokens": input_echo_len + i,
305
+ },
306
+ "finish_reason": finish_reason,
307
+ }
308
+
309
+ # Clean
310
+ del past_key_values, out
311
+ gc.collect()
312
+ torch.cuda.empty_cache()
313
+ if device == "xpu":
314
+ torch.xpu.empty_cache()
315
+ if device == "npu":
316
+ torch.npu.empty_cache()
317
+
318
+
319
+ class ChatIO(abc.ABC):
320
+ @abc.abstractmethod
321
+ def prompt_for_input(self, role: str) -> str:
322
+ """Prompt for input from a role."""
323
+
324
+ @abc.abstractmethod
325
+ def prompt_for_output(self, role: str):
326
+ """Prompt for output from a role."""
327
+
328
+ @abc.abstractmethod
329
+ def stream_output(self, output_stream):
330
+ """Stream output."""
331
+
332
+ @abc.abstractmethod
333
+ def print_output(self, text: str):
334
+ """Print output."""
335
+
336
+
337
+ def chat_loop(
338
+ model_path: str,
339
+ device: str,
340
+ num_gpus: int,
341
+ max_gpu_memory: str,
342
+ dtype: Optional[torch.dtype],
343
+ load_8bit: bool,
344
+ cpu_offloading: bool,
345
+ conv_template: Optional[str],
346
+ conv_system_msg: Optional[str],
347
+ temperature: float,
348
+ repetition_penalty: float,
349
+ max_new_tokens: int,
350
+ chatio: ChatIO,
351
+ gptq_config: Optional[GptqConfig] = None,
352
+ awq_config: Optional[AWQConfig] = None,
353
+ exllama_config: Optional[ExllamaConfig] = None,
354
+ xft_config: Optional[XftConfig] = None,
355
+ revision: str = "main",
356
+ judge_sent_end: bool = True,
357
+ debug: bool = True,
358
+ history: bool = True,
359
+ ):
360
+ # Model
361
+ model, tokenizer = load_model(
362
+ model_path,
363
+ device=device,
364
+ num_gpus=num_gpus,
365
+ max_gpu_memory=max_gpu_memory,
366
+ dtype=dtype,
367
+ load_8bit=load_8bit,
368
+ cpu_offloading=cpu_offloading,
369
+ gptq_config=gptq_config,
370
+ awq_config=awq_config,
371
+ exllama_config=exllama_config,
372
+ xft_config=xft_config,
373
+ revision=revision,
374
+ debug=debug,
375
+ )
376
+ generate_stream_func = get_generate_stream_function(model, model_path)
377
+
378
+ model_type = str(type(model)).lower()
379
+ is_t5 = "t5" in model_type
380
+ is_codet5p = "codet5p" in model_type
381
+ is_xft = "xft" in model_type
382
+
383
+ # Hardcode T5's default repetition penalty to be 1.2
384
+ if is_t5 and repetition_penalty == 1.0:
385
+ repetition_penalty = 1.2
386
+
387
+ # Set context length
388
+ context_len = get_context_length(model.config)
389
+
390
+ # Chat
391
+ def new_chat():
392
+ if conv_template:
393
+ conv = get_conv_template(conv_template)
394
+ else:
395
+ conv = get_conversation_template(model_path)
396
+ if conv_system_msg is not None:
397
+ conv.set_system_message(conv_system_msg)
398
+ return conv
399
+
400
+ def reload_conv(conv):
401
+ """
402
+ Reprints the conversation from the start.
403
+ """
404
+ for message in conv.messages[conv.offset :]:
405
+ chatio.prompt_for_output(message[0])
406
+ chatio.print_output(message[1])
407
+
408
+ conv = None
409
+
410
+ while True:
411
+ if not history or not conv:
412
+ conv = new_chat()
413
+
414
+ try:
415
+ inp = chatio.prompt_for_input(conv.roles[0])
416
+ except EOFError:
417
+ inp = ""
418
+
419
+ if inp == "!!exit" or not inp:
420
+ print("exit...")
421
+ break
422
+ elif inp == "!!reset":
423
+ print("resetting...")
424
+ conv = new_chat()
425
+ continue
426
+ elif inp == "!!remove":
427
+ print("removing last message...")
428
+ if len(conv.messages) > conv.offset:
429
+ # Assistant
430
+ if conv.messages[-1][0] == conv.roles[1]:
431
+ conv.messages.pop()
432
+ # User
433
+ if conv.messages[-1][0] == conv.roles[0]:
434
+ conv.messages.pop()
435
+ reload_conv(conv)
436
+ else:
437
+ print("No messages to remove.")
438
+ continue
439
+ elif inp == "!!regen":
440
+ print("regenerating last message...")
441
+ if len(conv.messages) > conv.offset:
442
+ # Assistant
443
+ if conv.messages[-1][0] == conv.roles[1]:
444
+ conv.messages.pop()
445
+ # User
446
+ if conv.messages[-1][0] == conv.roles[0]:
447
+ reload_conv(conv)
448
+ # Set inp to previous message
449
+ inp = conv.messages.pop()[1]
450
+ else:
451
+ # Shouldn't happen in normal circumstances
452
+ print("No user message to regenerate from.")
453
+ continue
454
+ else:
455
+ print("No messages to regenerate.")
456
+ continue
457
+ elif inp.startswith("!!save"):
458
+ args = inp.split(" ", 1)
459
+
460
+ if len(args) != 2:
461
+ print("usage: !!save <filename>")
462
+ continue
463
+ else:
464
+ filename = args[1]
465
+
466
+ # Add .json if extension not present
467
+ if not "." in filename:
468
+ filename += ".json"
469
+
470
+ print("saving...", filename)
471
+ with open(filename, "w") as outfile:
472
+ json.dump(conv.dict(), outfile)
473
+ continue
474
+ elif inp.startswith("!!load"):
475
+ args = inp.split(" ", 1)
476
+
477
+ if len(args) != 2:
478
+ print("usage: !!load <filename>")
479
+ continue
480
+ else:
481
+ filename = args[1]
482
+
483
+ # Check if file exists and add .json if needed
484
+ if not os.path.exists(filename):
485
+ if (not filename.endswith(".json")) and os.path.exists(
486
+ filename + ".json"
487
+ ):
488
+ filename += ".json"
489
+ else:
490
+ print("file not found:", filename)
491
+ continue
492
+
493
+ print("loading...", filename)
494
+ with open(filename, "r") as infile:
495
+ new_conv = json.load(infile)
496
+
497
+ conv = get_conv_template(new_conv["template_name"])
498
+ conv.set_system_message(new_conv["system_message"])
499
+ conv.messages = new_conv["messages"]
500
+ reload_conv(conv)
501
+ continue
502
+
503
+ conv.append_message(conv.roles[0], inp)
504
+ conv.append_message(conv.roles[1], None)
505
+ prompt = conv.get_prompt()
506
+
507
+ if is_codet5p: # codet5p is a code completion model.
508
+ prompt = inp
509
+
510
+ gen_params = {
511
+ "model": model_path,
512
+ "prompt": prompt,
513
+ "temperature": temperature,
514
+ "repetition_penalty": repetition_penalty,
515
+ "max_new_tokens": max_new_tokens,
516
+ "stop": conv.stop_str,
517
+ "stop_token_ids": conv.stop_token_ids,
518
+ "echo": False,
519
+ }
520
+
521
+ try:
522
+ chatio.prompt_for_output(conv.roles[1])
523
+ output_stream = generate_stream_func(
524
+ model,
525
+ tokenizer,
526
+ gen_params,
527
+ device,
528
+ context_len=context_len,
529
+ judge_sent_end=judge_sent_end,
530
+ )
531
+ t = time.time()
532
+ outputs = chatio.stream_output(output_stream)
533
+ duration = time.time() - t
534
+ conv.update_last_message(outputs.strip())
535
+
536
+ if debug:
537
+ num_tokens = len(tokenizer.encode(outputs))
538
+ msg = {
539
+ "conv_template": conv.name,
540
+ "prompt": prompt,
541
+ "outputs": outputs,
542
+ "speed (token/s)": round(num_tokens / duration, 2),
543
+ }
544
+ print(f"\n{msg}\n")
545
+
546
+ except KeyboardInterrupt:
547
+ print("stopped generation.")
548
+ # If generation didn't finish
549
+ if conv.messages[-1][1] is None:
550
+ conv.messages.pop()
551
+ # Remove last user message, so there isn't a double up
552
+ if conv.messages[-1][0] == conv.roles[0]:
553
+ conv.messages.pop()
554
+
555
+ reload_conv(conv)
launch_all_serve.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022"
3
+
4
+ Workers are listed in format of `model-path`@`host`@`port`
5
+
6
+ The key mechanism behind this scripts is:
7
+ 1, execute shell cmd to launch the controller/worker/openai-api-server;
8
+ 2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly.
9
+ Note that a few of non-critical `fastchat.serve` cmd options are not supported currently.
10
+ """
11
+ import sys
12
+ import os
13
+
14
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
15
+
16
+ import subprocess
17
+ import re
18
+ import argparse
19
+
20
+ LOGDIR = "./logs/"
21
+
22
+ if not os.path.exists(LOGDIR):
23
+ os.makedirs(LOGDIR)
24
+
25
+ parser = argparse.ArgumentParser()
26
+ # ------multi worker-----------------
27
+ parser.add_argument(
28
+ "--model-path-address",
29
+ default="THUDM/chatglm2-6b@localhost@20002",
30
+ nargs="+",
31
+ type=str,
32
+ help="model path, host, and port, formatted as model-path@host@port",
33
+ )
34
+ # ---------------controller-------------------------
35
+
36
+ parser.add_argument("--controller-host", type=str, default="localhost")
37
+ parser.add_argument("--controller-port", type=int, default=21001)
38
+ parser.add_argument(
39
+ "--dispatch-method",
40
+ type=str,
41
+ choices=["lottery", "shortest_queue"],
42
+ default="shortest_queue",
43
+ )
44
+ controller_args = ["controller-host", "controller-port", "dispatch-method"]
45
+
46
+ # ----------------------worker------------------------------------------
47
+
48
+ parser.add_argument("--worker-host", type=str, default="localhost")
49
+ parser.add_argument("--worker-port", type=int, default=21002)
50
+ # parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
51
+ # parser.add_argument(
52
+ # "--controller-address", type=str, default="http://localhost:21001"
53
+ # )
54
+ parser.add_argument(
55
+ "--model-path",
56
+ type=str,
57
+ default="lmsys/vicuna-7b-v1.5",
58
+ help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
59
+ )
60
+ parser.add_argument(
61
+ "--revision",
62
+ type=str,
63
+ default="main",
64
+ help="Hugging Face Hub model revision identifier",
65
+ )
66
+ parser.add_argument(
67
+ "--device",
68
+ type=str,
69
+ choices=["cpu", "cuda", "mps", "xpu", "npu"],
70
+ default="cuda",
71
+ help="The device type",
72
+ )
73
+ parser.add_argument(
74
+ "--gpus",
75
+ type=str,
76
+ default="0",
77
+ help="A single GPU like 1 or multiple GPUs like 0,2",
78
+ )
79
+ parser.add_argument("--num-gpus", type=int, default=1)
80
+ parser.add_argument(
81
+ "--max-gpu-memory",
82
+ type=str,
83
+ help="The maximum memory per gpu. Use a string like '13Gib'",
84
+ )
85
+ parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization")
86
+ parser.add_argument(
87
+ "--cpu-offloading",
88
+ action="store_true",
89
+ help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
90
+ )
91
+ parser.add_argument(
92
+ "--gptq-ckpt",
93
+ type=str,
94
+ default=None,
95
+ help="Load quantized model. The path to the local GPTQ checkpoint.",
96
+ )
97
+ parser.add_argument(
98
+ "--gptq-wbits",
99
+ type=int,
100
+ default=16,
101
+ choices=[2, 3, 4, 8, 16],
102
+ help="#bits to use for quantization",
103
+ )
104
+ parser.add_argument(
105
+ "--gptq-groupsize",
106
+ type=int,
107
+ default=-1,
108
+ help="Groupsize to use for quantization; default uses full row.",
109
+ )
110
+ parser.add_argument(
111
+ "--gptq-act-order",
112
+ action="store_true",
113
+ help="Whether to apply the activation order GPTQ heuristic",
114
+ )
115
+ parser.add_argument(
116
+ "--model-names",
117
+ type=lambda s: s.split(","),
118
+ help="Optional display comma separated names",
119
+ )
120
+ parser.add_argument(
121
+ "--limit-worker-concurrency",
122
+ type=int,
123
+ default=5,
124
+ help="Limit the model concurrency to prevent OOM.",
125
+ )
126
+ parser.add_argument("--stream-interval", type=int, default=2)
127
+ parser.add_argument("--no-register", action="store_true")
128
+
129
+ worker_args = [
130
+ "worker-host",
131
+ "worker-port",
132
+ "model-path",
133
+ "revision",
134
+ "device",
135
+ "gpus",
136
+ "num-gpus",
137
+ "max-gpu-memory",
138
+ "load-8bit",
139
+ "cpu-offloading",
140
+ "gptq-ckpt",
141
+ "gptq-wbits",
142
+ "gptq-groupsize",
143
+ "gptq-act-order",
144
+ "model-names",
145
+ "limit-worker-concurrency",
146
+ "stream-interval",
147
+ "no-register",
148
+ "controller-address",
149
+ ]
150
+ # -----------------openai server---------------------------
151
+
152
+ parser.add_argument("--server-host", type=str, default="localhost", help="host name")
153
+ parser.add_argument("--server-port", type=int, default=8001, help="port number")
154
+ parser.add_argument(
155
+ "--allow-credentials", action="store_true", help="allow credentials"
156
+ )
157
+ # parser.add_argument(
158
+ # "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
159
+ # )
160
+ # parser.add_argument(
161
+ # "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
162
+ # )
163
+ # parser.add_argument(
164
+ # "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
165
+ # )
166
+ parser.add_argument(
167
+ "--api-keys",
168
+ type=lambda s: s.split(","),
169
+ help="Optional list of comma separated API keys",
170
+ )
171
+ server_args = [
172
+ "server-host",
173
+ "server-port",
174
+ "allow-credentials",
175
+ "api-keys",
176
+ "controller-address",
177
+ ]
178
+
179
+ args = parser.parse_args()
180
+
181
+ args = argparse.Namespace(
182
+ **vars(args),
183
+ **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"},
184
+ )
185
+
186
+ if args.gpus:
187
+ if len(args.gpus.split(",")) < args.num_gpus:
188
+ raise ValueError(
189
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
190
+ )
191
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
192
+
193
+ # 0,controller, model_worker, openai_api_server
194
+ # 1, cmd options
195
+ # 2,LOGDIR
196
+ # 3, log file name
197
+ base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
198
+
199
+ # 0 LOGDIR
200
+ #! 1 log file name
201
+ # 2 controller, worker, openai_api_server
202
+ base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
203
+ sleep 1s;
204
+ echo "wait {2} running"
205
+ done
206
+ echo '{2} running' """
207
+
208
+
209
+ def string_args(args, args_list):
210
+ args_str = ""
211
+ for key, value in args._get_kwargs():
212
+ key = key.replace("_", "-")
213
+ if key not in args_list:
214
+ continue
215
+
216
+ key = key.split("-")[-1] if re.search("port|host", key) else key
217
+ if not value:
218
+ pass
219
+ # 1==True -> True
220
+ elif isinstance(value, bool) and value == True:
221
+ args_str += f" --{key} "
222
+ elif (
223
+ isinstance(value, list)
224
+ or isinstance(value, tuple)
225
+ or isinstance(value, set)
226
+ ):
227
+ value = " ".join(value)
228
+ args_str += f" --{key} {value} "
229
+ else:
230
+ args_str += f" --{key} {value} "
231
+
232
+ return args_str
233
+
234
+
235
+ def launch_worker(item):
236
+ log_name = (
237
+ item.split("/")[-1]
238
+ .split("\\")[-1]
239
+ .replace("-", "_")
240
+ .replace("@", "_")
241
+ .replace(".", "_")
242
+ )
243
+
244
+ args.model_path, args.worker_host, args.worker_port = item.split("@")
245
+ print("*" * 80)
246
+ worker_str_args = string_args(args, worker_args)
247
+ print(worker_str_args)
248
+ worker_sh = base_launch_sh.format(
249
+ "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}"
250
+ )
251
+ worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker")
252
+ subprocess.run(worker_sh, shell=True, check=True)
253
+ subprocess.run(worker_check_sh, shell=True, check=True)
254
+
255
+
256
+ def launch_all():
257
+ controller_str_args = string_args(args, controller_args)
258
+ controller_sh = base_launch_sh.format(
259
+ "controller", controller_str_args, LOGDIR, "controller"
260
+ )
261
+ controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller")
262
+ subprocess.run(controller_sh, shell=True, check=True)
263
+ subprocess.run(controller_check_sh, shell=True, check=True)
264
+
265
+ if isinstance(args.model_path_address, str):
266
+ launch_worker(args.model_path_address)
267
+ else:
268
+ for idx, item in enumerate(args.model_path_address):
269
+ print(f"loading {idx}th model:{item}")
270
+ launch_worker(item)
271
+
272
+ server_str_args = string_args(args, server_args)
273
+ server_sh = base_launch_sh.format(
274
+ "openai_api_server", server_str_args, LOGDIR, "openai_api_server"
275
+ )
276
+ server_check_sh = base_check_sh.format(
277
+ LOGDIR, "openai_api_server", "openai_api_server"
278
+ )
279
+ subprocess.run(server_sh, shell=True, check=True)
280
+ subprocess.run(server_check_sh, shell=True, check=True)
281
+
282
+
283
+ if __name__ == "__main__":
284
+ launch_all()
lightllm_worker.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker that executes the model based on LightLLM.
3
+
4
+ See documentations at docs/lightllm_integration.md
5
+ """
6
+
7
+ import argparse
8
+ import asyncio
9
+ import json
10
+ import os
11
+ import torch
12
+ import uvicorn
13
+
14
+ from transformers import AutoConfig
15
+
16
+ from typing import List
17
+
18
+ from fastapi import FastAPI, Request, BackgroundTasks
19
+ from fastapi.responses import StreamingResponse, JSONResponse
20
+
21
+ from fastchat.serve.base_model_worker import BaseModelWorker
22
+ from fastchat.serve.model_worker import (
23
+ logger,
24
+ worker_id,
25
+ )
26
+
27
+ from lightllm.server.sampling_params import SamplingParams
28
+ from lightllm.server.multimodal_params import MultimodalParams
29
+ from lightllm.server.httpserver.manager import HttpServerManager
30
+ from lightllm.server.detokenization.manager import start_detokenization_process
31
+ from lightllm.server.router.manager import start_router_process
32
+ from lightllm.server.req_id_generator import ReqIDGenerator
33
+
34
+ from lightllm.utils.net_utils import alloc_can_use_network_port
35
+ from lightllm.utils.start_utils import start_submodule_processes
36
+ from fastchat.utils import get_context_length, is_partial_stop
37
+
38
+ app = FastAPI()
39
+ g_id_gen = ReqIDGenerator()
40
+
41
+
42
+ class LightLLMWorker(BaseModelWorker):
43
+ def __init__(
44
+ self,
45
+ controller_addr: str,
46
+ worker_addr: str,
47
+ worker_id: str,
48
+ model_path: str,
49
+ model_names: List[str],
50
+ limit_worker_concurrency: int,
51
+ no_register: bool,
52
+ conv_template: str,
53
+ tokenizer,
54
+ context_len,
55
+ ):
56
+ super().__init__(
57
+ controller_addr,
58
+ worker_addr,
59
+ worker_id,
60
+ model_path,
61
+ model_names,
62
+ limit_worker_concurrency,
63
+ conv_template,
64
+ )
65
+
66
+ logger.info(
67
+ f"Loading the model {self.model_names} on worker {worker_id}, worker type: LightLLM worker..."
68
+ )
69
+ self.tokenizer = tokenizer
70
+ self.context_len = context_len
71
+
72
+ self.is_first = True
73
+
74
+ if not no_register:
75
+ self.init_heart_beat()
76
+
77
+ async def generate_stream(self, params):
78
+ self.call_ct += 1
79
+
80
+ prompt = params.pop("prompt")
81
+ request_id = params.pop("request_id")
82
+ temperature = float(params.get("temperature", 1.0))
83
+ top_p = float(params.get("top_p", 1.0))
84
+ top_k = params.get("top_k", -1.0)
85
+ presence_penalty = float(params.get("presence_penalty", 0.0))
86
+ frequency_penalty = float(params.get("frequency_penalty", 0.0))
87
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
88
+ max_new_tokens = params.get("max_new_tokens", 256)
89
+ echo = params.get("echo", True)
90
+ stop_str = params.get("stop", None)
91
+ stop_token_ids = params.get("stop_token_ids", None) or []
92
+ if self.tokenizer.eos_token_id is not None:
93
+ stop_token_ids.append(self.tokenizer.eos_token_id)
94
+
95
+ request = params.get("request", None)
96
+
97
+ # Handle stop_str
98
+ stop = set()
99
+ if isinstance(stop_str, str) and stop_str != "":
100
+ stop.add(stop_str)
101
+ elif isinstance(stop_str, list) and stop_str != []:
102
+ stop.update(stop_str)
103
+
104
+ for tid in stop_token_ids:
105
+ if tid is not None:
106
+ s = self.tokenizer.decode(tid)
107
+ if s != "":
108
+ stop.add(s)
109
+
110
+ if self.is_first:
111
+ loop = asyncio.get_event_loop()
112
+ loop.create_task(httpserver_manager.handle_loop())
113
+ self.is_first = False
114
+
115
+ # make sampling params in vllm
116
+ top_p = max(top_p, 1e-5)
117
+ if temperature <= 1e-5:
118
+ top_p = 1.0
119
+
120
+ sampling_params = SamplingParams(
121
+ do_sample=temperature > 0.0,
122
+ temperature=temperature,
123
+ top_p=top_p,
124
+ top_k=top_k,
125
+ presence_penalty=presence_penalty,
126
+ frequency_penalty=frequency_penalty,
127
+ repetition_penalty=repetition_penalty,
128
+ max_new_tokens=max_new_tokens,
129
+ stop_sequences=list(stop),
130
+ )
131
+ sampling_params.verify()
132
+
133
+ results_generator = httpserver_manager.generate(
134
+ prompt, sampling_params, request_id, MultimodalParams()
135
+ )
136
+
137
+ completion_tokens = 0
138
+ text_outputs = ""
139
+ cumulative_logprob = 0.0
140
+
141
+ async for request_output, metadata, finish_status in results_generator:
142
+ text_outputs += request_output
143
+ completion_tokens += 1
144
+
145
+ partial_stop = any(is_partial_stop(text_outputs, i) for i in stop)
146
+ # prevent yielding partial stop sequence
147
+ if partial_stop:
148
+ continue
149
+
150
+ if type(finish_status) is bool: # compatibility with old version
151
+ finish_reason = "stop" if finish_status else None
152
+ else:
153
+ finish_reason = finish_status.get_finish_reason()
154
+
155
+ if request and await request.is_disconnected():
156
+ await httpserver_manager.abort(request_id)
157
+ finish_reason = "abort"
158
+
159
+ logprob = metadata.get("logprob", None)
160
+ if logprob is not None:
161
+ cumulative_logprob += logprob
162
+
163
+ prompt_tokens = metadata["prompt_tokens"]
164
+ ret = {
165
+ "text": prompt + text_outputs if echo else text_outputs,
166
+ "error_code": 0,
167
+ "usage": {
168
+ "prompt_tokens": prompt_tokens,
169
+ "completion_tokens": completion_tokens,
170
+ "total_tokens": prompt_tokens + completion_tokens,
171
+ },
172
+ "cumulative_logprob": cumulative_logprob,
173
+ }
174
+
175
+ if finish_reason is not None:
176
+ yield (
177
+ json.dumps({**ret, "finish_reason": None}, ensure_ascii=False)
178
+ + "\0"
179
+ ).encode("utf-8")
180
+ yield (
181
+ json.dumps({**ret, "finish_reason": finish_reason}, ensure_ascii=False)
182
+ + "\0"
183
+ ).encode("utf-8")
184
+
185
+ if finish_reason is not None: # In case of abort, we need to break the loop
186
+ break
187
+
188
+ async def generate(self, params):
189
+ async for x in self.generate_stream(params):
190
+ pass
191
+ return json.loads(x[:-1].decode())
192
+
193
+
194
+ def release_worker_semaphore():
195
+ worker.semaphore.release()
196
+
197
+
198
+ def acquire_worker_semaphore():
199
+ if worker.semaphore is None:
200
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
201
+ return worker.semaphore.acquire()
202
+
203
+
204
+ def create_background_tasks(request_id):
205
+ async def abort_request() -> None:
206
+ await httpserver_manager.abort(request_id)
207
+
208
+ background_tasks = BackgroundTasks()
209
+ background_tasks.add_task(release_worker_semaphore)
210
+ background_tasks.add_task(abort_request)
211
+ return background_tasks
212
+
213
+
214
+ @app.post("/worker_generate_stream")
215
+ async def api_generate_stream(request: Request):
216
+ params = await request.json()
217
+ await acquire_worker_semaphore()
218
+ request_id = g_id_gen.generate_id()
219
+ params["request_id"] = request_id
220
+ params["request"] = request
221
+ generator = worker.generate_stream(params)
222
+ background_tasks = create_background_tasks(request_id)
223
+ return StreamingResponse(generator, background=background_tasks)
224
+
225
+
226
+ @app.post("/worker_generate")
227
+ async def api_generate(request: Request):
228
+ params = await request.json()
229
+ await acquire_worker_semaphore()
230
+ request_id = g_id_gen.generate_id()
231
+ params["request_id"] = request_id
232
+ params["request"] = request
233
+ output = await worker.generate(params)
234
+ release_worker_semaphore()
235
+ await httpserver_manager.abort(request_id)
236
+ return JSONResponse(output)
237
+
238
+
239
+ @app.post("/worker_get_status")
240
+ async def api_get_status(request: Request):
241
+ return worker.get_status()
242
+
243
+
244
+ @app.post("/count_token")
245
+ async def api_count_token(request: Request):
246
+ params = await request.json()
247
+ return worker.count_token(params)
248
+
249
+
250
+ @app.post("/worker_get_conv_template")
251
+ async def api_get_conv(request: Request):
252
+ return worker.get_conv_template()
253
+
254
+
255
+ @app.post("/model_details")
256
+ async def api_model_details(request: Request):
257
+ return {"context_length": worker.context_len}
258
+
259
+
260
+ if __name__ == "__main__":
261
+ torch.multiprocessing.set_start_method("spawn")
262
+ parser = argparse.ArgumentParser()
263
+ parser.add_argument("--host", type=str, default="127.0.0.1")
264
+ parser.add_argument("--port", type=int, default=8000)
265
+
266
+ parser.add_argument(
267
+ "--model-path",
268
+ dest="model_dir",
269
+ type=str,
270
+ default=None,
271
+ help="the model weight dir path, the app will load config, weights and tokenizer from this dir",
272
+ )
273
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
274
+ parser.add_argument(
275
+ "--controller-address", type=str, default="http://localhost:21001"
276
+ )
277
+ parser.add_argument(
278
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
279
+ )
280
+ parser.add_argument(
281
+ "--model-names",
282
+ type=lambda s: s.split(","),
283
+ help="Optional display comma separated names",
284
+ )
285
+ parser.add_argument("--limit-worker-concurrency", type=int, default=1024)
286
+ parser.add_argument("--no-register", action="store_true")
287
+
288
+ parser.add_argument(
289
+ "--tokenizer_mode",
290
+ type=str,
291
+ default="slow",
292
+ help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is good for debug and test,
293
+ when you want to get best performance, try auto mode""",
294
+ )
295
+ parser.add_argument(
296
+ "--load_way",
297
+ type=str,
298
+ default="HF",
299
+ help="the way of loading model weights, the default is HF(Huggingface format), llama also supports DS(Deepspeed)",
300
+ )
301
+ parser.add_argument(
302
+ "--max_total_token_num",
303
+ type=int,
304
+ default=6000,
305
+ help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)",
306
+ )
307
+ parser.add_argument(
308
+ "--batch_max_tokens",
309
+ type=int,
310
+ default=None,
311
+ help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM",
312
+ )
313
+ parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id")
314
+ parser.add_argument(
315
+ "--running_max_req_size",
316
+ type=int,
317
+ default=1000,
318
+ help="the max size for forward requests in the same time",
319
+ )
320
+ parser.add_argument(
321
+ "--tp", type=int, default=1, help="model tp parral size, the default is 1"
322
+ )
323
+ parser.add_argument(
324
+ "--max_req_input_len",
325
+ type=int,
326
+ default=None,
327
+ help="the max value for req input tokens num. If None, it will be derived from the config.",
328
+ )
329
+ parser.add_argument(
330
+ "--max_req_total_len",
331
+ type=int,
332
+ default=None,
333
+ help="the max value for req_input_len + req_output_len. If None, it will be derived from the config.",
334
+ )
335
+ parser.add_argument(
336
+ "--mode",
337
+ type=str,
338
+ default=[],
339
+ nargs="+",
340
+ help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding
341
+ | triton_gqa_attention | triton_gqa_flashdecoding]
342
+ [triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight],
343
+ triton_flashdecoding mode is for long context, current support llama llama2 qwen;
344
+ triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
345
+ triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
346
+ ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel;
347
+ ppl_fp16 mode use ppl fast fp16 decode attention kernel;
348
+ triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights;
349
+ you need to read source code to make sure the supported detail mode for all models""",
350
+ )
351
+ parser.add_argument(
352
+ "--trust_remote_code",
353
+ action="store_true",
354
+ help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
355
+ )
356
+ parser.add_argument(
357
+ "--disable_log_stats",
358
+ action="store_true",
359
+ help="disable logging throughput stats.",
360
+ )
361
+ parser.add_argument(
362
+ "--log_stats_interval",
363
+ type=int,
364
+ default=10,
365
+ help="log stats interval in second.",
366
+ )
367
+
368
+ parser.add_argument(
369
+ "--router_token_ratio",
370
+ type=float,
371
+ default=0.0,
372
+ help="token ratio to control router dispatch",
373
+ )
374
+ parser.add_argument(
375
+ "--router_max_new_token_len",
376
+ type=int,
377
+ default=1024,
378
+ help="the request max new token len for router",
379
+ )
380
+
381
+ parser.add_argument(
382
+ "--no_skipping_special_tokens",
383
+ action="store_true",
384
+ help="whether to skip special tokens when decoding",
385
+ )
386
+ parser.add_argument(
387
+ "--no_spaces_between_special_tokens",
388
+ action="store_true",
389
+ help="whether to add spaces between special tokens when decoding",
390
+ )
391
+
392
+ parser.add_argument(
393
+ "--splitfuse_mode", action="store_true", help="use splitfuse mode"
394
+ )
395
+ parser.add_argument(
396
+ "--splitfuse_block_size", type=int, default=256, help="splitfuse block size"
397
+ )
398
+ parser.add_argument(
399
+ "--prompt_cache_strs",
400
+ type=str,
401
+ default=[],
402
+ nargs="+",
403
+ help="""prompt cache strs""",
404
+ )
405
+ parser.add_argument(
406
+ "--cache_capacity",
407
+ type=int,
408
+ default=200,
409
+ help="cache server capacity for multimodal resources",
410
+ )
411
+ parser.add_argument(
412
+ "--cache_reserved_ratio",
413
+ type=float,
414
+ default=0.5,
415
+ help="cache server reserved capacity ratio after clear",
416
+ )
417
+ parser.add_argument(
418
+ "--return_all_prompt_logprobs",
419
+ action="store_true",
420
+ help="return all prompt tokens logprobs",
421
+ )
422
+ parser.add_argument(
423
+ "--long_truncation_mode",
424
+ type=str,
425
+ choices=[None, "head", "center"],
426
+ default=None,
427
+ help="""use to select the handle way when input token len > max_req_input_len.
428
+ None : raise Exception
429
+ head : remove some head tokens to make input token len <= max_req_input_len
430
+ center : remove some tokens in center loc to make input token len <= max_req_input_len""",
431
+ )
432
+
433
+ args = parser.parse_args()
434
+
435
+ # 非splitfuse 模式,不支持 prompt cache 特性
436
+ if not args.splitfuse_mode:
437
+ assert len(args.prompt_cache_strs) == 0
438
+
439
+ model_config = AutoConfig.from_pretrained(args.model_dir)
440
+ context_length = get_context_length(model_config)
441
+
442
+ if args.max_req_input_len is None:
443
+ args.max_req_input_len = context_length - 1
444
+ if args.max_req_total_len is None:
445
+ args.max_req_total_len = context_length
446
+
447
+ assert args.max_req_input_len < args.max_req_total_len
448
+ assert args.max_req_total_len <= args.max_total_token_num
449
+
450
+ if not args.splitfuse_mode:
451
+ # 普通模式下
452
+ if args.batch_max_tokens is None:
453
+ batch_max_tokens = int(1 / 6 * args.max_total_token_num)
454
+ batch_max_tokens = max(batch_max_tokens, args.max_req_total_len)
455
+ args.batch_max_tokens = batch_max_tokens
456
+ else:
457
+ assert (
458
+ args.batch_max_tokens >= args.max_req_total_len
459
+ ), "batch_max_tokens must >= max_req_total_len"
460
+ else:
461
+ # splitfuse 模式下
462
+ # assert args.batch_max_tokens is not None, "need to set by yourself"
463
+ if args.batch_max_tokens is None:
464
+ batch_max_tokens = int(1 / 6 * args.max_total_token_num)
465
+ batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size)
466
+ args.batch_max_tokens = batch_max_tokens
467
+
468
+ can_use_ports = alloc_can_use_network_port(num=6 + args.tp)
469
+
470
+ assert can_use_ports is not None, "Can not alloc enough free ports."
471
+ (
472
+ router_port,
473
+ detokenization_port,
474
+ httpserver_port,
475
+ visual_port,
476
+ cache_port,
477
+ nccl_port,
478
+ ) = can_use_ports[0:6]
479
+ args.nccl_port = nccl_port
480
+ model_rpc_ports = can_use_ports[6:]
481
+
482
+ global httpserver_manager
483
+ httpserver_manager = HttpServerManager(
484
+ args,
485
+ router_port=router_port,
486
+ cache_port=cache_port,
487
+ visual_port=visual_port,
488
+ httpserver_port=httpserver_port,
489
+ enable_multimodal=False,
490
+ )
491
+
492
+ start_submodule_processes(
493
+ start_funcs=[start_router_process, start_detokenization_process],
494
+ start_args=[
495
+ (args, router_port, detokenization_port, model_rpc_ports),
496
+ (args, detokenization_port, httpserver_port),
497
+ ],
498
+ )
499
+ worker = LightLLMWorker(
500
+ args.controller_address,
501
+ args.worker_address,
502
+ worker_id,
503
+ args.model_dir,
504
+ args.model_names,
505
+ args.limit_worker_concurrency,
506
+ args.no_register,
507
+ args.conv_template,
508
+ httpserver_manager.tokenizer,
509
+ context_length,
510
+ )
511
+
512
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
mlx_worker.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker using Apple MLX
3
+
4
+ https://github.com/ml-explore/mlx-examples/tree/main/llms
5
+
6
+ Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py
7
+
8
+ You must install MLX python:
9
+
10
+ pip install mlx-lm
11
+ """
12
+
13
+ import argparse
14
+ import asyncio
15
+ import atexit
16
+ import json
17
+ from typing import List
18
+ import uuid
19
+
20
+ from fastapi import FastAPI, Request, BackgroundTasks
21
+ from fastapi.concurrency import run_in_threadpool
22
+ from fastapi.responses import StreamingResponse, JSONResponse
23
+ import uvicorn
24
+
25
+ from fastchat.serve.base_model_worker import BaseModelWorker
26
+ from fastchat.serve.model_worker import (
27
+ logger,
28
+ worker_id,
29
+ )
30
+ from fastchat.utils import get_context_length, is_partial_stop
31
+
32
+ import mlx.core as mx
33
+ from mlx_lm import load, generate
34
+ from mlx_lm.utils import generate_step
35
+
36
+ app = FastAPI()
37
+
38
+
39
+ class MLXWorker(BaseModelWorker):
40
+ def __init__(
41
+ self,
42
+ controller_addr: str,
43
+ worker_addr: str,
44
+ worker_id: str,
45
+ model_path: str,
46
+ model_names: List[str],
47
+ limit_worker_concurrency: int,
48
+ no_register: bool,
49
+ llm_engine: "MLX",
50
+ conv_template: str,
51
+ ):
52
+ super().__init__(
53
+ controller_addr,
54
+ worker_addr,
55
+ worker_id,
56
+ model_path,
57
+ model_names,
58
+ limit_worker_concurrency,
59
+ conv_template,
60
+ )
61
+
62
+ logger.info(
63
+ f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..."
64
+ )
65
+
66
+ self.model_name = model_path
67
+ self.mlx_model, self.mlx_tokenizer = load(model_path)
68
+
69
+ self.tokenizer = self.mlx_tokenizer
70
+ # self.context_len = get_context_length(
71
+ # llm_engine.engine.model_config.hf_config)
72
+ self.context_len = 2048 # hard code for now -- not sure how to get in MLX
73
+
74
+ if not no_register:
75
+ self.init_heart_beat()
76
+
77
+ async def generate_stream(self, params):
78
+ self.call_ct += 1
79
+
80
+ context = params.pop("prompt")
81
+ request_id = params.pop("request_id")
82
+ temperature = float(params.get("temperature", 1.0))
83
+ top_p = float(params.get("top_p", 1.0))
84
+ top_k = params.get("top_k", -1.0)
85
+ presence_penalty = float(params.get("presence_penalty", 0.0))
86
+ frequency_penalty = float(params.get("frequency_penalty", 0.0))
87
+ max_new_tokens = params.get("max_new_tokens", 256)
88
+ stop_str = params.get("stop", None)
89
+ stop_token_ids = params.get("stop_token_ids", None) or []
90
+ if self.tokenizer.eos_token_id is not None:
91
+ stop_token_ids.append(self.tokenizer.eos_token_id)
92
+ echo = params.get("echo", True)
93
+ use_beam_search = params.get("use_beam_search", False)
94
+ best_of = params.get("best_of", None)
95
+
96
+ # Handle stop_str
97
+ stop = set()
98
+ if isinstance(stop_str, str) and stop_str != "":
99
+ stop.add(stop_str)
100
+ elif isinstance(stop_str, list) and stop_str != []:
101
+ stop.update(stop_str)
102
+
103
+ for tid in stop_token_ids:
104
+ if tid is not None:
105
+ s = self.tokenizer.decode(tid)
106
+ if s != "":
107
+ stop.add(s)
108
+
109
+ print("Stop patterns: ", stop)
110
+
111
+ top_p = max(top_p, 1e-5)
112
+ if temperature <= 1e-5:
113
+ top_p = 1.0
114
+
115
+ tokens = []
116
+ skip = 0
117
+
118
+ context_mlx = mx.array(self.tokenizer.encode(context))
119
+
120
+ finish_reason = "length"
121
+
122
+ iterator = await run_in_threadpool(
123
+ generate_step, context_mlx, self.mlx_model, temperature
124
+ )
125
+
126
+ for i in range(max_new_tokens):
127
+ (token, _) = await run_in_threadpool(next, iterator)
128
+ if token == self.mlx_tokenizer.eos_token_id:
129
+ finish_reason = "stop"
130
+ break
131
+ tokens.append(token.item())
132
+ tokens_decoded = self.mlx_tokenizer.decode(tokens)
133
+ last_token_decoded = self.mlx_tokenizer.decode([token.item()])
134
+ skip = len(tokens_decoded)
135
+
136
+ partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop)
137
+
138
+ if partial_stop:
139
+ finish_reason = "stop"
140
+ break
141
+
142
+ ret = {
143
+ "text": tokens_decoded,
144
+ "error_code": 0,
145
+ "usage": {
146
+ "prompt_tokens": len(context),
147
+ "completion_tokens": len(tokens),
148
+ "total_tokens": len(context) + len(tokens),
149
+ },
150
+ "cumulative_logprob": [],
151
+ "finish_reason": None, # hard code for now
152
+ }
153
+ # print(ret)
154
+ yield (json.dumps(ret) + "\0").encode()
155
+ ret = {
156
+ "text": self.mlx_tokenizer.decode(tokens),
157
+ "error_code": 0,
158
+ "usage": {},
159
+ "cumulative_logprob": [],
160
+ "finish_reason": finish_reason,
161
+ }
162
+ yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode()
163
+ yield (json.dumps(ret) + "\0").encode()
164
+
165
+ async def generate(self, params):
166
+ async for x in self.generate_stream(params):
167
+ pass
168
+ return json.loads(x[:-1].decode())
169
+
170
+
171
+ def release_worker_semaphore():
172
+ worker.semaphore.release()
173
+
174
+
175
+ def acquire_worker_semaphore():
176
+ if worker.semaphore is None:
177
+ worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
178
+ return worker.semaphore.acquire()
179
+
180
+
181
+ def create_background_tasks(request_id):
182
+ async def abort_request() -> None:
183
+ print("trying to abort but not implemented")
184
+
185
+ background_tasks = BackgroundTasks()
186
+ background_tasks.add_task(release_worker_semaphore)
187
+ background_tasks.add_task(abort_request)
188
+ return background_tasks
189
+
190
+
191
+ @app.post("/worker_generate_stream")
192
+ async def api_generate_stream(request: Request):
193
+ params = await request.json()
194
+ await acquire_worker_semaphore()
195
+ request_id = uuid.uuid4()
196
+ params["request_id"] = str(request_id)
197
+ generator = worker.generate_stream(params)
198
+ background_tasks = create_background_tasks(request_id)
199
+ return StreamingResponse(generator, background=background_tasks)
200
+
201
+
202
+ @app.post("/worker_generate")
203
+ async def api_generate(request: Request):
204
+ params = await request.json()
205
+ await acquire_worker_semaphore()
206
+ request_id = uuid.uuid4()
207
+ params["request_id"] = str(request_id)
208
+ output = await worker.generate(params)
209
+ release_worker_semaphore()
210
+ # await engine.abort(request_id)
211
+ print("Trying to abort but not implemented")
212
+ return JSONResponse(output)
213
+
214
+
215
+ @app.post("/worker_get_status")
216
+ async def api_get_status(request: Request):
217
+ return worker.get_status()
218
+
219
+
220
+ @app.post("/count_token")
221
+ async def api_count_token(request: Request):
222
+ params = await request.json()
223
+ return worker.count_token(params)
224
+
225
+
226
+ @app.post("/worker_get_conv_template")
227
+ async def api_get_conv(request: Request):
228
+ return worker.get_conv_template()
229
+
230
+
231
+ @app.post("/model_details")
232
+ async def api_model_details(request: Request):
233
+ return {"context_length": worker.context_len}
234
+
235
+
236
+ worker = None
237
+
238
+
239
+ def cleanup_at_exit():
240
+ global worker
241
+ print("Cleaning up...")
242
+ del worker
243
+
244
+
245
+ atexit.register(cleanup_at_exit)
246
+
247
+ if __name__ == "__main__":
248
+ parser = argparse.ArgumentParser()
249
+ parser.add_argument("--host", type=str, default="localhost")
250
+ parser.add_argument("--port", type=int, default=21002)
251
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
252
+ parser.add_argument(
253
+ "--controller-address", type=str, default="http://localhost:21001"
254
+ )
255
+ parser.add_argument("--model-path", type=str, default="microsoft/phi-2")
256
+ parser.add_argument(
257
+ "--model-names",
258
+ type=lambda s: s.split(","),
259
+ help="Optional display comma separated names",
260
+ )
261
+ parser.add_argument(
262
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
263
+ )
264
+ parser.add_argument(
265
+ "--trust_remote_code",
266
+ action="store_false",
267
+ default=True,
268
+ help="Trust remote code (e.g., from HuggingFace) when"
269
+ "downloading the model and tokenizer.",
270
+ )
271
+
272
+ args, unknown = parser.parse_known_args()
273
+
274
+ if args.model_path:
275
+ args.model = args.model_path
276
+
277
+ worker = MLXWorker(
278
+ args.controller_address,
279
+ args.worker_address,
280
+ worker_id,
281
+ args.model_path,
282
+ args.model_names,
283
+ 1024,
284
+ False,
285
+ "MLX",
286
+ args.conv_template,
287
+ )
288
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
model_worker.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker that executes the model.
3
+ """
4
+ import argparse
5
+ import base64
6
+ import gc
7
+ import json
8
+ import os
9
+ from typing import List, Optional
10
+ import uuid
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from transformers import set_seed
15
+ import uvicorn
16
+
17
+ from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
18
+ from fastchat.model.model_adapter import (
19
+ load_model,
20
+ add_model_args,
21
+ get_generate_stream_function,
22
+ )
23
+ from fastchat.modules.awq import AWQConfig
24
+ from fastchat.modules.exllama import ExllamaConfig
25
+ from fastchat.modules.xfastertransformer import XftConfig
26
+ from fastchat.modules.gptq import GptqConfig
27
+ from fastchat.serve.base_model_worker import BaseModelWorker, app
28
+ from fastchat.utils import (
29
+ build_logger,
30
+ get_context_length,
31
+ str_to_torch_dtype,
32
+ )
33
+
34
+ worker_id = str(uuid.uuid4())[:8]
35
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
36
+
37
+
38
+ class ModelWorker(BaseModelWorker):
39
+ def __init__(
40
+ self,
41
+ controller_addr: str,
42
+ worker_addr: str,
43
+ worker_id: str,
44
+ model_path: str,
45
+ model_names: List[str],
46
+ limit_worker_concurrency: int,
47
+ no_register: bool,
48
+ device: str,
49
+ num_gpus: int,
50
+ max_gpu_memory: str,
51
+ revision: str = None,
52
+ dtype: Optional[torch.dtype] = None,
53
+ load_8bit: bool = False,
54
+ cpu_offloading: bool = False,
55
+ gptq_config: Optional[GptqConfig] = None,
56
+ awq_config: Optional[AWQConfig] = None,
57
+ exllama_config: Optional[ExllamaConfig] = None,
58
+ xft_config: Optional[XftConfig] = None,
59
+ stream_interval: int = 2,
60
+ conv_template: Optional[str] = None,
61
+ embed_in_truncate: bool = False,
62
+ seed: Optional[int] = None,
63
+ debug: bool = False,
64
+ **kwargs,
65
+ ):
66
+ super().__init__(
67
+ controller_addr,
68
+ worker_addr,
69
+ worker_id,
70
+ model_path,
71
+ model_names,
72
+ limit_worker_concurrency,
73
+ conv_template=conv_template,
74
+ )
75
+
76
+ logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
77
+ self.model, self.tokenizer = load_model(
78
+ model_path,
79
+ revision=revision,
80
+ device=device,
81
+ num_gpus=num_gpus,
82
+ max_gpu_memory=max_gpu_memory,
83
+ dtype=dtype,
84
+ load_8bit=load_8bit,
85
+ cpu_offloading=cpu_offloading,
86
+ gptq_config=gptq_config,
87
+ awq_config=awq_config,
88
+ exllama_config=exllama_config,
89
+ xft_config=xft_config,
90
+ debug=debug,
91
+ )
92
+ self.device = device
93
+ if self.tokenizer.pad_token == None:
94
+ self.tokenizer.pad_token = self.tokenizer.eos_token
95
+ self.context_len = get_context_length(self.model.config)
96
+ self.generate_stream_func = get_generate_stream_function(self.model, model_path)
97
+ self.stream_interval = stream_interval
98
+ self.embed_in_truncate = embed_in_truncate
99
+ self.seed = seed
100
+
101
+ if not no_register:
102
+ self.init_heart_beat()
103
+
104
+ def generate_stream_gate(self, params):
105
+ if self.device == "npu":
106
+ import torch_npu
107
+
108
+ torch_npu.npu.set_device("npu:0")
109
+ self.call_ct += 1
110
+
111
+ try:
112
+ if self.seed is not None:
113
+ set_seed(self.seed)
114
+ for output in self.generate_stream_func(
115
+ self.model,
116
+ self.tokenizer,
117
+ params,
118
+ self.device,
119
+ self.context_len,
120
+ self.stream_interval,
121
+ ):
122
+ ret = {
123
+ "text": output["text"],
124
+ "error_code": 0,
125
+ }
126
+ if "usage" in output:
127
+ ret["usage"] = output["usage"]
128
+ if "finish_reason" in output:
129
+ ret["finish_reason"] = output["finish_reason"]
130
+ if "logprobs" in output:
131
+ ret["logprobs"] = output["logprobs"]
132
+ yield json.dumps(ret).encode() + b"\0"
133
+ except torch.cuda.OutOfMemoryError as e:
134
+ ret = {
135
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
136
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
137
+ }
138
+ yield json.dumps(ret).encode() + b"\0"
139
+ except (ValueError, RuntimeError) as e:
140
+ ret = {
141
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
142
+ "error_code": ErrorCode.INTERNAL_ERROR,
143
+ }
144
+ yield json.dumps(ret).encode() + b"\0"
145
+
146
+ def generate_gate(self, params):
147
+ for x in self.generate_stream_gate(params):
148
+ pass
149
+ return json.loads(x[:-1].decode())
150
+
151
+ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
152
+ if model_type_dict.get("is_bert"):
153
+ model_output = self.model(input_ids)
154
+ if model_type_dict.get("is_robert"):
155
+ data = model_output.last_hidden_state
156
+ else:
157
+ data = model_output[0]
158
+ elif model_type_dict.get("is_t5"):
159
+ model_output = self.model(input_ids, decoder_input_ids=input_ids)
160
+ data = model_output.encoder_last_hidden_state
161
+ else:
162
+ model_output = self.model(input_ids, output_hidden_states=True)
163
+ if model_type_dict.get("is_chatglm"):
164
+ data = model_output.hidden_states[-1].transpose(0, 1)
165
+ else:
166
+ data = model_output.hidden_states[-1]
167
+
168
+ if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling:
169
+ sum_embeddings = data[:, 0]
170
+ else:
171
+ mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
172
+ masked_embeddings = data * mask
173
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
174
+ token_num = torch.sum(attention_mask).item()
175
+
176
+ return sum_embeddings, token_num
177
+
178
+ def __encode_base64(self, embeddings: torch.Tensor) -> List[str]:
179
+ embeddings = embeddings.cpu()
180
+ return [
181
+ base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings
182
+ ]
183
+
184
+ @torch.inference_mode()
185
+ def get_embeddings(self, params):
186
+ self.call_ct += 1
187
+
188
+ try:
189
+ tokenizer = self.tokenizer
190
+ ret = {"embedding": [], "token_num": 0}
191
+
192
+ model_type_dict = {
193
+ "is_llama": "llama" in str(type(self.model)),
194
+ "is_t5": "t5" in str(type(self.model)),
195
+ "is_chatglm": "chatglm" in str(type(self.model)),
196
+ "is_bert": "bert" in str(type(self.model)),
197
+ "is_robert": "robert" in str(type(self.model)),
198
+ }
199
+
200
+ if self.embed_in_truncate:
201
+ encoding = tokenizer.batch_encode_plus(
202
+ params["input"],
203
+ padding=True,
204
+ truncation="longest_first",
205
+ return_tensors="pt",
206
+ max_length=self.context_len,
207
+ )
208
+ else:
209
+ encoding = tokenizer.batch_encode_plus(
210
+ params["input"], padding=True, return_tensors="pt"
211
+ )
212
+ input_ids = encoding["input_ids"].to(self.device)
213
+ attention_mask = input_ids != tokenizer.pad_token_id
214
+
215
+ base64_encode = params.get("encoding_format", None)
216
+
217
+ if self.embed_in_truncate:
218
+ embedding, token_num = self.__process_embed_chunk(
219
+ input_ids, attention_mask, **model_type_dict
220
+ )
221
+ if (
222
+ not hasattr(self.model, "use_cls_pooling")
223
+ or not self.model.use_cls_pooling
224
+ ):
225
+ embedding = embedding / token_num
226
+ normalized_embeddings = F.normalize(embedding, p=2, dim=1)
227
+ ret["token_num"] = token_num
228
+ else:
229
+ all_embeddings = []
230
+ all_token_num = 0
231
+ for i in range(0, input_ids.size(1), self.context_len):
232
+ chunk_input_ids = input_ids[:, i : i + self.context_len]
233
+ chunk_attention_mask = attention_mask[:, i : i + self.context_len]
234
+
235
+ # add cls token and mask to get cls embedding
236
+ if (
237
+ hasattr(self.model, "use_cls_pooling")
238
+ and self.model.use_cls_pooling
239
+ ):
240
+ cls_tokens = (
241
+ torch.zeros(
242
+ (chunk_input_ids.size(0), 1),
243
+ dtype=chunk_input_ids.dtype,
244
+ device=chunk_input_ids.device,
245
+ )
246
+ + tokenizer.cls_token_id
247
+ )
248
+ chunk_input_ids = torch.cat(
249
+ [cls_tokens, chunk_input_ids], dim=-1
250
+ )
251
+ mask = torch.ones(
252
+ (chunk_attention_mask.size(0), 1),
253
+ dtype=chunk_attention_mask.dtype,
254
+ device=chunk_attention_mask.device,
255
+ )
256
+ chunk_attention_mask = torch.cat(
257
+ [mask, chunk_attention_mask], dim=-1
258
+ )
259
+
260
+ chunk_embeddings, token_num = self.__process_embed_chunk(
261
+ chunk_input_ids, chunk_attention_mask, **model_type_dict
262
+ )
263
+ if (
264
+ hasattr(self.model, "use_cls_pooling")
265
+ and self.model.use_cls_pooling
266
+ ):
267
+ all_embeddings.append(chunk_embeddings * token_num)
268
+ else:
269
+ all_embeddings.append(chunk_embeddings)
270
+ all_token_num += token_num
271
+
272
+ all_embeddings_tensor = torch.stack(all_embeddings)
273
+ embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
274
+ normalized_embeddings = F.normalize(embedding, p=2, dim=1)
275
+
276
+ ret["token_num"] = all_token_num
277
+
278
+ if base64_encode == "base64":
279
+ out_embeddings = self.__encode_base64(normalized_embeddings)
280
+ else:
281
+ out_embeddings = normalized_embeddings.tolist()
282
+ ret["embedding"] = out_embeddings
283
+
284
+ gc.collect()
285
+ torch.cuda.empty_cache()
286
+ if self.device == "xpu":
287
+ torch.xpu.empty_cache()
288
+ if self.device == "npu":
289
+ torch.npu.empty_cache()
290
+ except torch.cuda.OutOfMemoryError as e:
291
+ ret = {
292
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
293
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
294
+ }
295
+ except (ValueError, RuntimeError) as e:
296
+ ret = {
297
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
298
+ "error_code": ErrorCode.INTERNAL_ERROR,
299
+ }
300
+ return ret
301
+
302
+
303
+ def create_model_worker():
304
+ parser = argparse.ArgumentParser()
305
+ parser.add_argument("--host", type=str, default="localhost")
306
+ parser.add_argument("--port", type=int, default=21002)
307
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
308
+ parser.add_argument(
309
+ "--controller-address", type=str, default="http://localhost:21001"
310
+ )
311
+ add_model_args(parser)
312
+ parser.add_argument(
313
+ "--model-names",
314
+ type=lambda s: s.split(","),
315
+ help="Optional display comma separated names",
316
+ )
317
+ parser.add_argument(
318
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
319
+ )
320
+ parser.add_argument("--embed-in-truncate", action="store_true")
321
+ parser.add_argument(
322
+ "--limit-worker-concurrency",
323
+ type=int,
324
+ default=5,
325
+ help="Limit the model concurrency to prevent OOM.",
326
+ )
327
+ parser.add_argument("--stream-interval", type=int, default=2)
328
+ parser.add_argument("--no-register", action="store_true")
329
+ parser.add_argument(
330
+ "--seed",
331
+ type=int,
332
+ default=None,
333
+ help="Overwrite the random seed for each generation.",
334
+ )
335
+ parser.add_argument(
336
+ "--debug", type=bool, default=False, help="Print debugging messages"
337
+ )
338
+ parser.add_argument(
339
+ "--ssl",
340
+ action="store_true",
341
+ required=False,
342
+ default=False,
343
+ help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
344
+ )
345
+ args = parser.parse_args()
346
+ logger.info(f"args: {args}")
347
+
348
+ if args.gpus:
349
+ if len(args.gpus.split(",")) < args.num_gpus:
350
+ raise ValueError(
351
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
352
+ )
353
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
354
+
355
+ gptq_config = GptqConfig(
356
+ ckpt=args.gptq_ckpt or args.model_path,
357
+ wbits=args.gptq_wbits,
358
+ groupsize=args.gptq_groupsize,
359
+ act_order=args.gptq_act_order,
360
+ )
361
+ awq_config = AWQConfig(
362
+ ckpt=args.awq_ckpt or args.model_path,
363
+ wbits=args.awq_wbits,
364
+ groupsize=args.awq_groupsize,
365
+ )
366
+ if args.enable_exllama:
367
+ exllama_config = ExllamaConfig(
368
+ max_seq_len=args.exllama_max_seq_len,
369
+ gpu_split=args.exllama_gpu_split,
370
+ cache_8bit=args.exllama_cache_8bit,
371
+ )
372
+ else:
373
+ exllama_config = None
374
+ if args.enable_xft:
375
+ xft_config = XftConfig(
376
+ max_seq_len=args.xft_max_seq_len,
377
+ data_type=args.xft_dtype,
378
+ )
379
+ if args.device != "cpu":
380
+ print("xFasterTransformer now is only support CPUs. Reset device to CPU")
381
+ args.device = "cpu"
382
+ else:
383
+ xft_config = None
384
+
385
+ worker = ModelWorker(
386
+ args.controller_address,
387
+ args.worker_address,
388
+ worker_id,
389
+ args.model_path,
390
+ args.model_names,
391
+ args.limit_worker_concurrency,
392
+ revision=args.revision,
393
+ no_register=args.no_register,
394
+ device=args.device,
395
+ num_gpus=args.num_gpus,
396
+ max_gpu_memory=args.max_gpu_memory,
397
+ dtype=str_to_torch_dtype(args.dtype),
398
+ load_8bit=args.load_8bit,
399
+ cpu_offloading=args.cpu_offloading,
400
+ gptq_config=gptq_config,
401
+ awq_config=awq_config,
402
+ exllama_config=exllama_config,
403
+ xft_config=xft_config,
404
+ stream_interval=args.stream_interval,
405
+ conv_template=args.conv_template,
406
+ embed_in_truncate=args.embed_in_truncate,
407
+ seed=args.seed,
408
+ debug=args.debug,
409
+ )
410
+ return args, worker
411
+
412
+
413
+ if __name__ == "__main__":
414
+ args, worker = create_model_worker()
415
+ if args.ssl:
416
+ uvicorn.run(
417
+ app,
418
+ host=args.host,
419
+ port=args.port,
420
+ log_level="info",
421
+ ssl_keyfile=os.environ["SSL_KEYFILE"],
422
+ ssl_certfile=os.environ["SSL_CERTFILE"],
423
+ )
424
+ else:
425
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
monitor/__pycache__/basic_stats.cpython-39.pyc ADDED
Binary file (6.1 kB). View file
 
monitor/__pycache__/clean_battle_data.cpython-39.pyc ADDED
Binary file (7 kB). View file
 
monitor/__pycache__/clean_chat_data.cpython-39.pyc ADDED
Binary file (4.36 kB). View file
 
monitor/__pycache__/elo_analysis.cpython-39.pyc ADDED
Binary file (9.21 kB). View file
 
monitor/__pycache__/inspect_conv.cpython-39.pyc ADDED
Binary file (2.16 kB). View file