import torch from torchvision.utils import draw_segmentation_masks import torch.nn.functional as F from torchvision import transforms import segmentation_models_pytorch as smp import gradio as gr import os config = { "downsize_res": 512, "batch_size": 6, "epochs": 30, "lr": 3e-4, "model_architecture": "Unet", "model_config": { "encoder_name": "resnet34", "encoder_weights": "imagenet", "in_channels": 3, "classes": 7, }, } class_rgb_colors = [ (0, 255, 255), (255, 255, 0), (255, 0, 255), (0, 255, 0), (0, 0, 255), (255, 255, 255), (0, 0, 0), ] def label_to_onehot(mask, num_classes): dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2) return torch.permute( F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool), dims_p, ) cp_path = "CP_epoch20.pth" device = "cuda" if torch.cuda.is_available() else "cpu" # load model model_architecture = getattr(smp, config["model_architecture"]) model = model_architecture(**config["model_config"]) model.load_state_dict(torch.load(cp_path, map_location=torch.device(device))) model.to(device) model.eval() # mean = [0.4085, 0.3798, 0.2822] # std = [0.1410, 0.1051, 0.0927] # transforms downsize_t = transforms.Resize(config["downsize_res"], antialias=True) transform = transforms.Compose( [ transforms.ToTensor(), # transforms.Normalize(mean, std), ] ) def get_overlay(sat_img, preds, alpha): class_rgb_colors = [ (0, 255, 255), (255, 255, 0), (255, 0, 255), (0, 255, 0), (0, 0, 255), (255, 255, 255), (0, 0, 0), ] masks = preds.squeeze() masks = label_to_onehot(masks, 7) overlay = draw_segmentation_masks( sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors ) return overlay def segment(numpy_arr): sat_img_arr = torch.from_numpy(numpy_arr) sat_img_arr = torch.permute(sat_img_arr, (2, 0, 1)) sat_img_pil = transforms.functional.to_pil_image(sat_img_arr) # preprocess image X = transform(sat_img_pil).unsqueeze(0) X = X.to(device) X_down = downsize_t(X) # forward pass logits = model(X_down) preds = torch.argmax(logits, 1).detach() # resize to evaluate with the original image preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True) # ger rbg formatted images overlay = get_overlay(sat_img_arr, preds, 0.2) raw_masks = get_overlay(torch.zeros_like(sat_img_arr), preds, 1) raw_masks = torch.permute(raw_masks, (1, 2, 0)) overlay = torch.permute(overlay, (1, 2, 0)) return raw_masks.numpy(), overlay.numpy() i = gr.inputs.Image() o = [gr.Image(), gr.Image()] images_dir = "sample_sat_images/" image_ids = os.listdir(images_dir) examples = [f"{images_dir}/{image_id}" for image_id in image_ids] title = "Satellite Images Landcover Segmentation" description = "Upload an image or select from examples to segment" iface = gr.Interface( segment, i, o, examples=examples, title=title, description=description ) iface.launch(debug=True)