multimodalart HF staff commited on
Commit
6a87547
·
verified ·
1 Parent(s): eb60639

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from huggingface_hub import snapshot_download
6
+ from pyramid_dit import PyramidDiTForVideoGeneration
7
+ from diffusers.utils import export_to_video
8
+
9
+ # Constants
10
+ MODEL_PATH = "pyramid-flow-model"
11
+ MODEL_REPO = "rain1011/pyramid-flow-sd3"
12
+ MODEL_VARIANT = "diffusion_transformer_768p"
13
+ MODEL_DTYPE = "bf16"
14
+
15
+ # Download and load the model
16
+ def load_model():
17
+ if not os.path.exists(MODEL_PATH):
18
+ snapshot_download(MODEL_REPO, local_dir=MODEL_PATH, local_dir_use_symlinks=False, repo_type='model')
19
+
20
+ model = PyramidDiTForVideoGeneration(
21
+ MODEL_PATH,
22
+ MODEL_DTYPE,
23
+ model_variant=MODEL_VARIANT,
24
+ )
25
+
26
+ model.vae.to("cuda")
27
+ model.dit.to("cuda")
28
+ model.text_encoder.to("cuda")
29
+ model.vae.enable_tiling()
30
+
31
+ return model
32
+
33
+ # Global model variable
34
+ model = load_model()
35
+
36
+ # Text-to-video generation function
37
+ def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
38
+ temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
39
+ torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
40
+
41
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
42
+ frames = model.generate(
43
+ prompt=prompt,
44
+ num_inference_steps=[20, 20, 20],
45
+ video_num_inference_steps=[10, 10, 10],
46
+ height=768,
47
+ width=1280,
48
+ temp=temp,
49
+ guidance_scale=guidance_scale,
50
+ video_guidance_scale=video_guidance_scale,
51
+ output_type="pil",
52
+ save_memory=True,
53
+ )
54
+
55
+ output_path = "output_video.mp4"
56
+ export_to_video(frames, output_path, fps=24)
57
+ return output_path
58
+
59
+ # Image-to-video generation function
60
+ def generate_video_from_image(image, prompt, duration, video_guidance_scale):
61
+ temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
62
+ torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
63
+
64
+ image = image.resize((1280, 768))
65
+
66
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
67
+ frames = model.generate_i2v(
68
+ prompt=prompt,
69
+ input_image=image,
70
+ num_inference_steps=[10, 10, 10],
71
+ temp=temp,
72
+ guidance_scale=7.0,
73
+ video_guidance_scale=video_guidance_scale,
74
+ output_type="pil",
75
+ save_memory=True,
76
+ )
77
+
78
+ output_path = "output_video_i2v.mp4"
79
+ export_to_video(frames, output_path, fps=24)
80
+ return output_path
81
+
82
+ # Gradio interface
83
+ with gr.Blocks() as demo:
84
+ gr.Markdown("# Pyramid Flow Video Generation Demo")
85
+
86
+ with gr.Tab("Text-to-Video"):
87
+ with gr.Row():
88
+ with gr.Column():
89
+ t2v_prompt = gr.Textbox(label="Prompt")
90
+ t2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
91
+ t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
92
+ t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
93
+ t2v_generate_btn = gr.Button("Generate Video")
94
+ with gr.Column():
95
+ t2v_output = gr.Video(label="Generated Video")
96
+
97
+ t2v_generate_btn.click(
98
+ generate_video,
99
+ inputs=[t2v_prompt, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale],
100
+ outputs=t2v_output
101
+ )
102
+
103
+ with gr.Tab("Image-to-Video"):
104
+ with gr.Row():
105
+ with gr.Column():
106
+ i2v_image = gr.Image(type="pil", label="Input Image")
107
+ i2v_prompt = gr.Textbox(label="Prompt")
108
+ i2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
109
+ i2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=4, step=0.1, label="Video Guidance Scale")
110
+ i2v_generate_btn = gr.Button("Generate Video")
111
+ with gr.Column():
112
+ i2v_output = gr.Video(label="Generated Video")
113
+
114
+ i2v_generate_btn.click(
115
+ generate_video_from_image,
116
+ inputs=[i2v_image, i2v_prompt, i2v_duration, i2v_video_guidance_scale],
117
+ outputs=i2v_output
118
+ )
119
+
120
+ demo.launch()