Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification | |
from PIL import Image | |
import requests | |
import torch | |
# Load the FashionCLIP processor and model | |
processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip") | |
model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip") | |
# Define the function to process both text and image inputs | |
def generate_embeddings(input_text=None, input_image_url=None): | |
try: | |
if input_image_url: | |
# Process image with accompanying text | |
response = requests.get(input_image_url, stream=True) | |
response.raise_for_status() | |
image = Image.open(response.raw) | |
# Use a default text if none is provided | |
if not input_text: | |
input_text = "this is an image" | |
# Prepare inputs for the model | |
inputs = processor( | |
text=[input_text], | |
images=image, | |
return_tensors="pt", | |
padding=True | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
image_embedding = outputs.logits_per_image.cpu().numpy().tolist() | |
return { | |
"type": "image_embedding", | |
"input_image_url": input_image_url, | |
"input_text": input_text, | |
"embedding": image_embedding | |
} | |
elif input_text: | |
# Process text input only | |
inputs = processor( | |
text=[input_text], | |
images=None, | |
return_tensors="pt", | |
padding=True | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
text_embedding = outputs.logits_per_text.cpu().numpy().tolist() | |
return { | |
"type": "text_embedding", | |
"input_text": input_text, | |
"embedding": text_embedding | |
} | |
else: | |
return {"error": "Please provide either a text query or an image URL."} | |
except Exception as e: | |
return {"error": str(e)} | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=generate_embeddings, | |
inputs=[ | |
gr.Textbox(label="Text Query (Optional)", placeholder="e.g., red dress (used with image or for text embedding)"), | |
gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg (used with or without text query)") | |
], | |
outputs="json", | |
title="FashionCLIP Combined Embedding API", | |
description="Provide a text query and/or an image URL to compute embeddings for vector search." | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |