vilarin's picture
Update app.py
82ba711 verified
raw
history blame
4.17 kB
import os
import gradio as gr
import torch
import spaces
import random
from PIL import Image
import icecream as ic
import numpy as np
from glob import glob
from pathlib import Path
from typing import Optional
from diffsynth import ModelManager, SVDVideoPipeline, HunyuanDiTImagePipeline
from diffsynth import ModelManager
from diffusers.utils import load_image, export_to_video
import uuid
from huggingface_hub import hf_hub_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Constants
MAX_SEED = np.iinfo(np.int32).max
CSS = """
footer {
visibility: hidden;
}
"""
JS = """function () {
gradioURL = window.location.href
if (!gradioURL.endsWith('?__theme=dark')) {
window.location.replace(gradioURL + '?__theme=dark');
}
}"""
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
model_manager = ModelManager(
torch_dtype=torch.float16,
device="cuda",
model_id_list=["stable-video-diffusion-img2vid-xt", "ExVideo-SVD-128f-v1"],
downloading_priority=["HuggingFace"])
pipe = SVDVideoPipeline.from_model_manager(model_manager)
# function source codes modified from multimodalart/stable-video-diffusion
@spaces.GPU(duration=120)
def generate(
image,
seed: Optional[int] = -1,
motion_bucket_id: int = 127,
fps_id: int = 25,
output_folder: str = "outputs",
progress=gr.Progress(track_tqdm=True)):
ic(image)
if seed == -1:
seed = random.randint(0, MAX_SEED)
image = Image.open(image)
torch.manual_seed(seed)
os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
frames = pipe(
input_image=image.resize((512, 512)),
num_frames=128,
fps=fps_id,
height=512,
width=512,
motion_bucket_id=motion_bucket_id,
num_inference_steps=50,
min_cfg_scale=2,
max_cfg_scale=2,
contrast_enhance_scale=1.2
).frames[0]
export_to_video(frames, video_path, fps=fps_id)
return video_path, seed
examples = [
"./train.jpg",
"./girl.webp",
"./robo.jpg",
]
# Gradio Interface
with gr.Blocks(css=CSS, js=JS, theme="soft") as demo:
gr.HTML("<h1><center>Exvideo📽️</center></h1>")
gr.HTML("<p><center><a href='https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1'>ExVideo</a> image-to-video generation<br><b>Update</b>: first version</center></p>")
with gr.Row():
image = gr.Image(label='Upload Image', height=600, scale=2, image_mode="RGB", type="filepath")
video = gr.Video(label="Generated Video", height=600, scale=2)
with gr.Accordion("Advanced Options", open=True):
with gr.Column(scale=1):
seed = gr.Slider(
label="Seed (-1 Random)",
minimum=-1,
maximum=MAX_SEED,
step=1,
value=-1,
)
motion_bucket_id = gr.Slider(
label="Motion bucket id",
info="Controls how much motion to add/remove from the image",
value=127,
minimum=1,
maximum=255
)
fps_id = gr.Slider(
label="Frames per second",
info="The length of your video in seconds will be 25/fps",
value=25,
minimum=5,
maximum=30
)
with gr.Row():
submit_btn = gr.Button(value="Generate")
clear_btn = gr.ClearButton([image, seed, video])
gr.Examples(
examples=examples,
inputs=image,
outputs=[video, seed],
fn=generate,
cache_examples="lazy",
examples_per_page=4,
)
submit_btn.click(fn=generate, inputs=[image, seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
demo.queue().launch()