File size: 3,189 Bytes
988ebda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import glob
import gradio as gr
import numpy as np
from os import environ
from PIL import Image
from torchvision import transforms as T
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor


example_images = sorted(glob.glob('examples/map*.jpg'))

ade_mean=[0.485, 0.456, 0.406]
ade_std=[0.229, 0.224, 0.225]

test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=ade_mean, std=ade_std)
])

palette = [
    [120, 120, 120], [4, 200, 4], [4, 4, 250], [6, 230, 230],
    [80, 50, 50], [120, 120, 80], [140, 140, 140], [204, 5, 255]
]

model_id = f"thiagohersan/maskformer-satellite-trees"
vegetation_labels = ["vegetation"]

# preprocessor = MaskFormerImageProcessor.from_pretrained(model_id)
preprocessor = MaskFormerImageProcessor(
    do_resize=False,
    do_normalize=False,
    do_rescale=False,
    ignore_index=255,
    reduce_labels=False
)

hf_token = environ.get('HFTOKEN')
model = MaskFormerForInstanceSegmentation.from_pretrained(model_id, use_auth_token=hf_token)


def visualize_instance_seg_mask(img_in, mask, id2label, included_labels):
    img_out = np.zeros((mask.shape[0], mask.shape[1], 3))
    image_total_pixels = mask.shape[0] * mask.shape[1]
    label_ids = np.unique(mask)

    id2color = {id: palette[id] for id in label_ids}
    id2count = {id: 0 for id in label_ids}

    for i in range(img_out.shape[0]):
      for j in range(img_out.shape[1]):
        img_out[i, j, :] = id2color[mask[i, j]]
        id2count[mask[i, j]] = id2count[mask[i, j]] + 1

    image_res = (0.5 * img_in + 0.5 * img_out).astype(np.uint8)

    dataframe = [[
        f"{id2label[id]}",
        f"{(100 * id2count[id] / image_total_pixels):.2f} %",
        f"{np.sqrt(id2count[id] / image_total_pixels):.2f} m"
        ] for id in label_ids if id2label[id] in included_labels]

    if len(dataframe) < 1:
        dataframe = [[
            f"",
            f"{(0):.2f} %",
            f"{(0):.2f} m"
        ]]

    return image_res, dataframe


def query_image(image_path):
    img = np.array(Image.open(image_path))
    img_size = (img.shape[0], img.shape[1])
    inputs = preprocessor(images=test_transform(img), return_tensors="pt")
    outputs = model(**inputs)
    results = preprocessor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[img_size])[0]
    mask_img, dataframe = visualize_instance_seg_mask(img, results.numpy(), model.config.id2label, vegetation_labels)
    return mask_img, dataframe


demo = gr.Interface(
    title="Maskformer Satellite+Trees",
    description="Using a finetuned version of the [facebook/maskformer-swin-base-ade](https://huggingface.co/facebook/maskformer-swin-base-ade) model (created specifically to work with satellite images) to calculate percentage of pixels in an image that belong to vegetation.",

    fn=query_image,
    inputs=[gr.Image(type="filepath", label="Input Image")],
    outputs=[
        gr.Image(label="Vegetation"),
        gr.DataFrame(label="Info", headers=["Object Label", "Pixel Percent", "Square Length"])
    ],

    examples=example_images,
    cache_examples=True,

    allow_flagging="never",
    analytics_enabled=None
)

demo.launch(show_api=False)