Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio_image_prompter import ImagePrompter | |
import torch | |
import numpy as np | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
from uuid import uuid4 | |
import os | |
from huggingface_hub import upload_folder, login | |
from PIL import Image as PILImage | |
from datasets import Dataset, Features, Array2D, Image | |
import shutil | |
import random | |
from datasets import load_dataset | |
MODEL = "facebook/sam2-hiera-large" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE) | |
DESTINATION_DS = "amaye15/object-segmentation" | |
token = os.getenv("TOKEN") | |
if token: | |
login(token) | |
IMAGE = None | |
MASKS = None | |
MASKED_IMAGES = None | |
INDEX = None | |
ds_name = ["amaye15/product_labels"] # "amaye15/Products-10k", "amaye15/receipts" | |
choices = ["test", "train"] | |
max_len = None | |
ds_stream = load_dataset(random.choice(ds_name), streaming=True) | |
ds_split = ds_stream[random.choice(choices)] | |
ds_iter = ds_split.iter(batch_size=1) | |
for idx, val in enumerate(ds_iter): | |
max_len = idx | |
def prompter(prompts): | |
image = np.array(prompts["image"]) # Convert the image to a numpy array | |
points = prompts["points"] # Get the points from prompts | |
# Perform inference with multimask_output=True | |
with torch.inference_mode(): | |
PREDICTOR.set_image(image) | |
input_point = [[point[0], point[1]] for point in points] | |
input_label = [1] * len(points) # Assuming all points are foreground | |
masks, _, _ = PREDICTOR.predict( | |
point_coords=input_point, point_labels=input_label, multimask_output=True | |
) | |
# Prepare individual images with separate overlays | |
overlay_images = [] | |
for i, mask in enumerate(masks): | |
print(f"Predicted Mask {i+1}:", mask.shape) | |
red_mask = np.zeros_like(image) | |
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel | |
red_mask = PILImage.fromarray(red_mask) | |
# Convert the original image to a PIL image | |
original_image = PILImage.fromarray(image) | |
# Blend the original image with the red mask | |
blended_image = PILImage.blend(original_image, red_mask, alpha=0.5) | |
# Add the blended image to the list | |
overlay_images.append(blended_image) | |
global IMAGE, MASKS, MASKED_IMAGES | |
IMAGE, MASKS = image, masks | |
MASKED_IMAGES = [np.array(img) for img in overlay_images] | |
return overlay_images[0], overlay_images[1], overlay_images[2], masks | |
def select_mask( | |
selected_mask_index, | |
mask1, | |
mask2, | |
mask3, | |
): | |
masks = [mask1, mask2, mask3] | |
global INDEX | |
INDEX = selected_mask_index | |
return masks[selected_mask_index] | |
def save_selected_mask(image, mask, output_dir="output"): | |
output_dir = os.path.join(os.getcwd(), output_dir) | |
os.makedirs(output_dir, exist_ok=True) | |
folder_id = str(uuid4()) | |
folder_path = os.path.join(output_dir, folder_id) | |
os.makedirs(folder_path, exist_ok=True) | |
data_path = os.path.join(folder_path, "data.parquet") | |
data = { | |
"image": IMAGE, | |
"masked_image": MASKED_IMAGES[INDEX], | |
"mask": MASKS[INDEX], | |
} | |
features = Features( | |
{ | |
"image": Image(), | |
"masked_image": Image(), | |
"mask": Array2D( | |
dtype="int64", shape=(MASKS[INDEX].shape[0], MASKS[INDEX].shape[1]) | |
), | |
} | |
) | |
ds = Dataset.from_list([data], features=features) | |
ds.to_parquet(data_path) | |
upload_folder( | |
folder_path=output_dir, | |
repo_id=DESTINATION_DS, | |
repo_type="dataset", | |
) | |
shutil.rmtree(folder_path) | |
iframe_code = """## Success! ππ€β | |
You've successfully contributed to the dataset. | |
Please note that because new data has been added to the dataset, it may take a couple of minutes to render. | |
Check it out here: | |
[Object Segmentation Dataset](https://huggingface.co/datasets/amaye15/object-segmentation) | |
""" | |
return iframe_code | |
def get_random_image(): | |
"""Get a random image from the dataset.""" | |
global max_len | |
random_idx = random.choice(range(max_len)) | |
image_data = list(ds_split.skip(random_idx).take(1))[0]["pixel_values"] | |
formatted_image = { | |
"image": np.array(image_data), | |
"points": [], | |
} # Create the correct format | |
return formatted_image | |
# Define the Gradio Blocks app | |
with gr.Blocks() as demo: | |
gr.Markdown("# Object Segmentation- Image Point Collector and Mask Overlay Tool") | |
gr.Markdown( | |
""" | |
This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image or select a random image from a dataset and interactively generate segmentation masks based on multiple points you select on the image. | |
### How It Works: | |
1. **Upload or Select an Image**: You can either upload your own image or use a random image from the dataset. | |
2. **Point Selection**: Click on the image to indicate points of interest. You can add multiple points, and these will be used collectively to generate segmentation masks using SAM2. | |
3. **Mask Generation**: The app will generate up to three different segmentation masks for the selected points, each displayed separately with a red overlay. | |
4. **Mask Selection**: Carefully review the generated masks and select the one that best fits your needs. **It's important to choose the correct mask, as your selection will be saved and used for further processing.** | |
5. **Save and Contribute**: Save the selected mask along with the image to a dataset, contributing to a shared dataset on Hugging Face. | |
**Disclaimer**: All images and masks you work with will be collected and stored in a public dataset. Please ensure that you are comfortable with your selections and the data you provide before saving. | |
This tool is particularly useful for creating precise object segmentation masks for computer vision tasks, such as training models or generating labeled datasets. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.State() | |
# Input: ImagePrompter for uploaded image | |
upload_image_input = ImagePrompter(show_label=False) | |
random_image_button = gr.Button("Use Random Image") | |
submit_button = gr.Button("Submit") | |
with gr.Row(): | |
with gr.Column(): | |
# Outputs: Up to 3 overlay images | |
image_output_1 = gr.Image(show_label=False) | |
with gr.Column(): | |
image_output_2 = gr.Image(show_label=False) | |
with gr.Column(): | |
image_output_3 = gr.Image(show_label=False) | |
# Dropdown for selecting the correct mask | |
with gr.Row(): | |
mask_selector = gr.Radio( | |
label="Select the correct mask", | |
choices=["Mask 1", "Mask 2", "Mask 3"], | |
type="index", | |
) | |
# selected_mask_output = gr.Image(show_label=False) | |
save_button = gr.Button("Save Selected Mask and Image") | |
iframe_display = gr.Markdown() | |
# Logic for the random image button | |
random_image_button.click( | |
fn=get_random_image, | |
inputs=None, | |
outputs=upload_image_input, # Pass the formatted random image to ImagePrompter | |
) | |
# Logic to use uploaded image | |
upload_image_input.change( | |
fn=lambda img: img, inputs=upload_image_input, outputs=image_input | |
) | |
# Define the action triggered by the submit button | |
submit_button.click( | |
fn=prompter, | |
inputs=upload_image_input, # The final image input (whether uploaded or random) | |
outputs=[image_output_1, image_output_2, image_output_3, gr.State()], | |
show_progress=True, | |
) | |
# Define the action triggered by mask selection | |
mask_selector.change( | |
fn=select_mask, | |
inputs=[mask_selector, image_output_1, image_output_2, image_output_3], | |
outputs=gr.State(), | |
) | |
# Define the action triggered by the save button | |
save_button.click( | |
fn=save_selected_mask, | |
inputs=[gr.State(), gr.State()], | |
outputs=iframe_display, | |
show_progress=True, | |
) | |
# Launch the Gradio app | |
demo.launch() | |
# with gr.Column(): | |
# source = gr.Textbox(label="Source Dataset") | |
# source_display = gr.Markdown() | |
# iframe_display = gr.HTML() | |
# source.change( | |
# save_dataset_name, | |
# inputs=(gr.State("source_dataset"), source), | |
# outputs=(source_display, iframe_display), | |
# ) | |
# with gr.Column(): | |
# destination = gr.Textbox(label="Destination Dataset") | |
# destination_display = gr.Markdown() | |
# destination.change( | |
# save_dataset_name, | |
# inputs=(gr.State("destination_dataset"), destination), | |
# outputs=destination_display, | |
# ) | |