yuchenlin commited on
Commit
9e82682
·
1 Parent(s): 8cba5ca

add more models

Browse files
Files changed (2) hide show
  1. README.md +1 -2
  2. app.py +61 -9
README.md CHANGED
@@ -6,8 +6,7 @@ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
  app_file: app.py
9
- pinned: true
10
- fullWidth: true
11
  ---
12
 
13
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
  app_file: app.py
9
+ pinned: true
 
10
  ---
11
 
12
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
app.py CHANGED
@@ -85,6 +85,18 @@ def respond(
85
  _model_name = "meta-llama/Llama-3-8b-hf"
86
  elif model_name == "Llama-3-70B":
87
  _model_name = "meta-llama/Llama-3-70b-hf"
 
 
 
 
 
 
 
 
 
 
 
 
88
  else:
89
  raise ValueError("Invalid model name")
90
  # _model_name = "meta-llama/Llama-3-8b-hf"
@@ -105,31 +117,69 @@ def respond(
105
  for msg in request:
106
  # print(msg.choices[0].delta.keys())
107
  token = msg.choices[0].delta["content"]
108
- response += token
109
  should_stop = False
110
  for _stop in stop_str:
111
- if _stop in response:
112
  should_stop = True
113
  break
114
  if should_stop:
115
  break
116
- yield response
 
 
 
 
 
117
 
118
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  with gr.Row():
120
  with gr.Column():
121
- gr.Label("Welcome to the URIAL Chatbot!")
122
- model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B"], value="Llama-3-8B", label="Base model name")
123
- together_api_key = gr.Textbox(label="Together API Key", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password")
 
 
 
124
  with gr.Column():
 
125
  with gr.Column():
126
  with gr.Row():
127
  max_tokens = gr.Textbox(value=1024, label="Max tokens")
128
  temperature = gr.Textbox(value=0.5, label="Temperature")
129
- with gr.Column():
130
- with gr.Row():
131
  top_p = gr.Textbox(value=0.9, label="Top-p")
132
  rp = gr.Textbox(value=1.1, label="Repetition penalty")
 
133
 
134
  chat = gr.ChatInterface(
135
  respond,
@@ -139,6 +189,8 @@ with gr.Blocks() as demo:
139
  )
140
  chat.chatbot.height = 600
141
 
 
 
142
 
143
  if __name__ == "__main__":
144
  demo.launch()
 
85
  _model_name = "meta-llama/Llama-3-8b-hf"
86
  elif model_name == "Llama-3-70B":
87
  _model_name = "meta-llama/Llama-3-70b-hf"
88
+ elif model_name == "Llama-2-7B":
89
+ _model_name = "meta-llama/Llama-2-7b-hf"
90
+ elif model_name == "Llama-2-70B":
91
+ _model_name = "meta-llama/Llama-2-70b-hf"
92
+ elif model_name == "Mistral-7B-v0.1":
93
+ _model_name = "mistralai/Mistral-7B-v0.1"
94
+ elif model_name == "mistralai/Mixtral-8x22B":
95
+ _model_name = "mistralai/Mixtral-8x22B"
96
+ elif model_name == "Qwen1.5-72B":
97
+ _model_name = "Qwen/Qwen1.5-72B"
98
+ elif model_name == "Yi-34B":
99
+ _model_name = "zero-one-ai/Yi-34B"
100
  else:
101
  raise ValueError("Invalid model name")
102
  # _model_name = "meta-llama/Llama-3-8b-hf"
 
117
  for msg in request:
118
  # print(msg.choices[0].delta.keys())
119
  token = msg.choices[0].delta["content"]
 
120
  should_stop = False
121
  for _stop in stop_str:
122
+ if _stop in response + token:
123
  should_stop = True
124
  break
125
  if should_stop:
126
  break
127
+ response += token
128
+ if response.endswith('\n"'):
129
+ response = response[:-1]
130
+ elif response.endswith('\n""'):
131
+ response = response[:-2]
132
+ yield response
133
 
134
+ js_code_label = """
135
+ function addApiKeyLink() {
136
+ // Select the div with id 'api_key'
137
+ const apiKeyDiv = document.getElementById('api_key');
138
+
139
+ // Find the span within that div with data-testid 'block-info'
140
+ const blockInfoSpan = apiKeyDiv.querySelector('span[data-testid="block-info"]');
141
+
142
+ // Create the new link element
143
+ const newLink = document.createElement('a');
144
+ newLink.href = 'https://api.together.ai/settings/api-keys';
145
+ newLink.textContent = ' View your keys here.';
146
+ newLink.target = '_blank'; // Open link in new tab
147
+ newLink.style = 'color: #007bff; text-decoration: underline;';
148
+
149
+ // Create the additional text
150
+ const additionalText = document.createTextNode(' (new account will have free credits to use.)');
151
+
152
+ // Append the link and additional text to the span
153
+ if (blockInfoSpan) {
154
+ // add a br
155
+ apiKeyDiv.appendChild(document.createElement('br'));
156
+ apiKeyDiv.appendChild(newLink);
157
+ apiKeyDiv.appendChild(additionalText);
158
+ } else {
159
+ console.error('Span with data-testid "block-info" not found');
160
+ }
161
+ }
162
+ """
163
+ with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
164
  with gr.Row():
165
  with gr.Column():
166
+ gr.Markdown("""# 💬 BaseChat: Chat with Base LLMs with URIAL
167
+ [Paper](https://arxiv.org/abs/2312.01552) | [Website](https://allenai.github.io/re-align/) | [GitHub](https://github.com/Re-Align/urial) | Contact: [Yuchen Lin](https://yuchenlin.xyz/)
168
+
169
+ **Talk with __BASE__ LLMs which are not fine-tuned at all.**
170
+ """)
171
+ model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", "mistralai/Mixtral-8x22B", "Yi-34B", "Llama-2-7B", "Llama-2-70B"], value="Llama-3-8B", label="Base LLM name")
172
  with gr.Column():
173
+ together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password", elem_id="api_key")
174
  with gr.Column():
175
  with gr.Row():
176
  max_tokens = gr.Textbox(value=1024, label="Max tokens")
177
  temperature = gr.Textbox(value=0.5, label="Temperature")
178
+ # with gr.Column():
179
+ # with gr.Row():
180
  top_p = gr.Textbox(value=0.9, label="Top-p")
181
  rp = gr.Textbox(value=1.1, label="Repetition penalty")
182
+
183
 
184
  chat = gr.ChatInterface(
185
  respond,
 
189
  )
190
  chat.chatbot.height = 600
191
 
192
+
193
+
194
 
195
  if __name__ == "__main__":
196
  demo.launch()