im commited on
Commit
660ea27
·
1 Parent(s): 3957ec0

Example 2 encoder

Browse files
Files changed (3) hide show
  1. app.py +148 -71
  2. encoder.py +89 -0
  3. test_local.py +14 -0
app.py CHANGED
@@ -1,80 +1,157 @@
1
- import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
3
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import requests
5
- import torch
 
 
 
6
 
7
- # Load the FashionCLIP processor and model
8
- processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
9
- model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip")
 
10
 
11
- # Define the function to process both text and image inputs
12
- def generate_embeddings(input_text=None, input_image_url=None):
13
- try:
14
- if input_image_url:
15
- # Process image with accompanying text
16
- response = requests.get(input_image_url, stream=True)
17
- response.raise_for_status()
18
- image = Image.open(response.raw)
19
-
20
- # Use a default text if none is provided
21
- if not input_text:
22
- input_text = "this is an image"
23
-
24
- # Prepare inputs for the model
25
- inputs = processor(
26
- text=[input_text],
27
- images=image,
28
- return_tensors="pt",
29
- padding=True
30
- )
31
-
32
- with torch.no_grad():
33
- outputs = model(**inputs)
34
-
35
- image_embedding = outputs.logits_per_image.cpu().numpy().tolist()
36
- return {
37
- "type": "image_embedding",
38
- "input_image_url": input_image_url,
39
- "input_text": input_text,
40
- "embedding": image_embedding
41
- }
42
-
43
- elif input_text:
44
- # Process text input only
45
- inputs = processor(
46
- text=[input_text],
47
- images=None,
48
- return_tensors="pt",
49
- padding=True
50
- )
51
- with torch.no_grad():
52
- outputs = model(**inputs)
53
-
54
- text_embedding = outputs.logits_per_text.cpu().numpy().tolist()
55
- return {
56
- "type": "text_embedding",
57
- "input_text": input_text,
58
- "embedding": text_embedding
59
- }
60
- else:
61
- return {"error": "Please provide either a text query or an image URL."}
62
 
 
 
 
 
 
 
 
63
  except Exception as e:
64
- return {"error": str(e)}
65
-
66
- # Create the Gradio interface
67
- interface = gr.Interface(
68
- fn=generate_embeddings,
69
- inputs=[
70
- gr.Textbox(label="Text Query (Optional)", placeholder="e.g., red dress (used with image or for text embedding)"),
71
- gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg (used with or without text query)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  ],
73
- outputs="json",
74
- title="FashionCLIP Combined Embedding API",
75
- description="Provide a text query and/or an image URL to compute embeddings for vector search."
76
  )
77
 
78
- # Launch the app
79
  if __name__ == "__main__":
80
- interface.launch()
 
1
+ # import gradio as gr
2
+ # from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
3
+ # from PIL import Image
4
+ # import requests
5
+ # import torch
6
+
7
+ # # Load the FashionCLIP processor and model
8
+ # processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
9
+ # model = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip")
10
+
11
+ # # Define the function to process both text and image inputs
12
+ # def generate_embeddings(input_text=None, input_image_url=None):
13
+ # try:
14
+ # if input_image_url:
15
+ # # Process image with accompanying text
16
+ # response = requests.get(input_image_url, stream=True)
17
+ # response.raise_for_status()
18
+ # image = Image.open(response.raw)
19
+
20
+ # # Use a default text if none is provided
21
+ # if not input_text:
22
+ # input_text = "this is an image"
23
+
24
+ # # Prepare inputs for the model
25
+ # inputs = processor(
26
+ # text=[input_text],
27
+ # images=image,
28
+ # return_tensors="pt",
29
+ # padding=True
30
+ # )
31
+
32
+ # with torch.no_grad():
33
+ # outputs = model(**inputs)
34
+
35
+ # image_embedding = outputs.logits_per_image.cpu().numpy().tolist()
36
+ # return {
37
+ # "type": "image_embedding",
38
+ # "input_image_url": input_image_url,
39
+ # "input_text": input_text,
40
+ # "embedding": image_embedding
41
+ # }
42
+
43
+ # elif input_text:
44
+ # # Process text input only
45
+ # inputs = processor(
46
+ # text=[input_text],
47
+ # images=None,
48
+ # return_tensors="pt",
49
+ # padding=True
50
+ # )
51
+ # with torch.no_grad():
52
+ # outputs = model(**inputs)
53
+
54
+ # text_embedding = outputs.logits_per_text.cpu().numpy().tolist()
55
+ # return {
56
+ # "type": "text_embedding",
57
+ # "input_text": input_text,
58
+ # "embedding": text_embedding
59
+ # }
60
+ # else:
61
+ # return {"error": "Please provide either a text query or an image URL."}
62
+
63
+ # except Exception as e:
64
+ # return {"error": str(e)}
65
+
66
+ # # Create the Gradio interface
67
+ # interface = gr.Interface(
68
+ # fn=generate_embeddings,
69
+ # inputs=[
70
+ # gr.Textbox(label="Text Query (Optional)", placeholder="e.g., red dress (used with image or for text embedding)"),
71
+ # gr.Textbox(label="Image URL", placeholder="e.g., https://example.com/image.jpg (used with or without text query)")
72
+ # ],
73
+ # outputs="json",
74
+ # title="FashionCLIP Combined Embedding API",
75
+ # description="Provide a text query and/or an image URL to compute embeddings for vector search."
76
+ # )
77
+
78
+ # # Launch the app
79
+ # if __name__ == "__main__":
80
+ # interface.launch()
81
+ # print(generate_embeddings("red dress"))
82
+
83
+
84
+
85
+ import uuid
86
  import requests
