yuchenlin commited on
Commit
d8f6559
·
1 Parent(s): f45484b

side by side

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
__pycache__/constant.cpython-311.pyc CHANGED
Binary files a/__pycache__/constant.cpython-311.pyc and b/__pycache__/constant.cpython-311.pyc differ
 
__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
 
app.py CHANGED
@@ -3,8 +3,8 @@ import os
3
  from typing import List
4
  import logging
5
  import urllib.request
6
- from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY
7
- from constant import js_code_label, HEADER_MD
8
  from openai import OpenAI
9
  import datetime
10
  # add logging info to console
@@ -19,28 +19,49 @@ STOP_STRS = ['"""', '# Query:', '# Answer:']
19
  addr_limit_counter = {}
20
  LAST_UPDATE_TIME = datetime.datetime.now()
21
 
 
 
 
 
 
 
 
 
 
22
  def respond(
23
  message,
24
  history: list[tuple[str, str]],
25
  max_tokens,
26
  temperature,
27
  top_p,
28
- rp,
29
  model_name,
30
- together_api_key,
 
31
  request:gr.Request
32
  ):
33
- global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
34
- rp = 1.0
35
- prompt = urial_template(urial_prompt, history, message)
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # _model_name = "meta-llama/Llama-3-8b-hf"
38
  _model_name = model_name_mapping(model_name)
39
 
40
- if together_api_key and len(together_api_key) == 64:
41
- api_key = together_api_key
42
  else:
43
- api_key = DEFAULT_API_KEY
44
 
45
  # headers = request.headers
46
  # if already 24 hours passed, reset the counter
