add fps control

#2
by multimodalart HF staff - opened
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -62,7 +62,7 @@ model = load_model()
62
 
63
  # Text-to-video generation function
64
  @spaces.GPU(duration=120)
65
- def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guidance_scale=5, progress=gr.Progress(track_tqdm=True)):
66
  multiplier = 1.2 if is_canonical else 3.0
67
  temp = int(duration * multiplier) + 1
68
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
@@ -95,7 +95,7 @@ def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guida
95
  save_memory=True,
96
  )
97
  output_path = f"{str(uuid.uuid4())}_output_video.mp4"
98
- export_to_video(frames, output_path, fps=8 if is_canonical else 24)
99
  return output_path
100
 
101
  # Gradio interface
@@ -111,11 +111,12 @@ with gr.Blocks() as demo:
111
  t2v_prompt = gr.Textbox(label="Prompt")
112
  with gr.Accordion("Advanced settings", open=False):
113
  t2v_duration = gr.Slider(minimum=1, maximum=3 if is_canonical else 10, value=3 if is_canonical else 5, step=1, label="Duration (seconds)", visible=not is_canonical)
 
114
  t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
115
  t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
116
  t2v_generate_btn = gr.Button("Generate Video")
117
  with gr.Column():
118
- t2v_output = gr.Video(label=f"Generated Video at {'8fps' if is_canonical else '24fps'}")
119
  gr.HTML("""
120
  <div style="display: flex; flex-direction: column;justify-content: center; align-items: center; text-align: center;">
121
  <p style="display: flex;gap: 6px;">
@@ -138,7 +139,7 @@ with gr.Blocks() as demo:
138
  )
139
  t2v_generate_btn.click(
140
  generate_video,
141
- inputs=[t2v_prompt, i2v_image, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale],
142
  outputs=t2v_output
143
  )
144
 
 
62
 
63
  # Text-to-video generation function
64
  @spaces.GPU(duration=120)
65
+ def generate_video(prompt, image=None, duration=3, guidance_scale=9, video_guidance_scale=5, frames_per_second=8, progress=gr.Progress(track_tqdm=True)):
66
  multiplier = 1.2 if is_canonical else 3.0
67
  temp = int(duration * multiplier) + 1
68
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
 
95
  save_memory=True,
96
  )
97
  output_path = f"{str(uuid.uuid4())}_output_video.mp4"
98
+ export_to_video(frames, output_path, fps=frames_per_second)
99
  return output_path
100
 
101
  # Gradio interface
 
111
  t2v_prompt = gr.Textbox(label="Prompt")
112
  with gr.Accordion("Advanced settings", open=False):
113
  t2v_duration = gr.Slider(minimum=1, maximum=3 if is_canonical else 10, value=3 if is_canonical else 5, step=1, label="Duration (seconds)", visible=not is_canonical)
114
+ t2v_fps = gr.Slider(minimum=8, maximum=24, step=16, value=8 if is_canonical else 24, label="Frames per second", visible=is_canonical)
115
  t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
116
  t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
117
  t2v_generate_btn = gr.Button("Generate Video")
118
  with gr.Column():
119
+ t2v_output = gr.Video(label=f"Generated Video")
120
  gr.HTML("""
121
  <div style="display: flex; flex-direction: column;justify-content: center; align-items: center; text-align: center;">
122
  <p style="display: flex;gap: 6px;">
 
139
  )
140
  t2v_generate_btn.click(
141
  generate_video,
142
+ inputs=[t2v_prompt, i2v_image, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale, t2v_fps],
143
  outputs=t2v_output
144
  )
145