Files changed (1) hide show
  1. app.py +95 -43
app.py CHANGED
@@ -2,13 +2,12 @@ import spaces
2
 
3
  import os
4
  import re
5
- import traceback
6
 
7
  import torch
8
  import gradio as gr
9
 
10
  import sys
11
- sys.path.append('./VideoLLaMA2')
12
  from videollama2 import model_init, mm_infer
13
  from videollama2.utils import disable_torch_init
14
 
@@ -98,7 +97,7 @@ class Chat:
98
 
99
 
100
  @spaces.GPU(duration=120)
101
- def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
102
  data = []
103
 
104
  processor = handler.processor
@@ -106,7 +105,15 @@ def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max
106
  if image is not None:
107
  data.append((processor['image'](image).to(handler.model.device, dtype=dtype), '<image>'))
108
  elif video is not None:
109
- data.append((processor['video'](video).to(handler.model.device, dtype=dtype), '<video>'))
 
 
 
 
 
 
 
 
110
  elif image is None and video is None:
111
  data.append((None, '<text>'))
112
  else:
@@ -122,6 +129,8 @@ def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max
122
  show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
123
  if video is not None:
124
  show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
 
 
125
 
126
  one_turn_chat = [textbox_in, None]
127
 
@@ -130,35 +139,50 @@ def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max
130
  one_turn_chat[0] += "\n" + show_images
131
  # 2. not first run case
132
  else:
133
- # scanning the last image or video
134
- length = len(chatbot)
135
- for i in range(length - 1, -1, -1):
136
- previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0])
137
- previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0])
138
-
139
- if len(previous_image) > 0:
140
- previous_image = previous_image[-1]
141
- # 2.1 new image append or pure text input will start a new conversation
142
- if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)):
143
- message.clear()
144
- one_turn_chat[0] += "\n" + show_images
145
- break
146
- elif len(previous_video) > 0:
147
- previous_video = previous_video[-1]
148
- # 2.2 new video append or pure text input will start a new conversation
149
- if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)):
150
- message.clear()
151
- one_turn_chat[0] += "\n" + show_images
152
- break
 
153
 
154
  message.append({'role': 'user', 'content': textbox_in})
 
 
 
 
 
 
 
 
155
  text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
 
 
 
 
 
 
156
  message.append({'role': 'assistant', 'content': text_en_out})
157
 
158
  one_turn_chat[1] = text_en_out
159
  chatbot.append(one_turn_chat)
160
 
161
- return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot
162
 
163
 
164
  def regenerate(message, chatbot):
@@ -170,6 +194,7 @@ def regenerate(message, chatbot):
170
  def clear_history(message, chatbot):
171
  message.clear(), chatbot.clear()
172
  return (gr.update(value=None, interactive=True),
 
173
  gr.update(value=None, interactive=True),
174
  message, chatbot,
175
  gr.update(value=None, interactive=True))
@@ -180,9 +205,9 @@ def clear_history(message, chatbot):
180
  # 2. The operation or tensor which requires cuda are limited in those functions wrapped via spaces.GPU
181
  # 3. The function can't return tensor or other cuda objects.
182
 
183
- model_path = 'DAMO-NLP-SG/VideoLLaMA2.1-7B-16F'
184
 
185
- handler = Chat(model_path, load_8bit=False, load_4bit=True)
186
 
187
  textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
188
 
@@ -194,6 +219,7 @@ theme.set(block_label_text_color="#9C276A")
194
  theme.set(button_primary_text_color="#9C276A")
195
  # theme.set(button_secondary_text_color="*neutral_800")
196
 
 
197
  with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=theme, css=block_css) as demo:
198
  gr.Markdown(title_markdown)
199
  message = gr.State([])
@@ -202,6 +228,7 @@ with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=theme, css=block_css) as
202
  with gr.Column(scale=3):
203
  image = gr.Image(label="Input Image", type="filepath")
204
  video = gr.Video(label="Input Video")
 
205
 
206
  with gr.Accordion("Parameters", open=True) as parameter_row:
207
  # num_beams = gr.Slider(
@@ -213,6 +240,8 @@ with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=theme, css=block_css) as
213
  # label="beam search numbers",
214
  # )
215
 
 
 
216
  temperature = gr.Slider(
217
  minimum=0.1,
218
  maximum=1.0,
@@ -256,8 +285,9 @@ with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=theme, css=block_css) as
256
  clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=True)
257
 
258
  with gr.Row():
 
 
259
  with gr.Column():
260
- cur_dir = os.path.dirname(os.path.abspath(__file__))
261
  gr.Examples(
262
  examples=[
263
  [
@@ -268,51 +298,73 @@ with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=theme, css=block_css) as
268
  f"{cur_dir}/examples/waterview.jpg",
269
  "What are the things I should be cautious about when I visit here?",
270
  ],
271
- [
272
- f"{cur_dir}/examples/desert.jpg",
273
- "If there are factual errors in the questions, point it out; if not, proceed answering the question. What’s happening in the desert?",
274
- ],
275
  ],
