Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +2 -8
- __init__.py +0 -0
- __pycache__/__init__.cpython-39.pyc +0 -0
- __pycache__/api_provider.cpython-39.pyc +0 -0
- __pycache__/base_model_worker.cpython-39.pyc +0 -0
- __pycache__/call_monitor.cpython-39.pyc +0 -0
- __pycache__/cli.cpython-39.pyc +0 -0
- __pycache__/controller.cpython-39.pyc +0 -0
- __pycache__/gradio_block_arena_anony.cpython-39.pyc +0 -0
- __pycache__/gradio_block_arena_named.cpython-39.pyc +0 -0
- __pycache__/gradio_block_arena_vision.cpython-39.pyc +0 -0
- __pycache__/gradio_web_server.cpython-39.pyc +0 -0
- __pycache__/gradio_web_server_multi.cpython-39.pyc +0 -0
- __pycache__/huggingface_api.cpython-39.pyc +0 -0
- __pycache__/huggingface_api_worker.cpython-39.pyc +0 -0
- __pycache__/inference.cpython-39.pyc +0 -0
- __pycache__/launch_all_serve.cpython-39.pyc +0 -0
- __pycache__/lightllm_worker.cpython-39.pyc +0 -0
- __pycache__/mlx_worker.cpython-39.pyc +0 -0
- __pycache__/model_worker.cpython-39.pyc +0 -0
- __pycache__/multi_model_worker.cpython-39.pyc +0 -0
- __pycache__/openai_api_server.cpython-39.pyc +0 -0
- __pycache__/register_worker.cpython-39.pyc +0 -0
- __pycache__/sglang_worker.cpython-39.pyc +0 -0
- __pycache__/shutdown_serve.cpython-39.pyc +0 -0
- __pycache__/test_message.cpython-39.pyc +0 -0
- __pycache__/test_throughput.cpython-39.pyc +0 -0
- __pycache__/vllm_worker.cpython-39.pyc +0 -0
- api_provider.py +454 -0
- base_model_worker.py +241 -0
- call_monitor.py +219 -0
- cli.py +304 -0
- controller.py +389 -0
- gradio_block_arena_anony.py +811 -0
- gradio_block_arena_named.py +469 -0
- gradio_block_arena_vision.py +187 -0
- gradio_web_server.py +887 -0
- gradio_web_server_multi.py +277 -0
- huggingface_api.py +73 -0
- huggingface_api_worker.py +415 -0
- inference.py +555 -0
- launch_all_serve.py +284 -0
- lightllm_worker.py +512 -0
- mlx_worker.py +288 -0
- model_worker.py +425 -0
- monitor/__pycache__/basic_stats.cpython-39.pyc +0 -0
- monitor/__pycache__/clean_battle_data.cpython-39.pyc +0 -0
- monitor/__pycache__/clean_chat_data.cpython-39.pyc +0 -0
- monitor/__pycache__/elo_analysis.cpython-39.pyc +0 -0
- monitor/__pycache__/inspect_conv.cpython-39.pyc +0 -0
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.27.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: test_fastchat
|
3 |
+
app_file: gradio_web_server.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.27.0
|
|
|
|
|
6 |
---
|
|
|
|
__init__.py
ADDED
File without changes
|
__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (190 Bytes). View file
|
|
__pycache__/api_provider.cpython-39.pyc
ADDED
Binary file (7.84 kB). View file
|
|
__pycache__/base_model_worker.cpython-39.pyc
ADDED
Binary file (7.09 kB). View file
|
|
__pycache__/call_monitor.cpython-39.pyc
ADDED
Binary file (6.5 kB). View file
|
|
__pycache__/cli.cpython-39.pyc
ADDED
Binary file (8.98 kB). View file
|
|
__pycache__/controller.cpython-39.pyc
ADDED
Binary file (10.3 kB). View file
|
|
__pycache__/gradio_block_arena_anony.cpython-39.pyc
ADDED
Binary file (15.5 kB). View file
|
|
__pycache__/gradio_block_arena_named.cpython-39.pyc
ADDED
Binary file (11.1 kB). View file
|
|
__pycache__/gradio_block_arena_vision.cpython-39.pyc
ADDED
Binary file (4.25 kB). View file
|
|
__pycache__/gradio_web_server.cpython-39.pyc
ADDED
Binary file (21.9 kB). View file
|
|
__pycache__/gradio_web_server_multi.cpython-39.pyc
ADDED
Binary file (6.25 kB). View file
|
|
__pycache__/huggingface_api.cpython-39.pyc
ADDED
Binary file (1.98 kB). View file
|
|
__pycache__/huggingface_api_worker.cpython-39.pyc
ADDED
Binary file (11 kB). View file
|
|
__pycache__/inference.cpython-39.pyc
ADDED
Binary file (10.7 kB). View file
|
|
__pycache__/launch_all_serve.cpython-39.pyc
ADDED
Binary file (6.34 kB). View file
|
|
__pycache__/lightllm_worker.cpython-39.pyc
ADDED
Binary file (13 kB). View file
|
|
__pycache__/mlx_worker.cpython-39.pyc
ADDED
Binary file (7.55 kB). View file
|
|
__pycache__/model_worker.cpython-39.pyc
ADDED
Binary file (10 kB). View file
|
|
__pycache__/multi_model_worker.cpython-39.pyc
ADDED
Binary file (8.78 kB). View file
|
|
__pycache__/openai_api_server.cpython-39.pyc
ADDED
Binary file (21.6 kB). View file
|
|
__pycache__/register_worker.cpython-39.pyc
ADDED
Binary file (914 Bytes). View file
|
|
__pycache__/sglang_worker.cpython-39.pyc
ADDED
Binary file (8.6 kB). View file
|
|
__pycache__/shutdown_serve.cpython-39.pyc
ADDED
Binary file (923 Bytes). View file
|
|
__pycache__/test_message.cpython-39.pyc
ADDED
Binary file (2.11 kB). View file
|
|
__pycache__/test_throughput.cpython-39.pyc
ADDED
Binary file (3.11 kB). View file
|
|
__pycache__/vllm_worker.cpython-39.pyc
ADDED
Binary file (8.58 kB). View file
|
|
api_provider.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Call API providers."""
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from fastchat.utils import build_logger
|
11 |
+
|
12 |
+
|
13 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
14 |
+
|
15 |
+
|
16 |
+
def get_api_provider_stream_iter(
|
17 |
+
conv,
|
18 |
+
model_name,
|
19 |
+
model_api_dict,
|
20 |
+
temperature,
|
21 |
+
top_p,
|
22 |
+
max_new_tokens,
|
23 |
+
):
|
24 |
+
if model_api_dict["api_type"] == "openai":
|
25 |
+
prompt = conv.to_openai_api_messages()
|
26 |
+
stream_iter = openai_api_stream_iter(
|
27 |
+
model_api_dict["model_name"],
|
28 |
+
prompt,
|
29 |
+
temperature,
|
30 |
+
top_p,
|
31 |
+
max_new_tokens,
|
32 |
+
api_base=model_api_dict["api_base"],
|
33 |
+
api_key=model_api_dict["api_key"],
|
34 |
+
)
|
35 |
+
elif model_api_dict["api_type"] == "anthropic":
|
36 |
+
prompt = conv.get_prompt()
|
37 |
+
stream_iter = anthropic_api_stream_iter(
|
38 |
+
model_name, prompt, temperature, top_p, max_new_tokens
|
39 |
+
)
|
40 |
+
elif model_api_dict["api_type"] == "gemini":
|
41 |
+
stream_iter = gemini_api_stream_iter(
|
42 |
+
model_api_dict["model_name"],
|
43 |
+
conv,
|
44 |
+
temperature,
|
45 |
+
top_p,
|
46 |
+
max_new_tokens,
|
47 |
+
api_key=model_api_dict["api_key"],
|
48 |
+
)
|
49 |
+
elif model_api_dict["api_type"] == "bard":
|
50 |
+
prompt = conv.to_openai_api_messages()
|
51 |
+
stream_iter = bard_api_stream_iter(
|
52 |
+
model_api_dict["model_name"],
|
53 |
+
prompt,
|
54 |
+
temperature,
|
55 |
+
top_p,
|
56 |
+
api_key=model_api_dict["api_key"],
|
57 |
+
)
|
58 |
+
elif model_api_dict["api_type"] == "mistral":
|
59 |
+
prompt = conv.to_openai_api_messages()
|
60 |
+
stream_iter = mistral_api_stream_iter(
|
61 |
+
model_name, prompt, temperature, top_p, max_new_tokens
|
62 |
+
)
|
63 |
+
elif model_api_dict["api_type"] == "nvidia":
|
64 |
+
prompt = conv.to_openai_api_messages()
|
65 |
+
stream_iter = nvidia_api_stream_iter(
|
66 |
+
model_name,
|
67 |
+
prompt,
|
68 |
+
temperature,
|
69 |
+
top_p,
|
70 |
+
max_new_tokens,
|
71 |
+
model_api_dict["api_base"],
|
72 |
+
)
|
73 |
+
elif model_api_dict["api_type"] == "ai2":
|
74 |
+
prompt = conv.to_openai_api_messages()
|
75 |
+
stream_iter = ai2_api_stream_iter(
|
76 |
+
model_name,
|
77 |
+
model_api_dict["model_name"],
|
78 |
+
prompt,
|
79 |
+
temperature,
|
80 |
+
top_p,
|
81 |
+
max_new_tokens,
|
82 |
+
api_base=model_api_dict["api_base"],
|
83 |
+
api_key=model_api_dict["api_key"],
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
raise NotImplementedError()
|
87 |
+
|
88 |
+
return stream_iter
|
89 |
+
|
90 |
+
|
91 |
+
def openai_api_stream_iter(
|
92 |
+
model_name,
|
93 |
+
messages,
|
94 |
+
temperature,
|
95 |
+
top_p,
|
96 |
+
max_new_tokens,
|
97 |
+
api_base=None,
|
98 |
+
api_key=None,
|
99 |
+
):
|
100 |
+
import openai
|
101 |
+
|
102 |
+
api_key = api_key or os.environ["OPENAI_API_KEY"]
|
103 |
+
|
104 |
+
if "azure" in model_name:
|
105 |
+
client = openai.AzureOpenAI(
|
106 |
+
api_version="2023-07-01-preview",
|
107 |
+
azure_endpoint=api_base or "https://api.openai.com/v1",
|
108 |
+
api_key=api_key,
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
client = openai.OpenAI(
|
112 |
+
base_url=api_base or "https://api.openai.com/v1", api_key=api_key
|
113 |
+
)
|
114 |
+
|
115 |
+
if model_name == "gpt-4-turbo":
|
116 |
+
model_name = "gpt-4-1106-preview"
|
117 |
+
|
118 |
+
# Make requests
|
119 |
+
gen_params = {
|
120 |
+
"model": model_name,
|
121 |
+
"prompt": messages,
|
122 |
+
"temperature": temperature,
|
123 |
+
"top_p": top_p,
|
124 |
+
"max_new_tokens": max_new_tokens,
|
125 |
+
}
|
126 |
+
logger.info(f"==== request ====\n{gen_params}")
|
127 |
+
|
128 |
+
res = client.chat.completions.create(
|
129 |
+
model=model_name,
|
130 |
+
messages=messages,
|
131 |
+
temperature=temperature,
|
132 |
+
max_tokens=max_new_tokens,
|
133 |
+
stream=True,
|
134 |
+
)
|
135 |
+
text = ""
|
136 |
+
for chunk in res:
|
137 |
+
if len(chunk.choices) > 0:
|
138 |
+
text += chunk.choices[0].delta.content or ""
|
139 |
+
data = {
|
140 |
+
"text": text,
|
141 |
+
"error_code": 0,
|
142 |
+
}
|
143 |
+
yield data
|
144 |
+
|
145 |
+
|
146 |
+
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
|
147 |
+
import anthropic
|
148 |
+
|
149 |
+
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
|
150 |
+
|
151 |
+
# Make requests
|
152 |
+
gen_params = {
|
153 |
+
"model": model_name,
|
154 |
+
"prompt": prompt,
|
155 |
+
"temperature": temperature,
|
156 |
+
"top_p": top_p,
|
157 |
+
"max_new_tokens": max_new_tokens,
|
158 |
+
}
|
159 |
+
logger.info(f"==== request ====\n{gen_params}")
|
160 |
+
|
161 |
+
res = c.completions.create(
|
162 |
+
prompt=prompt,
|
163 |
+
stop_sequences=[anthropic.HUMAN_PROMPT],
|
164 |
+
max_tokens_to_sample=max_new_tokens,
|
165 |
+
temperature=temperature,
|
166 |
+
top_p=top_p,
|
167 |
+
model=model_name,
|
168 |
+
stream=True,
|
169 |
+
)
|
170 |
+
text = ""
|
171 |
+
for chunk in res:
|
172 |
+
text += chunk.completion
|
173 |
+
data = {
|
174 |
+
"text": text,
|
175 |
+
"error_code": 0,
|
176 |
+
}
|
177 |
+
yield data
|
178 |
+
|
179 |
+
|
180 |
+
def gemini_api_stream_iter(
|
181 |
+
model_name, conv, temperature, top_p, max_new_tokens, api_key=None
|
182 |
+
):
|
183 |
+
import google.generativeai as genai # pip install google-generativeai
|
184 |
+
|
185 |
+
if api_key is None:
|
186 |
+
api_key = os.environ["GEMINI_API_KEY"]
|
187 |
+
genai.configure(api_key=api_key)
|
188 |
+
|
189 |
+
generation_config = {
|
190 |
+
"temperature": temperature,
|
191 |
+
"max_output_tokens": max_new_tokens,
|
192 |
+
"top_p": top_p,
|
193 |
+
}
|
194 |
+
params = {
|
195 |
+
"model": model_name,
|
196 |
+
"prompt": conv,
|
197 |
+
}
|
198 |
+
params.update(generation_config)
|
199 |
+
logger.info(f"==== request ====\n{params}")
|
200 |
+
|
201 |
+
safety_settings = [
|
202 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
203 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
204 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
205 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
206 |
+
]
|
207 |
+
model = genai.GenerativeModel(
|
208 |
+
model_name=model_name,
|
209 |
+
generation_config=generation_config,
|
210 |
+
safety_settings=safety_settings,
|
211 |
+
)
|
212 |
+
history = []
|
213 |
+
for role, message in conv.messages[:-2]:
|
214 |
+
history.append({"role": role, "parts": message})
|
215 |
+
convo = model.start_chat(history=history)
|
216 |
+
response = convo.send_message(conv.messages[-2][1], stream=True)
|
217 |
+
|
218 |
+
try:
|
219 |
+
text = ""
|
220 |
+
for chunk in response:
|
221 |
+
text += chunk.text
|
222 |
+
data = {
|
223 |
+
"text": text,
|
224 |
+
"error_code": 0,
|
225 |
+
}
|
226 |
+
yield data
|
227 |
+
except Exception as e:
|
228 |
+
logger.error(f"==== error ====\n{e}")
|
229 |
+
reason = chunk.candidates
|
230 |
+
yield {
|
231 |
+
"text": f"**API REQUEST ERROR** Reason: {reason}.",
|
232 |
+
"error_code": 1,
|
233 |
+
}
|
234 |
+
|
235 |
+
|
236 |
+
def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None):
|
237 |
+
del top_p # not supported
|
238 |
+
del temperature # not supported
|
239 |
+
|
240 |
+
if api_key is None:
|
241 |
+
api_key = os.environ["BARD_API_KEY"]
|
242 |
+
|
243 |
+
# convert conv to conv_bard
|
244 |
+
conv_bard = []
|
245 |
+
for turn in conv:
|
246 |
+
if turn["role"] == "user":
|
247 |
+
conv_bard.append({"author": "0", "content": turn["content"]})
|
248 |
+
elif turn["role"] == "assistant":
|
249 |
+
conv_bard.append({"author": "1", "content": turn["content"]})
|
250 |
+
else:
|
251 |
+
raise ValueError(f"Unsupported role: {turn['role']}")
|
252 |
+
|
253 |
+
params = {
|
254 |
+
"model": model_name,
|
255 |
+
"prompt": conv_bard,
|
256 |
+
}
|
257 |
+
logger.info(f"==== request ====\n{params}")
|
258 |
+
|
259 |
+
try:
|
260 |
+
res = requests.post(
|
261 |
+
f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}",
|
262 |
+
json={
|
263 |
+
"prompt": {
|
264 |
+
"messages": conv_bard,
|
265 |
+
},
|
266 |
+
},
|
267 |
+
timeout=30,
|
268 |
+
)
|
269 |
+
except Exception as e:
|
270 |
+
logger.error(f"==== error ====\n{e}")
|
271 |
+
yield {
|
272 |
+
"text": f"**API REQUEST ERROR** Reason: {e}.",
|
273 |
+
"error_code": 1,
|
274 |
+
}
|
275 |
+
|
276 |
+
if res.status_code != 200:
|
277 |
+
logger.error(f"==== error ==== ({res.status_code}): {res.text}")
|
278 |
+
yield {
|
279 |
+
"text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.",
|
280 |
+
"error_code": 1,
|
281 |
+
}
|
282 |
+
|
283 |
+
response_json = res.json()
|
284 |
+
if "candidates" not in response_json:
|
285 |
+
logger.error(f"==== error ==== response blocked: {response_json}")
|
286 |
+
reason = response_json["filters"][0]["reason"]
|
287 |
+
yield {
|
288 |
+
"text": f"**API REQUEST ERROR** Reason: {reason}.",
|
289 |
+
"error_code": 1,
|
290 |
+
}
|
291 |
+
|
292 |
+
response = response_json["candidates"][0]["content"]
|
293 |
+
pos = 0
|
294 |
+
while pos < len(response):
|
295 |
+
# simulate token streaming
|
296 |
+
pos += random.randint(3, 6)
|
297 |
+
time.sleep(0.002)
|
298 |
+
data = {
|
299 |
+
"text": response[:pos],
|
300 |
+
"error_code": 0,
|
301 |
+
}
|
302 |
+
yield data
|
303 |
+
|
304 |
+
|
305 |
+
def ai2_api_stream_iter(
|
306 |
+
model_name,
|
307 |
+
model_id,
|
308 |
+
messages,
|
309 |
+
temperature,
|
310 |
+
top_p,
|
311 |
+
max_new_tokens,
|
312 |
+
api_key=None,
|
313 |
+
api_base=None,
|
314 |
+
):
|
315 |
+
# get keys and needed values
|
316 |
+
ai2_key = api_key or os.environ.get("AI2_API_KEY")
|
317 |
+
api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
|
318 |
+
|
319 |
+
# Make requests
|
320 |
+
gen_params = {
|
321 |
+
"model": model_name,
|
322 |
+
"prompt": messages,
|
323 |
+
"temperature": temperature,
|
324 |
+
"top_p": top_p,
|
325 |
+
"max_new_tokens": max_new_tokens,
|
326 |
+
}
|
327 |
+
logger.info(f"==== request ====\n{gen_params}")
|
328 |
+
|
329 |
+
# AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
|
330 |
+
# https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
|
331 |
+
if temperature == 0.0 and top_p < 1.0:
|
332 |
+
raise ValueError("top_p must be 1 when temperature is 0.0")
|
333 |
+
|
334 |
+
res = requests.post(
|
335 |
+
api_base,
|
336 |
+
stream=True,
|
337 |
+
headers={"Authorization": f"Bearer {ai2_key}"},
|
338 |
+
json={
|
339 |
+
"model_id": model_id,
|
340 |
+
# This input format is specific to the Tulu2 model. Other models
|
341 |
+
# may require different input formats. See the model's schema
|
342 |
+
# documentation on InferD for more information.
|
343 |
+
"input": {
|
344 |
+
"messages": messages,
|
345 |
+
"opts": {
|
346 |
+
"max_tokens": max_new_tokens,
|
347 |
+
"temperature": temperature,
|
348 |
+
"top_p": top_p,
|
349 |
+
"logprobs": 1, # increase for more choices
|
350 |
+
},
|
351 |
+
},
|
352 |
+
},
|
353 |
+
timeout=5,
|
354 |
+
)
|
355 |
+
|
356 |
+
if res.status_code != 200:
|
357 |
+
logger.error(f"unexpected response ({res.status_code}): {res.text}")
|
358 |
+
raise ValueError("unexpected response from InferD", res)
|
359 |
+
|
360 |
+
text = ""
|
361 |
+
for line in res.iter_lines():
|
362 |
+
if line:
|
363 |
+
part = json.loads(line)
|
364 |
+
if "result" in part and "output" in part["result"]:
|
365 |
+
for t in part["result"]["output"]["text"]:
|
366 |
+
text += t
|
367 |
+
else:
|
368 |
+
logger.error(f"unexpected part: {part}")
|
369 |
+
raise ValueError("empty result in InferD response")
|
370 |
+
|
371 |
+
data = {
|
372 |
+
"text": text,
|
373 |
+
"error_code": 0,
|
374 |
+
}
|
375 |
+
yield data
|
376 |
+
|
377 |
+
|
378 |
+
def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
|
379 |
+
from mistralai.client import MistralClient
|
380 |
+
from mistralai.models.chat_completion import ChatMessage
|
381 |
+
|
382 |
+
api_key = os.environ["MISTRAL_API_KEY"]
|
383 |
+
|
384 |
+
client = MistralClient(api_key=api_key)
|
385 |
+
|
386 |
+
# Make requests
|
387 |
+
gen_params = {
|
388 |
+
"model": model_name,
|
389 |
+
"prompt": messages,
|
390 |
+
"temperature": temperature,
|
391 |
+
"top_p": top_p,
|
392 |
+
"max_new_tokens": max_new_tokens,
|
393 |
+
}
|
394 |
+
logger.info(f"==== request ====\n{gen_params}")
|
395 |
+
|
396 |
+
new_messages = [
|
397 |
+
ChatMessage(role=message["role"], content=message["content"])
|
398 |
+
for message in messages
|
399 |
+
]
|
400 |
+
|
401 |
+
res = client.chat_stream(
|
402 |
+
model=model_name,
|
403 |
+
temperature=temperature,
|
404 |
+
messages=new_messages,
|
405 |
+
max_tokens=max_new_tokens,
|
406 |
+
top_p=top_p,
|
407 |
+
)
|
408 |
+
|
409 |
+
text = ""
|
410 |
+
for chunk in res:
|
411 |
+
if chunk.choices[0].delta.content is not None:
|
412 |
+
text += chunk.choices[0].delta.content
|
413 |
+
data = {
|
414 |
+
"text": text,
|
415 |
+
"error_code": 0,
|
416 |
+
}
|
417 |
+
yield data
|
418 |
+
|
419 |
+
|
420 |
+
def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base):
|
421 |
+
assert model_name in ["llama2-70b-steerlm-chat", "yi-34b-chat"]
|
422 |
+
|
423 |
+
api_key = os.environ["NVIDIA_API_KEY"]
|
424 |
+
headers = {
|
425 |
+
"Authorization": f"Bearer {api_key}",
|
426 |
+
"accept": "text/event-stream",
|
427 |
+
"content-type": "application/json",
|
428 |
+
}
|
429 |
+
# nvidia api does not accept 0 temperature
|
430 |
+
if temp == 0.0:
|
431 |
+
temp = 0.0001
|
432 |
+
|
433 |
+
payload = {
|
434 |
+
"messages": messages,
|
435 |
+
"temperature": temp,
|
436 |
+
"top_p": top_p,
|
437 |
+
"max_tokens": max_tokens,
|
438 |
+
"seed": 42,
|
439 |
+
"stream": True,
|
440 |
+
}
|
441 |
+
logger.info(f"==== request ====\n{payload}")
|
442 |
+
|
443 |
+
response = requests.post(
|
444 |
+
api_base, headers=headers, json=payload, stream=True, timeout=1
|
445 |
+
)
|
446 |
+
text = ""
|
447 |
+
for line in response.iter_lines():
|
448 |
+
if line:
|
449 |
+
data = line.decode("utf-8")
|
450 |
+
if data.endswith("[DONE]"):
|
451 |
+
break
|
452 |
+
data = json.loads(data[6:])["choices"][0]["delta"]["content"]
|
453 |
+
text += data
|
454 |
+
yield {"text": text, "error_code": 0}
|
base_model_worker.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import threading
|
3 |
+
import time
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
7 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from fastchat.constants import WORKER_HEART_BEAT_INTERVAL
|
11 |
+
from fastchat.conversation import Conversation
|
12 |
+
from fastchat.utils import pretty_print_semaphore, build_logger
|
13 |
+
|
14 |
+
|
15 |
+
worker = None
|
16 |
+
logger = None
|
17 |
+
|
18 |
+
app = FastAPI()
|
19 |
+
|
20 |
+
|
21 |
+
def heart_beat_worker(obj):
|
22 |
+
while True:
|
23 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
24 |
+
obj.send_heart_beat()
|
25 |
+
|
26 |
+
|
27 |
+
class BaseModelWorker:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
controller_addr: str,
|
31 |
+
worker_addr: str,
|
32 |
+
worker_id: str,
|
33 |
+
model_path: str,
|
34 |
+
model_names: List[str],
|
35 |
+
limit_worker_concurrency: int,
|
36 |
+
conv_template: str = None,
|
37 |
+
multimodal: bool = False,
|
38 |
+
):
|
39 |
+
global logger, worker
|
40 |
+
|
41 |
+
self.controller_addr = controller_addr
|
42 |
+
self.worker_addr = worker_addr
|
43 |
+
self.worker_id = worker_id
|
44 |
+
if model_path.endswith("/"):
|
45 |
+
model_path = model_path[:-1]
|
46 |
+
self.model_names = model_names or [model_path.split("/")[-1]]
|
47 |
+
self.limit_worker_concurrency = limit_worker_concurrency
|
48 |
+
self.conv = self.make_conv_template(conv_template, model_path)
|
49 |
+
self.conv.sep_style = int(self.conv.sep_style)
|
50 |
+
self.multimodal = multimodal
|
51 |
+
self.tokenizer = None
|
52 |
+
self.context_len = None
|
53 |
+
self.call_ct = 0
|
54 |
+
self.semaphore = None
|
55 |
+
|
56 |
+
self.heart_beat_thread = None
|
57 |
+
|
58 |
+
if logger is None:
|
59 |
+
logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log")
|
60 |
+
if worker is None:
|
61 |
+
worker = self
|
62 |
+
|
63 |
+
def make_conv_template(
|
64 |
+
self,
|
65 |
+
conv_template: str = None,
|
66 |
+
model_path: str = None,
|
67 |
+
) -> Conversation:
|
68 |
+
"""
|
69 |
+
can be overrided to costomize the conversation template for different model workers.
|
70 |
+
"""
|
71 |
+
from fastchat.conversation import get_conv_template
|
72 |
+
from fastchat.model.model_adapter import get_conversation_template
|
73 |
+
|
74 |
+
if conv_template:
|
75 |
+
conv = get_conv_template(conv_template)
|
76 |
+
else:
|
77 |
+
conv = get_conversation_template(model_path)
|
78 |
+
return conv
|
79 |
+
|
80 |
+
def init_heart_beat(self):
|
81 |
+
self.register_to_controller()
|
82 |
+
self.heart_beat_thread = threading.Thread(
|
83 |
+
target=heart_beat_worker,
|
84 |
+
args=(self,),
|
85 |
+
daemon=True,
|
86 |
+
)
|
87 |
+
self.heart_beat_thread.start()
|
88 |
+
|
89 |
+
def register_to_controller(self):
|
90 |
+
logger.info("Register to controller")
|
91 |
+
|
92 |
+
url = self.controller_addr + "/register_worker"
|
93 |
+
data = {
|
94 |
+
"worker_name": self.worker_addr,
|
95 |
+
"check_heart_beat": True,
|
96 |
+
"worker_status": self.get_status(),
|
97 |
+
"multimodal": self.multimodal,
|
98 |
+
}
|
99 |
+
r = requests.post(url, json=data)
|
100 |
+
assert r.status_code == 200
|
101 |
+
|
102 |
+
def send_heart_beat(self):
|
103 |
+
logger.info(
|
104 |
+
f"Send heart beat. Models: {self.model_names}. "
|
105 |
+
f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
|
106 |
+
f"call_ct: {self.call_ct}. "
|
107 |
+
f"worker_id: {self.worker_id}. "
|
108 |
+
)
|
109 |
+
|
110 |
+
url = self.controller_addr + "/receive_heart_beat"
|
111 |
+
|
112 |
+
while True:
|
113 |
+
try:
|
114 |
+
ret = requests.post(
|
115 |
+
url,
|
116 |
+
json={
|
117 |
+
"worker_name": self.worker_addr,
|
118 |
+
"queue_length": self.get_queue_length(),
|
119 |
+
},
|
120 |
+
timeout=5,
|
121 |
+
)
|
122 |
+
exist = ret.json()["exist"]
|
123 |
+
break
|
124 |
+
except (requests.exceptions.RequestException, KeyError) as e:
|
125 |
+
logger.error(f"heart beat error: {e}")
|
126 |
+
time.sleep(5)
|
127 |
+
|
128 |
+
if not exist:
|
129 |
+
self.register_to_controller()
|
130 |
+
|
131 |
+
def get_queue_length(self):
|
132 |
+
if self.semaphore is None:
|
133 |
+
return 0
|
134 |
+
else:
|
135 |
+
sempahore_value = (
|
136 |
+
self.semaphore._value
|
137 |
+
if self.semaphore._value is not None
|
138 |
+
else self.limit_worker_concurrency
|
139 |
+
)
|
140 |
+
waiter_count = (
|
141 |
+
0 if self.semaphore._waiters is None else len(self.semaphore._waiters)
|
142 |
+
)
|
143 |
+
return self.limit_worker_concurrency - sempahore_value + waiter_count
|
144 |
+
|
145 |
+
def get_status(self):
|
146 |
+
return {
|
147 |
+
"model_names": self.model_names,
|
148 |
+
"speed": 1,
|
149 |
+
"queue_length": self.get_queue_length(),
|
150 |
+
}
|
151 |
+
|
152 |
+
def count_token(self, params):
|
153 |
+
prompt = params["prompt"]
|
154 |
+
|
155 |
+
try:
|
156 |
+
input_ids = self.tokenizer(prompt).input_ids
|
157 |
+
input_echo_len = len(input_ids)
|
158 |
+
except TypeError:
|
159 |
+
input_echo_len = self.tokenizer.num_tokens(prompt)
|
160 |
+
|
161 |
+
ret = {
|
162 |
+
"count": input_echo_len,
|
163 |
+
"error_code": 0,
|
164 |
+
}
|
165 |
+
return ret
|
166 |
+
|
167 |
+
def get_conv_template(self):
|
168 |
+
return {"conv": self.conv}
|
169 |
+
|
170 |
+
def generate_stream_gate(self, params):
|
171 |
+
raise NotImplementedError
|
172 |
+
|
173 |
+
def generate_gate(self, params):
|
174 |
+
raise NotImplementedError
|
175 |
+
|
176 |
+
def get_embeddings(self, params):
|
177 |
+
raise NotImplementedError
|
178 |
+
|
179 |
+
|
180 |
+
def release_worker_semaphore():
|
181 |
+
worker.semaphore.release()
|
182 |
+
|
183 |
+
|
184 |
+
def acquire_worker_semaphore():
|
185 |
+
if worker.semaphore is None:
|
186 |
+
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
187 |
+
return worker.semaphore.acquire()
|
188 |
+
|
189 |
+
|
190 |
+
def create_background_tasks():
|
191 |
+
background_tasks = BackgroundTasks()
|
192 |
+
background_tasks.add_task(release_worker_semaphore)
|
193 |
+
return background_tasks
|
194 |
+
|
195 |
+
|
196 |
+
@app.post("/worker_generate_stream")
|
197 |
+
async def api_generate_stream(request: Request):
|
198 |
+
params = await request.json()
|
199 |
+
await acquire_worker_semaphore()
|
200 |
+
generator = worker.generate_stream_gate(params)
|
201 |
+
background_tasks = create_background_tasks()
|
202 |
+
return StreamingResponse(generator, background=background_tasks)
|
203 |
+
|
204 |
+
|
205 |
+
@app.post("/worker_generate")
|
206 |
+
async def api_generate(request: Request):
|
207 |
+
params = await request.json()
|
208 |
+
await acquire_worker_semaphore()
|
209 |
+
output = await asyncio.to_thread(worker.generate_gate, params)
|
210 |
+
release_worker_semaphore()
|
211 |
+
return JSONResponse(output)
|
212 |
+
|
213 |
+
|
214 |
+
@app.post("/worker_get_embeddings")
|
215 |
+
async def api_get_embeddings(request: Request):
|
216 |
+
params = await request.json()
|
217 |
+
await acquire_worker_semaphore()
|
218 |
+
embedding = worker.get_embeddings(params)
|
219 |
+
release_worker_semaphore()
|
220 |
+
return JSONResponse(content=embedding)
|
221 |
+
|
222 |
+
|
223 |
+
@app.post("/worker_get_status")
|
224 |
+
async def api_get_status(request: Request):
|
225 |
+
return worker.get_status()
|
226 |
+
|
227 |
+
|
228 |
+
@app.post("/count_token")
|
229 |
+
async def api_count_token(request: Request):
|
230 |
+
params = await request.json()
|
231 |
+
return worker.count_token(params)
|
232 |
+
|
233 |
+
|
234 |
+
@app.post("/worker_get_conv_template")
|
235 |
+
async def api_get_conv(request: Request):
|
236 |
+
return worker.get_conv_template()
|
237 |
+
|
238 |
+
|
239 |
+
@app.post("/model_details")
|
240 |
+
async def api_model_details(request: Request):
|
241 |
+
return {"context_length": worker.context_len}
|
call_monitor.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
import time
|
5 |
+
|
6 |
+
from fastapi import FastAPI
|
7 |
+
import hashlib
|
8 |
+
import asyncio
|
9 |
+
|
10 |
+
REFRESH_INTERVAL_SEC = 60
|
11 |
+
LOG_DIR = "/home/vicuna/fastchat_logs/server0"
|
12 |
+
# LOG_DIR = "/home/vicuna/tmp/test_env"
|
13 |
+
|
14 |
+
|
15 |
+
class Monitor:
|
16 |
+
"""Monitor the number of calls to each model."""
|
17 |
+
|
18 |
+
def __init__(self, log_dir: str):
|
19 |
+
self.log_dir = log_dir
|
20 |
+
self.model_call = {}
|
21 |
+
self.user_call = {}
|
22 |
+
self.model_call_limit_global = {
|
23 |
+
"gpt-4-1106-preview": 300,
|
24 |
+
"gpt-4-0125-preview": 300,
|
25 |
+
}
|
26 |
+
self.model_call_day_limit_per_user = {"gpt-4-1106-preview": 10}
|
27 |
+
|
28 |
+
async def update_stats(self, num_file=1) -> None:
|
29 |
+
while True:
|
30 |
+
# find the latest num_file log under log_dir
|
31 |
+
json_files = glob.glob(os.path.join(self.log_dir, "*.json"))
|
32 |
+
json_files.sort(key=os.path.getctime, reverse=True)
|
33 |
+
json_files = json_files[:num_file]
|
34 |
+
|
35 |
+
model_call = {}
|
36 |
+
user_call = {}
|
37 |
+
for json_file in json_files:
|
38 |
+
for line in open(json_file, "r", encoding="utf-8"):
|
39 |
+
obj = json.loads(line)
|
40 |
+
if obj["type"] != "chat":
|
41 |
+
continue
|
42 |
+
if obj["model"] not in model_call:
|
43 |
+
model_call[obj["model"]] = []
|
44 |
+
model_call[obj["model"]].append(
|
45 |
+
{"tstamp": obj["tstamp"], "user_id": obj["ip"]}
|
46 |
+
)
|
47 |
+
if obj["ip"] not in user_call:
|
48 |
+
user_call[obj["ip"]] = []
|
49 |
+
user_call[obj["ip"]].append(
|
50 |
+
{"tstamp": obj["tstamp"], "model": obj["model"]}
|
51 |
+
)
|
52 |
+
|
53 |
+
self.model_call = model_call
|
54 |
+
self.model_call_stats_hour = self.get_model_call_stats(top_k=None)
|
55 |
+
self.model_call_stats_day = self.get_model_call_stats(
|
56 |
+
top_k=None, most_recent_min=24 * 60
|
57 |
+
)
|
58 |
+
|
59 |
+
self.user_call = user_call
|
60 |
+
self.user_call_stats_hour = self.get_user_call_stats(top_k=None)
|
61 |
+
self.user_call_stats_day = self.get_user_call_stats(
|
62 |
+
top_k=None, most_recent_min=24 * 60
|
63 |
+
)
|
64 |
+
await asyncio.sleep(REFRESH_INTERVAL_SEC)
|
65 |
+
|
66 |
+
def get_model_call_limit(self, model: str) -> int:
|
67 |
+
if model not in self.model_call_limit_global:
|
68 |
+
return -1
|
69 |
+
return self.model_call_limit_global[model]
|
70 |
+
|
71 |
+
def update_model_call_limit(self, model: str, limit: int) -> bool:
|
72 |
+
if model not in self.model_call_limit_global:
|
73 |
+
return False
|
74 |
+
self.model_call_limit_global[model] = limit
|
75 |
+
return True
|
76 |
+
|
77 |
+
def is_model_limit_reached(self, model: str) -> bool:
|
78 |
+
if model not in self.model_call_limit_global:
|
79 |
+
return False
|
80 |
+
if model not in self.model_call_stats_hour:
|
81 |
+
return False
|
82 |
+
# check if the model call limit is reached
|
83 |
+
if self.model_call_stats_hour[model] >= self.model_call_limit_global[model]:
|
84 |
+
return True
|
85 |
+
return False
|
86 |
+
|
87 |
+
def is_user_limit_reached(self, model: str, user_id: str) -> bool:
|
88 |
+
if model not in self.model_call_day_limit_per_user:
|
89 |
+
return False
|
90 |
+
if user_id not in self.user_call_stats_day:
|
91 |
+
return False
|
92 |
+
if model not in self.user_call_stats_day[user_id]["call_dict"]:
|
93 |
+
return False
|
94 |
+
# check if the user call limit is reached
|
95 |
+
if (
|
96 |
+
self.user_call_stats_day[user_id]["call_dict"][model]
|
97 |
+
>= self.model_call_day_limit_per_user[model]
|
98 |
+
):
|
99 |
+
return True
|
100 |
+
return False
|
101 |
+
|
102 |
+
def get_model_call_stats(
|
103 |
+
self, target_model=None, most_recent_min: int = 60, top_k: int = 20
|
104 |
+
) -> dict:
|
105 |
+
model_call_stats = {}
|
106 |
+
for model, reqs in self.model_call.items():
|
107 |
+
if target_model is not None and model != target_model:
|
108 |
+
continue
|
109 |
+
model_call = []
|
110 |
+
for req in reqs:
|
111 |
+
if req["tstamp"] < time.time() - most_recent_min * 60:
|
112 |
+
continue
|
113 |
+
model_call.append(req["tstamp"])
|
114 |
+
model_call_stats[model] = len(model_call)
|
115 |
+
if top_k is not None:
|
116 |
+
top_k_model = sorted(
|
117 |
+
model_call_stats, key=lambda x: model_call_stats[x], reverse=True
|
118 |
+
)[:top_k]
|
119 |
+
model_call_stats = {model: model_call_stats[model] for model in top_k_model}
|
120 |
+
return model_call_stats
|
121 |
+
|
122 |
+
def get_user_call_stats(
|
123 |
+
self, target_model=None, most_recent_min: int = 60, top_k: int = 20
|
124 |
+
) -> dict:
|
125 |
+
user_call_stats = {}
|
126 |
+
for user_id, reqs in self.user_call.items():
|
127 |
+
user_model_call = {"call_dict": {}}
|
128 |
+
for req in reqs:
|
129 |
+
if req["tstamp"] < time.time() - most_recent_min * 60:
|
130 |
+
continue
|
131 |
+
if target_model is not None and req["model"] != target_model:
|
132 |
+
continue
|
133 |
+
if req["model"] not in user_model_call["call_dict"]:
|
134 |
+
user_model_call["call_dict"][req["model"]] = 0
|
135 |
+
user_model_call["call_dict"][req["model"]] += 1
|
136 |
+
|
137 |
+
user_model_call["total_calls"] = sum(user_model_call["call_dict"].values())
|
138 |
+
if user_model_call["total_calls"] > 0:
|
139 |
+
user_call_stats[user_id] = user_model_call
|
140 |
+
if top_k is not None:
|
141 |
+
top_k_user = sorted(
|
142 |
+
user_call_stats,
|
143 |
+
key=lambda x: user_call_stats[x]["total_calls"],
|
144 |
+
reverse=True,
|
145 |
+
)[:top_k]
|
146 |
+
user_call_stats = {
|
147 |
+
user_id: user_call_stats[user_id] for user_id in top_k_user
|
148 |
+
}
|
149 |
+
return user_call_stats
|
150 |
+
|
151 |
+
def get_num_users(self, most_recent_min: int = 60) -> int:
|
152 |
+
user_call_stats = self.get_user_call_stats(
|
153 |
+
most_recent_min=most_recent_min, top_k=None
|
154 |
+
)
|
155 |
+
return len(user_call_stats)
|
156 |
+
|
157 |
+
|
158 |
+
monitor = Monitor(log_dir=LOG_DIR)
|
159 |
+
app = FastAPI()
|
160 |
+
|
161 |
+
|
162 |
+
@app.on_event("startup")
|
163 |
+
async def app_startup():
|
164 |
+
asyncio.create_task(monitor.update_stats(2))
|
165 |
+
|
166 |
+
|
167 |
+
@app.get("/get_model_call_limit/{model}")
|
168 |
+
async def get_model_call_limit(model: str):
|
169 |
+
return {"model_call_limit": {model: monitor.get_model_call_limit(model)}}
|
170 |
+
|
171 |
+
|
172 |
+
@app.get("/update_model_call_limit/{model}/{limit}")
|
173 |
+
async def update_model_call_limit(model: str, limit: int):
|
174 |
+
if not monitor.update_model_call_limit(model, limit):
|
175 |
+
return {"success": False}
|
176 |
+
return {"success": True}
|
177 |
+
|
178 |
+
|
179 |
+
@app.get("/is_limit_reached")
|
180 |
+
async def is_limit_reached(model: str, user_id: str):
|
181 |
+
if monitor.is_model_limit_reached(model):
|
182 |
+
return {
|
183 |
+
"is_limit_reached": True,
|
184 |
+
"reason": f"MODEL_HOURLY_LIMIT ({model}): {monitor.get_model_call_limit(model)}",
|
185 |
+
}
|
186 |
+
if monitor.is_user_limit_reached(model, user_id):
|
187 |
+
return {
|
188 |
+
"is_limit_reached": True,
|
189 |
+
"reason": f"USER_DAILY_LIMIT ({model}): {monitor.model_call_day_limit_per_user[model]}",
|
190 |
+
}
|
191 |
+
return {"is_limit_reached": False}
|
192 |
+
|
193 |
+
|
194 |
+
@app.get("/get_num_users_hr")
|
195 |
+
async def get_num_users():
|
196 |
+
return {"num_users": len(monitor.user_call_stats_hour)}
|
197 |
+
|
198 |
+
|
199 |
+
@app.get("/get_num_users_day")
|
200 |
+
async def get_num_users_day():
|
201 |
+
return {"num_users": len(monitor.user_call_stats_day)}
|
202 |
+
|
203 |
+
|
204 |
+
@app.get("/get_user_call_stats")
|
205 |
+
async def get_user_call_stats(
|
206 |
+
model: str = None, most_recent_min: int = 60, top_k: int = None
|
207 |
+
):
|
208 |
+
return {
|
209 |
+
"user_call_stats": monitor.get_user_call_stats(model, most_recent_min, top_k)
|
210 |
+
}
|
211 |
+
|
212 |
+
|
213 |
+
@app.get("/get_model_call_stats")
|
214 |
+
async def get_model_call_stats(
|
215 |
+
model: str = None, most_recent_min: int = 60, top_k: int = None
|
216 |
+
):
|
217 |
+
return {
|
218 |
+
"model_call_stats": monitor.get_model_call_stats(model, most_recent_min, top_k)
|
219 |
+
}
|
cli.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chat with a model with command line interface.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5
|
6 |
+
python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0
|
7 |
+
|
8 |
+
Other commands:
|
9 |
+
- Type "!!exit" or an empty line to exit.
|
10 |
+
- Type "!!reset" to start a new conversation.
|
11 |
+
- Type "!!remove" to remove the last prompt.
|
12 |
+
- Type "!!regen" to regenerate the last message.
|
13 |
+
- Type "!!save <filename>" to save the conversation history to a json file.
|
14 |
+
- Type "!!load <filename>" to load a conversation history from a json file.
|
15 |
+
"""
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
import sys
|
20 |
+
|
21 |
+
from prompt_toolkit import PromptSession
|
22 |
+
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
23 |
+
from prompt_toolkit.completion import WordCompleter
|
24 |
+
from prompt_toolkit.history import InMemoryHistory
|
25 |
+
from prompt_toolkit.key_binding import KeyBindings
|
26 |
+
from rich.console import Console
|
27 |
+
from rich.live import Live
|
28 |
+
from rich.markdown import Markdown
|
29 |
+
import torch
|
30 |
+
|
31 |
+
from fastchat.model.model_adapter import add_model_args
|
32 |
+
from fastchat.modules.awq import AWQConfig
|
33 |
+
from fastchat.modules.exllama import ExllamaConfig
|
34 |
+
from fastchat.modules.xfastertransformer import XftConfig
|
35 |
+
from fastchat.modules.gptq import GptqConfig
|
36 |
+
from fastchat.serve.inference import ChatIO, chat_loop
|
37 |
+
from fastchat.utils import str_to_torch_dtype
|
38 |
+
|
39 |
+
|
40 |
+
class SimpleChatIO(ChatIO):
|
41 |
+
def __init__(self, multiline: bool = False):
|
42 |
+
self._multiline = multiline
|
43 |
+
|
44 |
+
def prompt_for_input(self, role) -> str:
|
45 |
+
if not self._multiline:
|
46 |
+
return input(f"{role}: ")
|
47 |
+
|
48 |
+
prompt_data = []
|
49 |
+
line = input(f"{role} [ctrl-d/z on empty line to end]: ")
|
50 |
+
while True:
|
51 |
+
prompt_data.append(line.strip())
|
52 |
+
try:
|
53 |
+
line = input()
|
54 |
+
except EOFError as e:
|
55 |
+
break
|
56 |
+
return "\n".join(prompt_data)
|
57 |
+
|
58 |
+
def prompt_for_output(self, role: str):
|
59 |
+
print(f"{role}: ", end="", flush=True)
|
60 |
+
|
61 |
+
def stream_output(self, output_stream):
|
62 |
+
pre = 0
|
63 |
+
for outputs in output_stream:
|
64 |
+
output_text = outputs["text"]
|
65 |
+
output_text = output_text.strip().split(" ")
|
66 |
+
now = len(output_text) - 1
|
67 |
+
if now > pre:
|
68 |
+
print(" ".join(output_text[pre:now]), end=" ", flush=True)
|
69 |
+
pre = now
|
70 |
+
print(" ".join(output_text[pre:]), flush=True)
|
71 |
+
return " ".join(output_text)
|
72 |
+
|
73 |
+
def print_output(self, text: str):
|
74 |
+
print(text)
|
75 |
+
|
76 |
+
|
77 |
+
class RichChatIO(ChatIO):
|
78 |
+
bindings = KeyBindings()
|
79 |
+
|
80 |
+
@bindings.add("escape", "enter")
|
81 |
+
def _(event):
|
82 |
+
event.app.current_buffer.newline()
|
83 |
+
|
84 |
+
def __init__(self, multiline: bool = False, mouse: bool = False):
|
85 |
+
self._prompt_session = PromptSession(history=InMemoryHistory())
|
86 |
+
self._completer = WordCompleter(
|
87 |
+
words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"],
|
88 |
+
pattern=re.compile("$"),
|
89 |
+
)
|
90 |
+
self._console = Console()
|
91 |
+
self._multiline = multiline
|
92 |
+
self._mouse = mouse
|
93 |
+
|
94 |
+
def prompt_for_input(self, role) -> str:
|
95 |
+
self._console.print(f"[bold]{role}:")
|
96 |
+
# TODO(suquark): multiline input has some issues. fix it later.
|
97 |
+
prompt_input = self._prompt_session.prompt(
|
98 |
+
completer=self._completer,
|
99 |
+
multiline=False,
|
100 |
+
mouse_support=self._mouse,
|
101 |
+
auto_suggest=AutoSuggestFromHistory(),
|
102 |
+
key_bindings=self.bindings if self._multiline else None,
|
103 |
+
)
|
104 |
+
self._console.print()
|
105 |
+
return prompt_input
|
106 |
+
|
107 |
+
def prompt_for_output(self, role: str):
|
108 |
+
self._console.print(f"[bold]{role.replace('/', '|')}:")
|
109 |
+
|
110 |
+
def stream_output(self, output_stream):
|
111 |
+
"""Stream output from a role."""
|
112 |
+
# TODO(suquark): the console flickers when there is a code block
|
113 |
+
# above it. We need to cut off "live" when a code block is done.
|
114 |
+
|
115 |
+
# Create a Live context for updating the console output
|
116 |
+
with Live(console=self._console, refresh_per_second=4) as live:
|
117 |
+
# Read lines from the stream
|
118 |
+
for outputs in output_stream:
|
119 |
+
if not outputs:
|
120 |
+
continue
|
121 |
+
text = outputs["text"]
|
122 |
+
# Render the accumulated text as Markdown
|
123 |
+
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
124 |
+
# in rich. The chatbots output treat "\n" as a new line for
|
125 |
+
# better compatibility with real-world text. However, rendering
|
126 |
+
# in markdown would break the format. It is because standard markdown
|
127 |
+
# treat a single "\n" in normal text as a space.
|
128 |
+
# Our workaround is adding two spaces at the end of each line.
|
129 |
+
# This is not a perfect solution, as it would
|
130 |
+
# introduce trailing spaces (only) in code block, but it works well
|
131 |
+
# especially for console output, because in general the console does not
|
132 |
+
# care about trailing spaces.
|
133 |
+
lines = []
|
134 |
+
for line in text.splitlines():
|
135 |
+
lines.append(line)
|
136 |
+
if line.startswith("```"):
|
137 |
+
# Code block marker - do not add trailing spaces, as it would
|
138 |
+
# break the syntax highlighting
|
139 |
+
lines.append("\n")
|
140 |
+
else:
|
141 |
+
lines.append(" \n")
|
142 |
+
markdown = Markdown("".join(lines))
|
143 |
+
# Update the Live console output
|
144 |
+
live.update(markdown)
|
145 |
+
self._console.print()
|
146 |
+
return text
|
147 |
+
|
148 |
+
def print_output(self, text: str):
|
149 |
+
self.stream_output([{"text": text}])
|
150 |
+
|
151 |
+
|
152 |
+
class ProgrammaticChatIO(ChatIO):
|
153 |
+
def prompt_for_input(self, role) -> str:
|
154 |
+
contents = ""
|
155 |
+
# `end_sequence` signals the end of a message. It is unlikely to occur in
|
156 |
+
# message content.
|
157 |
+
end_sequence = " __END_OF_A_MESSAGE_47582648__\n"
|
158 |
+
len_end = len(end_sequence)
|
159 |
+
while True:
|
160 |
+
if len(contents) >= len_end:
|
161 |
+
last_chars = contents[-len_end:]
|
162 |
+
if last_chars == end_sequence:
|
163 |
+
break
|
164 |
+
try:
|
165 |
+
char = sys.stdin.read(1)
|
166 |
+
contents = contents + char
|
167 |
+
except EOFError:
|
168 |
+
continue
|
169 |
+
contents = contents[:-len_end]
|
170 |
+
print(f"[!OP:{role}]: {contents}", flush=True)
|
171 |
+
return contents
|
172 |
+
|
173 |
+
def prompt_for_output(self, role: str):
|
174 |
+
print(f"[!OP:{role}]: ", end="", flush=True)
|
175 |
+
|
176 |
+
def stream_output(self, output_stream):
|
177 |
+
pre = 0
|
178 |
+
for outputs in output_stream:
|
179 |
+
output_text = outputs["text"]
|
180 |
+
output_text = output_text.strip().split(" ")
|
181 |
+
now = len(output_text) - 1
|
182 |
+
if now > pre:
|
183 |
+
print(" ".join(output_text[pre:now]), end=" ", flush=True)
|
184 |
+
pre = now
|
185 |
+
print(" ".join(output_text[pre:]), flush=True)
|
186 |
+
return " ".join(output_text)
|
187 |
+
|
188 |
+
def print_output(self, text: str):
|
189 |
+
print(text)
|
190 |
+
|
191 |
+
|
192 |
+
def main(args):
|
193 |
+
if args.gpus:
|
194 |
+
if len(args.gpus.split(",")) < args.num_gpus:
|
195 |
+
raise ValueError(
|
196 |
+
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
197 |
+
)
|
198 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
199 |
+
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
|
200 |
+
if args.enable_exllama:
|
201 |
+
exllama_config = ExllamaConfig(
|
202 |
+
max_seq_len=args.exllama_max_seq_len,
|
203 |
+
gpu_split=args.exllama_gpu_split,
|
204 |
+
cache_8bit=args.exllama_cache_8bit,
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
exllama_config = None
|
208 |
+
if args.enable_xft:
|
209 |
+
xft_config = XftConfig(
|
210 |
+
max_seq_len=args.xft_max_seq_len,
|
211 |
+
data_type=args.xft_dtype,
|
212 |
+
)
|
213 |
+
if args.device != "cpu":
|
214 |
+
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
|
215 |
+
args.device = "cpu"
|
216 |
+
else:
|
217 |
+
xft_config = None
|
218 |
+
if args.style == "simple":
|
219 |
+
chatio = SimpleChatIO(args.multiline)
|
220 |
+
elif args.style == "rich":
|
221 |
+
chatio = RichChatIO(args.multiline, args.mouse)
|
222 |
+
elif args.style == "programmatic":
|
223 |
+
chatio = ProgrammaticChatIO()
|
224 |
+
else:
|
225 |
+
raise ValueError(f"Invalid style for console: {args.style}")
|
226 |
+
try:
|
227 |
+
chat_loop(
|
228 |
+
args.model_path,
|
229 |
+
args.device,
|
230 |
+
args.num_gpus,
|
231 |
+
args.max_gpu_memory,
|
232 |
+
str_to_torch_dtype(args.dtype),
|
233 |
+
args.load_8bit,
|
234 |
+
args.cpu_offloading,
|
235 |
+
args.conv_template,
|
236 |
+
args.conv_system_msg,
|
237 |
+
args.temperature,
|
238 |
+
args.repetition_penalty,
|
239 |
+
args.max_new_tokens,
|
240 |
+
chatio,
|
241 |
+
gptq_config=GptqConfig(
|
242 |
+
ckpt=args.gptq_ckpt or args.model_path,
|
243 |
+
wbits=args.gptq_wbits,
|
244 |
+
groupsize=args.gptq_groupsize,
|
245 |
+
act_order=args.gptq_act_order,
|
246 |
+
),
|
247 |
+
awq_config=AWQConfig(
|
248 |
+
ckpt=args.awq_ckpt or args.model_path,
|
249 |
+
wbits=args.awq_wbits,
|
250 |
+
groupsize=args.awq_groupsize,
|
251 |
+
),
|
252 |
+
exllama_config=exllama_config,
|
253 |
+
xft_config=xft_config,
|
254 |
+
revision=args.revision,
|
255 |
+
judge_sent_end=args.judge_sent_end,
|
256 |
+
debug=args.debug,
|
257 |
+
history=not args.no_history,
|
258 |
+
)
|
259 |
+
except KeyboardInterrupt:
|
260 |
+
print("exit...")
|
261 |
+
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
parser = argparse.ArgumentParser()
|
265 |
+
add_model_args(parser)
|
266 |
+
parser.add_argument(
|
267 |
+
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--conv-system-msg", type=str, default=None, help="Conversation system message."
|
271 |
+
)
|
272 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
273 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.0)
|
274 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
275 |
+
parser.add_argument("--no-history", action="store_true")
|
276 |
+
parser.add_argument(
|
277 |
+
"--style",
|
278 |
+
type=str,
|
279 |
+
default="simple",
|
280 |
+
choices=["simple", "rich", "programmatic"],
|
281 |
+
help="Display style.",
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--multiline",
|
285 |
+
action="store_true",
|
286 |
+
help="Enable multiline input. Use ESC+Enter for newline.",
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--mouse",
|
290 |
+
action="store_true",
|
291 |
+
help="[Rich Style]: Enable mouse support for cursor positioning.",
|
292 |
+
)
|
293 |
+
parser.add_argument(
|
294 |
+
"--judge-sent-end",
|
295 |
+
action="store_true",
|
296 |
+
help="Whether enable the correction logic that interrupts the output of sentences due to EOS.",
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--debug",
|
300 |
+
action="store_true",
|
301 |
+
help="Print useful debug information (e.g., prompts)",
|
302 |
+
)
|
303 |
+
args = parser.parse_args()
|
304 |
+
main(args)
|
controller.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import asyncio
|
7 |
+
import dataclasses
|
8 |
+
from enum import Enum, auto
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
from typing import List, Union
|
14 |
+
import threading
|
15 |
+
|
16 |
+
from fastapi import FastAPI, Request
|
17 |
+
from fastapi.responses import StreamingResponse
|
18 |
+
import numpy as np
|
19 |
+
import requests
|
20 |
+
import uvicorn
|
21 |
+
|
22 |
+
from fastchat.constants import (
|
23 |
+
CONTROLLER_HEART_BEAT_EXPIRATION,
|
24 |
+
WORKER_API_TIMEOUT,
|
25 |
+
ErrorCode,
|
26 |
+
SERVER_ERROR_MSG,
|
27 |
+
)
|
28 |
+
from fastchat.utils import build_logger
|
29 |
+
|
30 |
+
|
31 |
+
logger = build_logger("controller", "controller.log")
|
32 |
+
|
33 |
+
|
34 |
+
class DispatchMethod(Enum):
|
35 |
+
LOTTERY = auto()
|
36 |
+
SHORTEST_QUEUE = auto()
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def from_str(cls, name):
|
40 |
+
if name == "lottery":
|
41 |
+
return cls.LOTTERY
|
42 |
+
elif name == "shortest_queue":
|
43 |
+
return cls.SHORTEST_QUEUE
|
44 |
+
else:
|
45 |
+
raise ValueError(f"Invalid dispatch method")
|
46 |
+
|
47 |
+
|
48 |
+
@dataclasses.dataclass
|
49 |
+
class WorkerInfo:
|
50 |
+
model_names: List[str]
|
51 |
+
speed: int
|
52 |
+
queue_length: int
|
53 |
+
check_heart_beat: bool
|
54 |
+
last_heart_beat: str
|
55 |
+
multimodal: bool
|
56 |
+
|
57 |
+
|
58 |
+
def heart_beat_controller(controller):
|
59 |
+
while True:
|
60 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
61 |
+
controller.remove_stale_workers_by_expiration()
|
62 |
+
|
63 |
+
|
64 |
+
class Controller:
|
65 |
+
def __init__(self, dispatch_method: str):
|
66 |
+
# Dict[str -> WorkerInfo]
|
67 |
+
self.worker_info = {}
|
68 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
69 |
+
|
70 |
+
self.heart_beat_thread = threading.Thread(
|
71 |
+
target=heart_beat_controller, args=(self,)
|
72 |
+
)
|
73 |
+
self.heart_beat_thread.start()
|
74 |
+
|
75 |
+
def register_worker(
|
76 |
+
self,
|
77 |
+
worker_name: str,
|
78 |
+
check_heart_beat: bool,
|
79 |
+
worker_status: dict,
|
80 |
+
multimodal: bool,
|
81 |
+
):
|
82 |
+
if worker_name not in self.worker_info:
|
83 |
+
logger.info(f"Register a new worker: {worker_name}")
|
84 |
+
else:
|
85 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
86 |
+
|
87 |
+
if not worker_status:
|
88 |
+
worker_status = self.get_worker_status(worker_name)
|
89 |
+
if not worker_status:
|
90 |
+
return False
|
91 |
+
|
92 |
+
self.worker_info[worker_name] = WorkerInfo(
|
93 |
+
worker_status["model_names"],
|
94 |
+
worker_status["speed"],
|
95 |
+
worker_status["queue_length"],
|
96 |
+
check_heart_beat,
|
97 |
+
time.time(),
|
98 |
+
multimodal,
|
99 |
+
)
|
100 |
+
|
101 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
102 |
+
return True
|
103 |
+
|
104 |
+
def get_worker_status(self, worker_name: str):
|
105 |
+
try:
|
106 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
107 |
+
except requests.exceptions.RequestException as e:
|
108 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
109 |
+
return None
|
110 |
+
|
111 |
+
if r.status_code != 200:
|
112 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
113 |
+
return None
|
114 |
+
|
115 |
+
return r.json()
|
116 |
+
|
117 |
+
def remove_worker(self, worker_name: str):
|
118 |
+
del self.worker_info[worker_name]
|
119 |
+
|
120 |
+
def refresh_all_workers(self):
|
121 |
+
old_info = dict(self.worker_info)
|
122 |
+
self.worker_info = {}
|
123 |
+
|
124 |
+
for w_name, w_info in old_info.items():
|
125 |
+
if not self.register_worker(
|
126 |
+
w_name, w_info.check_heart_beat, None, w_info.multimodal
|
127 |
+
):
|
128 |
+
logger.info(f"Remove stale worker: {w_name}")
|
129 |
+
|
130 |
+
def list_models(self):
|
131 |
+
model_names = set()
|
132 |
+
|
133 |
+
for w_name, w_info in self.worker_info.items():
|
134 |
+
model_names.update(w_info.model_names)
|
135 |
+
|
136 |
+
return list(model_names)
|
137 |
+
|
138 |
+
def list_multimodal_models(self):
|
139 |
+
model_names = set()
|
140 |
+
|
141 |
+
for w_name, w_info in self.worker_info.items():
|
142 |
+
if w_info.multimodal:
|
143 |
+
model_names.update(w_info.model_names)
|
144 |
+
|
145 |
+
return list(model_names)
|
146 |
+
|
147 |
+
def list_language_models(self):
|
148 |
+
model_names = set()
|
149 |
+
|
150 |
+
for w_name, w_info in self.worker_info.items():
|
151 |
+
if not w_info.multimodal:
|
152 |
+
model_names.update(w_info.model_names)
|
153 |
+
|
154 |
+
return list(model_names)
|
155 |
+
|
156 |
+
def get_worker_address(self, model_name: str):
|
157 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
158 |
+
worker_names = []
|
159 |
+
worker_speeds = []
|
160 |
+
for w_name, w_info in self.worker_info.items():
|
161 |
+
if model_name in w_info.model_names:
|
162 |
+
worker_names.append(w_name)
|
163 |
+
worker_speeds.append(w_info.speed)
|
164 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
165 |
+
norm = np.sum(worker_speeds)
|
166 |
+
if norm < 1e-4:
|
167 |
+
return ""
|
168 |
+
worker_speeds = worker_speeds / norm
|
169 |
+
if True: # Directly return address
|
170 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
171 |
+
worker_name = worker_names[pt]
|
172 |
+
return worker_name
|
173 |
+
|
174 |
+
# Check status before returning
|
175 |
+
while True:
|
176 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
177 |
+
worker_name = worker_names[pt]
|
178 |
+
|
179 |
+
if self.get_worker_status(worker_name):
|
180 |
+
break
|
181 |
+
else:
|
182 |
+
self.remove_worker(worker_name)
|
183 |
+
worker_speeds[pt] = 0
|
184 |
+
norm = np.sum(worker_speeds)
|
185 |
+
if norm < 1e-4:
|
186 |
+
return ""
|
187 |
+
worker_speeds = worker_speeds / norm
|
188 |
+
continue
|
189 |
+
return worker_name
|
190 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
191 |
+
worker_names = []
|
192 |
+
worker_qlen = []
|
193 |
+
for w_name, w_info in self.worker_info.items():
|
194 |
+
if model_name in w_info.model_names:
|
195 |
+
worker_names.append(w_name)
|
196 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
197 |
+
if len(worker_names) == 0:
|
198 |
+
return ""
|
199 |
+
min_index = np.argmin(worker_qlen)
|
200 |
+
w_name = worker_names[min_index]
|
201 |
+
self.worker_info[w_name].queue_length += 1
|
202 |
+
logger.info(
|
203 |
+
f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
|
204 |
+
)
|
205 |
+
return w_name
|
206 |
+
else:
|
207 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
208 |
+
|
209 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
210 |
+
if worker_name not in self.worker_info:
|
211 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
212 |
+
return False
|
213 |
+
|
214 |
+
self.worker_info[worker_name].queue_length = queue_length
|
215 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
216 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
217 |
+
return True
|
218 |
+
|
219 |
+
def remove_stale_workers_by_expiration(self):
|
220 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
221 |
+
to_delete = []
|
222 |
+
for worker_name, w_info in self.worker_info.items():
|
223 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
224 |
+
to_delete.append(worker_name)
|
225 |
+
|
226 |
+
for worker_name in to_delete:
|
227 |
+
self.remove_worker(worker_name)
|
228 |
+
|
229 |
+
def handle_no_worker(self, params):
|
230 |
+
logger.info(f"no worker: {params['model']}")
|
231 |
+
ret = {
|
232 |
+
"text": SERVER_ERROR_MSG,
|
233 |
+
"error_code": ErrorCode.CONTROLLER_NO_WORKER,
|
234 |
+
}
|
235 |
+
return json.dumps(ret).encode() + b"\0"
|
236 |
+
|
237 |
+
def handle_worker_timeout(self, worker_address):
|
238 |
+
logger.info(f"worker timeout: {worker_address}")
|
239 |
+
ret = {
|
240 |
+
"text": SERVER_ERROR_MSG,
|
241 |
+
"error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
|
242 |
+
}
|
243 |
+
return json.dumps(ret).encode() + b"\0"
|
244 |
+
|
245 |
+
# Let the controller act as a worker to achieve hierarchical
|
246 |
+
# management. This can be used to connect isolated sub networks.
|
247 |
+
def worker_api_get_status(self):
|
248 |
+
model_names = set()
|
249 |
+
speed = 0
|
250 |
+
queue_length = 0
|
251 |
+
|
252 |
+
for w_name in self.worker_info:
|
253 |
+
worker_status = self.get_worker_status(w_name)
|
254 |
+
if worker_status is not None:
|
255 |
+
model_names.update(worker_status["model_names"])
|
256 |
+
speed += worker_status["speed"]
|
257 |
+
queue_length += worker_status["queue_length"]
|
258 |
+
|
259 |
+
model_names = sorted(list(model_names))
|
260 |
+
return {
|
261 |
+
"model_names": model_names,
|
262 |
+
"speed": speed,
|
263 |
+
"queue_length": queue_length,
|
264 |
+
}
|
265 |
+
|
266 |
+
def worker_api_generate_stream(self, params):
|
267 |
+
worker_addr = self.get_worker_address(params["model"])
|
268 |
+
if not worker_addr:
|
269 |
+
yield self.handle_no_worker(params)
|
270 |
+
|
271 |
+
try:
|
272 |
+
response = requests.post(
|
273 |
+
worker_addr + "/worker_generate_stream",
|
274 |
+
json=params,
|
275 |
+
stream=True,
|
276 |
+
timeout=WORKER_API_TIMEOUT,
|
277 |
+
)
|
278 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
279 |
+
if chunk:
|
280 |
+
yield chunk + b"\0"
|
281 |
+
except requests.exceptions.RequestException as e:
|
282 |
+
yield self.handle_worker_timeout(worker_addr)
|
283 |
+
|
284 |
+
|
285 |
+
app = FastAPI()
|
286 |
+
|
287 |
+
|
288 |
+
@app.post("/register_worker")
|
289 |
+
async def register_worker(request: Request):
|
290 |
+
data = await request.json()
|
291 |
+
controller.register_worker(
|
292 |
+
data["worker_name"],
|
293 |
+
data["check_heart_beat"],
|
294 |
+
data.get("worker_status", None),
|
295 |
+
data.get("multimodal", False),
|
296 |
+
)
|
297 |
+
|
298 |
+
|
299 |
+
@app.post("/refresh_all_workers")
|
300 |
+
async def refresh_all_workers():
|
301 |
+
models = controller.refresh_all_workers()
|
302 |
+
|
303 |
+
|
304 |
+
@app.post("/list_models")
|
305 |
+
async def list_models():
|
306 |
+
models = controller.list_models()
|
307 |
+
return {"models": models}
|
308 |
+
|
309 |
+
|
310 |
+
@app.post("/list_multimodal_models")
|
311 |
+
async def list_multimodal_models():
|
312 |
+
models = controller.list_multimodal_models()
|
313 |
+
return {"models": models}
|
314 |
+
|
315 |
+
|
316 |
+
@app.post("/list_language_models")
|
317 |
+
async def list_language_models():
|
318 |
+
models = controller.list_language_models()
|
319 |
+
return {"models": models}
|
320 |
+
|
321 |
+
|
322 |
+
@app.post("/get_worker_address")
|
323 |
+
async def get_worker_address(request: Request):
|
324 |
+
data = await request.json()
|
325 |
+
addr = controller.get_worker_address(data["model"])
|
326 |
+
return {"address": addr}
|
327 |
+
|
328 |
+
|
329 |
+
@app.post("/receive_heart_beat")
|
330 |
+
async def receive_heart_beat(request: Request):
|
331 |
+
data = await request.json()
|
332 |
+
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
|
333 |
+
return {"exist": exist}
|
334 |
+
|
335 |
+
|
336 |
+
@app.post("/worker_generate_stream")
|
337 |
+
async def worker_api_generate_stream(request: Request):
|
338 |
+
params = await request.json()
|
339 |
+
generator = controller.worker_api_generate_stream(params)
|
340 |
+
return StreamingResponse(generator)
|
341 |
+
|
342 |
+
|
343 |
+
@app.post("/worker_get_status")
|
344 |
+
async def worker_api_get_status(request: Request):
|
345 |
+
return controller.worker_api_get_status()
|
346 |
+
|
347 |
+
|
348 |
+
@app.get("/test_connection")
|
349 |
+
async def worker_api_get_status(request: Request):
|
350 |
+
return "success"
|
351 |
+
|
352 |
+
|
353 |
+
def create_controller():
|
354 |
+
parser = argparse.ArgumentParser()
|
355 |
+
parser.add_argument("--host", type=str, default="localhost")
|
356 |
+
parser.add_argument("--port", type=int, default=21001)
|
357 |
+
parser.add_argument(
|
358 |
+
"--dispatch-method",
|
359 |
+
type=str,
|
360 |
+
choices=["lottery", "shortest_queue"],
|
361 |
+
default="shortest_queue",
|
362 |
+
)
|
363 |
+
parser.add_argument(
|
364 |
+
"--ssl",
|
365 |
+
action="store_true",
|
366 |
+
required=False,
|
367 |
+
default=False,
|
368 |
+
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
|
369 |
+
)
|
370 |
+
args = parser.parse_args()
|
371 |
+
logger.info(f"args: {args}")
|
372 |
+
|
373 |
+
controller = Controller(args.dispatch_method)
|
374 |
+
return args, controller
|
375 |
+
|
376 |
+
|
377 |
+
if __name__ == "__main__":
|
378 |
+
args, controller = create_controller()
|
379 |
+
if args.ssl:
|
380 |
+
uvicorn.run(
|
381 |
+
app,
|
382 |
+
host=args.host,
|
383 |
+
port=args.port,
|
384 |
+
log_level="info",
|
385 |
+
ssl_keyfile=os.environ["SSL_KEYFILE"],
|
386 |
+
ssl_certfile=os.environ["SSL_CERTFILE"],
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
gradio_block_arena_anony.py
ADDED
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chatbot Arena (battle) tab.
|
3 |
+
Users chat with two anonymous models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from fastchat.constants import (
|
13 |
+
MODERATION_MSG,
|
14 |
+
CONVERSATION_LIMIT_MSG,
|
15 |
+
SLOW_MODEL_MSG,
|
16 |
+
INPUT_CHAR_LEN_LIMIT,
|
17 |
+
CONVERSATION_TURN_LIMIT,
|
18 |
+
)
|
19 |
+
from fastchat.model.model_adapter import get_conversation_template
|
20 |
+
from fastchat.serve.gradio_block_arena_named import flash_buttons
|
21 |
+
from fastchat.serve.gradio_web_server import (
|
22 |
+
State,
|
23 |
+
bot_response,
|
24 |
+
get_conv_log_filename,
|
25 |
+
no_change_btn,
|
26 |
+
enable_btn,
|
27 |
+
disable_btn,
|
28 |
+
invisible_btn,
|
29 |
+
acknowledgment_md,
|
30 |
+
get_ip,
|
31 |
+
get_model_description_md,
|
32 |
+
)
|
33 |
+
from fastchat.utils import (
|
34 |
+
build_logger,
|
35 |
+
moderation_filter,
|
36 |
+
)
|
37 |
+
|
38 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
39 |
+
|
40 |
+
num_sides = 2
|
41 |
+
enable_moderation = False
|
42 |
+
anony_names = ["", ""]
|
43 |
+
models = []
|
44 |
+
|
45 |
+
|
46 |
+
def set_global_vars_anony(enable_moderation_):
|
47 |
+
global enable_moderation
|
48 |
+
enable_moderation = enable_moderation_
|
49 |
+
|
50 |
+
|
51 |
+
def load_demo_side_by_side_anony(models_, url_params):
|
52 |
+
global models
|
53 |
+
models = models_
|
54 |
+
|
55 |
+
states = (None,) * num_sides
|
56 |
+
selector_updates = (
|
57 |
+
gr.Markdown(visible=True),
|
58 |
+
gr.Markdown(visible=True),
|
59 |
+
)
|
60 |
+
|
61 |
+
return states + selector_updates
|
62 |
+
|
63 |
+
|
64 |
+
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
|
65 |
+
with open(get_conv_log_filename(), "a") as fout:
|
66 |
+
data = {
|
67 |
+
"tstamp": round(time.time(), 4),
|
68 |
+
"type": vote_type,
|
69 |
+
"models": [x for x in model_selectors],
|
70 |
+
"states": [x.dict() for x in states],
|
71 |
+
"ip": get_ip(request),
|
72 |
+
}
|
73 |
+
fout.write(json.dumps(data) + "\n")
|
74 |
+
|
75 |
+
if ":" not in model_selectors[0]:
|
76 |
+
for i in range(5):
|
77 |
+
names = (
|
78 |
+
"### Model A: " + states[0].model_name,
|
79 |
+
"### Model B: " + states[1].model_name,
|
80 |
+
)
|
81 |
+
yield names + ("",) + (disable_btn,) * 4
|
82 |
+
time.sleep(0.1)
|
83 |
+
else:
|
84 |
+
names = (
|
85 |
+
"### Model A: " + states[0].model_name,
|
86 |
+
"### Model B: " + states[1].model_name,
|
87 |
+
)
|
88 |
+
yield names + ("",) + (disable_btn,) * 4
|
89 |
+
|
90 |
+
|
91 |
+
def leftvote_last_response(
|
92 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
93 |
+
):
|
94 |
+
logger.info(f"leftvote (anony). ip: {get_ip(request)}")
|
95 |
+
for x in vote_last_response(
|
96 |
+
[state0, state1], "leftvote", [model_selector0, model_selector1], request
|
97 |
+
):
|
98 |
+
yield x
|
99 |
+
|
100 |
+
|
101 |
+
def rightvote_last_response(
|
102 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
103 |
+
):
|
104 |
+
logger.info(f"rightvote (anony). ip: {get_ip(request)}")
|
105 |
+
for x in vote_last_response(
|
106 |
+
[state0, state1], "rightvote", [model_selector0, model_selector1], request
|
107 |
+
):
|
108 |
+
yield x
|
109 |
+
|
110 |
+
|
111 |
+
def tievote_last_response(
|
112 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
113 |
+
):
|
114 |
+
logger.info(f"tievote (anony). ip: {get_ip(request)}")
|
115 |
+
for x in vote_last_response(
|
116 |
+
[state0, state1], "tievote", [model_selector0, model_selector1], request
|
117 |
+
):
|
118 |
+
yield x
|
119 |
+
|
120 |
+
|
121 |
+
def bothbad_vote_last_response(
|
122 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
123 |
+
):
|
124 |
+
logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
|
125 |
+
for x in vote_last_response(
|
126 |
+
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
|
127 |
+
):
|
128 |
+
yield x
|
129 |
+
|
130 |
+
|
131 |
+
def regenerate(state0, state1, request: gr.Request):
|
132 |
+
logger.info(f"regenerate (anony). ip: {get_ip(request)}")
|
133 |
+
states = [state0, state1]
|
134 |
+
for i in range(num_sides):
|
135 |
+
states[i].conv.update_last_message(None)
|
136 |
+
return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
|
137 |
+
|
138 |
+
|
139 |
+
def clear_history(request: gr.Request):
|
140 |
+
logger.info(f"clear_history (anony). ip: {get_ip(request)}")
|
141 |
+
return (
|
142 |
+
[None] * num_sides
|
143 |
+
+ [None] * num_sides
|
144 |
+
+ anony_names
|
145 |
+
+ [""]
|
146 |
+
+ [invisible_btn] * 4
|
147 |
+
+ [disable_btn] * 2
|
148 |
+
+ [""]
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
|
153 |
+
logger.info(f"share (anony). ip: {get_ip(request)}")
|
154 |
+
if state0 is not None and state1 is not None:
|
155 |
+
vote_last_response(
|
156 |
+
[state0, state1], "share", [model_selector0, model_selector1], request
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
SAMPLING_WEIGHTS = {
|
161 |
+
# tier 0
|
162 |
+
"gpt-4": 4,
|
163 |
+
"gpt-4-0314": 4,
|
164 |
+
"gpt-4-0613": 4,
|
165 |
+
"gpt-4-turbo": 4,
|
166 |
+
"gpt-4-1106-preview": 4,
|
167 |
+
"gpt-4-0125-preview": 4,
|
168 |
+
"gpt-3.5-turbo-0613": 2,
|
169 |
+
"gpt-3.5-turbo-1106": 2,
|
170 |
+
"gpt-3.5-turbo-0125": 4,
|
171 |
+
"claude-2.1": 4,
|
172 |
+
"claude-2.0": 2,
|
173 |
+
"claude-1": 2,
|
174 |
+
"claude-instant-1": 2,
|
175 |
+
"gemini-pro": 4,
|
176 |
+
"gemini-pro-dev-api": 4,
|
177 |
+
"bard-jan-24-gemini-pro": 4,
|
178 |
+
"bard-feb-2024": 4,
|
179 |
+
"mixtral-8x7b-instruct-v0.1": 4,
|
180 |
+
"mistral-medium": 4,
|
181 |
+
"qwen1.5-72b-chat": 4,
|
182 |
+
"qwen1.5-7b-chat": 2,
|
183 |
+
"qwen1.5-4b-chat": 2,
|
184 |
+
"nous-hermes-2-mixtral-8x7b-dpo": 2,
|
185 |
+
"deepseek-llm-67b-chat": 2,
|
186 |
+
"stripedhyena-nous-7b": 2,
|
187 |
+
"openchat-3.5-0106": 2,
|
188 |
+
"mistral-7b-instruct-v0.2": 2,
|
189 |
+
"solar-10.7b-instruct-v1.0": 2,
|
190 |
+
"dolphin-2.2.1-mistral-7b": 2,
|
191 |
+
"starling-lm-7b-alpha": 2,
|
192 |
+
"tulu-2-dpo-70b": 2,
|
193 |
+
"yi-34b-chat": 2,
|
194 |
+
"zephyr-7b-beta": 2,
|
195 |
+
# tier 1
|
196 |
+
"deluxe-chat-v1.2": 4,
|
197 |
+
"llama-2-70b-chat": 4,
|
198 |
+
"llama-2-13b-chat": 2,
|
199 |
+
"llama-2-7b-chat": 2,
|
200 |
+
"mistral-7b-instruct": 2,
|
201 |
+
"codellama-34b-instruct": 1.5,
|
202 |
+
"vicuna-33b": 2,
|
203 |
+
"vicuna-13b": 1.5,
|
204 |
+
"wizardlm-13b": 1.5,
|
205 |
+
"qwen-14b-chat": 1.5,
|
206 |
+
# tier 2
|
207 |
+
"pplx-7b-online": 1,
|
208 |
+
"pplx-70b-online": 1,
|
209 |
+
"openhermes-2.5-mistral-7b": 1.0,
|
210 |
+
"llama2-70b-steerlm-chat": 1.0,
|
211 |
+
"chatglm3-6b": 1.0,
|
212 |
+
"openchat-3.5": 1.0,
|
213 |
+
"wizardlm-70b": 1.0,
|
214 |
+
"vicuna-7b": 1.0,
|
215 |
+
"chatglm2-6b": 1.0,
|
216 |
+
# deprecated
|
217 |
+
"zephyr-7b-alpha": 1.5,
|
218 |
+
"codellama-13b-instruct": 1.0,
|
219 |
+
"mpt-30b-chat": 1.5,
|
220 |
+
"guanaco-33b": 1.0,
|
221 |
+
"fastchat-t5-3b": 0.5,
|
222 |
+
"alpaca-13b": 0.5,
|
223 |
+
"mpt-7b-chat": 0.1,
|
224 |
+
"oasst-pythia-12b": 0.1,
|
225 |
+
"RWKV-4-Raven-14B": 0.1,
|
226 |
+
"gpt4all-13b-snoozy": 0.1,
|
227 |
+
"koala-13b": 0.1,
|
228 |
+
"stablelm-tuned-alpha-7b": 0.1,
|
229 |
+
"dolly-v2-12b": 0.1,
|
230 |
+
"llama-13b": 0.1,
|
231 |
+
"chatglm-6b": 0.5,
|
232 |
+
"deluxe-chat-v1": 4,
|
233 |
+
"palm-2": 1.5,
|
234 |
+
}
|
235 |
+
|
236 |
+
# target model sampling weights will be boosted.
|
237 |
+
BATTLE_TARGETS = {
|
238 |
+
"gpt-4": {"gpt-4-0314", "claude-2.1", "gpt-4-1106-preview"},
|
239 |
+
"gpt-4-0613": {"gpt-4-0314", "claude-2.1", "gpt-4-1106-preview"},
|
240 |
+
"gpt-4-0314": {
|
241 |
+
"gpt-4-1106-preview",
|
242 |
+
"gpt-4-0613",
|
243 |
+
"claude-2.1",
|
244 |
+
"gpt-3.5-turbo-0613",
|
245 |
+
},
|
246 |
+
"gpt-4-1106-preview": {
|
247 |
+
"gpt-4-0613",
|
248 |
+
"gpt-3.5-turbo-0613",
|
249 |
+
"gpt-3.5-turbo-1106",
|
250 |
+
"claude-2.1",
|
251 |
+
"bard-feb-2024",
|
252 |
+
},
|
253 |
+
"gpt-4-0125-preview": {
|
254 |
+
"gpt-4-1106-preview",
|
255 |
+
"gpt-4-0613",
|
256 |
+
"gpt-3.5-turbo-0613",
|
257 |
+
"claude-2.1",
|
258 |
+
"mistral-medium",
|
259 |
+
"bard-feb-2024",
|
260 |
+
},
|
261 |
+
"gpt-3.5-turbo-0613": {"claude-instant-1", "gpt-4-0613", "claude-2.1"},
|
262 |
+
"gpt-3.5-turbo-1106": {"gpt-4-0613", "claude-instant-1", "gpt-3.5-turbo-0613"},
|
263 |
+
"gpt-3.5-turbo-0125": {
|
264 |
+
"gpt-4-0613",
|
265 |
+
"gpt-4-1106-preview",
|
266 |
+
"gpt-3.5-turbo-0613",
|
267 |
+
"gpt-3.5-turbo-1106",
|
268 |
+
"mixtral-8x7b-instruct-v0.1",
|
269 |
+
},
|
270 |
+
"qwen1.5-72b-chat": {
|
271 |
+
"gpt-3.5-turbo-0125",
|
272 |
+
"gpt-4-0613",
|
273 |
+
"gpt-4-1106-preview",
|
274 |
+
"llama-2-70b-chat",
|
275 |
+
"mixtral-8x7b-instruct-v0.1",
|
276 |
+
"mistral-medium",
|
277 |
+
"yi-34b-chat",
|
278 |
+
},
|
279 |
+
"qwen1.5-7b-chat": {
|
280 |
+
"gpt-3.5-turbo-0125",
|
281 |
+
"starling-lm-7b-alpha",
|
282 |
+
"llama-2-70b-chat",
|
283 |
+
"openchat-3.5",
|
284 |
+
"mixtral-8x7b-instruct-v0.1",
|
285 |
+
},
|
286 |
+
"qwen1.5-4b-chat": {
|
287 |
+
"llama-2-70b-chat",
|
288 |
+
"llama-2-13b-chat",
|
289 |
+
"llama-2-7b-chat",
|
290 |
+
"openchat-3.5",
|
291 |
+
},
|
292 |
+
"openchat-3.5-0106": {
|
293 |
+
"gpt-3.5-turbo-0125",
|
294 |
+
"gpt-3.5-turbo-0613",
|
295 |
+
"llama-2-70b-chat",
|
296 |
+
"openchat-3.5",
|
297 |
+
"mixtral-8x7b-instruct-v0.1",
|
298 |
+
},
|
299 |
+
"nous-hermes-2-mixtral-8x7b-dpo": {
|
300 |
+
"gpt-4-1106-preview",
|
301 |
+
"claude-2.1",
|
302 |
+
"mistral-medium",
|
303 |
+
"gpt-3.5-turbo-0613",
|
304 |
+
"mixtral-8x7b-instruct-v0.1",
|
305 |
+
},
|
306 |
+
"mistral-7b-instruct-v0.2": {
|
307 |
+
"llama-2-70b-chat",
|
308 |
+
"mixtral-8x7b-instruct-v0.1",
|
309 |
+
"starling-lm-7b-alpha",
|
310 |
+
"openhermes-2.5-mistral-7b",
|
311 |
+
},
|
312 |
+
"solar-10.7b-instruct-v1.0": {
|
313 |
+
"mixtral-8x7b-instruct-v0.1",
|
314 |
+
"gpt-3.5-turbo-0613",
|
315 |
+
"llama-2-70b-chat",
|
316 |
+
},
|
317 |
+
"mistral-medium": {
|
318 |
+
"gpt-3.5-turbo-0125",
|
319 |
+
"gpt-3.5-turbo-0613",
|
320 |
+
"gpt-4-1106-preview",
|
321 |
+
"mixtral-8x7b-instruct-v0.1",
|
322 |
+
"bard-feb-2024",
|
323 |
+
},
|
324 |
+
"mixtral-8x7b-instruct-v0.1": {
|
325 |
+
"gpt-3.5-turbo-0125",
|
326 |
+
"gpt-3.5-turbo-0613",
|
327 |
+
"gpt-4-1106-preview",
|
328 |
+
"llama-2-70b-chat",
|
329 |
+
},
|
330 |
+
"claude-2.1": {"gpt-4-1106-preview", "gpt-4-0613", "claude-1"},
|
331 |
+
"claude-2.0": {"gpt-4-1106-preview", "gpt-4-0613", "claude-1"},
|
332 |
+
"claude-1": {"claude-2.1", "gpt-4-0613", "gpt-3.5-turbo-0613"},
|
333 |
+
"claude-instant-1": {"gpt-3.5-turbo-0125", "claude-2.1"},
|
334 |
+
"gemini-pro": {"gpt-4-1106-preview", "gpt-4-0613", "gpt-3.5-turbo-0613"},
|
335 |
+
"gemini-pro-dev-api": {
|
336 |
+
"gpt-4-1106-preview",
|
337 |
+
"gpt-4-0613",
|
338 |
+
"gpt-3.5-turbo-0613",
|
339 |
+
"bard-feb-2024",
|
340 |
+
},
|
341 |
+
"bard-jan-24-gemini-pro": {
|
342 |
+
"gpt-4-1106-preview",
|
343 |
+
"gpt-4-0613",
|
344 |
+
"gpt-3.5-turbo-0613",
|
345 |
+
"gemini-pro-dev-api",
|
346 |
+
},
|
347 |
+
"bard-feb-2024": {
|
348 |
+
"gpt-4-1106-preview",
|
349 |
+
"gpt-4-0613",
|
350 |
+
"gpt-3.5-turbo-0613",
|
351 |
+
"bard-jan-24-gemini-pro",
|
352 |
+
},
|
353 |
+
"deepseek-llm-67b-chat": {
|
354 |
+
"gpt-4-1106-preview",
|
355 |
+
"gpt-4-turbo",
|
356 |
+
"gpt-3.5-turbo-0613",
|
357 |
+
},
|
358 |
+
"llama2-70b-steerlm-chat": {
|
359 |
+
"llama-2-70b-chat",
|
360 |
+
"tulu-2-dpo-70b",
|
361 |
+
"yi-34b-chat",
|
362 |
+
},
|
363 |
+
"stripedhyena-nous-7b": {
|
364 |
+
"starling-lm-7b-alpha",
|
365 |
+
"openhermes-2.5-mistral-7b",
|
366 |
+
"mistral-7b-instruct",
|
367 |
+
"llama-2-7b-chat",
|
368 |
+
},
|
369 |
+
"deluxe-chat-v1.1": {"gpt-4-0613", "gpt-4-1106-preview"},
|
370 |
+
"deluxe-chat-v1.2": {"gpt-4-0613", "gpt-4-1106-preview"},
|
371 |
+
"pplx-7b-online": {"gpt-3.5-turbo-0125", "llama-2-70b-chat"},
|
372 |
+
"pplx-70b-online": {"gpt-3.5-turbo-0125", "llama-2-70b-chat"},
|
373 |
+
"openhermes-2.5-mistral-7b": {
|
374 |
+
"gpt-3.5-turbo-0613",
|
375 |
+
"openchat-3.5",
|
376 |
+
"zephyr-7b-beta",
|
377 |
+
},
|
378 |
+
"dolphin-2.2.1-mistral-7b": {
|
379 |
+
"gpt-3.5-turbo-0613",
|
380 |
+
"vicuna-33b",
|
381 |
+
"starling-lm-7b-alpha",
|
382 |
+
"openhermes-2.5-mistral-7b",
|
383 |
+
},
|
384 |
+
"starling-lm-7b-alpha": {"gpt-3.5-turbo-0613", "openchat-3.5", "zephyr-7b-beta"},
|
385 |
+
"tulu-2-dpo-70b": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"},
|
386 |
+
"yi-34b-chat": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"},
|
387 |
+
"openchat-3.5": {"gpt-3.5-turbo-0613", "llama-2-70b-chat", "zephyr-7b-beta"},
|
388 |
+
"chatglm3-6b": {"yi-34b-chat", "qwen-14b-chat"},
|
389 |
+
"qwen-14b-chat": {"vicuna-13b", "llama-2-13b-chat", "llama-2-70b-chat"},
|
390 |
+
"zephyr-7b-alpha": {"mistral-7b-instruct", "llama-2-13b-chat"},
|
391 |
+
"zephyr-7b-beta": {
|
392 |
+
"mistral-7b-instruct",
|
393 |
+
"llama-2-13b-chat",
|
394 |
+
"llama-2-7b-chat",
|
395 |
+
"wizardlm-13b",
|
396 |
+
},
|
397 |
+
"llama-2-70b-chat": {"gpt-3.5-turbo-0125", "claude-instant-1"},
|
398 |
+
"llama-2-13b-chat": {"mistral-7b-instruct", "vicuna-13b", "llama-2-70b-chat"},
|
399 |
+
"llama-2-7b-chat": {"mistral-7b-instruct", "vicuna-7b", "llama-2-13b-chat"},
|
400 |
+
"mistral-7b-instruct": {
|
401 |
+
"llama-2-7b-chat",
|
402 |
+
"llama-2-13b-chat",
|
403 |
+
"llama-2-70b-chat",
|
404 |
+
},
|
405 |
+
"vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo-0613", "claude-instant-1"},
|
406 |
+
"vicuna-13b": {"llama-2-13b-chat", "llama-2-70b-chat"},
|
407 |
+
"vicuna-7b": {"llama-2-7b-chat", "mistral-7b-instruct", "llama-2-13b-chat"},
|
408 |
+
"wizardlm-70b": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"},
|
409 |
+
}
|
410 |
+
|
411 |
+
SAMPLING_BOOST_MODELS = [
|
412 |
+
# "claude-2.1",
|
413 |
+
# "gpt-4-0613",
|
414 |
+
# "gpt-4-0314",
|
415 |
+
# "gpt-4-1106-preview",
|
416 |
+
# "gpt-4-0125-preview",
|
417 |
+
"gpt-3.5-turbo-0125",
|
418 |
+
# "mistral-medium",
|
419 |
+
"nous-hermes-2-mixtral-8x7b-dpo",
|
420 |
+
"openchat-3.5-0106",
|
421 |
+
"qwen1.5-72b-chat",
|
422 |
+
"qwen1.5-7b-chat",
|
423 |
+
"qwen1.5-4b-chat",
|
424 |
+
# "mistral-7b-instruct-v0.2",
|
425 |
+
]
|
426 |
+
|
427 |
+
# outage models won't be sampled.
|
428 |
+
OUTAGE_MODELS = []
|
429 |
+
|
430 |
+
|
431 |
+
def get_sample_weight(model):
|
432 |
+
if model in OUTAGE_MODELS:
|
433 |
+
return 0
|
434 |
+
weight = SAMPLING_WEIGHTS.get(model, 1.0)
|
435 |
+
if model in SAMPLING_BOOST_MODELS:
|
436 |
+
weight *= 5
|
437 |
+
return weight
|
438 |
+
|
439 |
+
|
440 |
+
def get_battle_pair():
|
441 |
+
if len(models) == 1:
|
442 |
+
return models[0], models[0]
|
443 |
+
|
444 |
+
model_weights = []
|
445 |
+
for model in models:
|
446 |
+
weight = get_sample_weight(model)
|
447 |
+
model_weights.append(weight)
|
448 |
+
total_weight = np.sum(model_weights)
|
449 |
+
model_weights = model_weights / total_weight
|
450 |
+
chosen_idx = np.random.choice(len(models), p=model_weights)
|
451 |
+
chosen_model = models[chosen_idx]
|
452 |
+
# for p, w in zip(models, model_weights):
|
453 |
+
# print(p, w)
|
454 |
+
|
455 |
+
rival_models = []
|
456 |
+
rival_weights = []
|
457 |
+
for model in models:
|
458 |
+
if model == chosen_model:
|
459 |
+
continue
|
460 |
+
weight = get_sample_weight(model)
|
461 |
+
if (
|
462 |
+
weight != 0
|
463 |
+
and chosen_model in BATTLE_TARGETS
|
464 |
+
and model in BATTLE_TARGETS[chosen_model]
|
465 |
+
):
|
466 |
+
# boost to 50% chance
|
467 |
+
weight = total_weight / len(BATTLE_TARGETS[chosen_model])
|
468 |
+
rival_models.append(model)
|
469 |
+
rival_weights.append(weight)
|
470 |
+
# for p, w in zip(rival_models, rival_weights):
|
471 |
+
# print(p, w)
|
472 |
+
rival_weights = rival_weights / np.sum(rival_weights)
|
473 |
+
rival_idx = np.random.choice(len(rival_models), p=rival_weights)
|
474 |
+
rival_model = rival_models[rival_idx]
|
475 |
+
|
476 |
+
swap = np.random.randint(2)
|
477 |
+
if swap == 0:
|
478 |
+
return chosen_model, rival_model
|
479 |
+
else:
|
480 |
+
return rival_model, chosen_model
|
481 |
+
|
482 |
+
|
483 |
+
def add_text(
|
484 |
+
state0, state1, model_selector0, model_selector1, text, request: gr.Request
|
485 |
+
):
|
486 |
+
ip = get_ip(request)
|
487 |
+
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
|
488 |
+
states = [state0, state1]
|
489 |
+
model_selectors = [model_selector0, model_selector1]
|
490 |
+
|
491 |
+
# Init states if necessary
|
492 |
+
if states[0] is None:
|
493 |
+
assert states[1] is None
|
494 |
+
|
495 |
+
model_left, model_right = get_battle_pair()
|
496 |
+
states = [
|
497 |
+
State(model_left),
|
498 |
+
State(model_right),
|
499 |
+
]
|
500 |
+
|
501 |
+
if len(text) <= 0:
|
502 |
+
for i in range(num_sides):
|
503 |
+
states[i].skip_next = True
|
504 |
+
return (
|
505 |
+
states
|
506 |
+
+ [x.to_gradio_chatbot() for x in states]
|
507 |
+
+ [""]
|
508 |
+
+ [
|
509 |
+
no_change_btn,
|
510 |
+
]
|
511 |
+
* 6
|
512 |
+
+ [""]
|
513 |
+
)
|
514 |
+
|
515 |
+
model_list = [states[i].model_name for i in range(num_sides)]
|
516 |
+
flagged = moderation_filter(text, model_list)
|
517 |
+
if flagged:
|
518 |
+
logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
|
519 |
+
# overwrite the original text
|
520 |
+
text = MODERATION_MSG
|
521 |
+
|
522 |
+
conv = states[0].conv
|
523 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
524 |
+
logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
|
525 |
+
for i in range(num_sides):
|
526 |
+
states[i].skip_next = True
|
527 |
+
return (
|
528 |
+
states
|
529 |
+
+ [x.to_gradio_chatbot() for x in states]
|
530 |
+
+ [CONVERSATION_LIMIT_MSG]
|
531 |
+
+ [
|
532 |
+
no_change_btn,
|
533 |
+
]
|
534 |
+
* 6
|
535 |
+
+ [""]
|
536 |
+
)
|
537 |
+
|
538 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
539 |
+
for i in range(num_sides):
|
540 |
+
states[i].conv.append_message(states[i].conv.roles[0], text)
|
541 |
+
states[i].conv.append_message(states[i].conv.roles[1], None)
|
542 |
+
states[i].skip_next = False
|
543 |
+
|
544 |
+
hint_msg = ""
|
545 |
+
for i in range(num_sides):
|
546 |
+
if "deluxe" in states[i].model_name:
|
547 |
+
hint_msg = SLOW_MODEL_MSG
|
548 |
+
return (
|
549 |
+
states
|
550 |
+
+ [x.to_gradio_chatbot() for x in states]
|
551 |
+
+ [""]
|
552 |
+
+ [
|
553 |
+
disable_btn,
|
554 |
+
]
|
555 |
+
* 6
|
556 |
+
+ [hint_msg]
|
557 |
+
)
|
558 |
+
|
559 |
+
|
560 |
+
def bot_response_multi(
|
561 |
+
state0,
|
562 |
+
state1,
|
563 |
+
temperature,
|
564 |
+
top_p,
|
565 |
+
max_new_tokens,
|
566 |
+
request: gr.Request,
|
567 |
+
):
|
568 |
+
logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
|
569 |
+
|
570 |
+
if state0 is None or state0.skip_next:
|
571 |
+
# This generate call is skipped due to invalid inputs
|
572 |
+
yield (
|
573 |
+
state0,
|
574 |
+
state1,
|
575 |
+
state0.to_gradio_chatbot(),
|
576 |
+
state1.to_gradio_chatbot(),
|
577 |
+
) + (no_change_btn,) * 6
|
578 |
+
return
|
579 |
+
|
580 |
+
states = [state0, state1]
|
581 |
+
gen = []
|
582 |
+
for i in range(num_sides):
|
583 |
+
gen.append(
|
584 |
+
bot_response(
|
585 |
+
states[i],
|
586 |
+
temperature,
|
587 |
+
top_p,
|
588 |
+
max_new_tokens,
|
589 |
+
request,
|
590 |
+
apply_rate_limit=False,
|
591 |
+
)
|
592 |
+
)
|
593 |
+
|
594 |
+
is_gemini = []
|
595 |
+
for i in range(num_sides):
|
596 |
+
is_gemini.append(states[i].model_name in ["gemini-pro", "gemini-pro-dev-api"])
|
597 |
+
chatbots = [None] * num_sides
|
598 |
+
iters = 0
|
599 |
+
while True:
|
600 |
+
stop = True
|
601 |
+
iters += 1
|
602 |
+
for i in range(num_sides):
|
603 |
+
try:
|
604 |
+
# yield gemini fewer times as its chunk size is larger
|
605 |
+
# otherwise, gemini will stream too fast
|
606 |
+
if not is_gemini[i] or (iters % 30 == 1 or iters < 3):
|
607 |
+
ret = next(gen[i])
|
608 |
+
states[i], chatbots[i] = ret[0], ret[1]
|
609 |
+
stop = False
|
610 |
+
except StopIteration:
|
611 |
+
pass
|
612 |
+
yield states + chatbots + [disable_btn] * 6
|
613 |
+
if stop:
|
614 |
+
break
|
615 |
+
|
616 |
+
|
617 |
+
def build_side_by_side_ui_anony(models):
|
618 |
+
notice_markdown = """
|
619 |
+
# ⚔️ Chatbot Arena: Benchmarking LLMs in the Wild
|
620 |
+
| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
621 |
+
|
622 |
+
## 📜 Rules
|
623 |
+
- Ask any question to two anonymous models (e.g., ChatGPT, Claude, Llama) and vote for the better one!
|
624 |
+
- You can continue chatting until you identify a winner.
|
625 |
+
- Vote won't be counted if model identity is revealed during conversation.
|
626 |
+
|
627 |
+
## 🏆 Arena Elo [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard)
|
628 |
+
We collect **200K+** human votes to compute an Elo-based LLM leaderboard.
|
629 |
+
Find out who is the 🥇LLM Champion!
|
630 |
+
|
631 |
+
## 👇 Chat now!
|
632 |
+
"""
|
633 |
+
|
634 |
+
states = [gr.State() for _ in range(num_sides)]
|
635 |
+
model_selectors = [None] * num_sides
|
636 |
+
chatbots = [None] * num_sides
|
637 |
+
|
638 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
639 |
+
|
640 |
+
with gr.Group(elem_id="share-region-anony"):
|
641 |
+
with gr.Accordion(
|
642 |
+
f"🔍 Expand to see the descriptions of {len(models)} models", open=False
|
643 |
+
):
|
644 |
+
model_description_md = get_model_description_md(models)
|
645 |
+
gr.Markdown(model_description_md, elem_id="model_description_markdown")
|
646 |
+
with gr.Row():
|
647 |
+
for i in range(num_sides):
|
648 |
+
label = "Model A" if i == 0 else "Model B"
|
649 |
+
with gr.Column():
|
650 |
+
chatbots[i] = gr.Chatbot(
|
651 |
+
label=label,
|
652 |
+
elem_id="chatbot",
|
653 |
+
height=550,
|
654 |
+
show_copy_button=True,
|
655 |
+
)
|
656 |
+
|
657 |
+
with gr.Row():
|
658 |
+
for i in range(num_sides):
|
659 |
+
with gr.Column():
|
660 |
+
model_selectors[i] = gr.Markdown(
|
661 |
+
anony_names[i], elem_id="model_selector_md"
|
662 |
+
)
|
663 |
+
with gr.Row():
|
664 |
+
slow_warning = gr.Markdown("", elem_id="notice_markdown")
|
665 |
+
|
666 |
+
with gr.Row():
|
667 |
+
leftvote_btn = gr.Button(
|
668 |
+
value="👈 A is better", visible=False, interactive=False
|
669 |
+
)
|
670 |
+
rightvote_btn = gr.Button(
|
671 |
+
value="👉 B is better", visible=False, interactive=False
|
672 |
+
)
|
673 |
+
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
|
674 |
+
bothbad_btn = gr.Button(
|
675 |
+
value="👎 Both are bad", visible=False, interactive=False
|
676 |
+
)
|
677 |
+
|
678 |
+
with gr.Row():
|
679 |
+
textbox = gr.Textbox(
|
680 |
+
show_label=False,
|
681 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
682 |
+
elem_id="input_box",
|
683 |
+
)
|
684 |
+
send_btn = gr.Button(value="Send", variant="primary", scale=0)
|
685 |
+
|
686 |
+
with gr.Row() as button_row:
|
687 |
+
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
|
688 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
689 |
+
share_btn = gr.Button(value="📷 Share")
|
690 |
+
|
691 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
692 |
+
temperature = gr.Slider(
|
693 |
+
minimum=0.0,
|
694 |
+
maximum=1.0,
|
695 |
+
value=0.7,
|
696 |
+
step=0.1,
|
697 |
+
interactive=True,
|
698 |
+
label="Temperature",
|
699 |
+
)
|
700 |
+
top_p = gr.Slider(
|
701 |
+
minimum=0.0,
|
702 |
+
maximum=1.0,
|
703 |
+
value=1.0,
|
704 |
+
step=0.1,
|
705 |
+
interactive=True,
|
706 |
+
label="Top P",
|
707 |
+
)
|
708 |
+
max_output_tokens = gr.Slider(
|
709 |
+
minimum=16,
|
710 |
+
maximum=2048,
|
711 |
+
value=1024,
|
712 |
+
step=64,
|
713 |
+
interactive=True,
|
714 |
+
label="Max output tokens",
|
715 |
+
)
|
716 |
+
|
717 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
718 |
+
|
719 |
+
# Register listeners
|
720 |
+
btn_list = [
|
721 |
+
leftvote_btn,
|
722 |
+
rightvote_btn,
|
723 |
+
tie_btn,
|
724 |
+
bothbad_btn,
|
725 |
+
regenerate_btn,
|
726 |
+
clear_btn,
|
727 |
+
]
|
728 |
+
leftvote_btn.click(
|
729 |
+
leftvote_last_response,
|
730 |
+
states + model_selectors,
|
731 |
+
model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
732 |
+
)
|
733 |
+
rightvote_btn.click(
|
734 |
+
rightvote_last_response,
|
735 |
+
states + model_selectors,
|
736 |
+
model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
737 |
+
)
|
738 |
+
tie_btn.click(
|
739 |
+
tievote_last_response,
|
740 |
+
states + model_selectors,
|
741 |
+
model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
742 |
+
)
|
743 |
+
bothbad_btn.click(
|
744 |
+
bothbad_vote_last_response,
|
745 |
+
states + model_selectors,
|
746 |
+
model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
747 |
+
)
|
748 |
+
regenerate_btn.click(
|
749 |
+
regenerate, states, states + chatbots + [textbox] + btn_list
|
750 |
+
).then(
|
751 |
+
bot_response_multi,
|
752 |
+
states + [temperature, top_p, max_output_tokens],
|
753 |
+
states + chatbots + btn_list,
|
754 |
+
).then(
|
755 |
+
flash_buttons, [], btn_list
|
756 |
+
)
|
757 |
+
clear_btn.click(
|
758 |
+
clear_history,
|
759 |
+
None,
|
760 |
+
states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning],
|
761 |
+
)
|
762 |
+
|
763 |
+
share_js = """
|
764 |
+
function (a, b, c, d) {
|
765 |
+
const captureElement = document.querySelector('#share-region-anony');
|
766 |
+
html2canvas(captureElement)
|
767 |
+
.then(canvas => {
|
768 |
+
canvas.style.display = 'none'
|
769 |
+
document.body.appendChild(canvas)
|
770 |
+
return canvas
|
771 |
+
})
|
772 |
+
.then(canvas => {
|
773 |
+
const image = canvas.toDataURL('image/png')
|
774 |
+
const a = document.createElement('a')
|
775 |
+
a.setAttribute('download', 'chatbot-arena.png')
|
776 |
+
a.setAttribute('href', image)
|
777 |
+
a.click()
|
778 |
+
canvas.remove()
|
779 |
+
});
|
780 |
+
return [a, b, c, d];
|
781 |
+
}
|
782 |
+
"""
|
783 |
+
share_btn.click(share_click, states + model_selectors, [], js=share_js)
|
784 |
+
|
785 |
+
textbox.submit(
|
786 |
+
add_text,
|
787 |
+
states + model_selectors + [textbox],
|
788 |
+
states + chatbots + [textbox] + btn_list + [slow_warning],
|
789 |
+
).then(
|
790 |
+
bot_response_multi,
|
791 |
+
states + [temperature, top_p, max_output_tokens],
|
792 |
+
states + chatbots + btn_list,
|
793 |
+
).then(
|
794 |
+
flash_buttons,
|
795 |
+
[],
|
796 |
+
btn_list,
|
797 |
+
)
|
798 |
+
|
799 |
+
send_btn.click(
|
800 |
+
add_text,
|
801 |
+
states + model_selectors + [textbox],
|
802 |
+
states + chatbots + [textbox] + btn_list,
|
803 |
+
).then(
|
804 |
+
bot_response_multi,
|
805 |
+
states + [temperature, top_p, max_output_tokens],
|
806 |
+
states + chatbots + btn_list,
|
807 |
+
).then(
|
808 |
+
flash_buttons, [], btn_list
|
809 |
+
)
|
810 |
+
|
811 |
+
return states + model_selectors
|
gradio_block_arena_named.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chatbot Arena (side-by-side) tab.
|
3 |
+
Users chat with two chosen models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from fastchat.constants import (
|
13 |
+
MODERATION_MSG,
|
14 |
+
CONVERSATION_LIMIT_MSG,
|
15 |
+
INPUT_CHAR_LEN_LIMIT,
|
16 |
+
CONVERSATION_TURN_LIMIT,
|
17 |
+
)
|
18 |
+
from fastchat.model.model_adapter import get_conversation_template
|
19 |
+
from fastchat.serve.gradio_web_server import (
|
20 |
+
State,
|
21 |
+
bot_response,
|
22 |
+
get_conv_log_filename,
|
23 |
+
no_change_btn,
|
24 |
+
enable_btn,
|
25 |
+
disable_btn,
|
26 |
+
invisible_btn,
|
27 |
+
acknowledgment_md,
|
28 |
+
get_ip,
|
29 |
+
get_model_description_md,
|
30 |
+
)
|
31 |
+
from fastchat.utils import (
|
32 |
+
build_logger,
|
33 |
+
moderation_filter,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
38 |
+
|
39 |
+
num_sides = 2
|
40 |
+
enable_moderation = False
|
41 |
+
|
42 |
+
|
43 |
+
def set_global_vars_named(enable_moderation_):
|
44 |
+
global enable_moderation
|
45 |
+
enable_moderation = enable_moderation_
|
46 |
+
|
47 |
+
|
48 |
+
def load_demo_side_by_side_named(models, url_params):
|
49 |
+
states = (None,) * num_sides
|
50 |
+
|
51 |
+
model_left = models[0] if len(models) > 0 else ""
|
52 |
+
if len(models) > 1:
|
53 |
+
weights = ([8] * 4 + [4] * 8 + [1] * 32)[: len(models) - 1]
|
54 |
+
weights = weights / np.sum(weights)
|
55 |
+
model_right = np.random.choice(models[1:], p=weights)
|
56 |
+
else:
|
57 |
+
model_right = model_left
|
58 |
+
|
59 |
+
selector_updates = (
|
60 |
+
gr.Dropdown(choices=models, value=model_left, visible=True),
|
61 |
+
gr.Dropdown(choices=models, value=model_right, visible=True),
|
62 |
+
)
|
63 |
+
|
64 |
+
return states + selector_updates
|
65 |
+
|
66 |
+
|
67 |
+
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
|
68 |
+
with open(get_conv_log_filename(), "a") as fout:
|
69 |
+
data = {
|
70 |
+
"tstamp": round(time.time(), 4),
|
71 |
+
"type": vote_type,
|
72 |
+
"models": [x for x in model_selectors],
|
73 |
+
"states": [x.dict() for x in states],
|
74 |
+
"ip": get_ip(request),
|
75 |
+
}
|
76 |
+
fout.write(json.dumps(data) + "\n")
|
77 |
+
|
78 |
+
|
79 |
+
def leftvote_last_response(
|
80 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
81 |
+
):
|
82 |
+
logger.info(f"leftvote (named). ip: {get_ip(request)}")
|
83 |
+
vote_last_response(
|
84 |
+
[state0, state1], "leftvote", [model_selector0, model_selector1], request
|
85 |
+
)
|
86 |
+
return ("",) + (disable_btn,) * 4
|
87 |
+
|
88 |
+
|
89 |
+
def rightvote_last_response(
|
90 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
91 |
+
):
|
92 |
+
logger.info(f"rightvote (named). ip: {get_ip(request)}")
|
93 |
+
vote_last_response(
|
94 |
+
[state0, state1], "rightvote", [model_selector0, model_selector1], request
|
95 |
+
)
|
96 |
+
return ("",) + (disable_btn,) * 4
|
97 |
+
|
98 |
+
|
99 |
+
def tievote_last_response(
|
100 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
101 |
+
):
|
102 |
+
logger.info(f"tievote (named). ip: {get_ip(request)}")
|
103 |
+
vote_last_response(
|
104 |
+
[state0, state1], "tievote", [model_selector0, model_selector1], request
|
105 |
+
)
|
106 |
+
return ("",) + (disable_btn,) * 4
|
107 |
+
|
108 |
+
|
109 |
+
def bothbad_vote_last_response(
|
110 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
111 |
+
):
|
112 |
+
logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
|
113 |
+
vote_last_response(
|
114 |
+
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
|
115 |
+
)
|
116 |
+
return ("",) + (disable_btn,) * 4
|
117 |
+
|
118 |
+
|
119 |
+
def regenerate(state0, state1, request: gr.Request):
|
120 |
+
logger.info(f"regenerate (named). ip: {get_ip(request)}")
|
121 |
+
states = [state0, state1]
|
122 |
+
for i in range(num_sides):
|
123 |
+
states[i].conv.update_last_message(None)
|
124 |
+
return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
|
125 |
+
|
126 |
+
|
127 |
+
def clear_history(request: gr.Request):
|
128 |
+
logger.info(f"clear_history (named). ip: {get_ip(request)}")
|
129 |
+
return (
|
130 |
+
[None] * num_sides
|
131 |
+
+ [None] * num_sides
|
132 |
+
+ [""]
|
133 |
+
+ [invisible_btn] * 4
|
134 |
+
+ [disable_btn] * 2
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
|
139 |
+
logger.info(f"share (named). ip: {get_ip(request)}")
|
140 |
+
if state0 is not None and state1 is not None:
|
141 |
+
vote_last_response(
|
142 |
+
[state0, state1], "share", [model_selector0, model_selector1], request
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
def add_text(
|
147 |
+
state0, state1, model_selector0, model_selector1, text, request: gr.Request
|
148 |
+
):
|
149 |
+
ip = get_ip(request)
|
150 |
+
logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
|
151 |
+
states = [state0, state1]
|
152 |
+
model_selectors = [model_selector0, model_selector1]
|
153 |
+
|
154 |
+
# Init states if necessary
|
155 |
+
for i in range(num_sides):
|
156 |
+
if states[i] is None:
|
157 |
+
states[i] = State(model_selectors[i])
|
158 |
+
|
159 |
+
if len(text) <= 0:
|
160 |
+
for i in range(num_sides):
|
161 |
+
states[i].skip_next = True
|
162 |
+
return (
|
163 |
+
states
|
164 |
+
+ [x.to_gradio_chatbot() for x in states]
|
165 |
+
+ [""]
|
166 |
+
+ [
|
167 |
+
no_change_btn,
|
168 |
+
]
|
169 |
+
* 6
|
170 |
+
)
|
171 |
+
|
172 |
+
model_list = [states[i].model_name for i in range(num_sides)]
|
173 |
+
flagged = moderation_filter(text, model_list)
|
174 |
+
if flagged:
|
175 |
+
logger.info(f"violate moderation (named). ip: {ip}. text: {text}")
|
176 |
+
# overwrite the original text
|
177 |
+
text = MODERATION_MSG
|
178 |
+
|
179 |
+
conv = states[0].conv
|
180 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
181 |
+
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
|
182 |
+
for i in range(num_sides):
|
183 |
+
states[i].skip_next = True
|
184 |
+
return (
|
185 |
+
states
|
186 |
+
+ [x.to_gradio_chatbot() for x in states]
|
187 |
+
+ [CONVERSATION_LIMIT_MSG]
|
188 |
+
+ [
|
189 |
+
no_change_btn,
|
190 |
+
]
|
191 |
+
* 6
|
192 |
+
)
|
193 |
+
|
194 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
195 |
+
for i in range(num_sides):
|
196 |
+
states[i].conv.append_message(states[i].conv.roles[0], text)
|
197 |
+
states[i].conv.append_message(states[i].conv.roles[1], None)
|
198 |
+
states[i].skip_next = False
|
199 |
+
|
200 |
+
return (
|
201 |
+
states
|
202 |
+
+ [x.to_gradio_chatbot() for x in states]
|
203 |
+
+ [""]
|
204 |
+
+ [
|
205 |
+
disable_btn,
|
206 |
+
]
|
207 |
+
* 6
|
208 |
+
)
|
209 |
+
|
210 |
+
|
211 |
+
def bot_response_multi(
|
212 |
+
state0,
|
213 |
+
state1,
|
214 |
+
temperature,
|
215 |
+
top_p,
|
216 |
+
max_new_tokens,
|
217 |
+
request: gr.Request,
|
218 |
+
):
|
219 |
+
logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
|
220 |
+
|
221 |
+
if state0.skip_next:
|
222 |
+
# This generate call is skipped due to invalid inputs
|
223 |
+
yield (
|
224 |
+
state0,
|
225 |
+
state1,
|
226 |
+
state0.to_gradio_chatbot(),
|
227 |
+
state1.to_gradio_chatbot(),
|
228 |
+
) + (no_change_btn,) * 6
|
229 |
+
return
|
230 |
+
|
231 |
+
states = [state0, state1]
|
232 |
+
gen = []
|
233 |
+
for i in range(num_sides):
|
234 |
+
gen.append(
|
235 |
+
bot_response(
|
236 |
+
states[i],
|
237 |
+
temperature,
|
238 |
+
top_p,
|
239 |
+
max_new_tokens,
|
240 |
+
request,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
|
244 |
+
is_gemini = []
|
245 |
+
for i in range(num_sides):
|
246 |
+
is_gemini.append(states[i].model_name in ["gemini-pro", "gemini-pro-dev-api"])
|
247 |
+
|
248 |
+
chatbots = [None] * num_sides
|
249 |
+
iters = 0
|
250 |
+
while True:
|
251 |
+
stop = True
|
252 |
+
iters += 1
|
253 |
+
for i in range(num_sides):
|
254 |
+
try:
|
255 |
+
# yield gemini fewer times as its chunk size is larger
|
256 |
+
# otherwise, gemini will stream too fast
|
257 |
+
if not is_gemini[i] or (iters % 30 == 1 or iters < 3):
|
258 |
+
ret = next(gen[i])
|
259 |
+
states[i], chatbots[i] = ret[0], ret[1]
|
260 |
+
stop = False
|
261 |
+
except StopIteration:
|
262 |
+
pass
|
263 |
+
yield states + chatbots + [disable_btn] * 6
|
264 |
+
if stop:
|
265 |
+
break
|
266 |
+
|
267 |
+
|
268 |
+
def flash_buttons():
|
269 |
+
btn_updates = [
|
270 |
+
[disable_btn] * 4 + [enable_btn] * 2,
|
271 |
+
[enable_btn] * 6,
|
272 |
+
]
|
273 |
+
for i in range(4):
|
274 |
+
yield btn_updates[i % 2]
|
275 |
+
time.sleep(0.3)
|
276 |
+
|
277 |
+
|
278 |
+
def build_side_by_side_ui_named(models):
|
279 |
+
notice_markdown = """
|
280 |
+
# ⚔️ Chatbot Arena: Benchmarking LLMs in the Wild
|
281 |
+
| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
282 |
+
|
283 |
+
## 📜 Rules
|
284 |
+
- Chat with any two models side-by-side and vote!
|
285 |
+
- You can continue chatting for multiple rounds.
|
286 |
+
- Click "Clear history" to start a new round.
|
287 |
+
|
288 |
+
## 🤖 Choose two models to compare
|
289 |
+
"""
|
290 |
+
|
291 |
+
states = [gr.State() for _ in range(num_sides)]
|
292 |
+
model_selectors = [None] * num_sides
|
293 |
+
chatbots = [None] * num_sides
|
294 |
+
|
295 |
+
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
296 |
+
|
297 |
+
with gr.Group(elem_id="share-region-named"):
|
298 |
+
with gr.Row():
|
299 |
+
for i in range(num_sides):
|
300 |
+
with gr.Column():
|
301 |
+
model_selectors[i] = gr.Dropdown(
|
302 |
+
choices=models,
|
303 |
+
value=models[i] if len(models) > i else "",
|
304 |
+
interactive=True,
|
305 |
+
show_label=False,
|
306 |
+
container=False,
|
307 |
+
)
|
308 |
+
with gr.Row():
|
309 |
+
with gr.Accordion(
|
310 |
+
f"🔍 Expand to see the descriptions of {len(models)} models", open=False
|
311 |
+
):
|
312 |
+
model_description_md = get_model_description_md(models)
|
313 |
+
gr.Markdown(model_description_md, elem_id="model_description_markdown")
|
314 |
+
|
315 |
+
with gr.Row():
|
316 |
+
for i in range(num_sides):
|
317 |
+
label = "Model A" if i == 0 else "Model B"
|
318 |
+
with gr.Column():
|
319 |
+
chatbots[i] = gr.Chatbot(
|
320 |
+
label=label,
|
321 |
+
elem_id=f"chatbot",
|
322 |
+
height=550,
|
323 |
+
show_copy_button=True,
|
324 |
+
)
|
325 |
+
|
326 |
+
with gr.Row():
|
327 |
+
leftvote_btn = gr.Button(
|
328 |
+
value="👈 A is better", visible=False, interactive=False
|
329 |
+
)
|
330 |
+
rightvote_btn = gr.Button(
|
331 |
+
value="👉 B is better", visible=False, interactive=False
|
332 |
+
)
|
333 |
+
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
|
334 |
+
bothbad_btn = gr.Button(
|
335 |
+
value="👎 Both are bad", visible=False, interactive=False
|
336 |
+
)
|
337 |
+
|
338 |
+
with gr.Row():
|
339 |
+
textbox = gr.Textbox(
|
340 |
+
show_label=False,
|
341 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
342 |
+
elem_id="input_box",
|
343 |
+
)
|
344 |
+
send_btn = gr.Button(value="Send", variant="primary", scale=0)
|
345 |
+
|
346 |
+
with gr.Row() as button_row:
|
347 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
348 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
349 |
+
share_btn = gr.Button(value="📷 Share")
|
350 |
+
|
351 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
352 |
+
temperature = gr.Slider(
|
353 |
+
minimum=0.0,
|
354 |
+
maximum=1.0,
|
355 |
+
value=0.7,
|
356 |
+
step=0.1,
|
357 |
+
interactive=True,
|
358 |
+
label="Temperature",
|
359 |
+
)
|
360 |
+
top_p = gr.Slider(
|
361 |
+
minimum=0.0,
|
362 |
+
maximum=1.0,
|
363 |
+
value=1.0,
|
364 |
+
step=0.1,
|
365 |
+
interactive=True,
|
366 |
+
label="Top P",
|
367 |
+
)
|
368 |
+
max_output_tokens = gr.Slider(
|
369 |
+
minimum=16,
|
370 |
+
maximum=2048,
|
371 |
+
value=1024,
|
372 |
+
step=64,
|
373 |
+
interactive=True,
|
374 |
+
label="Max output tokens",
|
375 |
+
)
|
376 |
+
|
377 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
378 |
+
|
379 |
+
# Register listeners
|
380 |
+
btn_list = [
|
381 |
+
leftvote_btn,
|
382 |
+
rightvote_btn,
|
383 |
+
tie_btn,
|
384 |
+
bothbad_btn,
|
385 |
+
regenerate_btn,
|
386 |
+
clear_btn,
|
387 |
+
]
|
388 |
+
leftvote_btn.click(
|
389 |
+
leftvote_last_response,
|
390 |
+
states + model_selectors,
|
391 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
392 |
+
)
|
393 |
+
rightvote_btn.click(
|
394 |
+
rightvote_last_response,
|
395 |
+
states + model_selectors,
|
396 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
397 |
+
)
|
398 |
+
tie_btn.click(
|
399 |
+
tievote_last_response,
|
400 |
+
states + model_selectors,
|
401 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
402 |
+
)
|
403 |
+
bothbad_btn.click(
|
404 |
+
bothbad_vote_last_response,
|
405 |
+
states + model_selectors,
|
406 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
407 |
+
)
|
408 |
+
regenerate_btn.click(
|
409 |
+
regenerate, states, states + chatbots + [textbox] + btn_list
|
410 |
+
).then(
|
411 |
+
bot_response_multi,
|
412 |
+
states + [temperature, top_p, max_output_tokens],
|
413 |
+
states + chatbots + btn_list,
|
414 |
+
).then(
|
415 |
+
flash_buttons, [], btn_list
|
416 |
+
)
|
417 |
+
clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list)
|
418 |
+
|
419 |
+
share_js = """
|
420 |
+
function (a, b, c, d) {
|
421 |
+
const captureElement = document.querySelector('#share-region-named');
|
422 |
+
html2canvas(captureElement)
|
423 |
+
.then(canvas => {
|
424 |
+
canvas.style.display = 'none'
|
425 |
+
document.body.appendChild(canvas)
|
426 |
+
return canvas
|
427 |
+
})
|
428 |
+
.then(canvas => {
|
429 |
+
const image = canvas.toDataURL('image/png')
|
430 |
+
const a = document.createElement('a')
|
431 |
+
a.setAttribute('download', 'chatbot-arena.png')
|
432 |
+
a.setAttribute('href', image)
|
433 |
+
a.click()
|
434 |
+
canvas.remove()
|
435 |
+
});
|
436 |
+
return [a, b, c, d];
|
437 |
+
}
|
438 |
+
"""
|
439 |
+
share_btn.click(share_click, states + model_selectors, [], js=share_js)
|
440 |
+
|
441 |
+
for i in range(num_sides):
|
442 |
+
model_selectors[i].change(
|
443 |
+
clear_history, None, states + chatbots + [textbox] + btn_list
|
444 |
+
)
|
445 |
+
|
446 |
+
textbox.submit(
|
447 |
+
add_text,
|
448 |
+
states + model_selectors + [textbox],
|
449 |
+
states + chatbots + [textbox] + btn_list,
|
450 |
+
).then(
|
451 |
+
bot_response_multi,
|
452 |
+
states + [temperature, top_p, max_output_tokens],
|
453 |
+
states + chatbots + btn_list,
|
454 |
+
).then(
|
455 |
+
flash_buttons, [], btn_list
|
456 |
+
)
|
457 |
+
send_btn.click(
|
458 |
+
add_text,
|
459 |
+
states + model_selectors + [textbox],
|
460 |
+
states + chatbots + [textbox] + btn_list,
|
461 |
+
).then(
|
462 |
+
bot_response_multi,
|
463 |
+
states + [temperature, top_p, max_output_tokens],
|
464 |
+
states + chatbots + btn_list,
|
465 |
+
).then(
|
466 |
+
flash_buttons, [], btn_list
|
467 |
+
)
|
468 |
+
|
469 |
+
return states + model_selectors
|
gradio_block_arena_vision.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The gradio demo server for chatting with a large multimodal model.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.controller
|
6 |
+
python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf
|
7 |
+
python3 -m fastchat.serve.gradio_web_server_multi --share --multimodal
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
from fastchat.serve.gradio_web_server import (
|
15 |
+
upvote_last_response,
|
16 |
+
downvote_last_response,
|
17 |
+
flag_last_response,
|
18 |
+
get_model_description_md,
|
19 |
+
acknowledgment_md,
|
20 |
+
bot_response,
|
21 |
+
add_text,
|
22 |
+
clear_history,
|
23 |
+
regenerate,
|
24 |
+
)
|
25 |
+
from fastchat.utils import (
|
26 |
+
build_logger,
|
27 |
+
)
|
28 |
+
|
29 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
30 |
+
|
31 |
+
|
32 |
+
def build_single_vision_language_model_ui(models, add_promotion_links=False):
|
33 |
+
promotion = (
|
34 |
+
"""
|
35 |
+
| [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
36 |
+
"""
|
37 |
+
if add_promotion_links
|
38 |
+
else ""
|
39 |
+
)
|
40 |
+
|
41 |
+
notice_markdown = f"""
|
42 |
+
# 🏔️ Chat with Open Large Vision-Language Models
|
43 |
+
{promotion}
|
44 |
+
"""
|
45 |
+
|
46 |
+
state = gr.State()
|
47 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
48 |
+
|
49 |
+
with gr.Group():
|
50 |
+
with gr.Row(elem_id="model_selector_row"):
|
51 |
+
model_selector = gr.Dropdown(
|
52 |
+
choices=models,
|
53 |
+
value=models[0] if len(models) > 0 else "",
|
54 |
+
interactive=True,
|
55 |
+
show_label=False,
|
56 |
+
container=False,
|
57 |
+
)
|
58 |
+
|
59 |
+
with gr.Accordion(
|
60 |
+
f"🔍 Expand to see the descriptions of {len(models)} models", open=False
|
61 |
+
):
|
62 |
+
model_description_md = get_model_description_md(models)
|
63 |
+
gr.Markdown(model_description_md, elem_id="model_description_markdown")
|
64 |
+
|
65 |
+
with gr.Row():
|
66 |
+
with gr.Column(scale=3):
|
67 |
+
textbox = gr.Textbox(
|
68 |
+
show_label=False,
|
69 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
70 |
+
container=False,
|
71 |
+
render=False,
|
72 |
+
elem_id="input_box",
|
73 |
+
)
|
74 |
+
imagebox = gr.Image(type="pil")
|
75 |
+
|
76 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
77 |
+
|
78 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
79 |
+
temperature = gr.Slider(
|
80 |
+
minimum=0.0,
|
81 |
+
maximum=1.0,
|
82 |
+
value=0.2,
|
83 |
+
step=0.1,
|
84 |
+
interactive=True,
|
85 |
+
label="Temperature",
|
86 |
+
)
|
87 |
+
top_p = gr.Slider(
|
88 |
+
minimum=0.0,
|
89 |
+
maximum=1.0,
|
90 |
+
value=0.7,
|
91 |
+
step=0.1,
|
92 |
+
interactive=True,
|
93 |
+
label="Top P",
|
94 |
+
)
|
95 |
+
max_output_tokens = gr.Slider(
|
96 |
+
minimum=0,
|
97 |
+
maximum=1024,
|
98 |
+
value=512,
|
99 |
+
step=64,
|
100 |
+
interactive=True,
|
101 |
+
label="Max output tokens",
|
102 |
+
)
|
103 |
+
|
104 |
+
gr.Examples(
|
105 |
+
examples=[
|
106 |
+
[
|
107 |
+
f"{cur_dir}/example_images/city.jpeg",
|
108 |
+
"What is unusual about this image?",
|
109 |
+
],
|
110 |
+
[
|
111 |
+
f"{cur_dir}/example_images/fridge.jpeg",
|
112 |
+
"What is in this fridge?",
|
113 |
+
],
|
114 |
+
],
|
115 |
+
inputs=[imagebox, textbox],
|
116 |
+
)
|
117 |
+
|
118 |
+
with gr.Column(scale=8):
|
119 |
+
chatbot = gr.Chatbot(
|
120 |
+
elem_id="chatbot", label="Scroll down and start chatting", height=550
|
121 |
+
)
|
122 |
+
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column(scale=8):
|
125 |
+
textbox.render()
|
126 |
+
with gr.Column(scale=1, min_width=50):
|
127 |
+
send_btn = gr.Button(value="Send", variant="primary")
|
128 |
+
with gr.Row(elem_id="buttons"):
|
129 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
130 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
131 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
132 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
133 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
134 |
+
|
135 |
+
if add_promotion_links:
|
136 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
137 |
+
|
138 |
+
# Register listeners
|
139 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
140 |
+
upvote_btn.click(
|
141 |
+
upvote_last_response,
|
142 |
+
[state, model_selector],
|
143 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
144 |
+
)
|
145 |
+
downvote_btn.click(
|
146 |
+
downvote_last_response,
|
147 |
+
[state, model_selector],
|
148 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
149 |
+
)
|
150 |
+
flag_btn.click(
|
151 |
+
flag_last_response,
|
152 |
+
[state, model_selector],
|
153 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
154 |
+
)
|
155 |
+
regenerate_btn.click(
|
156 |
+
regenerate, state, [state, chatbot, textbox, imagebox] + btn_list
|
157 |
+
).then(
|
158 |
+
bot_response,
|
159 |
+
[state, temperature, top_p, max_output_tokens],
|
160 |
+
[state, chatbot] + btn_list,
|
161 |
+
)
|
162 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
|
163 |
+
|
164 |
+
model_selector.change(
|
165 |
+
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
|
166 |
+
)
|
167 |
+
|
168 |
+
textbox.submit(
|
169 |
+
add_text,
|
170 |
+
[state, model_selector, textbox, imagebox],
|
171 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
172 |
+
).then(
|
173 |
+
bot_response,
|
174 |
+
[state, temperature, top_p, max_output_tokens],
|
175 |
+
[state, chatbot] + btn_list,
|
176 |
+
)
|
177 |
+
send_btn.click(
|
178 |
+
add_text,
|
179 |
+
[state, model_selector, textbox, imagebox],
|
180 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
181 |
+
).then(
|
182 |
+
bot_response,
|
183 |
+
[state, temperature, top_p, max_output_tokens],
|
184 |
+
[state, chatbot] + btn_list,
|
185 |
+
)
|
186 |
+
|
187 |
+
return [state, model_selector]
|
gradio_web_server.py
ADDED
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The gradio demo server for chatting with a single model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
from collections import defaultdict
|
7 |
+
import datetime
|
8 |
+
import hashlib
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import time
|
13 |
+
import uuid
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
import requests
|
17 |
+
|
18 |
+
from fastchat.constants import (
|
19 |
+
LOGDIR,
|
20 |
+
WORKER_API_TIMEOUT,
|
21 |
+
ErrorCode,
|
22 |
+
MODERATION_MSG,
|
23 |
+
CONVERSATION_LIMIT_MSG,
|
24 |
+
RATE_LIMIT_MSG,
|
25 |
+
SERVER_ERROR_MSG,
|
26 |
+
INPUT_CHAR_LEN_LIMIT,
|
27 |
+
CONVERSATION_TURN_LIMIT,
|
28 |
+
SESSION_EXPIRATION_TIME,
|
29 |
+
)
|
30 |
+
from fastchat.model.model_adapter import (
|
31 |
+
get_conversation_template,
|
32 |
+
)
|
33 |
+
from fastchat.model.model_registry import get_model_info, model_info
|
34 |
+
from fastchat.serve.api_provider import get_api_provider_stream_iter
|
35 |
+
from fastchat.utils import (
|
36 |
+
build_logger,
|
37 |
+
get_window_url_params_js,
|
38 |
+
get_window_url_params_with_tos_js,
|
39 |
+
moderation_filter,
|
40 |
+
parse_gradio_auth_creds,
|
41 |
+
load_image,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
46 |
+
|
47 |
+
headers = {"User-Agent": "FastChat Client"}
|
48 |
+
|
49 |
+
no_change_btn = gr.Button()
|
50 |
+
enable_btn = gr.Button(interactive=True, visible=True)
|
51 |
+
disable_btn = gr.Button(interactive=False)
|
52 |
+
invisible_btn = gr.Button(interactive=False, visible=False)
|
53 |
+
|
54 |
+
controller_url = None
|
55 |
+
enable_moderation = False
|
56 |
+
|
57 |
+
acknowledgment_md = """
|
58 |
+
### Terms of Service
|
59 |
+
|
60 |
+
Users are required to agree to the following terms before using the service:
|
61 |
+
|
62 |
+
The service is a research preview. It only provides limited safety measures and may generate offensive content.
|
63 |
+
It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
|
64 |
+
The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license.
|
65 |
+
Additionally, Bard is offered on LMSys for research purposes only. To access the Bard product, please visit its [website](http://bard.google.com).
|
66 |
+
|
67 |
+
### Acknowledgment
|
68 |
+
We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous [sponsorship](https://lmsys.org/donations/).
|
69 |
+
|
70 |
+
<div class="sponsor-image-about">
|
71 |
+
<img src="https://storage.googleapis.com/public-arena-asset/kaggle.png" alt="Kaggle">
|
72 |
+
<img src="https://storage.googleapis.com/public-arena-asset/mbzuai.jpeg" alt="MBZUAI">
|
73 |
+
<img src="https://storage.googleapis.com/public-arena-asset/a16z.jpeg" alt="a16z">
|
74 |
+
<img src="https://storage.googleapis.com/public-arena-asset/together.png" alt="Together AI">
|
75 |
+
<img src="https://storage.googleapis.com/public-arena-asset/anyscale.png" alt="AnyScale">
|
76 |
+
<img src="https://storage.googleapis.com/public-arena-asset/huggingface.png" alt="HuggingFace">
|
77 |
+
</div>
|
78 |
+
"""
|
79 |
+
|
80 |
+
# JSON file format of API-based models:
|
81 |
+
# {
|
82 |
+
# "gpt-3.5-turbo-0613": {
|
83 |
+
# "model_name": "gpt-3.5-turbo-0613",
|
84 |
+
# "api_type": "openai",
|
85 |
+
# "api_base": "https://api.openai.com/v1",
|
86 |
+
# "api_key": "sk-******",
|
87 |
+
# "anony_only": false
|
88 |
+
# }
|
89 |
+
# }
|
90 |
+
# "api_type" can be one of the following: openai, anthropic, gemini, mistral.
|
91 |
+
# "anony_only" means whether to show this model in anonymous mode only.
|
92 |
+
api_endpoint_info = {}
|
93 |
+
|
94 |
+
|
95 |
+
class State:
|
96 |
+
def __init__(self, model_name):
|
97 |
+
self.conv = get_conversation_template(model_name)
|
98 |
+
self.conv_id = uuid.uuid4().hex
|
99 |
+
self.skip_next = False
|
100 |
+
self.model_name = model_name
|
101 |
+
|
102 |
+
def to_gradio_chatbot(self):
|
103 |
+
return self.conv.to_gradio_chatbot()
|
104 |
+
|
105 |
+
def dict(self):
|
106 |
+
base = self.conv.dict()
|
107 |
+
base.update(
|
108 |
+
{
|
109 |
+
"conv_id": self.conv_id,
|
110 |
+
"model_name": self.model_name,
|
111 |
+
}
|
112 |
+
)
|
113 |
+
return base
|
114 |
+
|
115 |
+
|
116 |
+
def set_global_vars(controller_url_, enable_moderation_):
|
117 |
+
global controller_url, enable_moderation
|
118 |
+
controller_url = controller_url_
|
119 |
+
enable_moderation = enable_moderation_
|
120 |
+
|
121 |
+
|
122 |
+
def get_conv_log_filename():
|
123 |
+
t = datetime.datetime.now()
|
124 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
125 |
+
return name
|
126 |
+
|
127 |
+
|
128 |
+
def get_model_list(controller_url, register_api_endpoint_file, multimodal):
|
129 |
+
global api_endpoint_info
|
130 |
+
|
131 |
+
# Add models from the controller
|
132 |
+
if controller_url:
|
133 |
+
ret = requests.post(controller_url + "/refresh_all_workers")
|
134 |
+
assert ret.status_code == 200
|
135 |
+
|
136 |
+
if multimodal:
|
137 |
+
ret = requests.post(controller_url + "/list_multimodal_models")
|
138 |
+
models = ret.json()["models"]
|
139 |
+
else:
|
140 |
+
ret = requests.post(controller_url + "/list_language_models")
|
141 |
+
models = ret.json()["models"]
|
142 |
+
else:
|
143 |
+
models = []
|
144 |
+
|
145 |
+
# Add models from the API providers
|
146 |
+
if register_api_endpoint_file:
|
147 |
+
api_endpoint_info = json.load(open(register_api_endpoint_file))
|
148 |
+
for mdl, mdl_dict in api_endpoint_info.items():
|
149 |
+
mdl_multimodal = mdl_dict.get("multimodal", False)
|
150 |
+
if multimodal and mdl_multimodal:
|
151 |
+
models += [mdl]
|
152 |
+
elif not multimodal and not mdl_multimodal:
|
153 |
+
models += [mdl]
|
154 |
+
|
155 |
+
# Remove anonymous models
|
156 |
+
models = list(set(models))
|
157 |
+
visible_models = models.copy()
|
158 |
+
for mdl in visible_models:
|
159 |
+
if mdl not in api_endpoint_info:
|
160 |
+
continue
|
161 |
+
mdl_dict = api_endpoint_info[mdl]
|
162 |
+
if mdl_dict["anony_only"]:
|
163 |
+
visible_models.remove(mdl)
|
164 |
+
|
165 |
+
# Sort models and add descriptions
|
166 |
+
priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)}
|
167 |
+
models.sort(key=lambda x: priority.get(x, x))
|
168 |
+
visible_models.sort(key=lambda x: priority.get(x, x))
|
169 |
+
logger.info(f"All models: {models}")
|
170 |
+
logger.info(f"Visible models: {visible_models}")
|
171 |
+
return visible_models, models
|
172 |
+
|
173 |
+
|
174 |
+
def load_demo_single(models, url_params):
|
175 |
+
selected_model = models[0] if len(models) > 0 else ""
|
176 |
+
if "model" in url_params:
|
177 |
+
model = url_params["model"]
|
178 |
+
if model in models:
|
179 |
+
selected_model = model
|
180 |
+
|
181 |
+
dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True)
|
182 |
+
state = None
|
183 |
+
return state, dropdown_update
|
184 |
+
|
185 |
+
|
186 |
+
def load_demo(url_params, request: gr.Request):
|
187 |
+
global models
|
188 |
+
|
189 |
+
ip = get_ip(request)
|
190 |
+
logger.info(f"load_demo. ip: {ip}. params: {url_params}")
|
191 |
+
|
192 |
+
if args.model_list_mode == "reload":
|
193 |
+
models, all_models = get_model_list(
|
194 |
+
controller_url, args.register_api_endpoint_file, False
|
195 |
+
)
|
196 |
+
|
197 |
+
return load_demo_single(models, url_params)
|
198 |
+
|
199 |
+
|
200 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
201 |
+
with open(get_conv_log_filename(), "a") as fout:
|
202 |
+
data = {
|
203 |
+
"tstamp": round(time.time(), 4),
|
204 |
+
"type": vote_type,
|
205 |
+
"model": model_selector,
|
206 |
+
"state": state.dict(),
|
207 |
+
"ip": get_ip(request),
|
208 |
+
}
|
209 |
+
fout.write(json.dumps(data) + "\n")
|
210 |
+
|
211 |
+
|
212 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
213 |
+
ip = get_ip(request)
|
214 |
+
logger.info(f"upvote. ip: {ip}")
|
215 |
+
vote_last_response(state, "upvote", model_selector, request)
|
216 |
+
return ("",) + (disable_btn,) * 3
|
217 |
+
|
218 |
+
|
219 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
220 |
+
ip = get_ip(request)
|
221 |
+
logger.info(f"downvote. ip: {ip}")
|
222 |
+
vote_last_response(state, "downvote", model_selector, request)
|
223 |
+
return ("",) + (disable_btn,) * 3
|
224 |
+
|
225 |
+
|
226 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
227 |
+
ip = get_ip(request)
|
228 |
+
logger.info(f"flag. ip: {ip}")
|
229 |
+
vote_last_response(state, "flag", model_selector, request)
|
230 |
+
return ("",) + (disable_btn,) * 3
|
231 |
+
|
232 |
+
|
233 |
+
def regenerate(state, request: gr.Request):
|
234 |
+
ip = get_ip(request)
|
235 |
+
logger.info(f"regenerate. ip: {ip}")
|
236 |
+
state.conv.update_last_message(None)
|
237 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
238 |
+
|
239 |
+
|
240 |
+
def clear_history(request: gr.Request):
|
241 |
+
ip = get_ip(request)
|
242 |
+
logger.info(f"clear_history. ip: {ip}")
|
243 |
+
state = None
|
244 |
+
return (state, [], "", None) + (disable_btn,) * 5
|
245 |
+
|
246 |
+
|
247 |
+
def get_ip(request: gr.Request):
|
248 |
+
if "cf-connecting-ip" in request.headers:
|
249 |
+
ip = request.headers["cf-connecting-ip"]
|
250 |
+
else:
|
251 |
+
ip = request.client.host
|
252 |
+
return ip
|
253 |
+
|
254 |
+
|
255 |
+
def _prepare_text_with_image(state, text, image):
|
256 |
+
if image is not None:
|
257 |
+
if len(state.conv.get_images()) > 0:
|
258 |
+
# reset convo with new image
|
259 |
+
state.conv = get_conversation_template(state.model_name)
|
260 |
+
|
261 |
+
image = state.conv.convert_image_to_base64(
|
262 |
+
image
|
263 |
+
) # PIL type is not JSON serializable
|
264 |
+
|
265 |
+
text = text, [image]
|
266 |
+
|
267 |
+
return text
|
268 |
+
|
269 |
+
|
270 |
+
def add_text(state, model_selector, text, image, request: gr.Request):
|
271 |
+
ip = get_ip(request)
|
272 |
+
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
|
273 |
+
|
274 |
+
if state is None:
|
275 |
+
state = State(model_selector)
|
276 |
+
|
277 |
+
if len(text) <= 0:
|
278 |
+
state.skip_next = True
|
279 |
+
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
280 |
+
|
281 |
+
flagged = moderation_filter(text, [state.model_name])
|
282 |
+
if flagged:
|
283 |
+
logger.info(f"violate moderation. ip: {ip}. text: {text}")
|
284 |
+
# overwrite the original text
|
285 |
+
text = MODERATION_MSG
|
286 |
+
|
287 |
+
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
288 |
+
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
|
289 |
+
state.skip_next = True
|
290 |
+
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
|
291 |
+
no_change_btn,
|
292 |
+
) * 5
|
293 |
+
|
294 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
295 |
+
text = _prepare_text_with_image(state, text, image)
|
296 |
+
state.conv.append_message(state.conv.roles[0], text)
|
297 |
+
state.conv.append_message(state.conv.roles[1], None)
|
298 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
299 |
+
|
300 |
+
|
301 |
+
def model_worker_stream_iter(
|
302 |
+
conv,
|
303 |
+
model_name,
|
304 |
+
worker_addr,
|
305 |
+
prompt,
|
306 |
+
temperature,
|
307 |
+
repetition_penalty,
|
308 |
+
top_p,
|
309 |
+
max_new_tokens,
|
310 |
+
images,
|
311 |
+
):
|
312 |
+
# Make requests
|
313 |
+
gen_params = {
|
314 |
+
"model": model_name,
|
315 |
+
"prompt": prompt,
|
316 |
+
"temperature": temperature,
|
317 |
+
"repetition_penalty": repetition_penalty,
|
318 |
+
"top_p": top_p,
|
319 |
+
"max_new_tokens": max_new_tokens,
|
320 |
+
"stop": conv.stop_str,
|
321 |
+
"stop_token_ids": conv.stop_token_ids,
|
322 |
+
"echo": False,
|
323 |
+
}
|
324 |
+
|
325 |
+
logger.info(f"==== request ====\n{gen_params}")
|
326 |
+
|
327 |
+
if len(images) > 0:
|
328 |
+
gen_params["images"] = images
|
329 |
+
|
330 |
+
# Stream output
|
331 |
+
response = requests.post(
|
332 |
+
worker_addr + "/worker_generate_stream",
|
333 |
+
headers=headers,
|
334 |
+
json=gen_params,
|
335 |
+
stream=True,
|
336 |
+
timeout=WORKER_API_TIMEOUT,
|
337 |
+
)
|
338 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
339 |
+
if chunk:
|
340 |
+
data = json.loads(chunk.decode())
|
341 |
+
yield data
|
342 |
+
|
343 |
+
|
344 |
+
def is_limit_reached(model_name, ip):
|
345 |
+
monitor_url = "http://localhost:9090"
|
346 |
+
try:
|
347 |
+
ret = requests.get(
|
348 |
+
f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1
|
349 |
+
)
|
350 |
+
obj = ret.json()
|
351 |
+
return obj
|
352 |
+
except Exception as e:
|
353 |
+
logger.info(f"monitor error: {e}")
|
354 |
+
return None
|
355 |
+
|
356 |
+
|
357 |
+
def bot_response(
|
358 |
+
state,
|
359 |
+
temperature,
|
360 |
+
top_p,
|
361 |
+
max_new_tokens,
|
362 |
+
request: gr.Request,
|
363 |
+
apply_rate_limit=True,
|
364 |
+
):
|
365 |
+
ip = get_ip(request)
|
366 |
+
logger.info(f"bot_response. ip: {ip}")
|
367 |
+
start_tstamp = time.time()
|
368 |
+
temperature = float(temperature)
|
369 |
+
top_p = float(top_p)
|
370 |
+
max_new_tokens = int(max_new_tokens)
|
371 |
+
|
372 |
+
if state.skip_next:
|
373 |
+
# This generate call is skipped due to invalid inputs
|
374 |
+
state.skip_next = False
|
375 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
376 |
+
return
|
377 |
+
|
378 |
+
if apply_rate_limit:
|
379 |
+
ret = is_limit_reached(state.model_name, ip)
|
380 |
+
if ret is not None and ret["is_limit_reached"]:
|
381 |
+
error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"]
|
382 |
+
logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}")
|
383 |
+
state.conv.update_last_message(error_msg)
|
384 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
385 |
+
return
|
386 |
+
|
387 |
+
conv, model_name = state.conv, state.model_name
|
388 |
+
model_api_dict = (
|
389 |
+
api_endpoint_info[model_name] if model_name in api_endpoint_info else None
|
390 |
+
)
|
391 |
+
images = conv.get_images()
|
392 |
+
|
393 |
+
if model_api_dict is None:
|
394 |
+
# Query worker address
|
395 |
+
ret = requests.post(
|
396 |
+
controller_url + "/get_worker_address", json={"model": model_name}
|
397 |
+
)
|
398 |
+
worker_addr = ret.json()["address"]
|
399 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
400 |
+
|
401 |
+
# No available worker
|
402 |
+
if worker_addr == "":
|
403 |
+
conv.update_last_message(SERVER_ERROR_MSG)
|
404 |
+
yield (
|
405 |
+
state,
|
406 |
+
state.to_gradio_chatbot(),
|
407 |
+
disable_btn,
|
408 |
+
disable_btn,
|
409 |
+
disable_btn,
|
410 |
+
enable_btn,
|
411 |
+
enable_btn,
|
412 |
+
)
|
413 |
+
return
|
414 |
+
|
415 |
+
# Construct prompt.
|
416 |
+
# We need to call it here, so it will not be affected by "▌".
|
417 |
+
prompt = conv.get_prompt()
|
418 |
+
|
419 |
+
# Set repetition_penalty
|
420 |
+
if "t5" in model_name:
|
421 |
+
repetition_penalty = 1.2
|
422 |
+
else:
|
423 |
+
repetition_penalty = 1.0
|
424 |
+
|
425 |
+
stream_iter = model_worker_stream_iter(
|
426 |
+
conv,
|
427 |
+
model_name,
|
428 |
+
worker_addr,
|
429 |
+
prompt,
|
430 |
+
temperature,
|
431 |
+
repetition_penalty,
|
432 |
+
top_p,
|
433 |
+
max_new_tokens,
|
434 |
+
images,
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
stream_iter = get_api_provider_stream_iter(
|
438 |
+
conv,
|
439 |
+
model_name,
|
440 |
+
model_api_dict,
|
441 |
+
temperature,
|
442 |
+
top_p,
|
443 |
+
max_new_tokens,
|
444 |
+
)
|
445 |
+
|
446 |
+
conv.update_last_message("▌")
|
447 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
448 |
+
|
449 |
+
try:
|
450 |
+
for i, data in enumerate(stream_iter):
|
451 |
+
if data["error_code"] == 0:
|
452 |
+
output = data["text"].strip()
|
453 |
+
conv.update_last_message(output + "▌")
|
454 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
455 |
+
else:
|
456 |
+
output = data["text"] + f"\n\n(error_code: {data['error_code']})"
|
457 |
+
conv.update_last_message(output)
|
458 |
+
yield (state, state.to_gradio_chatbot()) + (
|
459 |
+
disable_btn,
|
460 |
+
disable_btn,
|
461 |
+
disable_btn,
|
462 |
+
enable_btn,
|
463 |
+
enable_btn,
|
464 |
+
)
|
465 |
+
return
|
466 |
+
output = data["text"].strip()
|
467 |
+
conv.update_last_message(output)
|
468 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
469 |
+
except requests.exceptions.RequestException as e:
|
470 |
+
conv.update_last_message(
|
471 |
+
f"{SERVER_ERROR_MSG}\n\n"
|
472 |
+
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
|
473 |
+
)
|
474 |
+
yield (state, state.to_gradio_chatbot()) + (
|
475 |
+
disable_btn,
|
476 |
+
disable_btn,
|
477 |
+
disable_btn,
|
478 |
+
enable_btn,
|
479 |
+
enable_btn,
|
480 |
+
)
|
481 |
+
return
|
482 |
+
except Exception as e:
|
483 |
+
conv.update_last_message(
|
484 |
+
f"{SERVER_ERROR_MSG}\n\n"
|
485 |
+
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
|
486 |
+
)
|
487 |
+
yield (state, state.to_gradio_chatbot()) + (
|
488 |
+
disable_btn,
|
489 |
+
disable_btn,
|
490 |
+
disable_btn,
|
491 |
+
enable_btn,
|
492 |
+
enable_btn,
|
493 |
+
)
|
494 |
+
return
|
495 |
+
|
496 |
+
finish_tstamp = time.time()
|
497 |
+
logger.info(f"{output}")
|
498 |
+
|
499 |
+
# We load the image because gradio accepts base64 but that increases file size by ~1.33x
|
500 |
+
loaded_images = [load_image(image) for image in images]
|
501 |
+
images_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images]
|
502 |
+
for image, hash_str in zip(loaded_images, images_hash):
|
503 |
+
t = datetime.datetime.now()
|
504 |
+
filename = os.path.join(
|
505 |
+
LOGDIR,
|
506 |
+
"serve_images",
|
507 |
+
f"{hash_str}.jpg",
|
508 |
+
)
|
509 |
+
if not os.path.isfile(filename):
|
510 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
511 |
+
image.save(filename)
|
512 |
+
|
513 |
+
with open(get_conv_log_filename(), "a") as fout:
|
514 |
+
data = {
|
515 |
+
"tstamp": round(finish_tstamp, 4),
|
516 |
+
"type": "chat",
|
517 |
+
"model": model_name,
|
518 |
+
"gen_params": {
|
519 |
+
"temperature": temperature,
|
520 |
+
"top_p": top_p,
|
521 |
+
"max_new_tokens": max_new_tokens,
|
522 |
+
},
|
523 |
+
"start": round(start_tstamp, 4),
|
524 |
+
"finish": round(finish_tstamp, 4),
|
525 |
+
"state": state.dict(),
|
526 |
+
"ip": get_ip(request),
|
527 |
+
"images": images_hash,
|
528 |
+
}
|
529 |
+
fout.write(json.dumps(data) + "\n")
|
530 |
+
|
531 |
+
|
532 |
+
block_css = """
|
533 |
+
#notice_markdown .prose {
|
534 |
+
font-size: 120% !important;
|
535 |
+
}
|
536 |
+
#notice_markdown th {
|
537 |
+
display: none;
|
538 |
+
}
|
539 |
+
#notice_markdown td {
|
540 |
+
padding-top: 6px;
|
541 |
+
padding-bottom: 6px;
|
542 |
+
}
|
543 |
+
#model_description_markdown {
|
544 |
+
font-size: 120% !important;
|
545 |
+
}
|
546 |
+
#leaderboard_markdown .prose {
|
547 |
+
font-size: 120% !important;
|
548 |
+
}
|
549 |
+
#leaderboard_markdown td {
|
550 |
+
padding-top: 6px;
|
551 |
+
padding-bottom: 6px;
|
552 |
+
}
|
553 |
+
#leaderboard_dataframe td {
|
554 |
+
line-height: 0.1em;
|
555 |
+
}
|
556 |
+
#about_markdown .prose {
|
557 |
+
font-size: 120% !important;
|
558 |
+
}
|
559 |
+
#ack_markdown .prose {
|
560 |
+
font-size: 120% !important;
|
561 |
+
}
|
562 |
+
footer {
|
563 |
+
display:none !important;
|
564 |
+
}
|
565 |
+
.sponsor-image-about img {
|
566 |
+
margin: 0 20px;
|
567 |
+
margin-top: 20px;
|
568 |
+
height: 40px;
|
569 |
+
max-height: 100%;
|
570 |
+
width: auto;
|
571 |
+
float: left;
|
572 |
+
}
|
573 |
+
"""
|
574 |
+
|
575 |
+
|
576 |
+
def get_model_description_md(models):
|
577 |
+
model_description_md = """
|
578 |
+
| | | |
|
579 |
+
| ---- | ---- | ---- |
|
580 |
+
"""
|
581 |
+
ct = 0
|
582 |
+
visited = set()
|
583 |
+
for i, name in enumerate(models):
|
584 |
+
minfo = get_model_info(name)
|
585 |
+
if minfo.simple_name in visited:
|
586 |
+
continue
|
587 |
+
visited.add(minfo.simple_name)
|
588 |
+
one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
|
589 |
+
|
590 |
+
if ct % 3 == 0:
|
591 |
+
model_description_md += "|"
|
592 |
+
model_description_md += f" {one_model_md} |"
|
593 |
+
if ct % 3 == 2:
|
594 |
+
model_description_md += "\n"
|
595 |
+
ct += 1
|
596 |
+
return model_description_md
|
597 |
+
|
598 |
+
|
599 |
+
def build_about():
|
600 |
+
about_markdown = """
|
601 |
+
# About Us
|
602 |
+
Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our [FastChat](https://github.com/lm-sys/FastChat) project at GitHub and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey!
|
603 |
+
|
604 |
+
## Read More
|
605 |
+
- Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/)
|
606 |
+
- LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998)
|
607 |
+
|
608 |
+
## Core Members
|
609 |
+
[Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ)
|
610 |
+
|
611 |
+
## Advisors
|
612 |
+
[Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/)
|
613 |
+
|
614 |
+
## Contact Us
|
615 |
+
- Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at [email protected]
|
616 |
+
- File issues on [GitHub](https://github.com/lm-sys/FastChat)
|
617 |
+
- Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys)
|
618 |
+
|
619 |
+
## Acknowledgment
|
620 |
+
We thank [SkyPilot](https://github.com/skypilot-org/skypilot) and [Gradio](https://github.com/gradio-app/gradio) team for their system support.
|
621 |
+
We also thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. Learn more about partnership [here](https://lmsys.org/donations/).
|
622 |
+
|
623 |
+
<div class="sponsor-image-about">
|
624 |
+
<img src="https://storage.googleapis.com/public-arena-asset/kaggle.png" alt="Kaggle">
|
625 |
+
<img src="https://storage.googleapis.com/public-arena-asset/mbzuai.jpeg" alt="MBZUAI">
|
626 |
+
<img src="https://storage.googleapis.com/public-arena-asset/a16z.jpeg" alt="a16z">
|
627 |
+
<img src="https://storage.googleapis.com/public-arena-asset/together.png" alt="Together AI">
|
628 |
+
<img src="https://storage.googleapis.com/public-arena-asset/anyscale.png" alt="AnyScale">
|
629 |
+
<img src="https://storage.googleapis.com/public-arena-asset/huggingface.png" alt="HuggingFace">
|
630 |
+
</div>
|
631 |
+
"""
|
632 |
+
gr.Markdown(about_markdown, elem_id="about_markdown")
|
633 |
+
|
634 |
+
|
635 |
+
def build_single_model_ui(models, add_promotion_links=False):
|
636 |
+
promotion = (
|
637 |
+
"""
|
638 |
+
- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
639 |
+
- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/)
|
640 |
+
- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/)
|
641 |
+
|
642 |
+
## 🤖 Choose any model to chat
|
643 |
+
"""
|
644 |
+
if add_promotion_links
|
645 |
+
else ""
|
646 |
+
)
|
647 |
+
|
648 |
+
notice_markdown = f"""
|
649 |
+
# 🏔️ Chat with Open Large Language Models
|
650 |
+
{promotion}
|
651 |
+
"""
|
652 |
+
|
653 |
+
state = gr.State()
|
654 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
655 |
+
|
656 |
+
with gr.Group(elem_id="share-region-named"):
|
657 |
+
with gr.Row(elem_id="model_selector_row"):
|
658 |
+
model_selector = gr.Dropdown(
|
659 |
+
choices=models,
|
660 |
+
value=models[0] if len(models) > 0 else "",
|
661 |
+
interactive=True,
|
662 |
+
show_label=False,
|
663 |
+
container=False,
|
664 |
+
)
|
665 |
+
with gr.Row():
|
666 |
+
with gr.Accordion(
|
667 |
+
f"🔍 Expand to see the descriptions of {len(models)} models",
|
668 |
+
open=False,
|
669 |
+
):
|
670 |
+
model_description_md = get_model_description_md(models)
|
671 |
+
gr.Markdown(model_description_md, elem_id="model_description_markdown")
|
672 |
+
|
673 |
+
chatbot = gr.Chatbot(
|
674 |
+
elem_id="chatbot",
|
675 |
+
label="Scroll down and start chatting",
|
676 |
+
height=550,
|
677 |
+
show_copy_button=True,
|
678 |
+
)
|
679 |
+
with gr.Row():
|
680 |
+
textbox = gr.Textbox(
|
681 |
+
show_label=False,
|
682 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
683 |
+
elem_id="input_box",
|
684 |
+
)
|
685 |
+
send_btn = gr.Button(value="Send", variant="primary", scale=0)
|
686 |
+
|
687 |
+
with gr.Row() as button_row:
|
688 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
689 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
690 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
691 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
692 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
693 |
+
|
694 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
695 |
+
temperature = gr.Slider(
|
696 |
+
minimum=0.0,
|
697 |
+
maximum=1.0,
|
698 |
+
value=0.7,
|
699 |
+
step=0.1,
|
700 |
+
interactive=True,
|
701 |
+
label="Temperature",
|
702 |
+
)
|
703 |
+
top_p = gr.Slider(
|
704 |
+
minimum=0.0,
|
705 |
+
maximum=1.0,
|
706 |
+
value=1.0,
|
707 |
+
step=0.1,
|
708 |
+
interactive=True,
|
709 |
+
label="Top P",
|
710 |
+
)
|
711 |
+
max_output_tokens = gr.Slider(
|
712 |
+
minimum=16,
|
713 |
+
maximum=2048,
|
714 |
+
value=1024,
|
715 |
+
step=64,
|
716 |
+
interactive=True,
|
717 |
+
label="Max output tokens",
|
718 |
+
)
|
719 |
+
|
720 |
+
if add_promotion_links:
|
721 |
+
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
722 |
+
|
723 |
+
# Register listeners
|
724 |
+
imagebox = gr.State(None)
|
725 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
726 |
+
upvote_btn.click(
|
727 |
+
upvote_last_response,
|
728 |
+
[state, model_selector],
|
729 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
730 |
+
)
|
731 |
+
downvote_btn.click(
|
732 |
+
downvote_last_response,
|
733 |
+
[state, model_selector],
|
734 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
735 |
+
)
|
736 |
+
flag_btn.click(
|
737 |
+
flag_last_response,
|
738 |
+
[state, model_selector],
|
739 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
740 |
+
)
|
741 |
+
regenerate_btn.click(
|
742 |
+
regenerate, state, [state, chatbot, textbox, imagebox] + btn_list
|
743 |
+
).then(
|
744 |
+
bot_response,
|
745 |
+
[state, temperature, top_p, max_output_tokens],
|
746 |
+
[state, chatbot] + btn_list,
|
747 |
+
)
|
748 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
|
749 |
+
|
750 |
+
model_selector.change(
|
751 |
+
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
|
752 |
+
)
|
753 |
+
|
754 |
+
textbox.submit(
|
755 |
+
add_text,
|
756 |
+
[state, model_selector, textbox, imagebox],
|
757 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
758 |
+
).then(
|
759 |
+
bot_response,
|
760 |
+
[state, temperature, top_p, max_output_tokens],
|
761 |
+
[state, chatbot] + btn_list,
|
762 |
+
)
|
763 |
+
send_btn.click(
|
764 |
+
add_text,
|
765 |
+
[state, model_selector, textbox, imagebox],
|
766 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
767 |
+
).then(
|
768 |
+
bot_response,
|
769 |
+
[state, temperature, top_p, max_output_tokens],
|
770 |
+
[state, chatbot] + btn_list,
|
771 |
+
)
|
772 |
+
|
773 |
+
return [state, model_selector]
|
774 |
+
|
775 |
+
|
776 |
+
def build_demo(models):
|
777 |
+
with gr.Blocks(
|
778 |
+
title="Chat with Open Large Language Models",
|
779 |
+
theme=gr.themes.Default(),
|
780 |
+
css=block_css,
|
781 |
+
) as demo:
|
782 |
+
url_params = gr.JSON(visible=False)
|
783 |
+
|
784 |
+
state, model_selector = build_single_model_ui(models)
|
785 |
+
|
786 |
+
if args.model_list_mode not in ["once", "reload"]:
|
787 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
788 |
+
|
789 |
+
if args.show_terms_of_use:
|
790 |
+
load_js = get_window_url_params_with_tos_js
|
791 |
+
else:
|
792 |
+
load_js = get_window_url_params_js
|
793 |
+
|
794 |
+
demo.load(
|
795 |
+
load_demo,
|
796 |
+
[url_params],
|
797 |
+
[
|
798 |
+
state,
|
799 |
+
model_selector,
|
800 |
+
],
|
801 |
+
js=load_js,
|
802 |
+
)
|
803 |
+
|
804 |
+
return demo
|
805 |
+
|
806 |
+
|
807 |
+
if __name__ == "__main__":
|
808 |
+
parser = argparse.ArgumentParser()
|
809 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
810 |
+
parser.add_argument("--port", type=int)
|
811 |
+
parser.add_argument(
|
812 |
+
"--share",
|
813 |
+
action="store_true",
|
814 |
+
help="Whether to generate a public, shareable link",
|
815 |
+
)
|
816 |
+
parser.add_argument(
|
817 |
+
"--controller-url",
|
818 |
+
type=str,
|
819 |
+
default="http://localhost:21001",
|
820 |
+
help="The address of the controller",
|
821 |
+
)
|
822 |
+
parser.add_argument(
|
823 |
+
"--concurrency-count",
|
824 |
+
type=int,
|
825 |
+
default=10,
|
826 |
+
help="The concurrency count of the gradio queue",
|
827 |
+
)
|
828 |
+
parser.add_argument(
|
829 |
+
"--model-list-mode",
|
830 |
+
type=str,
|
831 |
+
default="once",
|
832 |
+
choices=["once", "reload"],
|
833 |
+
help="Whether to load the model list once or reload the model list every time",
|
834 |
+
)
|
835 |
+
parser.add_argument(
|
836 |
+
"--moderate",
|
837 |
+
action="store_true",
|
838 |
+
help="Enable content moderation to block unsafe inputs",
|
839 |
+
)
|
840 |
+
parser.add_argument(
|
841 |
+
"--show-terms-of-use",
|
842 |
+
action="store_true",
|
843 |
+
help="Shows term of use before loading the demo",
|
844 |
+
)
|
845 |
+
parser.add_argument(
|
846 |
+
"--register-api-endpoint-file",
|
847 |
+
type=str,
|
848 |
+
help="Register API-based model endpoints from a JSON file",
|
849 |
+
)
|
850 |
+
parser.add_argument(
|
851 |
+
"--gradio-auth-path",
|
852 |
+
type=str,
|
853 |
+
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
|
854 |
+
)
|
855 |
+
parser.add_argument(
|
856 |
+
"--gradio-root-path",
|
857 |
+
type=str,
|
858 |
+
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
|
859 |
+
)
|
860 |
+
args = parser.parse_args()
|
861 |
+
logger.info(f"args: {args}")
|
862 |
+
|
863 |
+
# Set global variables
|
864 |
+
set_global_vars(args.controller_url, args.moderate)
|
865 |
+
models, all_models = get_model_list(
|
866 |
+
args.controller_url, args.register_api_endpoint_file, False
|
867 |
+
)
|
868 |
+
|
869 |
+
# Set authorization credentials
|
870 |
+
auth = None
|
871 |
+
if args.gradio_auth_path is not None:
|
872 |
+
auth = parse_gradio_auth_creds(args.gradio_auth_path)
|
873 |
+
|
874 |
+
# Launch the demo
|
875 |
+
demo = build_demo(models)
|
876 |
+
demo.queue(
|
877 |
+
default_concurrency_limit=args.concurrency_count,
|
878 |
+
status_update_rate=10,
|
879 |
+
api_open=False,
|
880 |
+
).launch(
|
881 |
+
server_name=args.host,
|
882 |
+
server_port=args.port,
|
883 |
+
share=args.share,
|
884 |
+
max_threads=200,
|
885 |
+
auth=auth,
|
886 |
+
root_path=args.gradio_root_path,
|
887 |
+
)
|
gradio_web_server_multi.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The gradio demo server with multiple tabs.
|
3 |
+
It supports chatting with a single model or chatting with two models side-by-side.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import pickle
|
8 |
+
import time
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
from fastchat.serve.gradio_block_arena_anony import (
|
13 |
+
build_side_by_side_ui_anony,
|
14 |
+
load_demo_side_by_side_anony,
|
15 |
+
set_global_vars_anony,
|
16 |
+
)
|
17 |
+
from fastchat.serve.gradio_block_arena_named import (
|
18 |
+
build_side_by_side_ui_named,
|
19 |
+
load_demo_side_by_side_named,
|
20 |
+
set_global_vars_named,
|
21 |
+
)
|
22 |
+
from fastchat.serve.gradio_block_arena_vision import (
|
23 |
+
build_single_vision_language_model_ui,
|
24 |
+
)
|
25 |
+
from fastchat.serve.gradio_web_server import (
|
26 |
+
set_global_vars,
|
27 |
+
block_css,
|
28 |
+
build_single_model_ui,
|
29 |
+
build_about,
|
30 |
+
get_model_list,
|
31 |
+
load_demo_single,
|
32 |
+
get_ip,
|
33 |
+
)
|
34 |
+
from fastchat.serve.monitor.monitor import build_leaderboard_tab
|
35 |
+
from fastchat.utils import (
|
36 |
+
build_logger,
|
37 |
+
get_window_url_params_js,
|
38 |
+
get_window_url_params_with_tos_js,
|
39 |
+
parse_gradio_auth_creds,
|
40 |
+
)
|
41 |
+
|
42 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
43 |
+
|
44 |
+
|
45 |
+
def load_demo(url_params, request: gr.Request):
|
46 |
+
global models, all_models, vl_models
|
47 |
+
|
48 |
+
ip = get_ip(request)
|
49 |
+
logger.info(f"load_demo. ip: {ip}. params: {url_params}")
|
50 |
+
|
51 |
+
selected = 0
|
52 |
+
if "arena" in url_params:
|
53 |
+
selected = 0
|
54 |
+
elif "compare" in url_params:
|
55 |
+
selected = 1
|
56 |
+
elif "direct" in url_params or "model" in url_params:
|
57 |
+
selected = 2
|
58 |
+
elif "vision" in url_params:
|
59 |
+
selected = 3
|
60 |
+
elif "leaderboard" in url_params:
|
61 |
+
selected = 4
|
62 |
+
|
63 |
+
if args.model_list_mode == "reload":
|
64 |
+
models, all_models = get_model_list(
|
65 |
+
args.controller_url,
|
66 |
+
args.register_api_endpoint_file,
|
67 |
+
False,
|
68 |
+
)
|
69 |
+
|
70 |
+
vl_models, all_vl_models = get_model_list(
|
71 |
+
args.controller_url,
|
72 |
+
args.register_api_endpoint_file,
|
73 |
+
True,
|
74 |
+
)
|
75 |
+
|
76 |
+
single_updates = load_demo_single(models, url_params)
|
77 |
+
side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params)
|
78 |
+
side_by_side_named_updates = load_demo_side_by_side_named(models, url_params)
|
79 |
+
vision_language_updates = load_demo_single(vl_models, url_params)
|
80 |
+
|
81 |
+
return (
|
82 |
+
(gr.Tabs(selected=selected),)
|
83 |
+
+ single_updates
|
84 |
+
+ side_by_side_anony_updates
|
85 |
+
+ side_by_side_named_updates
|
86 |
+
+ vision_language_updates
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
|
91 |
+
text_size = gr.themes.sizes.text_md
|
92 |
+
if args.show_terms_of_use:
|
93 |
+
load_js = get_window_url_params_with_tos_js
|
94 |
+
else:
|
95 |
+
load_js = get_window_url_params_js
|
96 |
+
|
97 |
+
head_js = """
|
98 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.min.js"></script>
|
99 |
+
"""
|
100 |
+
if args.ga_id is not None:
|
101 |
+
head_js += f"""
|
102 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id={args.ga_id}"></script>
|
103 |
+
<script>
|
104 |
+
window.dataLayer = window.dataLayer || [];
|
105 |
+
function gtag(){{dataLayer.push(arguments);}}
|
106 |
+
gtag('js', new Date());
|
107 |
+
|
108 |
+
gtag('config', '{args.ga_id}');
|
109 |
+
window.__gradio_mode__ = "app";
|
110 |
+
</script>
|
111 |
+
"""
|
112 |
+
|
113 |
+
with gr.Blocks(
|
114 |
+
title="Chat with Open Large Language Models",
|
115 |
+
theme=gr.themes.Default(text_size=text_size),
|
116 |
+
css=block_css,
|
117 |
+
head=head_js,
|
118 |
+
) as demo:
|
119 |
+
with gr.Tabs() as tabs:
|
120 |
+
with gr.Tab("Arena (battle)", id=0):
|
121 |
+
side_by_side_anony_list = build_side_by_side_ui_anony(models)
|
122 |
+
|
123 |
+
with gr.Tab("Arena (side-by-side)", id=1):
|
124 |
+
side_by_side_named_list = build_side_by_side_ui_named(models)
|
125 |
+
|
126 |
+
with gr.Tab("Direct Chat", id=2):
|
127 |
+
single_model_list = build_single_model_ui(
|
128 |
+
models, add_promotion_links=True
|
129 |
+
)
|
130 |
+
|
131 |
+
with gr.Tab(
|
132 |
+
"Vision-Language Model Direct Chat", id=3, visible=args.multimodal
|
133 |
+
):
|
134 |
+
single_vision_language_model_list = (
|
135 |
+
build_single_vision_language_model_ui(
|
136 |
+
vl_models, add_promotion_links=True
|
137 |
+
)
|
138 |
+
)
|
139 |
+
|
140 |
+
if elo_results_file:
|
141 |
+
with gr.Tab("Leaderboard", id=4):
|
142 |
+
build_leaderboard_tab(elo_results_file, leaderboard_table_file)
|
143 |
+
|
144 |
+
with gr.Tab("About Us", id=5):
|
145 |
+
about = build_about()
|
146 |
+
|
147 |
+
url_params = gr.JSON(visible=False)
|
148 |
+
|
149 |
+
if args.model_list_mode not in ["once", "reload"]:
|
150 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
151 |
+
|
152 |
+
demo.load(
|
153 |
+
load_demo,
|
154 |
+
[url_params],
|
155 |
+
[tabs]
|
156 |
+
+ single_model_list
|
157 |
+
+ side_by_side_anony_list
|
158 |
+
+ side_by_side_named_list
|
159 |
+
+ single_vision_language_model_list,
|
160 |
+
js=load_js,
|
161 |
+
)
|
162 |
+
|
163 |
+
return demo
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == "__main__":
|
167 |
+
parser = argparse.ArgumentParser()
|
168 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
169 |
+
parser.add_argument("--port", type=int)
|
170 |
+
parser.add_argument(
|
171 |
+
"--share",
|
172 |
+
action="store_true",
|
173 |
+
help="Whether to generate a public, shareable link",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--controller-url",
|
177 |
+
type=str,
|
178 |
+
default="http://localhost:21001",
|
179 |
+
help="The address of the controller",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--concurrency-count",
|
183 |
+
type=int,
|
184 |
+
default=10,
|
185 |
+
help="The concurrency count of the gradio queue",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--model-list-mode",
|
189 |
+
type=str,
|
190 |
+
default="once",
|
191 |
+
choices=["once", "reload"],
|
192 |
+
help="Whether to load the model list once or reload the model list every time.",
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--moderate",
|
196 |
+
action="store_true",
|
197 |
+
help="Enable content moderation to block unsafe inputs",
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--show-terms-of-use",
|
201 |
+
action="store_true",
|
202 |
+
help="Shows term of use before loading the demo",
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--multimodal", action="store_true", help="Show multi modal tabs."
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--register-api-endpoint-file",
|
209 |
+
type=str,
|
210 |
+
help="Register API-based model endpoints from a JSON file",
|
211 |
+
)
|
212 |
+
parser.add_argument(
|
213 |
+
"--gradio-auth-path",
|
214 |
+
type=str,
|
215 |
+
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
|
216 |
+
default=None,
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--elo-results-file", type=str, help="Load leaderboard results and plots"
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--leaderboard-table-file", type=str, help="Load leaderboard results and plots"
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--gradio-root-path",
|
226 |
+
type=str,
|
227 |
+
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--ga-id",
|
231 |
+
type=str,
|
232 |
+
help="the Google Analytics ID",
|
233 |
+
default=None,
|
234 |
+
)
|
235 |
+
args = parser.parse_args()
|
236 |
+
logger.info(f"args: {args}")
|
237 |
+
|
238 |
+
# Set global variables
|
239 |
+
set_global_vars(args.controller_url, args.moderate)
|
240 |
+
set_global_vars_named(args.moderate)
|
241 |
+
set_global_vars_anony(args.moderate)
|
242 |
+
models, all_models = get_model_list(
|
243 |
+
args.controller_url,
|
244 |
+
args.register_api_endpoint_file,
|
245 |
+
False,
|
246 |
+
)
|
247 |
+
|
248 |
+
vl_models, all_vl_models = get_model_list(
|
249 |
+
args.controller_url,
|
250 |
+
args.register_api_endpoint_file,
|
251 |
+
True,
|
252 |
+
)
|
253 |
+
|
254 |
+
# Set authorization credentials
|
255 |
+
auth = None
|
256 |
+
if args.gradio_auth_path is not None:
|
257 |
+
auth = parse_gradio_auth_creds(args.gradio_auth_path)
|
258 |
+
|
259 |
+
# Launch the demo
|
260 |
+
demo = build_demo(
|
261 |
+
models,
|
262 |
+
vl_models,
|
263 |
+
args.elo_results_file,
|
264 |
+
args.leaderboard_table_file,
|
265 |
+
)
|
266 |
+
demo.queue(
|
267 |
+
default_concurrency_limit=args.concurrency_count,
|
268 |
+
status_update_rate=10,
|
269 |
+
api_open=False,
|
270 |
+
).launch(
|
271 |
+
server_name=args.host,
|
272 |
+
server_port=args.port,
|
273 |
+
share=args.share,
|
274 |
+
max_threads=200,
|
275 |
+
auth=auth,
|
276 |
+
root_path=args.gradio_root_path,
|
277 |
+
)
|
huggingface_api.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Use FastChat with Hugging Face generation APIs.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5
|
6 |
+
python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from fastchat.model import load_model, get_conversation_template, add_model_args
|
13 |
+
|
14 |
+
|
15 |
+
@torch.inference_mode()
|
16 |
+
def main(args):
|
17 |
+
# Load model
|
18 |
+
model, tokenizer = load_model(
|
19 |
+
args.model_path,
|
20 |
+
device=args.device,
|
21 |
+
num_gpus=args.num_gpus,
|
22 |
+
max_gpu_memory=args.max_gpu_memory,
|
23 |
+
load_8bit=args.load_8bit,
|
24 |
+
cpu_offloading=args.cpu_offloading,
|
25 |
+
revision=args.revision,
|
26 |
+
debug=args.debug,
|
27 |
+
)
|
28 |
+
|
29 |
+
# Build the prompt with a conversation template
|
30 |
+
msg = args.message
|
31 |
+
conv = get_conversation_template(args.model_path)
|
32 |
+
conv.append_message(conv.roles[0], msg)
|
33 |
+
conv.append_message(conv.roles[1], None)
|
34 |
+
prompt = conv.get_prompt()
|
35 |
+
|
36 |
+
# Run inference
|
37 |
+
inputs = tokenizer([prompt], return_tensors="pt").to(args.device)
|
38 |
+
output_ids = model.generate(
|
39 |
+
**inputs,
|
40 |
+
do_sample=True if args.temperature > 1e-5 else False,
|
41 |
+
temperature=args.temperature,
|
42 |
+
repetition_penalty=args.repetition_penalty,
|
43 |
+
max_new_tokens=args.max_new_tokens,
|
44 |
+
)
|
45 |
+
|
46 |
+
if model.config.is_encoder_decoder:
|
47 |
+
output_ids = output_ids[0]
|
48 |
+
else:
|
49 |
+
output_ids = output_ids[0][len(inputs["input_ids"][0]) :]
|
50 |
+
outputs = tokenizer.decode(
|
51 |
+
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
|
52 |
+
)
|
53 |
+
|
54 |
+
# Print results
|
55 |
+
print(f"{conv.roles[0]}: {msg}")
|
56 |
+
print(f"{conv.roles[1]}: {outputs}")
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
add_model_args(parser)
|
62 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
63 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.0)
|
64 |
+
parser.add_argument("--max-new-tokens", type=int, default=1024)
|
65 |
+
parser.add_argument("--debug", action="store_true")
|
66 |
+
parser.add_argument("--message", type=str, default="Hello! Who are you?")
|
67 |
+
args = parser.parse_args()
|
68 |
+
|
69 |
+
# Reset default repetition penalty for T5 models.
|
70 |
+
if "t5" in args.model_path and args.repetition_penalty == 1.0:
|
71 |
+
args.repetition_penalty = 1.2
|
72 |
+
|
73 |
+
main(args)
|
huggingface_api_worker.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker that calls huggingface inference endpoint.
|
3 |
+
|
4 |
+
Register models in a JSON file with the following format:
|
5 |
+
{
|
6 |
+
"falcon-180b-chat": {
|
7 |
+
"model_name": "falcon-180B-chat",
|
8 |
+
"api_base": "https://api-inference.huggingface.co/models",
|
9 |
+
"model_path": "tiiuae/falcon-180B-chat",
|
10 |
+
"token": "hf_XXX",
|
11 |
+
"context_length": 2048
|
12 |
+
},
|
13 |
+
"zephyr-7b-beta": {
|
14 |
+
"model_name": "zephyr-7b-beta",
|
15 |
+
"model_path": "",
|
16 |
+
"api_base": "xxx",
|
17 |
+
"token": "hf_XXX",
|
18 |
+
"context_length": 4096
|
19 |
+
}
|
20 |
+
}
|
21 |
+
|
22 |
+
"model_path", "api_base", "token", and "context_length" are necessary, while others are optional.
|
23 |
+
"""
|
24 |
+
import argparse
|
25 |
+
import asyncio
|
26 |
+
import json
|
27 |
+
import uuid
|
28 |
+
import os
|
29 |
+
from typing import List, Optional
|
30 |
+
|
31 |
+
import requests
|
32 |
+
import uvicorn
|
33 |
+
from fastapi import BackgroundTasks, FastAPI, Request
|
34 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
35 |
+
from huggingface_hub import InferenceClient
|
36 |
+
|
37 |
+
from fastchat.constants import SERVER_ERROR_MSG, ErrorCode
|
38 |
+
from fastchat.serve.base_model_worker import BaseModelWorker
|
39 |
+
from fastchat.utils import build_logger
|
40 |
+
|
41 |
+
worker_id = str(uuid.uuid4())[:8]
|
42 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
43 |
+
|
44 |
+
workers = []
|
45 |
+
worker_map = {}
|
46 |
+
app = FastAPI()
|
47 |
+
|
48 |
+
|
49 |
+
# reference to
|
50 |
+
# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392
|
51 |
+
def get_gen_kwargs(
|
52 |
+
params,
|
53 |
+
seed: Optional[int] = None,
|
54 |
+
):
|
55 |
+
stop = params.get("stop", None)
|
56 |
+
if isinstance(stop, list):
|
57 |
+
stop_sequences = stop
|
58 |
+
elif isinstance(stop, str):
|
59 |
+
stop_sequences = [stop]
|
60 |
+
else:
|
61 |
+
stop_sequences = []
|
62 |
+
gen_kwargs = {
|
63 |
+
"do_sample": True,
|
64 |
+
"return_full_text": bool(params.get("echo", False)),
|
65 |
+
"max_new_tokens": int(params.get("max_new_tokens", 256)),
|
66 |
+
"top_p": float(params.get("top_p", 1.0)),
|
67 |
+
"temperature": float(params.get("temperature", 1.0)),
|
68 |
+
"stop_sequences": stop_sequences,
|
69 |
+
"repetition_penalty": float(params.get("repetition_penalty", 1.0)),
|
70 |
+
"top_k": params.get("top_k", None),
|
71 |
+
"seed": seed,
|
72 |
+
}
|
73 |
+
if gen_kwargs["top_p"] == 1:
|
74 |
+
gen_kwargs["top_p"] = 0.9999999
|
75 |
+
if gen_kwargs["top_p"] == 0:
|
76 |
+
gen_kwargs.pop("top_p")
|
77 |
+
if gen_kwargs["temperature"] == 0:
|
78 |
+
gen_kwargs.pop("temperature")
|
79 |
+
gen_kwargs["do_sample"] = False
|
80 |
+
return gen_kwargs
|
81 |
+
|
82 |
+
|
83 |
+
def could_be_stop(text, stop):
|
84 |
+
for s in stop:
|
85 |
+
if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)):
|
86 |
+
return True
|
87 |
+
return False
|
88 |
+
|
89 |
+
|
90 |
+
class HuggingfaceApiWorker(BaseModelWorker):
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
controller_addr: str,
|
94 |
+
worker_addr: str,
|
95 |
+
worker_id: str,
|
96 |
+
model_path: str,
|
97 |
+
api_base: str,
|
98 |
+
token: str,
|
99 |
+
context_length: int,
|
100 |
+
model_names: List[str],
|
101 |
+
limit_worker_concurrency: int,
|
102 |
+
no_register: bool,
|
103 |
+
conv_template: Optional[str] = None,
|
104 |
+
seed: Optional[int] = None,
|
105 |
+
**kwargs,
|
106 |
+
):
|
107 |
+
super().__init__(
|
108 |
+
controller_addr,
|
109 |
+
worker_addr,
|
110 |
+
worker_id,
|
111 |
+
model_path,
|
112 |
+
model_names,
|
113 |
+
limit_worker_concurrency,
|
114 |
+
conv_template=conv_template,
|
115 |
+
)
|
116 |
+
|
117 |
+
self.model_path = model_path
|
118 |
+
self.api_base = api_base
|
119 |
+
self.token = token
|
120 |
+
self.context_len = context_length
|
121 |
+
self.seed = seed
|
122 |
+
|
123 |
+
logger.info(
|
124 |
+
f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..."
|
125 |
+
)
|
126 |
+
|
127 |
+
if not no_register:
|
128 |
+
self.init_heart_beat()
|
129 |
+
|
130 |
+
def count_token(self, params):
|
131 |
+
# No tokenizer here
|
132 |
+
ret = {
|
133 |
+
"count": 0,
|
134 |
+
"error_code": 0,
|
135 |
+
}
|
136 |
+
return ret
|
137 |
+
|
138 |
+
def generate_stream_gate(self, params):
|
139 |
+
self.call_ct += 1
|
140 |
+
|
141 |
+
prompt = params["prompt"]
|
142 |
+
gen_kwargs = get_gen_kwargs(params, seed=self.seed)
|
143 |
+
stop = gen_kwargs["stop_sequences"]
|
144 |
+
if "falcon" in self.model_path and "chat" in self.model_path:
|
145 |
+
stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"])
|
146 |
+
stop = list(set(stop))
|
147 |
+
gen_kwargs["stop_sequences"] = stop
|
148 |
+
|
149 |
+
logger.info(f"prompt: {prompt}")
|
150 |
+
logger.info(f"gen_kwargs: {gen_kwargs}")
|
151 |
+
|
152 |
+
try:
|
153 |
+
if self.model_path == "":
|
154 |
+
url = f"{self.api_base}"
|
155 |
+
else:
|
156 |
+
url = f"{self.api_base}/{self.model_path}"
|
157 |
+
client = InferenceClient(url, token=self.token)
|
158 |
+
res = client.text_generation(
|
159 |
+
prompt, stream=True, details=True, **gen_kwargs
|
160 |
+
)
|
161 |
+
|
162 |
+
reason = None
|
163 |
+
text = ""
|
164 |
+
for chunk in res:
|
165 |
+
if chunk.token.special:
|
166 |
+
continue
|
167 |
+
text += chunk.token.text
|
168 |
+
|
169 |
+
s = next((x for x in stop if text.endswith(x)), None)
|
170 |
+
if s is not None:
|
171 |
+
text = text[: -len(s)]
|
172 |
+
reason = "stop"
|
173 |
+
break
|
174 |
+
if could_be_stop(text, stop):
|
175 |
+
continue
|
176 |
+
if (
|
177 |
+
chunk.details is not None
|
178 |
+
and chunk.details.finish_reason is not None
|
179 |
+
):
|
180 |
+
reason = chunk.details.finish_reason
|
181 |
+
if reason not in ["stop", "length"]:
|
182 |
+
reason = None
|
183 |
+
ret = {
|
184 |
+
"text": text,
|
185 |
+
"error_code": 0,
|
186 |
+
"finish_reason": reason,
|
187 |
+
}
|
188 |
+
yield json.dumps(ret).encode() + b"\0"
|
189 |
+
except Exception as e:
|
190 |
+
ret = {
|
191 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
192 |
+
"error_code": ErrorCode.INTERNAL_ERROR,
|
193 |
+
}
|
194 |
+
yield json.dumps(ret).encode() + b"\0"
|
195 |
+
|
196 |
+
def generate_gate(self, params):
|
197 |
+
for x in self.generate_stream_gate(params):
|
198 |
+
pass
|
199 |
+
return json.loads(x[:-1].decode())
|
200 |
+
|
201 |
+
def get_embeddings(self, params):
|
202 |
+
raise NotImplementedError()
|
203 |
+
|
204 |
+
|
205 |
+
def release_worker_semaphore(worker):
|
206 |
+
worker.semaphore.release()
|
207 |
+
|
208 |
+
|
209 |
+
def acquire_worker_semaphore(worker):
|
210 |
+
if worker.semaphore is None:
|
211 |
+
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
212 |
+
return worker.semaphore.acquire()
|
213 |
+
|
214 |
+
|
215 |
+
def create_background_tasks(worker):
|
216 |
+
background_tasks = BackgroundTasks()
|
217 |
+
background_tasks.add_task(lambda: release_worker_semaphore(worker))
|
218 |
+
return background_tasks
|
219 |
+
|
220 |
+
|
221 |
+
@app.post("/worker_generate_stream")
|
222 |
+
async def api_generate_stream(request: Request):
|
223 |
+
params = await request.json()
|
224 |
+
worker = worker_map[params["model"]]
|
225 |
+
await acquire_worker_semaphore(worker)
|
226 |
+
generator = worker.generate_stream_gate(params)
|
227 |
+
background_tasks = create_background_tasks(worker)
|
228 |
+
return StreamingResponse(generator, background=background_tasks)
|
229 |
+
|
230 |
+
|
231 |
+
@app.post("/worker_generate")
|
232 |
+
async def api_generate(request: Request):
|
233 |
+
params = await request.json()
|
234 |
+
worker = worker_map[params["model"]]
|
235 |
+
await acquire_worker_semaphore(worker)
|
236 |
+
output = worker.generate_gate(params)
|
237 |
+
release_worker_semaphore(worker)
|
238 |
+
return JSONResponse(output)
|
239 |
+
|
240 |
+
|
241 |
+
@app.post("/worker_get_embeddings")
|
242 |
+
async def api_get_embeddings(request: Request):
|
243 |
+
params = await request.json()
|
244 |
+
worker = worker_map[params["model"]]
|
245 |
+
await acquire_worker_semaphore(worker)
|
246 |
+
embedding = worker.get_embeddings(params)
|
247 |
+
release_worker_semaphore(worker)
|
248 |
+
return JSONResponse(content=embedding)
|
249 |
+
|
250 |
+
|
251 |
+
@app.post("/worker_get_status")
|
252 |
+
async def api_get_status(request: Request):
|
253 |
+
return {
|
254 |
+
"model_names": [m for w in workers for m in w.model_names],
|
255 |
+
"speed": 1,
|
256 |
+
"queue_length": sum([w.get_queue_length() for w in workers]),
|
257 |
+
}
|
258 |
+
|
259 |
+
|
260 |
+
@app.post("/count_token")
|
261 |
+
async def api_count_token(request: Request):
|
262 |
+
params = await request.json()
|
263 |
+
worker = worker_map[params["model"]]
|
264 |
+
return worker.count_token(params)
|
265 |
+
|
266 |
+
|
267 |
+
@app.post("/worker_get_conv_template")
|
268 |
+
async def api_get_conv(request: Request):
|
269 |
+
params = await request.json()
|
270 |
+
worker = worker_map[params["model"]]
|
271 |
+
return worker.get_conv_template()
|
272 |
+
|
273 |
+
|
274 |
+
@app.post("/model_details")
|
275 |
+
async def api_model_details(request: Request):
|
276 |
+
params = await request.json()
|
277 |
+
worker = worker_map[params["model"]]
|
278 |
+
return {"context_length": worker.context_len}
|
279 |
+
|
280 |
+
|
281 |
+
def create_huggingface_api_worker():
|
282 |
+
parser = argparse.ArgumentParser()
|
283 |
+
parser.add_argument("--host", type=str, default="localhost")
|
284 |
+
parser.add_argument("--port", type=int, default=21002)
|
285 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
286 |
+
parser.add_argument(
|
287 |
+
"--controller-address", type=str, default="http://localhost:21001"
|
288 |
+
)
|
289 |
+
# all model-related parameters are listed in --model-info-file
|
290 |
+
parser.add_argument(
|
291 |
+
"--model-info-file",
|
292 |
+
type=str,
|
293 |
+
required=True,
|
294 |
+
help="Huggingface API model's info file path",
|
295 |
+
)
|
296 |
+
|
297 |
+
parser.add_argument(
|
298 |
+
"--limit-worker-concurrency",
|
299 |
+
type=int,
|
300 |
+
default=5,
|
301 |
+
help="Limit the model concurrency to prevent OOM.",
|
302 |
+
)
|
303 |
+
parser.add_argument("--no-register", action="store_true")
|
304 |
+
parser.add_argument(
|
305 |
+
"--seed",
|
306 |
+
type=int,
|
307 |
+
default=None,
|
308 |
+
help="Overwrite the random seed for each generation.",
|
309 |
+
)
|
310 |
+
parser.add_argument(
|
311 |
+
"--ssl",
|
312 |
+
action="store_true",
|
313 |
+
required=False,
|
314 |
+
default=False,
|
315 |
+
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
|
316 |
+
)
|
317 |
+
args = parser.parse_args()
|
318 |
+
|
319 |
+
with open(args.model_info_file, "r", encoding="UTF-8") as f:
|
320 |
+
model_info = json.load(f)
|
321 |
+
|
322 |
+
logger.info(f"args: {args}")
|
323 |
+
|
324 |
+
model_path_list = []
|
325 |
+
api_base_list = []
|
326 |
+
token_list = []
|
327 |
+
context_length_list = []
|
328 |
+
model_names_list = []
|
329 |
+
conv_template_list = []
|
330 |
+
|
331 |
+
for m in model_info:
|
332 |
+
model_path_list.append(model_info[m]["model_path"])
|
333 |
+
api_base_list.append(model_info[m]["api_base"])
|
334 |
+
token_list.append(model_info[m]["token"])
|
335 |
+
|
336 |
+
context_length = model_info[m]["context_length"]
|
337 |
+
model_names = model_info[m].get("model_names", [m.split("/")[-1]])
|
338 |
+
if isinstance(model_names, str):
|
339 |
+
model_names = [model_names]
|
340 |
+
conv_template = model_info[m].get("conv_template", None)
|
341 |
+
|
342 |
+
context_length_list.append(context_length)
|
343 |
+
model_names_list.append(model_names)
|
344 |
+
conv_template_list.append(conv_template)
|
345 |
+
|
346 |
+
logger.info(f"Model paths: {model_path_list}")
|
347 |
+
logger.info(f"API bases: {api_base_list}")
|
348 |
+
logger.info(f"Tokens: {token_list}")
|
349 |
+
logger.info(f"Context lengths: {context_length_list}")
|
350 |
+
logger.info(f"Model names: {model_names_list}")
|
351 |
+
logger.info(f"Conv templates: {conv_template_list}")
|
352 |
+
|
353 |
+
for (
|
354 |
+
model_names,
|
355 |
+
conv_template,
|
356 |
+
model_path,
|
357 |
+
api_base,
|
358 |
+
token,
|
359 |
+
context_length,
|
360 |
+
) in zip(
|
361 |
+
model_names_list,
|
362 |
+
conv_template_list,
|
363 |
+
model_path_list,
|
364 |
+
api_base_list,
|
365 |
+
token_list,
|
366 |
+
context_length_list,
|
367 |
+
):
|
368 |
+
m = HuggingfaceApiWorker(
|
369 |
+
args.controller_address,
|
370 |
+
args.worker_address,
|
371 |
+
worker_id,
|
372 |
+
model_path,
|
373 |
+
api_base,
|
374 |
+
token,
|
375 |
+
context_length,
|
376 |
+
model_names,
|
377 |
+
args.limit_worker_concurrency,
|
378 |
+
no_register=args.no_register,
|
379 |
+
conv_template=conv_template,
|
380 |
+
seed=args.seed,
|
381 |
+
)
|
382 |
+
workers.append(m)
|
383 |
+
for name in model_names:
|
384 |
+
worker_map[name] = m
|
385 |
+
|
386 |
+
# register all the models
|
387 |
+
url = args.controller_address + "/register_worker"
|
388 |
+
data = {
|
389 |
+
"worker_name": workers[0].worker_addr,
|
390 |
+
"check_heart_beat": not args.no_register,
|
391 |
+
"worker_status": {
|
392 |
+
"model_names": [m for w in workers for m in w.model_names],
|
393 |
+
"speed": 1,
|
394 |
+
"queue_length": sum([w.get_queue_length() for w in workers]),
|
395 |
+
},
|
396 |
+
}
|
397 |
+
r = requests.post(url, json=data)
|
398 |
+
assert r.status_code == 200
|
399 |
+
|
400 |
+
return args, workers
|
401 |
+
|
402 |
+
|
403 |
+
if __name__ == "__main__":
|
404 |
+
args, workers = create_huggingface_api_worker()
|
405 |
+
if args.ssl:
|
406 |
+
uvicorn.run(
|
407 |
+
app,
|
408 |
+
host=args.host,
|
409 |
+
port=args.port,
|
410 |
+
log_level="info",
|
411 |
+
ssl_keyfile=os.environ["SSL_KEYFILE"],
|
412 |
+
ssl_certfile=os.environ["SSL_CERTFILE"],
|
413 |
+
)
|
414 |
+
else:
|
415 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
inference.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Inference for FastChat models."""
|
2 |
+
import abc
|
3 |
+
import gc
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
from typing import Iterable, Optional, Dict
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import psutil
|
13 |
+
import torch
|
14 |
+
from transformers import (
|
15 |
+
AutoTokenizer,
|
16 |
+
AutoModelForCausalLM,
|
17 |
+
LlamaTokenizer,
|
18 |
+
LlamaForCausalLM,
|
19 |
+
AutoModel,
|
20 |
+
AutoModelForSeq2SeqLM,
|
21 |
+
T5Tokenizer,
|
22 |
+
AutoConfig,
|
23 |
+
)
|
24 |
+
from transformers.generation.logits_process import (
|
25 |
+
LogitsProcessorList,
|
26 |
+
RepetitionPenaltyLogitsProcessor,
|
27 |
+
TemperatureLogitsWarper,
|
28 |
+
TopKLogitsWarper,
|
29 |
+
TopPLogitsWarper,
|
30 |
+
)
|
31 |
+
|
32 |
+
from fastchat.conversation import get_conv_template, SeparatorStyle
|
33 |
+
from fastchat.model.model_adapter import (
|
34 |
+
load_model,
|
35 |
+
get_conversation_template,
|
36 |
+
get_generate_stream_function,
|
37 |
+
)
|
38 |
+
from fastchat.modules.awq import AWQConfig
|
39 |
+
from fastchat.modules.gptq import GptqConfig
|
40 |
+
from fastchat.modules.exllama import ExllamaConfig
|
41 |
+
from fastchat.modules.xfastertransformer import XftConfig
|
42 |
+
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
|
43 |
+
|
44 |
+
|
45 |
+
def prepare_logits_processor(
|
46 |
+
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
47 |
+
) -> LogitsProcessorList:
|
48 |
+
processor_list = LogitsProcessorList()
|
49 |
+
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
50 |
+
if temperature >= 1e-5 and temperature != 1.0:
|
51 |
+
processor_list.append(TemperatureLogitsWarper(temperature))
|
52 |
+
if repetition_penalty > 1.0:
|
53 |
+
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
54 |
+
if 1e-8 <= top_p < 1.0:
|
55 |
+
processor_list.append(TopPLogitsWarper(top_p))
|
56 |
+
if top_k > 0:
|
57 |
+
processor_list.append(TopKLogitsWarper(top_k))
|
58 |
+
return processor_list
|
59 |
+
|
60 |
+
|
61 |
+
@torch.inference_mode()
|
62 |
+
def generate_stream(
|
63 |
+
model,
|
64 |
+
tokenizer,
|
65 |
+
params: Dict,
|
66 |
+
device: str,
|
67 |
+
context_len: int,
|
68 |
+
stream_interval: int = 2,
|
69 |
+
judge_sent_end: bool = False,
|
70 |
+
):
|
71 |
+
if hasattr(model, "device"):
|
72 |
+
device = model.device
|
73 |
+
|
74 |
+
# Read parameters
|
75 |
+
prompt = params["prompt"]
|
76 |
+
len_prompt = len(prompt)
|
77 |
+
temperature = float(params.get("temperature", 1.0))
|
78 |
+
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
79 |
+
top_p = float(params.get("top_p", 1.0))
|
80 |
+
top_k = int(params.get("top_k", -1)) # -1 means disable
|
81 |
+
max_new_tokens = int(params.get("max_new_tokens", 256))
|
82 |
+
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
|
83 |
+
echo = bool(params.get("echo", True))
|
84 |
+
stop_str = params.get("stop", None)
|
85 |
+
stop_token_ids = params.get("stop_token_ids", None) or []
|
86 |
+
if tokenizer.eos_token_id not in stop_token_ids:
|
87 |
+
stop_token_ids.append(tokenizer.eos_token_id)
|
88 |
+
|
89 |
+
logits_processor = prepare_logits_processor(
|
90 |
+
temperature, repetition_penalty, top_p, top_k
|
91 |
+
)
|
92 |
+
input_ids = tokenizer(prompt).input_ids
|
93 |
+
|
94 |
+
if model.config.is_encoder_decoder:
|
95 |
+
max_src_len = context_len
|
96 |
+
else: # truncate
|
97 |
+
max_src_len = context_len - max_new_tokens - 1
|
98 |
+
|
99 |
+
input_ids = input_ids[-max_src_len:]
|
100 |
+
output_ids = list(input_ids)
|
101 |
+
input_echo_len = len(input_ids)
|
102 |
+
|
103 |
+
if model.config.is_encoder_decoder:
|
104 |
+
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
|
105 |
+
raise NotImplementedError
|
106 |
+
encoder_output = model.encoder(
|
107 |
+
input_ids=torch.as_tensor([input_ids], device=device)
|
108 |
+
)[0]
|
109 |
+
start_ids = torch.as_tensor(
|
110 |
+
[[model.generation_config.decoder_start_token_id]],
|
111 |
+
dtype=torch.int64,
|
112 |
+
device=device,
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
start_ids = torch.as_tensor([input_ids], device=device)
|
116 |
+
|
117 |
+
past_key_values = out = None
|
118 |
+
token_logprobs = [None] # The first token has no logprobs.
|
119 |
+
sent_interrupt = False
|
120 |
+
finish_reason = None
|
121 |
+
stopped = False
|
122 |
+
for i in range(max_new_tokens):
|
123 |
+
if i == 0: # prefill
|
124 |
+
if model.config.is_encoder_decoder:
|
125 |
+
out = model.decoder(
|
126 |
+
input_ids=start_ids,
|
127 |
+
encoder_hidden_states=encoder_output,
|
128 |
+
use_cache=True,
|
129 |
+
)
|
130 |
+
logits = model.lm_head(out[0])
|
131 |
+
else:
|
132 |
+
out = model(input_ids=start_ids, use_cache=True)
|
133 |
+
logits = out.logits
|
134 |
+
past_key_values = out.past_key_values
|
135 |
+
|
136 |
+
if logprobs is not None:
|
137 |
+
# Prefull logprobs for the prompt.
|
138 |
+
shift_input_ids = start_ids[..., 1:].contiguous()
|
139 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
140 |
+
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
|
141 |
+
for label_id, logit in zip(
|
142 |
+
shift_input_ids[0].tolist(), shift_logits[0]
|
143 |
+
):
|
144 |
+
token_logprobs.append(logit[label_id])
|
145 |
+
else: # decoding
|
146 |
+
if model.config.is_encoder_decoder:
|
147 |
+
out = model.decoder(
|
148 |
+
input_ids=torch.as_tensor(
|
149 |
+
[[token] if not sent_interrupt else output_ids],
|
150 |
+
device=device,
|
151 |
+
),
|
152 |
+
encoder_hidden_states=encoder_output,
|
153 |
+
use_cache=True,
|
154 |
+
past_key_values=past_key_values if not sent_interrupt else None,
|
155 |
+
)
|
156 |
+
sent_interrupt = False
|
157 |
+
|
158 |
+
logits = model.lm_head(out[0])
|
159 |
+
else:
|
160 |
+
out = model(
|
161 |
+
input_ids=torch.as_tensor(
|
162 |
+
[[token] if not sent_interrupt else output_ids],
|
163 |
+
device=device,
|
164 |
+
),
|
165 |
+
use_cache=True,
|
166 |
+
past_key_values=past_key_values if not sent_interrupt else None,
|
167 |
+
)
|
168 |
+
sent_interrupt = False
|
169 |
+
logits = out.logits
|
170 |
+
past_key_values = out.past_key_values
|
171 |
+
|
172 |
+
if logits_processor:
|
173 |
+
if repetition_penalty > 1.0:
|
174 |
+
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
|
175 |
+
else:
|
176 |
+
tmp_output_ids = None
|
177 |
+
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
178 |
+
else:
|
179 |
+
last_token_logits = logits[0, -1, :]
|
180 |
+
|
181 |
+
if device == "mps":
|
182 |
+
# Switch to CPU by avoiding some bugs in mps backend.
|
183 |
+
last_token_logits = last_token_logits.float().to("cpu")
|
184 |
+
|
185 |
+
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
186 |
+
_, indices = torch.topk(last_token_logits, 2)
|
187 |
+
tokens = [int(index) for index in indices.tolist()]
|
188 |
+
else:
|
189 |
+
probs = torch.softmax(last_token_logits, dim=-1)
|
190 |
+
indices = torch.multinomial(probs, num_samples=2)
|
191 |
+
tokens = [int(token) for token in indices.tolist()]
|
192 |
+
token = tokens[0]
|
193 |
+
output_ids.append(token)
|
194 |
+
if logprobs is not None:
|
195 |
+
# Cannot use last_token_logits because logprobs is based on raw logits.
|
196 |
+
token_logprobs.append(
|
197 |
+
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
|
198 |
+
)
|
199 |
+
|
200 |
+
if token in stop_token_ids:
|
201 |
+
stopped = True
|
202 |
+
else:
|
203 |
+
stopped = False
|
204 |
+
|
205 |
+
# Yield the output tokens
|
206 |
+
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
207 |
+
if echo:
|
208 |
+
tmp_output_ids = output_ids
|
209 |
+
rfind_start = len_prompt
|
210 |
+
else:
|
211 |
+
tmp_output_ids = output_ids[input_echo_len:]
|
212 |
+
rfind_start = 0
|
213 |
+
|
214 |
+
output = tokenizer.decode(
|
215 |
+
tmp_output_ids,
|
216 |
+
skip_special_tokens=True,
|
217 |
+
spaces_between_special_tokens=False,
|
218 |
+
clean_up_tokenization_spaces=True,
|
219 |
+
)
|
220 |
+
ret_logprobs = None
|
221 |
+
if logprobs is not None:
|
222 |
+
ret_logprobs = {
|
223 |
+
"text_offset": [],
|
224 |
+
"tokens": [
|
225 |
+
tokenizer.decode(token)
|
226 |
+
for token in (
|
227 |
+
output_ids if echo else output_ids[input_echo_len:]
|
228 |
+
)
|
229 |
+
],
|
230 |
+
"token_logprobs": token_logprobs
|
231 |
+
if echo
|
232 |
+
else token_logprobs[input_echo_len:],
|
233 |
+
"top_logprobs": [{}]
|
234 |
+
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
|
235 |
+
}
|
236 |
+
# Compute text_offset
|
237 |
+
curr_pos = 0
|
238 |
+
for text in ret_logprobs["tokens"]:
|
239 |
+
ret_logprobs["text_offset"].append(curr_pos)
|
240 |
+
curr_pos += len(text)
|
241 |
+
|
242 |
+
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
|
243 |
+
if judge_sent_end and stopped and not is_sentence_complete(output):
|
244 |
+
if len(tokens) > 1:
|
245 |
+
token = tokens[1]
|
246 |
+
output_ids[-1] = token
|
247 |
+
else:
|
248 |
+
output_ids.pop()
|
249 |
+
stopped = False
|
250 |
+
sent_interrupt = True
|
251 |
+
|
252 |
+
partially_stopped = False
|
253 |
+
if stop_str:
|
254 |
+
if isinstance(stop_str, str):
|
255 |
+
pos = output.rfind(stop_str, rfind_start)
|
256 |
+
if pos != -1:
|
257 |
+
output = output[:pos]
|
258 |
+
stopped = True
|
259 |
+
else:
|
260 |
+
partially_stopped = is_partial_stop(output, stop_str)
|
261 |
+
elif isinstance(stop_str, Iterable):
|
262 |
+
for each_stop in stop_str:
|
263 |
+
pos = output.rfind(each_stop, rfind_start)
|
264 |
+
if pos != -1:
|
265 |
+
output = output[:pos]
|
266 |
+
stopped = True
|
267 |
+
break
|
268 |
+
else:
|
269 |
+
partially_stopped = is_partial_stop(output, each_stop)
|
270 |
+
if partially_stopped:
|
271 |
+
break
|
272 |
+
else:
|
273 |
+
raise ValueError("Invalid stop field type.")
|
274 |
+
|
275 |
+
# Prevent yielding partial stop sequence
|
276 |
+
if not partially_stopped:
|
277 |
+
yield {
|
278 |
+
"text": output,
|
279 |
+
"logprobs": ret_logprobs,
|
280 |
+
"usage": {
|
281 |
+
"prompt_tokens": input_echo_len,
|
282 |
+
"completion_tokens": i,
|
283 |
+
"total_tokens": input_echo_len + i,
|
284 |
+
},
|
285 |
+
"finish_reason": None,
|
286 |
+
}
|
287 |
+
|
288 |
+
if stopped:
|
289 |
+
break
|
290 |
+
|
291 |
+
# Finish stream event, which contains finish reason
|
292 |
+
else:
|
293 |
+
finish_reason = "length"
|
294 |
+
|
295 |
+
if stopped:
|
296 |
+
finish_reason = "stop"
|
297 |
+
|
298 |
+
yield {
|
299 |
+
"text": output,
|
300 |
+
"logprobs": ret_logprobs,
|
301 |
+
"usage": {
|
302 |
+
"prompt_tokens": input_echo_len,
|
303 |
+
"completion_tokens": i,
|
304 |
+
"total_tokens": input_echo_len + i,
|
305 |
+
},
|
306 |
+
"finish_reason": finish_reason,
|
307 |
+
}
|
308 |
+
|
309 |
+
# Clean
|
310 |
+
del past_key_values, out
|
311 |
+
gc.collect()
|
312 |
+
torch.cuda.empty_cache()
|
313 |
+
if device == "xpu":
|
314 |
+
torch.xpu.empty_cache()
|
315 |
+
if device == "npu":
|
316 |
+
torch.npu.empty_cache()
|
317 |
+
|
318 |
+
|
319 |
+
class ChatIO(abc.ABC):
|
320 |
+
@abc.abstractmethod
|
321 |
+
def prompt_for_input(self, role: str) -> str:
|
322 |
+
"""Prompt for input from a role."""
|
323 |
+
|
324 |
+
@abc.abstractmethod
|
325 |
+
def prompt_for_output(self, role: str):
|
326 |
+
"""Prompt for output from a role."""
|
327 |
+
|
328 |
+
@abc.abstractmethod
|
329 |
+
def stream_output(self, output_stream):
|
330 |
+
"""Stream output."""
|
331 |
+
|
332 |
+
@abc.abstractmethod
|
333 |
+
def print_output(self, text: str):
|
334 |
+
"""Print output."""
|
335 |
+
|
336 |
+
|
337 |
+
def chat_loop(
|
338 |
+
model_path: str,
|
339 |
+
device: str,
|
340 |
+
num_gpus: int,
|
341 |
+
max_gpu_memory: str,
|
342 |
+
dtype: Optional[torch.dtype],
|
343 |
+
load_8bit: bool,
|
344 |
+
cpu_offloading: bool,
|
345 |
+
conv_template: Optional[str],
|
346 |
+
conv_system_msg: Optional[str],
|
347 |
+
temperature: float,
|
348 |
+
repetition_penalty: float,
|
349 |
+
max_new_tokens: int,
|
350 |
+
chatio: ChatIO,
|
351 |
+
gptq_config: Optional[GptqConfig] = None,
|
352 |
+
awq_config: Optional[AWQConfig] = None,
|
353 |
+
exllama_config: Optional[ExllamaConfig] = None,
|
354 |
+
xft_config: Optional[XftConfig] = None,
|
355 |
+
revision: str = "main",
|
356 |
+
judge_sent_end: bool = True,
|
357 |
+
debug: bool = True,
|
358 |
+
history: bool = True,
|
359 |
+
):
|
360 |
+
# Model
|
361 |
+
model, tokenizer = load_model(
|
362 |
+
model_path,
|
363 |
+
device=device,
|
364 |
+
num_gpus=num_gpus,
|
365 |
+
max_gpu_memory=max_gpu_memory,
|
366 |
+
dtype=dtype,
|
367 |
+
load_8bit=load_8bit,
|
368 |
+
cpu_offloading=cpu_offloading,
|
369 |
+
gptq_config=gptq_config,
|
370 |
+
awq_config=awq_config,
|
371 |
+
exllama_config=exllama_config,
|
372 |
+
xft_config=xft_config,
|
373 |
+
revision=revision,
|
374 |
+
debug=debug,
|
375 |
+
)
|
376 |
+
generate_stream_func = get_generate_stream_function(model, model_path)
|
377 |
+
|
378 |
+
model_type = str(type(model)).lower()
|
379 |
+
is_t5 = "t5" in model_type
|
380 |
+
is_codet5p = "codet5p" in model_type
|
381 |
+
is_xft = "xft" in model_type
|
382 |
+
|
383 |
+
# Hardcode T5's default repetition penalty to be 1.2
|
384 |
+
if is_t5 and repetition_penalty == 1.0:
|
385 |
+
repetition_penalty = 1.2
|
386 |
+
|
387 |
+
# Set context length
|
388 |
+
context_len = get_context_length(model.config)
|
389 |
+
|
390 |
+
# Chat
|
391 |
+
def new_chat():
|
392 |
+
if conv_template:
|
393 |
+
conv = get_conv_template(conv_template)
|
394 |
+
else:
|
395 |
+
conv = get_conversation_template(model_path)
|
396 |
+
if conv_system_msg is not None:
|
397 |
+
conv.set_system_message(conv_system_msg)
|
398 |
+
return conv
|
399 |
+
|
400 |
+
def reload_conv(conv):
|
401 |
+
"""
|
402 |
+
Reprints the conversation from the start.
|
403 |
+
"""
|
404 |
+
for message in conv.messages[conv.offset :]:
|
405 |
+
chatio.prompt_for_output(message[0])
|
406 |
+
chatio.print_output(message[1])
|
407 |
+
|
408 |
+
conv = None
|
409 |
+
|
410 |
+
while True:
|
411 |
+
if not history or not conv:
|
412 |
+
conv = new_chat()
|
413 |
+
|
414 |
+
try:
|
415 |
+
inp = chatio.prompt_for_input(conv.roles[0])
|
416 |
+
except EOFError:
|
417 |
+
inp = ""
|
418 |
+
|
419 |
+
if inp == "!!exit" or not inp:
|
420 |
+
print("exit...")
|
421 |
+
break
|
422 |
+
elif inp == "!!reset":
|
423 |
+
print("resetting...")
|
424 |
+
conv = new_chat()
|
425 |
+
continue
|
426 |
+
elif inp == "!!remove":
|
427 |
+
print("removing last message...")
|
428 |
+
if len(conv.messages) > conv.offset:
|
429 |
+
# Assistant
|
430 |
+
if conv.messages[-1][0] == conv.roles[1]:
|
431 |
+
conv.messages.pop()
|
432 |
+
# User
|
433 |
+
if conv.messages[-1][0] == conv.roles[0]:
|
434 |
+
conv.messages.pop()
|
435 |
+
reload_conv(conv)
|
436 |
+
else:
|
437 |
+
print("No messages to remove.")
|
438 |
+
continue
|
439 |
+
elif inp == "!!regen":
|
440 |
+
print("regenerating last message...")
|
441 |
+
if len(conv.messages) > conv.offset:
|
442 |
+
# Assistant
|
443 |
+
if conv.messages[-1][0] == conv.roles[1]:
|
444 |
+
conv.messages.pop()
|
445 |
+
# User
|
446 |
+
if conv.messages[-1][0] == conv.roles[0]:
|
447 |
+
reload_conv(conv)
|
448 |
+
# Set inp to previous message
|
449 |
+
inp = conv.messages.pop()[1]
|
450 |
+
else:
|
451 |
+
# Shouldn't happen in normal circumstances
|
452 |
+
print("No user message to regenerate from.")
|
453 |
+
continue
|
454 |
+
else:
|
455 |
+
print("No messages to regenerate.")
|
456 |
+
continue
|
457 |
+
elif inp.startswith("!!save"):
|
458 |
+
args = inp.split(" ", 1)
|
459 |
+
|
460 |
+
if len(args) != 2:
|
461 |
+
print("usage: !!save <filename>")
|
462 |
+
continue
|
463 |
+
else:
|
464 |
+
filename = args[1]
|
465 |
+
|
466 |
+
# Add .json if extension not present
|
467 |
+
if not "." in filename:
|
468 |
+
filename += ".json"
|
469 |
+
|
470 |
+
print("saving...", filename)
|
471 |
+
with open(filename, "w") as outfile:
|
472 |
+
json.dump(conv.dict(), outfile)
|
473 |
+
continue
|
474 |
+
elif inp.startswith("!!load"):
|
475 |
+
args = inp.split(" ", 1)
|
476 |
+
|
477 |
+
if len(args) != 2:
|
478 |
+
print("usage: !!load <filename>")
|
479 |
+
continue
|
480 |
+
else:
|
481 |
+
filename = args[1]
|
482 |
+
|
483 |
+
# Check if file exists and add .json if needed
|
484 |
+
if not os.path.exists(filename):
|
485 |
+
if (not filename.endswith(".json")) and os.path.exists(
|
486 |
+
filename + ".json"
|
487 |
+
):
|
488 |
+
filename += ".json"
|
489 |
+
else:
|
490 |
+
print("file not found:", filename)
|
491 |
+
continue
|
492 |
+
|
493 |
+
print("loading...", filename)
|
494 |
+
with open(filename, "r") as infile:
|
495 |
+
new_conv = json.load(infile)
|
496 |
+
|
497 |
+
conv = get_conv_template(new_conv["template_name"])
|
498 |
+
conv.set_system_message(new_conv["system_message"])
|
499 |
+
conv.messages = new_conv["messages"]
|
500 |
+
reload_conv(conv)
|
501 |
+
continue
|
502 |
+
|
503 |
+
conv.append_message(conv.roles[0], inp)
|
504 |
+
conv.append_message(conv.roles[1], None)
|
505 |
+
prompt = conv.get_prompt()
|
506 |
+
|
507 |
+
if is_codet5p: # codet5p is a code completion model.
|
508 |
+
prompt = inp
|
509 |
+
|
510 |
+
gen_params = {
|
511 |
+
"model": model_path,
|
512 |
+
"prompt": prompt,
|
513 |
+
"temperature": temperature,
|
514 |
+
"repetition_penalty": repetition_penalty,
|
515 |
+
"max_new_tokens": max_new_tokens,
|
516 |
+
"stop": conv.stop_str,
|
517 |
+
"stop_token_ids": conv.stop_token_ids,
|
518 |
+
"echo": False,
|
519 |
+
}
|
520 |
+
|
521 |
+
try:
|
522 |
+
chatio.prompt_for_output(conv.roles[1])
|
523 |
+
output_stream = generate_stream_func(
|
524 |
+
model,
|
525 |
+
tokenizer,
|
526 |
+
gen_params,
|
527 |
+
device,
|
528 |
+
context_len=context_len,
|
529 |
+
judge_sent_end=judge_sent_end,
|
530 |
+
)
|
531 |
+
t = time.time()
|
532 |
+
outputs = chatio.stream_output(output_stream)
|
533 |
+
duration = time.time() - t
|
534 |
+
conv.update_last_message(outputs.strip())
|
535 |
+
|
536 |
+
if debug:
|
537 |
+
num_tokens = len(tokenizer.encode(outputs))
|
538 |
+
msg = {
|
539 |
+
"conv_template": conv.name,
|
540 |
+
"prompt": prompt,
|
541 |
+
"outputs": outputs,
|
542 |
+
"speed (token/s)": round(num_tokens / duration, 2),
|
543 |
+
}
|
544 |
+
print(f"\n{msg}\n")
|
545 |
+
|
546 |
+
except KeyboardInterrupt:
|
547 |
+
print("stopped generation.")
|
548 |
+
# If generation didn't finish
|
549 |
+
if conv.messages[-1][1] is None:
|
550 |
+
conv.messages.pop()
|
551 |
+
# Remove last user message, so there isn't a double up
|
552 |
+
if conv.messages[-1][0] == conv.roles[0]:
|
553 |
+
conv.messages.pop()
|
554 |
+
|
555 |
+
reload_conv(conv)
|
launch_all_serve.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022"
|
3 |
+
|
4 |
+
Workers are listed in format of `model-path`@`host`@`port`
|
5 |
+
|
6 |
+
The key mechanism behind this scripts is:
|
7 |
+
1, execute shell cmd to launch the controller/worker/openai-api-server;
|
8 |
+
2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly.
|
9 |
+
Note that a few of non-critical `fastchat.serve` cmd options are not supported currently.
|
10 |
+
"""
|
11 |
+
import sys
|
12 |
+
import os
|
13 |
+
|
14 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
15 |
+
|
16 |
+
import subprocess
|
17 |
+
import re
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
LOGDIR = "./logs/"
|
21 |
+
|
22 |
+
if not os.path.exists(LOGDIR):
|
23 |
+
os.makedirs(LOGDIR)
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
# ------multi worker-----------------
|
27 |
+
parser.add_argument(
|
28 |
+
"--model-path-address",
|
29 |
+
default="THUDM/chatglm2-6b@localhost@20002",
|
30 |
+
nargs="+",
|
31 |
+
type=str,
|
32 |
+
help="model path, host, and port, formatted as model-path@host@port",
|
33 |
+
)
|
34 |
+
# ---------------controller-------------------------
|
35 |
+
|
36 |
+
parser.add_argument("--controller-host", type=str, default="localhost")
|
37 |
+
parser.add_argument("--controller-port", type=int, default=21001)
|
38 |
+
parser.add_argument(
|
39 |
+
"--dispatch-method",
|
40 |
+
type=str,
|
41 |
+
choices=["lottery", "shortest_queue"],
|
42 |
+
default="shortest_queue",
|
43 |
+
)
|
44 |
+
controller_args = ["controller-host", "controller-port", "dispatch-method"]
|
45 |
+
|
46 |
+
# ----------------------worker------------------------------------------
|
47 |
+
|
48 |
+
parser.add_argument("--worker-host", type=str, default="localhost")
|
49 |
+
parser.add_argument("--worker-port", type=int, default=21002)
|
50 |
+
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
51 |
+
# parser.add_argument(
|
52 |
+
# "--controller-address", type=str, default="http://localhost:21001"
|
53 |
+
# )
|
54 |
+
parser.add_argument(
|
55 |
+
"--model-path",
|
56 |
+
type=str,
|
57 |
+
default="lmsys/vicuna-7b-v1.5",
|
58 |
+
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--revision",
|
62 |
+
type=str,
|
63 |
+
default="main",
|
64 |
+
help="Hugging Face Hub model revision identifier",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--device",
|
68 |
+
type=str,
|
69 |
+
choices=["cpu", "cuda", "mps", "xpu", "npu"],
|
70 |
+
default="cuda",
|
71 |
+
help="The device type",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--gpus",
|
75 |
+
type=str,
|
76 |
+
default="0",
|
77 |
+
help="A single GPU like 1 or multiple GPUs like 0,2",
|
78 |
+
)
|
79 |
+
parser.add_argument("--num-gpus", type=int, default=1)
|
80 |
+
parser.add_argument(
|
81 |
+
"--max-gpu-memory",
|
82 |
+
type=str,
|
83 |
+
help="The maximum memory per gpu. Use a string like '13Gib'",
|
84 |
+
)
|
85 |
+
parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization")
|
86 |
+
parser.add_argument(
|
87 |
+
"--cpu-offloading",
|
88 |
+
action="store_true",
|
89 |
+
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--gptq-ckpt",
|
93 |
+
type=str,
|
94 |
+
default=None,
|
95 |
+
help="Load quantized model. The path to the local GPTQ checkpoint.",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--gptq-wbits",
|
99 |
+
type=int,
|
100 |
+
default=16,
|
101 |
+
choices=[2, 3, 4, 8, 16],
|
102 |
+
help="#bits to use for quantization",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--gptq-groupsize",
|
106 |
+
type=int,
|
107 |
+
default=-1,
|
108 |
+
help="Groupsize to use for quantization; default uses full row.",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--gptq-act-order",
|
112 |
+
action="store_true",
|
113 |
+
help="Whether to apply the activation order GPTQ heuristic",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--model-names",
|
117 |
+
type=lambda s: s.split(","),
|
118 |
+
help="Optional display comma separated names",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--limit-worker-concurrency",
|
122 |
+
type=int,
|
123 |
+
default=5,
|
124 |
+
help="Limit the model concurrency to prevent OOM.",
|
125 |
+
)
|
126 |
+
parser.add_argument("--stream-interval", type=int, default=2)
|
127 |
+
parser.add_argument("--no-register", action="store_true")
|
128 |
+
|
129 |
+
worker_args = [
|
130 |
+
"worker-host",
|
131 |
+
"worker-port",
|
132 |
+
"model-path",
|
133 |
+
"revision",
|
134 |
+
"device",
|
135 |
+
"gpus",
|
136 |
+
"num-gpus",
|
137 |
+
"max-gpu-memory",
|
138 |
+
"load-8bit",
|
139 |
+
"cpu-offloading",
|
140 |
+
"gptq-ckpt",
|
141 |
+
"gptq-wbits",
|
142 |
+
"gptq-groupsize",
|
143 |
+
"gptq-act-order",
|
144 |
+
"model-names",
|
145 |
+
"limit-worker-concurrency",
|
146 |
+
"stream-interval",
|
147 |
+
"no-register",
|
148 |
+
"controller-address",
|
149 |
+
]
|
150 |
+
# -----------------openai server---------------------------
|
151 |
+
|
152 |
+
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
|
153 |
+
parser.add_argument("--server-port", type=int, default=8001, help="port number")
|
154 |
+
parser.add_argument(
|
155 |
+
"--allow-credentials", action="store_true", help="allow credentials"
|
156 |
+
)
|
157 |
+
# parser.add_argument(
|
158 |
+
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
159 |
+
# )
|
160 |
+
# parser.add_argument(
|
161 |
+
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
162 |
+
# )
|
163 |
+
# parser.add_argument(
|
164 |
+
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
165 |
+
# )
|
166 |
+
parser.add_argument(
|
167 |
+
"--api-keys",
|
168 |
+
type=lambda s: s.split(","),
|
169 |
+
help="Optional list of comma separated API keys",
|
170 |
+
)
|
171 |
+
server_args = [
|
172 |
+
"server-host",
|
173 |
+
"server-port",
|
174 |
+
"allow-credentials",
|
175 |
+
"api-keys",
|
176 |
+
"controller-address",
|
177 |
+
]
|
178 |
+
|
179 |
+
args = parser.parse_args()
|
180 |
+
|
181 |
+
args = argparse.Namespace(
|
182 |
+
**vars(args),
|
183 |
+
**{"controller-address": f"http://{args.controller_host}:{args.controller_port}"},
|
184 |
+
)
|
185 |
+
|
186 |
+
if args.gpus:
|
187 |
+
if len(args.gpus.split(",")) < args.num_gpus:
|
188 |
+
raise ValueError(
|
189 |
+
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
190 |
+
)
|
191 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
192 |
+
|
193 |
+
# 0,controller, model_worker, openai_api_server
|
194 |
+
# 1, cmd options
|
195 |
+
# 2,LOGDIR
|
196 |
+
# 3, log file name
|
197 |
+
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
|
198 |
+
|
199 |
+
# 0 LOGDIR
|
200 |
+
#! 1 log file name
|
201 |
+
# 2 controller, worker, openai_api_server
|
202 |
+
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
|
203 |
+
sleep 1s;
|
204 |
+
echo "wait {2} running"
|
205 |
+
done
|
206 |
+
echo '{2} running' """
|
207 |
+
|
208 |
+
|
209 |
+
def string_args(args, args_list):
|
210 |
+
args_str = ""
|
211 |
+
for key, value in args._get_kwargs():
|
212 |
+
key = key.replace("_", "-")
|
213 |
+
if key not in args_list:
|
214 |
+
continue
|
215 |
+
|
216 |
+
key = key.split("-")[-1] if re.search("port|host", key) else key
|
217 |
+
if not value:
|
218 |
+
pass
|
219 |
+
# 1==True -> True
|
220 |
+
elif isinstance(value, bool) and value == True:
|
221 |
+
args_str += f" --{key} "
|
222 |
+
elif (
|
223 |
+
isinstance(value, list)
|
224 |
+
or isinstance(value, tuple)
|
225 |
+
or isinstance(value, set)
|
226 |
+
):
|
227 |
+
value = " ".join(value)
|
228 |
+
args_str += f" --{key} {value} "
|
229 |
+
else:
|
230 |
+
args_str += f" --{key} {value} "
|
231 |
+
|
232 |
+
return args_str
|
233 |
+
|
234 |
+
|
235 |
+
def launch_worker(item):
|
236 |
+
log_name = (
|
237 |
+
item.split("/")[-1]
|
238 |
+
.split("\\")[-1]
|
239 |
+
.replace("-", "_")
|
240 |
+
.replace("@", "_")
|
241 |
+
.replace(".", "_")
|
242 |
+
)
|
243 |
+
|
244 |
+
args.model_path, args.worker_host, args.worker_port = item.split("@")
|
245 |
+
print("*" * 80)
|
246 |
+
worker_str_args = string_args(args, worker_args)
|
247 |
+
print(worker_str_args)
|
248 |
+
worker_sh = base_launch_sh.format(
|
249 |
+
"model_worker", worker_str_args, LOGDIR, f"worker_{log_name}"
|
250 |
+
)
|
251 |
+
worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker")
|
252 |
+
subprocess.run(worker_sh, shell=True, check=True)
|
253 |
+
subprocess.run(worker_check_sh, shell=True, check=True)
|
254 |
+
|
255 |
+
|
256 |
+
def launch_all():
|
257 |
+
controller_str_args = string_args(args, controller_args)
|
258 |
+
controller_sh = base_launch_sh.format(
|
259 |
+
"controller", controller_str_args, LOGDIR, "controller"
|
260 |
+
)
|
261 |
+
controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller")
|
262 |
+
subprocess.run(controller_sh, shell=True, check=True)
|
263 |
+
subprocess.run(controller_check_sh, shell=True, check=True)
|
264 |
+
|
265 |
+
if isinstance(args.model_path_address, str):
|
266 |
+
launch_worker(args.model_path_address)
|
267 |
+
else:
|
268 |
+
for idx, item in enumerate(args.model_path_address):
|
269 |
+
print(f"loading {idx}th model:{item}")
|
270 |
+
launch_worker(item)
|
271 |
+
|
272 |
+
server_str_args = string_args(args, server_args)
|
273 |
+
server_sh = base_launch_sh.format(
|
274 |
+
"openai_api_server", server_str_args, LOGDIR, "openai_api_server"
|
275 |
+
)
|
276 |
+
server_check_sh = base_check_sh.format(
|
277 |
+
LOGDIR, "openai_api_server", "openai_api_server"
|
278 |
+
)
|
279 |
+
subprocess.run(server_sh, shell=True, check=True)
|
280 |
+
subprocess.run(server_check_sh, shell=True, check=True)
|
281 |
+
|
282 |
+
|
283 |
+
if __name__ == "__main__":
|
284 |
+
launch_all()
|
lightllm_worker.py
ADDED
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker that executes the model based on LightLLM.
|
3 |
+
|
4 |
+
See documentations at docs/lightllm_integration.md
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import asyncio
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import torch
|
12 |
+
import uvicorn
|
13 |
+
|
14 |
+
from transformers import AutoConfig
|
15 |
+
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
19 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
20 |
+
|
21 |
+
from fastchat.serve.base_model_worker import BaseModelWorker
|
22 |
+
from fastchat.serve.model_worker import (
|
23 |
+
logger,
|
24 |
+
worker_id,
|
25 |
+
)
|
26 |
+
|
27 |
+
from lightllm.server.sampling_params import SamplingParams
|
28 |
+
from lightllm.server.multimodal_params import MultimodalParams
|
29 |
+
from lightllm.server.httpserver.manager import HttpServerManager
|
30 |
+
from lightllm.server.detokenization.manager import start_detokenization_process
|
31 |
+
from lightllm.server.router.manager import start_router_process
|
32 |
+
from lightllm.server.req_id_generator import ReqIDGenerator
|
33 |
+
|
34 |
+
from lightllm.utils.net_utils import alloc_can_use_network_port
|
35 |
+
from lightllm.utils.start_utils import start_submodule_processes
|
36 |
+
from fastchat.utils import get_context_length, is_partial_stop
|
37 |
+
|
38 |
+
app = FastAPI()
|
39 |
+
g_id_gen = ReqIDGenerator()
|
40 |
+
|
41 |
+
|
42 |
+
class LightLLMWorker(BaseModelWorker):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
controller_addr: str,
|
46 |
+
worker_addr: str,
|
47 |
+
worker_id: str,
|
48 |
+
model_path: str,
|
49 |
+
model_names: List[str],
|
50 |
+
limit_worker_concurrency: int,
|
51 |
+
no_register: bool,
|
52 |
+
conv_template: str,
|
53 |
+
tokenizer,
|
54 |
+
context_len,
|
55 |
+
):
|
56 |
+
super().__init__(
|
57 |
+
controller_addr,
|
58 |
+
worker_addr,
|
59 |
+
worker_id,
|
60 |
+
model_path,
|
61 |
+
model_names,
|
62 |
+
limit_worker_concurrency,
|
63 |
+
conv_template,
|
64 |
+
)
|
65 |
+
|
66 |
+
logger.info(
|
67 |
+
f"Loading the model {self.model_names} on worker {worker_id}, worker type: LightLLM worker..."
|
68 |
+
)
|
69 |
+
self.tokenizer = tokenizer
|
70 |
+
self.context_len = context_len
|
71 |
+
|
72 |
+
self.is_first = True
|
73 |
+
|
74 |
+
if not no_register:
|
75 |
+
self.init_heart_beat()
|
76 |
+
|
77 |
+
async def generate_stream(self, params):
|
78 |
+
self.call_ct += 1
|
79 |
+
|
80 |
+
prompt = params.pop("prompt")
|
81 |
+
request_id = params.pop("request_id")
|
82 |
+
temperature = float(params.get("temperature", 1.0))
|
83 |
+
top_p = float(params.get("top_p", 1.0))
|
84 |
+
top_k = params.get("top_k", -1.0)
|
85 |
+
presence_penalty = float(params.get("presence_penalty", 0.0))
|
86 |
+
frequency_penalty = float(params.get("frequency_penalty", 0.0))
|
87 |
+
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
88 |
+
max_new_tokens = params.get("max_new_tokens", 256)
|
89 |
+
echo = params.get("echo", True)
|
90 |
+
stop_str = params.get("stop", None)
|
91 |
+
stop_token_ids = params.get("stop_token_ids", None) or []
|
92 |
+
if self.tokenizer.eos_token_id is not None:
|
93 |
+
stop_token_ids.append(self.tokenizer.eos_token_id)
|
94 |
+
|
95 |
+
request = params.get("request", None)
|
96 |
+
|
97 |
+
# Handle stop_str
|
98 |
+
stop = set()
|
99 |
+
if isinstance(stop_str, str) and stop_str != "":
|
100 |
+
stop.add(stop_str)
|
101 |
+
elif isinstance(stop_str, list) and stop_str != []:
|
102 |
+
stop.update(stop_str)
|
103 |
+
|
104 |
+
for tid in stop_token_ids:
|
105 |
+
if tid is not None:
|
106 |
+
s = self.tokenizer.decode(tid)
|
107 |
+
if s != "":
|
108 |
+
stop.add(s)
|
109 |
+
|
110 |
+
if self.is_first:
|
111 |
+
loop = asyncio.get_event_loop()
|
112 |
+
loop.create_task(httpserver_manager.handle_loop())
|
113 |
+
self.is_first = False
|
114 |
+
|
115 |
+
# make sampling params in vllm
|
116 |
+
top_p = max(top_p, 1e-5)
|
117 |
+
if temperature <= 1e-5:
|
118 |
+
top_p = 1.0
|
119 |
+
|
120 |
+
sampling_params = SamplingParams(
|
121 |
+
do_sample=temperature > 0.0,
|
122 |
+
temperature=temperature,
|
123 |
+
top_p=top_p,
|
124 |
+
top_k=top_k,
|
125 |
+
presence_penalty=presence_penalty,
|
126 |
+
frequency_penalty=frequency_penalty,
|
127 |
+
repetition_penalty=repetition_penalty,
|
128 |
+
max_new_tokens=max_new_tokens,
|
129 |
+
stop_sequences=list(stop),
|
130 |
+
)
|
131 |
+
sampling_params.verify()
|
132 |
+
|
133 |
+
results_generator = httpserver_manager.generate(
|
134 |
+
prompt, sampling_params, request_id, MultimodalParams()
|
135 |
+
)
|
136 |
+
|
137 |
+
completion_tokens = 0
|
138 |
+
text_outputs = ""
|
139 |
+
cumulative_logprob = 0.0
|
140 |
+
|
141 |
+
async for request_output, metadata, finish_status in results_generator:
|
142 |
+
text_outputs += request_output
|
143 |
+
completion_tokens += 1
|
144 |
+
|
145 |
+
partial_stop = any(is_partial_stop(text_outputs, i) for i in stop)
|
146 |
+
# prevent yielding partial stop sequence
|
147 |
+
if partial_stop:
|
148 |
+
continue
|
149 |
+
|
150 |
+
if type(finish_status) is bool: # compatibility with old version
|
151 |
+
finish_reason = "stop" if finish_status else None
|
152 |
+
else:
|
153 |
+
finish_reason = finish_status.get_finish_reason()
|
154 |
+
|
155 |
+
if request and await request.is_disconnected():
|
156 |
+
await httpserver_manager.abort(request_id)
|
157 |
+
finish_reason = "abort"
|
158 |
+
|
159 |
+
logprob = metadata.get("logprob", None)
|
160 |
+
if logprob is not None:
|
161 |
+
cumulative_logprob += logprob
|
162 |
+
|
163 |
+
prompt_tokens = metadata["prompt_tokens"]
|
164 |
+
ret = {
|
165 |
+
"text": prompt + text_outputs if echo else text_outputs,
|
166 |
+
"error_code": 0,
|
167 |
+
"usage": {
|
168 |
+
"prompt_tokens": prompt_tokens,
|
169 |
+
"completion_tokens": completion_tokens,
|
170 |
+
"total_tokens": prompt_tokens + completion_tokens,
|
171 |
+
},
|
172 |
+
"cumulative_logprob": cumulative_logprob,
|
173 |
+
}
|
174 |
+
|
175 |
+
if finish_reason is not None:
|
176 |
+
yield (
|
177 |
+
json.dumps({**ret, "finish_reason": None}, ensure_ascii=False)
|
178 |
+
+ "\0"
|
179 |
+
).encode("utf-8")
|
180 |
+
yield (
|
181 |
+
json.dumps({**ret, "finish_reason": finish_reason}, ensure_ascii=False)
|
182 |
+
+ "\0"
|
183 |
+
).encode("utf-8")
|
184 |
+
|
185 |
+
if finish_reason is not None: # In case of abort, we need to break the loop
|
186 |
+
break
|
187 |
+
|
188 |
+
async def generate(self, params):
|
189 |
+
async for x in self.generate_stream(params):
|
190 |
+
pass
|
191 |
+
return json.loads(x[:-1].decode())
|
192 |
+
|
193 |
+
|
194 |
+
def release_worker_semaphore():
|
195 |
+
worker.semaphore.release()
|
196 |
+
|
197 |
+
|
198 |
+
def acquire_worker_semaphore():
|
199 |
+
if worker.semaphore is None:
|
200 |
+
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
201 |
+
return worker.semaphore.acquire()
|
202 |
+
|
203 |
+
|
204 |
+
def create_background_tasks(request_id):
|
205 |
+
async def abort_request() -> None:
|
206 |
+
await httpserver_manager.abort(request_id)
|
207 |
+
|
208 |
+
background_tasks = BackgroundTasks()
|
209 |
+
background_tasks.add_task(release_worker_semaphore)
|
210 |
+
background_tasks.add_task(abort_request)
|
211 |
+
return background_tasks
|
212 |
+
|
213 |
+
|
214 |
+
@app.post("/worker_generate_stream")
|
215 |
+
async def api_generate_stream(request: Request):
|
216 |
+
params = await request.json()
|
217 |
+
await acquire_worker_semaphore()
|
218 |
+
request_id = g_id_gen.generate_id()
|
219 |
+
params["request_id"] = request_id
|
220 |
+
params["request"] = request
|
221 |
+
generator = worker.generate_stream(params)
|
222 |
+
background_tasks = create_background_tasks(request_id)
|
223 |
+
return StreamingResponse(generator, background=background_tasks)
|
224 |
+
|
225 |
+
|
226 |
+
@app.post("/worker_generate")
|
227 |
+
async def api_generate(request: Request):
|
228 |
+
params = await request.json()
|
229 |
+
await acquire_worker_semaphore()
|
230 |
+
request_id = g_id_gen.generate_id()
|
231 |
+
params["request_id"] = request_id
|
232 |
+
params["request"] = request
|
233 |
+
output = await worker.generate(params)
|
234 |
+
release_worker_semaphore()
|
235 |
+
await httpserver_manager.abort(request_id)
|
236 |
+
return JSONResponse(output)
|
237 |
+
|
238 |
+
|
239 |
+
@app.post("/worker_get_status")
|
240 |
+
async def api_get_status(request: Request):
|
241 |
+
return worker.get_status()
|
242 |
+
|
243 |
+
|
244 |
+
@app.post("/count_token")
|
245 |
+
async def api_count_token(request: Request):
|
246 |
+
params = await request.json()
|
247 |
+
return worker.count_token(params)
|
248 |
+
|
249 |
+
|
250 |
+
@app.post("/worker_get_conv_template")
|
251 |
+
async def api_get_conv(request: Request):
|
252 |
+
return worker.get_conv_template()
|
253 |
+
|
254 |
+
|
255 |
+
@app.post("/model_details")
|
256 |
+
async def api_model_details(request: Request):
|
257 |
+
return {"context_length": worker.context_len}
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == "__main__":
|
261 |
+
torch.multiprocessing.set_start_method("spawn")
|
262 |
+
parser = argparse.ArgumentParser()
|
263 |
+
parser.add_argument("--host", type=str, default="127.0.0.1")
|
264 |
+
parser.add_argument("--port", type=int, default=8000)
|
265 |
+
|
266 |
+
parser.add_argument(
|
267 |
+
"--model-path",
|
268 |
+
dest="model_dir",
|
269 |
+
type=str,
|
270 |
+
default=None,
|
271 |
+
help="the model weight dir path, the app will load config, weights and tokenizer from this dir",
|
272 |
+
)
|
273 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
274 |
+
parser.add_argument(
|
275 |
+
"--controller-address", type=str, default="http://localhost:21001"
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
279 |
+
)
|
280 |
+
parser.add_argument(
|
281 |
+
"--model-names",
|
282 |
+
type=lambda s: s.split(","),
|
283 |
+
help="Optional display comma separated names",
|
284 |
+
)
|
285 |
+
parser.add_argument("--limit-worker-concurrency", type=int, default=1024)
|
286 |
+
parser.add_argument("--no-register", action="store_true")
|
287 |
+
|
288 |
+
parser.add_argument(
|
289 |
+
"--tokenizer_mode",
|
290 |
+
type=str,
|
291 |
+
default="slow",
|
292 |
+
help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is good for debug and test,
|
293 |
+
when you want to get best performance, try auto mode""",
|
294 |
+
)
|
295 |
+
parser.add_argument(
|
296 |
+
"--load_way",
|
297 |
+
type=str,
|
298 |
+
default="HF",
|
299 |
+
help="the way of loading model weights, the default is HF(Huggingface format), llama also supports DS(Deepspeed)",
|
300 |
+
)
|
301 |
+
parser.add_argument(
|
302 |
+
"--max_total_token_num",
|
303 |
+
type=int,
|
304 |
+
default=6000,
|
305 |
+
help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)",
|
306 |
+
)
|
307 |
+
parser.add_argument(
|
308 |
+
"--batch_max_tokens",
|
309 |
+
type=int,
|
310 |
+
default=None,
|
311 |
+
help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM",
|
312 |
+
)
|
313 |
+
parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id")
|
314 |
+
parser.add_argument(
|
315 |
+
"--running_max_req_size",
|
316 |
+
type=int,
|
317 |
+
default=1000,
|
318 |
+
help="the max size for forward requests in the same time",
|
319 |
+
)
|
320 |
+
parser.add_argument(
|
321 |
+
"--tp", type=int, default=1, help="model tp parral size, the default is 1"
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--max_req_input_len",
|
325 |
+
type=int,
|
326 |
+
default=None,
|
327 |
+
help="the max value for req input tokens num. If None, it will be derived from the config.",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--max_req_total_len",
|
331 |
+
type=int,
|
332 |
+
default=None,
|
333 |
+
help="the max value for req_input_len + req_output_len. If None, it will be derived from the config.",
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--mode",
|
337 |
+
type=str,
|
338 |
+
default=[],
|
339 |
+
nargs="+",
|
340 |
+
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding
|
341 |
+
| triton_gqa_attention | triton_gqa_flashdecoding]
|
342 |
+
[triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight],
|
343 |
+
triton_flashdecoding mode is for long context, current support llama llama2 qwen;
|
344 |
+
triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
|
345 |
+
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
|
346 |
+
ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel;
|
347 |
+
ppl_fp16 mode use ppl fast fp16 decode attention kernel;
|
348 |
+
triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights;
|
349 |
+
you need to read source code to make sure the supported detail mode for all models""",
|
350 |
+
)
|
351 |
+
parser.add_argument(
|
352 |
+
"--trust_remote_code",
|
353 |
+
action="store_true",
|
354 |
+
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
355 |
+
)
|
356 |
+
parser.add_argument(
|
357 |
+
"--disable_log_stats",
|
358 |
+
action="store_true",
|
359 |
+
help="disable logging throughput stats.",
|
360 |
+
)
|
361 |
+
parser.add_argument(
|
362 |
+
"--log_stats_interval",
|
363 |
+
type=int,
|
364 |
+
default=10,
|
365 |
+
help="log stats interval in second.",
|
366 |
+
)
|
367 |
+
|
368 |
+
parser.add_argument(
|
369 |
+
"--router_token_ratio",
|
370 |
+
type=float,
|
371 |
+
default=0.0,
|
372 |
+
help="token ratio to control router dispatch",
|
373 |
+
)
|
374 |
+
parser.add_argument(
|
375 |
+
"--router_max_new_token_len",
|
376 |
+
type=int,
|
377 |
+
default=1024,
|
378 |
+
help="the request max new token len for router",
|
379 |
+
)
|
380 |
+
|
381 |
+
parser.add_argument(
|
382 |
+
"--no_skipping_special_tokens",
|
383 |
+
action="store_true",
|
384 |
+
help="whether to skip special tokens when decoding",
|
385 |
+
)
|
386 |
+
parser.add_argument(
|
387 |
+
"--no_spaces_between_special_tokens",
|
388 |
+
action="store_true",
|
389 |
+
help="whether to add spaces between special tokens when decoding",
|
390 |
+
)
|
391 |
+
|
392 |
+
parser.add_argument(
|
393 |
+
"--splitfuse_mode", action="store_true", help="use splitfuse mode"
|
394 |
+
)
|
395 |
+
parser.add_argument(
|
396 |
+
"--splitfuse_block_size", type=int, default=256, help="splitfuse block size"
|
397 |
+
)
|
398 |
+
parser.add_argument(
|
399 |
+
"--prompt_cache_strs",
|
400 |
+
type=str,
|
401 |
+
default=[],
|
402 |
+
nargs="+",
|
403 |
+
help="""prompt cache strs""",
|
404 |
+
)
|
405 |
+
parser.add_argument(
|
406 |
+
"--cache_capacity",
|
407 |
+
type=int,
|
408 |
+
default=200,
|
409 |
+
help="cache server capacity for multimodal resources",
|
410 |
+
)
|
411 |
+
parser.add_argument(
|
412 |
+
"--cache_reserved_ratio",
|
413 |
+
type=float,
|
414 |
+
default=0.5,
|
415 |
+
help="cache server reserved capacity ratio after clear",
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--return_all_prompt_logprobs",
|
419 |
+
action="store_true",
|
420 |
+
help="return all prompt tokens logprobs",
|
421 |
+
)
|
422 |
+
parser.add_argument(
|
423 |
+
"--long_truncation_mode",
|
424 |
+
type=str,
|
425 |
+
choices=[None, "head", "center"],
|
426 |
+
default=None,
|
427 |
+
help="""use to select the handle way when input token len > max_req_input_len.
|
428 |
+
None : raise Exception
|
429 |
+
head : remove some head tokens to make input token len <= max_req_input_len
|
430 |
+
center : remove some tokens in center loc to make input token len <= max_req_input_len""",
|
431 |
+
)
|
432 |
+
|
433 |
+
args = parser.parse_args()
|
434 |
+
|
435 |
+
# 非splitfuse 模式,不支持 prompt cache 特性
|
436 |
+
if not args.splitfuse_mode:
|
437 |
+
assert len(args.prompt_cache_strs) == 0
|
438 |
+
|
439 |
+
model_config = AutoConfig.from_pretrained(args.model_dir)
|
440 |
+
context_length = get_context_length(model_config)
|
441 |
+
|
442 |
+
if args.max_req_input_len is None:
|
443 |
+
args.max_req_input_len = context_length - 1
|
444 |
+
if args.max_req_total_len is None:
|
445 |
+
args.max_req_total_len = context_length
|
446 |
+
|
447 |
+
assert args.max_req_input_len < args.max_req_total_len
|
448 |
+
assert args.max_req_total_len <= args.max_total_token_num
|
449 |
+
|
450 |
+
if not args.splitfuse_mode:
|
451 |
+
# 普通模式下
|
452 |
+
if args.batch_max_tokens is None:
|
453 |
+
batch_max_tokens = int(1 / 6 * args.max_total_token_num)
|
454 |
+
batch_max_tokens = max(batch_max_tokens, args.max_req_total_len)
|
455 |
+
args.batch_max_tokens = batch_max_tokens
|
456 |
+
else:
|
457 |
+
assert (
|
458 |
+
args.batch_max_tokens >= args.max_req_total_len
|
459 |
+
), "batch_max_tokens must >= max_req_total_len"
|
460 |
+
else:
|
461 |
+
# splitfuse 模式下
|
462 |
+
# assert args.batch_max_tokens is not None, "need to set by yourself"
|
463 |
+
if args.batch_max_tokens is None:
|
464 |
+
batch_max_tokens = int(1 / 6 * args.max_total_token_num)
|
465 |
+
batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size)
|
466 |
+
args.batch_max_tokens = batch_max_tokens
|
467 |
+
|
468 |
+
can_use_ports = alloc_can_use_network_port(num=6 + args.tp)
|
469 |
+
|
470 |
+
assert can_use_ports is not None, "Can not alloc enough free ports."
|
471 |
+
(
|
472 |
+
router_port,
|
473 |
+
detokenization_port,
|
474 |
+
httpserver_port,
|
475 |
+
visual_port,
|
476 |
+
cache_port,
|
477 |
+
nccl_port,
|
478 |
+
) = can_use_ports[0:6]
|
479 |
+
args.nccl_port = nccl_port
|
480 |
+
model_rpc_ports = can_use_ports[6:]
|
481 |
+
|
482 |
+
global httpserver_manager
|
483 |
+
httpserver_manager = HttpServerManager(
|
484 |
+
args,
|
485 |
+
router_port=router_port,
|
486 |
+
cache_port=cache_port,
|
487 |
+
visual_port=visual_port,
|
488 |
+
httpserver_port=httpserver_port,
|
489 |
+
enable_multimodal=False,
|
490 |
+
)
|
491 |
+
|
492 |
+
start_submodule_processes(
|
493 |
+
start_funcs=[start_router_process, start_detokenization_process],
|
494 |
+
start_args=[
|
495 |
+
(args, router_port, detokenization_port, model_rpc_ports),
|
496 |
+
(args, detokenization_port, httpserver_port),
|
497 |
+
],
|
498 |
+
)
|
499 |
+
worker = LightLLMWorker(
|
500 |
+
args.controller_address,
|
501 |
+
args.worker_address,
|
502 |
+
worker_id,
|
503 |
+
args.model_dir,
|
504 |
+
args.model_names,
|
505 |
+
args.limit_worker_concurrency,
|
506 |
+
args.no_register,
|
507 |
+
args.conv_template,
|
508 |
+
httpserver_manager.tokenizer,
|
509 |
+
context_length,
|
510 |
+
)
|
511 |
+
|
512 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
mlx_worker.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker using Apple MLX
|
3 |
+
|
4 |
+
https://github.com/ml-explore/mlx-examples/tree/main/llms
|
5 |
+
|
6 |
+
Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py
|
7 |
+
|
8 |
+
You must install MLX python:
|
9 |
+
|
10 |
+
pip install mlx-lm
|
11 |
+
"""
|
12 |
+
|
13 |
+
import argparse
|
14 |
+
import asyncio
|
15 |
+
import atexit
|
16 |
+
import json
|
17 |
+
from typing import List
|
18 |
+
import uuid
|
19 |
+
|
20 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
21 |
+
from fastapi.concurrency import run_in_threadpool
|
22 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
23 |
+
import uvicorn
|
24 |
+
|
25 |
+
from fastchat.serve.base_model_worker import BaseModelWorker
|
26 |
+
from fastchat.serve.model_worker import (
|
27 |
+
logger,
|
28 |
+
worker_id,
|
29 |
+
)
|
30 |
+
from fastchat.utils import get_context_length, is_partial_stop
|
31 |
+
|
32 |
+
import mlx.core as mx
|
33 |
+
from mlx_lm import load, generate
|
34 |
+
from mlx_lm.utils import generate_step
|
35 |
+
|
36 |
+
app = FastAPI()
|
37 |
+
|
38 |
+
|
39 |
+
class MLXWorker(BaseModelWorker):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
controller_addr: str,
|
43 |
+
worker_addr: str,
|
44 |
+
worker_id: str,
|
45 |
+
model_path: str,
|
46 |
+
model_names: List[str],
|
47 |
+
limit_worker_concurrency: int,
|
48 |
+
no_register: bool,
|
49 |
+
llm_engine: "MLX",
|
50 |
+
conv_template: str,
|
51 |
+
):
|
52 |
+
super().__init__(
|
53 |
+
controller_addr,
|
54 |
+
worker_addr,
|
55 |
+
worker_id,
|
56 |
+
model_path,
|
57 |
+
model_names,
|
58 |
+
limit_worker_concurrency,
|
59 |
+
conv_template,
|
60 |
+
)
|
61 |
+
|
62 |
+
logger.info(
|
63 |
+
f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..."
|
64 |
+
)
|
65 |
+
|
66 |
+
self.model_name = model_path
|
67 |
+
self.mlx_model, self.mlx_tokenizer = load(model_path)
|
68 |
+
|
69 |
+
self.tokenizer = self.mlx_tokenizer
|
70 |
+
# self.context_len = get_context_length(
|
71 |
+
# llm_engine.engine.model_config.hf_config)
|
72 |
+
self.context_len = 2048 # hard code for now -- not sure how to get in MLX
|
73 |
+
|
74 |
+
if not no_register:
|
75 |
+
self.init_heart_beat()
|
76 |
+
|
77 |
+
async def generate_stream(self, params):
|
78 |
+
self.call_ct += 1
|
79 |
+
|
80 |
+
context = params.pop("prompt")
|
81 |
+
request_id = params.pop("request_id")
|
82 |
+
temperature = float(params.get("temperature", 1.0))
|
83 |
+
top_p = float(params.get("top_p", 1.0))
|
84 |
+
top_k = params.get("top_k", -1.0)
|
85 |
+
presence_penalty = float(params.get("presence_penalty", 0.0))
|
86 |
+
frequency_penalty = float(params.get("frequency_penalty", 0.0))
|
87 |
+
max_new_tokens = params.get("max_new_tokens", 256)
|
88 |
+
stop_str = params.get("stop", None)
|
89 |
+
stop_token_ids = params.get("stop_token_ids", None) or []
|
90 |
+
if self.tokenizer.eos_token_id is not None:
|
91 |
+
stop_token_ids.append(self.tokenizer.eos_token_id)
|
92 |
+
echo = params.get("echo", True)
|
93 |
+
use_beam_search = params.get("use_beam_search", False)
|
94 |
+
best_of = params.get("best_of", None)
|
95 |
+
|
96 |
+
# Handle stop_str
|
97 |
+
stop = set()
|
98 |
+
if isinstance(stop_str, str) and stop_str != "":
|
99 |
+
stop.add(stop_str)
|
100 |
+
elif isinstance(stop_str, list) and stop_str != []:
|
101 |
+
stop.update(stop_str)
|
102 |
+
|
103 |
+
for tid in stop_token_ids:
|
104 |
+
if tid is not None:
|
105 |
+
s = self.tokenizer.decode(tid)
|
106 |
+
if s != "":
|
107 |
+
stop.add(s)
|
108 |
+
|
109 |
+
print("Stop patterns: ", stop)
|
110 |
+
|
111 |
+
top_p = max(top_p, 1e-5)
|
112 |
+
if temperature <= 1e-5:
|
113 |
+
top_p = 1.0
|
114 |
+
|
115 |
+
tokens = []
|
116 |
+
skip = 0
|
117 |
+
|
118 |
+
context_mlx = mx.array(self.tokenizer.encode(context))
|
119 |
+
|
120 |
+
finish_reason = "length"
|
121 |
+
|
122 |
+
iterator = await run_in_threadpool(
|
123 |
+
generate_step, context_mlx, self.mlx_model, temperature
|
124 |
+
)
|
125 |
+
|
126 |
+
for i in range(max_new_tokens):
|
127 |
+
(token, _) = await run_in_threadpool(next, iterator)
|
128 |
+
if token == self.mlx_tokenizer.eos_token_id:
|
129 |
+
finish_reason = "stop"
|
130 |
+
break
|
131 |
+
tokens.append(token.item())
|
132 |
+
tokens_decoded = self.mlx_tokenizer.decode(tokens)
|
133 |
+
last_token_decoded = self.mlx_tokenizer.decode([token.item()])
|
134 |
+
skip = len(tokens_decoded)
|
135 |
+
|
136 |
+
partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop)
|
137 |
+
|
138 |
+
if partial_stop:
|
139 |
+
finish_reason = "stop"
|
140 |
+
break
|
141 |
+
|
142 |
+
ret = {
|
143 |
+
"text": tokens_decoded,
|
144 |
+
"error_code": 0,
|
145 |
+
"usage": {
|
146 |
+
"prompt_tokens": len(context),
|
147 |
+
"completion_tokens": len(tokens),
|
148 |
+
"total_tokens": len(context) + len(tokens),
|
149 |
+
},
|
150 |
+
"cumulative_logprob": [],
|
151 |
+
"finish_reason": None, # hard code for now
|
152 |
+
}
|
153 |
+
# print(ret)
|
154 |
+
yield (json.dumps(ret) + "\0").encode()
|
155 |
+
ret = {
|
156 |
+
"text": self.mlx_tokenizer.decode(tokens),
|
157 |
+
"error_code": 0,
|
158 |
+
"usage": {},
|
159 |
+
"cumulative_logprob": [],
|
160 |
+
"finish_reason": finish_reason,
|
161 |
+
}
|
162 |
+
yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode()
|
163 |
+
yield (json.dumps(ret) + "\0").encode()
|
164 |
+
|
165 |
+
async def generate(self, params):
|
166 |
+
async for x in self.generate_stream(params):
|
167 |
+
pass
|
168 |
+
return json.loads(x[:-1].decode())
|
169 |
+
|
170 |
+
|
171 |
+
def release_worker_semaphore():
|
172 |
+
worker.semaphore.release()
|
173 |
+
|
174 |
+
|
175 |
+
def acquire_worker_semaphore():
|
176 |
+
if worker.semaphore is None:
|
177 |
+
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
178 |
+
return worker.semaphore.acquire()
|
179 |
+
|
180 |
+
|
181 |
+
def create_background_tasks(request_id):
|
182 |
+
async def abort_request() -> None:
|
183 |
+
print("trying to abort but not implemented")
|
184 |
+
|
185 |
+
background_tasks = BackgroundTasks()
|
186 |
+
background_tasks.add_task(release_worker_semaphore)
|
187 |
+
background_tasks.add_task(abort_request)
|
188 |
+
return background_tasks
|
189 |
+
|
190 |
+
|
191 |
+
@app.post("/worker_generate_stream")
|
192 |
+
async def api_generate_stream(request: Request):
|
193 |
+
params = await request.json()
|
194 |
+
await acquire_worker_semaphore()
|
195 |
+
request_id = uuid.uuid4()
|
196 |
+
params["request_id"] = str(request_id)
|
197 |
+
generator = worker.generate_stream(params)
|
198 |
+
background_tasks = create_background_tasks(request_id)
|
199 |
+
return StreamingResponse(generator, background=background_tasks)
|
200 |
+
|
201 |
+
|
202 |
+
@app.post("/worker_generate")
|
203 |
+
async def api_generate(request: Request):
|
204 |
+
params = await request.json()
|
205 |
+
await acquire_worker_semaphore()
|
206 |
+
request_id = uuid.uuid4()
|
207 |
+
params["request_id"] = str(request_id)
|
208 |
+
output = await worker.generate(params)
|
209 |
+
release_worker_semaphore()
|
210 |
+
# await engine.abort(request_id)
|
211 |
+
print("Trying to abort but not implemented")
|
212 |
+
return JSONResponse(output)
|
213 |
+
|
214 |
+
|
215 |
+
@app.post("/worker_get_status")
|
216 |
+
async def api_get_status(request: Request):
|
217 |
+
return worker.get_status()
|
218 |
+
|
219 |
+
|
220 |
+
@app.post("/count_token")
|
221 |
+
async def api_count_token(request: Request):
|
222 |
+
params = await request.json()
|
223 |
+
return worker.count_token(params)
|
224 |
+
|
225 |
+
|
226 |
+
@app.post("/worker_get_conv_template")
|
227 |
+
async def api_get_conv(request: Request):
|
228 |
+
return worker.get_conv_template()
|
229 |
+
|
230 |
+
|
231 |
+
@app.post("/model_details")
|
232 |
+
async def api_model_details(request: Request):
|
233 |
+
return {"context_length": worker.context_len}
|
234 |
+
|
235 |
+
|
236 |
+
worker = None
|
237 |
+
|
238 |
+
|
239 |
+
def cleanup_at_exit():
|
240 |
+
global worker
|
241 |
+
print("Cleaning up...")
|
242 |
+
del worker
|
243 |
+
|
244 |
+
|
245 |
+
atexit.register(cleanup_at_exit)
|
246 |
+
|
247 |
+
if __name__ == "__main__":
|
248 |
+
parser = argparse.ArgumentParser()
|
249 |
+
parser.add_argument("--host", type=str, default="localhost")
|
250 |
+
parser.add_argument("--port", type=int, default=21002)
|
251 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
252 |
+
parser.add_argument(
|
253 |
+
"--controller-address", type=str, default="http://localhost:21001"
|
254 |
+
)
|
255 |
+
parser.add_argument("--model-path", type=str, default="microsoft/phi-2")
|
256 |
+
parser.add_argument(
|
257 |
+
"--model-names",
|
258 |
+
type=lambda s: s.split(","),
|
259 |
+
help="Optional display comma separated names",
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--trust_remote_code",
|
266 |
+
action="store_false",
|
267 |
+
default=True,
|
268 |
+
help="Trust remote code (e.g., from HuggingFace) when"
|
269 |
+
"downloading the model and tokenizer.",
|
270 |
+
)
|
271 |
+
|
272 |
+
args, unknown = parser.parse_known_args()
|
273 |
+
|
274 |
+
if args.model_path:
|
275 |
+
args.model = args.model_path
|
276 |
+
|
277 |
+
worker = MLXWorker(
|
278 |
+
args.controller_address,
|
279 |
+
args.worker_address,
|
280 |
+
worker_id,
|
281 |
+
args.model_path,
|
282 |
+
args.model_names,
|
283 |
+
1024,
|
284 |
+
False,
|
285 |
+
"MLX",
|
286 |
+
args.conv_template,
|
287 |
+
)
|
288 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
model_worker.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker that executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import base64
|
6 |
+
import gc
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from typing import List, Optional
|
10 |
+
import uuid
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from transformers import set_seed
|
15 |
+
import uvicorn
|
16 |
+
|
17 |
+
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
|
18 |
+
from fastchat.model.model_adapter import (
|
19 |
+
load_model,
|
20 |
+
add_model_args,
|
21 |
+
get_generate_stream_function,
|
22 |
+
)
|
23 |
+
from fastchat.modules.awq import AWQConfig
|
24 |
+
from fastchat.modules.exllama import ExllamaConfig
|
25 |
+
from fastchat.modules.xfastertransformer import XftConfig
|
26 |
+
from fastchat.modules.gptq import GptqConfig
|
27 |
+
from fastchat.serve.base_model_worker import BaseModelWorker, app
|
28 |
+
from fastchat.utils import (
|
29 |
+
build_logger,
|
30 |
+
get_context_length,
|
31 |
+
str_to_torch_dtype,
|
32 |
+
)
|
33 |
+
|
34 |
+
worker_id = str(uuid.uuid4())[:8]
|
35 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
36 |
+
|
37 |
+
|
38 |
+
class ModelWorker(BaseModelWorker):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
controller_addr: str,
|
42 |
+
worker_addr: str,
|
43 |
+
worker_id: str,
|
44 |
+
model_path: str,
|
45 |
+
model_names: List[str],
|
46 |
+
limit_worker_concurrency: int,
|
47 |
+
no_register: bool,
|
48 |
+
device: str,
|
49 |
+
num_gpus: int,
|
50 |
+
max_gpu_memory: str,
|
51 |
+
revision: str = None,
|
52 |
+
dtype: Optional[torch.dtype] = None,
|
53 |
+
load_8bit: bool = False,
|
54 |
+
cpu_offloading: bool = False,
|
55 |
+
gptq_config: Optional[GptqConfig] = None,
|
56 |
+
awq_config: Optional[AWQConfig] = None,
|
57 |
+
exllama_config: Optional[ExllamaConfig] = None,
|
58 |
+
xft_config: Optional[XftConfig] = None,
|
59 |
+
stream_interval: int = 2,
|
60 |
+
conv_template: Optional[str] = None,
|
61 |
+
embed_in_truncate: bool = False,
|
62 |
+
seed: Optional[int] = None,
|
63 |
+
debug: bool = False,
|
64 |
+
**kwargs,
|
65 |
+
):
|
66 |
+
super().__init__(
|
67 |
+
controller_addr,
|
68 |
+
worker_addr,
|
69 |
+
worker_id,
|
70 |
+
model_path,
|
71 |
+
model_names,
|
72 |
+
limit_worker_concurrency,
|
73 |
+
conv_template=conv_template,
|
74 |
+
)
|
75 |
+
|
76 |
+
logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
|
77 |
+
self.model, self.tokenizer = load_model(
|
78 |
+
model_path,
|
79 |
+
revision=revision,
|
80 |
+
device=device,
|
81 |
+
num_gpus=num_gpus,
|
82 |
+
max_gpu_memory=max_gpu_memory,
|
83 |
+
dtype=dtype,
|
84 |
+
load_8bit=load_8bit,
|
85 |
+
cpu_offloading=cpu_offloading,
|
86 |
+
gptq_config=gptq_config,
|
87 |
+
awq_config=awq_config,
|
88 |
+
exllama_config=exllama_config,
|
89 |
+
xft_config=xft_config,
|
90 |
+
debug=debug,
|
91 |
+
)
|
92 |
+
self.device = device
|
93 |
+
if self.tokenizer.pad_token == None:
|
94 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
95 |
+
self.context_len = get_context_length(self.model.config)
|
96 |
+
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
|
97 |
+
self.stream_interval = stream_interval
|
98 |
+
self.embed_in_truncate = embed_in_truncate
|
99 |
+
self.seed = seed
|
100 |
+
|
101 |
+
if not no_register:
|
102 |
+
self.init_heart_beat()
|
103 |
+
|
104 |
+
def generate_stream_gate(self, params):
|
105 |
+
if self.device == "npu":
|
106 |
+
import torch_npu
|
107 |
+
|
108 |
+
torch_npu.npu.set_device("npu:0")
|
109 |
+
self.call_ct += 1
|
110 |
+
|
111 |
+
try:
|
112 |
+
if self.seed is not None:
|
113 |
+
set_seed(self.seed)
|
114 |
+
for output in self.generate_stream_func(
|
115 |
+
self.model,
|
116 |
+
self.tokenizer,
|
117 |
+
params,
|
118 |
+
self.device,
|
119 |
+
self.context_len,
|
120 |
+
self.stream_interval,
|
121 |
+
):
|
122 |
+
ret = {
|
123 |
+
"text": output["text"],
|
124 |
+
"error_code": 0,
|
125 |
+
}
|
126 |
+
if "usage" in output:
|
127 |
+
ret["usage"] = output["usage"]
|
128 |
+
if "finish_reason" in output:
|
129 |
+
ret["finish_reason"] = output["finish_reason"]
|
130 |
+
if "logprobs" in output:
|
131 |
+
ret["logprobs"] = output["logprobs"]
|
132 |
+
yield json.dumps(ret).encode() + b"\0"
|
133 |
+
except torch.cuda.OutOfMemoryError as e:
|
134 |
+
ret = {
|
135 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
136 |
+
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
|
137 |
+
}
|
138 |
+
yield json.dumps(ret).encode() + b"\0"
|
139 |
+
except (ValueError, RuntimeError) as e:
|
140 |
+
ret = {
|
141 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
142 |
+
"error_code": ErrorCode.INTERNAL_ERROR,
|
143 |
+
}
|
144 |
+
yield json.dumps(ret).encode() + b"\0"
|
145 |
+
|
146 |
+
def generate_gate(self, params):
|
147 |
+
for x in self.generate_stream_gate(params):
|
148 |
+
pass
|
149 |
+
return json.loads(x[:-1].decode())
|
150 |
+
|
151 |
+
def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
|
152 |
+
if model_type_dict.get("is_bert"):
|
153 |
+
model_output = self.model(input_ids)
|
154 |
+
if model_type_dict.get("is_robert"):
|
155 |
+
data = model_output.last_hidden_state
|
156 |
+
else:
|
157 |
+
data = model_output[0]
|
158 |
+
elif model_type_dict.get("is_t5"):
|
159 |
+
model_output = self.model(input_ids, decoder_input_ids=input_ids)
|
160 |
+
data = model_output.encoder_last_hidden_state
|
161 |
+
else:
|
162 |
+
model_output = self.model(input_ids, output_hidden_states=True)
|
163 |
+
if model_type_dict.get("is_chatglm"):
|
164 |
+
data = model_output.hidden_states[-1].transpose(0, 1)
|
165 |
+
else:
|
166 |
+
data = model_output.hidden_states[-1]
|
167 |
+
|
168 |
+
if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling:
|
169 |
+
sum_embeddings = data[:, 0]
|
170 |
+
else:
|
171 |
+
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
172 |
+
masked_embeddings = data * mask
|
173 |
+
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
174 |
+
token_num = torch.sum(attention_mask).item()
|
175 |
+
|
176 |
+
return sum_embeddings, token_num
|
177 |
+
|
178 |
+
def __encode_base64(self, embeddings: torch.Tensor) -> List[str]:
|
179 |
+
embeddings = embeddings.cpu()
|
180 |
+
return [
|
181 |
+
base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings
|
182 |
+
]
|
183 |
+
|
184 |
+
@torch.inference_mode()
|
185 |
+
def get_embeddings(self, params):
|
186 |
+
self.call_ct += 1
|
187 |
+
|
188 |
+
try:
|
189 |
+
tokenizer = self.tokenizer
|
190 |
+
ret = {"embedding": [], "token_num": 0}
|
191 |
+
|
192 |
+
model_type_dict = {
|
193 |
+
"is_llama": "llama" in str(type(self.model)),
|
194 |
+
"is_t5": "t5" in str(type(self.model)),
|
195 |
+
"is_chatglm": "chatglm" in str(type(self.model)),
|
196 |
+
"is_bert": "bert" in str(type(self.model)),
|
197 |
+
"is_robert": "robert" in str(type(self.model)),
|
198 |
+
}
|
199 |
+
|
200 |
+
if self.embed_in_truncate:
|
201 |
+
encoding = tokenizer.batch_encode_plus(
|
202 |
+
params["input"],
|
203 |
+
padding=True,
|
204 |
+
truncation="longest_first",
|
205 |
+
return_tensors="pt",
|
206 |
+
max_length=self.context_len,
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
encoding = tokenizer.batch_encode_plus(
|
210 |
+
params["input"], padding=True, return_tensors="pt"
|
211 |
+
)
|
212 |
+
input_ids = encoding["input_ids"].to(self.device)
|
213 |
+
attention_mask = input_ids != tokenizer.pad_token_id
|
214 |
+
|
215 |
+
base64_encode = params.get("encoding_format", None)
|
216 |
+
|
217 |
+
if self.embed_in_truncate:
|
218 |
+
embedding, token_num = self.__process_embed_chunk(
|
219 |
+
input_ids, attention_mask, **model_type_dict
|
220 |
+
)
|
221 |
+
if (
|
222 |
+
not hasattr(self.model, "use_cls_pooling")
|
223 |
+
or not self.model.use_cls_pooling
|
224 |
+
):
|
225 |
+
embedding = embedding / token_num
|
226 |
+
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
227 |
+
ret["token_num"] = token_num
|
228 |
+
else:
|
229 |
+
all_embeddings = []
|
230 |
+
all_token_num = 0
|
231 |
+
for i in range(0, input_ids.size(1), self.context_len):
|
232 |
+
chunk_input_ids = input_ids[:, i : i + self.context_len]
|
233 |
+
chunk_attention_mask = attention_mask[:, i : i + self.context_len]
|
234 |
+
|
235 |
+
# add cls token and mask to get cls embedding
|
236 |
+
if (
|
237 |
+
hasattr(self.model, "use_cls_pooling")
|
238 |
+
and self.model.use_cls_pooling
|
239 |
+
):
|
240 |
+
cls_tokens = (
|
241 |
+
torch.zeros(
|
242 |
+
(chunk_input_ids.size(0), 1),
|
243 |
+
dtype=chunk_input_ids.dtype,
|
244 |
+
device=chunk_input_ids.device,
|
245 |
+
)
|
246 |
+
+ tokenizer.cls_token_id
|
247 |
+
)
|
248 |
+
chunk_input_ids = torch.cat(
|
249 |
+
[cls_tokens, chunk_input_ids], dim=-1
|
250 |
+
)
|
251 |
+
mask = torch.ones(
|
252 |
+
(chunk_attention_mask.size(0), 1),
|
253 |
+
dtype=chunk_attention_mask.dtype,
|
254 |
+
device=chunk_attention_mask.device,
|
255 |
+
)
|
256 |
+
chunk_attention_mask = torch.cat(
|
257 |
+
[mask, chunk_attention_mask], dim=-1
|
258 |
+
)
|
259 |
+
|
260 |
+
chunk_embeddings, token_num = self.__process_embed_chunk(
|
261 |
+
chunk_input_ids, chunk_attention_mask, **model_type_dict
|
262 |
+
)
|
263 |
+
if (
|
264 |
+
hasattr(self.model, "use_cls_pooling")
|
265 |
+
and self.model.use_cls_pooling
|
266 |
+
):
|
267 |
+
all_embeddings.append(chunk_embeddings * token_num)
|
268 |
+
else:
|
269 |
+
all_embeddings.append(chunk_embeddings)
|
270 |
+
all_token_num += token_num
|
271 |
+
|
272 |
+
all_embeddings_tensor = torch.stack(all_embeddings)
|
273 |
+
embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
|
274 |
+
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
275 |
+
|
276 |
+
ret["token_num"] = all_token_num
|
277 |
+
|
278 |
+
if base64_encode == "base64":
|
279 |
+
out_embeddings = self.__encode_base64(normalized_embeddings)
|
280 |
+
else:
|
281 |
+
out_embeddings = normalized_embeddings.tolist()
|
282 |
+
ret["embedding"] = out_embeddings
|
283 |
+
|
284 |
+
gc.collect()
|
285 |
+
torch.cuda.empty_cache()
|
286 |
+
if self.device == "xpu":
|
287 |
+
torch.xpu.empty_cache()
|
288 |
+
if self.device == "npu":
|
289 |
+
torch.npu.empty_cache()
|
290 |
+
except torch.cuda.OutOfMemoryError as e:
|
291 |
+
ret = {
|
292 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
293 |
+
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
|
294 |
+
}
|
295 |
+
except (ValueError, RuntimeError) as e:
|
296 |
+
ret = {
|
297 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
298 |
+
"error_code": ErrorCode.INTERNAL_ERROR,
|
299 |
+
}
|
300 |
+
return ret
|
301 |
+
|
302 |
+
|
303 |
+
def create_model_worker():
|
304 |
+
parser = argparse.ArgumentParser()
|
305 |
+
parser.add_argument("--host", type=str, default="localhost")
|
306 |
+
parser.add_argument("--port", type=int, default=21002)
|
307 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
308 |
+
parser.add_argument(
|
309 |
+
"--controller-address", type=str, default="http://localhost:21001"
|
310 |
+
)
|
311 |
+
add_model_args(parser)
|
312 |
+
parser.add_argument(
|
313 |
+
"--model-names",
|
314 |
+
type=lambda s: s.split(","),
|
315 |
+
help="Optional display comma separated names",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
319 |
+
)
|
320 |
+
parser.add_argument("--embed-in-truncate", action="store_true")
|
321 |
+
parser.add_argument(
|
322 |
+
"--limit-worker-concurrency",
|
323 |
+
type=int,
|
324 |
+
default=5,
|
325 |
+
help="Limit the model concurrency to prevent OOM.",
|
326 |
+
)
|
327 |
+
parser.add_argument("--stream-interval", type=int, default=2)
|
328 |
+
parser.add_argument("--no-register", action="store_true")
|
329 |
+
parser.add_argument(
|
330 |
+
"--seed",
|
331 |
+
type=int,
|
332 |
+
default=None,
|
333 |
+
help="Overwrite the random seed for each generation.",
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--debug", type=bool, default=False, help="Print debugging messages"
|
337 |
+
)
|
338 |
+
parser.add_argument(
|
339 |
+
"--ssl",
|
340 |
+
action="store_true",
|
341 |
+
required=False,
|
342 |
+
default=False,
|
343 |
+
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
|
344 |
+
)
|
345 |
+
args = parser.parse_args()
|
346 |
+
logger.info(f"args: {args}")
|
347 |
+
|
348 |
+
if args.gpus:
|
349 |
+
if len(args.gpus.split(",")) < args.num_gpus:
|
350 |
+
raise ValueError(
|
351 |
+
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
352 |
+
)
|
353 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
354 |
+
|
355 |
+
gptq_config = GptqConfig(
|
356 |
+
ckpt=args.gptq_ckpt or args.model_path,
|
357 |
+
wbits=args.gptq_wbits,
|
358 |
+
groupsize=args.gptq_groupsize,
|
359 |
+
act_order=args.gptq_act_order,
|
360 |
+
)
|
361 |
+
awq_config = AWQConfig(
|
362 |
+
ckpt=args.awq_ckpt or args.model_path,
|
363 |
+
wbits=args.awq_wbits,
|
364 |
+
groupsize=args.awq_groupsize,
|
365 |
+
)
|
366 |
+
if args.enable_exllama:
|
367 |
+
exllama_config = ExllamaConfig(
|
368 |
+
max_seq_len=args.exllama_max_seq_len,
|
369 |
+
gpu_split=args.exllama_gpu_split,
|
370 |
+
cache_8bit=args.exllama_cache_8bit,
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
exllama_config = None
|
374 |
+
if args.enable_xft:
|
375 |
+
xft_config = XftConfig(
|
376 |
+
max_seq_len=args.xft_max_seq_len,
|
377 |
+
data_type=args.xft_dtype,
|
378 |
+
)
|
379 |
+
if args.device != "cpu":
|
380 |
+
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
|
381 |
+
args.device = "cpu"
|
382 |
+
else:
|
383 |
+
xft_config = None
|
384 |
+
|
385 |
+
worker = ModelWorker(
|
386 |
+
args.controller_address,
|
387 |
+
args.worker_address,
|
388 |
+
worker_id,
|
389 |
+
args.model_path,
|
390 |
+
args.model_names,
|
391 |
+
args.limit_worker_concurrency,
|
392 |
+
revision=args.revision,
|
393 |
+
no_register=args.no_register,
|
394 |
+
device=args.device,
|
395 |
+
num_gpus=args.num_gpus,
|
396 |
+
max_gpu_memory=args.max_gpu_memory,
|
397 |
+
dtype=str_to_torch_dtype(args.dtype),
|
398 |
+
load_8bit=args.load_8bit,
|
399 |
+
cpu_offloading=args.cpu_offloading,
|
400 |
+
gptq_config=gptq_config,
|
401 |
+
awq_config=awq_config,
|
402 |
+
exllama_config=exllama_config,
|
403 |
+
xft_config=xft_config,
|
404 |
+
stream_interval=args.stream_interval,
|
405 |
+
conv_template=args.conv_template,
|
406 |
+
embed_in_truncate=args.embed_in_truncate,
|
407 |
+
seed=args.seed,
|
408 |
+
debug=args.debug,
|
409 |
+
)
|
410 |
+
return args, worker
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == "__main__":
|
414 |
+
args, worker = create_model_worker()
|
415 |
+
if args.ssl:
|
416 |
+
uvicorn.run(
|
417 |
+
app,
|
418 |
+
host=args.host,
|
419 |
+
port=args.port,
|
420 |
+
log_level="info",
|
421 |
+
ssl_keyfile=os.environ["SSL_KEYFILE"],
|
422 |
+
ssl_certfile=os.environ["SSL_CERTFILE"],
|
423 |
+
)
|
424 |
+
else:
|
425 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
monitor/__pycache__/basic_stats.cpython-39.pyc
ADDED
Binary file (6.1 kB). View file
|
|
monitor/__pycache__/clean_battle_data.cpython-39.pyc
ADDED
Binary file (7 kB). View file
|
|
monitor/__pycache__/clean_chat_data.cpython-39.pyc
ADDED
Binary file (4.36 kB). View file
|
|
monitor/__pycache__/elo_analysis.cpython-39.pyc
ADDED
Binary file (9.21 kB). View file
|
|
monitor/__pycache__/inspect_conv.cpython-39.pyc
ADDED
Binary file (2.16 kB). View file
|
|