Spaces:
Runtime error
Runtime error
File size: 2,332 Bytes
d4dcd19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import torch
import random
import gradio as gr
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
model.eval()
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
def visualize_instance_seg_mask(mask):
# Initialize image
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
target_size = (img.shape[0], img.shape[1])
inputs = preprocessor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
outputs.class_queries_logits = outputs.class_queries_logits.cpu()
outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
results = torch.argmax(results, dim=0).numpy()
results = visualize_instance_seg_mask(results)
return results
description = """
Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/maskformer">MaskFormer</a>,
introduced in <a href="https://arxiv.org/abs/2107.06278">Per-Pixel Classification is Not All You Need for Semantic Segmentation
</a>.
\n\n"Mask2Former is a unified framework architecture based on MaskFormer meta-architecture that achieves SOTA on panoptic,
instance and semantic segmentation across four popular datasets (ADE20K, Cityscapes, COCO, Mapillary Vistas). You can use
MaskFormer for semantic, instance (illustrated in the demo) and panoptic segmentation.
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="MaskFormer Demo",
description=description,
examples=["assets/test_image_35.png", "assets/test_image_82.png"]
)
demo.launch() |