m7mdal7aj commited on
Commit
88484c4
·
verified ·
1 Parent(s): b434799

Update my_model/object_detection.py

Browse files
Files changed (1) hide show
  1. my_model/object_detection.py +4 -11
my_model/object_detection.py CHANGED
@@ -8,8 +8,6 @@ import os
8
  from my_model.gen_utilities import get_image_path, get_model_path ,show_image
9
 
10
 
11
-
12
-
13
  class ObjectDetector:
14
  """
15
  A class for detecting objects in images using models like Detic and YOLOv5.
@@ -63,7 +61,6 @@ class ObjectDetector:
63
 
64
  try:
65
  model_path = get_model_path('deformable-detr-detic')
66
-
67
  self.processor = AutoImageProcessor.from_pretrained(model_path)
68
  self.model = AutoModelForObjectDetection.from_pretrained(model_path)
69
  except Exception as e:
@@ -115,8 +112,7 @@ class ObjectDetector:
115
  print(f"Error processing image: {e}")
116
  raise
117
 
118
-
119
-
120
  def detect_objects(self, image, threshold=0.4):
121
  """
122
  Detect objects in the given image using the loaded model.
@@ -139,6 +135,7 @@ class ObjectDetector:
139
  else:
140
  raise ValueError("Model not loaded or unsupported model name")
141
 
 
142
  def _detect_with_detic(self, image, threshold):
143
  """
144
  Detect objects using the Detic model.
@@ -155,9 +152,7 @@ class ObjectDetector:
155
  inputs = self.processor(images=image, return_tensors="pt")
156
  outputs = self.model(**inputs)
157
  target_sizes = torch.tensor([image.size[::-1]])
158
- results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[
159
- 0]
160
-
161
  detected_objects_str = ""
162
  detected_objects_list = []
163
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
@@ -169,6 +164,7 @@ class ObjectDetector:
169
  detected_objects_list.append((label_name, box_rounded, certainty))
170
  return detected_objects_str, detected_objects_list
171
 
 
172
  def _detect_with_yolov5(self, image, threshold):
173
  """
174
  Detect objects using the YOLOv5 model.
@@ -184,7 +180,6 @@ class ObjectDetector:
184
 
185
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
186
  results = self.model(cv2_img)
187
-
188
  detected_objects_str = ""
189
  detected_objects_list = []
190
  for *bbox, conf, cls in results.xyxy[0]:
@@ -214,7 +209,6 @@ class ObjectDetector:
214
  font = ImageFont.truetype("arial.ttf", 15)
215
  except IOError:
216
  font = ImageFont.load_default()
217
-
218
  colors = ["red", "green", "blue", "yellow", "purple", "orange"]
219
  label_color_map = {}
220
 
@@ -224,7 +218,6 @@ class ObjectDetector:
224
 
225
  color = label_color_map[label_name]
226
  draw.rectangle(box, outline=color, width=3)
227
-
228
  label_text = f"{label_name}"
229
  if show_confidence:
230
  label_text += f" ({round(score, 2)}%)"
 
8
  from my_model.gen_utilities import get_image_path, get_model_path ,show_image
9
 
10
 
 
 
11
  class ObjectDetector:
12
  """
13
  A class for detecting objects in images using models like Detic and YOLOv5.
 
61
 
62
  try:
63
  model_path = get_model_path('deformable-detr-detic')
 
64
  self.processor = AutoImageProcessor.from_pretrained(model_path)
65
  self.model = AutoModelForObjectDetection.from_pretrained(model_path)
66
  except Exception as e:
 
112
  print(f"Error processing image: {e}")
113
  raise
114
 
115
+
 
116
  def detect_objects(self, image, threshold=0.4):
117
  """
118
  Detect objects in the given image using the loaded model.
 
135
  else:
136
  raise ValueError("Model not loaded or unsupported model name")
137
 
138
+
139
  def _detect_with_detic(self, image, threshold):
140
  """
141
  Detect objects using the Detic model.
 
152
  inputs = self.processor(images=image, return_tensors="pt")
153
  outputs = self.model(**inputs)
154
  target_sizes = torch.tensor([image.size[::-1]])
155
+ results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
 
 
156
  detected_objects_str = ""
157
  detected_objects_list = []
158
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
 
164
  detected_objects_list.append((label_name, box_rounded, certainty))
165
  return detected_objects_str, detected_objects_list
166
 
167
+
168
  def _detect_with_yolov5(self, image, threshold):
169
  """
170
  Detect objects using the YOLOv5 model.
 
180
 
181
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
182
  results = self.model(cv2_img)
 
183
  detected_objects_str = ""
184
  detected_objects_list = []
185
  for *bbox, conf, cls in results.xyxy[0]:
 
209
  font = ImageFont.truetype("arial.ttf", 15)
210
  except IOError:
211
  font = ImageFont.load_default()
 
212
  colors = ["red", "green", "blue", "yellow", "purple", "orange"]
213
  label_color_map = {}
214
 
 
218
 
219
  color = label_color_map[label_name]
220
  draw.rectangle(box, outline=color, width=3)
 
221
  label_text = f"{label_name}"
222
  if show_confidence:
223
  label_text += f" ({round(score, 2)}%)"