276
  inputs=[image, textbox],
277
  )
 
278
  with gr.Column():
279
  gr.Examples(
280
  examples=[
281
  [
282
- f"{cur_dir}/examples/rap.mp4",
283
- "What happens in this video?",
 
 
 
 
284
  ],
 
 
 
 
 
 
285
  [
286
- f"{cur_dir}/examples/demo2.mp4",
287
- "Do you think it's morning or night in this video? Why?",
288
  ],
289
  [
290
- f"{cur_dir}/examples/demo3.mp4",
291
- "At the intersection, in which direction does the red car turn?",
292
  ],
293
  ],
294
  inputs=[video, textbox],
295
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  gr.Markdown(tos_markdown)
298
  gr.Markdown(learn_more_markdown)
299
 
300
  submit_btn.click(
301
  generate,
302
- [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
303
- [image, video, message, chatbot])
304
 
305
  regenerate_btn.click(
306
  regenerate,
307
  [message, chatbot],
308
  [message, chatbot]).then(
309
  generate,
310
- [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
311
- [image, video, message, chatbot])
312
 
313
  clear_btn.click(
314
  clear_history,
315
  [message, chatbot],
316
- [image, video, message, chatbot, textbox])
317
 
318
  demo.launch()
 
2
 
3
  import os
4
  import re
 
5
 
6
  import torch
7
  import gradio as gr
8
 
9
  import sys
10
+ sys.path.append('./')
11
  from videollama2 import model_init, mm_infer
12
  from videollama2.utils import disable_torch_init
13
 
 
97
 
98
 
99
  @spaces.GPU(duration=120)
100
+ def generate(image, video, audio, message, chatbot, va_tag, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
101
  data = []
102
 
103
  processor = handler.processor
 
105
  if image is not None:
106
  data.append((processor['image'](image).to(handler.model.device, dtype=dtype), '<image>'))
107
  elif video is not None:
108
+ video_audio = processor['video'](video, va=va_tag=="Audio Vision")
109
+ if va_tag=="Audio Vision":
110
+ for k,v in video_audio.items():
111
+ video_audio[k] = v.to(handler.model.device, dtype=dtype)
112
+ else:
113
+ video_audio = video_audio.to(handler.model.device, dtype=dtype)
114
+ data.append((video_audio, '<video>'))
115
+ elif audio is not None:
116
+ data.append((processor['audio'](audio).to(handler.model.device, dtype=dtype), '<audio>'))
117
  elif image is None and video is None:
118
  data.append((None, '<text>'))
119
  else:
 
129
  show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
130
  if video is not None:
131
  show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
132
+ if audio is not None:
133
+ show_images += f'<audio controls style="display: inline-block;" src="./file={audio}"></audio>'
134
 
135
  one_turn_chat = [textbox_in, None]
136
 
 
139
  one_turn_chat[0] += "\n" + show_images
140
  # 2. not first run case
141
  else:
142
+ previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[0][0])
143
+ previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[0][0])
144
+ previous_audio = re.findall(r'<audio controls style="display: inline-block;" src="./file=(.+?)"', chatbot[0][0])
145
+ if len(previous_image) > 0:
146
+ previous_image = previous_image[0]
147
+ # 2.1 new image append or pure text input will start a new conversation
148
+ if image is not None and os.path.basename(previous_image) != os.path.basename(image):
149
+ message.clear()
150
+ one_turn_chat[0] += "\n" + show_images
151
+ elif len(previous_video) > 0:
152
+ previous_video = previous_video[0]
153
+ # 2.2 new video append or pure text input will start a new conversation
154
+ if video is not None and os.path.basename(previous_video) != os.path.basename(video):
155
+ message.clear()
156
+ one_turn_chat[0] += "\n" + show_images
157
+ elif len(previous_audio) > 0:
158
+ previous_audio = previous_audio[0]
159
+ # 2.3 new audio append or pure text input will start a new conversation
160
+ if audio is not None and os.path.basename(previous_audio) != os.path.basename(video):
161
+ message.clear()
162
+ one_turn_chat[0] += "\n" + show_images
163
 
164
  message.append({'role': 'user', 'content': textbox_in})
165
+
166
+ if va_tag == "Vision Only":
167
+ audio_tower = handler.model.model.audio_tower
168
+ handler.model.model.audio_tower = None
169
+ elif va_tag == "Audio Only":
170
+ vision_tower = handler.model.model.vision_tower
171
+ handler.model.model.vision_tower = None
172
+
173
  text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
174
+
175
+ if va_tag == "Vision Only":
176
+ handler.model.model.audio_tower = audio_tower
177
+ elif va_tag == "Audio Only":
178
+ handler.model.model.vision_tower = vision_tower
179
+
180
  message.append({'role': 'assistant', 'content': text_en_out})
181
 
182
  one_turn_chat[1] = text_en_out
183
  chatbot.append(one_turn_chat)
184
 
185
+ return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), gr.update(value=audio, interactive=True), message, chatbot
186
 
187
 
188
  def regenerate(message, chatbot):
 
194
  def clear_history(message, chatbot):
195
  message.clear(), chatbot.clear()
196
  return (gr.update(value=None, interactive=True),
197
+ gr.update(value=None, interactive=True),
198
  gr.update(value=None, interactive=True),
199
  message, chatbot,
200
  gr.update(value=None, interactive=True))
 
205
  # 2. The operation or tensor which requires cuda are limited in those functions wrapped via spaces.GPU
206
  # 3. The function can't return tensor or other cuda objects.
207
 
208
+ model_path = 'DAMO-NLP-SG/VideoLLaMA2.1-7B-AV'
209
 
210
+ handler = Chat(model_path, load_8bit=False, load_4bit=False)
211
 
212
  textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
213
 
 
219
  theme.set(button_primary_text_color="#9C276A")
220
  # theme.set(button_secondary_text_color="*neutral_800")
221
 
222
+
223
  with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=theme, css=block_css) as demo:
