DavidFM43 commited on
Commit
ef005af
·
1 Parent(s): d38fb36

small bug fix

Browse files
Files changed (1) hide show
  1. app.py +25 -18
app.py CHANGED
@@ -22,7 +22,7 @@ config = {
22
  },
23
  }
24
 
25
- class_rgb_colors = [
26
  (0, 255, 255),
27
  (255, 255, 0),
28
  (255, 0, 255),
@@ -33,16 +33,6 @@ class_rgb_colors = [
33
  ]
34
 
35
 
36
- def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor:
37
- """Transforms a tensor from label encoding to one hot encoding in boolean dtype"""
38
-
39
- dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
40
- return torch.permute(
41
- F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
42
- dims_p,
43
- )
44
-
45
-
46
  cp_path = "CP_epoch20.pth"
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
 
@@ -63,13 +53,21 @@ transform = transforms.Compose(
63
  )
64
 
65
 
 
 
 
 
 
 
 
 
 
 
66
  def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor:
67
  """Generates the segmentation ovelay for an satellite image"""
68
 
69
  masks = label_to_onehot(preds.squeeze(), 7)
70
- overlay = draw_segmentation_masks(
71
- image, masks=masks, alpha=alpha, colors=class_rgb_colors
72
- )
73
  return overlay
74
 
75
 
@@ -77,11 +75,14 @@ def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor:
77
  return torch.permute(image_tensor, (2, 0, 1))
78
 
79
 
 
 
 
 
80
  def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
81
  image_tensor = torch.from_numpy(satellite_image)
82
  image_tensor = hwc_to_chw(image_tensor)
83
  pil_image = transforms.functional.to_pil_image(image_tensor)
84
-
85
  # preprocess image
86
  X = transform(pil_image).unsqueeze(0)
87
  X = X.to(device)
@@ -92,8 +93,8 @@ def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
92
  # resize to evaluate with the original image
93
  preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
94
  # get rbg formatted images
95
- segmentation_overlay = hwc_to_chw(get_overlay(image_tensor, preds, 0.2)).numpy()
96
- raw_segmentation = hwc_to_chw(
97
  get_overlay(torch.zeros_like(image_tensor), preds, 1)
98
  ).numpy()
99
 
@@ -108,6 +109,12 @@ title = "Satellite Images Landcover Segmentation"
108
  description = "Upload an image or select from examples to segment"
109
 
110
  iface = gr.Interface(
111
- segment, i, o, examples=examples, title=title, description=description, cache_examples=True
 
 
 
 
 
 
112
  )
113
  iface.launch()
 
22
  },
23
  }
24
 
25
+ colors = [
26
  (0, 255, 255),
27
  (255, 255, 0),
28
  (255, 0, 255),
 
33
  ]
34
 
35
 
 
 
 
 
 
 
 
 
 
 
36
  cp_path = "CP_epoch20.pth"
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
 
 
53
  )
54
 
55
 
56
+ def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor:
57
+ """Transforms a tensor from label encoding to one hot encoding in boolean dtype"""
58
+
59
+ dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
60
+ return torch.permute(
61
+ F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
62
+ dims_p,
63
+ )
64
+
65
+
66
  def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor:
67
  """Generates the segmentation ovelay for an satellite image"""
68
 
69
  masks = label_to_onehot(preds.squeeze(), 7)
70
+ overlay = draw_segmentation_masks(image, masks=masks, alpha=alpha, colors=colors)
 
 
71
  return overlay
72
 
73
 
 
75
  return torch.permute(image_tensor, (2, 0, 1))
76
 
77
 
78
+ def chw_to_hwc(image_tensor: torch.Tensor) -> torch.Tensor:
79
+ return torch.permute(image_tensor, (1, 2, 0))
80
+
81
+
82
  def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
83
  image_tensor = torch.from_numpy(satellite_image)
84
  image_tensor = hwc_to_chw(image_tensor)
85
  pil_image = transforms.functional.to_pil_image(image_tensor)
 
86
  # preprocess image
87
  X = transform(pil_image).unsqueeze(0)
88
  X = X.to(device)
 
93
  # resize to evaluate with the original image
94
  preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
95
  # get rbg formatted images
96
+ segmentation_overlay = chw_to_hwc(get_overlay(image_tensor, preds, 0.2)).numpy()
97
+ raw_segmentation = chw_to_hwc(
98
  get_overlay(torch.zeros_like(image_tensor), preds, 1)
99
  ).numpy()
100
 
 
109
  description = "Upload an image or select from examples to segment"
110
 
111
  iface = gr.Interface(
112
+ segment,
113
+ i,
114
+ o,
115
+ examples=examples,
116
+ title=title,
117
+ description=description,
118
+ cache_examples=True,
119
  )
120
  iface.launch()