im commited on
Commit
82a72c7
·
1 Parent(s): f341903

Trying to faster model

Browse files
Files changed (2) hide show
  1. app.py +31 -7
  2. encoder.py +11 -18
app.py CHANGED
@@ -93,6 +93,7 @@ from encoder import FashionCLIPEncoder
93
  REQUESTS_HEADERS = {
94
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
95
  }
 
96
 
97
  # Initialize encoder
98
  encoder = FashionCLIPEncoder()
@@ -116,8 +117,9 @@ def batch_process_images(image_urls: str):
116
  if not urls:
117
  return {"error": "No valid image URLs provided."}
118
 
119
- embeddings = []
120
  results = []
 
 
121
  for url in urls:
122
  try:
123
  # Download image
@@ -125,10 +127,32 @@ def batch_process_images(image_urls: str):
125
  if not image:
126
  results.append({"image_url": url, "error": "Failed to download image"})
127
  continue
128
-
129
- # Generate embedding
130
- embedding = encoder.encode_images([image])[0]
131
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # Normalize embedding
133
  embedding_normalized = embedding / np.linalg.norm(embedding)
134
 
@@ -138,9 +162,9 @@ def batch_process_images(image_urls: str):
138
  "embedding_preview": embedding_normalized[:5].tolist(), # First 5 values for preview
139
  "success": True
140
  })
141
- except Exception as e:
 
142
  results.append({"image_url": url, "error": str(e)})
143
- return results
144
 
145
 
146
  # Gradio Interface
 
93
  REQUESTS_HEADERS = {
94
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
95
  }
96
+ BATCH_SIZE = 30 # Define batch size for processing
97
 
98
  # Initialize encoder
99
  encoder = FashionCLIPEncoder()
 
117
  if not urls:
118
  return {"error": "No valid image URLs provided."}
119
 
 
120
  results = []
121
+ batch_urls, batch_images = [], []
122
+
123
  for url in urls:
124
  try:
125
  # Download image
 
127
  if not image:
128
  results.append({"image_url": url, "error": "Failed to download image"})
129
  continue
130
+
131
+ batch_urls.append(url)
132
+ batch_images.append(image)
133
+
134
+ # Process batch when reaching batch size
135
+ if len(batch_images) == BATCH_SIZE:
136
+ process_batch(batch_urls, batch_images, results)
137
+ batch_urls, batch_images = [], []
138
+
139
+ except Exception as e:
140
+ results.append({"image_url": url, "error": str(e)})
141
+
142
+ # Process remaining images in the last batch
143
+ if batch_images:
144
+ process_batch(batch_urls, batch_images, results)
145
+
146
+ return results
147
+
148
+
149
+ # Helper function to process a batch
150
+ def process_batch(batch_urls, batch_images, results):
151
+ try:
152
+ # Generate embeddings
153
+ embeddings = encoder.encode_images(batch_images)
154
+
155
+ for url, embedding in zip(batch_urls, embeddings):
156
  # Normalize embedding
157
  embedding_normalized = embedding / np.linalg.norm(embedding)
158
 
 
162
  "embedding_preview": embedding_normalized[:5].tolist(), # First 5 values for preview
163
  "success": True
164
  })
165
+ except Exception as e:
166
+ for url in batch_urls:
167
  results.append({"image_url": url, "error": str(e)})
 
168
 
169
 
170
  # Gradio Interface
encoder.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import List, Dict, Optional
 
2
  import torch
3
  from PIL.Image import Image
4
  from torch.utils.data import DataLoader
@@ -15,15 +16,14 @@ class FashionCLIPEncoder:
15
  )
16
  self.model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
17
 
18
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- self.model.to(self.device)
20
  self.model.eval()
 
21
 
22
  def encode_images(
23
  self, images: List[Image], batch_size: Optional[int] = None
24
  ) -> List[List[float]]:
25
  if batch_size is None:
26
- batch_size = min(len(images), 32) # Default to a safe batch size
27
 
28
  def transform_fn(el: Dict):
29
  return self.processor(
@@ -39,12 +39,9 @@ class FashionCLIPEncoder:
39
 
40
  with torch.no_grad():
41
  for batch in dataloader:
42
- try:
43
- batch = {k: v.to(self.device) for k, v in batch.items()}
44
- embeddings = self._encode_images(batch)
45
- image_embeddings.extend(embeddings)
46
- except Exception as e:
47
- print(f"Error encoding image batch: {e}")
48
 
49
  return image_embeddings
50
 
@@ -52,7 +49,7 @@ class FashionCLIPEncoder:
52
  self, text: List[str], batch_size: Optional[int] = None
53
  ) -> List[List[float]]:
54
  if batch_size is None:
55
- batch_size = min(len(text), 32) # Default to a safe batch size
56
 
57
  def transform_fn(el: Dict):
58
  kwargs = {
@@ -68,17 +65,13 @@ class FashionCLIPEncoder:
68
  )
69
  dataset.set_format("torch")
70
  dataloader = DataLoader(dataset, batch_size=batch_size)
71
-
72
  text_embeddings = []
73
 
74
  with torch.no_grad():
75
  for batch in dataloader:
76
- try:
77
- batch = {k: v.to(self.device) for k, v in batch.items()}
78
- embeddings = self._encode_text(batch)
79
- text_embeddings.extend(embeddings)
80
- except Exception as e:
81
- print(f"Error encoding text batch: {e}")
82
 
83
  return text_embeddings
84
 
@@ -86,4 +79,4 @@ class FashionCLIPEncoder:
86
  return self.model.get_image_features(**batch).detach().cpu().numpy().tolist()
87
 
88
  def _encode_text(self, batch: Dict) -> List:
89
- return self.model.get_text_features(**batch).detach().cpu().numpy().tolist()
 
1
  from typing import List, Dict, Optional
2
+
3
  import torch
4
  from PIL.Image import Image
5
  from torch.utils.data import DataLoader
 
16
  )
17
  self.model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
18
 
 
 
19
  self.model.eval()
20
+ self.device = self.model.device
21
 
22
  def encode_images(
23
  self, images: List[Image], batch_size: Optional[int] = None
24
  ) -> List[List[float]]:
25
  if batch_size is None:
26
+ batch_size = len(images)
27
 
28
  def transform_fn(el: Dict):
29
  return self.processor(
 
39
 
40
  with torch.no_grad():
41
  for batch in dataloader:
42
+ batch = {k: v.to(self.device) for k, v in batch.items()}
43
+ embeddings = self._encode_images(batch)
44
+ image_embeddings.extend(embeddings)
 
 
 
45
 
46
  return image_embeddings
47
 
 
49
  self, text: List[str], batch_size: Optional[int] = None
50
  ) -> List[List[float]]:
51
  if batch_size is None:
52
+ batch_size = len(text)
53
 
54
  def transform_fn(el: Dict):
55
  kwargs = {
 
65
  )
66
  dataset.set_format("torch")
67
  dataloader = DataLoader(dataset, batch_size=batch_size)
 
68
  text_embeddings = []
69
 
70
  with torch.no_grad():
71
  for batch in dataloader:
72
+ batch = {k: v.to(self.device) for k, v in batch.items()}
73
+ embeddings = self._encode_text(batch)
74
+ text_embeddings.extend(embeddings)
 
 
 
75
 
76
  return text_embeddings
77
 
 
79
  return self.model.get_image_features(**batch).detach().cpu().numpy().tolist()
80
 
81
  def _encode_text(self, batch: Dict) -> List:
82
+ return self.model.get_text_features(**batch).detach().cpu().numpy().tolist()