LanguageBind commited on
Commit
4b9c0aa
·
verified ·
1 Parent(s): 2631dca

Update opensora/serve/gradio_web_server.py

Browse files
Files changed (1) hide show
  1. opensora/serve/gradio_web_server.py +24 -18
opensora/serve/gradio_web_server.py CHANGED
@@ -22,32 +22,38 @@ from opensora.models.ae import ae_stride_config, getae, getae_wrapper
22
  from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
23
  from opensora.models.diffusion.latte.modeling_latte import LatteT2V
24
  from opensora.sample.pipeline_videogen import VideoGenPipeline
25
- from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION
26
-
27
 
28
  @spaces.GPU(duration=300)
 
 
29
  @torch.inference_mode()
30
  def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
31
- seed = int(randomize_seed_fn(seed, randomize_seed))
32
- set_env(seed)
33
  video_length = transformer_model.config.video_length if not force_images else 1
34
  height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
35
  num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
36
- videos = videogen_pipeline(prompt,
37
- num_frames=num_frames,
38
- height=height,
39
- width=width,
40
- num_inference_steps=sample_steps,
41
- guidance_scale=scale,
42
- enable_temporal_attentions=not force_images,
43
- num_images_per_prompt=1,
44
- mask_feature=True,
45
- ).video
 
 
 
 
 
 
46
 
47
- torch.cuda.empty_cache()
48
- videos = videos[0]
49
- tmp_save_path = 'tmp.mp4'
50
- imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9) # highest quality is 10, lowest is 0
51
  display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}"
52
  return tmp_save_path, prompt, display_model_info, seed
53
 
 
22
  from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
23
  from opensora.models.diffusion.latte.modeling_latte import LatteT2V
24
  from opensora.sample.pipeline_videogen import VideoGenPipeline
25
+ from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, DESCRIPTION
26
+ from opensora.serve.gradio_utils import examples_txt, examples
27
 
28
  @spaces.GPU(duration=300)
29
+ @torch.inference_mode()
30
+
31
  @torch.inference_mode()
32
  def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
 
 
33
  video_length = transformer_model.config.video_length if not force_images else 1
34
  height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
35
  num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
36
+ if not force_images and prompt in examples_txt:
37
+ idx = examples_txt.index(prompt)
38
+ tmp_save_path = f'demo65-221/f65/{idx+1}.mp4'
39
+ else:
40
+ seed = int(randomize_seed_fn(seed, randomize_seed))
41
+ set_env(seed)
42
+ videos = videogen_pipeline(prompt,
43
+ num_frames=num_frames,
44
+ height=height,
45
+ width=width,
46
+ num_inference_steps=sample_steps,
47
+ guidance_scale=scale,
48
+ enable_temporal_attentions=not force_images,
49
+ num_images_per_prompt=1,
50
+ mask_feature=True,
51
+ ).video
52
 
53
+ torch.cuda.empty_cache()
54
+ videos = videos[0]
55
+ tmp_save_path = 'tmp.mp4'
56
+ imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9) # highest quality is 10, lowest is 0
57
  display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}"
58
  return tmp_save_path, prompt, display_model_info, seed
59