# import gradio as gr | |
# from gradio_image_prompter import ImagePrompter | |
# import os | |
# import torch | |
# def prompter(prompts): | |
# image = prompts["image"] # Get the image from prompts | |
# points = prompts["points"] # Get the points from prompts | |
# # Print the collected inputs for debugging or logging | |
# print("Image received:", image) | |
# print("Points received:", points) | |
# import torch | |
# from sam2.sam2_image_predictor import SAM2ImagePredictor | |
# device = torch.device("cpu") | |
# predictor = SAM2ImagePredictor.from_pretrained( | |
# "facebook/sam2-hiera-base-plus", device=device | |
# ) | |
# with torch.inference_mode(): | |
# predictor.set_image(image) | |
# # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points]) | |
# input_point = [[point[0], point[1]] for point in points] | |
# input_label = [1] | |
# masks, _, _ = predictor.predict( | |
# point_coords=input_point, point_labels=input_label | |
# ) | |
# print("Predicted Mask:", masks) | |
# return image, points | |
# # Define the Gradio interface | |
# demo = gr.Interface( | |
# fn=prompter, # Use the custom prompter function | |
# inputs=ImagePrompter( | |
# show_label=False | |
# ), # ImagePrompter for image input and point selection | |
# outputs=[ | |
# gr.Image(show_label=False), # Display the image | |
# gr.Dataframe(label="Points"), # Display the points in a DataFrame | |
# ], | |
# title="Image Point Collector", | |
# description="Upload an image, click on it, and get the coordinates of the clicked points.", | |
# ) | |
# # Launch the Gradio app | |
# demo.launch() | |
# import gradio as gr | |
# from gradio_image_prompter import ImagePrompter | |
# import torch | |
# from sam2.sam2_image_predictor import SAM2ImagePredictor | |
# def prompter(prompts): | |
# image = prompts["image"] # Get the image from prompts | |
# points = prompts["points"] # Get the points from prompts | |
# # Print the collected inputs for debugging or logging | |
# print("Image received:", image) | |
# print("Points received:", points) | |
# device = torch.device("cpu") | |
# # Load the SAM2ImagePredictor model | |
# predictor = SAM2ImagePredictor.from_pretrained( | |
# "facebook/sam2-hiera-base-plus", device=device | |
# ) | |
# # Perform inference | |
# with torch.inference_mode(): | |
# predictor.set_image(image) | |
# input_point = [[point[0], point[1]] for point in points] | |
# input_label = [1] * len(points) # Assuming all points are foreground | |
# masks, _, _ = predictor.predict( | |
# point_coords=input_point, point_labels=input_label | |
# ) | |
# # The masks are returned as a list of numpy arrays | |
# print("Predicted Mask:", masks) | |
# # Assuming there's only one mask returned, you can adjust if there are multiple | |
# predicted_mask = masks[0] | |
# print(len(image)) | |
# print(len(predicted_mask)) | |
# # Create annotations for AnnotatedImage | |
# annotations = [(predicted_mask, "Predicted Mask")] | |
# return image, annotations | |
# # Define the Gradio interface | |
# demo = gr.Interface( | |
# fn=prompter, # Use the custom prompter function | |
# inputs=ImagePrompter( | |
# show_label=False | |
# ), # ImagePrompter for image input and point selection | |
# outputs=gr.AnnotatedImage(), # Display the image with the predicted mask | |
# title="Image Point Collector with Mask Overlay", | |
# description="Upload an image, click on it, and get the predicted mask overlayed on the image.", | |
# ) | |
# # Launch the Gradio app | |
# demo.launch() | |
import gradio as gr | |
from gradio_image_prompter import ImagePrompter | |
import torch | |
import numpy as np | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
from PIL import Image | |
def prompter(prompts): | |
image = np.array(prompts["image"]) # Convert the image to a numpy array | |
points = prompts["points"] # Get the points from prompts | |
# Print the collected inputs for debugging or logging | |
print("Image received:", image) | |
print("Points received:", points) | |
device = torch.device("cpu") | |
# Load the SAM2ImagePredictor model | |
predictor = SAM2ImagePredictor.from_pretrained( | |
"facebook/sam2-hiera-base-plus", device=device | |
) | |
# Perform inference with multimask_output=True | |
with torch.inference_mode(): | |
predictor.set_image(image) | |
input_point = [[point[0], point[1]] for point in points] | |
input_label = [1] * len(points) # Assuming all points are foreground | |
masks, _, _ = predictor.predict( | |
point_coords=input_point, point_labels=input_label, multimask_output=True | |
) | |
# Prepare individual images with separate overlays | |
overlay_images = [] | |
for i, mask in enumerate(masks): | |
print(f"Predicted Mask {i+1}:", mask) | |
red_mask = np.zeros_like(image) | |
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel | |
red_mask = Image.fromarray(red_mask) | |
# Convert the original image to a PIL image | |
original_image = Image.fromarray(image) | |
# Blend the original image with the red mask | |
blended_image = Image.blend(original_image, red_mask, alpha=0.5) | |
# Add the blended image to the list | |
overlay_images.append(blended_image) | |
return overlay_images | |
# Define the Gradio interface | |
demo = gr.Interface( | |
fn=prompter, # Use the custom prompter function | |
inputs=ImagePrompter( | |
show_label=False | |
), # ImagePrompter for image input and point selection | |
outputs=[ | |
gr.Image(show_label=False) for _ in range(3) | |
], # Display up to 3 overlay images | |
title="Image Point Collector with Multiple Separate Mask Overlays", | |
description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.", | |
) | |
# Launch the Gradio app | |
demo.launch() | |