openFashionClip / app.py
Streetmarkets's picture
Create app.py
159fca3 verified
raw
history blame
1.93 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from PIL import Image
import requests
import torch
# Load the FashionCLIP processor and model
processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip")
# Define the function to process image and text
def process_image_and_text(product_title, image_url):
try:
# Fetch and process the image
response = requests.get(image_url, stream=True)
response.raise_for_status()
image = Image.open(response.raw)
# Prepare inputs for the model
inputs = processor(
text=[product_title],
images=image,
return_tensors="pt",
padding=True
)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Extract similarity score and embeddings
similarity_score = outputs.logits_per_image[0].item()
text_embedding = outputs.logits_per_text.cpu().numpy().tolist()
image_embedding = outputs.logits_per_image.cpu().numpy().tolist()
return {
"similarity_score": similarity_score,
"text_embedding": text_embedding,
"image_embedding": image_embedding
}
except Exception as e:
return {"error": str(e)}
# Create the Gradio interface
interface = gr.Interface(
fn=process_image_and_text,
inputs=[
gr.Textbox(label="Product Title", placeholder="e.g., ring for men"),
gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg")
],
outputs="json",
title="FashionCLIP API",
description="Provide a product title and an image URL to compute similarity score and embeddings."
)
# Launch the app
if __name__ == "__main__":
interface.launch()