Spaces:
Running
Running
File size: 3,314 Bytes
0dd5c06 73e8b86 67812d2 73e8b86 1371afd 73e8b86 c73f9e9 6b89337 73e8b86 6b89337 c73f9e9 73e8b86 6b89337 73e8b86 6b89337 73e8b86 6b89337 73e8b86 c73f9e9 73e8b86 6b89337 73e8b86 44ad98f 73e8b86 c73f9e9 6b89337 c73f9e9 73e8b86 c73f9e9 73e8b86 2e9ad55 6b89337 73e8b86 6b89337 73e8b86 6b89337 73e8b86 6b89337 73e8b86 6b89337 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
"""
It provides a platform for comparing the responses of two LLMs.
"""
from random import sample
from fastchat.serve import gradio_web_server
from fastchat.serve.gradio_web_server import bot_response
import gradio as gr
# TODO(#1): Add more models.
SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gemini-pro"]
def user(user_prompt):
model_pair = sample(SUPPORTED_MODELS, 2)
new_state_a = gradio_web_server.State(model_pair[0])
new_state_b = gradio_web_server.State(model_pair[1])
for state in [new_state_a, new_state_b]:
state.conv.append_message(state.conv.roles[0], user_prompt)
state.conv.append_message(state.conv.roles[1], None)
state.skip_next = False
return [
new_state_a, new_state_b, new_state_a.model_name, new_state_b.model_name
]
def bot(state_a, state_b, request: gr.Request):
new_states = [state_a, state_b]
generators = []
for state in new_states:
try:
# TODO(#1): Allow user to set configuration.
# bot_response returns a generator yielding states.
generator = bot_response(state,
temperature=0.9,
top_p=0.9,
max_new_tokens=100,
request=request)
generators.append(generator)
# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error in bot_response: {e}")
raise e
new_responses = [None, None]
# It simulates concurrent response generation from two models.
while True:
stop = True
for i in range(len(generators)):
try:
yielded = next(generators[i])
# The generator yields a tuple, with the new state as the first item.
new_state = yielded[0]
new_states[i] = new_state
# The last item from 'messages' represents the response to the prompt.
bot_message = new_state.conv.messages[-1]
# Each message in conv.messages is structured as [role, message],
# so we extract the last message component.
new_responses[i] = bot_message[-1]
stop = False
except StopIteration:
pass
# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error in generator: {e}")
raise e
yield new_states + new_responses
if stop:
break
with gr.Blocks() as app:
model_names = [gr.State(None), gr.State(None)]
responses = [gr.State(None), gr.State(None)]
# states stores FastChat-specific conversation states.
states = [gr.State(None), gr.State(None)]
prompt = gr.TextArea(label="Prompt", lines=4)
submit = gr.Button()
with gr.Row():
responses[0] = gr.Textbox(label="Model A", interactive=False)
responses[1] = gr.Textbox(label="Model B", interactive=False)
with gr.Accordion("Show models", open=False):
with gr.Row():
model_names[0] = gr.Textbox(label="Model A", interactive=False)
model_names[1] = gr.Textbox(label="Model B", interactive=False)
submit.click(user, prompt, states + model_names,
queue=False).then(bot, states, states + responses)
if __name__ == "__main__":
# We need to enable queue to use generators.
app.queue()
app.launch(debug=True)
|