87
+ from PIL import Image
88
+ import numpy as np
89
+ import gradio as gr
90
+ from encoder import FashionCLIPEncoder
91
 
92
+ # Constants
93
+ REQUESTS_HEADERS = {
94
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
95
+ }
96
 
97
+ # Initialize encoder
98
+ encoder = FashionCLIPEncoder()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Helper function to download images
101
+ def download_image_as_pil(url: str, timeout: int = 10) -> Image.Image:
102
+ try:
103
+ response = requests.get(url, stream=True, headers=REQUESTS_HEADERS, timeout=timeout)
104
+ if response.status_code == 200:
105
+ return Image.open(response.raw).convert("RGB") # Ensure consistent format
106
+ return None
107
  except Exception as e:
108
+ print(f"Error downloading image: {e}")
109
+ return None
110
+
111
+ # Embedding function for a batch of images
112
+ def batch_process_images(image_urls: list):
113
+ embeddings = []
114
+ results = []
115
+ for url in image_urls:
116
+ try:
117
+ # Download image
118
+ image = download_image_as_pil(url)
119
+ if not image:
120
+ results.append({"image_url": url, "error": "Failed to download image"})
121
+ continue
122
+
123
+ # Generate embedding
124
+ embedding = encoder.encode_images([image])[0]
125
+
126
+ # Normalize embedding
127
+ embedding_normalized = embedding / np.linalg.norm(embedding)
128
+
129
+ # Append results
130
+ results.append({
131
+ "image_url": url,
132
+ "embedding_preview": embedding_normalized[:5].tolist(), # First 5 values for preview
133
+ "success": True
134
+ })
135
+ except Exception as e:
136
+ results.append({"image_url": url, "error": str(e)})
137
+ return results
138
+
139
+ # Gradio Interface
140
+ iface = gr.Interface(
141
+ fn=batch_process_images,
142
+ inputs=gr.Textbox(
143
+ lines=5,
144
+ placeholder="Enter image URLs separated by commas",
145
+ label="Batch Image URLs",
146
+ ),
147
+ outputs=gr.JSON(label="Embedding Results"),
148
+ title="Batch Fashion CLIP Embedding API",
149
+ description="Enter multiple image URLs (separated by commas) to generate embeddings for the batch. Each embedding preview includes the first 5 values.",
150
+ examples=[
151
+ ["https://cdn.shopify.com/s/files/1/0522/2239/4534/files/CT21355-22_1024x1024.webp, https://cdn.shopify.com/s/files/1/0522/2239/4534/files/00907857-C6B0-4D2A-8AEA-688BDE1E67D7_1024x1024.jpg"]
152
  ],
 
 
 
153
  )
154
 
155
+ # Launch Gradio App
156
  if __name__ == "__main__":
