ClownRat commited on
Commit
000c55e
Β·
1 Parent(s): b76f6c0
app.py CHANGED
@@ -19,7 +19,7 @@ from videollama2.mm_utils import KeywordsStoppingCriteria, tokenizer_MMODAL_toke
19
  title_markdown = ("""
20
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
21
  <a href="https://github.com/DAMO-NLP-SG/VideoLLaMA2" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
22
- <img src="https://s2.loli.net/2024/06/03/D3NeXHWy5az9tmT.png" alt="VideoLLaMA2πŸš€" style="max-width: 120px; height: auto;">
23
  </a>
24
  <div>
25
  <h1 >VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs</h1>
@@ -89,9 +89,8 @@ class Chat:
89
  # 2. text preprocess (tag process & generate prompt).
90
  state = self.get_prompt(prompt, state)
91
  prompt = state.get_prompt()
92
- # print('\n\n\n')
93
- # print(prompt)
94
- input_ids = tokenizer_MMODAL_token(prompt, tokenizer, MMODAL_TOKEN_INDEX[modals[0]], return_tensors='pt').unsqueeze(0).to(self.model.device)
95
 
96
  # 3. generate response according to visual signals and prompts.
97
  stop_str = self.conv.sep if self.conv.sep_style in [SeparatorStyle.SINGLE] else self.conv.sep2
@@ -116,19 +115,6 @@ class Chat:
116
  return outputs, state
117
 
118
 
119
- def save_image_to_local(image):
120
- filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
121
- image = Image.open(image)
122
- image.save(filename)
123
- return filename
124
-
125
-
126
- def save_video_to_local(video_path):
127
- filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
128
- shutil.copyfile(video_path, filename)
129
- return filename
130
-
131
-
132
  @spaces.GPU(duration=120)
133
  def generate(image, video, first_run, state, state_, textbox_in, dtype=torch.float16):
134
  flag = 1
@@ -180,14 +166,10 @@ def generate(image, video, first_run, state, state_, textbox_in, dtype=torch.flo
180
  text_en_out = text_en_out.split('#')[0]
181
  textbox_out = text_en_out
182
 
183
- print(image, video)
184
-
185
  show_images = ""
186
  if os.path.exists(image):
187
- # filename = save_image_to_local(image)
188
  show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
189
  if os.path.exists(video):
190
- # filename = save_video_to_local(video)
191
  show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
192
 
193
  if flag:
@@ -215,58 +197,30 @@ def clear_history(state, state_):
215
  state.to_gradio_chatbot(), \
216
  True, state, state_, gr.update(value=None, interactive=True))
217
 
 
 
 
 
218
 
219
  conv_mode = "llama_2"
220
  model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B'
221
 
222
- def find_cuda():
223
- # Check if CUDA_HOME or CUDA_PATH environment variables are set
224
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
225
-
226
- if cuda_home and os.path.exists(cuda_home):
227
- return cuda_home
228
-
229
- # Search for the nvcc executable in the system's PATH
230
- nvcc_path = shutil.which('nvcc')
231
-
232
- if nvcc_path:
233
- # Remove the 'bin/nvcc' part to get the CUDA installation path
234
- cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
235
- return cuda_path
236
-
237
- return None
238
-
239
- cuda_path = find_cuda()
240
-
241
- if cuda_path:
242
- print(f"CUDA installation found at: {cuda_path}")
243
- else:
244
- print("CUDA installation not found")
245
-
246
  device = torch.device("cuda")
247
 
248
  handler = Chat(model_path, conv_mode=conv_mode, load_8bit=False, load_4bit=True)
249
- # handler.model.to(dtype=torch.float16)
250
- # handler = handler.model.to(device)
251
 
252
- if not os.path.exists("temp"):
253
- os.makedirs("temp")
254
 
255
- textbox = gr.Textbox(
256
- show_label=False, placeholder="Enter text and press ENTER", container=False
257
- )
258
- with gr.Blocks(title='VideoLLaMA2πŸš€', theme=gr.themes.Default(), css=block_css) as demo:
259
  gr.Markdown(title_markdown)
260
  state = gr.State()
261
  state_ = gr.State()
262
  first_run = gr.State()
263
- # tensor = gr.State()
264
- # modals = gr.State()
265
 
266
  with gr.Row():
267
  with gr.Column(scale=3):
268
  image = gr.Image(label="Input Image", type="filepath")
269
- video = gr.Video(label="Input Video")
270
 
271
  cur_dir = os.path.dirname(os.path.abspath(__file__))
272
  gr.Examples(
@@ -288,19 +242,19 @@ with gr.Blocks(title='VideoLLaMA2πŸš€', theme=gr.themes.Default(), css=block_css
288
  )
289
 
290
  with gr.Column(scale=7):
291
- chatbot = gr.Chatbot(label="VideoLLaMA2", bubble_full_width=True, height=750)
292
  with gr.Row():
293
  with gr.Column(scale=8):
294
  textbox.render()
295
  with gr.Column(scale=1, min_width=50):
296
  submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
297
  with gr.Row(elem_id="buttons") as button_row:
298
- upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=True)
299
- downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=True)
300
- # flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
301
- # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
302
  regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=True)
303
- clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=True)
304
 
305
  gr.Markdown(tos_markdown)
306
  gr.Markdown(learn_more_markdown)
@@ -308,9 +262,7 @@ with gr.Blocks(title='VideoLLaMA2πŸš€', theme=gr.themes.Default(), css=block_css
308
  submit_btn.click(
309
  generate,
310
  [image, video, first_run, state, state_, textbox],
311
- [image, video, chatbot, first_run, state, state_, textbox,
312
- # tensor, modals
313
- ])
314
 
315
  regenerate_btn.click(
316
  regenerate,
 
19
  title_markdown = ("""
20
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
21
  <a href="https://github.com/DAMO-NLP-SG/VideoLLaMA2" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
22
+ <img src="https://s2.loli.net/2024/06/03/D3NeXHWy5az9tmT.png" alt="VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯" style="max-width: 120px; height: auto;">
23
  </a>
24
  <div>
25
  <h1 >VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs</h1>
 
89
  # 2. text preprocess (tag process & generate prompt).
90
  state = self.get_prompt(prompt, state)
91
  prompt = state.get_prompt()
92
+ input_ids = tokenizer_MMODAL_token(prompt, tokenizer, MMODAL_TOKEN_INDEX[modals[0]], return_tensors='pt')
93
+ input_ids = input_ids.unsqueeze(0).to(self.model.device)
 
94
 
95
  # 3. generate response according to visual signals and prompts.
96
  stop_str = self.conv.sep if self.conv.sep_style in [SeparatorStyle.SINGLE] else self.conv.sep2
 
115
  return outputs, state
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  @spaces.GPU(duration=120)
119
  def generate(image, video, first_run, state, state_, textbox_in, dtype=torch.float16):
120
  flag = 1
 
166
  text_en_out = text_en_out.split('#')[0]
167
  textbox_out = text_en_out
168
 
 
 
169
  show_images = ""
170
  if os.path.exists(image):
 
171
  show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
172
  if os.path.exists(video):
 
173
  show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
174
 
175
  if flag:
 
197
  state.to_gradio_chatbot(), \
198
  True, state, state_, gr.update(value=None, interactive=True))
