import gradio as gr from transformers import AutoProcessor, AutoModelForZeroShotImageClassification from PIL import Image import requests import torch # Load the FashionCLIP processor and model processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip") model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip") # Define the function to process both text and image inputs def generate_embeddings(input_text=None, input_image_url=None): try: if input_image_url: # Process image with accompanying text response = requests.get(input_image_url, stream=True) response.raise_for_status() image = Image.open(response.raw) # Use a default text if none is provided if not input_text: input_text = "this is an image" # Prepare inputs for the model inputs = processor( text=[input_text], images=image, return_tensors="pt", padding=True ) with torch.no_grad(): outputs = model(**inputs) image_embedding = outputs.logits_per_image.cpu().numpy().tolist() return { "type": "image_embedding", "input_image_url": input_image_url, "input_text": input_text, "embedding": image_embedding } elif input_text: # Process text input only inputs = processor( text=[input_text], images=None, return_tensors="pt", padding=True ) with torch.no_grad(): outputs = model(**inputs) text_embedding = outputs.logits_per_text.cpu().numpy().tolist() return { "type": "text_embedding", "input_text": input_text, "embedding": text_embedding } else: return {"error": "Please provide either a text query or an image URL."} except Exception as e: return {"error": str(e)} # Create the Gradio interface interface = gr.Interface( fn=generate_embeddings, inputs=[ gr.Textbox(label="Text Query (Optional)", placeholder="e.g., red dress (used with image or for text embedding)"), gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg (used with or without text query)") ], outputs="json", title="FashionCLIP Combined Embedding API", description="Provide a text query and/or an image URL to compute embeddings for vector search." ) # Launch the app if __name__ == "__main__": interface.launch()