File size: 5,976 Bytes
513e1fb 544e91d 513e1fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import shutil
import gradio as gr
import torch
from fastapi import FastAPI
import os
import tempfile
from Infer import Infer
title_markdown = ("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1 >Temporal-guided Mixture-of-Experts for Zero-Shot Video Question Answering</h1>
<h5 style="margin: 0;">Under review.</h5>
</div>
</div>
<div align="center">
<div style="display:flex; gap: 0.25rem;" align="center">
<a href='https://github.com/qyx1121/T-MoENet'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
</div>
</div>
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
def save_video_to_local(video_path):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
shutil.copyfile(video_path, filename)
return filename
def generate(video, textbox_in, first_run, state, state_):
flag = 1
if not textbox_in:
if len(state_.messages) > 0:
textbox_in = state_.messages[-1][1]
state_.messages.pop(-1)
flag = 0
else:
return "Please enter instruction"
video = video if video else "none"
# assert not (os.path.exists(image1) and os.path.exists(video))
first_run = False if len(state.messages) > 0 else True
text_en_in = textbox_in.replace("picture", "image")
# images_tensor = [[], []]
image_processor = handler.image_processor
if os.path.exists(image1) and not os.path.exists(video):
tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ['image']
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
video_processor = handler.video_processor
if not os.path.exists(image1) and os.path.exists(video):
tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ['video']
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
if os.path.exists(image1) and os.path.exists(video):
tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ['video']
tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ['image']
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
state_.messages[-1] = (state_.roles[1], text_en_out)
text_en_out = text_en_out.split('#')[0]
textbox_out = text_en_out
show_images = ""
if flag:
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
state.append_message(state.roles[1], textbox_out)
torch.cuda.empty_cache()
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
device = "cpu"
handler = Infer(device)
# handler.model.to(dtype=dtype)
if not os.path.exists("temp"):
os.makedirs("temp")
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", container=False
)
with gr.Blocks(title='T-MoENet', theme=gr.themes.Default(), css=block_css) as demo:
gr.Markdown(title_markdown)
state = gr.State()
state_ = gr.State()
first_run = gr.State()
images_tensor = gr.State()
with gr.Row():
with gr.Column(scale=3):
video = gr.Video(label="Input Video")
cur_dir = os.path.dirname(os.path.abspath(__file__))
print(cur_dir)
gr.Examples(
examples=[
[
cur_dir + "/videos/3249402410.mp4",
"what did the lady in black on the left do after she finished spreading the sauce on her pizza?",
],
[
cur_dir + "/videos/4882821564.mp4",
"why did the boy clap his hands when he ran to the christmas tree?",
],
[
cur_dir + "/videos/6233408665.mp4",
"what did the people on the sofa do after the lady in pink finished singing?",
],
],
inputs=[video, textbox],
)
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="T-MoENet", bubble_full_width=True)
with gr.Row():
with gr.Column(scale=2):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(
value="Send", variant="primary", interactive=True
)
submit_btn.click(generate, [video, textbox, first_run, state, state_],
[state, state_, chatbot, first_run, textbox, video])
demo.launch(share=True)
|