import shutil import gradio as gr import torch from fastapi import FastAPI import os import tempfile from Infer import Infer title_markdown = ("""

Temporal-guided Mixture-of-Experts for Zero-Shot Video Question Answering

Under review.
""") block_css = """ #buttons button { min-width: min(120px,100%); } """ def save_video_to_local(video_path): filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') shutil.copyfile(video_path, filename) return filename def generate(video, textbox_in, first_run, state, state_): flag = 1 if not textbox_in: if len(state_.messages) > 0: textbox_in = state_.messages[-1][1] state_.messages.pop(-1) flag = 0 else: return "Please enter instruction" video = video if video else "none" # assert not (os.path.exists(image1) and os.path.exists(video)) first_run = False if len(state.messages) > 0 else True text_en_in = textbox_in.replace("picture", "image") # images_tensor = [[], []] image_processor = handler.image_processor if os.path.exists(image1) and not os.path.exists(video): tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] # print(tensor.shape) tensor = tensor.to(handler.model.device, dtype=dtype) images_tensor[0] = images_tensor[0] + [tensor] images_tensor[1] = images_tensor[1] + ['image'] print(torch.cuda.memory_allocated()) print(torch.cuda.max_memory_allocated()) video_processor = handler.video_processor if not os.path.exists(image1) and os.path.exists(video): tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] # print(tensor.shape) tensor = tensor.to(handler.model.device, dtype=dtype) images_tensor[0] = images_tensor[0] + [tensor] images_tensor[1] = images_tensor[1] + ['video'] print(torch.cuda.memory_allocated()) print(torch.cuda.max_memory_allocated()) if os.path.exists(image1) and os.path.exists(video): tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] # print(tensor.shape) tensor = tensor.to(handler.model.device, dtype=dtype) images_tensor[0] = images_tensor[0] + [tensor] images_tensor[1] = images_tensor[1] + ['video'] tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] # print(tensor.shape) tensor = tensor.to(handler.model.device, dtype=dtype) images_tensor[0] = images_tensor[0] + [tensor] images_tensor[1] = images_tensor[1] + ['image'] print(torch.cuda.memory_allocated()) print(torch.cuda.max_memory_allocated()) text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) state_.messages[-1] = (state_.roles[1], text_en_out) text_en_out = text_en_out.split('#')[0] textbox_out = text_en_out show_images = "" if flag: state.append_message(state.roles[0], textbox_in + "\n" + show_images) state.append_message(state.roles[1], textbox_out) torch.cuda.empty_cache() return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) device = "cuda" handler = Infer(device) # handler.model.to(dtype=dtype) if not os.path.exists("temp"): os.makedirs("temp") print(torch.cuda.memory_allocated()) print(torch.cuda.max_memory_allocated()) textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", container=False ) with gr.Blocks(title='T-MoENet', theme=gr.themes.Default(), css=block_css) as demo: gr.Markdown(title_markdown) state = gr.State() state_ = gr.State() first_run = gr.State() images_tensor = gr.State() with gr.Row(): with gr.Column(scale=3): video = gr.Video(label="Input Video") cur_dir = os.path.dirname(os.path.abspath(__file__)) print(cur_dir) gr.Examples( examples=[ [ cur_dir + "/videos/3249402410.mp4", "what did the lady in black on the left do after she finished spreading the sauce on her pizza?", ], [ cur_dir + "/videos/4882821564.mp4", "why did the boy clap his hands when he ran to the christmas tree?", ], [ cur_dir + "/videos/6233408665.mp4", "what did the people on the sofa do after the lady in pink finished singing?", ], ], inputs=[video, textbox], ) with gr.Column(scale=7): chatbot = gr.Chatbot(label="T-MoENet", bubble_full_width=True) with gr.Row(): with gr.Column(scale=2): textbox.render() with gr.Column(scale=1, min_width=50): submit_btn = gr.Button( value="Send", variant="primary", interactive=True ) submit_btn.click(generate, [video, textbox, first_run, state, state_], [state, state_, chatbot, first_run, textbox, video]) demo.launch(share=True)