157
+ iface.launch()
encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ import torch
3
+ from PIL.Image import Image
4
+ from torch.utils.data import DataLoader
5
+ from datasets import Dataset
6
+ from transformers import AutoModel, AutoProcessor
7
+
8
+ MODEL_NAME = "Marqo/marqo-fashionCLIP"
9
+
10
+
11
+ class FashionCLIPEncoder:
12
+ def __init__(self):
13
+ self.processor = AutoProcessor.from_pretrained(
14
+ MODEL_NAME, trust_remote_code=True
15
+ )
16
+ self.model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
17
+
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ self.model.to(self.device)
20
+ self.model.eval()
21
+
22
+ def encode_images(
23
+ self, images: List[Image], batch_size: Optional[int] = None
24
+ ) -> List[List[float]]:
25
+ if batch_size is None:
26
+ batch_size = min(len(images), 32) # Default to a safe batch size
27
+
28
+ def transform_fn(el: Dict):
29
+ return self.processor(
30
+ images=[content for content in el["image"]], return_tensors="pt"
31
+ )
32
+
33
+ dataset = Dataset.from_dict({"image": images})
34
+ dataset.set_format("torch")
35
+ dataset.set_transform(transform_fn)
36
+ dataloader = DataLoader(dataset, batch_size=batch_size)
37
+
38
+ image_embeddings = []
39
+
40
+ with torch.no_grad():
41
+ for batch in dataloader:
42
+ try:
43
+ batch = {k: v.to(self.device) for k, v in batch.items()}
44
+ embeddings = self._encode_images(batch)
45
+ image_embeddings.extend(embeddings)
46
+ except Exception as e:
47
+ print(f"Error encoding image batch: {e}")
48
+
49
+ return image_embeddings
50
+
51
+ def encode_text(
52
+ self, text: List[str], batch_size: Optional[int] = None
53
+ ) -> List[List[float]]:
54
+ if batch_size is None:
55
+ batch_size = min(len(text), 32) # Default to a safe batch size
56
+
57
+ def transform_fn(el: Dict):
58
+ kwargs = {
59
+ "padding": "max_length",
60
+ "return_tensors": "pt",
61
+ "truncation": True,
62
+ }
63
+ return self.processor(text=el["text"], **kwargs)
64
+
65
+ dataset = Dataset.from_dict({"text": text})
66
+ dataset = dataset.map(
67
+ function=transform_fn, batched=True, remove_columns=["text"]
68
+ )
69
+ dataset.set_format("torch")
70
+ dataloader = DataLoader(dataset, batch_size=batch_size)
71
+
72
+ text_embeddings = []
73
+
74
+ with torch.no_grad():
75
+ for batch in dataloader:
76
+ try:
77
+ batch = {k: v.to(self.device) for k, v in batch.items()}
78
+ embeddings = self._encode_text(batch)
79
+ text_embeddings.extend(embeddings)
80
+ except Exception as e:
81
+ print(f"Error encoding text batch: {e}")
82
+
83
+ return text_embeddings
84
+
85
+ def _encode_images(self, batch: Dict) -> List:
86
+ return self.model.get_image_features(**batch).detach().cpu().numpy().tolist()
87
+
88
+ def _encode_text(self, batch: Dict) -> List:
89
+ return self.model.get_text_features(**batch).detach().cpu().numpy().tolist()
test_local.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app import generate_embeddings # Import the function from app.py
2
+
3
+
4
+ # Test with a text query only
5
+ result_text = generate_embeddings(input_text="red dress")
6
+ print("Text Embedding Result:", result_text)
7
+
8
+ # Test with an image URL only
9
+ result_image = generate_embeddings(input_image_url="https://vacier.com/cdn/shop/files/Unisize_Ring_db777381-c510-457f-b8c9-5812665d094b.jpg?v=1731838123&width=1080")
10
+ print("Image Embedding Result:", result_image)
11
+
12
+ # Test with both text and image
13
+ result_both = generate_embeddings(input_text="red dress", input_image_url="https://vacier.com/cdn/shop/files/Unisize_Ring_db777381-c510-457f-b8c9-5812665d094b.jpg?v=1731838123&width=1080")
14
+ print("Text and Image Embedding Result:", result_both)