199
 
200
+ # BUG of Zero Environment
201
+ # 1. The environment is fixed to torch==2.0.1+cu117, gradio>=4.x.x
202
+ # 2. The operation or tensor which requires cuda are limited in those functions wrapped via spaces.GPU
203
+ # 3. The function can't return tensor or other cuda objects.
204
 
205
  conv_mode = "llama_2"
206
  model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B'
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  device = torch.device("cuda")
209
 
210
  handler = Chat(model_path, conv_mode=conv_mode, load_8bit=False, load_4bit=True)
 
 
211
 
212
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
 
213
 
214
+ with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=gr.themes.Default(), css=block_css) as demo:
 
 
 
215
  gr.Markdown(title_markdown)
216
  state = gr.State()
217
  state_ = gr.State()
218
  first_run = gr.State()
 
 
219
 
220
  with gr.Row():
221
  with gr.Column(scale=3):
222
  image = gr.Image(label="Input Image", type="filepath")
223
+ video = gr.Video(label="Input Video", type="filepath")
224
 
225
  cur_dir = os.path.dirname(os.path.abspath(__file__))
226
  gr.Examples(
 
242
  )
243
 
244
  with gr.Column(scale=7):
245
+ chatbot = gr.Chatbot(label="VideoLLaMA 2", bubble_full_width=True, height=750)
246
  with gr.Row():
247
  with gr.Column(scale=8):
248
  textbox.render()
249
  with gr.Column(scale=1, min_width=50):
250
  submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
251
  with gr.Row(elem_id="buttons") as button_row:
252
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=True)
253
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=True)
254
+ # flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
255
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
256
  regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=True)
257
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=True)
258
 
259
  gr.Markdown(tos_markdown)
260
  gr.Markdown(learn_more_markdown)
 
262
  submit_btn.click(
263
  generate,
264
  [image, video, first_run, state, state_, textbox],
265
+ [image, video, chatbot, first_run, state, state_, textbox])
 
 
266
 
267
  regenerate_btn.click(
268
  regenerate,
videollama2/model/multimodal_projector/builder.py CHANGED
@@ -20,7 +20,7 @@ import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
  from timm.models.regnet import RegStage
23
- from timm.models.layers import LayerNorm, LayerNorm2d
24
  from transformers import TRANSFORMERS_CACHE
25
 
26
 
 
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
  from timm.models.regnet import RegStage
23
+ from timm.models.layers import LayerNorm2d
24
  from transformers import TRANSFORMERS_CACHE
25
 
26