yuchenlin commited on
Commit
8df0f23
·
1 Parent(s): 54f7da0
__pycache__/constant.cpython-311.pyc ADDED
Binary file (1.59 kB). View file
 
__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.46 kB). View file
 
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import gradio as gr
2
- from openai import OpenAI
3
  import os
4
  from typing import List
5
  import logging
6
  import urllib.request
 
 
7
 
8
  # add logging info to console
9
  logging.basicConfig(level=logging.INFO)
10
 
11
 
12
- BASE_URL = "https://api.together.xyz/v1"
13
- DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY")
14
  URIAL_VERSION = "inst_1k_v4.help"
15
 
16
  URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
@@ -18,55 +18,6 @@ urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
18
  urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
19
  STOP_STRS = ['"""', '# Query:', '# Answer:']
20
 
21
- def urial_template(urial_prompt, history, message):
22
- current_prompt = urial_prompt + "\n"
23
- for user_msg, ai_msg in history:
24
- current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
25
- current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
26
- return current_prompt
27
-
28
-
29
-
30
-
31
- def openai_base_request(
32
- model: str=None,
33
- temperature: float=0,
34
- max_tokens: int=512,
35
- top_p: float=1.0,
36
- prompt: str=None,
37
- n: int=1,
38
- repetition_penalty: float=1.0,
39
- stop: List[str]=None,
40
- api_key: str=None,
41
- ):
42
- if api_key is None:
43
- api_key = DEFAULT_API_KEY
44
- client = OpenAI(api_key=api_key, base_url=BASE_URL)
45
- # print(f"Requesting chat completion from OpenAI API with model {model}")
46
- logging.info(f"Requesting chat completion from OpenAI API with model {model}")
47
- logging.info(f"Prompt: {prompt}")
48
- logging.info(f"Temperature: {temperature}")
49
- logging.info(f"Max tokens: {max_tokens}")
50
- logging.info(f"Top-p: {top_p}")
51
- logging.info(f"Repetition penalty: {repetition_penalty}")
52
- logging.info(f"Stop: {stop}")
53
-
54
- request = client.completions.create(
55
- model=model,
56
- prompt=prompt,
57
- temperature=float(temperature),
58
- max_tokens=int(max_tokens),
59
- top_p=float(top_p),
60
- n=n,
61
- extra_body={'repetition_penalty': float(repetition_penalty)},
62
- stop=stop,
63
- stream=True
64
- )
65
-
66
- return request
67
-
68
-
69
-
70
 
