T-MoENet / app.py
yixin1121's picture
Upload folder using huggingface_hub
513e1fb verified
raw
history blame
5.98 kB
import shutil
import gradio as gr
import torch
from fastapi import FastAPI
import os
import tempfile
from Infer import Infer
title_markdown = ("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1 >Temporal-guided Mixture-of-Experts for Zero-Shot Video Question Answering</h1>
<h5 style="margin: 0;">Under review.</h5>
</div>
</div>
<div align="center">
<div style="display:flex; gap: 0.25rem;" align="center">
<a href='https://github.com/qyx1121/T-MoENet'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
</div>
</div>
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
def save_video_to_local(video_path):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
shutil.copyfile(video_path, filename)
return filename
def generate(video, textbox_in, first_run, state, state_):
flag = 1
if not textbox_in:
if len(state_.messages) > 0:
textbox_in = state_.messages[-1][1]
state_.messages.pop(-1)
flag = 0
else:
return "Please enter instruction"
video = video if video else "none"
# assert not (os.path.exists(image1) and os.path.exists(video))
first_run = False if len(state.messages) > 0 else True
text_en_in = textbox_in.replace("picture", "image")
# images_tensor = [[], []]
image_processor = handler.image_processor
if os.path.exists(image1) and not os.path.exists(video):
tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ['image']
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
video_processor = handler.video_processor
if not os.path.exists(image1) and os.path.exists(video):
tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ['video']
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
if os.path.exists(image1) and os.path.exists(video):
tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ['video']
tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ['image']
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
state_.messages[-1] = (state_.roles[1], text_en_out)
text_en_out = text_en_out.split('#')[0]
textbox_out = text_en_out
show_images = ""
if flag:
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
state.append_message(state.roles[1], textbox_out)
torch.cuda.empty_cache()
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
device = "cuda"
handler = Infer(device)
# handler.model.to(dtype=dtype)
if not os.path.exists("temp"):
os.makedirs("temp")
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", container=False
)
with gr.Blocks(title='T-MoENet', theme=gr.themes.Default(), css=block_css) as demo:
gr.Markdown(title_markdown)
state = gr.State()
state_ = gr.State()
first_run = gr.State()
images_tensor = gr.State()
with gr.Row():
with gr.Column(scale=3):
video = gr.Video(label="Input Video")
cur_dir = os.path.dirname(os.path.abspath(__file__))
print(cur_dir)
gr.Examples(
examples=[
[
cur_dir + "/videos/3249402410.mp4",
"what did the lady in black on the left do after she finished spreading the sauce on her pizza?",
],
[
cur_dir + "/videos/4882821564.mp4",
"why did the boy clap his hands when he ran to the christmas tree?",
],
[
cur_dir + "/videos/6233408665.mp4",
"what did the people on the sofa do after the lady in pink finished singing?",
],
],
inputs=[video, textbox],
)
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="T-MoENet", bubble_full_width=True)
with gr.Row():
with gr.Column(scale=2):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(
value="Send", variant="primary", interactive=True
)
submit_btn.click(generate, [video, textbox, first_run, state, state_],
[state, state_, chatbot, first_run, textbox, video])
demo.launch(share=True)