basilevh commited on
Commit
244d4ad
·
1 Parent(s): 65beccc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -54
app.py CHANGED
@@ -1,6 +1,6 @@
1
  '''
2
  conda activate zero123
3
- cd stable-diffusion
4
  python gradio_new.py 0
5
  '''
6
 
@@ -11,6 +11,7 @@ import gradio as gr
11
  import lovely_numpy
12
  import lovely_tensors
13
  import numpy as np
 
14
  import plotly.express as px
15
  import plotly.graph_objects as go
16
  import rich
@@ -27,7 +28,7 @@ from lovely_numpy import lo
27
  from omegaconf import OmegaConf
28
  from PIL import Image
29
  from rich import print
30
- from transformers import AutoFeatureExtractor #, CLIPImageProcessor
31
  from torch import autocast
32
  from torchvision import transforms
33
 
@@ -43,8 +44,8 @@ _TITLE = 'Zero-1-to-3: Zero-shot One Image to 3D Object'
43
 
44
  # This demo allows you to generate novel viewpoints of an object depicted in an input image using a fine-tuned version of Stable Diffusion.
45
  _DESCRIPTION = '''
46
- This demo allows you to control camera rotation and thereby generate novel viewpoints of an object within a single image.
47
- It is based on Stable Diffusion. Check out our [project webpage](https://zero123.cs.columbia.edu/) and [paper](https://arxiv.org/) if you want to learn more about the method!
48
  Note that this model is not intended for images of humans or faces, and is unlikely to work well for them.
49
  '''
50
 
@@ -319,7 +320,7 @@ def main_run(models, device, cam_vis, return_what,
319
  '''
320
  :param raw_im (PIL Image).
321
  '''
322
-
323
  safety_checker_input = models['clip_fe'](raw_im, return_tensors='pt').to(device)
324
  (image, has_nsfw_concept) = models['nsfw'](
325
  images=np.ones((1, 3)), clip_input=safety_checker_input.pixel_values)
@@ -507,6 +508,18 @@ def run_demo(
507
  with open('instructions.md', 'r') as f:
508
  article = f.read()
509
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  # Compose demo layout & data flow.
511
  demo = gr.Blocks(title=_TITLE)
512
 
@@ -558,7 +571,8 @@ def run_demo(
558
  vis_btn = gr.Button('Visualize Angles', variant='secondary')
559
  run_btn = gr.Button('Run Generation', variant='primary')
560
 
561
- desc_output = gr.Markdown('The results will appear on the right.', visible=_SHOW_DESC)
 
562
 
563
  with gr.Column(scale=1.1, variant='panel'):
564
 
@@ -571,55 +585,25 @@ def run_demo(
571
  preproc_output = gr.Image(type='pil', image_mode='RGB',
572
  label='Preprocessed input image', visible=_SHOW_INTERMEDIATE)
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  gr.Markdown(article)
575
 
576
  # NOTE: I am forced to update vis_output for these preset buttons,
577
  # because otherwise the gradio plot always resets the plotly 3D viewpoint for some reason,
578
  # which might confuse the user into thinking that the plot has been updated too.
579
 
580
- # OLD 1:
581
- # left_btn.click(fn=lambda: [0.0, -90.0], #, 0.0],
582
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #], radius_slider])
583
- # above_btn.click(fn=lambda: [90.0, 0.0], #, 0.0],
584
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
585
- # right_btn.click(fn=lambda: [0.0, 90.0], #, 0.0],
586
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
587
- # random_btn.click(fn=lambda: [int(np.round(np.random.uniform(-60.0, 60.0))),
588
- # int(np.round(np.random.uniform(-150.0, 150.0)))], #, 0.0],
589
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
590
- # below_btn.click(fn=lambda: [-90.0, 0.0], #, 0.0],
591
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
592
- # behind_btn.click(fn=lambda: [0.0, 180.0], #, 0.0],
593
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
594
-
595
- # OLD 2:
596
- # preset_text = ('You have selected a preset target camera view. '
597
- # 'Now click Run Generation to update the results!')
598
-
599
- # left_btn.click(fn=lambda: [0.0, -90.0, None, preset_text],
600
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
601
- # above_btn.click(fn=lambda: [90.0, 0.0, None, preset_text],
602
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
603
- # right_btn.click(fn=lambda: [0.0, 90.0, None, preset_text],
604
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
605
- # random_btn.click(fn=lambda: [int(np.round(np.random.uniform(-60.0, 60.0))),
606
- # int(np.round(np.random.uniform(-150.0, 150.0))),
607
- # None, preset_text],
608
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
609
- # below_btn.click(fn=lambda: [-90.0, 0.0, None, preset_text],
610
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
611
- # behind_btn.click(fn=lambda: [0.0, 180.0, None, preset_text],
612
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
613
-
614
- # OLD 3 (does not work at all):
615
- # def a():
616
- # polar_slider.value = 77.7
617
- # polar_slider.postprocess(77.7)
618
- # print('testa')
619
- # left_btn.click(fn=a)
620
-
621
- cam_vis = CameraVisualizer(vis_output)
622
-
623
  vis_btn.click(fn=partial(main_run, models, device, cam_vis, 'vis'),
624
  inputs=[polar_slider, azimuth_slider, radius_slider,
625
  image_block, preprocess_chk],
@@ -641,19 +625,19 @@ def run_demo(
641
  inputs=preset_inputs, outputs=preset_outputs)
642
  above_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
643
  -90.0, 0.0, 0.0),
644
- inputs=preset_inputs, outputs=preset_outputs)
645
  right_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
646
  0.0, 90.0, 0.0),
647
- inputs=preset_inputs, outputs=preset_outputs)
648
  random_btn.click(fn=partial(main_run, models, device, cam_vis, 'rand_angles_gen',
649
  -1.0, -1.0, -1.0),
650
- inputs=preset_inputs, outputs=preset_outputs)
651
  below_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
652
  90.0, 0.0, 0.0),
