# import gradio as gr # from gradio_image_prompter import ImagePrompter # import os # import torch # def prompter(prompts): # image = prompts["image"] # Get the image from prompts # points = prompts["points"] # Get the points from prompts # # Print the collected inputs for debugging or logging # print("Image received:", image) # print("Points received:", points) # import torch # from sam2.sam2_image_predictor import SAM2ImagePredictor # device = torch.device("cpu") # predictor = SAM2ImagePredictor.from_pretrained( # "facebook/sam2-hiera-base-plus", device=device # ) # with torch.inference_mode(): # predictor.set_image(image) # # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points]) # input_point = [[point[0], point[1]] for point in points] # input_label = [1] # masks, _, _ = predictor.predict( # point_coords=input_point, point_labels=input_label # ) # print("Predicted Mask:", masks) # return image, points # # Define the Gradio interface # demo = gr.Interface( # fn=prompter, # Use the custom prompter function # inputs=ImagePrompter( # show_label=False # ), # ImagePrompter for image input and point selection # outputs=[ # gr.Image(show_label=False), # Display the image # gr.Dataframe(label="Points"), # Display the points in a DataFrame # ], # title="Image Point Collector", # description="Upload an image, click on it, and get the coordinates of the clicked points.", # ) # # Launch the Gradio app # demo.launch() # import gradio as gr # from gradio_image_prompter import ImagePrompter # import torch # from sam2.sam2_image_predictor import SAM2ImagePredictor # def prompter(prompts): # image = prompts["image"] # Get the image from prompts # points = prompts["points"] # Get the points from prompts # # Print the collected inputs for debugging or logging # print("Image received:", image) # print("Points received:", points) # device = torch.device("cpu") # # Load the SAM2ImagePredictor model # predictor = SAM2ImagePredictor.from_pretrained( # "facebook/sam2-hiera-base-plus", device=device # ) # # Perform inference # with torch.inference_mode(): # predictor.set_image(image) # input_point = [[point[0], point[1]] for point in points] # input_label = [1] * len(points) # Assuming all points are foreground # masks, _, _ = predictor.predict( # point_coords=input_point, point_labels=input_label # ) # # The masks are returned as a list of numpy arrays # print("Predicted Mask:", masks) # # Assuming there's only one mask returned, you can adjust if there are multiple # predicted_mask = masks[0] # print(len(image)) # print(len(predicted_mask)) # # Create annotations for AnnotatedImage # annotations = [(predicted_mask, "Predicted Mask")] # return image, annotations # # Define the Gradio interface # demo = gr.Interface( # fn=prompter, # Use the custom prompter function # inputs=ImagePrompter( # show_label=False # ), # ImagePrompter for image input and point selection # outputs=gr.AnnotatedImage(), # Display the image with the predicted mask # title="Image Point Collector with Mask Overlay", # description="Upload an image, click on it, and get the predicted mask overlayed on the image.", # ) # # Launch the Gradio app # demo.launch() import gradio as gr from gradio_image_prompter import ImagePrompter import torch import numpy as np from sam2.sam2_image_predictor import SAM2ImagePredictor from PIL import Image def prompter(prompts): image = np.array(prompts["image"]) # Convert the image to a numpy array points = prompts["points"] # Get the points from prompts # Print the collected inputs for debugging or logging print("Image received:", image) print("Points received:", points) device = torch.device("cpu") # Load the SAM2ImagePredictor model predictor = SAM2ImagePredictor.from_pretrained( "facebook/sam2-hiera-base-plus", device=device ) # Perform inference with multimask_output=True with torch.inference_mode(): predictor.set_image(image) input_point = [[point[0], point[1]] for point in points] input_label = [1] * len(points) # Assuming all points are foreground masks, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True ) # Prepare individual images with separate overlays overlay_images = [] for i, mask in enumerate(masks): print(f"Predicted Mask {i+1}:", mask) red_mask = np.zeros_like(image) red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel red_mask = Image.fromarray(red_mask) # Convert the original image to a PIL image original_image = Image.fromarray(image) # Blend the original image with the red mask blended_image = Image.blend(original_image, red_mask, alpha=0.5) # Add the blended image to the list overlay_images.append(blended_image) return overlay_images # Define the Gradio interface demo = gr.Interface( fn=prompter, # Use the custom prompter function inputs=ImagePrompter( show_label=False ), # ImagePrompter for image input and point selection outputs=[ gr.Image(show_label=False) for _ in range(3) ], # Display up to 3 overlay images title="Image Point Collector with Multiple Separate Mask Overlays", description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.", ) # Launch the Gradio app demo.launch()