suhyun.kang commited on
Commit
cf196b3
·
1 Parent(s): 5e33531

[#1] Add voting feature

Browse files

Changes:
- Added response type selection.
- Included source and target language options for the "Translate" selection.
- Implemented the voting feature.

Firestore example item: https://console.firebase.google.com/u/0/project/special-tf-prod/firestore/data/~2Farena-summarizations~2F28213a8a0c1c44c295745841dabc7ad4?hl=ko

Screenshot: https://screen.yanolja.in/tCD6mJ0CpqoGDZwr.png

Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +109 -0
  3. requirments.txt +12 -0
.gitignore CHANGED
@@ -1 +1,2 @@
1
  venv
 
 
1
  venv
2
+ *.log
app.py CHANGED
@@ -2,15 +2,70 @@
2
  It provides a platform for comparing the responses of two LLMs.
3
  """
4
 
 
 
5
  from random import sample
 
6
 
7
  from fastchat.serve import gradio_web_server
8
  from fastchat.serve.gradio_web_server import bot_response
 
 
9
  import gradio as gr
10
 
 
 
 
11
  # TODO(#1): Add more models.
12
  SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gemini-pro"]
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def user(user_prompt):
16
  model_pair = sample(SUPPORTED_MODELS, 2)
@@ -85,6 +140,40 @@ def bot(state_a, state_b, request: gr.Request):
85
 
86
 
87
  with gr.Blocks() as app:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  model_names = [gr.State(None), gr.State(None)]
89
  responses = [gr.State(None), gr.State(None)]
90
 
@@ -98,6 +187,26 @@ with gr.Blocks() as app:
98
  responses[0] = gr.Textbox(label="Model A", interactive=False)
99
  responses[1] = gr.Textbox(label="Model B", interactive=False)
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  with gr.Accordion("Show models", open=False):
102
  with gr.Row():
103
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
 
2
  It provides a platform for comparing the responses of two LLMs.
3
  """
4
 
5
+ import enum
6
+ import json
7
  from random import sample
8
+ from uuid import uuid4
9
 
10
  from fastchat.serve import gradio_web_server
11
  from fastchat.serve.gradio_web_server import bot_response
12
+ import firebase_admin
13
+ from firebase_admin import firestore
14
  import gradio as gr
15
 
16
+ db_app = firebase_admin.initialize_app()
17
+ db = firestore.client()
18
+
19
  # TODO(#1): Add more models.
20
  SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gemini-pro"]
21
 
22
+ SUPPORTED_TRANSLATION_LANGUAGES = ["Korean", "English"]
23
+
24
+
25
+ class ResponseType(enum.Enum):
26
+ SUMMARIZE = "Summarize"
27
+ TRANSLATE = "Translate"
28
+
29
+
30
+ class VoteOptions(enum.Enum):
31
+ MODEL_A = "Model A is better"
32
+ MODEL_B = "Model B is better"
33
+ TIE = "Tie"
34
+
35
+
36
+ def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
37
+ winner = VoteOptions(vote_button).name.lower()
38
+
39
+ # The 'messages' field in the state is an array of arrays, a data type
40
+ # not supported by Firestore. Therefore, we convert it to a JSON string.
41
+ model_a_conv = json.dumps(state_a.dict())
42
+ model_b_conv = json.dumps(state_b.dict())
43
+
44
+ if res_type == ResponseType.SUMMARIZE.value:
45
+ doc_ref = db.collection("arena-summarizations").document(uuid4().hex)
46
+ doc_ref.set({
47
+ "model_a": state_a.model_name,
48
+ "model_b": state_b.model_name,
49
+ "model_a_conv": model_a_conv,
50
+ "model_b_conv": model_b_conv,
51
+ "winner": winner,
52
+ "timestamp": firestore.SERVER_TIMESTAMP
53
+ })
54
+ return
55
+
56
+ if res_type == ResponseType.TRANSLATE.value:
57
+ doc_ref = db.collection("arena-translations").document(uuid4().hex)
58
+ doc_ref.set({
59
+ "model_a": state_a.model_name,
60
+ "model_b": state_b.model_name,
61
+ "model_a_conv": model_a_conv,
62
+ "model_b_conv": model_b_conv,
63
+ "source_language": source_lang.lower(),
64
+ "target_language": target_lang.lower(),
65
+ "winner": winner,
66
+ "timestamp": firestore.SERVER_TIMESTAMP
67
+ })
68
+
69
 
70
  def user(user_prompt):
71
  model_pair = sample(SUPPORTED_MODELS, 2)
 
140
 
141
 
142
  with gr.Blocks() as app:
143
+ with gr.Row():
144
+ response_type_radio = gr.Radio(
145
+ [response_type.value for response_type in ResponseType],
146
+ label="Response type",
147
+ info="Choose the type of response you want from the model.")
148
+
149
+ source_language = gr.Dropdown(
150
+ choices=SUPPORTED_TRANSLATION_LANGUAGES,
151
+ label="Source language",
152
+ info="Choose the source language for translation.",
153
+ interactive=True,
154
+ visible=False)
155
+ target_language = gr.Dropdown(
156
+ choices=SUPPORTED_TRANSLATION_LANGUAGES,
157
+ label="Target language",
158
+ info="Choose the target language for translation.",
159
+ interactive=True,
160
+ visible=False)
161
+
162
+ def update_language_visibility(response_type):
163
+ if response_type != ResponseType.TRANSLATE.value:
164
+ return {
165
+ source_language: gr.Dropdown(visible=False),
166
+ target_language: gr.Dropdown(visible=False)
167
+ }
168
+
169
+ return {
170
+ source_language: gr.Dropdown(visible=True),
171
+ target_language: gr.Dropdown(visible=True)
172
+ }
173
+
174
+ response_type_radio.change(update_language_visibility, response_type_radio,
175
+ [source_language, target_language])
176
+
177
  model_names = [gr.State(None), gr.State(None)]
178
  responses = [gr.State(None), gr.State(None)]
179
 
 
187
  responses[0] = gr.Textbox(label="Model A", interactive=False)
188
  responses[1] = gr.Textbox(label="Model B", interactive=False)
189
 
190
+ # TODO(#1): Display it only after the user submits the prompt.
191
+ # TODO(#1): Block voting if the response_type is not set.
192
+ # TODO(#1): Block voting if the user already voted.
193
+ with gr.Row():
194
+ option_a = gr.Button(VoteOptions.MODEL_A.value)
195
+ option_a.click(
196
+ vote, states +
197
+ [option_a, response_type_radio, source_language, target_language])
198
+
199
+ option_b = gr.Button("Model B is better")
200
+ option_b.click(
201
+ vote, states +
202
+ [option_b, response_type_radio, source_language, target_language])
203
+
204
+ tie = gr.Button("Tie")
205
+ tie.click(
206
+ vote,
207
+ states + [tie, response_type_radio, source_language, target_language])
208
+
209
+ # TODO(#1): Hide it until the user votes.
210
  with gr.Accordion("Show models", open=False):
211
  with gr.Row():
212
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
requirments.txt CHANGED
@@ -6,26 +6,33 @@ altair==5.2.0
6
  annotated-types==0.6.0
7
  anyio==4.2.0
8
  attrs==23.2.0
 
9
  cachetools==5.3.2
10
  certifi==2023.11.17
 
11
  charset-normalizer==3.3.2
12
  click==8.1.7
13
  colorama==0.4.6
14
  contourpy==1.2.0
 
15
  cycler==0.12.1
16
  distro==1.9.0
17
  fastapi==0.109.0
18
  ffmpy==0.3.1
19
  filelock==3.13.1
 
20
  fonttools==4.47.2
21
  frozenlist==1.4.1
22
  fschat==0.2.35
23
  fsspec==2023.12.2
24
  google-api-core==2.16.1
 
25
  google-auth==2.27.0
 
26
  google-cloud-aiplatform==1.40.0
27
  google-cloud-bigquery==3.17.1
28
  google-cloud-core==2.4.1
 
29
  google-cloud-resource-manager==1.11.0
30
  google-cloud-storage==2.14.0
31
  google-crc32c==1.5.0
@@ -38,6 +45,7 @@ grpcio==1.60.0
38
  grpcio-status==1.60.0
39
  h11==0.14.0
40
  httpcore==1.0.2
 
41
  httpx==0.26.0
42
  huggingface-hub==0.20.3
43
  idna==3.6
@@ -52,6 +60,7 @@ MarkupSafe==2.1.4
52
  matplotlib==3.8.2
53
  mdurl==0.1.2
54
  mpmath==1.3.0
 
55
  multidict==6.0.4
56
  networkx==3.2.1
57
  nh3==0.2.15
@@ -68,10 +77,12 @@ protobuf==4.25.2
68
  psutil==5.9.8
69
  pyasn1==0.5.1
70
  pyasn1-modules==0.3.0
 
71
  pydantic==1.10.14
72
  pydantic_core==2.16.1
73
  pydub==0.25.1
74
  Pygments==2.17.2
 
75
  pyparsing==3.1.1
76
  python-dateutil==2.8.2
77
  python-multipart==0.0.6
@@ -105,6 +116,7 @@ transformers==4.37.2
105
  typer==0.9.0
106
  typing_extensions==4.9.0
107
  tzdata==2023.4
 
108
  urllib3==2.2.0
109
  uvicorn==0.27.0.post1
110
  wavedrom==2.0.3.post3
 
6
  annotated-types==0.6.0
7
  anyio==4.2.0
8
  attrs==23.2.0
9
+ CacheControl==0.13.1
10
  cachetools==5.3.2
11
  certifi==2023.11.17
12
+ cffi==1.16.0
13
  charset-normalizer==3.3.2
14
  click==8.1.7
15
  colorama==0.4.6
16
  contourpy==1.2.0
17
+ cryptography==42.0.2
18
  cycler==0.12.1
19
  distro==1.9.0
20
  fastapi==0.109.0
21
  ffmpy==0.3.1
22
  filelock==3.13.1
23
+ firebase-admin==6.4.0
24
  fonttools==4.47.2
25
  frozenlist==1.4.1
26
  fschat==0.2.35
27
  fsspec==2023.12.2
28
  google-api-core==2.16.1
29
+ google-api-python-client==2.116.0
30
  google-auth==2.27.0
31
+ google-auth-httplib2==0.2.0
32
  google-cloud-aiplatform==1.40.0
33
  google-cloud-bigquery==3.17.1
34
  google-cloud-core==2.4.1
35
+ google-cloud-firestore==2.14.0
36
  google-cloud-resource-manager==1.11.0
37
  google-cloud-storage==2.14.0
38
  google-crc32c==1.5.0
 
45
  grpcio-status==1.60.0
46
  h11==0.14.0
47
  httpcore==1.0.2
48
+ httplib2==0.22.0
49
  httpx==0.26.0
50
  huggingface-hub==0.20.3
51
  idna==3.6
 
60
  matplotlib==3.8.2
61
  mdurl==0.1.2
62
  mpmath==1.3.0
63
+ msgpack==1.0.7
64
  multidict==6.0.4
65
  networkx==3.2.1
66
  nh3==0.2.15
 
77
  psutil==5.9.8
78
  pyasn1==0.5.1
79
  pyasn1-modules==0.3.0
80
+ pycparser==2.21
81
  pydantic==1.10.14
82
  pydantic_core==2.16.1
83
  pydub==0.25.1
84
  Pygments==2.17.2
85
+ PyJWT==2.8.0
86
  pyparsing==3.1.1
87
  python-dateutil==2.8.2
88
  python-multipart==0.0.6
 
116
  typer==0.9.0
117
  typing_extensions==4.9.0
118
  tzdata==2023.4
119
+ uritemplate==4.1.1
120
  urllib3==2.2.0
121
  uvicorn==0.27.0.post1
122
  wavedrom==2.0.3.post3