|
import shutil |
|
import gradio as gr |
|
import torch |
|
from fastapi import FastAPI |
|
import os |
|
import tempfile |
|
from Infer import Infer |
|
|
|
title_markdown = (""" |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
|
<div> |
|
<h1 >Temporal-guided Mixture-of-Experts for Zero-Shot Video Question Answering</h1> |
|
<h5 style="margin: 0;">Under review.</h5> |
|
</div> |
|
</div> |
|
|
|
<div align="center"> |
|
<div style="display:flex; gap: 0.25rem;" align="center"> |
|
<a href='https://github.com/qyx1121/T-MoENet'><img src='https://img.shields.io/badge/Github-Code-blue'></a> |
|
</div> |
|
</div> |
|
""") |
|
|
|
block_css = """ |
|
#buttons button { |
|
min-width: min(120px,100%); |
|
} |
|
""" |
|
|
|
def save_video_to_local(video_path): |
|
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') |
|
shutil.copyfile(video_path, filename) |
|
return filename |
|
|
|
|
|
def generate(video, textbox_in, first_run, state, state_): |
|
flag = 1 |
|
if not textbox_in: |
|
if len(state_.messages) > 0: |
|
textbox_in = state_.messages[-1][1] |
|
state_.messages.pop(-1) |
|
flag = 0 |
|
else: |
|
return "Please enter instruction" |
|
video = video if video else "none" |
|
|
|
|
|
first_run = False if len(state.messages) > 0 else True |
|
|
|
text_en_in = textbox_in.replace("picture", "image") |
|
|
|
|
|
image_processor = handler.image_processor |
|
if os.path.exists(image1) and not os.path.exists(video): |
|
tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] |
|
|
|
tensor = tensor.to(handler.model.device, dtype=dtype) |
|
images_tensor[0] = images_tensor[0] + [tensor] |
|
images_tensor[1] = images_tensor[1] + ['image'] |
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
video_processor = handler.video_processor |
|
if not os.path.exists(image1) and os.path.exists(video): |
|
tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] |
|
|
|
tensor = tensor.to(handler.model.device, dtype=dtype) |
|
images_tensor[0] = images_tensor[0] + [tensor] |
|
images_tensor[1] = images_tensor[1] + ['video'] |
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
if os.path.exists(image1) and os.path.exists(video): |
|
tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] |
|
|
|
tensor = tensor.to(handler.model.device, dtype=dtype) |
|
images_tensor[0] = images_tensor[0] + [tensor] |
|
images_tensor[1] = images_tensor[1] + ['video'] |
|
|
|
|
|
tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] |
|
|
|
tensor = tensor.to(handler.model.device, dtype=dtype) |
|
images_tensor[0] = images_tensor[0] + [tensor] |
|
images_tensor[1] = images_tensor[1] + ['image'] |
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
|
|
|
|
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) |
|
state_.messages[-1] = (state_.roles[1], text_en_out) |
|
|
|
text_en_out = text_en_out.split('#')[0] |
|
textbox_out = text_en_out |
|
|
|
show_images = "" |
|
if flag: |
|
state.append_message(state.roles[0], textbox_in + "\n" + show_images) |
|
state.append_message(state.roles[1], textbox_out) |
|
torch.cuda.empty_cache() |
|
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) |
|
|
|
|
|
device = "cuda" |
|
handler = Infer(device) |
|
|
|
if not os.path.exists("temp"): |
|
os.makedirs("temp") |
|
|
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
|
|
textbox = gr.Textbox( |
|
show_label=False, placeholder="Enter text and press ENTER", container=False |
|
) |
|
with gr.Blocks(title='T-MoENet', theme=gr.themes.Default(), css=block_css) as demo: |
|
gr.Markdown(title_markdown) |
|
state = gr.State() |
|
state_ = gr.State() |
|
first_run = gr.State() |
|
images_tensor = gr.State() |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
video = gr.Video(label="Input Video") |
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) |
|
print(cur_dir) |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
cur_dir + "/videos/3249402410.mp4", |
|
"what did the lady in black on the left do after she finished spreading the sauce on her pizza?", |
|
], |
|
[ |
|
cur_dir + "/videos/4882821564.mp4", |
|
"why did the boy clap his hands when he ran to the christmas tree?", |
|
], |
|
[ |
|
cur_dir + "/videos/6233408665.mp4", |
|
"what did the people on the sofa do after the lady in pink finished singing?", |
|
], |
|
], |
|
inputs=[video, textbox], |
|
) |
|
|
|
with gr.Column(scale=7): |
|
chatbot = gr.Chatbot(label="T-MoENet", bubble_full_width=True) |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
textbox.render() |
|
with gr.Column(scale=1, min_width=50): |
|
submit_btn = gr.Button( |
|
value="Send", variant="primary", interactive=True |
|
) |
|
|
|
submit_btn.click(generate, [video, textbox, first_run, state, state_], |
|
[state, state_, chatbot, first_run, textbox, video]) |
|
|
|
demo.launch(share=True) |
|
|