tokenid commited on
Commit
fc6f56d
·
1 Parent(s): 146da98

add random seed

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -11,6 +11,7 @@ import torch
11
  from torchvision import transforms
12
  import rembg
13
  import cv2
 
14
 
15
  from src.visualizer import CameraVisualizer
16
  from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs
@@ -113,7 +114,9 @@ def group_recenter(images, ratio=1.5, mask_thres=127, bkg_color=[255, 255, 255,
113
  return out_rgbs
114
 
115
 
116
- def run_preprocess(image1, image2, preprocess_chk):
 
 
117
 
118
  if preprocess_chk:
119
  rembg_session = rembg.new_session()
@@ -138,7 +141,9 @@ def image_to_tensor(img, width=256, height=256):
138
 
139
 
140
  @spaces.GPU
141
- def run_pose_exploration_a(cam_vis, image1, image2):
 
 
142
 
143
  image1 = image_to_tensor(image1).to(_device_)
144
  image2 = image_to_tensor(image2).to(_device_)
@@ -157,7 +162,9 @@ def run_pose_exploration_a(cam_vis, image1, image2):
157
 
158
 
159
  @spaces.GPU
160
- def run_pose_exploration_b(cam_vis, image1, image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters):
 
 
161
 
162
  noise = np.random.randn(probe_bsz, 4, 32, 32)
163
 
@@ -206,7 +213,9 @@ def run_pose_exploration_b(cam_vis, image1, image2, elevs, elev_ranges, probe_bs
206
 
207
 
208
  @spaces.GPU
209
- def run_pose_refinement(cam_vis, image1, image2, anchor_polar, explored_sph, refine_iters):
 
 
210
 
211
  cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
212
 
@@ -295,7 +304,8 @@ def run_demo():
295
  with gr.Accordion('Advanced options', open=False):
296
  probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size')
297
  adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size')
298
- adj_iters = gr.Slider(1, 20, value=5, step=1, label='Adjust Iterations')
 
299
 
300
  with gr.Row():
301
  run_btn = gr.Button('Estimate', variant='primary', interactive=True)
@@ -369,21 +379,21 @@ def run_demo():
369
 
370
  run_btn.click(
371
  fn=run_preprocess,
372
- inputs=[input_image1, input_image2, preprocess_chk],
373
  outputs=[processed_image1, processed_image2],
374
  ).success(
375
- fn=partial(run_pose_exploration_a, cam_vis),
376
- inputs=[processed_image1, processed_image2],
377
  outputs=[elevs, elev_ranges, vis_output]
378
  ).success(
379
  fn=partial(run_pose_exploration_b, cam_vis),
380
- inputs=[processed_image1, processed_image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters],
381
  outputs=[anchor_polar, explored_sph, vis_output, refine_btn]
382
  )
383
 
384
  refine_btn.click(
385
  fn=partial(run_pose_refinement, cam_vis),
386
- inputs=[processed_image1, processed_image2, anchor_polar, explored_sph, refine_iters],
387
  outputs=[refined_sph, vis_output]
388
  )
389
 
 
11
  from torchvision import transforms
12
  import rembg
13
  import cv2
14
+ from pytorch_lightning import seed_everything
15
 
16
  from src.visualizer import CameraVisualizer
17
  from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs
 
114
  return out_rgbs
115
 
116
 
117
+ def run_preprocess(image1, image2, preprocess_chk, seed_value):
118
+
119
+ seed_everything(seed_value)
120
 
121
  if preprocess_chk:
122
  rembg_session = rembg.new_session()
 
141
 
142
 
143
  @spaces.GPU
144
+ def run_pose_exploration_a(image1, image2, seed_value):
145
+
146
+ seed_everything(seed_value)
147
 
148
  image1 = image_to_tensor(image1).to(_device_)
149
  image2 = image_to_tensor(image2).to(_device_)
 
162
 
163
 
164
  @spaces.GPU
165
+ def run_pose_exploration_b(cam_vis, image1, image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters, seed_value):
166
+
167
+ seed_everything(seed_value)
168
 
169
  noise = np.random.randn(probe_bsz, 4, 32, 32)
170
 
 
213
 
214
 
215
  @spaces.GPU
216
+ def run_pose_refinement(cam_vis, image1, image2, anchor_polar, explored_sph, refine_iters, seed_value):
217
+
218
+ seed_everything(seed_value)
219
 
220
  cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
221
 
 
304
  with gr.Accordion('Advanced options', open=False):
305
  probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size')
306
  adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size')
307
+ adj_iters = gr.Slider(1, 20, value=8, step=1, label='Adjust Iterations')
308
+ seed_value = gr.Number(value=0, label="Seed Value", precision=0)
309
 
310
  with gr.Row():
311
  run_btn = gr.Button('Estimate', variant='primary', interactive=True)
 
379
 
380
  run_btn.click(
381
  fn=run_preprocess,
382
+ inputs=[input_image1, input_image2, preprocess_chk, seed_value],
383
  outputs=[processed_image1, processed_image2],
384
  ).success(
385
+ fn=run_pose_exploration_a,
386
+ inputs=[processed_image1, processed_image2, seed_value],
387
  outputs=[elevs, elev_ranges, vis_output]
388
  ).success(
389
  fn=partial(run_pose_exploration_b, cam_vis),
390
+ inputs=[processed_image1, processed_image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters, seed_value],
391
  outputs=[anchor_polar, explored_sph, vis_output, refine_btn]
392
  )
393
 
394
  refine_btn.click(
395
  fn=partial(run_pose_refinement, cam_vis),
396
+ inputs=[processed_image1, processed_image2, anchor_polar, explored_sph, refine_iters, seed_value],
397
  outputs=[refined_sph, vis_output]
398
  )
399