import gradio as gr from gradio_image_prompter import ImagePrompter import torch import numpy as np from sam2.sam2_image_predictor import SAM2ImagePredictor from PIL import Image from uuid import uuid4 import os from huggingface_hub import upload_folder, login import shutil MODEL = "facebook/sam2-hiera-large" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE) login(os.getenv("TOKEN")) GLOBALS = {} IMAGE = None MASKS = None INDEX = None def prompter(prompts): image = np.array(prompts["image"]) # Convert the image to a numpy array points = prompts["points"] # Get the points from prompts # Perform inference with multimask_output=True with torch.inference_mode(): PREDICTOR.set_image(image) input_point = [[point[0], point[1]] for point in points] input_label = [1] * len(points) # Assuming all points are foreground masks, _, _ = PREDICTOR.predict( point_coords=input_point, point_labels=input_label, multimask_output=True ) # Prepare individual images with separate overlays overlay_images = [] for i, mask in enumerate(masks): print(f"Predicted Mask {i+1}:", mask.shape) red_mask = np.zeros_like(image) red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel red_mask = Image.fromarray(red_mask) # Convert the original image to a PIL image original_image = Image.fromarray(image) # Blend the original image with the red mask blended_image = Image.blend(original_image, red_mask, alpha=0.5) # Add the blended image to the list overlay_images.append(blended_image) global IMAGE, MASKS IMAGE, MASKS = image, masks return overlay_images[0], overlay_images[1], overlay_images[2], masks def select_mask( selected_mask_index, mask1, mask2, mask3, ): masks = [mask1, mask2, mask3] global INDEX INDEX = selected_mask_index return masks[selected_mask_index] def save_selected_mask(image, mask, output_dir="output"): output_dir = os.path.join(os.getcwd(), output_dir) os.makedirs(output_dir, exist_ok=True) # Generate a unique UUID for the folder name folder_id = str(uuid4()) # Create a path for the new folder folder_path = os.path.join(output_dir, folder_id) # Ensure the folder is created os.makedirs(folder_path, exist_ok=True) # Define the paths for saving the image and mask image_path = os.path.join(folder_path, "image.npy") mask_path = os.path.join(folder_path, "mask.npy") # Save the image and mask to the respective paths with open(image_path, "wb") as f: np.save(f, IMAGE) with open(mask_path, "wb") as f: np.save(f, MASKS[INDEX]) # Upload the folder to the Hugging Face Hub upload_folder( folder_path=output_dir, # path_in_repo=path_in_repo, repo_id="amaye15/object-segmentation", repo_type="dataset", # ignore_patterns="**/logs/*.txt", # Adjust this if needed ) shutil.rmtree(folder_path) return f"Image and mask saved to {folder_path}." def save_dataset_name(key, dataset_name): global GLOBALS GLOBALS[key] = dataset_name iframe_code = f""" """ return f"Huggingface Dataset: {dataset_name}", iframe_code # Define the Gradio Blocks app with gr.Blocks() as demo: with gr.Tab("Setup"): with gr.Row(): with gr.Column(): source = gr.Textbox(label="Source Dataset") source_display = gr.Markdown() iframe_display = gr.HTML() source.change( save_dataset_name, inputs=(gr.State("source_dataset"), source), outputs=(source_display, iframe_display), ) with gr.Column(): destination = gr.Textbox(label="Destination Dataset") destination_display = gr.Markdown() destination.change( save_dataset_name, inputs=(gr.State("destination_dataset"), destination), outputs=destination_display, ) with gr.Tab("Object Mask - Point Prompt"): gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays") gr.Markdown( "Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images." ) with gr.Row(): with gr.Column(): # Input: ImagePrompter image_input = ImagePrompter(show_label=False) submit_button = gr.Button("Submit") with gr.Row(): with gr.Column(): # Outputs: Up to 3 overlay images image_output_1 = gr.Image(show_label=False) with gr.Column(): image_output_2 = gr.Image(show_label=False) with gr.Column(): image_output_3 = gr.Image(show_label=False) # Dropdown for selecting the correct mask with gr.Row(): mask_selector = gr.Radio( label="Select the correct mask", choices=["Mask 1", "Mask 2", "Mask 3"], type="index", ) # selected_mask_output = gr.Image(show_label=False) save_button = gr.Button("Save Selected Mask and Image") save_message = gr.Textbox(visible=False) # Define the action triggered by the submit button submit_button.click( fn=prompter, inputs=image_input, outputs=[image_output_1, image_output_2, image_output_3, gr.State()], ) # Define the action triggered by mask selection mask_selector.change( fn=select_mask, inputs=[mask_selector, image_output_1, image_output_2, image_output_3], outputs=gr.State(), ) # Define the action triggered by the save button save_button.click( fn=save_selected_mask, inputs=[gr.State(), gr.State()], outputs=save_message, show_progress=True, ) # Launch the Gradio app demo.launch()