Liangyu commited on
Commit
00d78be
Β·
1 Parent(s): ada7a74

resolution limit to 800

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -45,6 +45,9 @@ Check out our [GitHub repo](https://github.com/Jingkang50/OpenPSG) and [official
45
  <img id="visualzation" src="https://github.com/Jingkang50/OpenPSG/blob/main/assets/psgtr_long.gif?raw=true" alt="visualzation" style="width:100%">
46
  </div>
47
  </div>
 
 
 
48
  '''
49
  FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=c-liangyu.openpsg" alt="visitor badge" />'
50
 
@@ -64,7 +67,7 @@ def parse_args() -> argparse.Namespace:
64
  def update_input_image(image: np.ndarray) -> dict:
65
  if image is None:
66
  return gr.Image.update(value=None)
67
- scale = 1500 / max(image.shape[:2])
68
  if scale < 1:
69
  image = cv2.resize(image, None, fx=scale, fy=scale)
70
  return gr.Image.update(value=image)
@@ -114,7 +117,7 @@ def main():
114
  run_button = gr.Button(value='Run')
115
  with gr.Column():
116
  with gr.Row():
117
- result = gr.Gallery(label='Result', type='numpy')
118
 
119
  with gr.Row():
120
  paths = sorted(pathlib.Path('images').rglob('*.jpg'))
 
45
  <img id="visualzation" src="https://github.com/Jingkang50/OpenPSG/blob/main/assets/psgtr_long.gif?raw=true" alt="visualzation" style="width:100%">
46
  </div>
47
  </div>
48
+
49
+ Inference takes 10-30 seconds per image. The model is PSGTR (60 epochs).
50
+
51
  '''
52
  FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=c-liangyu.openpsg" alt="visitor badge" />'
53
 
 
67
  def update_input_image(image: np.ndarray) -> dict:
68
  if image is None:
69
  return gr.Image.update(value=None)
70
+ scale = 800 / max(image.shape[:2])
71
  if scale < 1:
72
  image = cv2.resize(image, None, fx=scale, fy=scale)
73
  return gr.Image.update(value=image)
 
117
  run_button = gr.Button(value='Run')
118
  with gr.Column():
119
  with gr.Row():
120
+ result = gr.Gallery(label='PSGTR Result', type='numpy')
121
 
122
  with gr.Row():
123
  paths = sorted(pathlib.Path('images').rglob('*.jpg'))