653
- inputs=preset_inputs, outputs=preset_outputs)
654
  behind_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
655
  0.0, 180.0, 0.0),
656
- inputs=preset_inputs, outputs=preset_outputs)
657
 
658
  demo.launch(enable_queue=True)
659
 
 
1
  '''
2
  conda activate zero123
3
+ cd zero123
4
  python gradio_new.py 0
5
  '''
6
 
 
11
  import lovely_numpy
12
  import lovely_tensors
13
  import numpy as np
14
+ import os
15
  import plotly.express as px
16
  import plotly.graph_objects as go
17
  import rich
 
28
  from omegaconf import OmegaConf
29
  from PIL import Image
30
  from rich import print
31
+ from transformers import AutoFeatureExtractor
32
  from torch import autocast
33
  from torchvision import transforms
34
 
 
44
 
45
  # This demo allows you to generate novel viewpoints of an object depicted in an input image using a fine-tuned version of Stable Diffusion.
46
  _DESCRIPTION = '''
47
+ This live demo allows you to control camera rotation and thereby generate novel viewpoints of an object within a single image.
48
+ It is based on Stable Diffusion. Check out our [project webpage](https://zero123.cs.columbia.edu/) and [paper](https://arxiv.org/pdf/2303.11328.pdf) if you want to learn more about the method!
49
  Note that this model is not intended for images of humans or faces, and is unlikely to work well for them.
50
  '''
51
 
 
320
  '''
321
  :param raw_im (PIL Image).
322
  '''
323
+
324
  safety_checker_input = models['clip_fe'](raw_im, return_tensors='pt').to(device)
325
  (image, has_nsfw_concept) = models['nsfw'](
326
  images=np.ones((1, 3)), clip_input=safety_checker_input.pixel_values)
 
508
  with open('instructions.md', 'r') as f:
509
  article = f.read()
510
 
511
+ # NOTE: Examples must match inputs
512
+ # [polar_slider, azimuth_slider, radius_slider, image_block,
513
+ # preprocess_chk, scale_slider, samples_slider, steps_slider].
514
+ example_fns = ['1_blue_arm.png', '2_cybercar.png', '3_sushi.png', '4_blackarm.png',
515
+ '5_cybercar.png', '6_burger.png', '7_london.png', '8_motor.png']
516
+ num_examples = len(example_fns)
517
+ example_fps = [os.path.join(os.path.dirname(__file__), 'assets', x) for x in example_fns]
518
+ example_angles = [(-40.0, -65.0, 0.0), (-30.0, 90.0, 0.0), (45.0, -15.0, 0.0), (-75.0, 100.0, 0.0),
519
+ (-40.0, -75.0, 0.0), (-45.0, 0.0, 0.0), (-55.0, 90.0, 0.0), (-20.0, 125.0, 0.0)]
520
+ examples_full = [[*example_angles[i], example_fps[i], True, 3, 4, 50] for i in range(num_examples)]
521
+ print('examples_full:', examples_full)
522
+
523
  # Compose demo layout & data flow.
524
  demo = gr.Blocks(title=_TITLE)
525
 
 
571
  vis_btn = gr.Button('Visualize Angles', variant='secondary')
572
  run_btn = gr.Button('Run Generation', variant='primary')
573
 
574
+ desc_output = gr.Markdown(
575
+ 'The results will appear on the right.', visible=_SHOW_DESC)
576
 
577
  with gr.Column(scale=1.1, variant='panel'):
578
 
 
585
  preproc_output = gr.Image(type='pil', image_mode='RGB',
586
  label='Preprocessed input image', visible=_SHOW_INTERMEDIATE)
587
 
588
+ cam_vis = CameraVisualizer(vis_output)
589
+
590
+ gr.Examples(
591
+ examples=examples_full, # NOTE: elements must match inputs list!
592
+ fn=partial(main_run, models, device, cam_vis, 'gen'),
593
+ inputs=[polar_slider, azimuth_slider, radius_slider,
594
+ image_block, preprocess_chk,
595
+ scale_slider, samples_slider, steps_slider],
596
+ outputs=[desc_output, vis_output, preproc_output, gen_output],
597
+ cache_examples=True,
598
+ run_on_click=True,
599
+ )
600
+
601
  gr.Markdown(article)
602
 
603
  # NOTE: I am forced to update vis_output for these preset buttons,
604
  # because otherwise the gradio plot always resets the plotly 3D viewpoint for some reason,
605
  # which might confuse the user into thinking that the plot has been updated too.
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  vis_btn.click(fn=partial(main_run, models, device, cam_vis, 'vis'),
608
  inputs=[polar_slider, azimuth_slider, radius_slider,
609
  image_block, preprocess_chk],
 
625
  inputs=preset_inputs, outputs=preset_outputs)
626
  above_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
627
  -90.0, 0.0, 0.0),
628
+ inputs=preset_inputs, outputs=preset_outputs)
629
  right_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
630
  0.0, 90.0, 0.0),
631
+ inputs=preset_inputs, outputs=preset_outputs)
632
  random_btn.click(fn=partial(main_run, models, device, cam_vis, 'rand_angles_gen',
633
  -1.0, -1.0, -1.0),
634
+ inputs=preset_inputs, outputs=preset_outputs)
635
  below_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
636
  90.0, 0.0, 0.0),
637
+ inputs=preset_inputs, outputs=preset_outputs)
638
  behind_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
639
  0.0, 180.0, 0.0),
640
+ inputs=preset_inputs, outputs=preset_outputs)
641
 
642
  demo.launch(enable_queue=True)
643