openFashionClip / app.py
im
Updated app to handle both text and image embeddings
e27c656
raw
history blame
2.76 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from PIL import Image
import requests
import torch
# Load the FashionCLIP processor and model
processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip")
# Define the function to process both text and image inputs
def generate_embeddings(input_text=None, input_image_url=None):
try:
if input_image_url:
# Process image with accompanying text
response = requests.get(input_image_url, stream=True)
response.raise_for_status()
image = Image.open(response.raw)
# Use a default text if none is provided
if not input_text:
input_text = "this is an image"
# Prepare inputs for the model
inputs = processor(
text=[input_text],
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
image_embedding = outputs.logits_per_image.cpu().numpy().tolist()
return {
"type": "image_embedding",
"input_image_url": input_image_url,
"input_text": input_text,
"embedding": image_embedding
}
elif input_text:
# Process text input only
inputs = processor(
text=[input_text],
images=None,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
text_embedding = outputs.logits_per_text.cpu().numpy().tolist()
return {
"type": "text_embedding",
"input_text": input_text,
"embedding": text_embedding
}
else:
return {"error": "Please provide either a text query or an image URL."}
except Exception as e:
return {"error": str(e)}
# Create the Gradio interface
interface = gr.Interface(
fn=generate_embeddings,
inputs=[
gr.Textbox(label="Text Query (Optional)", placeholder="e.g., red dress (used with image or for text embedding)"),
gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg (used with or without text query)")
],
outputs="json",
title="FashionCLIP Combined Embedding API",
description="Provide a text query and/or an image URL to compute embeddings for vector search."
)
# Launch the app
if __name__ == "__main__":
interface.launch()