DavidFM43 commited on
Commit
3a7ac79
·
1 Parent(s): f6e32c6

upload demo

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.jpg filter=lfs diff=lfs merge=lfs -text
CP_epoch20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddd8fde0d7ee3ccc5714c0b4c1f3cd31cbfaa8edf23676c92187e504b9bc58b5
3
+ size 97921621
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.utils import draw_segmentation_masks
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ import segmentation_models_pytorch as smp
6
+ import gradio as gr
7
+ import os
8
+
9
+ config = {
10
+ "downsize_res": 512,
11
+ "batch_size": 6,
12
+ "epochs": 30,
13
+ "lr": 3e-4,
14
+ "model_architecture": "Unet",
15
+ "model_config": {
16
+ "encoder_name": "resnet34",
17
+ "encoder_weights": "imagenet",
18
+ "in_channels": 3,
19
+ "classes": 7,
20
+ },
21
+ }
22
+
23
+ class_rgb_colors = [
24
+ (0, 255, 255),
25
+ (255, 255, 0),
26
+ (255, 0, 255),
27
+ (0, 255, 0),
28
+ (0, 0, 255),
29
+ (255, 255, 255),
30
+ (0, 0, 0),
31
+ ]
32
+
33
+
34
+ def label_to_onehot(mask, num_classes):
35
+ dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
36
+ return torch.permute(
37
+ F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
38
+ dims_p,
39
+ )
40
+
41
+
42
+ cp_path = "CP_epoch20.pth"
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ # load model
46
+ model_architecture = getattr(smp, config["model_architecture"])
47
+ model = model_architecture(**config["model_config"])
48
+ model.load_state_dict(torch.load(cp_path, map_location=torch.device(device)))
49
+ model.to(device)
50
+ model.eval()
51
+
52
+
53
+ # mean = [0.4085, 0.3798, 0.2822]
54
+ # std = [0.1410, 0.1051, 0.0927]
55
+ # transforms
56
+ downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
57
+ transform = transforms.Compose(
58
+ [
59
+ transforms.ToTensor(),
60
+ # transforms.Normalize(mean, std),
61
+ ]
62
+ )
63
+
64
+
65
+ def get_overlay(sat_img, preds, alpha):
66
+ class_rgb_colors = [
67
+ (0, 255, 255),
68
+ (255, 255, 0),
69
+ (255, 0, 255),
70
+ (0, 255, 0),
71
+ (0, 0, 255),
72
+ (255, 255, 255),
73
+ (0, 0, 0),
74
+ ]
75
+ masks = preds.squeeze()
76
+ masks = label_to_onehot(masks, 7)
77
+ overlay = draw_segmentation_masks(
78
+ sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors
79
+ )
80
+ return overlay
81
+
82
+
83
+ def segment(numpy_arr):
84
+ sat_img_arr = torch.from_numpy(numpy_arr)
85
+ sat_img_arr = torch.permute(sat_img_arr, (2, 0, 1))
86
+ sat_img_pil = transforms.functional.to_pil_image(sat_img_arr)
87
+
88
+ # preprocess image
89
+ X = transform(sat_img_pil).unsqueeze(0)
90
+ X = X.to(device)
91
+ X_down = downsize_t(X)
92
+ # forward pass
93
+ logits = model(X_down)
94
+ preds = torch.argmax(logits, 1).detach()
95
+ # resize to evaluate with the original image
96
+ preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
97
+
98
+ # ger rbg formatted images
99
+ overlay = get_overlay(sat_img_arr, preds, 0.2)
100
+ raw_masks = get_overlay(torch.zeros_like(sat_img_arr), preds, 1)
101
+ raw_masks = torch.permute(raw_masks, (1, 2, 0))
102
+ overlay = torch.permute(overlay, (1, 2, 0))
103
+
104
+ return raw_masks.numpy(), overlay.numpy()
105
+
106
+
107
+ i = gr.inputs.Image()
108
+ o = [gr.Image(), gr.Image()]
109
+ images_dir = "sample_sat_images/"
110
+
111
+ image_ids = os.listdir(images_dir)
112
+ examples = [f"{images_dir}/{image_id}" for image_id in image_ids]
113
+ title = "Satellite Images Landcover Segmentation"
114
+ description = "Upload an image or select from examples to segment"
115
+
116
+ iface = gr.Interface(
117
+ segment, i, o, examples=examples, title=title, description=description
118
+ )
119
+ iface.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ segmentation-models-pytorch
4
+ gradio
sample_sat_images/118757_sat.jpg ADDED

Git LFS Details

  • SHA256: 71f5be35839d4563db38d5ccc9122ae45d49c21796c0c3da823bfe05fb6f683c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.81 MB
sample_sat_images/26961_sat.jpg ADDED

Git LFS Details

  • SHA256: a6b9726624a5a0bf7c9ce5254933ebfa16e826826a71c3abde4739190d63481f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.76 MB
sample_sat_images/6390_sat.jpg ADDED

Git LFS Details

  • SHA256: 37a280ed2750173772b279c3bf79aac062286797e681bacaca5983b6a3796657
  • Pointer size: 132 Bytes
  • Size of remote file: 2.64 MB
sample_sat_images/999380_sat.jpg ADDED

Git LFS Details

  • SHA256: da30b6992e759fd8c6ed1a75df655f359284090b25e0513f3c504e92c8e5e29d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.96 MB