yuantao-infini-ai commited on
Commit
72e6273
·
1 Parent(s): 40b01dd

Init commit

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +289 -0
  3. requirements.txt +1 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐠
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.9.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.3.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # File: app.py
3
+ # Description: None
4
+
5
+
6
+ from copy import deepcopy
7
+ from typing import Dict, List
8
+ from PIL import Image
9
+ import io
10
+ import subprocess
11
+ import requests
12
+ import json
13
+ import base64
14
+ import gradio as gr
15
+ import librosa
16
+
17
+
18
+ IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")
19
+ VIDEO_EXTENSIONS = (".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v")
20
+ AUDIO_EXTENSIONS = (".mp3", ".wav", "flac", ".m4a")
21
+
22
+ DEFAULT_SAMPLING_PARAMS = {
23
+ "top_p": 0.8,
24
+ "top_k": 100,
25
+ "temperature": 0.7,
26
+ "do_sample": True,
27
+ "num_beams": 1,
28
+ "repetition_penalty": 1.2,
29
+ }
30
+ MAX_NEW_TOKENS = 1024
31
+
32
+
33
+
34
+ def load_image_to_base64(image_path):
35
+ """Load image and convert to base64 string"""
36
+ with Image.open(image_path) as img:
37
+ if img.mode != 'RGB':
38
+ img = img.convert('RGB')
39
+ img_byte_arr = io.BytesIO()
40
+ img.save(img_byte_arr, format='PNG')
41
+ img_byte_arr = img_byte_arr.getvalue()
42
+ return base64.b64encode(img_byte_arr).decode('utf-8')
43
+
44
+ def wav_to_bytes_with_ffmpeg(wav_file_path):
45
+ process = subprocess.Popen(
46
+ ['ffmpeg', '-i', wav_file_path, '-f', 'wav', '-'],
47
+ stdout=subprocess.PIPE,
48
+ stderr=subprocess.PIPE
49
+ )
50
+ out, _ = process.communicate()
51
+ return base64.b64encode(out).decode('utf-8')
52
+
53
+ def parse_sse_response(response):
54
+ for line in response.iter_lines():
55
+ if line:
56
+ line = line.decode('utf-8')
57
+ if line.startswith('data: '):
58
+ data = line[6:] # Remove 'data: ' prefix
59
+ if data == '[DONE]':
60
+ break
61
+ try:
62
+ json_data = json.loads(data)
63
+ yield json_data['text']
64
+ except json.JSONDecodeError:
65
+ raise gr.Error(f"Failed to parse JSON: {data}")
66
+
67
+ def history2messages(history: List[Dict]) -> List[Dict]:
68
+ """
69
+ Transform gradio history to chat messages.
70
+ """
71
+ messages = []
72
+ cur_message = dict()
73
+ for item in history:
74
+ if item["role"] == "assistant":
75
+ if len(cur_message) > 0:
76
+ messages.append(deepcopy(cur_message))
77
+ cur_message = dict()
78
+ messages.append(deepcopy(item))
79
+ continue
80
+
81
+ if "role" not in cur_message:
82
+ cur_message["role"] = "user"
83
+ if "content" not in cur_message:
84
+ cur_message["content"] = dict()
85
+
86
+ if "metadata" not in item:
87
+ item["metadata"] = {"title": None}
88
+ if item["metadata"]["title"] is None:
89
+ cur_message["content"]["text"] = item["content"]
90
+ elif item["metadata"]["title"] == "image":
91
+ cur_message["content"]["image"] = load_image_to_base64(item["content"][0])
92
+ elif item["metadata"]["title"] == "audio":
93
+ cur_message["content"]["audio"] = wav_to_bytes_with_ffmpeg(item["content"][0])
94
+ if len(cur_message) > 0:
95
+ messages.append(cur_message)
96
+ return messages
97
+
98
+ def check_messages(history, message, audio):
99
+ has_text = message["text"] and message["text"].strip()
100
+ has_files = len(message["files"]) > 0
101
+ has_audio = audio is not None
102
+
103
+ if not (has_text or has_files or has_audio):
104
+ raise gr.Error("请输入文字或上传音频/图片后再发送。")
105
+
106
+ audios = []
107
+ images = []
108
+
109
+ for file_msg in message["files"]:
110
+ if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS):
111
+ duration = librosa.get_duration(filename=file_msg)
112
+ if duration > 30:
113
+ raise gr.Error("音频时长不能超过30秒。")
114
+ if duration == 0:
115
+ raise gr.Error("音频时长不能为0秒。")
116
+ audios.append(file_msg)
117
+ elif file_msg.endswith(IMAGE_EXTENSIONS):
118
+ images.append(file_msg)
119
+ else:
120
+ filename = file_msg.split("/")[-1]
121
+ raise gr.Error(f"Unsupported file type: {filename}. It should be an image or audio file.")
122
+
123
+ if len(audios) > 1:
124
+ raise gr.Error("Please upload only one audio file.")
125
+
126
+ if len(images) > 1:
127
+ raise gr.Error("Please upload only one image file.")
128
+
129
+ if audio is not None:
130
+ if len(audios) > 0:
131
+ raise gr.Error("Please upload only one audio file or record audio.")
132
+ audios.append(audio)
133
+
134
+ # Append the message to the history
135
+ for image in images:
136
+ history.append({"role": "user", "content": (image,), "metadata": {"title": "image"}})
137
+
138
+ for audio in audios:
139
+ history.append({"role": "user", "content": (audio,), "metadata": {"title": "audio"}})
140
+
141
+ if message["text"]:
142
+ history.append({"role": "user", "content": message["text"]})
143
+
144
+ return history, gr.MultimodalTextbox(value=None, interactive=False), None
145
+
146
+ def bot(
147
+ history: list,
148
+ top_p: float,
149
+ top_k: int,
150
+ temperature: float,
151
+ repetition_penalty: float,
152
+ max_new_tokens: int = MAX_NEW_TOKENS,
153
+ regenerate: bool = False,
154
+ ):
155
+
156
+ if history and regenerate:
157
+ history = history[:-1]
158
+
159
+ if not history:
160
+ return history
161
+
162
+ msgs = history2messages(history)
163
+
164
+ API_URL = "http://8.152.0.142:8000/v1/chat"
165
+
166
+ payload = {
167
+ "messages": msgs,
168
+ "sampling_params": {
169
+ "top_p": top_p,
170
+ "top_k": top_k,
171
+ "temperature": temperature,
172
+ "repetition_penalty": repetition_penalty,
173
+ "max_new_tokens": max_new_tokens
174
+ }
175
+ }
176
+
177
+ response = requests.post(
178
+ API_URL,
179
+ json=payload,
180
+ headers={'Accept': 'text/event-stream'},
181
+ stream=True
182
+ )
183
+
184
+ response_text = ""
185
+ for text in parse_sse_response(response):
186
+ response_text += text
187
+ yield history + [{"role": "assistant", "content": response_text}]
188
+
189
+ return response_text
190
+
191
+ def change_state(state):
192
+ return gr.update(visible=not state), not state
193
+
194
+ def reset_user_input():
195
+ return gr.update(value="")
196
+
197
+ if __name__ == "__main__":
198
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
199
+ gr.Markdown(
200
+ f"""
201
+ # 🪐 Chat with <a href="https://github.com/infinigence/Infini-Megrez-Omni">Megrez-3B-Omni</a>
202
+ """
203
+ )
204
+ chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages", height='48vh')
205
+
206
+ sampling_params_group_hidden_state = gr.State(False)
207
+
208
+
209
+ with gr.Row(equal_height=True):
210
+ chat_input = gr.MultimodalTextbox(
211
+ file_count="multiple",
212
+ placeholder="Enter your prompt or upload image/audio here, then press ENTER...",
213
+ show_label=False,
214
+ scale=8,
215
+ file_types=["image", "audio"],
216
+ interactive=True,
217
+ # stop_btn=True,
218
+ )
219
+ with gr.Row(equal_height=True):
220
+ audio_input = gr.Audio(
221
+ sources=["microphone", "upload"],
222
+ type="filepath",
223
+ scale=1,
224
+ max_length=30
225
+ )
226
+ with gr.Row(equal_height=True):
227
+ with gr.Column(scale=1, min_width=150):
228
+ with gr.Row(equal_height=True):
229
+ regenerate_btn = gr.Button("Regenerate", variant="primary")
230
+ clear_btn = gr.ClearButton(
231
+ [chat_input, audio_input, chatbot],
232
+ )
233
+
234
+ with gr.Row():
235
+ sampling_params_toggle_btn = gr.Button("Sampling Parameters")
236
+
237
+ with gr.Group(visible=False) as sampling_params_group:
238
+ with gr.Row():
239
+ temperature = gr.Slider(
240
+ minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS["temperature"], label="Temperature"
241
+ )
242
+ repetition_penalty = gr.Slider(
243
+ minimum=0,
244
+ maximum=2,
245
+ value=DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
246
+ label="Repetition Penalty",
247
+ )
248
+
249
+ with gr.Row():
250
+ top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS["top_p"], label="Top-p")
251
+ top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS["top_k"], label="Top-k")
252
+
253
+ with gr.Row():
254
+ max_new_tokens = gr.Slider(
255
+ minimum=1,
256
+ maximum=MAX_NEW_TOKENS,
257
+ value=MAX_NEW_TOKENS,
258
+ label="Max New Tokens",
259
+ interactive=True,
260
+ )
261
+
262
+ sampling_params_toggle_btn.click(
263
+ change_state,
264
+ sampling_params_group_hidden_state,
265
+ [sampling_params_group, sampling_params_group_hidden_state],
266
+ )
267
+
268
+ chat_msg = chat_input.submit(
269
+ check_messages,
270
+ [chatbot, chat_input, audio_input],
271
+ [chatbot, chat_input, audio_input],
272
+ )
273
+
274
+ bot_msg = chat_msg.then(
275
+ bot,
276
+ inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens],
277
+ outputs=chatbot,
278
+ api_name="bot_response",
279
+ )
280
+
281
+ bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
282
+
283
+ regenerate_btn.click(
284
+ bot,
285
+ inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)],
286
+ outputs=chatbot,
287
+ )
288
+
289
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ librosa