Spaces:
Runtime error
Runtime error
""" | |
The gradio demo server for chatting with a single model. | |
""" | |
import argparse | |
from collections import defaultdict | |
import datetime | |
import hashlib | |
import json | |
import os | |
import random | |
import time | |
import uuid | |
import gradio as gr | |
import requests | |
from fastchat.constants import ( | |
LOGDIR, | |
WORKER_API_TIMEOUT, | |
ErrorCode, | |
MODERATION_MSG, | |
CONVERSATION_LIMIT_MSG, | |
RATE_LIMIT_MSG, | |
SERVER_ERROR_MSG, | |
INPUT_CHAR_LEN_LIMIT, | |
CONVERSATION_TURN_LIMIT, | |
SESSION_EXPIRATION_TIME, | |
) | |
from fastchat.model.model_adapter import ( | |
get_conversation_template, | |
) | |
from fastchat.model.model_registry import get_model_info, model_info | |
from fastchat.serve.api_provider import get_api_provider_stream_iter | |
from fastchat.utils import ( | |
build_logger, | |
get_window_url_params_js, | |
get_window_url_params_with_tos_js, | |
moderation_filter, | |
parse_gradio_auth_creds, | |
load_image, | |
) | |
logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
headers = {"User-Agent": "FastChat Client"} | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True, visible=True) | |
disable_btn = gr.Button(interactive=False) | |
invisible_btn = gr.Button(interactive=False, visible=False) | |
controller_url = None | |
enable_moderation = False | |
acknowledgment_md = """ | |
### Terms of Service | |
Users are required to agree to the following terms before using the service: | |
The service is a research preview. It only provides limited safety measures and may generate offensive content. | |
It must not be used for any illegal, harmful, violent, racist, or sexual purposes. | |
The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license. | |
Additionally, Bard is offered on LMSys for research purposes only. To access the Bard product, please visit its [website](http://bard.google.com). | |
### Acknowledgment | |
We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous [sponsorship](https://lmsys.org/donations/). | |
<div class="sponsor-image-about"> | |
<img src="https://storage.googleapis.com/public-arena-asset/kaggle.png" alt="Kaggle"> | |
<img src="https://storage.googleapis.com/public-arena-asset/mbzuai.jpeg" alt="MBZUAI"> | |
<img src="https://storage.googleapis.com/public-arena-asset/a16z.jpeg" alt="a16z"> | |
<img src="https://storage.googleapis.com/public-arena-asset/together.png" alt="Together AI"> | |
<img src="https://storage.googleapis.com/public-arena-asset/anyscale.png" alt="AnyScale"> | |
<img src="https://storage.googleapis.com/public-arena-asset/huggingface.png" alt="HuggingFace"> | |
</div> | |
""" | |
# JSON file format of API-based models: | |
# { | |
# "gpt-3.5-turbo-0613": { | |
# "model_name": "gpt-3.5-turbo-0613", | |
# "api_type": "openai", | |
# "api_base": "https://api.openai.com/v1", | |
# "api_key": "sk-******", | |
# "anony_only": false | |
# } | |
# } | |
# "api_type" can be one of the following: openai, anthropic, gemini, mistral. | |
# "anony_only" means whether to show this model in anonymous mode only. | |
api_endpoint_info = {} | |
class State: | |
def __init__(self, model_name): | |
self.conv = get_conversation_template(model_name) | |
self.conv_id = uuid.uuid4().hex | |
self.skip_next = False | |
self.model_name = model_name | |
def to_gradio_chatbot(self): | |
return self.conv.to_gradio_chatbot() | |
def dict(self): | |
base = self.conv.dict() | |
base.update( | |
{ | |
"conv_id": self.conv_id, | |
"model_name": self.model_name, | |
} | |
) | |
return base | |
def set_global_vars(controller_url_, enable_moderation_): | |
global controller_url, enable_moderation | |
controller_url = controller_url_ | |
enable_moderation = enable_moderation_ | |
def get_conv_log_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
return name | |
def get_model_list(controller_url, register_api_endpoint_file, multimodal): | |
global api_endpoint_info | |
# Add models from the controller | |
if controller_url: | |
ret = requests.post(controller_url + "/refresh_all_workers") | |
assert ret.status_code == 200 | |
if multimodal: | |
ret = requests.post(controller_url + "/list_multimodal_models") | |
models = ret.json()["models"] | |
else: | |
ret = requests.post(controller_url + "/list_language_models") | |
models = ret.json()["models"] | |
else: | |
models = [] | |
# Add models from the API providers | |
if register_api_endpoint_file: | |
api_endpoint_info = json.load(open(register_api_endpoint_file)) | |
for mdl, mdl_dict in api_endpoint_info.items(): | |
mdl_multimodal = mdl_dict.get("multimodal", False) | |
if multimodal and mdl_multimodal: | |
models += [mdl] | |
elif not multimodal and not mdl_multimodal: | |
models += [mdl] | |
# Remove anonymous models | |
models = list(set(models)) | |
visible_models = models.copy() | |
for mdl in visible_models: | |
if mdl not in api_endpoint_info: | |
continue | |
mdl_dict = api_endpoint_info[mdl] | |
if mdl_dict["anony_only"]: | |
visible_models.remove(mdl) | |
# Sort models and add descriptions | |
priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)} | |
models.sort(key=lambda x: priority.get(x, x)) | |
visible_models.sort(key=lambda x: priority.get(x, x)) | |
logger.info(f"All models: {models}") | |
logger.info(f"Visible models: {visible_models}") | |
return visible_models, models | |
def load_demo_single(models, url_params): | |
selected_model = models[0] if len(models) > 0 else "" | |
if "model" in url_params: | |
model = url_params["model"] | |
if model in models: | |
selected_model = model | |
dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True) | |
state = None | |
return state, dropdown_update | |
def load_demo(url_params, request: gr.Request): | |
global models | |
ip = get_ip(request) | |
logger.info(f"load_demo. ip: {ip}. params: {url_params}") | |
if args.model_list_mode == "reload": | |
models, all_models = get_model_list( | |
controller_url, args.register_api_endpoint_file, False | |
) | |
return load_demo_single(models, url_params) | |
def vote_last_response(state, vote_type, model_selector, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"model": model_selector, | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
def upvote_last_response(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
logger.info(f"upvote. ip: {ip}") | |
vote_last_response(state, "upvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def downvote_last_response(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
logger.info(f"downvote. ip: {ip}") | |
vote_last_response(state, "downvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def flag_last_response(state, model_selector, request: gr.Request): | |
ip = get_ip(request) | |
logger.info(f"flag. ip: {ip}") | |
vote_last_response(state, "flag", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def regenerate(state, request: gr.Request): | |
ip = get_ip(request) | |
logger.info(f"regenerate. ip: {ip}") | |
state.conv.update_last_message(None) | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def clear_history(request: gr.Request): | |
ip = get_ip(request) | |
logger.info(f"clear_history. ip: {ip}") | |
state = None | |
return (state, [], "", None) + (disable_btn,) * 5 | |
def get_ip(request: gr.Request): | |
if "cf-connecting-ip" in request.headers: | |
ip = request.headers["cf-connecting-ip"] | |
else: | |
ip = request.client.host | |
return ip | |
def _prepare_text_with_image(state, text, image): | |
if image is not None: | |
if len(state.conv.get_images()) > 0: | |
# reset convo with new image | |
state.conv = get_conversation_template(state.model_name) | |
image = state.conv.convert_image_to_base64( | |
image | |
) # PIL type is not JSON serializable | |
text = text, [image] | |
return text | |
def add_text(state, model_selector, text, image, request: gr.Request): | |
ip = get_ip(request) | |
logger.info(f"add_text. ip: {ip}. len: {len(text)}") | |
if state is None: | |
state = State(model_selector) | |
if len(text) <= 0: | |
state.skip_next = True | |
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 | |
flagged = moderation_filter(text, [state.model_name]) | |
if flagged: | |
logger.info(f"violate moderation. ip: {ip}. text: {text}") | |
# overwrite the original text | |
text = MODERATION_MSG | |
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: | |
logger.info(f"conversation turn limit. ip: {ip}. text: {text}") | |
state.skip_next = True | |
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + ( | |
no_change_btn, | |
) * 5 | |
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off | |
text = _prepare_text_with_image(state, text, image) | |
state.conv.append_message(state.conv.roles[0], text) | |
state.conv.append_message(state.conv.roles[1], None) | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def model_worker_stream_iter( | |
conv, | |
model_name, | |
worker_addr, | |
prompt, | |
temperature, | |
repetition_penalty, | |
top_p, | |
max_new_tokens, | |
images, | |
): | |
# Make requests | |
gen_params = { | |
"model": model_name, | |
"prompt": prompt, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
"top_p": top_p, | |
"max_new_tokens": max_new_tokens, | |
"stop": conv.stop_str, | |
"stop_token_ids": conv.stop_token_ids, | |
"echo": False, | |
} | |
logger.info(f"==== request ====\n{gen_params}") | |
if len(images) > 0: | |
gen_params["images"] = images | |
# Stream output | |
response = requests.post( | |
worker_addr + "/worker_generate_stream", | |
headers=headers, | |
json=gen_params, | |
stream=True, | |
timeout=WORKER_API_TIMEOUT, | |
) | |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
if chunk: | |
data = json.loads(chunk.decode()) | |
yield data | |
def is_limit_reached(model_name, ip): | |
monitor_url = "http://localhost:9090" | |
try: | |
ret = requests.get( | |
f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1 | |
) | |
obj = ret.json() | |
return obj | |
except Exception as e: | |
logger.info(f"monitor error: {e}") | |
return None | |
def bot_response( | |
state, | |
temperature, | |
top_p, | |
max_new_tokens, | |
request: gr.Request, | |
apply_rate_limit=True, | |
): | |
ip = get_ip(request) | |
logger.info(f"bot_response. ip: {ip}") | |
start_tstamp = time.time() | |
temperature = float(temperature) | |
top_p = float(top_p) | |
max_new_tokens = int(max_new_tokens) | |
if state.skip_next: | |
# This generate call is skipped due to invalid inputs | |
state.skip_next = False | |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 | |
return | |
if apply_rate_limit: | |
ret = is_limit_reached(state.model_name, ip) | |
if ret is not None and ret["is_limit_reached"]: | |
error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"] | |
logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}") | |
state.conv.update_last_message(error_msg) | |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 | |
return | |
conv, model_name = state.conv, state.model_name | |
model_api_dict = ( | |
api_endpoint_info[model_name] if model_name in api_endpoint_info else None | |
) | |
images = conv.get_images() | |
if model_api_dict is None: | |
# Query worker address | |
ret = requests.post( | |
controller_url + "/get_worker_address", json={"model": model_name} | |
) | |
worker_addr = ret.json()["address"] | |
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") | |
# No available worker | |
if worker_addr == "": | |
conv.update_last_message(SERVER_ERROR_MSG) | |
yield ( | |
state, | |
state.to_gradio_chatbot(), | |
disable_btn, | |
disable_btn, | |
disable_btn, | |
enable_btn, | |
enable_btn, | |
) | |
return | |
# Construct prompt. | |
# We need to call it here, so it will not be affected by "▌". | |
prompt = conv.get_prompt() | |
# Set repetition_penalty | |
if "t5" in model_name: | |
repetition_penalty = 1.2 | |
else: | |
repetition_penalty = 1.0 | |
stream_iter = model_worker_stream_iter( | |
conv, | |
model_name, | |
worker_addr, | |
prompt, | |
temperature, | |
repetition_penalty, | |
top_p, | |
max_new_tokens, | |
images, | |
) | |
else: | |
stream_iter = get_api_provider_stream_iter( | |
conv, | |
model_name, | |
model_api_dict, | |
temperature, | |
top_p, | |
max_new_tokens, | |
) | |
conv.update_last_message("▌") | |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 | |
try: | |
for i, data in enumerate(stream_iter): | |
if data["error_code"] == 0: | |
output = data["text"].strip() | |
conv.update_last_message(output + "▌") | |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 | |
else: | |
output = data["text"] + f"\n\n(error_code: {data['error_code']})" | |
conv.update_last_message(output) | |
yield (state, state.to_gradio_chatbot()) + ( | |
disable_btn, | |
disable_btn, | |
disable_btn, | |
enable_btn, | |
enable_btn, | |
) | |
return | |
output = data["text"].strip() | |
conv.update_last_message(output) | |
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 | |
except requests.exceptions.RequestException as e: | |
conv.update_last_message( | |
f"{SERVER_ERROR_MSG}\n\n" | |
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" | |
) | |
yield (state, state.to_gradio_chatbot()) + ( | |
disable_btn, | |
disable_btn, | |
disable_btn, | |
enable_btn, | |
enable_btn, | |
) | |
return | |
except Exception as e: | |
conv.update_last_message( | |
f"{SERVER_ERROR_MSG}\n\n" | |
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" | |
) | |
yield (state, state.to_gradio_chatbot()) + ( | |
disable_btn, | |
disable_btn, | |
disable_btn, | |
enable_btn, | |
enable_btn, | |
) | |
return | |
finish_tstamp = time.time() | |
logger.info(f"{output}") | |
# We load the image because gradio accepts base64 but that increases file size by ~1.33x | |
loaded_images = [load_image(image) for image in images] | |
images_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images] | |
for image, hash_str in zip(loaded_images, images_hash): | |
t = datetime.datetime.now() | |
filename = os.path.join( | |
LOGDIR, | |
"serve_images", | |
f"{hash_str}.jpg", | |
) | |
if not os.path.isfile(filename): | |
os.makedirs(os.path.dirname(filename), exist_ok=True) | |
image.save(filename) | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_name, | |
"gen_params": { | |
"temperature": temperature, | |
"top_p": top_p, | |
"max_new_tokens": max_new_tokens, | |
}, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state.dict(), | |
"ip": get_ip(request), | |
"images": images_hash, | |
} | |
fout.write(json.dumps(data) + "\n") | |
block_css = """ | |
#notice_markdown .prose { | |
font-size: 120% !important; | |
} | |
#notice_markdown th { | |
display: none; | |
} | |
#notice_markdown td { | |
padding-top: 6px; | |
padding-bottom: 6px; | |
} | |
#model_description_markdown { | |
font-size: 120% !important; | |
} | |
#leaderboard_markdown .prose { | |
font-size: 120% !important; | |
} | |
#leaderboard_markdown td { | |
padding-top: 6px; | |
padding-bottom: 6px; | |
} | |
#leaderboard_dataframe td { | |
line-height: 0.1em; | |
} | |
#about_markdown .prose { | |
font-size: 120% !important; | |
} | |
#ack_markdown .prose { | |
font-size: 120% !important; | |
} | |
footer { | |
display:none !important; | |
} | |
.sponsor-image-about img { | |
margin: 0 20px; | |
margin-top: 20px; | |
height: 40px; | |
max-height: 100%; | |
width: auto; | |
float: left; | |
} | |
""" | |
def get_model_description_md(models): | |
model_description_md = """ | |
| | | | | |
| ---- | ---- | ---- | | |
""" | |
ct = 0 | |
visited = set() | |
for i, name in enumerate(models): | |
minfo = get_model_info(name) | |
if minfo.simple_name in visited: | |
continue | |
visited.add(minfo.simple_name) | |
one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" | |
if ct % 3 == 0: | |
model_description_md += "|" | |
model_description_md += f" {one_model_md} |" | |
if ct % 3 == 2: | |
model_description_md += "\n" | |
ct += 1 | |
return model_description_md | |
def build_about(): | |
about_markdown = """ | |
# About Us | |
Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our [FastChat](https://github.com/lm-sys/FastChat) project at GitHub and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey! | |
## Read More | |
- Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/) | |
- LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998) | |
## Core Members | |
[Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ) | |
## Advisors | |
[Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/) | |
## Contact Us | |
- Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at [email protected] | |
- File issues on [GitHub](https://github.com/lm-sys/FastChat) | |
- Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys) | |
## Acknowledgment | |
We thank [SkyPilot](https://github.com/skypilot-org/skypilot) and [Gradio](https://github.com/gradio-app/gradio) team for their system support. | |
We also thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. Learn more about partnership [here](https://lmsys.org/donations/). | |
<div class="sponsor-image-about"> | |
<img src="https://storage.googleapis.com/public-arena-asset/kaggle.png" alt="Kaggle"> | |
<img src="https://storage.googleapis.com/public-arena-asset/mbzuai.jpeg" alt="MBZUAI"> | |
<img src="https://storage.googleapis.com/public-arena-asset/a16z.jpeg" alt="a16z"> | |
<img src="https://storage.googleapis.com/public-arena-asset/together.png" alt="Together AI"> | |
<img src="https://storage.googleapis.com/public-arena-asset/anyscale.png" alt="AnyScale"> | |
<img src="https://storage.googleapis.com/public-arena-asset/huggingface.png" alt="HuggingFace"> | |
</div> | |
""" | |
gr.Markdown(about_markdown, elem_id="about_markdown") | |
def build_single_model_ui(models, add_promotion_links=False): | |
promotion = ( | |
""" | |
- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | | |
- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/) | |
- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/) | |
## 🤖 Choose any model to chat | |
""" | |
if add_promotion_links | |
else "" | |
) | |
notice_markdown = f""" | |
# 🏔️ Chat with Open Large Language Models | |
{promotion} | |
""" | |
state = gr.State() | |
gr.Markdown(notice_markdown, elem_id="notice_markdown") | |
with gr.Group(elem_id="share-region-named"): | |
with gr.Row(elem_id="model_selector_row"): | |
model_selector = gr.Dropdown( | |
choices=models, | |
value=models[0] if len(models) > 0 else "", | |
interactive=True, | |
show_label=False, | |
container=False, | |
) | |
with gr.Row(): | |
with gr.Accordion( | |
f"🔍 Expand to see the descriptions of {len(models)} models", | |
open=False, | |
): | |
model_description_md = get_model_description_md(models) | |
gr.Markdown(model_description_md, elem_id="model_description_markdown") | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="Scroll down and start chatting", | |
height=550, | |
show_copy_button=True, | |
) | |
with gr.Row(): | |
textbox = gr.Textbox( | |
show_label=False, | |
placeholder="👉 Enter your prompt and press ENTER", | |
elem_id="input_box", | |
) | |
send_btn = gr.Button(value="Send", variant="primary", scale=0) | |
with gr.Row() as button_row: | |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False) | |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False) | |
flag_btn = gr.Button(value="⚠️ Flag", interactive=False) | |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) | |
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) | |
with gr.Accordion("Parameters", open=False) as parameter_row: | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
interactive=True, | |
label="Top P", | |
) | |
max_output_tokens = gr.Slider( | |
minimum=16, | |
maximum=2048, | |
value=1024, | |
step=64, | |
interactive=True, | |
label="Max output tokens", | |
) | |
if add_promotion_links: | |
gr.Markdown(acknowledgment_md, elem_id="ack_markdown") | |
# Register listeners | |
imagebox = gr.State(None) | |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
upvote_btn.click( | |
upvote_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
) | |
downvote_btn.click( | |
downvote_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
) | |
flag_btn.click( | |
flag_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
) | |
regenerate_btn.click( | |
regenerate, state, [state, chatbot, textbox, imagebox] + btn_list | |
).then( | |
bot_response, | |
[state, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
) | |
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) | |
model_selector.change( | |
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list | |
) | |
textbox.submit( | |
add_text, | |
[state, model_selector, textbox, imagebox], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
).then( | |
bot_response, | |
[state, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
) | |
send_btn.click( | |
add_text, | |
[state, model_selector, textbox, imagebox], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
).then( | |
bot_response, | |
[state, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
) | |
return [state, model_selector] | |
def build_demo(models): | |
with gr.Blocks( | |
title="Chat with Open Large Language Models", | |
theme=gr.themes.Default(), | |
css=block_css, | |
) as demo: | |
url_params = gr.JSON(visible=False) | |
state, model_selector = build_single_model_ui(models) | |
if args.model_list_mode not in ["once", "reload"]: | |
raise ValueError(f"Unknown model list mode: {args.model_list_mode}") | |
if args.show_terms_of_use: | |
load_js = get_window_url_params_with_tos_js | |
else: | |
load_js = get_window_url_params_js | |
demo.load( | |
load_demo, | |
[url_params], | |
[ | |
state, | |
model_selector, | |
], | |
js=load_js, | |
) | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int) | |
parser.add_argument( | |
"--share", | |
action="store_true", | |
help="Whether to generate a public, shareable link", | |
) | |
parser.add_argument( | |
"--controller-url", | |
type=str, | |
default="http://localhost:21001", | |
help="The address of the controller", | |
) | |
parser.add_argument( | |
"--concurrency-count", | |
type=int, | |
default=10, | |
help="The concurrency count of the gradio queue", | |
) | |
parser.add_argument( | |
"--model-list-mode", | |
type=str, | |
default="once", | |
choices=["once", "reload"], | |
help="Whether to load the model list once or reload the model list every time", | |
) | |
parser.add_argument( | |
"--moderate", | |
action="store_true", | |
help="Enable content moderation to block unsafe inputs", | |
) | |
parser.add_argument( | |
"--show-terms-of-use", | |
action="store_true", | |
help="Shows term of use before loading the demo", | |
) | |
parser.add_argument( | |
"--register-api-endpoint-file", | |
type=str, | |
help="Register API-based model endpoints from a JSON file", | |
) | |
parser.add_argument( | |
"--gradio-auth-path", | |
type=str, | |
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', | |
) | |
parser.add_argument( | |
"--gradio-root-path", | |
type=str, | |
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", | |
) | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
# Set global variables | |
set_global_vars(args.controller_url, args.moderate) | |
models, all_models = get_model_list( | |
args.controller_url, args.register_api_endpoint_file, False | |
) | |
# Set authorization credentials | |
auth = None | |
if args.gradio_auth_path is not None: | |
auth = parse_gradio_auth_creds(args.gradio_auth_path) | |
# Launch the demo | |
demo = build_demo(models) | |
demo.queue( | |
default_concurrency_limit=args.concurrency_count, | |
status_update_rate=10, | |
api_open=False, | |
).launch( | |
server_name=args.host, | |
server_port=args.port, | |
share=args.share, | |
max_threads=200, | |
auth=auth, | |
root_path=args.gradio_root_path, | |
) | |