DavidFM43 commited on
Commit
bc4c15d
·
1 Parent(s): 3a7ac79

Clean code

Browse files
Files changed (1) hide show
  1. app.py +27 -34
app.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
4
  from torchvision import transforms
5
  import segmentation_models_pytorch as smp
6
  import gradio as gr
 
7
  import os
8
 
9
  config = {
@@ -31,7 +32,9 @@ class_rgb_colors = [
31
  ]
32
 
33
 
34
- def label_to_onehot(mask, num_classes):
 
 
35
  dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
36
  return torch.permute(
37
  F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
@@ -50,43 +53,36 @@ model.to(device)
50
  model.eval()
51
 
52
 
53
- # mean = [0.4085, 0.3798, 0.2822]
54
- # std = [0.1410, 0.1051, 0.0927]
55
  # transforms
56
  downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
57
  transform = transforms.Compose(
58
  [
59
  transforms.ToTensor(),
60
- # transforms.Normalize(mean, std),
61
  ]
62
  )
63
 
64
 
65
- def get_overlay(sat_img, preds, alpha):
66
- class_rgb_colors = [
67
- (0, 255, 255),
68
- (255, 255, 0),
69
- (255, 0, 255),
70
- (0, 255, 0),
71
- (0, 0, 255),
72
- (255, 255, 255),
73
- (0, 0, 0),
74
- ]
75
- masks = preds.squeeze()
76
- masks = label_to_onehot(masks, 7)
77
  overlay = draw_segmentation_masks(
78
- sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors
79
  )
80
  return overlay
81
 
82
 
83
- def segment(numpy_arr):
84
- sat_img_arr = torch.from_numpy(numpy_arr)
85
- sat_img_arr = torch.permute(sat_img_arr, (2, 0, 1))
86
- sat_img_pil = transforms.functional.to_pil_image(sat_img_arr)
 
 
 
 
87
 
88
  # preprocess image
89
- X = transform(sat_img_pil).unsqueeze(0)
90
  X = X.to(device)
91
  X_down = downsize_t(X)
92
  # forward pass
@@ -94,26 +90,23 @@ def segment(numpy_arr):
94
  preds = torch.argmax(logits, 1).detach()
95
  # resize to evaluate with the original image
96
  preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
 
 
 
 
 
97
 
98
- # ger rbg formatted images
99
- overlay = get_overlay(sat_img_arr, preds, 0.2)
100
- raw_masks = get_overlay(torch.zeros_like(sat_img_arr), preds, 1)
101
- raw_masks = torch.permute(raw_masks, (1, 2, 0))
102
- overlay = torch.permute(overlay, (1, 2, 0))
103
-
104
- return raw_masks.numpy(), overlay.numpy()
105
 
106
 
107
  i = gr.inputs.Image()
108
  o = [gr.Image(), gr.Image()]
109
  images_dir = "sample_sat_images/"
110
-
111
- image_ids = os.listdir(images_dir)
112
- examples = [f"{images_dir}/{image_id}" for image_id in image_ids]
113
  title = "Satellite Images Landcover Segmentation"
114
  description = "Upload an image or select from examples to segment"
115
 
116
  iface = gr.Interface(
117
- segment, i, o, examples=examples, title=title, description=description
118
  )
119
- iface.launch(debug=True)
 
4
  from torchvision import transforms
5
  import segmentation_models_pytorch as smp
6
  import gradio as gr
7
+ import numpy as np
8
  import os
9
 
10
  config = {
 
32
  ]
33
 
34
 
35
+ def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor:
36
+ """Transforms a tensor from label encoding to one hot encoding in boolean dtype"""
37
+
38
  dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
39
  return torch.permute(
40
  F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
 
53
  model.eval()
54
 
55
 
 
 
56
  # transforms
57
  downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
58
  transform = transforms.Compose(
59
  [
60
  transforms.ToTensor(),
 
61
  ]
62
  )
63
 
64
 
65
+ def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor:
66
+ """Generates the segmentation ovelay for an satellite image"""
67
+
68
+ masks = preds.squeeze().label_to_onehot(masks, 7)
 
 
 
 
 
 
 
 
69
  overlay = draw_segmentation_masks(
70
+ image, masks=masks, alpha=alpha, colors=class_rgb_colors
71
  )
72
  return overlay
73
 
74
 
75
+ def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor:
76
+ return torch.permute(image_tensor, (2, 0, 1))
77
+
78
+
79
+ def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
80
+ image_tensor = torch.from_numpy(satellite_image)
81
+ image_tensor = hwc_to_chw(image_tensor)
82
+ pil_image = transforms.functional.to_pil_image(image_tensor)
83
 
84
  # preprocess image
85
+ X = transform(pil_image).unsqueeze(0)
86
  X = X.to(device)
87
  X_down = downsize_t(X)
88
  # forward pass
 
90
  preds = torch.argmax(logits, 1).detach()
91
  # resize to evaluate with the original image
92
  preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
93
+ # get rbg formatted images
94
+ segmentation_overlay = hwc_to_chw(get_overlay(image_tensor, preds, 0.2)).numpy()
95
+ raw_segmentation = hwc_to_chw(
96
+ get_overlay(torch.zeros_like(image_tensor), preds, 1)
97
+ ).numpy()
98
 
99
+ return raw_segmentation, segmentation_overlay
 
 
 
 
 
 
100
 
101
 
102
  i = gr.inputs.Image()
103
  o = [gr.Image(), gr.Image()]
104
  images_dir = "sample_sat_images/"
105
+ examples = [f"{images_dir}/{image_id}" for image_id in os.listdir(images_dir)]
 
 
106
  title = "Satellite Images Landcover Segmentation"
107
  description = "Upload an image or select from examples to segment"
108
 
109
  iface = gr.Interface(
110
+ segment, i, o, examples=examples, title=title, description=description, cache_examples=True
111
  )
112
+ iface.launch()