amaye15
Version 1 - Working
78d359b
raw
history blame
6.47 kB
import gradio as gr
from gradio_image_prompter import ImagePrompter
import torch
import numpy as np
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image
from uuid import uuid4
import os
from huggingface_hub import upload_folder
import shutil
MODEL = "facebook/sam2-hiera-large"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
GLOBALS = {}
IMAGE = None
MASKS = None
INDEX = None
def prompter(prompts):
image = np.array(prompts["image"]) # Convert the image to a numpy array
points = prompts["points"] # Get the points from prompts
# Perform inference with multimask_output=True
with torch.inference_mode():
PREDICTOR.set_image(image)
input_point = [[point[0], point[1]] for point in points]
input_label = [1] * len(points) # Assuming all points are foreground
masks, _, _ = PREDICTOR.predict(
point_coords=input_point, point_labels=input_label, multimask_output=True
)
# Prepare individual images with separate overlays
overlay_images = []
for i, mask in enumerate(masks):
print(f"Predicted Mask {i+1}:", mask.shape)
red_mask = np.zeros_like(image)
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel
red_mask = Image.fromarray(red_mask)
# Convert the original image to a PIL image
original_image = Image.fromarray(image)
# Blend the original image with the red mask
blended_image = Image.blend(original_image, red_mask, alpha=0.5)
# Add the blended image to the list
overlay_images.append(blended_image)
global IMAGE, MASKS
IMAGE, MASKS = image, masks
return overlay_images[0], overlay_images[1], overlay_images[2], masks
def select_mask(
selected_mask_index,
mask1,
mask2,
mask3,
):
masks = [mask1, mask2, mask3]
global INDEX
INDEX = selected_mask_index
return masks[selected_mask_index]
def save_selected_mask(image, mask, output_dir="output"):
output_dir = os.path.join(os.getcwd(), output_dir)
os.makedirs(output_dir, exist_ok=True)
# Generate a unique UUID for the folder name
folder_id = str(uuid4())
# Create a path for the new folder
folder_path = os.path.join(output_dir, folder_id)
# Ensure the folder is created
os.makedirs(folder_path, exist_ok=True)
# Define the paths for saving the image and mask
image_path = os.path.join(folder_path, "image.npy")
mask_path = os.path.join(folder_path, "mask.npy")
# Save the image and mask to the respective paths
with open(image_path, "wb") as f:
np.save(f, IMAGE)
with open(mask_path, "wb") as f:
np.save(f, MASKS[INDEX])
# Upload the folder to the Hugging Face Hub
upload_folder(
folder_path=output_dir,
# path_in_repo=path_in_repo,
repo_id="amaye15/object-segmentation",
repo_type="dataset",
# ignore_patterns="**/logs/*.txt", # Adjust this if needed
)
shutil.rmtree(folder_path)
return f"Image and mask saved to {folder_path}."
def save_dataset_name(key, dataset_name):
global GLOBALS
GLOBALS[key] = dataset_name
iframe_code = f"""
<iframe
src="https://huggingface.co/datasets/{dataset_name}/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>
"""
return f"Huggingface Dataset: {dataset_name}", iframe_code
# Define the Gradio Blocks app
with gr.Blocks() as demo:
with gr.Tab("Setup"):
with gr.Row():
with gr.Column():
source = gr.Textbox(label="Source Dataset")
source_display = gr.Markdown()
iframe_display = gr.HTML()
source.change(
save_dataset_name,
inputs=(gr.State("source_dataset"), source),
outputs=(source_display, iframe_display),
)
with gr.Column():
destination = gr.Textbox(label="Destination Dataset")
destination_display = gr.Markdown()
destination.change(
save_dataset_name,
inputs=(gr.State("destination_dataset"), destination),
outputs=destination_display,
)
with gr.Tab("Object Mask - Point Prompt"):
gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays")
gr.Markdown(
"Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images."
)
with gr.Row():
with gr.Column():
# Input: ImagePrompter
image_input = ImagePrompter(show_label=False)
submit_button = gr.Button("Submit")
with gr.Row():
with gr.Column():
# Outputs: Up to 3 overlay images
image_output_1 = gr.Image(show_label=False)
with gr.Column():
image_output_2 = gr.Image(show_label=False)
with gr.Column():
image_output_3 = gr.Image(show_label=False)
# Dropdown for selecting the correct mask
with gr.Row():
mask_selector = gr.Radio(
label="Select the correct mask",
choices=["Mask 1", "Mask 2", "Mask 3"],
type="index",
)
# selected_mask_output = gr.Image(show_label=False)
save_button = gr.Button("Save Selected Mask and Image")
save_message = gr.Textbox(visible=False)
# Define the action triggered by the submit button
submit_button.click(
fn=prompter,
inputs=image_input,
outputs=[image_output_1, image_output_2, image_output_3, gr.State()],
)
# Define the action triggered by mask selection
mask_selector.change(
fn=select_mask,
inputs=[mask_selector, image_output_1, image_output_2, image_output_3],
outputs=gr.State(),
)
# Define the action triggered by the save button
save_button.click(
fn=save_selected_mask,
inputs=[gr.State(), gr.State()],
outputs=save_message,
show_progress=True,
)
# Launch the Gradio app
demo.launch()