dakkoong commited on
Commit
83b7fa3
ยท
1 Parent(s): 771616a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -18
app.py CHANGED
@@ -1,12 +1,10 @@
1
  import gradio as gr
2
-
3
- from matplotlib import gridspec
4
- import matplotlib.pyplot as plt
5
  import numpy as np
6
  from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
 
 
10
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
11
  "nvidia/segformer-b1-finetuned-cityscapes-1024-1024"
12
  )
@@ -85,27 +83,18 @@ def sepia(input_img):
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
  logits = tf.image.resize(
87
  logits, input_img.size[::-1]
88
- ) # We reverse the shape of `image` because `image.size` returns width and height.
89
  seg = tf.math.argmax(logits, axis=-1)[0]
90
 
91
- color_seg = np.zeros(
92
- (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
93
- ) # height, width, 3
94
- for label, color in enumerate(colormap):
95
- color_seg[seg.numpy() == label, :] = color
96
-
97
- # Show image + mask
98
- pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
99
- pred_img = pred_img.astype(np.uint8)
100
-
101
- fig = draw_plot(pred_img, seg)
102
- return fig
103
 
 
104
  demo = gr.Interface(fn=sepia,
105
  inputs=gr.Image(shape=(800, 600)),
106
- outputs=['plot'],
107
  examples=["cityoutdoor-1.jpg", "cityoutdoor-2.jpg", "cityoutdoor-3.jpg"],
108
  allow_flagging='never')
109
 
110
-
111
  demo.launch()
 
1
  import gradio as gr
 
 
 
2
  import numpy as np
3
  from PIL import Image
4
  import tensorflow as tf
5
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
6
 
7
+
8
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
9
  "nvidia/segformer-b1-finetuned-cityscapes-1024-1024"
10
  )
 
83
  logits = tf.transpose(logits, [0, 2, 3, 1])
84
  logits = tf.image.resize(
85
  logits, input_img.size[::-1]
86
+ )
87
  seg = tf.math.argmax(logits, axis=-1)[0]
88
 
89
+ # Return segmentation label image instead of Matplotlib Figure
90
+ return seg.numpy()
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Gradio Interface ์„ค์ •
93
  demo = gr.Interface(fn=sepia,
94
  inputs=gr.Image(shape=(800, 600)),
95
+ outputs=['label'], # 'plot'์—์„œ 'label'๋กœ ๋ณ€๊ฒฝ
96
  examples=["cityoutdoor-1.jpg", "cityoutdoor-2.jpg", "cityoutdoor-3.jpg"],
97
  allow_flagging='never')
98
 
99
+ # Gradio ์‹คํ–‰
100
  demo.launch()