|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from transformers import SamModel, SamProcessor |
|
from gradio_image_prompter import ImagePrompter |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) |
|
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
|
slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device) |
|
slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform") |
|
|
|
def sam_box_inference(image, model, x_min, y_min, x_max, y_max): |
|
|
|
inputs = sam_processor( |
|
Image.fromarray(image), |
|
input_boxes=[[[[x_min, y_min, x_max, y_max]]]], |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
mask = sam_processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
)[0][0][0].numpy() |
|
mask = mask[np.newaxis, ...] |
|
print(mask) |
|
print(mask.shape) |
|
return [(mask, "mask")] |
|
|
|
|
|
def sam_point_inference(image, model, x, y): |
|
inputs = sam_processor( |
|
image, |
|
input_points=[[[x, y]]], |
|
return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = sam_model(**inputs) |
|
|
|
mask = sam_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
)[0][0][0].numpy() |
|
mask = mask[np.newaxis, ...] |
|
print(type(mask)) |
|
print(mask.shape) |
|
return [(mask, "mask")] |
|
|
|
def infer_point(img): |
|
if img is None: |
|
gr.Error("Please upload an image and select a point.") |
|
if img["background"] is None: |
|
gr.Error("Please upload an image and select a point.") |
|
|
|
image = img["background"].convert("RGB") |
|
point_prompt = img["layers"][0] |
|
total_image = img["composite"] |
|
img_arr = np.array(point_prompt) |
|
if not np.any(img_arr): |
|
gr.Error("Please select a point on top of the image.") |
|
else: |
|
nonzero_indices = np.nonzero(img_arr) |
|
img_arr = np.array(point_prompt) |
|
nonzero_indices = np.nonzero(img_arr) |
|
center_x = int(np.mean(nonzero_indices[1])) |
|
center_y = int(np.mean(nonzero_indices[0])) |
|
print("Point inference returned.") |
|
return ((image, sam_point_inference(image, slimsam_model, center_x, center_y)), |
|
(image, sam_point_inference(image, sam_model, center_x, center_y))) |
|
|
|
def infer_box(prompts): |
|
|
|
image = prompts["image"] |
|
if image is None: |
|
gr.Error("Please upload an image and draw a box before submitting") |
|
points = prompts["points"][0] |
|
if points is None: |
|
gr.Error("Please draw a box before submitting.") |
|
print(points) |
|
|
|
|
|
return ((image, sam_box_inference(image, slimsam_model, points[0], points[1], points[3], points[4])), |
|
(image, sam_box_inference(image, sam_model, points[0], points[1], points[3], points[4]))) |
|
with gr.Blocks(title="SlimSAM") as demo: |
|
gr.Markdown("# SlimSAM") |
|
gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.") |
|
gr.Markdown("In this demo, you can compare SlimSAM and SAM outputs in point and box prompts.") |
|
|
|
with gr.Tab("Box Prompt"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown("Box Prompting") |
|
with gr.Row(): |
|
with gr.Column(): |
|
im = ImagePrompter() |
|
btn = gr.Button("Submit") |
|
with gr.Column(): |
|
output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output") |
|
output_box_sam = gr.AnnotatedImage(label="SAM Output") |
|
|
|
|
|
btn.click(infer_box, inputs=im, outputs=[output_box_slimsam, output_box_sam]) |
|
|
|
with gr.Tab("Point Prompt"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown("Point Prompting") |
|
with gr.Row(): |
|
with gr.Column(): |
|
im = gr.ImageEditor( |
|
type="pil", |
|
) |
|
with gr.Column(): |
|
output_slimsam = gr.AnnotatedImage(label="SlimSAM Output") |
|
output_sam = gr.AnnotatedImage(label="SAM Output") |
|
|
|
im.change(infer_point, inputs=im, outputs=[output_slimsam, output_sam]) |
|
demo.launch(debug=True) |