Spaces:
Sleeping
Sleeping
upload demo
Browse files- .gitattributes +1 -0
- CP_epoch20.pth +3 -0
- app.py +119 -0
- requirements.txt +4 -0
- sample_sat_images/118757_sat.jpg +3 -0
- sample_sat_images/26961_sat.jpg +3 -0
- sample_sat_images/6390_sat.jpg +3 -0
- sample_sat_images/999380_sat.jpg +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
CP_epoch20.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ddd8fde0d7ee3ccc5714c0b4c1f3cd31cbfaa8edf23676c92187e504b9bc58b5
|
3 |
+
size 97921621
|
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.utils import draw_segmentation_masks
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import transforms
|
5 |
+
import segmentation_models_pytorch as smp
|
6 |
+
import gradio as gr
|
7 |
+
import os
|
8 |
+
|
9 |
+
config = {
|
10 |
+
"downsize_res": 512,
|
11 |
+
"batch_size": 6,
|
12 |
+
"epochs": 30,
|
13 |
+
"lr": 3e-4,
|
14 |
+
"model_architecture": "Unet",
|
15 |
+
"model_config": {
|
16 |
+
"encoder_name": "resnet34",
|
17 |
+
"encoder_weights": "imagenet",
|
18 |
+
"in_channels": 3,
|
19 |
+
"classes": 7,
|
20 |
+
},
|
21 |
+
}
|
22 |
+
|
23 |
+
class_rgb_colors = [
|
24 |
+
(0, 255, 255),
|
25 |
+
(255, 255, 0),
|
26 |
+
(255, 0, 255),
|
27 |
+
(0, 255, 0),
|
28 |
+
(0, 0, 255),
|
29 |
+
(255, 255, 255),
|
30 |
+
(0, 0, 0),
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
def label_to_onehot(mask, num_classes):
|
35 |
+
dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
|
36 |
+
return torch.permute(
|
37 |
+
F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
|
38 |
+
dims_p,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
cp_path = "CP_epoch20.pth"
|
43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
+
|
45 |
+
# load model
|
46 |
+
model_architecture = getattr(smp, config["model_architecture"])
|
47 |
+
model = model_architecture(**config["model_config"])
|
48 |
+
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device)))
|
49 |
+
model.to(device)
|
50 |
+
model.eval()
|
51 |
+
|
52 |
+
|
53 |
+
# mean = [0.4085, 0.3798, 0.2822]
|
54 |
+
# std = [0.1410, 0.1051, 0.0927]
|
55 |
+
# transforms
|
56 |
+
downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
|
57 |
+
transform = transforms.Compose(
|
58 |
+
[
|
59 |
+
transforms.ToTensor(),
|
60 |
+
# transforms.Normalize(mean, std),
|
61 |
+
]
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def get_overlay(sat_img, preds, alpha):
|
66 |
+
class_rgb_colors = [
|
67 |
+
(0, 255, 255),
|
68 |
+
(255, 255, 0),
|
69 |
+
(255, 0, 255),
|
70 |
+
(0, 255, 0),
|
71 |
+
(0, 0, 255),
|
72 |
+
(255, 255, 255),
|
73 |
+
(0, 0, 0),
|
74 |
+
]
|
75 |
+
masks = preds.squeeze()
|
76 |
+
masks = label_to_onehot(masks, 7)
|
77 |
+
overlay = draw_segmentation_masks(
|
78 |
+
sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors
|
79 |
+
)
|
80 |
+
return overlay
|
81 |
+
|
82 |
+
|
83 |
+
def segment(numpy_arr):
|
84 |
+
sat_img_arr = torch.from_numpy(numpy_arr)
|
85 |
+
sat_img_arr = torch.permute(sat_img_arr, (2, 0, 1))
|
86 |
+
sat_img_pil = transforms.functional.to_pil_image(sat_img_arr)
|
87 |
+
|
88 |
+
# preprocess image
|
89 |
+
X = transform(sat_img_pil).unsqueeze(0)
|
90 |
+
X = X.to(device)
|
91 |
+
X_down = downsize_t(X)
|
92 |
+
# forward pass
|
93 |
+
logits = model(X_down)
|
94 |
+
preds = torch.argmax(logits, 1).detach()
|
95 |
+
# resize to evaluate with the original image
|
96 |
+
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
|
97 |
+
|
98 |
+
# ger rbg formatted images
|
99 |
+
overlay = get_overlay(sat_img_arr, preds, 0.2)
|
100 |
+
raw_masks = get_overlay(torch.zeros_like(sat_img_arr), preds, 1)
|
101 |
+
raw_masks = torch.permute(raw_masks, (1, 2, 0))
|
102 |
+
overlay = torch.permute(overlay, (1, 2, 0))
|
103 |
+
|
104 |
+
return raw_masks.numpy(), overlay.numpy()
|
105 |
+
|
106 |
+
|
107 |
+
i = gr.inputs.Image()
|
108 |
+
o = [gr.Image(), gr.Image()]
|
109 |
+
images_dir = "sample_sat_images/"
|
110 |
+
|
111 |
+
image_ids = os.listdir(images_dir)
|
112 |
+
examples = [f"{images_dir}/{image_id}" for image_id in image_ids]
|
113 |
+
title = "Satellite Images Landcover Segmentation"
|
114 |
+
description = "Upload an image or select from examples to segment"
|
115 |
+
|
116 |
+
iface = gr.Interface(
|
117 |
+
segment, i, o, examples=examples, title=title, description=description
|
118 |
+
)
|
119 |
+
iface.launch(debug=True)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
segmentation-models-pytorch
|
4 |
+
gradio
|
sample_sat_images/118757_sat.jpg
ADDED
Git LFS Details
|
sample_sat_images/26961_sat.jpg
ADDED
Git LFS Details
|
sample_sat_images/6390_sat.jpg
ADDED
Git LFS Details
|
sample_sat_images/999380_sat.jpg
ADDED
Git LFS Details
|