@@ -53,12 +74,21 @@ def respond(
53
  if addr_limit_counter[host_addr] > 100:
54
  return "You have reached the limit of 100 requests for today. Please use your own API key."
55
 
56
- infer_request = openai_base_request(prompt=prompt, model=_model_name,
57
- temperature=temperature,
58
- max_tokens=max_tokens,
59
- top_p=top_p,
60
- repetition_penalty=rp,
61
- stop=STOP_STRS, api_key=api_key)
 
 
 
 
 
 
 
 
 
62
  addr_limit_counter[host_addr] += 1
63
  logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
64
  logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
@@ -66,45 +96,103 @@ def respond(
66
  response = ""
67
  for msg in infer_request:
68
  # print(msg.choices[0].delta.keys())
69
- token = msg.choices[0].delta["content"]
70
- should_stop = False
71
- for _stop in STOP_STRS:
72
- if _stop in response + token:
73
- should_stop = True
 
 
 
 
 
 
 
 
 
 
74
  break
75
- if should_stop:
76
- break
77
  response += token
78
- if response.endswith('\n"'):
79
- response = response[:-1]
80
- elif response.endswith('\n""'):
81
- response = response[:-2]
82
- yield response
 
 
 
 
83
 
84
- with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  with gr.Row():
86
- with gr.Column():
87
- gr.Markdown(HEADER_MD)
88
- model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1",
89
- "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
90
- , value="Llama-3-8B", label="Base LLM name")
91
- with gr.Column():
92
- together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key")
93
- with gr.Column():
94
- with gr.Row():
95
- max_tokens = gr.Textbox(value=256, label="Max tokens")
96
- temperature = gr.Textbox(value=0.5, label="Temperature")
97
- top_p = gr.Textbox(value=0.9, label="Top-p")
98
- rp = gr.Textbox(value=1.1, label="Repetition penalty")
99
- chat = gr.ChatInterface(
100
- respond,
101
- additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
102
- # additional_inputs_accordion="⚙️ Parameters",
103
- # fill_height=True,
104
- )
105
- chat.chatbot.label="Chat with Base LLMs via URIAL"
106
- chat.chatbot.height = 550
107
- chat.chatbot.show_copy_button = True
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if __name__ == "__main__":
110
  demo.launch(show_api=False)
 
3
  from typing import List
4
  import logging
5
  import urllib.request
6
+ from utils import model_name_mapping, urial_template, openai_base_request, chat_template, openai_chat_request
7
+ from constant import js_code_label, HEADER_MD, BASE_TO_ALIGNED, MODELS
8
  from openai import OpenAI
9
  import datetime
10
  # add logging info to console
 
19
  addr_limit_counter = {}
20
  LAST_UPDATE_TIME = datetime.datetime.now()
21
 
22
+
23
+ models = MODELS
24
+
25
+
26
+ # mega_hist = {
27
+ # "base": [],
28
+ # "aligned": []
29
+ # }
30
+
31
  def respond(
32
  message,
33
  history: list[tuple[str, str]],
34
  max_tokens,
35
  temperature,
36
  top_p,
37
+ rp,
38
  model_name,
39
+ model_type,
40
+ api_key,
41
  request:gr.Request
42
  ):
43
+ global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
44
+
45
+ assert model_type in ["base", "aligned"]
46
+ # if history:
47
+ # if model_type == "base":
48
+ # mega_hist["base"] = history
49
+ # else:
50
+ # mega_hist["aligned"] = history
51
+
52
+
53
+ if model_type == "base":
54
+ prompt = urial_template(urial_prompt, history, message)
55
+ else:
56
+ messages = chat_template(history, message)
57
 
58
  # _model_name = "meta-llama/Llama-3-8b-hf"
59
  _model_name = model_name_mapping(model_name)
60
 
61
+ if api_key and len(api_key) == 64:
62
+ api_key = api_key
63
  else:
64
+ api_key = None
65
 
66
  # headers = request.headers
67
  # if already 24 hours passed, reset the counter
 
74
  if addr_limit_counter[host_addr] > 100:
75
  return "You have reached the limit of 100 requests for today. Please use your own API key."
76
 
77
+ if model_type == "base":
78
+ infer_request = openai_base_request(prompt=prompt, model=_model_name,
79
+ temperature=temperature,
80
+ max_tokens=max_tokens,
81
+ top_p=top_p,
82
+ repetition_penalty=rp,
83
+ stop=STOP_STRS, api_key=api_key)
84
+ else:
85
+ infer_request = openai_chat_request(messages=messages, model=_model_name,
86
+ temperature=temperature,
87
+ max_tokens=max_tokens,
88
+ top_p=top_p,
89
+ repetition_penalty=rp,
90
+ stop=STOP_STRS, api_key=api_key)
91
+
92
  addr_limit_counter[host_addr] += 1
93
  logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
94
  logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
 
96
  response = ""
97
  for msg in infer_request:
98
  # print(msg.choices[0].delta.keys())
99
+ if hasattr(msg.choices[0], "delta"):
100
+ # Note: 'ChoiceDelta' object may or may not be not subscriptable
101
+ if "content" in msg.choices[0].delta:
102
+ token = msg.choices[0].delta["content"]
103
+ else:
104
+ token = msg.choices[0].delta.content
105
+ else:
106
+ token = msg.choices[0].text
107
+ if model_type == "base":
108
+ should_stop = False
109
+ for _stop in STOP_STRS:
110
+ if _stop in response + token:
111
+ should_stop = True
112
+ break
113
+ if should_stop:
114
  break
115
+ if token is None:
116
+ continue
117
  response += token
118
+ if model_type == "base":
119
+ if response.endswith('\n"'):
120
+ response = response[:-1]
121
+ elif response.endswith('\n""'):
122
+ response = response[:-2]
123
+ yield history + [(message, response)]
124
+ # mega_hist[model_type].append((message, response))
125
+ # yield mega_hist[model_type]
126
+
127
 
128
+
129
+ def load_models(base_model_name):
130
+ print(f"base_model_name={base_model_name}")
131
+ out_box = [gr.Chatbot(), gr.Chatbot(), gr.Dropdown()]
132
+ out_box[0] = (gr.update(label=f"Chat with Base LLM: {base_model_name}"))
133
+ aligned_model_name = BASE_TO_ALIGNED[base_model_name]
134
+ out_box[1] = (gr.update(label=f"Chat with Aligned LLM: {aligned_model_name}"))
135
+ out_box[2] = (gr.update(value=aligned_model_name, interactive=False))
136
+ return out_box[0], out_box[1], out_box[2]
137
+
138
+ def clear_fn():
139
+ # mega_hist["base"] = []
140
+ # mega_hist["aligned"] = []
141
+ return None, None, None
142
+
143
+
144
+ with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
145
+ api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
146
+
147
+ gr.Markdown(HEADER_MD)
148
+
149
  with gr.Row():
150
+ chat_a = gr.Chatbot(height=500, label="Chat with Base LLMs via URIAL")
151
+ chat_b = gr.Chatbot(height=500, label="Chat with Aligned LLMs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ with gr.Group():
154
+ with gr.Row():
155
+ with gr.Column(scale=2):
156
+ message = gr.Textbox(label="Prompt", placeholder="Enter your message here")
157
+ with gr.Row():
158
+ with gr.Column(scale=2):
159
+ with gr.Row():
160
+ left_model_choice = gr.Dropdown(label="Base Model", choices=models, interactive=True)
161
+ right_model_choice = gr.Textbox(label="Aligned Model", placeholder="xxx", visible=True)
162
+ with gr.Row():
163
+ btn = gr.Button("🚀 Chat")
164
+ # gr.Markdown("---")
165
+ with gr.Row():
166
+ stop_btn = gr.Button("⏸️ Stop")
167
+ clear_btn = gr.Button("🔁 Clear")
168
+ with gr.Row():
169
+ gr.Markdown("We thank for the support from [Hyperbolic AI](https://hyperbolic.xyz/).")
170
+ with gr.Column(scale=1):
171
+ with gr.Accordion("⚙️ Params for **Base** LLM", open=True):
172
+ with gr.Row():
173
+ max_tokens_1 = gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=2048, step=16, interactive=True, visible=True)
174
+ temperature_1 = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
175
+ with gr.Row():
176
+ top_p_1 = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
177
+ rp_1 = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.1)
178
+ with gr.Accordion("⚙️ Params for **Aligned** LLM", open=True):
179
+ with gr.Row():
180
+ max_tokens_2 = gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=2048, step=16, interactive=True, visible=True)
181
+ temperature_2 = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
182
+ with gr.Row():
183
+ top_p_2 = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
184
+ rp_2 = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
185
+
186
+ left_model_choice.change(load_models, [left_model_choice], [chat_a, chat_b, right_model_choice])
187
+
188
+ model_type_left = gr.Textbox(visible=False, value="base")
189
+ model_type_right = gr.Textbox(visible=False, value="aligned")
190
+
191
+ go1 = btn.click(respond, [message, chat_a, max_tokens_1, temperature_1, top_p_1, rp_1, left_model_choice, model_type_left, api_key], chat_a)
192
+ go2 = btn.click(respond, [message, chat_b, max_tokens_2, temperature_2, top_p_2, rp_2, right_model_choice, model_type_right, api_key], chat_b)
193
+
194
+ stop_btn.click(None, None, None, cancels=[go1, go2])
195
+ clear_btn.click(clear_fn, None, [message, chat_a, chat_b])
196
+
197
  if __name__ == "__main__":
