|
import torch |
|
import gradio as gr |
|
from flash_vstream.serve.demo import Chat, title_markdown, block_css |
|
from flash_vstream.constants import * |
|
from flash_vstream.conversation import conv_templates, Conversation |
|
import os |
|
from PIL import Image |
|
import tempfile |
|
import imageio |
|
import shutil |
|
|
|
|
|
model_path = "IVGSZ/Flash-VStream-7b" |
|
load_8bit = False |
|
load_4bit = False |
|
|
|
def save_image_to_local(image): |
|
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') |
|
image = Image.open(image) |
|
image.save(filename) |
|
return filename |
|
|
|
|
|
def save_video_to_local(video_path): |
|
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') |
|
shutil.copyfile(video_path, filename) |
|
return filename |
|
|
|
|
|
def generate(video, textbox_in, first_run, state, state_, images_tensor): |
|
|
|
flag = 1 |
|
if not textbox_in: |
|
if len(state_.messages) > 0: |
|
textbox_in = state_.messages[-1][1] |
|
state_.messages.pop(-1) |
|
flag = 0 |
|
else: |
|
return "Please enter instruction" |
|
|
|
video = video if video else "none" |
|
|
|
if type(state) is not Conversation: |
|
state = conv_templates[conv_mode].copy() |
|
state_ = conv_templates[conv_mode].copy() |
|
images_tensor = [] |
|
|
|
first_run = False if len(state.messages) > 0 else True |
|
|
|
text_en_in = textbox_in.replace("picture", "image") |
|
|
|
image_processor = handler.image_processor |
|
|
|
if os.path.exists(video): |
|
video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH) |
|
for img in video_tensor: |
|
images_tensor.append(image_processor(img, return_tensors='pt')['pixel_values'][0].to(handler.model.device, dtype=torch.float16)) |
|
|
|
if os.path.exists(video): |
|
text_en_in = DEFAULT_IMAGE_TOKEN * len(video_tensor) + '\n' + text_en_in |
|
|
|
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) |
|
state_.messages[-1] = (state_.roles[1], text_en_out) |
|
|
|
text_en_out = text_en_out.split('#')[0] |
|
textbox_out = text_en_out |
|
|
|
show_images = "" |
|
if os.path.exists(video): |
|
filename = save_video_to_local(video) |
|
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>' |
|
|
|
if flag: |
|
state.append_message(state.roles[0], textbox_in + "\n" + show_images) |
|
state.append_message(state.roles[1], textbox_out) |
|
|
|
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=None, interactive=True)) |
|
|
|
|
|
def regenerate(state, state_): |
|
state.messages.pop(-1) |
|
state_.messages.pop(-1) |
|
if len(state.messages) > 0: |
|
return state, state_, state.to_gradio_chatbot(), False |
|
return (state, state_, state.to_gradio_chatbot(), True) |
|
|
|
|
|
def clear_history(state, state_): |
|
state = conv_templates[conv_mode].copy() |
|
state_ = conv_templates[conv_mode].copy() |
|
return (gr.update(value=None, interactive=True), \ |
|
gr.update(value=None, interactive=True),\ |
|
True, state, state_, state.to_gradio_chatbot(), []) |
|
|
|
|
|
conv_mode = "simple" |
|
handler = Chat(model_path, conv_mode=conv_mode, load_4bit=load_4bit, load_8bit=load_8bit) |
|
if not os.path.exists("temp"): |
|
os.makedirs("temp") |
|
|
|
print(torch.cuda.memory_allocated()) |
|
print(torch.cuda.max_memory_allocated()) |
|
|
|
with gr.Blocks(title='Flash-VStream', theme=gr.themes.Soft(), css=block_css) as demo: |
|
gr.Markdown(title_markdown) |
|
state = gr.State() |
|
state_ = gr.State() |
|
first_run = gr.State() |
|
images_tensor = gr.State() |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
video = gr.Video(label="Input Video") |
|
|
|
with gr.Column(scale=7): |
|
chatbot = gr.Chatbot(label="Flash-VStream", bubble_full_width=True).style(height=700) |
|
with gr.Row(): |
|
with gr.Column(scale=8): |
|
textbox = gr.Textbox(show_label=False, |
|
placeholder="Enter text and press Send", |
|
container=False) |
|
with gr.Column(scale=2, min_width=50): |
|
submit_btn = gr.Button(value="Send", variant="primary", interactive=True) |
|
|
|
with gr.Row(visible=True) as button_row: |
|
flag_btn = gr.Button(value="⚠️ Flag", interactive=True) |
|
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) |
|
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) |
|
|
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
f"{cur_dir}/examples/video2.mp4", |
|
"Describe the video briefly.", |
|
] |
|
], |
|
inputs=[video, textbox], |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
f"{cur_dir}/examples/video4.mp4", |
|
"What is the boy doing?", |
|
] |
|
], |
|
inputs=[video, textbox], |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
f"{cur_dir}/examples/video5.mp4", |
|
"Why is this video funny?", |
|
] |
|
], |
|
inputs=[video, textbox], |
|
) |
|
|
|
submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) |
|
|
|
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( |
|
generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) |
|
|
|
clear_btn.click(clear_history, [state, state_], |
|
[video, textbox, first_run, state, state_, chatbot, images_tensor]) |
|
|
|
|
|
|
|
demo.launch() |