import numpy as np import gradio as gr import segment_anything import base64 import torch import typing import os import subprocess import requests import PIL.Image import urllib.parse def download_image(url) -> PIL.Image.Image: """Download an image from a URL and return it as a PIL image.""" return PIL.Image.open(requests.get(url, stream=True).raw) def image_to_sam_image_embedding( image_url: str, # model_size: typing.Literal["base", "large", "huge"] = "base", model_size: str = "base", ) -> str: """Generate an image embedding.""" image_url = urllib.parse.unquote(image_url) image = download_image(image_url) image = image.convert("RGB") image = np.asarray(image) # Select model size if model_size == "base": predictor = base_predictor elif model_size == "large": predictor = large_predictor elif model_size == "huge": predictor = huge_predictor # Run model predictor.set_image(image) # Output shape is (1, 256, 64, 64) image_embedding = predictor.get_image_embedding().cpu().numpy() # Flatten the array to a 1D array flat_arr = image_embedding.flatten() # Convert the 1D array to bytes bytes_arr = flat_arr.astype(np.float32).tobytes() # Encode the bytes to base64 base64_str = base64.b64encode(bytes_arr).decode("utf-8") return base64_str if __name__ == "__main__": # Load the model into memory to make running multiple predictions efficient device = "cuda" if torch.cuda.is_available() else "cpu" base_sam_checkpoint = "sam_vit_b_01ec64.pth" # 375 MB large_sam_checkpoint = "sam_vit_l_0b3195.pth" # 1.25 GB huge_sam_checkpoint = "sam_vit_h_4b8939.pth" # 2.56 GB # Download the model checkpoints for model in [base_sam_checkpoint, large_sam_checkpoint, huge_sam_checkpoint]: if not os.path.exists(f"./{model}"): result = subprocess.run( ["wget", f"https://dl.fbaipublicfiles.com/segment_anything/{model}"], check=True, ) print(f"wget {model} result = {result}") base_sam = segment_anything.sam_model_registry["vit_b"]( checkpoint=base_sam_checkpoint ) large_sam = segment_anything.sam_model_registry["vit_l"]( checkpoint=large_sam_checkpoint ) huge_sam = segment_anything.sam_model_registry["vit_h"]( checkpoint=huge_sam_checkpoint ) base_sam.to(device=device) large_sam.to(device=device) huge_sam.to(device=device) base_predictor = segment_anything.SamPredictor(base_sam) large_predictor = segment_anything.SamPredictor(large_sam) huge_predictor = segment_anything.SamPredictor(huge_sam) # Gradio app app = gr.Interface( fn=image_to_sam_image_embedding, inputs=[ gr.components.Textbox(label="Image URL"), gr.components.Radio( choices=["base", "large", "huge"], label="Model Size", value="base" ), ], outputs="text", ) app.launch()