Vishu26 commited on
Commit
c9f086f
·
1 Parent(s): ec04505
Files changed (1) hide show
  1. app.py +24 -3
app.py CHANGED
@@ -13,6 +13,8 @@ def get_index_of_element_containing_word(lst, word):
13
  return indices[0] if indices else -1
14
 
15
  pred_global = None
 
 
16
 
17
  stl_preds = np.load("stl_species.npy")
18
  df = pd.read_csv("gbif_full_filtered.csv")
@@ -36,30 +38,45 @@ def update_fn(val):
36
  return gr.Dropdown(label="Name", choices=obs, interactive=True)
37
 
38
  def text_fn(taxon, name):
39
- global pred_global
40
 
41
  species_index = get_index_of_element_containing_word(obs, name)
42
  preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1)
43
 
44
  pred_global = preds
 
45
  cmap = plt.get_cmap('plasma')
46
 
47
  rgba_img = cmap(preds)
48
  rgb_img = np.delete(rgba_img, 3, 2)
49
- blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), 0.5)
50
  rgb_img = np.array(blend)
51
  #return gr.Image(preds, label="Predicted Heatmap", visible=True)
52
  return rgb_img
53
 
54
  def thresh_fn(val):
55
- global pred_global
56
  preds = deepcopy(pred_global)
57
  preds[preds<val] = 0
58
  preds[preds>=val] = 1
 
59
  cmap = plt.get_cmap('plasma')
60
 
61
  rgba_img = cmap(preds)
62
  rgb_img = np.delete(rgba_img, 3, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  return rgb_img
64
 
65
  with gr.Blocks() as demo:
@@ -79,10 +96,14 @@ with gr.Blocks() as demo:
79
  with gr.Row():
80
  slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold")
81
 
 
 
 
82
  with gr.Row():
83
  pred = gr.Image(label="Predicted Heatmap", visible=True)
84
 
85
  check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
86
  slider.change(thresh_fn, slider, outputs=pred)
 
87
 
88
  demo.launch()
 
13
  return indices[0] if indices else -1
14
 
15
  pred_global = None
16
+ alpha_global = 0.5
17
+ alpha_image = None
18
 
19
  stl_preds = np.load("stl_species.npy")
20
  df = pd.read_csv("gbif_full_filtered.csv")
 
38
  return gr.Dropdown(label="Name", choices=obs, interactive=True)
39
 
40
  def text_fn(taxon, name):
41
+ global pred_global, alpha_global, alpha_image
42
 
43
  species_index = get_index_of_element_containing_word(obs, name)
44
  preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1)
45
 
46
  pred_global = preds
47
+ alpha_image = preds
48
  cmap = plt.get_cmap('plasma')
49
 
50
  rgba_img = cmap(preds)
51
  rgb_img = np.delete(rgba_img, 3, 2)
52
+ blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
53
  rgb_img = np.array(blend)
54
  #return gr.Image(preds, label="Predicted Heatmap", visible=True)
55
  return rgb_img
56
 
57
  def thresh_fn(val):
58
+ global pred_global, alpha_global, alpha_image
59
  preds = deepcopy(pred_global)
60
  preds[preds<val] = 0
61
  preds[preds>=val] = 1
62
+ alpha_image = deepcopy(preds)
63
  cmap = plt.get_cmap('plasma')
64
 
65
  rgba_img = cmap(preds)
66
  rgb_img = np.delete(rgba_img, 3, 2)
67
+ blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
68
+ rgb_img = np.array(blend)
69
+ return rgb_img
70
+
71
+ def alpha_fn(val):
72
+ global pred_global, alpha_global, alpha_image
73
+ alpha_global = val
74
+ preds = deepcopy(alpha_image)
75
+ cmap = plt.get_cmap('plasma')
76
+ rgba_img = cmap(preds)
77
+ rgb_img = np.delete(rgba_img, 3, 2)
78
+ blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
79
+ rgb_img = np.array(blend)
80
  return rgb_img
81
 
82
  with gr.Blocks() as demo:
 
96
  with gr.Row():
97
  slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold")
98
 
99
+ with gr.Row():
100
+ alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Image Transparency")
101
+
102
  with gr.Row():
103
  pred = gr.Image(label="Predicted Heatmap", visible=True)
104
 
105
  check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
106
  slider.change(thresh_fn, slider, outputs=pred)
107
+ alpha.change(alpha_fn, alpha, outputs=pred)
108
 
109
  demo.launch()