71
  def respond(
72
  message,
@@ -81,29 +32,9 @@ def respond(
81
  global STOP_STRS, urial_prompt
82
  rp = 1.0
83
  prompt = urial_template(urial_prompt, history, message)
84
- if model_name == "Llama-3-8B":
85
- _model_name = "meta-llama/Llama-3-8b-hf"
86
- elif model_name == "Llama-3-70B":
87
- _model_name = "meta-llama/Llama-3-70b-hf"
88
- elif model_name == "Llama-2-7B":
89
- _model_name = "meta-llama/Llama-2-7b-hf"
90
- elif model_name == "Llama-2-70B":
91
- _model_name = "meta-llama/Llama-2-70b-hf"
92
- elif model_name == "Mistral-7B-v0.1":
93
- _model_name = "mistralai/Mistral-7B-v0.1"
94
- elif model_name == "Mixtral-8x22B":
95
- _model_name = "mistralai/Mixtral-8x22B"
96
- elif model_name == "Qwen1.5-72B":
97
- _model_name = "Qwen/Qwen1.5-72B"
98
- elif model_name == "Yi-34B":
99
- _model_name = "zero-one-ai/Yi-34B"
100
- elif model_name == "Yi-6B":
101
- _model_name = "zero-one-ai/Yi-6B"
102
- elif model_name == "OLMO":
103
- _model_name = "allenai/OLMo-7B"
104
- else:
105
- raise ValueError("Invalid model name")
106
  # _model_name = "meta-llama/Llama-3-8b-hf"
 
107
 
108
  if together_api_key and len(together_api_key) == 64:
109
  api_key = together_api_key
@@ -116,7 +47,6 @@ def respond(
116
  top_p=top_p,
117
  repetition_penalty=rp,
118
  stop=STOP_STRS, api_key=api_key)
119
-
120
  response = ""
121
  for msg in request:
122
  # print(msg.choices[0].delta.keys())
@@ -135,43 +65,10 @@ def respond(
135
  response = response[:-2]
136
  yield response
137
 
138
- js_code_label = """
139
- function addApiKeyLink() {
140
- // Select the div with id 'api_key'
141
- const apiKeyDiv = document.getElementById('api_key');
142
-
143
- // Find the span within that div with data-testid 'block-info'
144
- const blockInfoSpan = apiKeyDiv.querySelector('span[data-testid="block-info"]');
145
-
146
- // Create the new link element
147
- const newLink = document.createElement('a');
148
- newLink.href = 'https://api.together.ai/settings/api-keys';
149
- newLink.textContent = ' View your keys here.';
150
- newLink.target = '_blank'; // Open link in new tab
151
- newLink.style = 'color: #007bff; text-decoration: underline;';
152
-
153
- // Create the additional text
154
- const additionalText = document.createTextNode(' (new account will have free credits to use.)');
155
-
156
- // Append the link and additional text to the span
157
- if (blockInfoSpan) {
158
- // add a br
159
- apiKeyDiv.appendChild(document.createElement('br'));
160
- apiKeyDiv.appendChild(newLink);
161
- apiKeyDiv.appendChild(additionalText);
162
- } else {
163
- console.error('Span with data-testid "block-info" not found');
164
- }
165
- }
166
- """
167
  with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
168
  with gr.Row():
169
  with gr.Column():
170
- gr.Markdown("""# 💬 BaseChat: Chat with Base LLMs with URIAL
171
- [Paper](https://arxiv.org/abs/2312.01552) | [Website](https://allenai.github.io/re-align/) | [GitHub](https://github.com/Re-Align/urial) | Contact: [Yuchen Lin](https://yuchenlin.xyz/)
172
-
173
- **Talk with __BASE__ LLMs which are not fine-tuned at all.**
174
- """)
175
  model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1",
176
  "Mixtral-8x22B", "Yi-6B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
177
  , value="Llama-3-8B", label="Base LLM name")
@@ -181,12 +78,8 @@ with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
181
  with gr.Row():
182
  max_tokens = gr.Textbox(value=256, label="Max tokens")
183
  temperature = gr.Textbox(value=0.5, label="Temperature")
184
- # with gr.Column():
185
- # with gr.Row():
186
  top_p = gr.Textbox(value=0.9, label="Top-p")
187
  rp = gr.Textbox(value=1.1, label="Repetition penalty")
188
-
189
-
190
  chat = gr.ChatInterface(
191
  respond,
192
  additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
 
1
  import gradio as gr
 
2
  import os
3
  from typing import List
4
  import logging
5
  import urllib.request
6
+ from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY
7
+ from constant import js_code_label, HEADER_MD
8
 
9
  # add logging info to console
10
  logging.basicConfig(level=logging.INFO)
11
 
12
 
13
+
 
14
  URIAL_VERSION = "inst_1k_v4.help"
15
 
16
  URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
 
18
  urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
19
  STOP_STRS = ['"""', '# Query:', '# Answer:']
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def respond(
23
  message,
 
32
  global STOP_STRS, urial_prompt
33
  rp = 1.0
34
  prompt = urial_template(urial_prompt, history, message)
35
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # _model_name = "meta-llama/Llama-3-8b-hf"
37
+ _model_name = model_name_mapping(model_name)
38
 
39
  if together_api_key and len(together_api_key) == 64:
40
  api_key = together_api_key
 
47
  top_p=top_p,
48
  repetition_penalty=rp,
49
  stop=STOP_STRS, api_key=api_key)
 
50
  response = ""
51
  for msg in request:
52
  # print(msg.choices[0].delta.keys())
 
65
  response = response[:-2]
66
  yield response
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
69
  with gr.Row():
70
  with gr.Column():
71
+ gr.Markdown(HEADER_MD)
 
 
 
 
72
  model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1",
73
  "Mixtral-8x22B", "Yi-6B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
74
  , value="Llama-3-8B", label="Base LLM name")
 
78
  with gr.Row():
79
  max_tokens = gr.Textbox(value=256, label="Max tokens")
80
  temperature = gr.Textbox(value=0.5, label="Temperature")
 
 
81
  top_p = gr.Textbox(value=0.9, label="Top-p")
82
  rp = gr.Textbox(value=1.1, label="Repetition penalty")
 
 
83
  chat = gr.ChatInterface(
84
  respond,
85
  additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
constant.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HEADER_MD = """# 💬 BaseChat: Chat with Base LLMs with URIAL
2
+ [Paper](https://arxiv.org/abs/2312.01552) | [Website](https://allenai.github.io/re-align/) | [GitHub](https://github.com/Re-Align/urial) | Contact: [Yuchen Lin](https://yuchenlin.xyz/)
3
+
4
+ **Talk with __BASE__ LLMs which are not fine-tuned at all.**
5
+ """
6
+
7
+ js_code_label = """
8
+ function addApiKeyLink() {
9
+ // Select the div with id 'api_key'
10
+ const apiKeyDiv = document.getElementById('api_key');
11
+
12
+ // Find the span within that div with data-testid 'block-info'
13
+ const blockInfoSpan = apiKeyDiv.querySelector('span[data-testid="block-info"]');
14
+
15
+ // Create the new link element
16
+ const newLink = document.createElement('a');
17
+ newLink.href = 'https://api.together.ai/settings/api-keys';
18
+ newLink.textContent = ' View your keys here.';
19
+ newLink.target = '_blank'; // Open link in new tab
20
+ newLink.style = 'color: #007bff; text-decoration: underline;';
21
+
22
+ // Create the additional text
23
+ const additionalText = document.createTextNode(' (new account will have free credits to use.)');
24
+
25
+ // Append the link and additional text to the span
26
+ if (blockInfoSpan) {
27
+ // add a br
28
+ apiKeyDiv.appendChild(document.createElement('br'));
29
+ apiKeyDiv.appendChild(newLink);
30
+ apiKeyDiv.appendChild(additionalText);
31
+ } else {
32
+ console.error('Span with data-testid "block-info" not found');
33
+ }
34
+ }
35
+ """
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from openai import OpenAI
3
+ import logging
4
+ from typing import List
5
+ import os
6
+
7
+ BASE_URL = "https://api.together.xyz/v1"
8
+ DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY")
9
+
10
+ def model_name_mapping(model_name):
11
+ if model_name == "Llama-3-8B":
12
+ _model_name = "meta-llama/Llama-3-8b-hf"
13
+ elif model_name == "Llama-3-70B":
14
+ _model_name = "meta-llama/Llama-3-70b-hf"
15
+ elif model_name == "Llama-2-7B":
16
+ _model_name = "meta-llama/Llama-2-7b-hf"
17
+ elif model_name == "Llama-2-70B":
18
+ _model_name = "meta-llama/Llama-2-70b-hf"
19
+ elif model_name == "Mistral-7B-v0.1":
20
+ _model_name = "mistralai/Mistral-7B-v0.1"
21
+ elif model_name == "Mixtral-8x22B":
22
+ _model_name = "mistralai/Mixtral-8x22B"
23
+ elif model_name == "Qwen1.5-72B":
24
+ _model_name = "Qwen/Qwen1.5-72B"
25
+ elif model_name == "Yi-34B":
26
+ _model_name = "zero-one-ai/Yi-34B"
27
+ elif model_name == "Yi-6B":
28
+ _model_name = "zero-one-ai/Yi-6B"
29
+ elif model_name == "OLMO":
30
+ _model_name = "allenai/OLMo-7B"
31
+ else:
32
+ raise ValueError("Invalid model name")
33
+ return _model_name
34
+
35
+
36
+ def urial_template(urial_prompt, history, message):
37
+ current_prompt = urial_prompt + "\n"
38
+ for user_msg, ai_msg in history:
39
+ current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
40
+ current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
41
+ return current_prompt
42
+
43
+
44
+ def openai_base_request(
45
+ model: str=None,
46
+ temperature: float=0,
47
+ max_tokens: int=512,
48
+ top_p: float=1.0,
49
+ prompt: str=None,
50
+ n: int=1,
51
+ repetition_penalty: float=1.0,
52
+ stop: List[str]=None,
53
+ api_key: str=None,
54
+ ):
55
+ if api_key is None:
56
+ api_key = DEFAULT_API_KEY
57
+ client = OpenAI(api_key=api_key, base_url=BASE_URL)
58
+ # print(f"Requesting chat completion from OpenAI API with model {model}")
59
+ logging.info(f"Requesting chat completion from OpenAI API with model {model}")
60
+ logging.info(f"Prompt: {prompt}")
61
+ logging.info(f"Temperature: {temperature}")
62
+ logging.info(f"Max tokens: {max_tokens}")
63
+ logging.info(f"Top-p: {top_p}")
64
+ logging.info(f"Repetition penalty: {repetition_penalty}")
65
+ logging.info(f"Stop: {stop}")
66
+
67
+ request = client.completions.create(
68
+ model=model,
69
+ prompt=prompt,
70
+ temperature=float(temperature),
71
+ max_tokens=int(max_tokens),
72
+ top_p=float(top_p),
73
+ n=n,
74
+ extra_body={'repetition_penalty': float(repetition_penalty)},
75
+ stop=stop,
76
+ stream=True
77
+ )
78
+
79
+ return request
80
+