openFashionClip / app.py
im
divid
e9eeafb
raw
history blame
3.97 kB
# import gradio as gr
# from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
# from PIL import Image
# import requests
# import torch
# # Load the FashionCLIP processor and model
# processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
# model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip")
# # Define the function to process both text and image inputs
# def generate_embeddings(input_text=None, input_image_url=None):
# try:
# if input_image_url:
# # Process image with accompanying text
# response = requests.get(input_image_url, stream=True)
# response.raise_for_status()
# image = Image.open(response.raw)
# # Use a default text if none is provided
# if not input_text:
# input_text = "this is an image"
# # Prepare inputs for the model
# inputs = processor(
# text=[input_text],
# images=image,
# return_tensors="pt",
# padding=True
# )
# with torch.no_grad():
# outputs = model(**inputs)
# image_embedding = outputs.logits_per_image.cpu().numpy().tolist()
# return {
# "type": "image_embedding",
# "input_image_url": input_image_url,
# "input_text": input_text,
# "embedding": image_embedding
# }
# elif input_text:
# # Process text input only
# inputs = processor(
# text=[input_text],
# images=None,
# return_tensors="pt",
# padding=True
# )
# with torch.no_grad():
# outputs = model(**inputs)
# text_embedding = outputs.logits_per_text.cpu().numpy().tolist()
# return {
# "type": "text_embedding",
# "input_text": input_text,
# "embedding": text_embedding
# }
# else:
# return {"error": "Please provide either a text query or an image URL."}
# except Exception as e:
# return {"error": str(e)}
# # Create the Gradio interface
# interface = gr.Interface(
# fn=generate_embeddings,
# inputs=[
# gr.Textbox(label="Text Query (Optional)", placeholder="e.g., red dress (used with image or for text embedding)"),
# gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg (used with or without text query)")
# ],
# outputs="json",
# title="FashionCLIP Combined Embedding API",
# description="Provide a text query and/or an image URL to compute embeddings for vector search."
# )
# # Launch the app
# if __name__ == "__main__":
# interface.launch()
# print(generate_embeddings("red dress"))
import os
import gradio as gr
from encoder import FashionCLIPEncoder
from process import batch_process_images
# Initialize encoder
encoder = FashionCLIPEncoder()
# Constants
BATCH_SIZE = 30 # Define batch size for processing
# Gradio Interface
iface = gr.Interface(
fn=lambda image_urls: batch_process_images(image_urls, encoder, BATCH_SIZE),
inputs=gr.Textbox(
lines=5,
placeholder="Enter image URLs separated by commas",
label="Batch Image URLs",
),
outputs=gr.JSON(label="Embedding Results"),
title="Batch Fashion CLIP Embedding API",
description="Enter multiple image URLs (separated by commas) to generate embeddings for the batch. Each embedding preview includes the first 5 values.",
examples=[
["https://cdn.shopify.com/s/files/1/0522/2239/4534/files/CT21355-22_1024x1024.webp, https://cdn.shopify.com/s/files/1/0522/2239/4534/files/00907857-C6B0-4D2A-8AEA-688BDE1E67D7_1024x1024.jpg"]
],
)
# Launch Gradio App
if __name__ == "__main__":
iface.launch()