haodongli commited on
Commit
9486ff0
·
1 Parent(s): 73d2593

add training progress visualization

Browse files
Files changed (2) hide show
  1. gradio_demo.py +15 -7
  2. train.py +3 -3
gradio_demo.py CHANGED
@@ -32,17 +32,25 @@ example_inputs = [[
32
  "A DSLR photo of a wooden car, super detailed, best quality, 4K, HD.",
33
  "a wooden car."
34
  ]]
35
- example_outputs = [
36
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/boots.mp4'), autoplay=True),
37
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/Donut.mp4'), autoplay=True),
38
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/durian.mp4'), autoplay=True),
39
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/pillow_huskies.mp4'), autoplay=True),
40
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/wooden_car.mp4'), autoplay=True)
41
  ]
 
 
 
 
 
 
 
 
42
 
43
  def main(prompt, init_prompt, negative_prompt, num_iter, CFG, seed):
44
  if [prompt, init_prompt] in example_inputs:
45
- return example_outputs[example_inputs.index([prompt, init_prompt])]
46
  args, lp, op, pp, gcp, gp = args_parser(default_opt=os.path.join(os.path.dirname(__file__), 'configs/white_hair_ironman.yaml'))
47
  gp.text = prompt
48
  gp.negative = negative_prompt
@@ -59,16 +67,16 @@ def main(prompt, init_prompt, negative_prompt, num_iter, CFG, seed):
59
  if os.environ.get('QUEUE_1') != "True":
60
  os.environ['QUEUE_1'] = "True"
61
  lp.workspace = 'gradio_demo_1'
62
- video_path = start_training(args, lp, op, pp, gcp, gp)
63
  os.environ['QUEUE_1'] = "False"
64
  else:
65
  lp.workspace = 'gradio_demo_2'
66
- video_path = start_training(args, lp, op, pp, gcp, gp)
67
- return gr.Video(value=video_path, autoplay=True)
68
 
69
  with gr.Blocks() as demo:
70
  gr.Markdown("# <center>LucidDreamer: Towards High-Fidelity Text-to-3D Generation via Interval Score Matching</center>")
71
- gr.Markdown("This live demo allows you to generate high-quality 3D content using text prompts.<br> \
72
  It is based on Stable Diffusion 2.1. Please check out our <strong><a href=https://github.com/EnVision-Research/LucidDreamer>Project Page</a> / <a href=https://arxiv.org/abs/2311.11284>Paper</a> / <a href=https://github.com/EnVision-Research/LucidDreamer>Code</a></strong> if you want to learn more about our method!<br> \
73
  Note that this demo is running on A10G Small, the running time might be longer than the reported 35 minutes (5000 iterations) on A100.<br> \
74
  &copy; This Gradio space was developed by Haodong LI.")
@@ -78,7 +86,7 @@ with gr.Blocks() as demo:
78
  gr.Slider(1000, 5000, value=5000, label="Number of iterations"),
79
  gr.Slider(7.5, 100, value=7.5, label="CFG"),
80
  gr.Number(value=0, label="Seed")],
81
- outputs="playable_video",
82
  examples=example_inputs,
83
  cache_examples=True,
84
  concurrency_limit=2)
 
32
  "A DSLR photo of a wooden car, super detailed, best quality, 4K, HD.",
33
  "a wooden car."
34
  ]]
35
+ example_outputs_1 = [
36
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/boots.mp4'), autoplay=True),
37
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/Donut.mp4'), autoplay=True),
38
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/durian.mp4'), autoplay=True),
39
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/pillow_huskies.mp4'), autoplay=True),
40
  gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/wooden_car.mp4'), autoplay=True)
41
  ]
42
+ example_outputs_2 = [
43
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/boots_pro.mp4'), autoplay=True),
44
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/Donut_pro.mp4'), autoplay=True),
45
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/durian_pro.mp4'), autoplay=True),
46
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/pillow_huskies_pro.mp4'), autoplay=True),
47
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/wooden_car_pro.mp4'), autoplay=True)
48
+ ]
49
+
50
 
51
  def main(prompt, init_prompt, negative_prompt, num_iter, CFG, seed):