198
  demo.launch(show_api=False)
app_single.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from typing import List
4
+ import logging
5
+ import urllib.request
6
+ from utils import model_name_mapping, urial_template, openai_base_request
7
+ from constant import js_code_label, HEADER_MD
8
+ from openai import OpenAI
9
+ import datetime
10
+ # add logging info to console
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ URIAL_VERSION = "inst_1k_v4.help"
14
+ URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
15
+ urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
16
+ urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
17
+ STOP_STRS = ['"""', '# Query:', '# Answer:']
18
+
19
+ addr_limit_counter = {}
20
+ LAST_UPDATE_TIME = datetime.datetime.now()
21
+
22
+
23
+ def respond(
24
+ message,
25
+ history: list[tuple[str, str]],
26
+ max_tokens,
27
+ temperature,
28
+ top_p,
29
+ rp,
30
+ model_name,
31
+ api_key,
32
+ request:gr.Request
33
+ ):
34
+ global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
35
+ rp = 1.0
36
+ prompt = urial_template(urial_prompt, history, message)
37
+
38
+ # _model_name = "meta-llama/Llama-3-8b-hf"
39
+ _model_name = model_name_mapping(model_name)
40
+
41
+ if api_key and len(api_key) == 64:
42
+ api_key = api_key
43
+ else:
44
+ api_key = None
45
+
46
+ # headers = request.headers
47
+ # if already 24 hours passed, reset the counter
48
+ if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
49
+ addr_limit_counter = {}
50
+ LAST_UPDATE_TIME = datetime.datetime.now()
51
+ host_addr = request.client.host
52
+ if host_addr not in addr_limit_counter:
53
+ addr_limit_counter[host_addr] = 0
54
+ if addr_limit_counter[host_addr] > 100:
55
+ return "You have reached the limit of 100 requests for today. Please use your own API key."
56
+
57
+ infer_request = openai_base_request(prompt=prompt, model=_model_name,
58
+ temperature=temperature,
59
+ max_tokens=max_tokens,
60
+ top_p=top_p,
61
+ repetition_penalty=rp,
62
+ stop=STOP_STRS, api_key=api_key)
63
+ addr_limit_counter[host_addr] += 1
64
+ logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
65
+ logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
66
+
67
+ response = ""
68
+ for msg in infer_request:
69
+ # print(msg.choices[0].delta.keys())
70
+ if hasattr(msg.choices[0], "delta"):
71
+ token = msg.choices[0].delta["content"]
72
+ else:
73
+ token = msg.choices[0].text
74
+ should_stop = False
75
+ for _stop in STOP_STRS:
76
+ if _stop in response + token:
77
+ should_stop = True
78
+ break
79
+ if should_stop:
80
+ break
81
+ response += token
82
+ if response.endswith('\n"'):
83
+ response = response[:-1]
84
+ elif response.endswith('\n""'):
85
+ response = response[:-2]
86
+ yield response
87
+
88
+ with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
89
+ with gr.Row():
90
+ with gr.Column():
91
+ gr.Markdown(HEADER_MD)
92
+ model_name = gr.Radio(["Llama-3.1-405B-FP8", "Llama-3-70B", "Llama-3-8B",
93
+ "Mistral-7B-v0.1",
94
+ "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
95
+ , value="Llama-3.1-405B-FP8", label="Base LLM name")
96
+ with gr.Column():
97
+ api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
98
+ # with gr.Column():
99
+ with gr.Accordion("⚙️ Parameters for Base LLM", open=True):
100
+ with gr.Row():
101
+ max_tokens = gr.Textbox(value=256, label="Max tokens")
102
+ temperature = gr.Textbox(value=0.5, label="Temperature")
103
+ top_p = gr.Textbox(value=0.9, label="Top-p")
104
+ rp = gr.Textbox(value=1.1, label="Repetition penalty")
105
+ # with gr.Row():
106
+ chat = gr.ChatInterface(
107
+ respond,
108
+ additional_inputs=[max_tokens, temperature, top_p, rp, model_name, api_key],
109
+ # additional_inputs_accordion="⚙️ Parameters",
110
+ # fill_height=True,
111
+ )
112
+ chat.chatbot.label="Chat with Base LLMs via URIAL"
113
+ chat.chatbot.height = 550
114
+ chat.chatbot.show_copy_button = True
115
+
116
+ if __name__ == "__main__":
117
+ demo.launch(show_api=False)
constant.py CHANGED
@@ -33,3 +33,57 @@ function addApiKeyLink() {
33
  }
34
  }
