amaye15
Sam 2 point prompt
933c40c
raw
history blame
5.88 kB
# import gradio as gr
# from gradio_image_prompter import ImagePrompter
# import os
# import torch
# def prompter(prompts):
# image = prompts["image"] # Get the image from prompts
# points = prompts["points"] # Get the points from prompts
# # Print the collected inputs for debugging or logging
# print("Image received:", image)
# print("Points received:", points)
# import torch
# from sam2.sam2_image_predictor import SAM2ImagePredictor
# device = torch.device("cpu")
# predictor = SAM2ImagePredictor.from_pretrained(
# "facebook/sam2-hiera-base-plus", device=device
# )
# with torch.inference_mode():
# predictor.set_image(image)
# # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points])
# input_point = [[point[0], point[1]] for point in points]
# input_label = [1]
# masks, _, _ = predictor.predict(
# point_coords=input_point, point_labels=input_label
# )
# print("Predicted Mask:", masks)
# return image, points
# # Define the Gradio interface
# demo = gr.Interface(
# fn=prompter, # Use the custom prompter function
# inputs=ImagePrompter(
# show_label=False
# ), # ImagePrompter for image input and point selection
# outputs=[
# gr.Image(show_label=False), # Display the image
# gr.Dataframe(label="Points"), # Display the points in a DataFrame
# ],
# title="Image Point Collector",
# description="Upload an image, click on it, and get the coordinates of the clicked points.",
# )
# # Launch the Gradio app
# demo.launch()
# import gradio as gr
# from gradio_image_prompter import ImagePrompter
# import torch
# from sam2.sam2_image_predictor import SAM2ImagePredictor
# def prompter(prompts):
# image = prompts["image"] # Get the image from prompts
# points = prompts["points"] # Get the points from prompts
# # Print the collected inputs for debugging or logging
# print("Image received:", image)
# print("Points received:", points)
# device = torch.device("cpu")
# # Load the SAM2ImagePredictor model
# predictor = SAM2ImagePredictor.from_pretrained(
# "facebook/sam2-hiera-base-plus", device=device
# )
# # Perform inference
# with torch.inference_mode():
# predictor.set_image(image)
# input_point = [[point[0], point[1]] for point in points]
# input_label = [1] * len(points) # Assuming all points are foreground
# masks, _, _ = predictor.predict(
# point_coords=input_point, point_labels=input_label
# )
# # The masks are returned as a list of numpy arrays
# print("Predicted Mask:", masks)
# # Assuming there's only one mask returned, you can adjust if there are multiple
# predicted_mask = masks[0]
# print(len(image))
# print(len(predicted_mask))
# # Create annotations for AnnotatedImage
# annotations = [(predicted_mask, "Predicted Mask")]
# return image, annotations
# # Define the Gradio interface
# demo = gr.Interface(
# fn=prompter, # Use the custom prompter function
# inputs=ImagePrompter(
# show_label=False
# ), # ImagePrompter for image input and point selection
# outputs=gr.AnnotatedImage(), # Display the image with the predicted mask
# title="Image Point Collector with Mask Overlay",
# description="Upload an image, click on it, and get the predicted mask overlayed on the image.",
# )
# # Launch the Gradio app
# demo.launch()
import gradio as gr
from gradio_image_prompter import ImagePrompter
import torch
import numpy as np
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image
def prompter(prompts):
image = np.array(prompts["image"]) # Convert the image to a numpy array
points = prompts["points"] # Get the points from prompts
# Print the collected inputs for debugging or logging
print("Image received:", image)
print("Points received:", points)
device = torch.device("cpu")
# Load the SAM2ImagePredictor model
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-base-plus", device=device
)
# Perform inference with multimask_output=True
with torch.inference_mode():
predictor.set_image(image)
input_point = [[point[0], point[1]] for point in points]
input_label = [1] * len(points) # Assuming all points are foreground
masks, _, _ = predictor.predict(
point_coords=input_point, point_labels=input_label, multimask_output=True
)
# Prepare individual images with separate overlays
overlay_images = []
for i, mask in enumerate(masks):
print(f"Predicted Mask {i+1}:", mask)
red_mask = np.zeros_like(image)
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel
red_mask = Image.fromarray(red_mask)
# Convert the original image to a PIL image
original_image = Image.fromarray(image)
# Blend the original image with the red mask
blended_image = Image.blend(original_image, red_mask, alpha=0.5)
# Add the blended image to the list
overlay_images.append(blended_image)
return overlay_images
# Define the Gradio interface
demo = gr.Interface(
fn=prompter, # Use the custom prompter function
inputs=ImagePrompter(
show_label=False
), # ImagePrompter for image input and point selection
outputs=[
gr.Image(show_label=False) for _ in range(3)
], # Display up to 3 overlay images
title="Image Point Collector with Multiple Separate Mask Overlays",
description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.",
)
# Launch the Gradio app
demo.launch()