Spaces:
Running
Running
# import gradio as gr | |
# from transformers import AutoProcessor, AutoModelForZeroShotImageClassification | |
# from PIL import Image | |
# import requests | |
# import torch | |
# # Load the FashionCLIP processor and model | |
# processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip") | |
# model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip") | |
# # Define the function to process both text and image inputs | |
# def generate_embeddings(input_text=None, input_image_url=None): | |
# try: | |
# if input_image_url: | |
# # Process image with accompanying text | |
# response = requests.get(input_image_url, stream=True) | |
# response.raise_for_status() | |
# image = Image.open(response.raw) | |
# # Use a default text if none is provided | |
# if not input_text: | |
# input_text = "this is an image" | |
# # Prepare inputs for the model | |
# inputs = processor( | |
# text=[input_text], | |
# images=image, | |
# return_tensors="pt", | |
# padding=True | |
# ) | |
# with torch.no_grad(): | |
# outputs = model(**inputs) | |
# image_embedding = outputs.logits_per_image.cpu().numpy().tolist() | |
# return { | |
# "type": "image_embedding", | |
# "input_image_url": input_image_url, | |
# "input_text": input_text, | |
# "embedding": image_embedding | |
# } | |
# elif input_text: | |
# # Process text input only | |
# inputs = processor( | |
# text=[input_text], | |
# images=None, | |
# return_tensors="pt", | |
# padding=True | |
# ) | |
# with torch.no_grad(): | |
# outputs = model(**inputs) | |
# text_embedding = outputs.logits_per_text.cpu().numpy().tolist() | |
# return { | |
# "type": "text_embedding", | |
# "input_text": input_text, | |
# "embedding": text_embedding | |
# } | |
# else: | |
# return {"error": "Please provide either a text query or an image URL."} | |
# except Exception as e: | |
# return {"error": str(e)} | |
# # Create the Gradio interface | |
# interface = gr.Interface( | |
# fn=generate_embeddings, | |
# inputs=[ | |
# gr.Textbox(label="Text Query (Optional)", placeholder="e.g., red dress (used with image or for text embedding)"), | |
# gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg (used with or without text query)") | |
# ], | |
# outputs="json", | |
# title="FashionCLIP Combined Embedding API", | |
# description="Provide a text query and/or an image URL to compute embeddings for vector search." | |
# ) | |
# # Launch the app | |
# if __name__ == "__main__": | |
# interface.launch() | |
# print(generate_embeddings("red dress")) | |
import uuid | |
import requests | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
from encoder import FashionCLIPEncoder | |
# Constants | |
REQUESTS_HEADERS = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
BATCH_SIZE = 30 # Define batch size for processing | |
# Initialize encoder | |
encoder = FashionCLIPEncoder() | |
# Helper function to download images | |
def download_image_as_pil(url: str, timeout: int = 10) -> Image.Image: | |
try: | |
response = requests.get(url, stream=True, headers=REQUESTS_HEADERS, timeout=timeout) | |
if response.status_code == 200: | |
return Image.open(response.raw).convert("RGB") # Ensure consistent format | |
return None | |
except Exception as e: | |
print(f"Error downloading image: {e}") | |
return None | |
# Embedding function for a batch of images | |
def batch_process_images(image_urls: str): | |
# Split the input string by commas and strip whitespace | |
urls = [url.strip() for url in image_urls.split(",") if url.strip()] | |
if not urls: | |
return {"error": "No valid image URLs provided."} | |
results = [] | |
batch_urls, batch_images = [], [] | |
for url in urls: | |
try: | |
# Download image | |
image = download_image_as_pil(url) | |
if not image: | |
results.append({"image_url": url, "error": "Failed to download image"}) | |
continue | |
batch_urls.append(url) | |
batch_images.append(image) | |
# Process batch when reaching batch size | |
if len(batch_images) == BATCH_SIZE: | |
process_batch(batch_urls, batch_images, results) | |
batch_urls, batch_images = [], [] | |
except Exception as e: | |
results.append({"image_url": url, "error": str(e)}) | |
# Process remaining images in the last batch | |
if batch_images: | |
process_batch(batch_urls, batch_images, results) | |
return results | |
# Helper function to process a batch | |
def process_batch(batch_urls, batch_images, results): | |
try: | |
# Generate embeddings | |
embeddings = encoder.encode_images(batch_images) | |
for url, embedding in zip(batch_urls, embeddings): | |
# Normalize embedding | |
embedding_normalized = embedding / np.linalg.norm(embedding) | |
# Append results | |
results.append({ | |
"image_url": url, | |
"embedding_preview": embedding_normalized[:5].tolist(), # First 5 values for preview | |
"success": True | |
}) | |
except Exception as e: | |
for url in batch_urls: | |
results.append({"image_url": url, "error": str(e)}) | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=batch_process_images, | |
inputs=gr.Textbox( | |
lines=5, | |
placeholder="Enter image URLs separated by commas", | |
label="Batch Image URLs", | |
), | |
outputs=gr.JSON(label="Embedding Results"), | |
title="Batch Fashion CLIP Embedding API", | |
description="Enter multiple image URLs (separated by commas) to generate embeddings for the batch. Each embedding preview includes the first 5 values.", | |
examples=[ | |
["https://cdn.shopify.com/s/files/1/0522/2239/4534/files/CT21355-22_1024x1024.webp, https://cdn.shopify.com/s/files/1/0522/2239/4534/files/00907857-C6B0-4D2A-8AEA-688BDE1E67D7_1024x1024.jpg"] | |
], | |
) | |
# Launch Gradio App | |
if __name__ == "__main__": | |
iface.launch() | |