52
  if [prompt, init_prompt] in example_inputs:
53
+ return example_outputs_1[example_inputs.index([prompt, init_prompt])], example_outputs_2[example_inputs.index([prompt, init_prompt])]
54
  args, lp, op, pp, gcp, gp = args_parser(default_opt=os.path.join(os.path.dirname(__file__), 'configs/white_hair_ironman.yaml'))
55
  gp.text = prompt
56
  gp.negative = negative_prompt
 
67
  if os.environ.get('QUEUE_1') != "True":
68
  os.environ['QUEUE_1'] = "True"
69
  lp.workspace = 'gradio_demo_1'
70
+ video_path, pro_video_path = start_training(args, lp, op, pp, gcp, gp)
71
  os.environ['QUEUE_1'] = "False"
72
  else:
73
  lp.workspace = 'gradio_demo_2'
74
+ video_path, pro_video_path = start_training(args, lp, op, pp, gcp, gp)
75
+ return gr.Video(value=video_path, autoplay=True), gr.Video(value=pro_video_path, autoplay=True)
76
 
77
  with gr.Blocks() as demo:
78
  gr.Markdown("# <center>LucidDreamer: Towards High-Fidelity Text-to-3D Generation via Interval Score Matching</center>")
79
+ gr.Markdown("This live demo allows you to generate high-quality 3D content using text prompts. The outputs are 360° rendered 3d gaussian video and training progress visualization.<br> \
80
  It is based on Stable Diffusion 2.1. Please check out our <strong><a href=https://github.com/EnVision-Research/LucidDreamer>Project Page</a> / <a href=https://arxiv.org/abs/2311.11284>Paper</a> / <a href=https://github.com/EnVision-Research/LucidDreamer>Code</a></strong> if you want to learn more about our method!<br> \
81
  Note that this demo is running on A10G Small, the running time might be longer than the reported 35 minutes (5000 iterations) on A100.<br> \
82
  &copy; This Gradio space was developed by Haodong LI.")
 
86
  gr.Slider(1000, 5000, value=5000, label="Number of iterations"),
87
  gr.Slider(7.5, 100, value=7.5, label="CFG"),
88
  gr.Number(value=0, label="Seed")],
89
+ outputs=["playable_video", "playable_video"],
90
  examples=example_inputs,
91
  cache_examples=True,
92
  concurrency_limit=2)
train.py CHANGED
@@ -382,7 +382,7 @@ def training(dataset, opt, pipe, gcams, guidance_opt, testing_iterations, saving
382
 
383
  if opt.save_process:
384
  imageio.mimwrite(os.path.join(save_folder_proc, "video_rgb.mp4"), pro_img_frames, fps=30, quality=8)
385
- return video_path
386
 
387
 
388
  def prepare_output_and_logger(args):
@@ -543,10 +543,10 @@ def start_training(args, lp, op, pp, gcp, gp):
543
  # Start GUI server, configure and run training
544
  network_gui.init(args.ip, args.port)
545
  torch.autograd.set_detect_anomaly(args.detect_anomaly)
546
- video_path = training(lp, op, pp, gcp, gp, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.save_video)
547
  # All done
548
  print("\nTraining complete.")
549
- return video_path
550
 
551
  if __name__ == "__main__":
552
  args, lp, op, pp, gcp, gp = args_parser()
 
382
 
383
  if opt.save_process:
384
  imageio.mimwrite(os.path.join(save_folder_proc, "video_rgb.mp4"), pro_img_frames, fps=30, quality=8)
385
+ return video_path, os.path.join(save_folder_proc, "video_rgb.mp4")
386
 
387
 
388
  def prepare_output_and_logger(args):
 
543
  # Start GUI server, configure and run training
544
  network_gui.init(args.ip, args.port)
545
  torch.autograd.set_detect_anomaly(args.detect_anomaly)
546
+ video_path, pro_video_path = training(lp, op, pp, gcp, gp, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.save_video)
547
  # All done
548
  print("\nTraining complete.")
549
+ return video_path, pro_video_path
550
 
551
  if __name__ == "__main__":
552
  args, lp, op, pp, gcp, gp = args_parser()