Spaces:
Runtime error
Runtime error
import torch | |
import random | |
import gradio as gr | |
import numpy as np | |
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation | |
# Use GPU if available | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device) | |
model.eval() | |
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade") | |
def visualize_instance_seg_mask(mask): | |
# Initialize image | |
image = np.zeros((mask.shape[0], mask.shape[1], 3)) | |
labels = np.unique(mask) | |
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels} | |
for i in range(image.shape[0]): | |
for j in range(image.shape[1]): | |
image[i, j, :] = label2color[mask[i, j]] | |
image = image / 255 | |
return image | |
def query_image(img): | |
target_size = (img.shape[0], img.shape[1]) | |
inputs = preprocessor(images=img, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
outputs.class_queries_logits = outputs.class_queries_logits.cpu() | |
outputs.masks_queries_logits = outputs.masks_queries_logits.cpu() | |
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach() | |
results = torch.argmax(results, dim=0).numpy() | |
results = visualize_instance_seg_mask(results) | |
return results | |
description = """ | |
Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/maskformer">MaskFormer</a>, | |
introduced in <a href="https://arxiv.org/abs/2107.06278">Per-Pixel Classification is Not All You Need for Semantic Segmentation | |
</a>. | |
\n\n"MaskFormer is a unified framework for panoptic, instance and semantic segmentation, trained across four popular datasets (ADE20K, Cityscapes, COCO, Mapillary Vistas). | |
""" | |
demo = gr.Interface( | |
query_image, | |
inputs=[gr.Image()], | |
outputs="image", | |
title="MaskFormer Demo", | |
description=description, | |
examples=["assets/test_image_35.png", "assets/test_image_82.png"] | |
) | |
demo.launch() |