Text-to-Video / app.py
Tennish's picture
Update app.py
02366db verified
raw
history blame
5.23 kB
import gradio as gr
import torch
import os
import spaces
import uuid
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
# Constants
bases = {
"Cartoon": "frankjoshua/toonyou_beta6",
"Realistic": "emilianJR/epiCRealism",
"3d": "Lykon/DreamShaper",
"Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None
# Ensure GPU availability
if not torch.cuda.is_available():
raise NotImplementedError("No GPU detected!")
device = "cuda"
dtype = torch.float16
# Load initial pipeline
print("Loading AnimateDiff pipeline...")
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
print("Pipeline loaded successfully.")
# Safety checkers
from transformers import CLIPFeatureExtractor
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
# Video Generation Function
@spaces.GPU(duration=30, queue=False)
def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Progress()):
global step_loaded
global base_loaded
global motion_loaded
print(f"Generating video for: Prompt='{prompt}', Base='{base}', Motion='{motion}', Steps='{step}'")
# Load step-specific model
if step_loaded != step:
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
step_loaded = step
# Load base model
if base_loaded != base:
pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
base_loaded = base
# Load motion adapter
if motion_loaded != motion:
pipe.unload_lora_weights()
if motion != "":
pipe.load_lora_weights(motion, adapter_name="motion")
pipe.set_adapters(["motion"], [0.7])
motion_loaded = motion
# Video parameters: 30-second duration
fps = 10
duration = 30 # seconds
total_frames = fps * duration # 300 frames for 30s at 10 FPS
progress((0, step))
def progress_callback(i, t, z):
progress((i + 1, step))
# Generate video frames
output_frames = []
for frame in range(total_frames):
output = pipe(
prompt=prompt,
guidance_scale=1.2,
num_inference_steps=step,
callback=progress_callback,
callback_steps=1
)
output_frames.extend(output.frames[0]) # Collect frames
# Export to video
name = str(uuid.uuid4()).replace("-", "")
path = f"/tmp/{name}.mp4"
export_to_video(output_frames, path, fps=fps)
return path
# Gradio Interface
with gr.Blocks(css="style.css") as demo:
gr.HTML("<h1><center>Textual Imagination: A Text To Video Synthesis</center></h1>")
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Prompt', placeholder="Enter your video description here...")
with gr.Row():
select_base = gr.Dropdown(
label='Base model',
choices=["Cartoon", "Realistic", "3d", "Anime"],
value=base_loaded,
interactive=True
)
select_motion = gr.Dropdown(
label='Motion',
choices=[
("Default", ""),
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
],
value="guoyww/animatediff-motion-lora-zoom-in",
interactive=True
)
select_step = gr.Dropdown(
label='Inference steps',
choices=[('1-Step', 1), ('2-Step', 2), ('4-Step', 4), ('8-Step', 8)],
value=4,
interactive=True
)
submit = gr.Button(scale=1, variant='primary')
video = gr.Video(
label='Generated Video',
autoplay=True,
height=512,
width=512,
elem_id="video_output"
)
gr.on(
triggers=[submit.click, prompt.submit],
fn=generate_image,
inputs=[prompt, select_base, select_motion, select_step],
outputs=[video],
api_name="instant_video",
queue=False
)
demo.queue().launch()