35
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  }
34
  }
35
  """
36
+
37
+
38
+ MODELS = ["Llama-3.1-405B-FP8", "Llama-3-70B", "Llama-3-8B",
39
+ "Mistral-7B-v0.1",
40
+ "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMo-7B"]
41
+
42
+ HYPERBOLIC_MODELS = ["meta-llama/Meta-Llama-3.1-405B-FP8", "meta-llama/Meta-Llama-3.1-405B-Instruct"]
43
+
44
+ BASE_TO_ALIGNED = {
45
+ "Llama-3-70B": "Llama-3-70B-Instruct",
46
+ "Llama-3-8B": "Llama-3-8B-Instruct",
47
+ "Mistral-7B-v0.1": "Mistral-7B-v0.1-Instruct",
48
+ "Mixtral-8x22B": "Mixtral-8x22B-Instruct",
49
+ "Qwen1.5-72B": "Qwen1.5-72B-Instruct",
50
+ "Llama-3.1-405B-FP8": "Llama-3.1-405B-FP8-Instruct",
51
+ "Yi-34B": "Yi-34B-chat",
52
+ "Llama-2-7B": "Llama-2-7B-chat",
53
+ "Llama-2-70B": "Llama-2-70B-chat",
54
+ "OLMo-7B": "OLMo-7B-Instruct",
55
+ }
56
+
57
+
58
+ MODEL_MAPPING = {
59
+ "Llama-3-8B": "meta-llama/Llama-3-8b-hf",
60
+ "Llama-3-70B": "meta-llama/Llama-3-70b-hf",
61
+ "Llama-2-7B": "meta-llama/Llama-2-7b-hf",
62
+ "Llama-2-70B": "meta-llama/Llama-2-70b-hf",
63
+ "Mistral-7B-v0.1": "mistralai/Mistral-7B-v0.1",
64
+ "Mixtral-8x22B": "mistralai/Mixtral-8x22B",
65
+ "Qwen1.5-72B": "Qwen/Qwen1.5-72B",
66
+ "Yi-34B": "zero-one-ai/Yi-34B",
67
+ "Yi-6B": "zero-one-ai/Yi-6B",
68
+ "OLMo-7B": "allenai/OLMo-7B",
69
+ "Llama-3.1-405B-FP8": "meta-llama/Meta-Llama-3.1-405B-FP8",
70
+ # Aligned models below
71
+ "Llama-3-70B-Instruct": "meta-llama/Meta-Llama-3-70B-Instruct-Lite",
72
+ "Llama-3-8B-Instruct": "meta-llama/Meta-Llama-3-8B-Instruct-Lite",
73
+ "Mistral-7B-v0.1-Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
74
+ "Mixtral-8x22B-Instruct": "mistralai/Mixtral-8x22B-Instruct-v0.1",
75
+ "Qwen1.5-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
76
+ "Yi-34B-chat": "zero-one-ai/Yi-34B-Chat",
77
+ "Llama-2-7B-chat": "meta-llama/Llama-2-7b-chat-hf",
78
+ "Llama-2-70B-chat": "meta-llama/Llama-2-70b-chat-hf",
79
+ "OLMo-7B-Instruct": "allenai/OLMo-7B-Instruct",
80
+ "Llama-3.1-405B-FP8-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
81
+ }
82
+
83
+ # import json
84
+ # with open("together_model_ids.json", "r") as f:
85
+ # TOGETHER_MODEL_IDS = json.load(f)
86
+
87
+ # for _, model_id in MODEL_MAPPING.items():
88
+ # if model_id not in TOGETHER_MODEL_IDS + HYPERBOLIC_MODELS:
89
+ # print(model_id)
list_models.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+
5
+ url = "https://api.together.xyz/v1/models"
6
+
7
+ headers = {
8
+ "accept": "application/json",
9
+ "Authorization": f"Bearer {os.getenv('TOGETHER_API_KEY')}"
10
+ }
11
+
12
+ response = requests.get(url, headers=headers)
13
+
14
+ data = response.json()
15
+ keywords = ["OLMO"]
16
+
17
+ model_ids = []
18
+ for item in data:
19
+ if any(keyword.lower() in item["id"].lower() for keyword in keywords):
20
+ print(item["id"])
21
+ model_ids.append(item["id"])
22
+
23
+ with open("together_model_ids.json", "w") as f:
24
+ json.dump(model_ids, f, indent=4)
together_model_ids.json ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "Nexusflow/NexusRaven-V2-13B",
3
+ "bert-base-uncased",
4
+ "WizardLM/WizardLM-13B-V1.2",
5
+ "codellama/CodeLlama-34b-Instruct-hf",
6
+ "google/gemma-7b",
7
+ "upstage/SOLAR-10.7B-Instruct-v1.0",
8
+ "zero-one-ai/Yi-34B",
9
+ "togethercomputer/StripedHyena-Hessian-7B",
10
+ "meta-llama/Llama-3-70b-chat-hf",
11
+ "teknium/OpenHermes-2-Mistral-7B",
12
+ "mistralai/Mixtral-8x7B-v0.1",
13
+ "WhereIsAI/UAE-Large-V1",
14
+ "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1",
15
+ "togethercomputer/Llama-2-7B-32K-Instruct",
16
+ "Undi95/ReMM-SLERP-L2-13B",
17
+ "meta-llama/Meta-Llama-Guard-3-8B",
18
+ "Undi95/Toppy-M-7B",
19
+ "Phind/Phind-CodeLlama-34B-v2",
20
+ "stabilityai/stable-diffusion-2-1",
21
+ "openchat/openchat-3.5-1210",
22
+ "Austism/chronos-hermes-13b",
23
+ "microsoft/phi-2",
24
+ "Qwen/Qwen1.5-0.5B",
25
+ "Qwen/Qwen1.5-1.8B",
26
+ "Qwen/Qwen1.5-4B",
27
+ "Qwen/Qwen1.5-7B",
28
+ "togethercomputer/m2-bert-80M-32k-retrieval",
29
+ "snorkelai/Snorkel-Mistral-PairRM-DPO",
30
+ "Qwen/Qwen1.5-7B-Chat",
31
+ "Qwen/Qwen1.5-14B",
32
+ "Qwen/Qwen1.5-14B-Chat",
33
+ "Qwen/Qwen1.5-72B",
34
+ "Qwen/Qwen1.5-1.8B-Chat",
35
+ "BAAI/bge-base-en-v1.5",
36
+ "Snowflake/snowflake-arctic-instruct",
37
+ "codellama/CodeLlama-13b-Python-hf",
38
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
39
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
40
+ "togethercomputer/m2-bert-80M-2k-retrieval",
41
+ "deepseek-ai/deepseek-coder-33b-instruct",
42
+ "codellama/CodeLlama-34b-Python-hf",
43
+ "NousResearch/Nous-Hermes-Llama2-13b",
44
+ "lmsys/vicuna-13b-v1.5",
45
+ "Qwen/Qwen1.5-0.5B-Chat",
46
+ "codellama/CodeLlama-70b-Python-hf",
47
+ "codellama/CodeLlama-7b-Instruct-hf",
48
+ "NousResearch/Nous-Hermes-2-Yi-34B",
49
+ "codellama/CodeLlama-13b-Instruct-hf",
50
+ "BAAI/bge-large-en-v1.5",
51
+ "togethercomputer/Llama-3-8b-chat-hf-int4",
52
+ "meta-llama/Llama-2-13b-hf",
53
+ "teknium/OpenHermes-2p5-Mistral-7B",
54
+ "NousResearch/Nous-Capybara-7B-V1p9",
55
+ "WizardLM/WizardCoder-Python-34B-V1.0",
56
+ "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
57
+ "NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
58
+ "togethercomputer/StripedHyena-Nous-7B",
59
+ "togethercomputer/alpaca-7b",
60
+ "garage-bAInd/Platypus2-70B-instruct",
61
+ "google/gemma-2b",
62
+ "google/gemma-2b-it",
63
+ "google/gemma-7b-it",
64
+ "meta-llama/Llama-2-7b-chat-hf",
65
+ "allenai/OLMo-7B",
66
+ "allenai/OLMo-7B-Instruct",
67
+ "Qwen/Qwen1.5-4B-Chat",
68
+ "stabilityai/stable-diffusion-xl-base-1.0",
69
+ "Gryphe/MythoMax-L2-13b",
70
+ "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
71
+ "meta-llama/LlamaGuard-2-8b",
72
+ "mistralai/Mistral-7B-Instruct-v0.1",
73
+ "mistralai/Mistral-7B-Instruct-v0.2",
74
+ "meta-llama/Meta-Llama-3-8B",
75
+ "mistralai/Mistral-7B-v0.1",
76
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
77
+ "Open-Orca/Mistral-7B-OpenOrca",
78
+ "Qwen/Qwen1.5-32B",
79
+ "NousResearch/Nous-Hermes-llama-2-7b",
80
+ "Qwen/Qwen1.5-32B-Chat",
81
+ "mistralai/Mixtral-8x22B",
82
+ "Qwen/Qwen2-72B-Instruct",
83
+ "Qwen/Qwen1.5-72B-Chat",
84
+ "meta-llama/Meta-Llama-3-70B",
85
+ "meta-llama/Llama-3-8b-hf",
86
+ "deepseek-ai/deepseek-llm-67b-chat",
87
+ "sentence-transformers/msmarco-bert-base-dot-v5",
88
+ "zero-one-ai/Yi-6B",
89
+ "lmsys/vicuna-7b-v1.5",
90
+ "togethercomputer/m2-bert-80M-8k-retrieval",
91
+ "microsoft/WizardLM-2-8x22B",
92
+ "togethercomputer/Llama-3-8b-chat-hf-int8",
93
+ "wavymulder/Analog-Diffusion",
94
+ "mistralai/Mistral-7B-Instruct-v0.3",
95
+ "Qwen/Qwen1.5-110B-Chat",
96
+ "runwayml/stable-diffusion-v1-5",
97
+ "prompthero/openjourney",
98
+ "meta-llama/Llama-2-7b-hf",
99
+ "SG161222/Realistic_Vision_V3.0_VAE",
100
+ "meta-llama/Llama-2-13b-chat-hf",
101
+ "google/gemma-2-27b-it",
102
+ "zero-one-ai/Yi-34B-Chat",
103
+ "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
104
+ "meta-llama/Meta-Llama-3-70B-Instruct-Lite",
105
+ "google/gemma-2-9b-it",
106
+ "google/gemma-2-9b",
107
+ "meta-llama/Llama-3-8b-chat-hf",
108
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
109
+ "codellama/CodeLlama-70b-hf",
110
+ "togethercomputer/LLaMA-2-7B-32K",
111
+ "databricks/dbrx-instruct",
112
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Reference",
113
+ "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
114
+ "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
115
+ "mistralai/Mixtral-8x22B-Instruct-v0.1",
116
+ "togethercomputer/evo-1-131k-base",
117
+ "meta-llama/Llama-2-70b-hf",
118
+ "codellama/CodeLlama-70b-Instruct-hf",
119
+ "meta-llama/Meta-Llama-3-8B-Instruct-Lite",
120
+ "togethercomputer/evo-1-8k-base",
121
+ "meta-llama/Llama-2-70b-chat-hf",
122
+ "codellama/CodeLlama-7b-Python-hf",
123
+ "Meta-Llama/Llama-Guard-7b",
124
+ "togethercomputer/Koala-7B",
125
+ "Qwen/Qwen2-1.5B-Instruct",
126
+ "Qwen/Qwen2-7B-Instruct",
127
+ "NousResearch/Nous-Hermes-13b",
128
+ "togethercomputer/guanaco-65b",
129
+ "togethercomputer/llama-2-7b",
130
+ "huggyllama/llama-7b",
131
+ "lmsys/vicuna-7b-v1.3",
132
+ "Qwen/Qwen2-72B",
133
+ "Phind/Phind-CodeLlama-34B-Python-v1",
134
+ "NumbersStation/nsql-llama-2-7B",
135
+ "NousResearch/Nous-Hermes-Llama2-70b",
136
+ "WizardLM/WizardLM-70B-V1.0",
137
+ "huggyllama/llama-65b",
138
+ "lmsys/vicuna-13b-v1.5-16k",
139
+ "HuggingFaceH4/zephyr-7b-beta",
140
+ "togethercomputer/llama-2-13b",
141
+ "togethercomputer/CodeLlama-7b-Instruct",
142
+ "togethercomputer/guanaco-13b",
143
+ "togethercomputer/CodeLlama-34b-Python",
144
+ "togethercomputer/CodeLlama-34b-Instruct",
145
+ "togethercomputer/CodeLlama-34b",
146
+ "togethercomputer/llama-2-70b",
147
+ "codellama/CodeLlama-13b-hf",
148
+ "Qwen/Qwen2-7B",
149
+ "Qwen/Qwen2-1.5B",
150
+ "togethercomputer/CodeLlama-13b-Instruct",
151
+ "togethercomputer/llama-2-13b-chat",
152
+ "lmsys/vicuna-13b-v1.3",
153
+ "huggyllama/llama-13b",
154
+ "huggyllama/llama-30b",
155
+ "togethercomputer/guanaco-33b",
156
+ "togethercomputer/Koala-13B",
157
+ "togethercomputer/llama-2-7b-chat",
158
+ "togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4",
159
+ "togethercomputer/guanaco-7b",
160
+ "EleutherAI/llemma_7b",
161
+ "meta-llama/Meta-Llama-3-8B-Instruct",
162
+ "codellama/CodeLlama-34b-hf",
163
+ "meta-llama/Meta-Llama-3-70B-Instruct",
164
+ "meta-llama/Llama-3-70b-hf",
165
+ "togethercomputer/CodeLlama-7b-Python",
166
+ "NousResearch/Hermes-2-Theta-Llama-3-70B",
167
+ "carson/ml318bit",
168
+ "togethercomputer/CodeLlama-13b-Python",
169
+ "codellama/CodeLlama-7b-hf",
170
+ "togethercomputer/llama-2-70b-chat",
171
+ "carson/ml31405bit",
172
+ "carson/ml3170bit",
173
+ "carson/mlg38b",
174
+ "carson/ml318br",
175
+ "meta-llama/Meta-Llama-3.1-8B-Reference",
176
+ "gradientai/Llama-3-70B-Instruct-Gradient-1048k",
177
+ "meta-llama/Meta-Llama-3.1-70B-Instruct-Reference",
178
+ "meta-llama/Meta-Llama-3.1-70B-Reference"
179
+ ]
utils.py CHANGED
@@ -3,36 +3,15 @@ from openai import OpenAI
3
  import logging
4
  from typing import List
5
  import os
 
6
 
7
- BASE_URL = "https://api.together.xyz/v1"
8
- DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY")
9
 
10
  def model_name_mapping(model_name):
11
- if model_name == "Llama-3-8B":
12
- _model_name = "meta-llama/Llama-3-8b-hf"
13
- elif model_name == "Llama-3-70B":
14
- _model_name = "meta-llama/Llama-3-70b-hf"
15
- elif model_name == "Llama-2-7B":
16
- _model_name = "meta-llama/Llama-2-7b-hf"
17
- elif model_name == "Llama-2-70B":
18
- _model_name = "meta-llama/Llama-2-70b-hf"
19
- elif model_name == "Mistral-7B-v0.1":
20
- _model_name = "mistralai/Mistral-7B-v0.1"
21
- elif model_name == "Mixtral-8x22B":
22
- _model_name = "mistralai/Mixtral-8x22B"
23
- elif model_name == "Qwen1.5-72B":
24
- _model_name = "Qwen/Qwen1.5-72B"
25
- elif model_name == "Yi-34B":
26
- _model_name = "zero-one-ai/Yi-34B"
27
- elif model_name == "Yi-6B":
28
- _model_name = "zero-one-ai/Yi-6B"
29
- elif model_name == "OLMO":
30
- _model_name = "allenai/OLMo-7B"
31
- elif model_name == "Qwen1.5-72B":
32
- _model_name = "Qwen/Qwen1.5-72B"
33
  else:
34
- raise ValueError("Invalid model name")
35
- return _model_name
36
 
37
 
38
  def urial_template(urial_prompt, history, message):
@@ -41,7 +20,14 @@ def urial_template(urial_prompt, history, message):
41
  current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
42
  current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
43
  return current_prompt
44
-
 
 
 
 
 
 
 
45
 
46
  def openai_base_request(
47
  model: str=None,
@@ -54,11 +40,18 @@ def openai_base_request(
54
  stop: List[str]=None,
55
  api_key: str=None,
56
  ):
 
 
 
 
 
 
 
 
57
  if api_key is None:
58
  api_key = DEFAULT_API_KEY
59
- client = OpenAI(api_key=api_key, base_url=BASE_URL)
60
- # print(f"Requesting chat completion from OpenAI API with model {model}")
61
- logging.info(f"Requesting chat completion from OpenAI API with model {model}")
62
  logging.info(f"Prompt: {prompt}")
63
  logging.info(f"Temperature: {temperature}")
64
  logging.info(f"Max tokens: {max_tokens}")
@@ -80,3 +73,44 @@ def openai_base_request(
80
 
81
  return request
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import logging
4
  from typing import List
5
  import os
6
+ from constant import HYPERBOLIC_MODELS, MODEL_MAPPING
7
 
 
 
8
 
9
  def model_name_mapping(model_name):
10
+ model_mapping = MODEL_MAPPING
11
+ if model_name in model_mapping:
12
+ return model_mapping[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  else:
14
+ raise ValueError("Invalid model name:", model_name)
 
15
 
16
 
17
  def urial_template(urial_prompt, history, message):
 
20
  current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
21
  current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
22
  return current_prompt
23
+
24
+ def chat_template(history, message):
25
+ messages = []
26
+ for user_msg, ai_msg in history:
27
+ messages.append({"role": "user", "content": user_msg})
28
+ messages.append({"role": "assistant", "content": ai_msg})
29
+ messages.append({"role": "user", "content": message})
30
+ return messages
31
 
32
  def openai_base_request(
33
  model: str=None,
 
40
  stop: List[str]=None,
41
  api_key: str=None,
42
  ):
43
+
44
+ if model in HYPERBOLIC_MODELS:
45
+ BASE_URL = "https://api.hyperbolic.xyz/v1"
46
+ DEFAULT_API_KEY = os.getenv("HYPERBOLIC_API_KEY")
47
+ else:
48
+ BASE_URL = "https://api.together.xyz/v1"
49
+ DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY")
50
+
51
  if api_key is None:
52
  api_key = DEFAULT_API_KEY
53
+ client = OpenAI(api_key=api_key, base_url=BASE_URL)
54
+ logging.info(f"Requesting base completion from OpenAI API with model {model}")
 
55
  logging.info(f"Prompt: {prompt}")
56
  logging.info(f"Temperature: {temperature}")
57
  logging.info(f"Max tokens: {max_tokens}")
 
73
 
74
  return request
75
 
76
+
77
+
78
+ def openai_chat_request(
79
+ model: str=None,
80
+ temperature: float=0,
81
+ max_tokens: int=512,
82
+ top_p: float=1.0,
83
+ messages=None,
84
+ n: int=1,
85
+ repetition_penalty: float=1.0,
86
+ stop: List[str]=None,
87
+ api_key: str=None,
88
+ ):
89
+
90
+ if model in HYPERBOLIC_MODELS:
91
+ BASE_URL = "https://api.hyperbolic.xyz/v1"
92
+ DEFAULT_API_KEY = os.getenv("HYPERBOLIC_API_KEY")
93
+ else:
94
+ BASE_URL = "https://api.together.xyz/v1"
95
+ DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY")
96
+
97
+ if api_key is None:
98
+ api_key = DEFAULT_API_KEY
99
+
100
+ logging.info(f"Requesting chat completion from OpenAI API with model {model}")
101
+
102
+ client = OpenAI(api_key=api_key, base_url=BASE_URL)
103
+
104
+ request = client.chat.completions.create(
105
+ model=model,
106
+ messages=messages,
107
+ temperature=float(temperature),
108
+ max_tokens=int(max_tokens),
109
+ top_p=float(top_p),
110
+ n=n,
111
+ extra_body={'repetition_penalty': float(repetition_penalty)},
112
+ stop=stop,
113
+ stream=True
114
+ )
115
+ return request
116
+