|
import gradio as gr |
|
from gradio_image_prompter import ImagePrompter |
|
import torch |
|
import numpy as np |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
from PIL import Image |
|
from uuid import uuid4 |
|
import os |
|
from huggingface_hub import upload_folder, login |
|
import shutil |
|
|
|
MODEL = "facebook/sam2-hiera-large" |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE) |
|
|
|
|
|
login(os.getenv("TOKEN")) |
|
|
|
GLOBALS = {} |
|
|
|
|
|
IMAGE = None |
|
MASKS = None |
|
INDEX = None |
|
|
|
|
|
def prompter(prompts): |
|
|
|
image = np.array(prompts["image"]) |
|
points = prompts["points"] |
|
|
|
|
|
with torch.inference_mode(): |
|
PREDICTOR.set_image(image) |
|
input_point = [[point[0], point[1]] for point in points] |
|
input_label = [1] * len(points) |
|
masks, _, _ = PREDICTOR.predict( |
|
point_coords=input_point, point_labels=input_label, multimask_output=True |
|
) |
|
|
|
|
|
overlay_images = [] |
|
for i, mask in enumerate(masks): |
|
print(f"Predicted Mask {i+1}:", mask.shape) |
|
red_mask = np.zeros_like(image) |
|
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 |
|
red_mask = Image.fromarray(red_mask) |
|
|
|
|
|
original_image = Image.fromarray(image) |
|
|
|
|
|
blended_image = Image.blend(original_image, red_mask, alpha=0.5) |
|
|
|
|
|
overlay_images.append(blended_image) |
|
|
|
global IMAGE, MASKS |
|
|
|
IMAGE, MASKS = image, masks |
|
|
|
return overlay_images[0], overlay_images[1], overlay_images[2], masks |
|
|
|
|
|
def select_mask( |
|
selected_mask_index, |
|
mask1, |
|
mask2, |
|
mask3, |
|
): |
|
masks = [mask1, mask2, mask3] |
|
global INDEX |
|
INDEX = selected_mask_index |
|
return masks[selected_mask_index] |
|
|
|
|
|
def save_selected_mask(image, mask, output_dir="output"): |
|
|
|
output_dir = os.path.join(os.getcwd(), output_dir) |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
folder_id = str(uuid4()) |
|
|
|
|
|
folder_path = os.path.join(output_dir, folder_id) |
|
|
|
|
|
os.makedirs(folder_path, exist_ok=True) |
|
|
|
|
|
image_path = os.path.join(folder_path, "image.npy") |
|
mask_path = os.path.join(folder_path, "mask.npy") |
|
|
|
|
|
with open(image_path, "wb") as f: |
|
np.save(f, IMAGE) |
|
|
|
with open(mask_path, "wb") as f: |
|
np.save(f, MASKS[INDEX]) |
|
|
|
|
|
upload_folder( |
|
folder_path=output_dir, |
|
|
|
repo_id="amaye15/object-segmentation", |
|
repo_type="dataset", |
|
|
|
) |
|
|
|
shutil.rmtree(folder_path) |
|
|
|
return f"Image and mask saved to {folder_path}." |
|
|
|
|
|
def save_dataset_name(key, dataset_name): |
|
global GLOBALS |
|
GLOBALS[key] = dataset_name |
|
|
|
iframe_code = f""" |
|
<iframe |
|
src="https://huggingface.co/datasets/{dataset_name}/embed/viewer/default/train" |
|
frameborder="0" |
|
width="100%" |
|
height="560px" |
|
></iframe> |
|
""" |
|
return f"Huggingface Dataset: {dataset_name}", iframe_code |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("Setup"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
source = gr.Textbox(label="Source Dataset") |
|
source_display = gr.Markdown() |
|
iframe_display = gr.HTML() |
|
|
|
source.change( |
|
save_dataset_name, |
|
inputs=(gr.State("source_dataset"), source), |
|
outputs=(source_display, iframe_display), |
|
) |
|
|
|
with gr.Column(): |
|
|
|
destination = gr.Textbox(label="Destination Dataset") |
|
destination_display = gr.Markdown() |
|
|
|
destination.change( |
|
save_dataset_name, |
|
inputs=(gr.State("destination_dataset"), destination), |
|
outputs=destination_display, |
|
) |
|
|
|
with gr.Tab("Object Mask - Point Prompt"): |
|
gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays") |
|
gr.Markdown( |
|
"Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
image_input = ImagePrompter(show_label=False) |
|
submit_button = gr.Button("Submit") |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
image_output_1 = gr.Image(show_label=False) |
|
with gr.Column(): |
|
image_output_2 = gr.Image(show_label=False) |
|
with gr.Column(): |
|
image_output_3 = gr.Image(show_label=False) |
|
|
|
|
|
with gr.Row(): |
|
mask_selector = gr.Radio( |
|
label="Select the correct mask", |
|
choices=["Mask 1", "Mask 2", "Mask 3"], |
|
type="index", |
|
) |
|
|
|
|
|
save_button = gr.Button("Save Selected Mask and Image") |
|
save_message = gr.Textbox(visible=False) |
|
|
|
|
|
submit_button.click( |
|
fn=prompter, |
|
inputs=image_input, |
|
outputs=[image_output_1, image_output_2, image_output_3, gr.State()], |
|
) |
|
|
|
|
|
mask_selector.change( |
|
fn=select_mask, |
|
inputs=[mask_selector, image_output_1, image_output_2, image_output_3], |
|
outputs=gr.State(), |
|
) |
|
|
|
|
|
save_button.click( |
|
fn=save_selected_mask, |
|
inputs=[gr.State(), gr.State()], |
|
outputs=save_message, |
|
show_progress=True, |
|
) |
|
|
|
|
|
demo.launch() |
|
|