File size: 2,171 Bytes
d4dcd19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f07415
d4dcd19
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import random
import gradio as gr
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation


# Use GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
model.eval()
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")


def visualize_instance_seg_mask(mask):
    # Initialize image
    image = np.zeros((mask.shape[0], mask.shape[1], 3))

    labels = np.unique(mask)
    label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}

    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        image[i, j, :] = label2color[mask[i, j]]

    image = image / 255
    return image

def query_image(img):
    target_size = (img.shape[0], img.shape[1])
    inputs = preprocessor(images=img, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    outputs.class_queries_logits = outputs.class_queries_logits.cpu()
    outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
    results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
    results = torch.argmax(results, dim=0).numpy()
    results = visualize_instance_seg_mask(results)

    return results


description = """
Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/maskformer">MaskFormer</a>, 
introduced in <a href="https://arxiv.org/abs/2107.06278">Per-Pixel Classification is Not All You Need for Semantic Segmentation
</a>. 
\n\n"MaskFormer is a unified framework for panoptic, instance and semantic segmentation, trained across four popular datasets (ADE20K, Cityscapes, COCO, Mapillary Vistas). 
"""

demo = gr.Interface(
    query_image, 
    inputs=[gr.Image()], 
    outputs="image",
    title="MaskFormer Demo",
    description=description,
    examples=["assets/test_image_35.png", "assets/test_image_82.png"]
)
demo.launch()