Spaces:
Sleeping
Sleeping
import torch | |
from torchvision.utils import draw_segmentation_masks | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import segmentation_models_pytorch as smp | |
import gradio as gr | |
import os | |
config = { | |
"downsize_res": 512, | |
"batch_size": 6, | |
"epochs": 30, | |
"lr": 3e-4, | |
"model_architecture": "Unet", | |
"model_config": { | |
"encoder_name": "resnet34", | |
"encoder_weights": "imagenet", | |
"in_channels": 3, | |
"classes": 7, | |
}, | |
} | |
class_rgb_colors = [ | |
(0, 255, 255), | |
(255, 255, 0), | |
(255, 0, 255), | |
(0, 255, 0), | |
(0, 0, 255), | |
(255, 255, 255), | |
(0, 0, 0), | |
] | |
def label_to_onehot(mask, num_classes): | |
dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2) | |
return torch.permute( | |
F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool), | |
dims_p, | |
) | |
cp_path = "CP_epoch20.pth" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load model | |
model_architecture = getattr(smp, config["model_architecture"]) | |
model = model_architecture(**config["model_config"]) | |
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device))) | |
model.to(device) | |
model.eval() | |
# mean = [0.4085, 0.3798, 0.2822] | |
# std = [0.1410, 0.1051, 0.0927] | |
# transforms | |
downsize_t = transforms.Resize(config["downsize_res"], antialias=True) | |
transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
# transforms.Normalize(mean, std), | |
] | |
) | |
def get_overlay(sat_img, preds, alpha): | |
class_rgb_colors = [ | |
(0, 255, 255), | |
(255, 255, 0), | |
(255, 0, 255), | |
(0, 255, 0), | |
(0, 0, 255), | |
(255, 255, 255), | |
(0, 0, 0), | |
] | |
masks = preds.squeeze() | |
masks = label_to_onehot(masks, 7) | |
overlay = draw_segmentation_masks( | |
sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors | |
) | |
return overlay | |
def segment(numpy_arr): | |
sat_img_arr = torch.from_numpy(numpy_arr) | |
sat_img_arr = torch.permute(sat_img_arr, (2, 0, 1)) | |
sat_img_pil = transforms.functional.to_pil_image(sat_img_arr) | |
# preprocess image | |
X = transform(sat_img_pil).unsqueeze(0) | |
X = X.to(device) | |
X_down = downsize_t(X) | |
# forward pass | |
logits = model(X_down) | |
preds = torch.argmax(logits, 1).detach() | |
# resize to evaluate with the original image | |
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True) | |
# ger rbg formatted images | |
overlay = get_overlay(sat_img_arr, preds, 0.2) | |
raw_masks = get_overlay(torch.zeros_like(sat_img_arr), preds, 1) | |
raw_masks = torch.permute(raw_masks, (1, 2, 0)) | |
overlay = torch.permute(overlay, (1, 2, 0)) | |
return raw_masks.numpy(), overlay.numpy() | |
i = gr.inputs.Image() | |
o = [gr.Image(), gr.Image()] | |
images_dir = "sample_sat_images/" | |
image_ids = os.listdir(images_dir) | |
examples = [f"{images_dir}/{image_id}" for image_id in image_ids] | |
title = "Satellite Images Landcover Segmentation" | |
description = "Upload an image or select from examples to segment" | |
iface = gr.Interface( | |
segment, i, o, examples=examples, title=title, description=description | |
) | |
iface.launch(debug=True) | |