224
  gr.Markdown(title_markdown)
225
  message = gr.State([])
 
228
  with gr.Column(scale=3):
229
  image = gr.Image(label="Input Image", type="filepath")
230
  video = gr.Video(label="Input Video")
231
+ audio = gr.Audio(label="Input Audio", type="filepath")
232
 
233
  with gr.Accordion("Parameters", open=True) as parameter_row:
234
  # num_beams = gr.Slider(
 
240
  # label="beam search numbers",
241
  # )
242
 
243
+ va_tag = gr.Radio(choices=["Audio Vision", "Vision Only", "Audio Only"], value="Audio Vision", label="Select one")
244
+
245
  temperature = gr.Slider(
246
  minimum=0.1,
247
  maximum=1.0,
 
285
  clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=True)
286
 
287
  with gr.Row():
288
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
289
+
290
  with gr.Column():
 
291
  gr.Examples(
292
  examples=[
293
  [
 
298
  f"{cur_dir}/examples/waterview.jpg",
299
  "What are the things I should be cautious about when I visit here?",
300
  ],
 
 
 
 
301
  ],
302
  inputs=[image, textbox],
303
  )
304
+
305
  with gr.Column():
306
  gr.Examples(
307
  examples=[
308
  [
309
+ f"{cur_dir}/examples/WBS4I.mp4",
310
+ "Please describe the video:",
311
+ ],
312
+ [
313
+ f"{cur_dir}/examples/sample_demo_1.mp4",
314
+ "Please describe the video:",
315
  ],
316
+ ],
317
+ inputs=[video, textbox],
318
+ )
319
+ with gr.Column():
320
+ gr.Examples(
321
+ examples=[
322
  [
323
+ f"{cur_dir}/examples/00000368.mp4",
324
+ "Where is the loudest instrument?",
325
  ],
326
  [
327
+ f"{cur_dir}/examples/00003491.mp4",
328
+ "Is the instrument on the left louder than the instrument on the right?",
329
  ],
330
  ],
331
  inputs=[video, textbox],
332
  )
333
+ with gr.Column():
334
+ # audio
335
+ gr.Examples(
336
+ examples=[
337
+ [
338
+ f"{cur_dir}/examples/Y--ZHUMfueO0.flac",
339
+ "Please describe the audio:",
340
+ ],
341
+ [
342
+ f"{cur_dir}/examples/Traffic and pedestrians.wav",
343
+ "Please describe the audio:",
344
+ ],
345
+ ],
346
+ inputs=[audio, textbox],
347
+ )
348
 
349
  gr.Markdown(tos_markdown)
350
  gr.Markdown(learn_more_markdown)
351
 
352
  submit_btn.click(
353
  generate,
354
+ [image, video, audio, message, chatbot, va_tag, textbox, temperature, top_p, max_output_tokens],
355
+ [image, video, audio, message, chatbot])
356
 
357
  regenerate_btn.click(
358
  regenerate,
359
  [message, chatbot],
360
  [message, chatbot]).then(
361
  generate,
362
+ [image, video, audio, message, chatbot, va_tag, textbox, temperature, top_p, max_output_tokens],
363
+ [image, video, audio, message, chatbot])
364
 
365
  clear_btn.click(
366
  clear_history,
367
  [message, chatbot],
368
+ [image, video, audio, message, chatbot, textbox])
369
 
370
  demo.launch()