Spaces:
Running
Running
im
commited on
Commit
·
82a72c7
1
Parent(s):
f341903
Trying to faster model
Browse files- app.py +31 -7
- encoder.py +11 -18
app.py
CHANGED
@@ -93,6 +93,7 @@ from encoder import FashionCLIPEncoder
|
|
93 |
REQUESTS_HEADERS = {
|
94 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
95 |
}
|
|
|
96 |
|
97 |
# Initialize encoder
|
98 |
encoder = FashionCLIPEncoder()
|
@@ -116,8 +117,9 @@ def batch_process_images(image_urls: str):
|
|
116 |
if not urls:
|
117 |
return {"error": "No valid image URLs provided."}
|
118 |
|
119 |
-
embeddings = []
|
120 |
results = []
|
|
|
|
|
121 |
for url in urls:
|
122 |
try:
|
123 |
# Download image
|
@@ -125,10 +127,32 @@ def batch_process_images(image_urls: str):
|
|
125 |
if not image:
|
126 |
results.append({"image_url": url, "error": "Failed to download image"})
|
127 |
continue
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
# Normalize embedding
|
133 |
embedding_normalized = embedding / np.linalg.norm(embedding)
|
134 |
|
@@ -138,9 +162,9 @@ def batch_process_images(image_urls: str):
|
|
138 |
"embedding_preview": embedding_normalized[:5].tolist(), # First 5 values for preview
|
139 |
"success": True
|
140 |
})
|
141 |
-
|
|
|
142 |
results.append({"image_url": url, "error": str(e)})
|
143 |
-
return results
|
144 |
|
145 |
|
146 |
# Gradio Interface
|
|
|
93 |
REQUESTS_HEADERS = {
|
94 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
95 |
}
|
96 |
+
BATCH_SIZE = 30 # Define batch size for processing
|
97 |
|
98 |
# Initialize encoder
|
99 |
encoder = FashionCLIPEncoder()
|
|
|
117 |
if not urls:
|
118 |
return {"error": "No valid image URLs provided."}
|
119 |
|
|
|
120 |
results = []
|
121 |
+
batch_urls, batch_images = [], []
|
122 |
+
|
123 |
for url in urls:
|
124 |
try:
|
125 |
# Download image
|
|
|
127 |
if not image:
|
128 |
results.append({"image_url": url, "error": "Failed to download image"})
|
129 |
continue
|
130 |
+
|
131 |
+
batch_urls.append(url)
|
132 |
+
batch_images.append(image)
|
133 |
+
|
134 |
+
# Process batch when reaching batch size
|
135 |
+
if len(batch_images) == BATCH_SIZE:
|
136 |
+
process_batch(batch_urls, batch_images, results)
|
137 |
+
batch_urls, batch_images = [], []
|
138 |
+
|
139 |
+
except Exception as e:
|
140 |
+
results.append({"image_url": url, "error": str(e)})
|
141 |
+
|
142 |
+
# Process remaining images in the last batch
|
143 |
+
if batch_images:
|
144 |
+
process_batch(batch_urls, batch_images, results)
|
145 |
+
|
146 |
+
return results
|
147 |
+
|
148 |
+
|
149 |
+
# Helper function to process a batch
|
150 |
+
def process_batch(batch_urls, batch_images, results):
|
151 |
+
try:
|
152 |
+
# Generate embeddings
|
153 |
+
embeddings = encoder.encode_images(batch_images)
|
154 |
+
|
155 |
+
for url, embedding in zip(batch_urls, embeddings):
|
156 |
# Normalize embedding
|
157 |
embedding_normalized = embedding / np.linalg.norm(embedding)
|
158 |
|
|
|
162 |
"embedding_preview": embedding_normalized[:5].tolist(), # First 5 values for preview
|
163 |
"success": True
|
164 |
})
|
165 |
+
except Exception as e:
|
166 |
+
for url in batch_urls:
|
167 |
results.append({"image_url": url, "error": str(e)})
|
|
|
168 |
|
169 |
|
170 |
# Gradio Interface
|
encoder.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from typing import List, Dict, Optional
|
|
|
2 |
import torch
|
3 |
from PIL.Image import Image
|
4 |
from torch.utils.data import DataLoader
|
@@ -15,15 +16,14 @@ class FashionCLIPEncoder:
|
|
15 |
)
|
16 |
self.model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
17 |
|
18 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
-
self.model.to(self.device)
|
20 |
self.model.eval()
|
|
|
21 |
|
22 |
def encode_images(
|
23 |
self, images: List[Image], batch_size: Optional[int] = None
|
24 |
) -> List[List[float]]:
|
25 |
if batch_size is None:
|
26 |
-
batch_size =
|
27 |
|
28 |
def transform_fn(el: Dict):
|
29 |
return self.processor(
|
@@ -39,12 +39,9 @@ class FashionCLIPEncoder:
|
|
39 |
|
40 |
with torch.no_grad():
|
41 |
for batch in dataloader:
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
image_embeddings.extend(embeddings)
|
46 |
-
except Exception as e:
|
47 |
-
print(f"Error encoding image batch: {e}")
|
48 |
|
49 |
return image_embeddings
|
50 |
|
@@ -52,7 +49,7 @@ class FashionCLIPEncoder:
|
|
52 |
self, text: List[str], batch_size: Optional[int] = None
|
53 |
) -> List[List[float]]:
|
54 |
if batch_size is None:
|
55 |
-
batch_size =
|
56 |
|
57 |
def transform_fn(el: Dict):
|
58 |
kwargs = {
|
@@ -68,17 +65,13 @@ class FashionCLIPEncoder:
|
|
68 |
)
|
69 |
dataset.set_format("torch")
|
70 |
dataloader = DataLoader(dataset, batch_size=batch_size)
|
71 |
-
|
72 |
text_embeddings = []
|
73 |
|
74 |
with torch.no_grad():
|
75 |
for batch in dataloader:
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
text_embeddings.extend(embeddings)
|
80 |
-
except Exception as e:
|
81 |
-
print(f"Error encoding text batch: {e}")
|
82 |
|
83 |
return text_embeddings
|
84 |
|
@@ -86,4 +79,4 @@ class FashionCLIPEncoder:
|
|
86 |
return self.model.get_image_features(**batch).detach().cpu().numpy().tolist()
|
87 |
|
88 |
def _encode_text(self, batch: Dict) -> List:
|
89 |
-
return self.model.get_text_features(**batch).detach().cpu().numpy().tolist()
|
|
|
1 |
from typing import List, Dict, Optional
|
2 |
+
|
3 |
import torch
|
4 |
from PIL.Image import Image
|
5 |
from torch.utils.data import DataLoader
|
|
|
16 |
)
|
17 |
self.model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
18 |
|
|
|
|
|
19 |
self.model.eval()
|
20 |
+
self.device = self.model.device
|
21 |
|
22 |
def encode_images(
|
23 |
self, images: List[Image], batch_size: Optional[int] = None
|
24 |
) -> List[List[float]]:
|
25 |
if batch_size is None:
|
26 |
+
batch_size = len(images)
|
27 |
|
28 |
def transform_fn(el: Dict):
|
29 |
return self.processor(
|
|
|
39 |
|
40 |
with torch.no_grad():
|
41 |
for batch in dataloader:
|
42 |
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
43 |
+
embeddings = self._encode_images(batch)
|
44 |
+
image_embeddings.extend(embeddings)
|
|
|
|
|
|
|
45 |
|
46 |
return image_embeddings
|
47 |
|
|
|
49 |
self, text: List[str], batch_size: Optional[int] = None
|
50 |
) -> List[List[float]]:
|
51 |
if batch_size is None:
|
52 |
+
batch_size = len(text)
|
53 |
|
54 |
def transform_fn(el: Dict):
|
55 |
kwargs = {
|
|
|
65 |
)
|
66 |
dataset.set_format("torch")
|
67 |
dataloader = DataLoader(dataset, batch_size=batch_size)
|
|
|
68 |
text_embeddings = []
|
69 |
|
70 |
with torch.no_grad():
|
71 |
for batch in dataloader:
|
72 |
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
73 |
+
embeddings = self._encode_text(batch)
|
74 |
+
text_embeddings.extend(embeddings)
|
|
|
|
|
|
|
75 |
|
76 |
return text_embeddings
|
77 |
|
|
|
79 |
return self.model.get_image_features(**batch).detach().cpu().numpy().tolist()
|
80 |
|
81 |
def _encode_text(self, batch: Dict) -> List:
|
82 |
+
return self.model.get_text_features(**batch).detach().cpu().numpy().tolist()
|