santit96 commited on
Commit
fa84113
·
0 Parent(s):

Create the streamlit app that classifies the trash in an image into classes

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .github/workflows/main.yml +20 -0
  3. .gitignore +5 -0
  4. README.md +14 -0
  5. app.py +82 -0
  6. constants.py +8 -0
  7. efficientdet/__init__.py +0 -0
  8. efficientdet/effdet/__init__.py +7 -0
  9. efficientdet/effdet/anchors.py +421 -0
  10. efficientdet/effdet/bench.py +143 -0
  11. efficientdet/effdet/config/__init__.py +4 -0
  12. efficientdet/effdet/config/config_utils.py +9 -0
  13. efficientdet/effdet/config/fpn_config.py +184 -0
  14. efficientdet/effdet/config/model_config.py +538 -0
  15. efficientdet/effdet/config/train_config.py +34 -0
  16. efficientdet/effdet/data/__init__.py +6 -0
  17. efficientdet/effdet/data/dataset.py +145 -0
  18. efficientdet/effdet/data/dataset_config.py +194 -0
  19. efficientdet/effdet/data/dataset_factory.py +85 -0
  20. efficientdet/effdet/data/input_config.py +60 -0
  21. efficientdet/effdet/data/loader.py +226 -0
  22. efficientdet/effdet/data/parsers/__init__.py +2 -0
  23. efficientdet/effdet/data/parsers/parser.py +82 -0
  24. efficientdet/effdet/data/parsers/parser_coco.py +93 -0
  25. efficientdet/effdet/data/parsers/parser_config.py +49 -0
  26. efficientdet/effdet/data/parsers/parser_factory.py +19 -0
  27. efficientdet/effdet/data/parsers/parser_open_images.py +211 -0
  28. efficientdet/effdet/data/parsers/parser_voc.py +148 -0
  29. efficientdet/effdet/data/random_erasing.py +94 -0
  30. efficientdet/effdet/data/transforms.py +275 -0
  31. efficientdet/effdet/data/transforms_albumentation.py +23 -0
  32. efficientdet/effdet/distributed.py +308 -0
  33. efficientdet/effdet/efficientdet.py +557 -0
  34. efficientdet/effdet/evaluation/README.md +7 -0
  35. efficientdet/effdet/evaluation/__init__.py +0 -0
  36. efficientdet/effdet/evaluation/detection_evaluator.py +590 -0
  37. efficientdet/effdet/evaluation/fields.py +105 -0
  38. efficientdet/effdet/evaluation/metrics.py +148 -0
  39. efficientdet/effdet/evaluation/np_box_list.py +696 -0
  40. efficientdet/effdet/evaluation/np_mask_list.py +478 -0
  41. efficientdet/effdet/evaluation/object_detection_evaluation.py +273 -0
  42. efficientdet/effdet/evaluation/per_image_evaluation.py +538 -0
  43. efficientdet/effdet/evaluator.py +195 -0
  44. efficientdet/effdet/factory.py +54 -0
  45. efficientdet/effdet/helpers.py +22 -0
  46. efficientdet/effdet/loss.py +259 -0
  47. efficientdet/effdet/object_detection/README.md +3 -0
  48. efficientdet/effdet/object_detection/__init__.py +22 -0
  49. efficientdet/effdet/object_detection/argmax_matcher.py +174 -0
  50. efficientdet/effdet/object_detection/box_coder.py +172 -0
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.psd filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
3
+ *.tar filter=lfs diff=lfs merge=lfs -text
4
+ *.pkl filter=lfs diff=lfs merge=lfs -text
.github/workflows/main.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to hub
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push --force https://santit96:[email protected]/spaces/rootstrap-org/waste-classifier main
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__
2
+ .DS_Store
3
+ *.jpg
4
+ *.png
5
+ *.jpeg
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Waste Classifier
3
+ emoji: ♻️
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: streamlit
7
+ sdk_version: 1.25.0
8
+ pinned: false
9
+ ---
10
+
11
+ Waste Classifier
12
+ ==============================
13
+
14
+ Waste Detection and Classifier tool
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit app
3
+ """
4
+ import sys
5
+
6
+ import streamlit as st
7
+
8
+ from constants import (CLAS_FILEPATH, CLAS_THRESHOLD, CLASSES, DET_FILEPATH,
9
+ DET_NAME, DET_THRESHOLD, DEVICE, OUTPUT_IMG_FILEPATH)
10
+
11
+ sys.path.append("./efficientdet")
12
+
13
+ from PIL import Image
14
+
15
+ from efficientdet.efficientdet import plot_results
16
+ from trash_detector import detect_trash
17
+
18
+
19
+ def initial_config():
20
+ """
21
+ Initial configuration of streamlit page
22
+ """
23
+ st.set_page_config(
24
+ page_title="Waste Classifier",
25
+ page_icon="♻️",
26
+ )
27
+
28
+
29
+ def render():
30
+ """
31
+ Render the streamlit app
32
+ """
33
+ st.title("Waste classifier")
34
+ st.markdown("""Classify your waste into different classes""")
35
+
36
+ # Image loader and button
37
+ uploaded_file = st.file_uploader(
38
+ "Upload image with trash", type=["jpg", "jpeg", "png", "gif", "bmp"]
39
+ )
40
+ classify_button = st.button("Classify trash")
41
+
42
+ if classify_button:
43
+ if not uploaded_file:
44
+ st.error("Upload an image")
45
+ else:
46
+ # Create two columns
47
+ col1, col2 = st.columns(2)
48
+
49
+ # Column 1: Uploaded image
50
+ with col1:
51
+ st.write("Uploaded image")
52
+ st.image(
53
+ uploaded_file, caption="Uploaded Image.", use_column_width=True
54
+ )
55
+
56
+ # Column 2: Classified image
57
+ with col2:
58
+ with st.spinner(text="Classifying the trash..."):
59
+ img = Image.open(uploaded_file).convert("RGB")
60
+ cls_prob, bboxes_final = detect_trash(
61
+ img,
62
+ DET_NAME,
63
+ DET_FILEPATH,
64
+ CLAS_FILEPATH,
65
+ DEVICE,
66
+ DET_THRESHOLD,
67
+ CLAS_THRESHOLD,
68
+ )
69
+ # plot and save demo image
70
+ plot_results(
71
+ img, cls_prob, bboxes_final, CLASSES, OUTPUT_IMG_FILEPATH
72
+ )
73
+ output_img = Image.open(OUTPUT_IMG_FILEPATH)
74
+ st.write("Classified image")
75
+ st.image(
76
+ output_img, caption="Classified Image.", use_column_width=True
77
+ )
78
+
79
+
80
+ if __name__ == "__main__":
81
+ initial_config()
82
+ render()
constants.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ CLAS_FILEPATH = "models/resnet50-classifier.pkl"
2
+ DET_FILEPATH = "models/efficientdet-d2-detector.pth.tar"
3
+ CLASSES = ["cardboard", "compost", "glass", "metal", "paper", "plastic", "trash"]
4
+ DET_NAME = "tf_efficientdet_d2"
5
+ CLAS_THRESHOLD = 0.5
6
+ DET_THRESHOLD = 0.17
7
+ DEVICE = "cpu"
8
+ OUTPUT_IMG_FILEPATH = "classified_image.jpg"
efficientdet/__init__.py ADDED
File without changes
efficientdet/effdet/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .efficientdet import EfficientDet
2
+ from .bench import DetBenchPredict, DetBenchTrain, unwrap_bench
3
+ from .data import create_dataset, create_loader, create_parser, DetectionDatset, SkipSubset
4
+ from .evaluator import CocoEvaluator, PascalEvaluator, OpenImagesEvaluator, create_evaluator
5
+ from .config import get_efficientdet_config, default_detection_model_configs
6
+ from .factory import create_model, create_model_from_config
7
+ from .helpers import load_checkpoint, load_pretrained
efficientdet/effdet/anchors.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ RetinaNet / EfficientDet Anchor Gen
2
+
3
+ Adapted for PyTorch from Tensorflow impl at
4
+ https://github.com/google/automl/blob/6f6694cec1a48cdb33d5d1551a2d5db8ad227798/efficientdet/anchors.py
5
+
6
+ Hacked together by Ross Wightman, original copyright below
7
+ """
8
+ # Copyright 2020 Google Research. All Rights Reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ # ==============================================================================
22
+ """Anchor definition.
23
+
24
+ This module is borrowed from TPU RetinaNet implementation:
25
+ https://github.com/tensorflow/tpu/blob/master/models/official/retinanet/anchors.py
26
+ """
27
+ from typing import Optional, Tuple, Sequence
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ #import torchvision.ops.boxes as tvb
33
+ from torchvision.ops.boxes import batched_nms, remove_small_boxes
34
+ from typing import List
35
+
36
+ from effdet.object_detection import ArgMaxMatcher, FasterRcnnBoxCoder, BoxList, IouSimilarity, TargetAssigner
37
+ from .soft_nms import batched_soft_nms
38
+
39
+
40
+ # The minimum score to consider a logit for identifying detections.
41
+ MIN_CLASS_SCORE = -5.0
42
+
43
+ # The score for a dummy detection
44
+ _DUMMY_DETECTION_SCORE = -1e5
45
+
46
+ # The maximum number of (anchor,class) pairs to keep for non-max suppression.
47
+ MAX_DETECTION_POINTS = 5000
48
+
49
+ # The maximum number of detections per image.
50
+ MAX_DETECTIONS_PER_IMAGE = 100
51
+
52
+
53
+ def decode_box_outputs(rel_codes, anchors, output_xyxy: bool=False):
54
+ """Transforms relative regression coordinates to absolute positions.
55
+
56
+ Network predictions are normalized and relative to a given anchor; this
57
+ reverses the transformation and outputs absolute coordinates for the input image.
58
+
59
+ Args:
60
+ rel_codes: box regression targets.
61
+
62
+ anchors: anchors on all feature levels.
63
+
64
+ Returns:
65
+ outputs: bounding boxes.
66
+
67
+ """
68
+ ycenter_a = (anchors[:, 0] + anchors[:, 2]) / 2
69
+ xcenter_a = (anchors[:, 1] + anchors[:, 3]) / 2
70
+ ha = anchors[:, 2] - anchors[:, 0]
71
+ wa = anchors[:, 3] - anchors[:, 1]
72
+
73
+ ty, tx, th, tw = rel_codes.unbind(dim=1)
74
+
75
+ w = torch.exp(tw) * wa
76
+ h = torch.exp(th) * ha
77
+ ycenter = ty * ha + ycenter_a
78
+ xcenter = tx * wa + xcenter_a
79
+ ymin = ycenter - h / 2.
80
+ xmin = xcenter - w / 2.
81
+ ymax = ycenter + h / 2.
82
+ xmax = xcenter + w / 2.
83
+ if output_xyxy:
84
+ out = torch.stack([xmin, ymin, xmax, ymax], dim=1)
85
+ else:
86
+ out = torch.stack([ymin, xmin, ymax, xmax], dim=1)
87
+ return out
88
+
89
+
90
+ def clip_boxes_xyxy(boxes: torch.Tensor, size: torch.Tensor):
91
+ boxes = boxes.clamp(min=0)
92
+ size = torch.cat([size, size], dim=0)
93
+ boxes = boxes.min(size)
94
+ return boxes
95
+
96
+
97
+ def generate_detections(
98
+ cls_outputs, box_outputs, anchor_boxes, indices, classes,
99
+ img_scale: Optional[torch.Tensor], img_size: Optional[torch.Tensor],
100
+ max_det_per_image: int = MAX_DETECTIONS_PER_IMAGE, soft_nms: bool = False):
101
+ """Generates detections with RetinaNet model outputs and anchors.
102
+
103
+ Args:
104
+ cls_outputs: a torch tensor with shape [N, 1], which has the highest class
105
+ scores on all feature levels. The N is the number of selected
106
+ top-K total anchors on all levels. (k being MAX_DETECTION_POINTS)
107
+
108
+ box_outputs: a torch tensor with shape [N, 4], which stacks box regression
109
+ outputs on all feature levels. The N is the number of selected top-k
110
+ total anchors on all levels. (k being MAX_DETECTION_POINTS)
111
+
112
+ anchor_boxes: a torch tensor with shape [N, 4], which stacks anchors on all
113
+ feature levels. The N is the number of selected top-k total anchors on all levels.
114
+
115
+ indices: a torch tensor with shape [N], which is the indices from top-k selection.
116
+
117
+ classes: a torch tensor with shape [N], which represents the class
118
+ prediction on all selected anchors from top-k selection.
119
+
120
+ img_scale: a float tensor representing the scale between original image
121
+ and input image for the detector. It is used to rescale detections for
122
+ evaluating with the original groundtruth annotations.
123
+
124
+ max_det_per_image: an int constant, added as argument to make torchscript happy
125
+
126
+ Returns:
127
+ detections: detection results in a tensor with shape [MAX_DETECTION_POINTS, 6],
128
+ each row representing [x_min, y_min, x_max, y_max, score, class]
129
+ """
130
+ assert box_outputs.shape[-1] == 4
131
+ assert anchor_boxes.shape[-1] == 4
132
+ assert cls_outputs.shape[-1] == 1
133
+
134
+ anchor_boxes = anchor_boxes[indices, :]
135
+
136
+ # Appply bounding box regression to anchors, boxes are converted to xyxy
137
+ # here since PyTorch NMS expects them in that form.
138
+ boxes = decode_box_outputs(box_outputs.float(), anchor_boxes, output_xyxy=True)
139
+ if img_scale is not None and img_size is not None:
140
+ boxes = clip_boxes_xyxy(boxes, img_size / img_scale) # clip before NMS better?
141
+
142
+ scores = cls_outputs.sigmoid().squeeze(1).float()
143
+ if soft_nms:
144
+ top_detection_idx, soft_scores = batched_soft_nms(
145
+ boxes, scores, classes, method_gaussian=True, iou_threshold=0.3, score_threshold=.001)
146
+ scores[top_detection_idx] = soft_scores
147
+ else:
148
+ top_detection_idx = batched_nms(boxes, scores, classes, iou_threshold=0.5)
149
+
150
+ # keep only topk scoring predictions
151
+ top_detection_idx = top_detection_idx[:max_det_per_image]
152
+ boxes = boxes[top_detection_idx]
153
+ scores = scores[top_detection_idx, None]
154
+ classes = classes[top_detection_idx, None] + 1 # back to class idx with background class = 0
155
+
156
+ if img_scale is not None:
157
+ boxes = boxes * img_scale
158
+
159
+ # FIXME add option to convert boxes back to yxyx? Otherwise must be handled downstream if
160
+ # that is the preferred output format.
161
+
162
+ # stack em and pad out to MAX_DETECTIONS_PER_IMAGE if necessary
163
+ num_det = len(top_detection_idx)
164
+ detections = torch.cat([boxes, scores, classes.float()], dim=1)
165
+ if num_det < max_det_per_image:
166
+ detections = torch.cat([
167
+ detections,
168
+ torch.zeros((max_det_per_image - num_det, 6), device=detections.device, dtype=detections.dtype)
169
+ ], dim=0)
170
+ return detections
171
+
172
+
173
+ def get_feat_sizes(image_size: Tuple[int, int], max_level: int):
174
+ """Get feat widths and heights for all levels.
175
+ Args:
176
+ image_size: a tuple (H, W)
177
+ max_level: maximum feature level.
178
+ Returns:
179
+ feat_sizes: a list of tuples (height, width) for each level.
180
+ """
181
+ feat_size = image_size
182
+ feat_sizes = [feat_size]
183
+ for _ in range(1, max_level + 1):
184
+ feat_size = ((feat_size[0] - 1) // 2 + 1, (feat_size[1] - 1) // 2 + 1)
185
+ feat_sizes.append(feat_size)
186
+ return feat_sizes
187
+
188
+
189
+ class Anchors(nn.Module):
190
+ """RetinaNet Anchors class."""
191
+
192
+ def __init__(self, min_level, max_level, num_scales, aspect_ratios, anchor_scale, image_size: Tuple[int, int]):
193
+ """Constructs multiscale RetinaNet anchors.
194
+
195
+ Args:
196
+ min_level: integer number of minimum level of the output feature pyramid.
197
+
198
+ max_level: integer number of maximum level of the output feature pyramid.
199
+
200
+ num_scales: integer number representing intermediate scales added
201
+ on each level. For instances, num_scales=2 adds two additional
202
+ anchor scales [2^0, 2^0.5] on each level.
203
+
204
+ aspect_ratios: list of tuples representing the aspect ratio anchors added
205
+ on each level. For instances, aspect_ratios =
206
+ [(1, 1), (1.4, 0.7), (0.7, 1.4)] adds three anchors on each level.
207
+
208
+ anchor_scale: float number representing the scale of size of the base
209
+ anchor to the feature stride 2^level.
210
+
211
+ image_size: Sequence specifying input image size of model (H, W).
212
+ The image_size should be divided by the largest feature stride 2^max_level.
213
+ """
214
+ super(Anchors, self).__init__()
215
+ self.min_level = min_level
216
+ self.max_level = max_level
217
+ self.num_scales = num_scales
218
+ self.aspect_ratios = aspect_ratios
219
+ if isinstance(anchor_scale, Sequence):
220
+ assert len(anchor_scale) == max_level - min_level + 1
221
+ self.anchor_scales = anchor_scale
222
+ else:
223
+ self.anchor_scales = [anchor_scale] * (max_level - min_level + 1)
224
+
225
+ assert isinstance(image_size, Sequence) and len(image_size) == 2
226
+ # FIXME this restriction can likely be relaxed with some additional changes
227
+ assert image_size[0] % 2 ** max_level == 0, 'Image size must be divisible by 2 ** max_level (128)'
228
+ assert image_size[1] % 2 ** max_level == 0, 'Image size must be divisible by 2 ** max_level (128)'
229
+ self.image_size = tuple(image_size)
230
+ self.feat_sizes = get_feat_sizes(image_size, max_level)
231
+ self.config = self._generate_configs()
232
+ self.register_buffer('boxes', self._generate_boxes())
233
+
234
+ @classmethod
235
+ def from_config(cls, config):
236
+ return cls(
237
+ config.min_level, config.max_level,
238
+ config.num_scales, config.aspect_ratios,
239
+ config.anchor_scale, config.image_size)
240
+
241
+ def _generate_configs(self):
242
+ """Generate configurations of anchor boxes."""
243
+ anchor_configs = {}
244
+ feat_sizes = self.feat_sizes
245
+ for level in range(self.min_level, self.max_level + 1):
246
+ anchor_configs[level] = []
247
+ for scale_octave in range(self.num_scales):
248
+ for aspect in self.aspect_ratios:
249
+ anchor_configs[level].append(
250
+ ((feat_sizes[0][0] // feat_sizes[level][0],
251
+ feat_sizes[0][1] // feat_sizes[level][1]),
252
+ scale_octave / float(self.num_scales), aspect,
253
+ self.anchor_scales[level - self.min_level]))
254
+ return anchor_configs
255
+
256
+ def _generate_boxes(self):
257
+ """Generates multiscale anchor boxes."""
258
+ boxes_all = []
259
+ for _, configs in self.config.items():
260
+ boxes_level = []
261
+ for config in configs:
262
+ stride, octave_scale, aspect, anchor_scale = config
263
+ base_anchor_size_x = anchor_scale * stride[1] * 2 ** octave_scale
264
+ base_anchor_size_y = anchor_scale * stride[0] * 2 ** octave_scale
265
+ if isinstance(aspect, Sequence):
266
+ aspect_x = aspect[0]
267
+ aspect_y = aspect[1]
268
+ else:
269
+ aspect_x = np.sqrt(aspect)
270
+ aspect_y = 1.0 / aspect_x
271
+ anchor_size_x_2 = base_anchor_size_x * aspect_x / 2.0
272
+ anchor_size_y_2 = base_anchor_size_y * aspect_y / 2.0
273
+
274
+ x = np.arange(stride[1] / 2, self.image_size[1], stride[1])
275
+ y = np.arange(stride[0] / 2, self.image_size[0], stride[0])
276
+ xv, yv = np.meshgrid(x, y)
277
+ xv = xv.reshape(-1)
278
+ yv = yv.reshape(-1)
279
+
280
+ boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
281
+ yv + anchor_size_y_2, xv + anchor_size_x_2))
282
+ boxes = np.swapaxes(boxes, 0, 1)
283
+ boxes_level.append(np.expand_dims(boxes, axis=1))
284
+
285
+ # concat anchors on the same level to the reshape NxAx4
286
+ boxes_level = np.concatenate(boxes_level, axis=1)
287
+ boxes_all.append(boxes_level.reshape([-1, 4]))
288
+
289
+ anchor_boxes = np.vstack(boxes_all)
290
+ anchor_boxes = torch.from_numpy(anchor_boxes).float()
291
+ return anchor_boxes
292
+
293
+ def get_anchors_per_location(self):
294
+ return self.num_scales * len(self.aspect_ratios)
295
+
296
+
297
+ class AnchorLabeler(object):
298
+ """Labeler for multiscale anchor boxes.
299
+ """
300
+
301
+ def __init__(self, anchors, num_classes: int, match_threshold: float = 0.5):
302
+ """Constructs anchor labeler to assign labels to anchors.
303
+
304
+ Args:
305
+ anchors: an instance of class Anchors.
306
+
307
+ num_classes: integer number representing number of classes in the dataset.
308
+
309
+ match_threshold: float number between 0 and 1 representing the threshold
310
+ to assign positive labels for anchors.
311
+ """
312
+ similarity_calc = IouSimilarity()
313
+ matcher = ArgMaxMatcher(
314
+ match_threshold,
315
+ unmatched_threshold=match_threshold,
316
+ negatives_lower_than_unmatched=True,
317
+ force_match_for_each_row=True)
318
+ box_coder = FasterRcnnBoxCoder()
319
+
320
+ self.target_assigner = TargetAssigner(similarity_calc, matcher, box_coder)
321
+ self.anchors = anchors
322
+ self.match_threshold = match_threshold
323
+ self.num_classes = num_classes
324
+ self.indices_cache = {}
325
+
326
+ def label_anchors(self, gt_boxes, gt_classes, filter_valid=True):
327
+ """Labels anchors with ground truth inputs.
328
+
329
+ Args:
330
+ gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
331
+ For each row, it stores [y0, x0, y1, x1] for four corners of a box.
332
+
333
+ gt_classes: A integer tensor with shape [N, 1] representing groundtruth classes.
334
+
335
+ filter_valid: Filter out any boxes w/ gt class <= -1 before assigning
336
+
337
+ Returns:
338
+ cls_targets_dict: ordered dictionary with keys [min_level, min_level+1, ..., max_level].
339
+ The values are tensor with shape [height_l, width_l, num_anchors]. The height_l and width_l
340
+ represent the dimension of class logits at l-th level.
341
+
342
+ box_targets_dict: ordered dictionary with keys [min_level, min_level+1, ..., max_level].
343
+ The values are tensor with shape [height_l, width_l, num_anchors * 4]. The height_l and
344
+ width_l represent the dimension of bounding box regression output at l-th level.
345
+
346
+ num_positives: scalar tensor storing number of positives in an image.
347
+ """
348
+ cls_targets_out = []
349
+ box_targets_out = []
350
+
351
+ if filter_valid:
352
+ valid_idx = gt_classes > -1 # filter gt targets w/ label <= -1
353
+ gt_boxes = gt_boxes[valid_idx]
354
+ gt_classes = gt_classes[valid_idx]
355
+
356
+ cls_targets, box_targets, matches = self.target_assigner.assign(
357
+ BoxList(self.anchors.boxes), BoxList(gt_boxes), gt_classes)
358
+
359
+ # class labels start from 1 and the background class = -1
360
+ cls_targets = (cls_targets - 1).long()
361
+
362
+ # Unpack labels.
363
+ """Unpacks an array of cls/box into multiple scales."""
364
+ count = 0
365
+ for level in range(self.anchors.min_level, self.anchors.max_level + 1):
366
+ feat_size = self.anchors.feat_sizes[level]
367
+ steps = feat_size[0] * feat_size[1] * self.anchors.get_anchors_per_location()
368
+ cls_targets_out.append(cls_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
369
+ box_targets_out.append(box_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
370
+ count += steps
371
+
372
+ num_positives = (matches.match_results > -1).float().sum()
373
+
374
+ return cls_targets_out, box_targets_out, num_positives
375
+
376
+ def batch_label_anchors(self, gt_boxes, gt_classes, filter_valid=True):
377
+ batch_size = len(gt_boxes)
378
+ assert batch_size == len(gt_classes)
379
+ num_levels = self.anchors.max_level - self.anchors.min_level + 1
380
+ cls_targets_out = [[] for _ in range(num_levels)]
381
+ box_targets_out = [[] for _ in range(num_levels)]
382
+ num_positives_out = []
383
+
384
+ anchor_box_list = BoxList(self.anchors.boxes)
385
+ for i in range(batch_size):
386
+ last_sample = i == batch_size - 1
387
+
388
+ if filter_valid:
389
+ valid_idx = gt_classes[i] > -1 # filter gt targets w/ label <= -1
390
+ gt_box_list = BoxList(gt_boxes[i][valid_idx])
391
+ gt_class_i = gt_classes[i][valid_idx]
392
+ else:
393
+ gt_box_list = BoxList(gt_boxes[i])
394
+ gt_class_i = gt_classes[i]
395
+ cls_targets, box_targets, matches = self.target_assigner.assign(anchor_box_list, gt_box_list, gt_class_i)
396
+
397
+ # class labels start from 1 and the background class = -1
398
+ cls_targets = (cls_targets - 1).long()
399
+
400
+ # Unpack labels.
401
+ """Unpacks an array of cls/box into multiple scales."""
402
+ count = 0
403
+ for level in range(self.anchors.min_level, self.anchors.max_level + 1):
404
+ level_idx = level - self.anchors.min_level
405
+ feat_size = self.anchors.feat_sizes[level]
406
+ steps = feat_size[0] * feat_size[1] * self.anchors.get_anchors_per_location()
407
+ cls_targets_out[level_idx].append(
408
+ cls_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
409
+ box_targets_out[level_idx].append(
410
+ box_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
411
+ count += steps
412
+ if last_sample:
413
+ cls_targets_out[level_idx] = torch.stack(cls_targets_out[level_idx])
414
+ box_targets_out[level_idx] = torch.stack(box_targets_out[level_idx])
415
+
416
+ num_positives_out.append((matches.match_results > -1).float().sum())
417
+ if last_sample:
418
+ num_positives_out = torch.stack(num_positives_out)
419
+
420
+ return cls_targets_out, box_targets_out, num_positives_out
421
+
efficientdet/effdet/bench.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch EfficientDet support benches
2
+
3
+ Hacked together by Ross Wightman
4
+ """
5
+ from typing import Optional, Dict, List
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.utils import ModelEma
9
+ from .anchors import Anchors, AnchorLabeler, generate_detections, MAX_DETECTION_POINTS
10
+ from .loss import DetectionLoss
11
+
12
+
13
+ def _post_process(
14
+ cls_outputs: List[torch.Tensor],
15
+ box_outputs: List[torch.Tensor],
16
+ num_levels: int,
17
+ num_classes: int,
18
+ max_detection_points: int = MAX_DETECTION_POINTS,
19
+ ):
20
+ """Selects top-k predictions.
21
+
22
+ Post-proc code adapted from Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet
23
+ and optimized for PyTorch.
24
+
25
+ Args:
26
+ cls_outputs: an OrderDict with keys representing levels and values
27
+ representing logits in [batch_size, height, width, num_anchors].
28
+
29
+ box_outputs: an OrderDict with keys representing levels and values
30
+ representing box regression targets in [batch_size, height, width, num_anchors * 4].
31
+
32
+ num_levels (int): number of feature levels
33
+
34
+ num_classes (int): number of output classes
35
+ """
36
+ batch_size = cls_outputs[0].shape[0]
37
+ cls_outputs_all = torch.cat([
38
+ cls_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, num_classes])
39
+ for level in range(num_levels)], 1)
40
+
41
+ box_outputs_all = torch.cat([
42
+ box_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, 4])
43
+ for level in range(num_levels)], 1)
44
+
45
+ _, cls_topk_indices_all = torch.topk(cls_outputs_all.reshape(batch_size, -1), dim=1, k=max_detection_points)
46
+ indices_all = cls_topk_indices_all // num_classes
47
+ classes_all = cls_topk_indices_all % num_classes
48
+
49
+ box_outputs_all_after_topk = torch.gather(
50
+ box_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, 4))
51
+
52
+ cls_outputs_all_after_topk = torch.gather(
53
+ cls_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, num_classes))
54
+ cls_outputs_all_after_topk = torch.gather(
55
+ cls_outputs_all_after_topk, 2, classes_all.unsqueeze(2))
56
+
57
+ return cls_outputs_all_after_topk, box_outputs_all_after_topk, indices_all, classes_all
58
+
59
+
60
+ @torch.jit.script
61
+ def _batch_detection(
62
+ batch_size: int, class_out, box_out, anchor_boxes, indices, classes,
63
+ img_scale: Optional[torch.Tensor] = None, img_size: Optional[torch.Tensor] = None):
64
+ batch_detections = []
65
+ # FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome
66
+ for i in range(batch_size):
67
+ img_scale_i = None if img_scale is None else img_scale[i]
68
+ img_size_i = None if img_size is None else img_size[i]
69
+ detections = generate_detections(
70
+ class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale_i, img_size_i)
71
+ batch_detections.append(detections)
72
+ return torch.stack(batch_detections, dim=0)
73
+
74
+
75
+ class DetBenchPredict(nn.Module):
76
+ def __init__(self, model):
77
+ super(DetBenchPredict, self).__init__()
78
+ self.model = model
79
+ self.config = model.config # FIXME remove this when we can use @property (torchscript limitation)
80
+ self.num_levels = model.config.num_levels
81
+ self.num_classes = model.config.num_classes
82
+ self.anchors = Anchors.from_config(model.config)
83
+
84
+ def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None):
85
+ class_out, box_out = self.model(x)
86
+ class_out, box_out, indices, classes = _post_process(
87
+ class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
88
+ if img_info is None:
89
+ img_scale, img_size = None, None
90
+ else:
91
+ img_scale, img_size = img_info['img_scale'], img_info['img_size']
92
+ return _batch_detection(
93
+ x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scale, img_size)
94
+
95
+
96
+ class DetBenchTrain(nn.Module):
97
+ def __init__(self, model, create_labeler=True):
98
+ super(DetBenchTrain, self).__init__()
99
+ self.model = model
100
+ self.config = model.config # FIXME remove this when we can use @property (torchscript limitation)
101
+ self.num_levels = model.config.num_levels
102
+ self.num_classes = model.config.num_classes
103
+ self.anchors = Anchors.from_config(model.config)
104
+ self.anchor_labeler = None
105
+ if create_labeler:
106
+ self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
107
+ self.loss_fn = DetectionLoss(model.config)
108
+
109
+ def forward(self, x, target: Dict[str, torch.Tensor]):
110
+ class_out, box_out = self.model(x)
111
+ if self.anchor_labeler is None:
112
+ # target should contain pre-computed anchor labels if labeler not present in bench
113
+ assert 'label_num_positives' in target
114
+ cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)]
115
+ box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)]
116
+ num_positives = target['label_num_positives']
117
+ else:
118
+ cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
119
+ target['bbox'], target['cls'])
120
+
121
+ loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
122
+ output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss}
123
+ if not self.training:
124
+ # if eval mode, output detections for evaluation
125
+ class_out_pp, box_out_pp, indices, classes = _post_process(
126
+ class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
127
+ output['detections'] = _batch_detection(
128
+ x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes,
129
+ target['img_scale'], target['img_size'])
130
+ return output
131
+
132
+
133
+ def unwrap_bench(model):
134
+ # Unwrap a model in support bench so that various other fns can access the weights and attribs of the
135
+ # underlying model directly
136
+ if isinstance(model, ModelEma): # unwrap ModelEma
137
+ return unwrap_bench(model.ema)
138
+ elif hasattr(model, 'module'): # unwrap DDP
139
+ return unwrap_bench(model.module)
140
+ elif hasattr(model, 'model'): # unwrap Bench -> model
141
+ return unwrap_bench(model.model)
142
+ else:
143
+ return model
efficientdet/effdet/config/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .config_utils import set_config_readonly, set_config_writeable
2
+ from .fpn_config import get_fpn_config
3
+ from .model_config import get_efficientdet_config, default_detection_model_configs
4
+ from .train_config import default_detection_train_config
efficientdet/effdet/config/config_utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+
3
+
4
+ def set_config_readonly(conf):
5
+ OmegaConf.set_readonly(conf, True)
6
+
7
+
8
+ def set_config_writeable(conf):
9
+ OmegaConf.set_readonly(conf, False)
efficientdet/effdet/config/fpn_config.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+ from omegaconf import OmegaConf
4
+
5
+
6
+ def bifpn_config(min_level, max_level, weight_method=None):
7
+ """BiFPN config.
8
+ Adapted from https://github.com/google/automl/blob/56815c9986ffd4b508fe1d68508e268d129715c1/efficientdet/keras/fpn_configs.py
9
+ """
10
+ p = OmegaConf.create()
11
+ weight_method = weight_method or 'fastattn'
12
+
13
+ num_levels = max_level - min_level + 1
14
+ node_ids = {min_level + i: [i] for i in range(num_levels)}
15
+
16
+ level_last_id = lambda level: node_ids[level][-1]
17
+ level_all_ids = lambda level: node_ids[level]
18
+ id_cnt = itertools.count(num_levels)
19
+
20
+ p.nodes = []
21
+ for i in range(max_level - 1, min_level - 1, -1):
22
+ # top-down path.
23
+ p.nodes.append({
24
+ 'reduction': 1 << i,
25
+ 'inputs_offsets': [level_last_id(i), level_last_id(i + 1)],
26
+ 'weight_method': weight_method,
27
+ })
28
+ node_ids[i].append(next(id_cnt))
29
+
30
+ for i in range(min_level + 1, max_level + 1):
31
+ # bottom-up path.
32
+ p.nodes.append({
33
+ 'reduction': 1 << i,
34
+ 'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)],
35
+ 'weight_method': weight_method,
36
+ })
37
+ node_ids[i].append(next(id_cnt))
38
+ return p
39
+
40
+
41
+ def panfpn_config(min_level, max_level, weight_method=None):
42
+ """PAN FPN config.
43
+
44
+ This defines FPN layout from Path Aggregation Networks as an alternate to
45
+ BiFPN, it does not implement the full PAN spec.
46
+
47
+ Paper: https://arxiv.org/abs/1803.01534
48
+ """
49
+ p = OmegaConf.create()
50
+ weight_method = weight_method or 'fastattn'
51
+
52
+ num_levels = max_level - min_level + 1
53
+ node_ids = {min_level + i: [i] for i in range(num_levels)}
54
+ level_last_id = lambda level: node_ids[level][-1]
55
+ id_cnt = itertools.count(num_levels)
56
+
57
+ p.nodes = []
58
+ for i in range(max_level, min_level - 1, -1):
59
+ # top-down path.
60
+ offsets = [level_last_id(i), level_last_id(i + 1)] if i != max_level else [level_last_id(i)]
61
+ p.nodes.append({
62
+ 'reduction': 1 << i,
63
+ 'inputs_offsets': offsets,
64
+ 'weight_method': weight_method,
65
+ })
66
+ node_ids[i].append(next(id_cnt))
67
+
68
+ for i in range(min_level, max_level + 1):
69
+ # bottom-up path.
70
+ offsets = [level_last_id(i), level_last_id(i - 1)] if i != min_level else [level_last_id(i)]
71
+ p.nodes.append({
72
+ 'reduction': 1 << i,
73
+ 'inputs_offsets': offsets,
74
+ 'weight_method': weight_method,
75
+ })
76
+ node_ids[i].append(next(id_cnt))
77
+
78
+ return p
79
+
80
+
81
+ def qufpn_config(min_level, max_level, weight_method=None):
82
+ """A dynamic quad fpn config that can adapt to different min/max levels.
83
+
84
+ It extends the idea of BiFPN, and has four paths:
85
+ (up_down -> bottom_up) + (bottom_up -> up_down).
86
+
87
+ Paper: https://ieeexplore.ieee.org/document/9225379
88
+ Ref code: From contribution to TF EfficientDet
89
+ https://github.com/google/automl/blob/eb74c6739382e9444817d2ad97c4582dbe9a9020/efficientdet/keras/fpn_configs.py
90
+ """
91
+ p = OmegaConf.create()
92
+ weight_method = weight_method or 'fastattn'
93
+ quad_method = 'fastattn'
94
+ num_levels = max_level - min_level + 1
95
+ node_ids = {min_level + i: [i] for i in range(num_levels)}
96
+ level_last_id = lambda level: node_ids[level][-1]
97
+ level_all_ids = lambda level: node_ids[level]
98
+ level_first_id = lambda level: node_ids[level][0]
99
+ id_cnt = itertools.count(num_levels)
100
+
101
+ p.nodes = []
102
+ for i in range(max_level - 1, min_level - 1, -1):
103
+ # top-down path 1.
104
+ p.nodes.append({
105
+ 'reduction': 1 << i,
106
+ 'inputs_offsets': [level_last_id(i), level_last_id(i + 1)],
107
+ 'weight_method': weight_method
108
+ })
109
+ node_ids[i].append(next(id_cnt))
110
+ node_ids[max_level].append(node_ids[max_level][-1])
111
+
112
+ for i in range(min_level + 1, max_level):
113
+ # bottom-up path 2.
114
+ p.nodes.append({
115
+ 'reduction': 1 << i,
116
+ 'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)],
117
+ 'weight_method': weight_method
118
+ })
119
+ node_ids[i].append(next(id_cnt))
120
+
121
+ i = max_level
122
+ p.nodes.append({
123
+ 'reduction': 1 << i,
124
+ 'inputs_offsets': [level_first_id(i)] + [level_last_id(i - 1)],
125
+ 'weight_method': weight_method
126
+ })
127
+ node_ids[i].append(next(id_cnt))
128
+ node_ids[min_level].append(node_ids[min_level][-1])
129
+
130
+ for i in range(min_level + 1, max_level + 1, 1):
131
+ # bottom-up path 3.
132
+ p.nodes.append({
133
+ 'reduction': 1 << i,
134
+ 'inputs_offsets': [
135
+ level_first_id(i), level_last_id(i - 1) if i != min_level + 1 else level_first_id(i - 1)],
136
+ 'weight_method': weight_method
137
+ })
138
+ node_ids[i].append(next(id_cnt))
139
+ node_ids[min_level].append(node_ids[min_level][-1])
140
+
141
+ for i in range(max_level - 1, min_level, -1):
142
+ # top-down path 4.
143
+ p.nodes.append({
144
+ 'reduction': 1 << i,
145
+ 'inputs_offsets': [node_ids[i][0]] + [node_ids[i][-1]] + [level_last_id(i + 1)],
146
+ 'weight_method': weight_method
147
+ })
148
+ node_ids[i].append(next(id_cnt))
149
+ i = min_level
150
+ p.nodes.append({
151
+ 'reduction': 1 << i,
152
+ 'inputs_offsets': [node_ids[i][0]] + [level_last_id(i + 1)],
153
+ 'weight_method': weight_method
154
+ })
155
+ node_ids[i].append(next(id_cnt))
156
+ node_ids[max_level].append(node_ids[max_level][-1])
157
+
158
+ # NOTE: the order of the quad path is reversed from the original, my code expects the output of
159
+ # each FPN repeat to be same as input from backbone, in order of increasing reductions
160
+ for i in range(min_level, max_level + 1):
161
+ # quad-add path.
162
+ p.nodes.append({
163
+ 'reduction': 1 << i,
164
+ 'inputs_offsets': [node_ids[i][2], node_ids[i][4]],
165
+ 'weight_method': quad_method
166
+ })
167
+ node_ids[i].append(next(id_cnt))
168
+
169
+ return p
170
+
171
+
172
+ def get_fpn_config(fpn_name, min_level=3, max_level=7):
173
+ if not fpn_name:
174
+ fpn_name = 'bifpn_fa'
175
+ name_to_config = {
176
+ 'bifpn_sum': bifpn_config(min_level=min_level, max_level=max_level, weight_method='sum'),
177
+ 'bifpn_attn': bifpn_config(min_level=min_level, max_level=max_level, weight_method='attn'),
178
+ 'bifpn_fa': bifpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'),
179
+ 'pan_sum': panfpn_config(min_level=min_level, max_level=max_level, weight_method='sum'),
180
+ 'pan_fa': panfpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'),
181
+ 'qufpn_sum': qufpn_config(min_level=min_level, max_level=max_level, weight_method='sum'),
182
+ 'qufpn_fa': qufpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'),
183
+ }
184
+ return name_to_config[fpn_name]
efficientdet/effdet/config/model_config.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EfficientDet Configurations
2
+
3
+ Adapted from official impl at https://github.com/google/automl/tree/master/efficientdet
4
+
5
+ TODO use a different config system (OmegaConfig -> Hydra?), separate model from train specific hparams
6
+ """
7
+
8
+ from omegaconf import OmegaConf
9
+ from copy import deepcopy
10
+
11
+
12
+ def default_detection_model_configs():
13
+ """Returns a default detection configs."""
14
+ h = OmegaConf.create()
15
+
16
+ # model name.
17
+ h.name = 'tf_efficientdet_d1'
18
+
19
+ h.backbone_name = 'tf_efficientnet_b1'
20
+ h.backbone_args = None # FIXME sort out kwargs vs config for backbone creation
21
+
22
+ # model specific, input preprocessing parameters
23
+ h.image_size = (640, 640)
24
+
25
+ # dataset specific head parameters
26
+ h.num_classes = 90
27
+
28
+ # feature + anchor config
29
+ h.min_level = 3
30
+ h.max_level = 7
31
+ h.num_levels = h.max_level - h.min_level + 1
32
+ h.num_scales = 3
33
+ h.aspect_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
34
+ # ratio w/h: 2.0 means w=1.4, h=0.7. Can be computed with k-mean per dataset.
35
+ #h.aspect_ratios = [1.0, 2.0, 0.5]
36
+ h.anchor_scale = 4.0
37
+
38
+ # FPN and head config
39
+ h.pad_type = 'same' # original TF models require an equivalent of Tensorflow 'SAME' padding
40
+ h.act_type = 'swish'
41
+ h.norm_layer = None # defaults to batch norm when None
42
+ h.norm_kwargs = dict(eps=.001, momentum=.01)
43
+ h.box_class_repeats = 3
44
+ h.fpn_cell_repeats = 3
45
+ h.fpn_channels = 88
46
+ h.separable_conv = True
47
+ h.apply_bn_for_resampling = True
48
+ h.conv_after_downsample = False
49
+ h.conv_bn_relu_pattern = False
50
+ h.use_native_resize_op = False
51
+ h.pooling_type = None
52
+ h.redundant_bias = True # original TF models have back to back bias + BN layers, not necessary!
53
+ h.head_bn_level_first = False # change order of BN in head repeat list of lists, True for torchscript compat
54
+
55
+ h.fpn_name = None
56
+ h.fpn_config = None
57
+ h.fpn_drop_path_rate = 0. # No stochastic depth in default. NOTE not currently used, unstable training
58
+
59
+ # classification loss (used by train bench)
60
+ h.alpha = 0.25
61
+ h.gamma = 1.5
62
+ h.label_smoothing = 0. # only supported if new_focal == True
63
+ h.new_focal = False # use new focal loss (supports label smoothing but uses more mem, less optimal w/ jit script)
64
+ h.jit_loss = False # torchscript jit for loss fn speed improvement, can impact stability and/or increase mem usage
65
+
66
+ # localization loss (used by train bench)
67
+ h.delta = 0.1
68
+ h.box_loss_weight = 50.0
69
+
70
+ return h
71
+
72
+
73
+ efficientdet_model_param_dict = dict(
74
+ # Models with PyTorch friendly padding and my PyTorch pretrained backbones, training TBD
75
+ efficientdet_d0=dict(
76
+ name='efficientdet_d0',
77
+ backbone_name='efficientnet_b0',
78
+ image_size=(512, 512),
79
+ fpn_channels=64,
80
+ fpn_cell_repeats=3,
81
+ box_class_repeats=3,
82
+ pad_type='',
83
+ redundant_bias=False,
84
+ backbone_args=dict(drop_path_rate=0.1),
85
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_d0-f3276ba8.pth',
86
+ ),
87
+ efficientdet_d1=dict(
88
+ name='efficientdet_d1',
89
+ backbone_name='efficientnet_b1',
90
+ image_size=(640, 640),
91
+ fpn_channels=88,
92
+ fpn_cell_repeats=4,
93
+ box_class_repeats=3,
94
+ pad_type='',
95
+ redundant_bias=False,
96
+ backbone_args=dict(drop_path_rate=0.2),
97
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_d1-bb7e98fe.pth',
98
+ ),
99
+ efficientdet_d2=dict(
100
+ name='efficientdet_d2',
101
+ backbone_name='efficientnet_b2',
102
+ image_size=(768, 768),
103
+ fpn_channels=112,
104
+ fpn_cell_repeats=5,
105
+ box_class_repeats=3,
106
+ pad_type='',
107
+ redundant_bias=False,
108
+ backbone_args=dict(drop_path_rate=0.2),
109
+ url='', # no pretrained weights yet
110
+ ),
111
+ efficientdet_d3=dict(
112
+ name='efficientdet_d3',
113
+ backbone_name='efficientnet_b3',
114
+ image_size=(896, 896),
115
+ fpn_channels=160,
116
+ fpn_cell_repeats=6,
117
+ box_class_repeats=4,
118
+ pad_type='',
119
+ redundant_bias=False,
120
+ backbone_args=dict(drop_path_rate=0.2),
121
+ url='', # no pretrained weights yet
122
+ ),
123
+ efficientdet_d4=dict(
124
+ name='efficientdet_d4',
125
+ backbone_name='efficientnet_b4',
126
+ image_size=(1024, 1024),
127
+ fpn_channels=224,
128
+ fpn_cell_repeats=7,
129
+ box_class_repeats=4,
130
+ backbone_args=dict(drop_path_rate=0.2),
131
+ ),
132
+ efficientdet_d5=dict(
133
+ name='efficientdet_d5',
134
+ backbone_name='efficientnet_b5',
135
+ image_size=(1280, 1280),
136
+ fpn_channels=288,
137
+ fpn_cell_repeats=7,
138
+ box_class_repeats=4,
139
+ backbone_args=dict(drop_path_rate=0.2),
140
+ url='',
141
+ ),
142
+
143
+ # My own experimental configs with alternate models, training TBD
144
+ # Note: any 'timm' model in the EfficientDet family can be used as a backbone here.
145
+ resdet50=dict(
146
+ name='resdet50',
147
+ backbone_name='resnet50',
148
+ image_size=(640, 640),
149
+ fpn_channels=88,
150
+ fpn_cell_repeats=4,
151
+ box_class_repeats=3,
152
+ pad_type='',
153
+ act_type='relu',
154
+ redundant_bias=False,
155
+ separable_conv=False,
156
+ backbone_args=dict(drop_path_rate=0.2),
157
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/resdet50_416-08676892.pth',
158
+ ),
159
+ cspresdet50=dict(
160
+ name='cspresdet50',
161
+ backbone_name='cspresnet50',
162
+ image_size=(640, 640),
163
+ aspect_ratios=[1.0, 2.0, 0.5],
164
+ fpn_channels=88,
165
+ fpn_cell_repeats=4,
166
+ box_class_repeats=3,
167
+ pad_type='',
168
+ act_type='leaky_relu',
169
+ redundant_bias=False,
170
+ separable_conv=False,
171
+ head_bn_level_first=True,
172
+ backbone_args=dict(drop_path_rate=0.2),
173
+ url='',
174
+ ),
175
+ cspresdext50=dict(
176
+ name='cspresdext50',
177
+ backbone_name='cspresnext50',
178
+ image_size=(640, 640),
179
+ aspect_ratios=[1.0, 2.0, 0.5],
180
+ fpn_channels=88,
181
+ fpn_cell_repeats=4,
182
+ box_class_repeats=3,
183
+ pad_type='',
184
+ act_type='leaky_relu',
185
+ redundant_bias=False,
186
+ separable_conv=False,
187
+ head_bn_level_first=True,
188
+ backbone_args=dict(drop_path_rate=0.2),
189
+ url='',
190
+ ),
191
+ cspresdext50pan=dict(
192
+ name='cspresdext50pan',
193
+ backbone_name='cspresnext50',
194
+ image_size=(640, 640),
195
+ aspect_ratios=[1.0, 2.0, 0.5],
196
+ fpn_channels=88,
197
+ fpn_cell_repeats=3,
198
+ box_class_repeats=3,
199
+ pad_type='',
200
+ act_type='leaky_relu',
201
+ fpn_name='pan_fa', # PAN FPN experiment
202
+ redundant_bias=False,
203
+ separable_conv=False,
204
+ head_bn_level_first=True,
205
+ backbone_args=dict(drop_path_rate=0.2),
206
+ url='',
207
+ ),
208
+ cspdarkdet53=dict(
209
+ name='cspdarkdet53',
210
+ backbone_name='cspdarknet53',
211
+ image_size=(640, 640),
212
+ aspect_ratios=[1.0, 2.0, 0.5],
213
+ fpn_channels=88,
214
+ fpn_cell_repeats=4,
215
+ box_class_repeats=3,
216
+ pad_type='',
217
+ act_type='leaky_relu',
218
+ redundant_bias=False,
219
+ separable_conv=False,
220
+ head_bn_level_first=True,
221
+ backbone_args=dict(drop_path_rate=0.2),
222
+ url='',
223
+ ),
224
+ mixdet_m=dict(
225
+ name='mixdet_m',
226
+ backbone_name='mixnet_m',
227
+ image_size=(512, 512),
228
+ aspect_ratios=[1.0, 2.0, 0.5],
229
+ fpn_channels=64,
230
+ fpn_cell_repeats=3,
231
+ box_class_repeats=3,
232
+ pad_type='',
233
+ redundant_bias=False,
234
+ head_bn_level_first=True,
235
+ backbone_args=dict(drop_path_rate=0.1),
236
+ url='', # no pretrained weights yet
237
+ ),
238
+ mixdet_l=dict(
239
+ name='mixdet_l',
240
+ backbone_name='mixnet_l',
241
+ image_size=(640, 640),
242
+ aspect_ratios=[1.0, 2.0, 0.5],
243
+ fpn_channels=88,
244
+ fpn_cell_repeats=4,
245
+ box_class_repeats=3,
246
+ pad_type='',
247
+ redundant_bias=False,
248
+ head_bn_level_first=True,
249
+ backbone_args=dict(drop_path_rate=0.2),
250
+ url='', # no pretrained weights yet
251
+ ),
252
+ mobiledetv2_110d=dict(
253
+ name='mobiledetv2_110d',
254
+ backbone_name='mobilenetv2_110d',
255
+ image_size=(384, 384),
256
+ aspect_ratios=[1.0, 2.0, 0.5],
257
+ fpn_channels=48,
258
+ fpn_cell_repeats=3,
259
+ box_class_repeats=3,
260
+ pad_type='',
261
+ act_type='relu6',
262
+ redundant_bias=False,
263
+ head_bn_level_first=True,
264
+ backbone_args=dict(drop_path_rate=0.05),
265
+ url='', # no pretrained weights yet
266
+ ),
267
+ mobiledetv2_120d=dict(
268
+ name='mobiledetv2_120d',
269
+ backbone_name='mobilenetv2_120d',
270
+ image_size=(512, 512),
271
+ aspect_ratios=[1.0, 2.0, 0.5],
272
+ fpn_channels=56,
273
+ fpn_cell_repeats=3,
274
+ box_class_repeats=3,
275
+ pad_type='',
276
+ act_type='relu6',
277
+ redundant_bias=False,
278
+ head_bn_level_first=True,
279
+ backbone_args=dict(drop_path_rate=0.1),
280
+ url='', # no pretrained weights yet
281
+ ),
282
+ mobiledetv3_large=dict(
283
+ name='mobiledetv3_large',
284
+ backbone_name='mobilenetv3_large_100',
285
+ image_size=(512, 512),
286
+ aspect_ratios=[1.0, 2.0, 0.5],
287
+ fpn_channels=64,
288
+ fpn_cell_repeats=3,
289
+ box_class_repeats=3,
290
+ pad_type='',
291
+ act_type='hard_swish',
292
+ redundant_bias=False,
293
+ head_bn_level_first=True,
294
+ backbone_args=dict(drop_path_rate=0.1),
295
+ url='', # no pretrained weights yet
296
+ ),
297
+ efficientdet_q0=dict(
298
+ name='efficientdet_q0',
299
+ backbone_name='efficientnet_b0',
300
+ image_size=(512, 512),
301
+ fpn_channels=64,
302
+ fpn_cell_repeats=3,
303
+ box_class_repeats=3,
304
+ pad_type='',
305
+ fpn_name='qufpn_fa', # quad-fpn + fast attn experiment
306
+ redundant_bias=False,
307
+ head_bn_level_first=True,
308
+ backbone_args=dict(drop_path_rate=0.1),
309
+ url='',
310
+ ),
311
+ efficientdet_w0=dict(
312
+ name='efficientdet_w0', # 'wide'
313
+ backbone_name='efficientnet_b0',
314
+ image_size=(512, 512),
315
+ aspect_ratios=[1.0, 2.0, 0.5],
316
+ fpn_channels=80,
317
+ fpn_cell_repeats=3,
318
+ box_class_repeats=3,
319
+ pad_type='',
320
+ redundant_bias=False,
321
+ head_bn_level_first=True,
322
+ backbone_args=dict(
323
+ drop_path_rate=0.1,
324
+ feature_location='depthwise'), # features from after DW/SE in IR block
325
+ url='', # no pretrained weights yet
326
+ ),
327
+ efficientdet_es=dict(
328
+ name='efficientdet_es', #EdgeTPU-Small
329
+ backbone_name='efficientnet_es',
330
+ image_size=(512, 512),
331
+ aspect_ratios=[1.0, 2.0, 0.5],
332
+ fpn_channels=72,
333
+ fpn_cell_repeats=3,
334
+ box_class_repeats=3,
335
+ pad_type='',
336
+ act_type='relu',
337
+ redundant_bias=False,
338
+ head_bn_level_first=True,
339
+ separable_conv=False,
340
+ backbone_args=dict(drop_path_rate=0.1),
341
+ url='',
342
+ ),
343
+ efficientdet_em=dict(
344
+ name='efficientdet_em', # Edge-TPU Medium
345
+ backbone_name='efficientnet_em',
346
+ image_size=(640, 640),
347
+ aspect_ratios=[1.0, 2.0, 0.5],
348
+ fpn_channels=96,
349
+ fpn_cell_repeats=4,
350
+ box_class_repeats=3,
351
+ pad_type='',
352
+ act_type='relu',
353
+ redundant_bias=False,
354
+ head_bn_level_first=True,
355
+ separable_conv=False,
356
+ backbone_args=dict(drop_path_rate=0.2),
357
+ url='', # no pretrained weights yet
358
+ ),
359
+ efficientdet_lite0=dict(
360
+ name='efficientdet_lite0',
361
+ backbone_name='efficientnet_lite0',
362
+ image_size=(512, 512),
363
+ fpn_channels=64,
364
+ fpn_cell_repeats=3,
365
+ box_class_repeats=3,
366
+ act_type='relu',
367
+ redundant_bias=False,
368
+ head_bn_level_first=True,
369
+ backbone_args=dict(drop_path_rate=0.1),
370
+ url='',
371
+ ),
372
+
373
+ # Models ported from Tensorflow with pretrained backbones ported from Tensorflow
374
+ tf_efficientdet_d0=dict(
375
+ name='tf_efficientdet_d0',
376
+ backbone_name='tf_efficientnet_b0',
377
+ image_size=(512, 512),
378
+ fpn_channels=64,
379
+ fpn_cell_repeats=3,
380
+ box_class_repeats=3,
381
+ backbone_args=dict(drop_path_rate=0.2),
382
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d0_34-f153e0cf.pth',
383
+ ),
384
+ tf_efficientdet_d1=dict(
385
+ name='tf_efficientdet_d1',
386
+ backbone_name='tf_efficientnet_b1',
387
+ image_size=(640, 640),
388
+ fpn_channels=88,
389
+ fpn_cell_repeats=4,
390
+ box_class_repeats=3,
391
+ backbone_args=dict(drop_path_rate=0.2),
392
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d1_40-a30f94af.pth'
393
+ ),
394
+ tf_efficientdet_d2=dict(
395
+ name='tf_efficientdet_d2',
396
+ backbone_name='tf_efficientnet_b2',
397
+ image_size=(768, 768),
398
+ fpn_channels=112,
399
+ fpn_cell_repeats=5,
400
+ box_class_repeats=3,
401
+ backbone_args=dict(drop_path_rate=0.2),
402
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d2_43-8107aa99.pth',
403
+ ),
404
+ tf_efficientdet_d3=dict(
405
+ name='tf_efficientdet_d3',
406
+ backbone_name='tf_efficientnet_b3',
407
+ image_size=(896, 896),
408
+ fpn_channels=160,
409
+ fpn_cell_repeats=6,
410
+ box_class_repeats=4,
411
+ backbone_args=dict(drop_path_rate=0.2),
412
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d3_47-0b525f35.pth',
413
+ ),
414
+ tf_efficientdet_d4=dict(
415
+ name='tf_efficientdet_d4',
416
+ backbone_name='tf_efficientnet_b4',
417
+ image_size=(1024, 1024),
418
+ fpn_channels=224,
419
+ fpn_cell_repeats=7,
420
+ box_class_repeats=4,
421
+ backbone_args=dict(drop_path_rate=0.2),
422
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d4_49-f56376d9.pth',
423
+ ),
424
+ tf_efficientdet_d5=dict(
425
+ name='tf_efficientdet_d5',
426
+ backbone_name='tf_efficientnet_b5',
427
+ image_size=(1280, 1280),
428
+ fpn_channels=288,
429
+ fpn_cell_repeats=7,
430
+ box_class_repeats=4,
431
+ backbone_args=dict(drop_path_rate=0.2),
432
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d5_51-c79f9be6.pth',
433
+ ),
434
+ tf_efficientdet_d6=dict(
435
+ name='tf_efficientdet_d6',
436
+ backbone_name='tf_efficientnet_b6',
437
+ image_size=(1280, 1280),
438
+ fpn_channels=384,
439
+ fpn_cell_repeats=8,
440
+ box_class_repeats=5,
441
+ fpn_name='bifpn_sum', # Use unweighted sum for training stability.
442
+ backbone_args=dict(drop_path_rate=0.2),
443
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d6_52-4eda3773.pth'
444
+ ),
445
+ tf_efficientdet_d7=dict(
446
+ name='tf_efficientdet_d7',
447
+ backbone_name='tf_efficientnet_b6',
448
+ image_size=(1536, 1536),
449
+ fpn_channels=384,
450
+ fpn_cell_repeats=8,
451
+ box_class_repeats=5,
452
+ anchor_scale=5.0,
453
+ fpn_name='bifpn_sum', # Use unweighted sum for training stability.
454
+ backbone_args=dict(drop_path_rate=0.2),
455
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d7_53-6d1d7a95.pth'
456
+ ),
457
+ tf_efficientdet_d7x=dict(
458
+ name='tf_efficientdet_d7x',
459
+ backbone_name='tf_efficientnet_b7',
460
+ image_size=(1536, 1536),
461
+ fpn_channels=384,
462
+ fpn_cell_repeats=8,
463
+ box_class_repeats=5,
464
+ anchor_scale=4.0,
465
+ max_level=8,
466
+ fpn_name='bifpn_sum', # Use unweighted sum for training stability.
467
+ backbone_args=dict(drop_path_rate=0.2),
468
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d7x-f390b87c.pth'
469
+ ),
470
+
471
+ # The lite configs are in TF automl repository but no weights yet and listed as 'not final'
472
+ tf_efficientdet_lite0=dict(
473
+ name='tf_efficientdet_lite0',
474
+ backbone_name='tf_efficientnet_lite0',
475
+ image_size=(512, 512),
476
+ fpn_channels=64,
477
+ fpn_cell_repeats=3,
478
+ box_class_repeats=3,
479
+ act_type='relu',
480
+ redundant_bias=False,
481
+ backbone_args=dict(drop_path_rate=0.1),
482
+ # unlike other tf_ models, this was not ported from tf automl impl, but trained from tf pretrained efficient lite
483
+ # weights using this code, will likely replace if/when official det-lite weights are released
484
+ url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_lite0-f5f303a9.pth',
485
+ ),
486
+ tf_efficientdet_lite1=dict(
487
+ name='tf_efficientdet_lite1',
488
+ backbone_name='tf_efficientnet_lite1',
489
+ image_size=(640, 640),
490
+ fpn_channels=88,
491
+ fpn_cell_repeats=4,
492
+ box_class_repeats=3,
493
+ act_type='relu',
494
+ backbone_args=dict(drop_path_rate=0.2),
495
+ url='', # no pretrained weights yet
496
+ ),
497
+ tf_efficientdet_lite2=dict(
498
+ name='tf_efficientdet_lite2',
499
+ backbone_name='tf_efficientnet_lite2',
500
+ image_size=(768, 768),
501
+ fpn_channels=112,
502
+ fpn_cell_repeats=5,
503
+ box_class_repeats=3,
504
+ act_type='relu',
505
+ backbone_args=dict(drop_path_rate=0.2),
506
+ url='',
507
+ ),
508
+ tf_efficientdet_lite3=dict(
509
+ name='tf_efficientdet_lite3',
510
+ backbone_name='tf_efficientnet_lite3',
511
+ image_size=(896, 896),
512
+ fpn_channels=160,
513
+ fpn_cell_repeats=6,
514
+ box_class_repeats=4,
515
+ act_type='relu',
516
+ backbone_args=dict(drop_path_rate=0.2),
517
+ url='',
518
+ ),
519
+ tf_efficientdet_lite4=dict(
520
+ name='tf_efficientdet_lite4',
521
+ backbone_name='tf_efficientnet_lite4',
522
+ image_size=(1024, 1024),
523
+ fpn_channels=224,
524
+ fpn_cell_repeats=7,
525
+ box_class_repeats=4,
526
+ act_type='relu',
527
+ backbone_args=dict(drop_path_rate=0.2),
528
+ url='',
529
+ ),
530
+ )
531
+
532
+
533
+ def get_efficientdet_config(model_name='tf_efficientdet_d1'):
534
+ """Get the default config for EfficientDet based on model name."""
535
+ h = default_detection_model_configs()
536
+ h.update(efficientdet_model_param_dict[model_name])
537
+ h.num_levels = h.max_level - h.min_level + 1
538
+ return deepcopy(h) # may be unnecessary, ensure no references to param dict values
efficientdet/effdet/config/train_config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+
3
+
4
+ def default_detection_train_config():
5
+ # FIXME currently using args for train config, will revisit, perhaps move to Hydra
6
+ h = OmegaConf.create()
7
+
8
+ # dataset
9
+ h.skip_crowd_during_training = True
10
+
11
+ # augmentation
12
+ h.input_rand_hflip = True
13
+ h.train_scale_min = 0.1
14
+ h.train_scale_max = 2.0
15
+ h.autoaugment_policy = None
16
+
17
+ # optimization
18
+ h.momentum = 0.9
19
+ h.learning_rate = 0.08
20
+ h.lr_warmup_init = 0.008
21
+ h.lr_warmup_epoch = 1.0
22
+ h.first_lr_drop_epoch = 200.0
23
+ h.second_lr_drop_epoch = 250.0
24
+ h.clip_gradients_norm = 10.0
25
+ h.num_epochs = 300
26
+
27
+ # regularization l2 loss.
28
+ h.weight_decay = 4e-5
29
+
30
+ h.lr_decay_method = 'cosine'
31
+ h.moving_average_decay = 0.9998
32
+ h.ckpt_var_scope = None
33
+
34
+ return h
efficientdet/effdet/data/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .dataset_factory import create_dataset
2
+ from .dataset import DetectionDatset, SkipSubset
3
+ from .input_config import resolve_input_config
4
+ from .loader import create_loader
5
+ from .parsers import create_parser
6
+ from .transforms import *
efficientdet/effdet/data/dataset.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Detection dataset
2
+
3
+ Hacked together by Ross Wightman
4
+ """
5
+ import torch.utils.data as data
6
+ import numpy as np
7
+ import albumentations as A
8
+ import torch
9
+
10
+ from PIL import Image
11
+ from .parsers import create_parser
12
+
13
+
14
+ class DetectionDatset(data.Dataset):
15
+ """`Object Detection Dataset. Use with parsers for COCO, VOC, and OpenImages.
16
+ Args:
17
+ parser (string, Parser):
18
+ transform (callable, optional): A function/transform that takes in an PIL image
19
+ and returns a transformed version. E.g, ``transforms.ToTensor``
20
+
21
+ """
22
+
23
+ def __init__(self, data_dir, parser=None, parser_kwargs=None, transform=None, transforms=None):
24
+ super(DetectionDatset, self).__init__()
25
+ parser_kwargs = parser_kwargs or {}
26
+ self.data_dir = data_dir
27
+ if isinstance(parser, str):
28
+ self._parser = create_parser(parser, **parser_kwargs)
29
+ else:
30
+ assert parser is not None and len(parser.img_ids)
31
+ self._parser = parser
32
+ self._transform = transform
33
+ self._transforms = transforms
34
+
35
+ def __getitem__(self, index):
36
+ """
37
+ Args:
38
+ index (int): Index
39
+ Returns:
40
+ tuple: Tuple (image, annotations (target)).
41
+ """
42
+ img_info = self._parser.img_infos[index]
43
+ target = dict(img_idx=index, img_size=(img_info['width'], img_info['height']))
44
+ if self._parser.has_labels:
45
+ ann = self._parser.get_ann_info(index)
46
+ target.update(ann)
47
+ img_path = self.data_dir / img_info['file_name']
48
+ img = Image.open(img_path).convert('RGB')
49
+ if self.transforms is not None:
50
+ img = torch.as_tensor(np.array(img), dtype=torch.uint8)
51
+ voc_boxes = []
52
+ for coord in target['bbox']:
53
+ xmin = coord[1]
54
+ ymin = coord[0]
55
+ xmax = coord[3]
56
+ ymax = coord[2]
57
+ if xmin<1:
58
+ xmin = 1
59
+ if ymin<1:
60
+ ymin = 1
61
+ if xmax>=img.shape[1]-1:
62
+ xmax = img.shape[1]-1
63
+ if ymax>=img.shape[0]-1:
64
+ ymax = img.shape[0]-1
65
+ voc_boxes.append([xmin, ymin, xmax, ymax])
66
+ transformed = self.transforms(image=np.array(img), bbox_classes=target['cls'], bboxes=voc_boxes)
67
+ img = torch.as_tensor(transformed['image'], dtype=torch.uint8)
68
+ target['bbox'] = []
69
+ for coord in transformed['bboxes']:
70
+ ymin = int(coord[1])
71
+ xmin = int(coord[0])
72
+ ymax = int(coord[3])
73
+ xmax = int(coord[2])
74
+ target['bbox'].append([ymin, xmin, ymax, xmax])
75
+ target['bbox'] = np.array(target['bbox'], dtype=np.float32)
76
+ target['cls'] = np.array(transformed['bbox_classes'])
77
+ img = Image.fromarray(np.array(img).astype('uint8'), 'RGB')
78
+ target['img_size'] = img.size
79
+
80
+ if self.transform is not None:
81
+ img, target = self.transform(img, target)
82
+
83
+ return img, target
84
+
85
+ def __len__(self):
86
+ return len(self._parser.img_ids)
87
+
88
+ @property
89
+ def parser(self):
90
+ return self._parser
91
+
92
+ @property
93
+ def transform(self):
94
+ return self._transform
95
+
96
+ @transform.setter
97
+ def transform(self, t):
98
+ self._transform = t
99
+
100
+ @property
101
+ def transforms(self):
102
+ return self._transforms
103
+
104
+ @transforms.setter
105
+ def transforms(self, t):
106
+ self._transforms = t
107
+
108
+ class SkipSubset(data.Dataset):
109
+ r"""
110
+ Subset of a dataset at specified indices.
111
+
112
+ Arguments:
113
+ dataset (Dataset): The whole Dataset
114
+ n (int): skip rate (select every nth)
115
+ """
116
+ def __init__(self, dataset, n=2):
117
+ self.dataset = dataset
118
+ assert n >= 1
119
+ self.indices = np.arange(len(dataset))[::n]
120
+
121
+ def __getitem__(self, idx):
122
+ return self.dataset[self.indices[idx]]
123
+
124
+ def __len__(self):
125
+ return len(self.indices)
126
+
127
+ @property
128
+ def parser(self):
129
+ return self.dataset.parser
130
+
131
+ @property
132
+ def transform(self):
133
+ return self.dataset.transform
134
+
135
+ @transform.setter
136
+ def transform(self, t):
137
+ self.dataset.transform = t
138
+
139
+ @property
140
+ def transforms(self):
141
+ return self.dataset.transforms
142
+
143
+ @transforms.setter
144
+ def transforms(self, t):
145
+ self.dataset.transforms = t
efficientdet/effdet/data/dataset_config.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ COCO detect-waste dataset configurations
2
+
3
+ Updated 2021 Wimlds in Detect Waste in Pomerania
4
+ """
5
+ from dataclasses import dataclass
6
+ from typing import Dict
7
+
8
+
9
+ @dataclass
10
+ class CocoCfg:
11
+ variant: str = None
12
+ parser: str = 'coco'
13
+ num_classes: int = 80
14
+ splits: Dict[str, dict] = None
15
+
16
+
17
+ @dataclass
18
+ class TACOCfg(CocoCfg):
19
+ root: str = ""
20
+ ann: str = ""
21
+ variant: str = '2017'
22
+ num_classes: int = 28
23
+
24
+ def add_split(self):
25
+ self.splits = {
26
+ 'train': {'ann_filename': self.ann+'_train.json',
27
+ 'img_dir': self.root,
28
+ 'has_labels': True},
29
+ 'val': {'ann_filename': self.ann+'_test.json',
30
+ 'img_dir': self.root,
31
+ 'has_labels': True}
32
+ }
33
+
34
+
35
+ @dataclass
36
+ class DetectwasteCfg(CocoCfg):
37
+ root: str = ""
38
+ ann: str = ""
39
+ variant: str = '2017'
40
+ num_classes: int = 7
41
+
42
+ def add_split(self):
43
+ self.splits = {
44
+ 'train': {'ann_filename': self.ann+'_train.json',
45
+ 'img_dir': self.root,
46
+ 'has_labels': True},
47
+ 'val': {'ann_filename': self.ann+'_test.json',
48
+ 'img_dir': self.root,
49
+ 'has_labels': True}
50
+ }
51
+
52
+
53
+ @dataclass
54
+ class BinaryCfg(CocoCfg):
55
+ root: str = ""
56
+ ann: str = ""
57
+ variant: str = '2017'
58
+ num_classes: int = 1
59
+
60
+ def add_split(self):
61
+ self.splits = {
62
+ 'train': {'ann_filename': self.ann+'_train.json',
63
+ 'img_dir': self.root,
64
+ 'has_labels': True},
65
+ 'val': {'ann_filename': self.ann+'_test.json',
66
+ 'img_dir': self.root,
67
+ 'has_labels': True}
68
+ }
69
+
70
+
71
+ @dataclass
72
+ class BinaryMultiCfg(CocoCfg):
73
+ root: str = ""
74
+ ann: str = ""
75
+ variant: str = '2017'
76
+ num_classes: int = 1
77
+
78
+ def add_split(self):
79
+ self.splits = {
80
+ 'train': {'ann_filename': self.ann+'_train.json',
81
+ 'img_dir': self.root,
82
+ 'has_labels': True},
83
+ 'val': {'ann_filename': self.ann+'_test.json',
84
+ 'img_dir': self.root,
85
+ 'has_labels': True}
86
+ }
87
+
88
+
89
+ @dataclass
90
+ class TrashCanCfg(CocoCfg):
91
+ root: str = ""
92
+ ann: str = ""
93
+ variant: str = '2017'
94
+ num_classes: int = 8
95
+
96
+ def add_split(self):
97
+ self.splits = {
98
+ 'train': {'ann_filename': self.ann+'_train.json',
99
+ 'img_dir': self.root,
100
+ 'has_labels': True},
101
+ 'val': {'ann_filename': self.ann+'_test.json',
102
+ 'img_dir': self.root,
103
+ 'has_labels': True}
104
+ }
105
+
106
+
107
+ @dataclass
108
+ class UAVVasteCfg(CocoCfg):
109
+ root: str = ""
110
+ ann: str = ""
111
+ variant: str = '2017'
112
+ num_classes: int = 1
113
+
114
+ def add_split(self):
115
+ self.splits = {
116
+ 'train': {'ann_filename': self.ann+'_train.json',
117
+ 'img_dir': self.root,
118
+ 'has_labels': True},
119
+ 'val': {'ann_filename': self.ann+'_test.json',
120
+ 'img_dir': self.root,
121
+ 'has_labels': True}
122
+ }
123
+
124
+
125
+ @dataclass
126
+ class ICRACfg(CocoCfg):
127
+ root: str = ""
128
+ ann: str = ""
129
+ variant: str = '2017'
130
+ num_classes: int = 7
131
+
132
+ def add_split(self):
133
+ self.splits = {
134
+ 'train': {'ann_filename': self.ann+'_train.json',
135
+ 'img_dir': self.root,
136
+ 'has_labels': True},
137
+ 'val': {'ann_filename': self.ann+'_test.json',
138
+ 'img_dir': self.root,
139
+ 'has_labels': True}
140
+ }
141
+
142
+
143
+ @dataclass
144
+ class DrinkWasteCfg(CocoCfg):
145
+ root: str = ""
146
+ ann: str = ""
147
+ variant: str = '2017'
148
+ num_classes: int = 4
149
+
150
+ def add_split(self):
151
+ self.splits = {
152
+ 'train': {'ann_filename': self.ann+'_train.json',
153
+ 'img_dir': self.root,
154
+ 'has_labels': True},
155
+ 'val': {'ann_filename': self.ann+'_test.json',
156
+ 'img_dir': self.root,
157
+ 'has_labels': True}
158
+ }
159
+
160
+
161
+ @dataclass
162
+ class MJU_WasteCfg(CocoCfg):
163
+ root: str = ""
164
+ ann: str = ""
165
+ variant: str = '2017'
166
+ num_classes: int = 1
167
+
168
+ def add_split(self):
169
+ self.splits = {
170
+ 'train': {'ann_filename': self.ann+'_train.json',
171
+ 'img_dir': self.root,
172
+ 'has_labels': True},
173
+ 'val': {'ann_filename': self.ann+'_test.json',
174
+ 'img_dir': self.root,
175
+ 'has_labels': True}
176
+ }
177
+
178
+
179
+ @dataclass
180
+ class WadeCfg(CocoCfg):
181
+ root: str = ""
182
+ ann: str = ""
183
+ variant: str = '2017'
184
+ num_classes: int = 1
185
+
186
+ def add_split(self):
187
+ self.splits = {
188
+ 'train': {'ann_filename': self.ann+'_train.json',
189
+ 'img_dir': self.root,
190
+ 'has_labels': True},
191
+ 'val': {'ann_filename': self.ann+'_test.json',
192
+ 'img_dir': self.root,
193
+ 'has_labels': True}
194
+ }
efficientdet/effdet/data/dataset_factory.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Dataset factory
2
+
3
+ Updated 2021 Wimlds in Detect Waste in Pomerania
4
+ """
5
+ from collections import OrderedDict
6
+ from pathlib import Path
7
+
8
+ from .dataset_config import *
9
+ from .parsers import *
10
+ from .dataset import DetectionDatset
11
+ from .parsers import create_parser
12
+
13
+ # list of detect-waste datasets
14
+ waste_datasets_list = ['taco', 'detectwaste', 'binary', 'multi',
15
+ 'uav', 'mju', 'trashcan', 'wade', 'icra'
16
+ 'drinkwaste']
17
+
18
+
19
+ def create_dataset(name, root, ann, splits=('train', 'val')):
20
+ if isinstance(splits, str):
21
+ splits = (splits,)
22
+ name = name.lower()
23
+ root = Path(root)
24
+ dataset_cls = DetectionDatset
25
+ datasets = OrderedDict()
26
+ if name.startswith('coco'):
27
+ if 'coco2014' in name:
28
+ dataset_cfg = Coco2014Cfg()
29
+ else:
30
+ dataset_cfg = Coco2017Cfg()
31
+ for s in splits:
32
+ if s not in dataset_cfg.splits:
33
+ raise RuntimeError(f'{s} split not found in config')
34
+ split_cfg = dataset_cfg.splits[s]
35
+ ann_file = root / split_cfg['ann_filename']
36
+ parser_cfg = CocoParserCfg(
37
+ ann_filename=ann_file,
38
+ has_labels=split_cfg['has_labels']
39
+ )
40
+ datasets[s] = dataset_cls(
41
+ data_dir=root / Path(split_cfg['img_dir']),
42
+ parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
43
+ )
44
+ datasets = OrderedDict()
45
+ elif name in waste_datasets_list:
46
+ if name.startswith('taco'):
47
+ dataset_cfg = TACOCfg(root=root, ann=ann)
48
+ elif name.startswith('detectwaste'):
49
+ dataset_cfg = DetectwasteCfg(root=root, ann=ann)
50
+ elif name.startswith('binary'):
51
+ dataset_cfg = BinaryCfg(root=root, ann=ann)
52
+ elif name.startswith('multi'):
53
+ dataset_cfg = BinaryMultiCfg(root=root, ann=ann)
54
+ elif name.startswith('uav'):
55
+ dataset_cfg = UAVVasteCfg(root=root, ann=ann)
56
+ elif name.startswith('trashcan'):
57
+ dataset_cfg = TrashCanCfg(root=root, ann=ann)
58
+ elif name.startswith('drinkwaste'):
59
+ dataset_cfg = DrinkWasteCfg(root=root, ann=ann)
60
+ elif name.startswith('mju'):
61
+ dataset_cfg = MJU_WasteCfg(root=root, ann=ann)
62
+ elif name.startswith('wade'):
63
+ dataset_cfg = WadeCfg(root=root, ann=ann)
64
+ elif name.startswith('icra'):
65
+ dataset_cfg = ICRACfg(root=root, ann=ann)
66
+ else:
67
+ assert False, f'Unknown dataset parser ({name})'
68
+ dataset_cfg.add_split()
69
+ for s in splits:
70
+ if s not in dataset_cfg.splits:
71
+ raise RuntimeError(f'{s} split not found in config')
72
+ split_cfg = dataset_cfg.splits[s]
73
+ parser_cfg = CocoParserCfg(
74
+ ann_filename=split_cfg['ann_filename'],
75
+ has_labels=split_cfg['has_labels']
76
+ )
77
+ datasets[s] = dataset_cls(
78
+ data_dir=split_cfg['img_dir'],
79
+ parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
80
+ )
81
+ else:
82
+ assert False, f'Unknown dataset parser ({name})'
83
+
84
+ datasets = list(datasets.values())
85
+ return datasets if len(datasets) > 1 else datasets[0]
efficientdet/effdet/data/input_config.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2
+
3
+
4
+ def resolve_input_config(args, model_config=None, model=None):
5
+ if not isinstance(args, dict):
6
+ args = vars(args)
7
+ input_config = {}
8
+ if not model_config and model is not None and hasattr(model, 'config'):
9
+ model_config = model.config
10
+
11
+ # Resolve input/image size
12
+ in_chans = 3
13
+ input_size = (in_chans, 512, 512)
14
+
15
+ if 'input_size' in model_config:
16
+ input_size = tuple(model_config['input_size'])
17
+ elif 'image_size' in model_config:
18
+ input_size = (in_chans,) + tuple(model_config['image_size'])
19
+ assert isinstance(input_size, tuple) and len(input_size) == 3
20
+ input_config['input_size'] = input_size
21
+
22
+ # resolve interpolation method
23
+ input_config['interpolation'] = 'bicubic'
24
+ if 'interpolation' in args and args['interpolation']:
25
+ input_config['interpolation'] = args['interpolation']
26
+ elif 'interpolation' in model_config:
27
+ input_config['interpolation'] = model_config['interpolation']
28
+
29
+ # resolve dataset + model mean for normalization
30
+ input_config['mean'] = IMAGENET_DEFAULT_MEAN
31
+ if 'mean' in args and args['mean'] is not None:
32
+ mean = tuple(args['mean'])
33
+ if len(mean) == 1:
34
+ mean = tuple(list(mean) * in_chans)
35
+ else:
36
+ assert len(mean) == in_chans
37
+ input_config['mean'] = mean
38
+ elif 'mean' in model_config:
39
+ input_config['mean'] = model_config['mean']
40
+
41
+ # resolve dataset + model std deviation for normalization
42
+ input_config['std'] = IMAGENET_DEFAULT_STD
43
+ if 'std' in args and args['std'] is not None:
44
+ std = tuple(args['std'])
45
+ if len(std) == 1:
46
+ std = tuple(list(std) * in_chans)
47
+ else:
48
+ assert len(std) == in_chans
49
+ input_config['std'] = std
50
+ elif 'std' in model_config:
51
+ input_config['std'] = model_config['std']
52
+
53
+ # resolve letterbox fill color
54
+ input_config['fill_color'] = 'mean'
55
+ if 'fill_color' in args and args['fill_color'] is not None:
56
+ input_config['fill_color'] = args['fill_color']
57
+ elif 'fill_color' in model_config:
58
+ input_config['fill_color'] = model_config['fill_color']
59
+
60
+ return input_config
efficientdet/effdet/data/loader.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Object detection loader/collate
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ import torch.utils.data
6
+ from .transforms import *
7
+ from .transforms_albumentation import get_transform
8
+ from .random_erasing import RandomErasing
9
+ from effdet.anchors import AnchorLabeler
10
+ from timm.data.distributed_sampler import OrderedDistributedSampler
11
+ import os
12
+
13
+ MAX_NUM_INSTANCES = 100
14
+
15
+
16
+ class DetectionFastCollate:
17
+ """ A detection specific, optimized collate function w/ a bit of state.
18
+
19
+ Optionally performs anchor labelling. Doing this here offloads some work from the
20
+ GPU and the main training process thread and increases the load on the dataloader
21
+ threads.
22
+
23
+ """
24
+ def __init__(
25
+ self,
26
+ instance_keys=None,
27
+ instance_shapes=None,
28
+ instance_fill=-1,
29
+ max_instances=MAX_NUM_INSTANCES,
30
+ anchor_labeler=None,
31
+ ):
32
+ instance_keys = instance_keys or {'bbox', 'bbox_ignore', 'cls'}
33
+ instance_shapes = instance_shapes or dict(
34
+ bbox=(max_instances, 4), bbox_ignore=(max_instances, 4), cls=(max_instances,))
35
+ self.instance_info = {k: dict(fill=instance_fill, shape=instance_shapes[k]) for k in instance_keys}
36
+ self.max_instances = max_instances
37
+ self.anchor_labeler = anchor_labeler
38
+
39
+ def __call__(self, batch):
40
+ batch_size = len(batch)
41
+ target = dict()
42
+ labeler_outputs = dict()
43
+ img_tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
44
+ for i in range(batch_size):
45
+ img_tensor[i] += torch.from_numpy(batch[i][0])
46
+ labeler_inputs = {}
47
+ for tk, tv in batch[i][1].items():
48
+ instance_info = self.instance_info.get(tk, None)
49
+ if instance_info is not None:
50
+ # target tensor is associated with a detection instance
51
+ tv = torch.from_numpy(tv).to(dtype=torch.float32)
52
+ if self.anchor_labeler is None:
53
+ if i == 0:
54
+ shape = (batch_size,) + instance_info['shape']
55
+ target_tensor = torch.full(shape, instance_info['fill'], dtype=torch.float32)
56
+ target[tk] = target_tensor
57
+ else:
58
+ target_tensor = target[tk]
59
+ num_elem = min(tv.shape[0], self.max_instances)
60
+ target_tensor[i, 0:num_elem] = tv[0:num_elem]
61
+ else:
62
+ # no need to pass gt tensors through when labeler in use
63
+ if tk in ('bbox', 'cls'):
64
+ labeler_inputs[tk] = tv
65
+ else:
66
+ # target tensor is an image-level annotation / metadata
67
+ if i == 0:
68
+ # first batch elem, create destination tensors
69
+ if isinstance(tv, (tuple, list)):
70
+ # per batch elem sequence
71
+ shape = (batch_size, len(tv))
72
+ dtype = torch.float32 if isinstance(tv[0], (float, np.floating)) else torch.int32
73
+ else:
74
+ # per batch elem scalar
75
+ shape = batch_size,
76
+ dtype = torch.float32 if isinstance(tv, (float, np.floating)) else torch.int64
77
+ target_tensor = torch.zeros(shape, dtype=dtype)
78
+ target[tk] = target_tensor
79
+ else:
80
+ target_tensor = target[tk]
81
+ target_tensor[i] = torch.tensor(tv, dtype=target_tensor.dtype)
82
+
83
+ if self.anchor_labeler is not None:
84
+ cls_targets, box_targets, num_positives = self.anchor_labeler.label_anchors(
85
+ labeler_inputs['bbox'], labeler_inputs['cls'], filter_valid=False)
86
+ if i == 0:
87
+ # first batch elem, create destination tensors, separate key per level
88
+ for j, (ct, bt) in enumerate(zip(cls_targets, box_targets)):
89
+ labeler_outputs[f'label_cls_{j}'] = torch.zeros(
90
+ (batch_size,) + ct.shape, dtype=torch.int64)
91
+ labeler_outputs[f'label_bbox_{j}'] = torch.zeros(
92
+ (batch_size,) + bt.shape, dtype=torch.float32)
93
+ labeler_outputs['label_num_positives'] = torch.zeros(batch_size)
94
+ for j, (ct, bt) in enumerate(zip(cls_targets, box_targets)):
95
+ labeler_outputs[f'label_cls_{j}'][i] = ct
96
+ labeler_outputs[f'label_bbox_{j}'][i] = bt
97
+ labeler_outputs['label_num_positives'][i] = num_positives
98
+ if labeler_outputs:
99
+ target.update(labeler_outputs)
100
+
101
+ return img_tensor, target
102
+
103
+
104
+ class PrefetchLoader:
105
+
106
+ def __init__(self,
107
+ loader,
108
+ mean=IMAGENET_DEFAULT_MEAN,
109
+ std=IMAGENET_DEFAULT_STD,
110
+ re_prob=0.,
111
+ re_mode='pixel',
112
+ re_count=1,
113
+ ):
114
+ self.loader = loader
115
+ self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
116
+ self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
117
+ if re_prob > 0.:
118
+ self.random_erasing = RandomErasing(probability=re_prob, mode=re_mode, max_count=re_count)
119
+ else:
120
+ self.random_erasing = None
121
+
122
+ def __iter__(self):
123
+ stream = torch.cuda.Stream()
124
+ first = True
125
+
126
+ for next_input, next_target in self.loader:
127
+ with torch.cuda.stream(stream):
128
+ next_input = next_input.cuda(non_blocking=True)
129
+ next_input = next_input.float().sub_(self.mean).div_(self.std)
130
+ next_target = {k: v.cuda(non_blocking=True) for k, v in next_target.items()}
131
+ if self.random_erasing is not None:
132
+ next_input = self.random_erasing(next_input, next_target)
133
+
134
+ if not first:
135
+ yield input, target
136
+ else:
137
+ first = False
138
+
139
+ torch.cuda.current_stream().wait_stream(stream)
140
+ input = next_input
141
+ target = next_target
142
+
143
+ yield input, target
144
+
145
+ def __len__(self):
146
+ return len(self.loader)
147
+
148
+ @property
149
+ def sampler(self):
150
+ return self.loader.sampler
151
+
152
+ @property
153
+ def dataset(self):
154
+ return self.loader.dataset
155
+
156
+
157
+ def create_loader(
158
+ dataset,
159
+ input_size,
160
+ batch_size,
161
+ is_training=False,
162
+ use_prefetcher=True,
163
+ re_prob=0.,
164
+ re_mode='pixel',
165
+ re_count=1,
166
+ interpolation='bilinear',
167
+ fill_color='mean',
168
+ mean=IMAGENET_DEFAULT_MEAN,
169
+ std=IMAGENET_DEFAULT_STD,
170
+ num_workers=1,
171
+ distributed=False,
172
+ pin_mem=False,
173
+ anchor_labeler=None,
174
+ ):
175
+ if isinstance(input_size, tuple):
176
+ img_size = input_size[-2:]
177
+ else:
178
+ img_size = input_size
179
+
180
+ if is_training:
181
+ transforms = get_transform()
182
+ transform = transforms_coco_train(
183
+ img_size,
184
+ interpolation=interpolation,
185
+ use_prefetcher=use_prefetcher,
186
+ fill_color=fill_color,
187
+ mean=mean,
188
+ std=std)
189
+ else:
190
+ transforms = None
191
+ transform = transforms_coco_eval(
192
+ img_size,
193
+ interpolation=interpolation,
194
+ use_prefetcher=use_prefetcher,
195
+ fill_color=fill_color,
196
+ mean=mean,
197
+ std=std)
198
+ dataset.transforms = transforms
199
+ dataset.transform = transform
200
+
201
+ sampler = None
202
+ if distributed:
203
+ if is_training:
204
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
205
+ else:
206
+ # This will add extra duplicate entries to result in equal num
207
+ # of samples per-process, will slightly alter validation results
208
+ sampler = OrderedDistributedSampler(dataset)
209
+
210
+ collate_fn = DetectionFastCollate(anchor_labeler=anchor_labeler)
211
+ loader = torch.utils.data.DataLoader(
212
+ dataset,
213
+ batch_size=batch_size,
214
+ shuffle=sampler is None and is_training,
215
+ num_workers=num_workers,
216
+ sampler=sampler,
217
+ pin_memory=pin_mem,
218
+ collate_fn=collate_fn,
219
+ )
220
+ if use_prefetcher:
221
+ if is_training:
222
+ loader = PrefetchLoader(loader, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count)
223
+ else:
224
+ loader = PrefetchLoader(loader, mean=mean, std=std)
225
+
226
+ return loader
efficientdet/effdet/data/parsers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .parser_config import OpenImagesParserCfg, CocoParserCfg, VocParserCfg
2
+ from .parser_factory import create_parser
efficientdet/effdet/data/parsers/parser.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numbers import Integral
2
+ from typing import List, Union, Dict, Any
3
+
4
+
5
+ class Parser:
6
+ """ Parser base class.
7
+
8
+ The attributes listed below make up a public interface common to all parsers. They can be accessed directly
9
+ once the dataset is constructed and annotations are populated.
10
+
11
+ Attributes:
12
+
13
+ cat_names (list[str]):
14
+ list of category (class) names, with background class at position 0.
15
+ cat_ids (list[union[str, int]):
16
+ list of dataset specific, unique integer or string category ids, does not include background
17
+ cat_id_to_label (dict):
18
+ map from category id to integer 1-indexed class label
19
+
20
+ img_ids (list):
21
+ list of dataset specific, unique image ids corresponding to valid samples in dataset
22
+ img_ids_invalid (list):
23
+ list of image ids corresponding to invalid images, not used as samples
24
+ img_infos (list[dict]):
25
+ image info, list of info dicts with filename, width, height for each image sample
26
+ """
27
+ def __init__(
28
+ self,
29
+ bbox_yxyx: bool = False,
30
+ has_labels: bool = True,
31
+ include_masks: bool = False,
32
+ include_bboxes_ignore: bool = False,
33
+ ignore_empty_gt: bool = False,
34
+ min_img_size: int = 32,
35
+ ):
36
+ """
37
+ Args:
38
+ yxyx (bool): output coords in yxyx format, otherwise xyxy
39
+ has_labels (bool): dataset has labels (for training validation, False usually for test sets)
40
+ include_masks (bool): include segmentation masks in target output (not supported yet for any dataset)
41
+ include_bboxes_ignore (bool): include ignored bbox in target output
42
+ ignore_empty_gt (bool): ignore images with no ground truth (no negative images)
43
+ min_img_size (bool): ignore images with width or height smaller than this number
44
+ sub_sample (int): sample every N images from the dataset
45
+ """
46
+ # parser config, determines how dataset parsed and validated
47
+ self.yxyx = bbox_yxyx
48
+ self.has_labels = has_labels
49
+ self.include_masks = include_masks
50
+ self.include_bboxes_ignore = include_bboxes_ignore
51
+ self.ignore_empty_gt = ignore_empty_gt
52
+ self.min_img_size = min_img_size
53
+ self.label_offset = 1
54
+
55
+ # Category (class) metadata. Populated by _load_annotations()
56
+ self.cat_names: List[str] = []
57
+ self.cat_ids: List[Union[str, Integral]] = []
58
+ self.cat_id_to_label: Dict[Union[str, Integral], Integral] = dict()
59
+
60
+ # Image metadata. Populated by _load_annotations()
61
+ self.img_ids: List[Union[str, Integral]] = []
62
+ self.img_ids_invalid: List[Union[str, Integral]] = []
63
+ self.img_infos: List[Dict[str, Any]] = []
64
+
65
+ @property
66
+ def cat_dicts(self):
67
+ """return category names and labels in format compatible with TF Models Evaluator
68
+ list[dict(name=<class name>, id=<class label>)]
69
+ """
70
+ return [
71
+ dict(
72
+ name=name,
73
+ id=cat_id if not self.cat_id_to_label else self.cat_id_to_label[cat_id]
74
+ ) for name, cat_id in zip(self.cat_names, self.cat_ids)]
75
+
76
+ @property
77
+ def max_label(self):
78
+ if self.cat_id_to_label:
79
+ return max(self.cat_id_to_label.values())
80
+ else:
81
+ assert len(self.cat_ids) and isinstance(self.cat_ids[0], Integral)
82
+ return max(self.cat_ids)
efficientdet/effdet/data/parsers/parser_coco.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ COCO dataset parser
2
+
3
+ Copyright 2020 Ross Wightman
4
+ """
5
+ import numpy as np
6
+ from pycocotools.coco import COCO
7
+ from .parser import Parser
8
+ from .parser_config import CocoParserCfg
9
+
10
+
11
+ class CocoParser(Parser):
12
+
13
+ def __init__(self, cfg: CocoParserCfg):
14
+ super().__init__(
15
+ bbox_yxyx=cfg.bbox_yxyx,
16
+ has_labels=cfg.has_labels,
17
+ include_masks=cfg.include_masks,
18
+ include_bboxes_ignore=cfg.include_bboxes_ignore,
19
+ ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt,
20
+ min_img_size=cfg.min_img_size
21
+ )
22
+ self.cat_ids_as_labels = True # this is the default for original TF EfficientDet models
23
+ self.coco = None
24
+ self._load_annotations(cfg.ann_filename)
25
+
26
+ def get_ann_info(self, idx):
27
+ img_id = self.img_ids[idx]
28
+ return self._parse_img_ann(img_id)
29
+
30
+ def _load_annotations(self, ann_file):
31
+ assert self.coco is None
32
+ self.coco = COCO(ann_file)
33
+ self.cat_ids = self.coco.getCatIds()
34
+ self.cat_names = [c['name'] for c in self.coco.loadCats(ids=self.cat_ids)]
35
+ if not self.cat_ids_as_labels:
36
+ self.cat_id_to_label = {cat_id: i + self.label_offset for i, cat_id in enumerate(self.cat_ids)}
37
+ img_ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
38
+ for img_id in sorted(self.coco.imgs.keys()):
39
+ info = self.coco.loadImgs([img_id])[0]
40
+ if (min(info['width'], info['height']) < self.min_img_size or
41
+ (self.ignore_empty_gt and img_id not in img_ids_with_ann)):
42
+ self.img_ids_invalid.append(img_id)
43
+ continue
44
+ self.img_ids.append(img_id)
45
+ self.img_infos.append(info)
46
+
47
+ def _parse_img_ann(self, img_id):
48
+ ann_ids = self.coco.getAnnIds(imgIds=[img_id])
49
+ ann_info = self.coco.loadAnns(ann_ids)
50
+ bboxes = []
51
+ bboxes_ignore = []
52
+ cls = []
53
+
54
+ for i, ann in enumerate(ann_info):
55
+ if ann.get('ignore', False):
56
+ continue
57
+ x1, y1, w, h = ann['bbox']
58
+ if self.include_masks and ann['area'] <= 0:
59
+ continue
60
+ if w < 1 or h < 1:
61
+ continue
62
+
63
+ if self.yxyx:
64
+ bbox = [y1, x1, y1 + h, x1 + w]
65
+ else:
66
+ bbox = [x1, y1, x1 + w, y1 + h]
67
+
68
+ if ann.get('iscrowd', False):
69
+ if self.include_bboxes_ignore:
70
+ bboxes_ignore.append(bbox)
71
+ else:
72
+ bboxes.append(bbox)
73
+ cls.append(self.cat_id_to_label[ann['category_id']] if self.cat_id_to_label else ann['category_id'])
74
+
75
+ if bboxes:
76
+ bboxes = np.array(bboxes, ndmin=2, dtype=np.float32)
77
+ cls = np.array(cls, dtype=np.int64)
78
+ else:
79
+ bboxes = np.zeros((0, 4), dtype=np.float32)
80
+ cls = np.array([], dtype=np.int64)
81
+
82
+ if self.include_bboxes_ignore:
83
+ if bboxes_ignore:
84
+ bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32)
85
+ else:
86
+ bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
87
+
88
+ ann = dict(bbox=bboxes, cls=cls)
89
+
90
+ if self.include_bboxes_ignore:
91
+ ann['bbox_ignore'] = bboxes_ignore
92
+
93
+ return ann
efficientdet/effdet/data/parsers/parser_config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Dataset parser configs
2
+
3
+ Copyright 2020 Ross Wightman
4
+ """
5
+ from dataclasses import dataclass
6
+
7
+ __all__ = ['CocoParserCfg', 'OpenImagesParserCfg', 'VocParserCfg']
8
+
9
+
10
+ @dataclass
11
+ class CocoParserCfg:
12
+ ann_filename: str # absolute path
13
+ include_masks: bool = False
14
+ include_bboxes_ignore: bool = False
15
+ has_labels: bool = True
16
+ bbox_yxyx: bool = True
17
+ min_img_size: int = 32
18
+ ignore_empty_gt: bool = False
19
+
20
+
21
+ @dataclass
22
+ class VocParserCfg:
23
+ split_filename: str
24
+ ann_filename: str
25
+ img_filename: str = '%.jpg'
26
+ keep_difficult: bool = True
27
+ classes: list = None
28
+ add_background: bool = True
29
+ has_labels: bool = True
30
+ bbox_yxyx: bool = True
31
+ min_img_size: int = 32
32
+ ignore_empty_gt: bool = False
33
+
34
+
35
+ @dataclass
36
+ class OpenImagesParserCfg:
37
+ categories_filename: str
38
+ img_info_filename: str
39
+ bbox_filename: str
40
+ img_label_filename: str = ''
41
+ masks_filename: str = ''
42
+ img_filename: str = '%s.jpg' # relative to dataset img_dir
43
+ task: str = 'obj'
44
+ prefix_levels: int = 1
45
+ add_background: bool = True
46
+ has_labels: bool = True
47
+ bbox_yxyx: bool = True
48
+ min_img_size: int = 32
49
+ ignore_empty_gt: bool = False
efficientdet/effdet/data/parsers/parser_factory.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Parser factory
2
+
3
+ Copyright 2020 Ross Wightman
4
+ """
5
+ from .parser_coco import CocoParser
6
+ from .parser_voc import VocParser
7
+ from .parser_open_images import OpenImagesParser
8
+
9
+
10
+ def create_parser(name, **kwargs):
11
+ if name == 'coco':
12
+ parser = CocoParser(**kwargs)
13
+ elif name == 'voc':
14
+ parser = VocParser(**kwargs)
15
+ elif name == 'openimages':
16
+ parser = OpenImagesParser(**kwargs)
17
+ else:
18
+ assert False, f'Unknown dataset parser ({name})'
19
+ return parser
efficientdet/effdet/data/parsers/parser_open_images.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenImages dataset parser
2
+
3
+ Copyright 2020 Ross Wightman
4
+ """
5
+ import numpy as np
6
+ import os
7
+ import logging
8
+
9
+ from .parser import Parser
10
+ from .parser_config import OpenImagesParserCfg
11
+
12
+ _logger = logging.getLogger(__name__)
13
+
14
+
15
+ class OpenImagesParser(Parser):
16
+
17
+ def __init__(self, cfg: OpenImagesParserCfg):
18
+ super().__init__(
19
+ bbox_yxyx=cfg.bbox_yxyx,
20
+ has_labels=cfg.has_labels,
21
+ include_masks=False, # FIXME to support someday
22
+ include_bboxes_ignore=False,
23
+ ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt,
24
+ min_img_size=cfg.min_img_size
25
+ )
26
+ self.img_prefix_levels = cfg.prefix_levels
27
+ self.mask_prefix_levels = 1
28
+ self._anns = None # access via get_ann_info()
29
+ self._img_to_ann = None
30
+ self._load_annotations(
31
+ categories_filename=cfg.categories_filename,
32
+ img_info_filename=cfg.img_info_filename,
33
+ img_filename=cfg.img_filename,
34
+ masks_filename=cfg.masks_filename,
35
+ bbox_filename=cfg.bbox_filename
36
+ )
37
+
38
+ def _load_annotations(
39
+ self,
40
+ categories_filename: str,
41
+ img_info_filename: str,
42
+ img_filename: str,
43
+ masks_filename: str,
44
+ bbox_filename: str,
45
+ ):
46
+ import pandas as pd # For now, blow up on pandas req only when trying to load open images anno
47
+
48
+ _logger.info('Loading categories...')
49
+ classes_df = pd.read_csv(categories_filename, header=None)
50
+ self.cat_ids = classes_df[0].tolist()
51
+ self.cat_names = classes_df[1].tolist()
52
+ self.cat_id_to_label = {c: i + self.label_offset for i, c in enumerate(self.cat_ids)}
53
+
54
+ def _img_filename(img_id):
55
+ # build image filenames that are relative to img_dir
56
+ filename = img_filename % img_id
57
+ if self.img_prefix_levels:
58
+ levels = [c for c in img_id[:self.img_prefix_levels]]
59
+ filename = os.path.join(*levels, filename)
60
+ return filename
61
+
62
+ def _mask_filename(mask_path):
63
+ # FIXME finish
64
+ if self.mask_prefix_levels:
65
+ levels = [c for c in mask_path[:self.mask_prefix_levels]]
66
+ mask_path = os.path.join(*levels, mask_path)
67
+ return mask_path
68
+
69
+ def _load_img_info(csv_file, select_img_ids=None):
70
+ _logger.info('Read img_info csv...')
71
+ img_info_df = pd.read_csv(csv_file, index_col='id')
72
+
73
+ _logger.info('Filter images...')
74
+ if select_img_ids is not None:
75
+ img_info_df = img_info_df.loc[select_img_ids]
76
+ img_info_df = img_info_df[
77
+ (img_info_df['width'] >= self.min_img_size) & (img_info_df['height'] >= self.min_img_size)]
78
+
79
+ _logger.info('Mapping ids...')
80
+ img_info_df['img_id'] = img_info_df.index
81
+ img_info_df['file_name'] = img_info_df.index.map(lambda x: _img_filename(x))
82
+ img_info_df = img_info_df[['img_id', 'file_name', 'width', 'height']]
83
+ img_sizes = img_info_df[['width', 'height']].values
84
+ self.img_infos = img_info_df.to_dict('records')
85
+ self.img_ids = img_info_df.index.values.tolist()
86
+ img_id_to_idx = {img_id: idx for idx, img_id in enumerate(self.img_ids)}
87
+ return img_sizes, img_id_to_idx
88
+
89
+ if self.include_masks and self.has_labels:
90
+ masks_df = pd.read_csv(masks_filename)
91
+
92
+ # NOTE currently using dataset masks anno ImageIDs to form valid img_ids from the dataset
93
+ anno_img_ids = sorted(masks_df['ImageID'].unique())
94
+ img_sizes, img_id_to_idx = _load_img_info(img_info_filename, select_img_ids=anno_img_ids)
95
+
96
+ masks_df['ImageIdx'] = masks_df['ImageID'].map(img_id_to_idx)
97
+ if np.issubdtype(masks_df.ImageIdx.dtype, np.floating):
98
+ masks_df = masks_df.dropna(axis='rows')
99
+ masks_df['ImageIdx'] = masks_df.ImageIdx.astype(np.int32)
100
+ masks_df.sort_values('ImageIdx', inplace=True)
101
+ ann_img_idx = masks_df['ImageIdx'].values
102
+ img_sizes = img_sizes[ann_img_idx]
103
+ masks_df['BoxXMin'] = masks_df['BoxXMin'] * img_sizes[:, 0]
104
+ masks_df['BoxXMax'] = masks_df['BoxXMax'] * img_sizes[:, 0]
105
+ masks_df['BoxYMin'] = masks_df['BoxYMin'] * img_sizes[:, 1]
106
+ masks_df['BoxYMax'] = masks_df['BoxYMax'] * img_sizes[:, 1]
107
+ masks_df['LabelIdx'] = masks_df['LabelName'].map(self.cat_id_to_label)
108
+ # FIXME remap mask filename with _mask_filename
109
+
110
+ self._anns = dict(
111
+ bbox=masks_df[['BoxXMin', 'BoxYMin', 'BoxXMax', 'BoxYMax']].values.astype(np.float32),
112
+ label=masks_df[['LabelIdx']].values.astype(np.int32),
113
+ mask_path=masks_df[['MaskPath']].values
114
+ )
115
+ _, ri, rc = np.unique(ann_img_idx, return_index=True, return_counts=True)
116
+ self._img_to_ann = list(zip(ri, rc)) # index, count tuples
117
+ elif self.has_labels:
118
+ _logger.info('Loading bbox...')
119
+ bbox_df = pd.read_csv(bbox_filename)
120
+
121
+ # NOTE currently using dataset box anno ImageIDs to form valid img_ids from the larger dataset.
122
+ # FIXME use *imagelabels.csv or imagelabels-boxable.csv for negative examples (without box?)
123
+ anno_img_ids = sorted(bbox_df['ImageID'].unique())
124
+ img_sizes, img_id_to_idx = _load_img_info(img_info_filename, select_img_ids=anno_img_ids)
125
+
126
+ _logger.info('Process bbox...')
127
+ bbox_df['ImageIdx'] = bbox_df['ImageID'].map(img_id_to_idx)
128
+ if np.issubdtype(bbox_df.ImageIdx.dtype, np.floating):
129
+ bbox_df = bbox_df.dropna(axis='rows')
130
+ bbox_df['ImageIdx'] = bbox_df.ImageIdx.astype(np.int32)
131
+ bbox_df.sort_values('ImageIdx', inplace=True)
132
+ ann_img_idx = bbox_df['ImageIdx'].values
133
+ img_sizes = img_sizes[ann_img_idx]
134
+ bbox_df['XMin'] = bbox_df['XMin'] * img_sizes[:, 0]
135
+ bbox_df['XMax'] = bbox_df['XMax'] * img_sizes[:, 0]
136
+ bbox_df['YMin'] = bbox_df['YMin'] * img_sizes[:, 1]
137
+ bbox_df['YMax'] = bbox_df['YMax'] * img_sizes[:, 1]
138
+ bbox_df['LabelIdx'] = bbox_df['LabelName'].map(self.cat_id_to_label).astype(np.int32)
139
+
140
+ self._anns = dict(
141
+ bbox=bbox_df[['XMin', 'YMin', 'XMax', 'YMax']].values.astype(np.float32),
142
+ label=bbox_df[['LabelIdx', 'IsGroupOf']].values.astype(np.int32),
143
+ )
144
+ _, ri, rc = np.unique(ann_img_idx, return_index=True, return_counts=True)
145
+ self._img_to_ann = list(zip(ri, rc)) # index, count tuples
146
+ else:
147
+ _load_img_info(img_info_filename)
148
+
149
+ _logger.info('Annotations loaded!')
150
+
151
+ def get_ann_info(self, idx):
152
+ if not self.has_labels:
153
+ return dict()
154
+ start_idx, num_ann = self._img_to_ann[idx]
155
+ ann_keys = tuple(self._anns.keys())
156
+ ann_values = tuple(self._anns[k][start_idx:start_idx + num_ann] for k in ann_keys)
157
+ return self._parse_ann_info(idx, ann_keys, ann_values)
158
+
159
+ def _parse_ann_info(self, img_idx, ann_keys, ann_values):
160
+ """
161
+ """
162
+ gt_bboxes = []
163
+ gt_labels = []
164
+ gt_bboxes_ignore = []
165
+ if self.include_masks:
166
+ assert 'mask_path' in ann_keys
167
+ gt_masks = []
168
+
169
+ for ann in zip(*ann_values):
170
+ ann = dict(zip(ann_keys, ann))
171
+ x1, y1, x2, y2 = ann['bbox']
172
+ if x2 - x1 < 1 or y2 - y1 < 1:
173
+ continue
174
+ label = ann['label'][0]
175
+ iscrowd = False
176
+ if len(ann['label']) > 1:
177
+ iscrowd = ann['label'][1]
178
+ if self.yxyx:
179
+ bbox = np.array([y1, x1, y2, x2], dtype=np.float32)
180
+ else:
181
+ bbox = ann['bbox']
182
+ if iscrowd:
183
+ gt_bboxes_ignore.append(bbox)
184
+ else:
185
+ gt_bboxes.append(bbox)
186
+ gt_labels.append(label)
187
+ # if self.include_masks:
188
+ # img_info = self.img_infos[img_idx]
189
+ # mask_img = SegmentationMask(ann['mask_filename'], img_info['width'], img_info['height'])
190
+ # gt_masks.append(mask_img)
191
+
192
+ if gt_bboxes:
193
+ gt_bboxes = np.array(gt_bboxes, ndmin=2, dtype=np.float32)
194
+ gt_labels = np.array(gt_labels, dtype=np.int64)
195
+ else:
196
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
197
+ gt_labels = np.array([], dtype=np.int64)
198
+
199
+ if self.include_bboxes_ignore:
200
+ if gt_bboxes_ignore:
201
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, ndmin=2, dtype=np.float32)
202
+ else:
203
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
204
+
205
+ ann = dict(bbox=gt_bboxes, cls=gt_labels)
206
+
207
+ if self.include_bboxes_ignore:
208
+ ann.update(dict(bbox_ignore=gt_bboxes_ignore, cls_ignore=np.array([], dtype=np.int64)))
209
+ if self.include_masks:
210
+ ann['masks'] = gt_masks
211
+ return ann
efficientdet/effdet/data/parsers/parser_voc.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Pascal VOC dataset parser
2
+
3
+ Copyright 2020 Ross Wightman
4
+ """
5
+ import os
6
+ import xml.etree.ElementTree as ET
7
+ from collections import defaultdict
8
+ import numpy as np
9
+
10
+ from .parser import Parser
11
+ from .parser_config import VocParserCfg
12
+
13
+
14
+ class VocParser(Parser):
15
+
16
+ DEFAULT_CLASSES = (
17
+ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
18
+ 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant',
19
+ 'sheep', 'sofa', 'train', 'tvmonitor')
20
+
21
+ def __init__(self, cfg: VocParserCfg):
22
+ super().__init__(
23
+ bbox_yxyx=cfg.bbox_yxyx,
24
+ has_labels=cfg.has_labels,
25
+ include_masks=False, # FIXME to support someday
26
+ include_bboxes_ignore=False,
27
+ ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt,
28
+ min_img_size=cfg.min_img_size
29
+ )
30
+ self.correct_bbox = 1
31
+ self.keep_difficult = cfg.keep_difficult
32
+
33
+ self.anns = None
34
+ self.img_id_to_idx = {}
35
+ self._load_annotations(
36
+ split_filename=cfg.split_filename,
37
+ img_filename=cfg.img_filename,
38
+ ann_filename=cfg.ann_filename,
39
+ classes=cfg.classes,
40
+ )
41
+
42
+ def _load_annotations(
43
+ self,
44
+ split_filename: str,
45
+ img_filename: str,
46
+ ann_filename: str,
47
+ classes=None,
48
+ ):
49
+ classes = classes or self.DEFAULT_CLASSES
50
+ self.cat_names = list(classes)
51
+ self.cat_ids = self.cat_names
52
+ self.cat_id_to_label = {cat: i + self.label_offset for i, cat in enumerate(self.cat_ids)}
53
+
54
+ self.anns = []
55
+
56
+ with open(split_filename) as f:
57
+ ids = f.readlines()
58
+ for img_id in ids:
59
+ img_id = img_id.strip("\n")
60
+ filename = img_filename % img_id
61
+ xml_path = ann_filename % img_id
62
+ tree = ET.parse(xml_path)
63
+ root = tree.getroot()
64
+ size = root.find('size')
65
+ width = int(size.find('width').text)
66
+ height = int(size.find('height').text)
67
+ if min(width, height) < self.min_img_size:
68
+ continue
69
+
70
+ anns = []
71
+ for obj_idx, obj in enumerate(root.findall('object')):
72
+ name = obj.find('name').text
73
+ label = self.cat_id_to_label[name]
74
+ difficult = int(obj.find('difficult').text)
75
+ bnd_box = obj.find('bndbox')
76
+ bbox = [
77
+ int(bnd_box.find('xmin').text),
78
+ int(bnd_box.find('ymin').text),
79
+ int(bnd_box.find('xmax').text),
80
+ int(bnd_box.find('ymax').text)
81
+ ]
82
+ anns.append(dict(label=label, bbox=bbox, difficult=difficult))
83
+
84
+ if not self.ignore_empty_gt or len(anns):
85
+ self.anns.append(anns)
86
+ self.img_infos.append(dict(id=img_id, file_name=filename, width=width, height=height))
87
+ self.img_ids.append(img_id)
88
+ else:
89
+ self.img_ids_invalid.append(img_id)
90
+
91
+ def merge(self, other):
92
+ assert len(self.cat_ids) == len(other.cat_ids)
93
+ self.img_ids.extend(other.img_ids)
94
+ self.img_infos.extend(other.img_infos)
95
+ self.anns.extend(other.anns)
96
+
97
+ def get_ann_info(self, idx):
98
+ return self._parse_ann_info(self.anns[idx])
99
+
100
+ def _parse_ann_info(self, ann_info):
101
+ bboxes = []
102
+ labels = []
103
+ bboxes_ignore = []
104
+ labels_ignore = []
105
+ for ann in ann_info:
106
+ ignore = False
107
+ x1, y1, x2, y2 = ann['bbox']
108
+ label = ann['label']
109
+ w = x2 - x1
110
+ h = y2 - y1
111
+ if w < 1 or h < 1:
112
+ ignore = True
113
+ if self.yxyx:
114
+ bbox = [y1, x1, y2, x2]
115
+ else:
116
+ bbox = ann['bbox']
117
+ if ignore or (ann['difficult'] and not self.keep_difficult):
118
+ bboxes_ignore.append(bbox)
119
+ labels_ignore.append(label)
120
+ else:
121
+ bboxes.append(bbox)
122
+ labels.append(label)
123
+
124
+ if not bboxes:
125
+ bboxes = np.zeros((0, 4), dtype=np.float32)
126
+ labels = np.zeros((0, ), dtype=np.float32)
127
+ else:
128
+ bboxes = np.array(bboxes, ndmin=2, dtype=np.float32) - self.correct_bbox
129
+ labels = np.array(labels, dtype=np.float32)
130
+
131
+ if self.include_bboxes_ignore:
132
+ if not bboxes_ignore:
133
+ bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
134
+ labels_ignore = np.zeros((0, ), dtype=np.float32)
135
+ else:
136
+ bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32) - self.correct_bbox
137
+ labels_ignore = np.array(labels_ignore, dtype=np.float32)
138
+
139
+ ann = dict(
140
+ bbox=bboxes.astype(np.float32),
141
+ cls=labels.astype(np.int64))
142
+
143
+ if self.include_bboxes_ignore:
144
+ ann.update(dict(
145
+ bbox_ignore=bboxes_ignore.astype(np.float32),
146
+ cls_ignore=labels_ignore.astype(np.int64)))
147
+ return ann
148
+
efficientdet/effdet/data/random_erasing.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Multi-Scale RandomErasing
2
+
3
+ Copyright 2020 Ross Wightman
4
+ """
5
+ import random
6
+ import math
7
+ import torch
8
+
9
+
10
+ def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
11
+ # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
12
+ # paths, flip the order so normal is run on CPU if this becomes a problem
13
+ # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
14
+ if per_pixel:
15
+ return torch.empty(patch_size, dtype=dtype, device=device).normal_()
16
+ elif rand_color:
17
+ return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
18
+ else:
19
+ return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
20
+
21
+
22
+ class RandomErasing:
23
+ """ Randomly selects a rectangle region in an image and erases its pixels.
24
+ 'Random Erasing Data Augmentation' by Zhong et al.
25
+ See https://arxiv.org/pdf/1708.04896.pdf
26
+
27
+ This variant of RandomErasing is tweaked for multi-scale obj detection training.
28
+ Args:
29
+ probability: Probability that the Random Erasing operation will be performed.
30
+ min_area: Minimum percentage of erased area wrt input image area.
31
+ max_area: Maximum percentage of erased area wrt input image area.
32
+ min_aspect: Minimum aspect ratio of erased area.
33
+ mode: pixel color mode, one of 'const', 'rand', or 'pixel'
34
+ 'const' - erase block is constant color of 0 for all channels
35
+ 'rand' - erase block is same per-channel random (normal) color
36
+ 'pixel' - erase block is per-pixel random (normal) color
37
+ max_count: maximum number of erasing blocks per image, area per box is scaled by count.
38
+ per-image count is randomly chosen between 1 and this value.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ probability=0.5, min_area=0.02, max_area=1/4, min_aspect=0.3, max_aspect=None,
44
+ mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
45
+ self.probability = probability
46
+ self.min_area = min_area
47
+ self.max_area = max_area
48
+ max_aspect = max_aspect or 1 / min_aspect
49
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
50
+ self.min_count = min_count
51
+ self.max_count = max_count or min_count
52
+ self.num_splits = num_splits
53
+ mode = mode.lower()
54
+ self.rand_color = False
55
+ self.per_pixel = False
56
+ if mode == 'rand':
57
+ self.rand_color = True # per block random normal
58
+ elif mode == 'pixel':
59
+ self.per_pixel = True # per pixel random normal
60
+ else:
61
+ assert not mode or mode == 'const'
62
+ self.device = device
63
+
64
+ def _erase(self, img, chan, img_h, img_w, dtype):
65
+ if random.random() > self.probability:
66
+ return
67
+ area = img_h * img_w
68
+ count = self.min_count if self.min_count == self.max_count else \
69
+ random.randint(self.min_count, self.max_count)
70
+ for _ in range(count):
71
+ for attempt in range(10):
72
+ target_area = random.uniform(self.min_area, self.max_area) * area / count
73
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
74
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
75
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
76
+ if w < img_w and h < img_h:
77
+ top = random.randint(0, img_h - h)
78
+ left = random.randint(0, img_w - w)
79
+ img[:, top:top + h, left:left + w] = _get_pixels(
80
+ self.per_pixel, self.rand_color, (chan, h, w),
81
+ dtype=dtype, device=self.device)
82
+ break
83
+
84
+ def __call__(self, input, target):
85
+ batch_size, chan, input_h, input_w = input.shape
86
+ img_scales = target['img_scale']
87
+ img_size = (target['img_size'] / img_scales.unsqueeze(1)).int()
88
+ img_size[:, 0] = img_size[:, 0].clamp(max=input_w)
89
+ img_size[:, 1] = img_size[:, 1].clamp(max=input_h)
90
+ # skip first slice of batch if num_splits is set (for clean portion of samples)
91
+ batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
92
+ for i in range(batch_start, batch_size):
93
+ self._erase(input[i], chan, img_size[i, 1], img_size[i, 0], input.dtype)
94
+ return input
efficientdet/effdet/data/transforms.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ COCO transforms (quick and dirty)
2
+
3
+ Hacked together by Ross Wightman
4
+ """
5
+ import torch
6
+ from PIL import Image
7
+ import numpy as np
8
+ import random
9
+ import math
10
+
11
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
12
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
13
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
14
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
15
+
16
+
17
+ class ImageToNumpy:
18
+
19
+ def __call__(self, pil_img, annotations: dict):
20
+ np_img = np.array(pil_img, dtype=np.uint8)
21
+ if np_img.ndim < 3:
22
+ np_img = np.expand_dims(np_img, axis=-1)
23
+ np_img = np.moveaxis(np_img, 2, 0) # HWC to CHW
24
+ return np_img, annotations
25
+
26
+
27
+ class ImageToTensor:
28
+
29
+ def __init__(self, dtype=torch.float32):
30
+ self.dtype = dtype
31
+
32
+ def __call__(self, pil_img, annotations: dict):
33
+ np_img = np.array(pil_img, dtype=np.uint8)
34
+ if np_img.ndim < 3:
35
+ np_img = np.expand_dims(np_img, axis=-1)
36
+ np_img = np.moveaxis(np_img, 2, 0) # HWC to CHW
37
+ return torch.from_numpy(np_img).to(dtype=self.dtype), annotations
38
+
39
+
40
+ def _pil_interp(method):
41
+ if method == 'bicubic':
42
+ return Image.BICUBIC
43
+ elif method == 'lanczos':
44
+ return Image.LANCZOS
45
+ elif method == 'hamming':
46
+ return Image.HAMMING
47
+ else:
48
+ # default bilinear, do we want to allow nearest?
49
+ return Image.BILINEAR
50
+
51
+
52
+ def clip_boxes_(boxes, img_size):
53
+ height, width = img_size
54
+ clip_upper = np.array([height, width] * 2, dtype=boxes.dtype)
55
+ np.clip(boxes, 0, clip_upper, out=boxes)
56
+
57
+
58
+ def clip_boxes(boxes, img_size):
59
+ clipped_boxes = boxes.copy()
60
+ clip_boxes_(clipped_boxes, img_size)
61
+ return clipped_boxes
62
+
63
+
64
+ def _size_tuple(size):
65
+ if isinstance(size, int):
66
+ return size, size
67
+ else:
68
+ assert len(size) == 2
69
+ return size
70
+
71
+
72
+ class ResizePad:
73
+
74
+ def __init__(self, target_size: int, interpolation: str = 'bilinear', fill_color: tuple = (0, 0, 0)):
75
+ self.target_size = _size_tuple(target_size)
76
+ self.interpolation = interpolation
77
+ self.fill_color = fill_color
78
+
79
+ def __call__(self, img, anno: dict):
80
+ width, height = img.size
81
+
82
+ img_scale_y = self.target_size[0] / height
83
+ img_scale_x = self.target_size[1] / width
84
+ img_scale = min(img_scale_y, img_scale_x)
85
+ scaled_h = int(height * img_scale)
86
+ scaled_w = int(width * img_scale)
87
+
88
+ new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color)
89
+ interp_method = _pil_interp(self.interpolation)
90
+ img = img.resize((scaled_w, scaled_h), interp_method)
91
+ new_img.paste(img)
92
+
93
+ if 'bbox' in anno:
94
+ # FIXME haven't tested this path since not currently using dataset annotations for train/eval
95
+ bbox = anno['bbox']
96
+ bbox[:, :4] *= img_scale
97
+ clip_boxes_(bbox, (scaled_h, scaled_w))
98
+ valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1)
99
+ anno['bbox'] = bbox[valid_indices, :]
100
+ anno['cls'] = anno['cls'][valid_indices]
101
+
102
+ anno['img_scale'] = 1. / img_scale # back to original
103
+
104
+ return new_img, anno
105
+
106
+
107
+ class RandomResizePad:
108
+
109
+ def __init__(self, target_size: int, scale: tuple = (0.1, 2.0), interpolation: str = 'bilinear',
110
+ fill_color: tuple = (0, 0, 0)):
111
+ self.target_size = _size_tuple(target_size)
112
+ self.scale = scale
113
+ self.interpolation = interpolation
114
+ self.fill_color = fill_color
115
+
116
+ def _get_params(self, img):
117
+ # Select a random scale factor.
118
+ scale_factor = random.uniform(*self.scale)
119
+ scaled_target_height = scale_factor * self.target_size[0]
120
+ scaled_target_width = scale_factor * self.target_size[1]
121
+
122
+ # Recompute the accurate scale_factor using rounded scaled image size.
123
+ width, height = img.size
124
+ img_scale_y = scaled_target_height / height
125
+ img_scale_x = scaled_target_width / width
126
+ img_scale = min(img_scale_y, img_scale_x)
127
+
128
+ # Select non-zero random offset (x, y) if scaled image is larger than target size
129
+ scaled_h = int(height * img_scale)
130
+ scaled_w = int(width * img_scale)
131
+ offset_y = scaled_h - self.target_size[0]
132
+ offset_x = scaled_w - self.target_size[1]
133
+ offset_y = int(max(0.0, float(offset_y)) * random.uniform(0, 1))
134
+ offset_x = int(max(0.0, float(offset_x)) * random.uniform(0, 1))
135
+ return scaled_h, scaled_w, offset_y, offset_x, img_scale
136
+
137
+ def __call__(self, img, anno: dict):
138
+ scaled_h, scaled_w, offset_y, offset_x, img_scale = self._get_params(img)
139
+
140
+ interp_method = _pil_interp(self.interpolation)
141
+ img = img.resize((scaled_w, scaled_h), interp_method)
142
+ right, lower = min(scaled_w, offset_x + self.target_size[1]), min(scaled_h, offset_y + self.target_size[0])
143
+ img = img.crop((offset_x, offset_y, right, lower))
144
+ new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color)
145
+ new_img.paste(img)
146
+
147
+ if 'bbox' in anno:
148
+ # FIXME not fully tested
149
+ bbox = anno['bbox'].copy() # FIXME copy for debugger inspection, back to inplace
150
+ bbox[:, :4] *= img_scale
151
+ box_offset = np.stack([offset_y, offset_x] * 2)
152
+ bbox -= box_offset
153
+ clip_boxes_(bbox, (scaled_h, scaled_w))
154
+ valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1)
155
+ anno['bbox'] = bbox[valid_indices, :]
156
+ anno['cls'] = anno['cls'][valid_indices]
157
+
158
+ anno['img_scale'] = 1. / img_scale # back to original
159
+
160
+ return new_img, anno
161
+
162
+
163
+ class RandomFlip:
164
+
165
+ def __init__(self, horizontal=True, vertical=False, prob=0.5):
166
+ self.horizontal = horizontal
167
+ self.vertical = vertical
168
+ self.prob = prob
169
+
170
+ def _get_params(self):
171
+ do_horizontal = random.random() < self.prob if self.horizontal else False
172
+ do_vertical = random.random() < self.prob if self.vertical else False
173
+ return do_horizontal, do_vertical
174
+
175
+ def __call__(self, img, annotations: dict):
176
+ do_horizontal, do_vertical = self._get_params()
177
+ width, height = img.size
178
+
179
+ def _fliph(bbox):
180
+ x_max = width - bbox[:, 1]
181
+ x_min = width - bbox[:, 3]
182
+ bbox[:, 1] = x_min
183
+ bbox[:, 3] = x_max
184
+
185
+ def _flipv(bbox):
186
+ y_max = height - bbox[:, 0]
187
+ y_min = height - bbox[:, 2]
188
+ bbox[:, 0] = y_min
189
+ bbox[:, 2] = y_max
190
+
191
+ if do_horizontal and do_vertical:
192
+ img = img.transpose(Image.ROTATE_180)
193
+ if 'bbox' in annotations:
194
+ _fliph(annotations['bbox'])
195
+ _flipv(annotations['bbox'])
196
+ elif do_horizontal:
197
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
198
+ if 'bbox' in annotations:
199
+ _fliph(annotations['bbox'])
200
+ elif do_vertical:
201
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
202
+ if 'bbox' in annotations:
203
+ _flipv(annotations['bbox'])
204
+
205
+ return img, annotations
206
+
207
+
208
+ def resolve_fill_color(fill_color, img_mean=IMAGENET_DEFAULT_MEAN):
209
+ if isinstance(fill_color, tuple):
210
+ assert len(fill_color) == 3
211
+ fill_color = fill_color
212
+ else:
213
+ try:
214
+ int_color = int(fill_color)
215
+ fill_color = (int_color,) * 3
216
+ except ValueError:
217
+ assert fill_color == 'mean'
218
+ fill_color = tuple([int(round(255 * x)) for x in img_mean])
219
+ return fill_color
220
+
221
+
222
+ class Compose:
223
+
224
+ def __init__(self, transforms: list):
225
+ self.transforms = transforms
226
+
227
+ def __call__(self, img, annotations: dict):
228
+ for t in self.transforms:
229
+ img, annotations = t(img, annotations)
230
+ return img, annotations
231
+
232
+
233
+ def transforms_coco_eval(
234
+ img_size=224,
235
+ interpolation='bilinear',
236
+ use_prefetcher=False,
237
+ fill_color='mean',
238
+ mean=IMAGENET_DEFAULT_MEAN,
239
+ std=IMAGENET_DEFAULT_STD):
240
+
241
+ fill_color = resolve_fill_color(fill_color, mean)
242
+
243
+ image_tfl = [
244
+ ResizePad(
245
+ target_size=img_size, interpolation=interpolation, fill_color=fill_color),
246
+ ImageToNumpy(),
247
+ ]
248
+
249
+ assert use_prefetcher, "Only supporting prefetcher usage right now"
250
+
251
+ image_tf = Compose(image_tfl)
252
+ return image_tf
253
+
254
+
255
+ def transforms_coco_train(
256
+ img_size=224,
257
+ interpolation='random',
258
+ use_prefetcher=False,
259
+ fill_color='mean',
260
+ mean=IMAGENET_DEFAULT_MEAN,
261
+ std=IMAGENET_DEFAULT_STD):
262
+
263
+ fill_color = resolve_fill_color(fill_color, mean)
264
+
265
+ image_tfl = [
266
+ RandomFlip(horizontal=True, prob=0.5),
267
+ RandomResizePad(
268
+ target_size=img_size, interpolation=interpolation, fill_color=fill_color),
269
+ ImageToNumpy(),
270
+ ]
271
+
272
+ assert use_prefetcher, "Only supporting prefetcher usage right now"
273
+
274
+ image_tf = Compose(image_tfl)
275
+ return image_tf
efficientdet/effdet/data/transforms_albumentation.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+
3
+ from albumentations.augmentations.transforms import (
4
+ RandomBrightness, Downscale, RandomFog, RandomRain, RandomSnow)
5
+
6
+ from albumentations.augmentations.blur.transforms import Blur
7
+
8
+ def get_transform():
9
+ transforms = A.Compose([
10
+ #HorizontalFlip(p=0.5),
11
+ #VerticalFlip(p=0.5),
12
+ #RandomSizedBBoxSafeCrop(700, 700, erosion_rate=0.0, interpolation=1, always_apply=False, p=0.5),
13
+ Blur(blur_limit=7, always_apply=False, p=0.5),
14
+ RandomBrightness(limit=0.2, always_apply=False, p=0.5),
15
+ #Downscale(scale_min=0.5, scale_max=0.9, interpolation=0, always_apply=False, p=0.5),
16
+ #PadIfNeeded(min_height=1024, min_width=1024, pad_height_divisor=None, pad_width_divisor=None, border_mode=4, value=None, mask_value=None, always_apply=False, p=1.0),
17
+ #RandomFog(fog_coef_lower=0.3, fog_coef_upper=1, alpha_coef=0.08, always_apply=False, p=0.2),
18
+ #RandomRain(slant_lower=-10, slant_upper=10, drop_length=20, drop_width=1, drop_color=(200, 200, 200), p=0.2),
19
+ #RandomSnow(snow_point_lower=0.1, snow_point_upper=0.3, brightness_coeff=2.5, always_apply=False, p=0.2)
20
+ ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_classes'])
21
+ )
22
+ return transforms
23
+
efficientdet/effdet/distributed.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch distributed helpers
2
+
3
+ Some of this lifted from Detectron2 with other fns added by myself. Some of the Detectron2 fns
4
+ were intended for use with GLOO PG. I am using NCCL here with default PG so not everything will work
5
+ as is -RW
6
+ """
7
+ import functools
8
+ import logging
9
+ import numpy as np
10
+ import pickle
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ _LOCAL_PROCESS_GROUP = None
15
+ """
16
+ A torch process group which only includes processes that on the same machine as the current process.
17
+ This variable is set when processes are spawned by `launch()` in "engine/launch.py".
18
+ """
19
+
20
+
21
+ def get_world_size() -> int:
22
+ if not dist.is_available():
23
+ return 1
24
+ if not dist.is_initialized():
25
+ return 1
26
+ return dist.get_world_size()
27
+
28
+
29
+ def get_rank() -> int:
30
+ if not dist.is_available():
31
+ return 0
32
+ if not dist.is_initialized():
33
+ return 0
34
+ return dist.get_rank()
35
+
36
+
37
+ def get_local_rank() -> int:
38
+ """
39
+ Returns:
40
+ The rank of the current process within the local (per-machine) process group.
41
+ """
42
+ if not dist.is_available():
43
+ return 0
44
+ if not dist.is_initialized():
45
+ return 0
46
+ assert _LOCAL_PROCESS_GROUP is not None
47
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
48
+
49
+
50
+ def get_local_size() -> int:
51
+ """
52
+ Returns:
53
+ The size of the per-machine process group,
54
+ i.e. the number of processes per machine.
55
+ """
56
+ if not dist.is_available():
57
+ return 1
58
+ if not dist.is_initialized():
59
+ return 1
60
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
61
+
62
+
63
+ def is_main_process() -> bool:
64
+ return get_rank() == 0
65
+
66
+
67
+ def synchronize():
68
+ """
69
+ Helper function to synchronize (barrier) among all processes when
70
+ using distributed training
71
+ """
72
+ if not dist.is_available():
73
+ return
74
+ if not dist.is_initialized():
75
+ return
76
+ world_size = dist.get_world_size()
77
+ if world_size == 1:
78
+ return
79
+ dist.barrier()
80
+
81
+
82
+ @functools.lru_cache()
83
+ def _get_global_gloo_group():
84
+ """
85
+ Return a process group based on gloo backend, containing all the ranks
86
+ The result is cached.
87
+ """
88
+ if dist.get_backend() == "nccl":
89
+ return dist.new_group(backend="gloo")
90
+ else:
91
+ return dist.group.WORLD
92
+
93
+
94
+ def _serialize_to_tensor(data, group):
95
+ backend = dist.get_backend(group)
96
+ assert backend in ["gloo", "nccl"]
97
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
98
+
99
+ buffer = pickle.dumps(data)
100
+ if len(buffer) > 1024 ** 3:
101
+ logger = logging.getLogger(__name__)
102
+ logger.warning(
103
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
104
+ get_rank(), len(buffer) / (1024 ** 3), device
105
+ )
106
+ )
107
+ storage = torch.ByteStorage.from_buffer(buffer)
108
+ tensor = torch.ByteTensor(storage).to(device=device)
109
+ return tensor
110
+
111
+
112
+ def _pad_to_largest_tensor(tensor, group):
113
+ """
114
+ Returns:
115
+ list[int]: size of the tensor, on each rank
116
+ Tensor: padded tensor that has the max size
117
+ """
118
+ world_size = dist.get_world_size(group=group)
119
+ assert (
120
+ world_size >= 1
121
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
122
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
123
+ size_list = [
124
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
125
+ ]
126
+ dist.all_gather(size_list, local_size, group=group)
127
+ size_list = [int(size.item()) for size in size_list]
128
+
129
+ max_size = max(size_list)
130
+
131
+ # we pad the tensor because torch all_gather does not support
132
+ # gathering tensors of different shapes
133
+ if local_size != max_size:
134
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
135
+ tensor = torch.cat((tensor, padding), dim=0)
136
+ return size_list, tensor
137
+
138
+
139
+ def all_gather(data, group=None):
140
+ """
141
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
142
+ Args:
143
+ data: any picklable object
144
+ group: a torch process group. By default, will use a group which
145
+ contains all ranks on gloo backend.
146
+ Returns:
147
+ list[data]: list of data gathered from each rank
148
+ """
149
+ if get_world_size() == 1:
150
+ return [data]
151
+ if group is None:
152
+ group = _get_global_gloo_group()
153
+ if dist.get_world_size(group) == 1:
154
+ return [data]
155
+
156
+ tensor = _serialize_to_tensor(data, group)
157
+
158
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
159
+ max_size = max(size_list)
160
+
161
+ # receiving Tensor from all ranks
162
+ tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list]
163
+ dist.all_gather(tensor_list, tensor, group=group)
164
+
165
+ data_list = []
166
+ for size, tensor in zip(size_list, tensor_list):
167
+ buffer = tensor.cpu().numpy().tobytes()[:size]
168
+ data_list.append(pickle.loads(buffer))
169
+
170
+ return data_list
171
+
172
+
173
+ def gather(data, dst=0, group=None):
174
+ """
175
+ Run gather on arbitrary picklable data (not necessarily tensors).
176
+ Args:
177
+ data: any picklable object
178
+ dst (int): destination rank
179
+ group: a torch process group. By default, will use a group which
180
+ contains all ranks on gloo backend.
181
+ Returns:
182
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
183
+ an empty list.
184
+ """
185
+ if get_world_size() == 1:
186
+ return [data]
187
+ if group is None:
188
+ group = _get_global_gloo_group()
189
+ if dist.get_world_size(group=group) == 1:
190
+ return [data]
191
+ rank = dist.get_rank(group=group)
192
+
193
+ tensor = _serialize_to_tensor(data, group)
194
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
195
+
196
+ # receiving Tensor from all ranks
197
+ if rank == dst:
198
+ max_size = max(size_list)
199
+ tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list]
200
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
201
+
202
+ data_list = []
203
+ for size, tensor in zip(size_list, tensor_list):
204
+ buffer = tensor.cpu().numpy().tobytes()[:size]
205
+ data_list.append(pickle.loads(buffer))
206
+ return data_list
207
+ else:
208
+ dist.gather(tensor, [], dst=dst, group=group)
209
+ return []
210
+
211
+
212
+ def shared_random_seed():
213
+ """
214
+ Returns:
215
+ int: a random number that is the same across all workers.
216
+ If workers need a shared RNG, they can use this shared seed to
217
+ create one.
218
+ All workers must call this function, otherwise it will deadlock.
219
+ """
220
+ ints = np.random.randint(2 ** 31)
221
+ all_ints = all_gather(ints)
222
+ return all_ints[0]
223
+
224
+
225
+ def reduce_dict(input_dict, average=True):
226
+ """
227
+ Reduce the values in the dictionary from all processes so that process with rank
228
+ 0 has the reduced results.
229
+ Args:
230
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
231
+ average (bool): whether to do average or sum
232
+ Returns:
233
+ a dict with the same keys as input_dict, after reduction.
234
+ """
235
+ world_size = get_world_size()
236
+ if world_size < 2:
237
+ return input_dict
238
+ with torch.no_grad():
239
+ names = []
240
+ values = []
241
+ # sort the keys so that they are consistent across processes
242
+ for k in sorted(input_dict.keys()):
243
+ names.append(k)
244
+ values.append(input_dict[k])
245
+ values = torch.stack(values, dim=0)
246
+ dist.reduce(values, dst=0)
247
+ if dist.get_rank() == 0 and average:
248
+ # only main process gets accumulated, so only divide by
249
+ # world_size in this case
250
+ values /= world_size
251
+ reduced_dict = {k: v for k, v in zip(names, values)}
252
+ return reduced_dict
253
+
254
+
255
+ def all_gather_container(container, group=None, cat_dim=0):
256
+ group = group or dist.group.WORLD
257
+ world_size = dist.get_world_size(group)
258
+
259
+ def _do_gather(tensor):
260
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
261
+ dist.all_gather(tensor_list, tensor, group=group)
262
+ return torch.cat(tensor_list, dim=cat_dim)
263
+
264
+ if isinstance(container, dict):
265
+ gathered = dict()
266
+ for k, v in container.items():
267
+ v = _do_gather(v)
268
+ gathered[k] = v
269
+ return gathered
270
+ elif isinstance(container, (list, tuple)):
271
+ gathered = [_do_gather(v) for v in container]
272
+ if isinstance(container, tuple):
273
+ gathered = tuple(gathered)
274
+ return gathered
275
+ else:
276
+ # if not a dict, list, tuple, expect a singular tensor
277
+ assert isinstance(container, torch.Tensor)
278
+ return _do_gather(container)
279
+
280
+
281
+ def gather_container(container, dst, group=None, cat_dim=0):
282
+ group = group or dist.group.WORLD
283
+ world_size = dist.get_world_size(group)
284
+ this_rank = dist.get_rank(group)
285
+
286
+ def _do_gather(tensor):
287
+ if this_rank == dst:
288
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
289
+ else:
290
+ tensor_list = None
291
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
292
+ return torch.cat(tensor_list, dim=cat_dim)
293
+
294
+ if isinstance(container, dict):
295
+ gathered = dict()
296
+ for k, v in container.items():
297
+ v = _do_gather(v)
298
+ gathered[k] = v
299
+ return gathered
300
+ elif isinstance(container, (list, tuple)):
301
+ gathered = [_do_gather(v) for v in container]
302
+ if isinstance(container, tuple):
303
+ gathered = tuple(gathered)
304
+ return gathered
305
+ else:
306
+ # if not a dict, list, tuple, expect a singular tensor
307
+ assert isinstance(container, torch.Tensor)
308
+ return _do_gather(container)
efficientdet/effdet/efficientdet.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch EfficientDet model
2
+
3
+ Based on official Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet
4
+ Paper: https://arxiv.org/abs/1911.09070
5
+
6
+ Hacked together by Ross Wightman
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ import logging
11
+ import math
12
+ from collections import OrderedDict
13
+ from typing import List, Callable
14
+ from functools import partial
15
+
16
+
17
+ from timm import create_model
18
+ from timm.models.layers import create_conv2d, drop_path, create_pool2d, Swish, get_act_layer
19
+ from .config import get_fpn_config, set_config_writeable, set_config_readonly
20
+
21
+ _DEBUG = False
22
+
23
+ _ACT_LAYER = Swish
24
+
25
+
26
+ class SequentialList(nn.Sequential):
27
+ """ This module exists to work around torchscript typing issues list -> list"""
28
+ def __init__(self, *args):
29
+ super(SequentialList, self).__init__(*args)
30
+
31
+ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
32
+ for module in self:
33
+ x = module(x)
34
+ return x
35
+
36
+
37
+ class ConvBnAct2d(nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding='', bias=False,
39
+ norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER):
40
+ super(ConvBnAct2d, self).__init__()
41
+ self.conv = create_conv2d(
42
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias)
43
+ self.bn = None if norm_layer is None else norm_layer(out_channels)
44
+ self.act = None if act_layer is None else act_layer(inplace=True)
45
+
46
+ def forward(self, x):
47
+ x = self.conv(x)
48
+ if self.bn is not None:
49
+ x = self.bn(x)
50
+ if self.act is not None:
51
+ x = self.act(x)
52
+ return x
53
+
54
+
55
+ class SeparableConv2d(nn.Module):
56
+ """ Separable Conv
57
+ """
58
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
59
+ channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER):
60
+ super(SeparableConv2d, self).__init__()
61
+ self.conv_dw = create_conv2d(
62
+ in_channels, int(in_channels * channel_multiplier), kernel_size,
63
+ stride=stride, dilation=dilation, padding=padding, depthwise=True)
64
+
65
+ self.conv_pw = create_conv2d(
66
+ int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
67
+
68
+ self.bn = None if norm_layer is None else norm_layer(out_channels)
69
+ self.act = None if act_layer is None else act_layer(inplace=True)
70
+
71
+ def forward(self, x):
72
+ x = self.conv_dw(x)
73
+ x = self.conv_pw(x)
74
+ if self.bn is not None:
75
+ x = self.bn(x)
76
+ if self.act is not None:
77
+ x = self.act(x)
78
+ return x
79
+
80
+
81
+ class ResampleFeatureMap(nn.Sequential):
82
+
83
+ def __init__(self, in_channels, out_channels, reduction_ratio=1., pad_type='', pooling_type='max',
84
+ norm_layer=nn.BatchNorm2d, apply_bn=False, conv_after_downsample=False, redundant_bias=False):
85
+ super(ResampleFeatureMap, self).__init__()
86
+ pooling_type = pooling_type or 'max'
87
+ self.in_channels = in_channels
88
+ self.out_channels = out_channels
89
+ self.reduction_ratio = reduction_ratio
90
+ self.conv_after_downsample = conv_after_downsample
91
+
92
+ conv = None
93
+ if in_channels != out_channels:
94
+ conv = ConvBnAct2d(
95
+ in_channels, out_channels, kernel_size=1, padding=pad_type,
96
+ norm_layer=norm_layer if apply_bn else None,
97
+ bias=not apply_bn or redundant_bias, act_layer=None)
98
+
99
+ if reduction_ratio > 1:
100
+ stride_size = int(reduction_ratio)
101
+ if conv is not None and not self.conv_after_downsample:
102
+ self.add_module('conv', conv)
103
+ self.add_module(
104
+ 'downsample',
105
+ create_pool2d(
106
+ pooling_type, kernel_size=stride_size + 1, stride=stride_size, padding=pad_type))
107
+ if conv is not None and self.conv_after_downsample:
108
+ self.add_module('conv', conv)
109
+ else:
110
+ if conv is not None:
111
+ self.add_module('conv', conv)
112
+ if reduction_ratio < 1:
113
+ scale = int(1 // reduction_ratio)
114
+ self.add_module('upsample', nn.UpsamplingNearest2d(scale_factor=scale))
115
+
116
+ # def forward(self, x):
117
+ # # here for debugging only
118
+ # assert x.shape[1] == self.in_channels
119
+ # if self.reduction_ratio > 1:
120
+ # if hasattr(self, 'conv') and not self.conv_after_downsample:
121
+ # x = self.conv(x)
122
+ # x = self.downsample(x)
123
+ # if hasattr(self, 'conv') and self.conv_after_downsample:
124
+ # x = self.conv(x)
125
+ # else:
126
+ # if hasattr(self, 'conv'):
127
+ # x = self.conv(x)
128
+ # if self.reduction_ratio < 1:
129
+ # x = self.upsample(x)
130
+ # return x
131
+
132
+
133
+ class FpnCombine(nn.Module):
134
+ def __init__(self, feature_info, fpn_config, fpn_channels, inputs_offsets, target_reduction, pad_type='',
135
+ pooling_type='max', norm_layer=nn.BatchNorm2d, apply_bn_for_resampling=False,
136
+ conv_after_downsample=False, redundant_bias=False, weight_method='attn'):
137
+ super(FpnCombine, self).__init__()
138
+ self.inputs_offsets = inputs_offsets
139
+ self.weight_method = weight_method
140
+
141
+ self.resample = nn.ModuleDict()
142
+ for idx, offset in enumerate(inputs_offsets):
143
+ in_channels = fpn_channels
144
+ if offset < len(feature_info):
145
+ in_channels = feature_info[offset]['num_chs']
146
+ input_reduction = feature_info[offset]['reduction']
147
+ else:
148
+ node_idx = offset - len(feature_info)
149
+ input_reduction = fpn_config.nodes[node_idx]['reduction']
150
+ reduction_ratio = target_reduction / input_reduction
151
+ self.resample[str(offset)] = ResampleFeatureMap(
152
+ in_channels, fpn_channels, reduction_ratio=reduction_ratio, pad_type=pad_type,
153
+ pooling_type=pooling_type, norm_layer=norm_layer, apply_bn=apply_bn_for_resampling,
154
+ conv_after_downsample=conv_after_downsample, redundant_bias=redundant_bias)
155
+
156
+ if weight_method == 'attn' or weight_method == 'fastattn':
157
+ self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets)), requires_grad=True) # WSM
158
+ else:
159
+ self.edge_weights = None
160
+
161
+ def forward(self, x: List[torch.Tensor]):
162
+ dtype = x[0].dtype
163
+ nodes = []
164
+ for offset, resample in zip(self.inputs_offsets, self.resample.values()):
165
+ input_node = x[offset]
166
+ input_node = resample(input_node)
167
+ nodes.append(input_node)
168
+
169
+ if self.weight_method == 'attn':
170
+ normalized_weights = torch.softmax(self.edge_weights.to(dtype=dtype), dim=0)
171
+ out = torch.stack(nodes, dim=-1) * normalized_weights
172
+ elif self.weight_method == 'fastattn':
173
+ edge_weights = nn.functional.relu(self.edge_weights.to(dtype=dtype))
174
+ weights_sum = torch.sum(edge_weights)
175
+ out = torch.stack(
176
+ [(nodes[i] * edge_weights[i]) / (weights_sum + 0.0001) for i in range(len(nodes))], dim=-1)
177
+ elif self.weight_method == 'sum':
178
+ out = torch.stack(nodes, dim=-1)
179
+ else:
180
+ raise ValueError('unknown weight_method {}'.format(self.weight_method))
181
+ out = torch.sum(out, dim=-1)
182
+ return out
183
+
184
+
185
+ class Fnode(nn.Module):
186
+ """ A simple wrapper used in place of nn.Sequential for torchscript typing
187
+ Handles input type List[Tensor] -> output type Tensor
188
+ """
189
+ def __init__(self, combine: nn.Module, after_combine: nn.Module):
190
+ super(Fnode, self).__init__()
191
+ self.combine = combine
192
+ self.after_combine = after_combine
193
+
194
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
195
+ return self.after_combine(self.combine(x))
196
+
197
+
198
+ class BiFpnLayer(nn.Module):
199
+ def __init__(self, feature_info, fpn_config, fpn_channels, num_levels=5, pad_type='',
200
+ pooling_type='max', norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER,
201
+ apply_bn_for_resampling=False, conv_after_downsample=True, conv_bn_relu_pattern=False,
202
+ separable_conv=True, redundant_bias=False):
203
+ super(BiFpnLayer, self).__init__()
204
+ self.num_levels = num_levels
205
+ self.conv_bn_relu_pattern = False
206
+
207
+ self.feature_info = []
208
+ self.fnode = nn.ModuleList()
209
+ for i, fnode_cfg in enumerate(fpn_config.nodes):
210
+ logging.debug('fnode {} : {}'.format(i, fnode_cfg))
211
+ reduction = fnode_cfg['reduction']
212
+ combine = FpnCombine(
213
+ feature_info, fpn_config, fpn_channels, tuple(fnode_cfg['inputs_offsets']),
214
+ target_reduction=reduction, pad_type=pad_type, pooling_type=pooling_type, norm_layer=norm_layer,
215
+ apply_bn_for_resampling=apply_bn_for_resampling, conv_after_downsample=conv_after_downsample,
216
+ redundant_bias=redundant_bias, weight_method=fnode_cfg['weight_method'])
217
+
218
+ after_combine = nn.Sequential()
219
+ conv_kwargs = dict(
220
+ in_channels=fpn_channels, out_channels=fpn_channels, kernel_size=3, padding=pad_type,
221
+ bias=False, norm_layer=norm_layer, act_layer=act_layer)
222
+ if not conv_bn_relu_pattern:
223
+ conv_kwargs['bias'] = redundant_bias
224
+ conv_kwargs['act_layer'] = None
225
+ after_combine.add_module('act', act_layer(inplace=True))
226
+ after_combine.add_module(
227
+ 'conv', SeparableConv2d(**conv_kwargs) if separable_conv else ConvBnAct2d(**conv_kwargs))
228
+
229
+ self.fnode.append(Fnode(combine=combine, after_combine=after_combine))
230
+ self.feature_info.append(dict(num_chs=fpn_channels, reduction=reduction))
231
+
232
+ self.feature_info = self.feature_info[-num_levels::]
233
+
234
+ def forward(self, x: List[torch.Tensor]):
235
+ for fn in self.fnode:
236
+ x.append(fn(x))
237
+ return x[-self.num_levels::]
238
+
239
+
240
+ class BiFpn(nn.Module):
241
+
242
+ def __init__(self, config, feature_info):
243
+ super(BiFpn, self).__init__()
244
+ self.num_levels = config.num_levels
245
+ norm_layer = config.norm_layer or nn.BatchNorm2d
246
+ if config.norm_kwargs:
247
+ norm_layer = partial(norm_layer, **config.norm_kwargs)
248
+ act_layer = get_act_layer(config.act_type) or _ACT_LAYER
249
+ fpn_config = config.fpn_config or get_fpn_config(
250
+ config.fpn_name, min_level=config.min_level, max_level=config.max_level)
251
+
252
+ self.resample = nn.ModuleDict()
253
+ for level in range(config.num_levels):
254
+ if level < len(feature_info):
255
+ in_chs = feature_info[level]['num_chs']
256
+ reduction = feature_info[level]['reduction']
257
+ else:
258
+ # Adds a coarser level by downsampling the last feature map
259
+ reduction_ratio = 2
260
+ self.resample[str(level)] = ResampleFeatureMap(
261
+ in_channels=in_chs,
262
+ out_channels=config.fpn_channels,
263
+ pad_type=config.pad_type,
264
+ pooling_type=config.pooling_type,
265
+ norm_layer=norm_layer,
266
+ reduction_ratio=reduction_ratio,
267
+ apply_bn=config.apply_bn_for_resampling,
268
+ conv_after_downsample=config.conv_after_downsample,
269
+ redundant_bias=config.redundant_bias,
270
+ )
271
+ in_chs = config.fpn_channels
272
+ reduction = int(reduction * reduction_ratio)
273
+ feature_info.append(dict(num_chs=in_chs, reduction=reduction))
274
+
275
+ self.cell = SequentialList()
276
+ for rep in range(config.fpn_cell_repeats):
277
+ logging.debug('building cell {}'.format(rep))
278
+ fpn_layer = BiFpnLayer(
279
+ feature_info=feature_info,
280
+ fpn_config=fpn_config,
281
+ fpn_channels=config.fpn_channels,
282
+ num_levels=config.num_levels,
283
+ pad_type=config.pad_type,
284
+ pooling_type=config.pooling_type,
285
+ norm_layer=norm_layer,
286
+ act_layer=act_layer,
287
+ separable_conv=config.separable_conv,
288
+ apply_bn_for_resampling=config.apply_bn_for_resampling,
289
+ conv_after_downsample=config.conv_after_downsample,
290
+ conv_bn_relu_pattern=config.conv_bn_relu_pattern,
291
+ redundant_bias=config.redundant_bias,
292
+ )
293
+ self.cell.add_module(str(rep), fpn_layer)
294
+ feature_info = fpn_layer.feature_info
295
+
296
+ def forward(self, x: List[torch.Tensor]):
297
+ for resample in self.resample.values():
298
+ x.append(resample(x[-1]))
299
+ x = self.cell(x)
300
+ return x
301
+
302
+
303
+ class HeadNet(nn.Module):
304
+
305
+ def __init__(self, config, num_outputs):
306
+ super(HeadNet, self).__init__()
307
+ self.num_levels = config.num_levels
308
+ self.bn_level_first = getattr(config, 'head_bn_level_first', False)
309
+ norm_layer = config.norm_layer or nn.BatchNorm2d
310
+ if config.norm_kwargs:
311
+ norm_layer = partial(norm_layer, **config.norm_kwargs)
312
+ act_layer = get_act_layer(config.act_type) or _ACT_LAYER
313
+
314
+ # Build convolution repeats
315
+ conv_fn = SeparableConv2d if config.separable_conv else ConvBnAct2d
316
+ conv_kwargs = dict(
317
+ in_channels=config.fpn_channels, out_channels=config.fpn_channels, kernel_size=3,
318
+ padding=config.pad_type, bias=config.redundant_bias, act_layer=None, norm_layer=None)
319
+ self.conv_rep = nn.ModuleList([conv_fn(**conv_kwargs) for _ in range(config.box_class_repeats)])
320
+
321
+ # Build batchnorm repeats. There is a unique batchnorm per feature level for each repeat.
322
+ # This can be organized with repeats first or feature levels first in module lists, the original models
323
+ # and weights were setup with repeats first, levels first is required for efficient torchscript usage.
324
+ self.bn_rep = nn.ModuleList()
325
+ if self.bn_level_first:
326
+ for _ in range(self.num_levels):
327
+ self.bn_rep.append(nn.ModuleList([
328
+ norm_layer(config.fpn_channels) for _ in range(config.box_class_repeats)]))
329
+ else:
330
+ for _ in range(config.box_class_repeats):
331
+ self.bn_rep.append(nn.ModuleList([
332
+ nn.Sequential(OrderedDict([('bn', norm_layer(config.fpn_channels))]))
333
+ for _ in range(self.num_levels)]))
334
+
335
+ self.act = act_layer(inplace=True)
336
+
337
+ # Prediction (output) layer. Has bias with special init reqs, see init fn.
338
+ num_anchors = len(config.aspect_ratios) * config.num_scales
339
+ predict_kwargs = dict(
340
+ in_channels=config.fpn_channels, out_channels=num_outputs * num_anchors, kernel_size=3,
341
+ padding=config.pad_type, bias=True, norm_layer=None, act_layer=None)
342
+ self.predict = conv_fn(**predict_kwargs)
343
+
344
+ @torch.jit.ignore()
345
+ def toggle_bn_level_first(self):
346
+ """ Toggle the batchnorm layers between feature level first vs repeat first access pattern
347
+ Limitations in torchscript require feature levels to be iterated over first.
348
+
349
+ This function can be used to allow loading weights in the original order, and then toggle before
350
+ jit scripting the model.
351
+ """
352
+ with torch.no_grad():
353
+ new_bn_rep = nn.ModuleList()
354
+ for i in range(len(self.bn_rep[0])):
355
+ bn_first = nn.ModuleList()
356
+ for r in self.bn_rep.children():
357
+ m = r[i]
358
+ # NOTE original rep first model def has extra Sequential container with 'bn', this was
359
+ # flattened in the level first definition.
360
+ bn_first.append(m[0] if isinstance(m, nn.Sequential) else nn.Sequential(OrderedDict([('bn', m)])))
361
+ new_bn_rep.append(bn_first)
362
+ self.bn_level_first = not self.bn_level_first
363
+ self.bn_rep = new_bn_rep
364
+
365
+ @torch.jit.ignore()
366
+ def _forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
367
+ outputs = []
368
+ for level in range(self.num_levels):
369
+ x_level = x[level]
370
+ for conv, bn in zip(self.conv_rep, self.bn_rep):
371
+ x_level = conv(x_level)
372
+ x_level = bn[level](x_level) # this is not allowed in torchscript
373
+ x_level = self.act(x_level)
374
+ outputs.append(self.predict(x_level))
375
+ return outputs
376
+
377
+ def _forward_level_first(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
378
+ outputs = []
379
+ for level, bn_rep in enumerate(self.bn_rep): # iterating over first bn dim first makes TS happy
380
+ x_level = x[level]
381
+ for conv, bn in zip(self.conv_rep, bn_rep):
382
+ x_level = conv(x_level)
383
+ x_level = bn(x_level)
384
+ x_level = self.act(x_level)
385
+ outputs.append(self.predict(x_level))
386
+ return outputs
387
+
388
+ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
389
+ if self.bn_level_first:
390
+ return self._forward_level_first(x)
391
+ else:
392
+ return self._forward(x)
393
+
394
+
395
+ def _init_weight(m, n='', ):
396
+ """ Weight initialization as per Tensorflow official implementations.
397
+ """
398
+
399
+ def _fan_in_out(w, groups=1):
400
+ dimensions = w.dim()
401
+ if dimensions < 2:
402
+ raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
403
+ num_input_fmaps = w.size(1)
404
+ num_output_fmaps = w.size(0)
405
+ receptive_field_size = 1
406
+ if w.dim() > 2:
407
+ receptive_field_size = w[0][0].numel()
408
+ fan_in = num_input_fmaps * receptive_field_size
409
+ fan_out = num_output_fmaps * receptive_field_size
410
+ fan_out //= groups
411
+ return fan_in, fan_out
412
+
413
+ def _glorot_uniform(w, gain=1, groups=1):
414
+ fan_in, fan_out = _fan_in_out(w, groups)
415
+ gain /= max(1., (fan_in + fan_out) / 2.) # fan avg
416
+ limit = math.sqrt(3.0 * gain)
417
+ w.data.uniform_(-limit, limit)
418
+
419
+ def _variance_scaling(w, gain=1, groups=1):
420
+ fan_in, fan_out = _fan_in_out(w, groups)
421
+ gain /= max(1., fan_in) # fan in
422
+ # gain /= max(1., (fan_in + fan_out) / 2.) # fan
423
+
424
+ # should it be normal or trunc normal? using normal for now since no good trunc in PT
425
+ # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
426
+ # std = math.sqrt(gain) / .87962566103423978
427
+ # w.data.trunc_normal(std=std)
428
+ std = math.sqrt(gain)
429
+ w.data.normal_(std=std)
430
+
431
+ if isinstance(m, SeparableConv2d):
432
+ if 'box_net' in n or 'class_net' in n:
433
+ _variance_scaling(m.conv_dw.weight, groups=m.conv_dw.groups)
434
+ _variance_scaling(m.conv_pw.weight)
435
+ if m.conv_pw.bias is not None:
436
+ if 'class_net.predict' in n:
437
+ m.conv_pw.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
438
+ else:
439
+ m.conv_pw.bias.data.zero_()
440
+ else:
441
+ _glorot_uniform(m.conv_dw.weight, groups=m.conv_dw.groups)
442
+ _glorot_uniform(m.conv_pw.weight)
443
+ if m.conv_pw.bias is not None:
444
+ m.conv_pw.bias.data.zero_()
445
+ elif isinstance(m, ConvBnAct2d):
446
+ if 'box_net' in n or 'class_net' in n:
447
+ m.conv.weight.data.normal_(std=.01)
448
+ if m.conv.bias is not None:
449
+ if 'class_net.predict' in n:
450
+ m.conv.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
451
+ else:
452
+ m.conv.bias.data.zero_()
453
+ else:
454
+ _glorot_uniform(m.conv.weight)
455
+ if m.conv.bias is not None:
456
+ m.conv.bias.data.zero_()
457
+ elif isinstance(m, nn.BatchNorm2d):
458
+ # looks like all bn init the same?
459
+ m.weight.data.fill_(1.0)
460
+ m.bias.data.zero_()
461
+
462
+
463
+ def _init_weight_alt(m, n='', ):
464
+ """ Weight initialization alternative, based on EfficientNet bacbkone init w/ class bias addition
465
+ NOTE: this will likely be removed after some experimentation
466
+ """
467
+ if isinstance(m, nn.Conv2d):
468
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
469
+ fan_out //= m.groups
470
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
471
+ if m.bias is not None:
472
+ if 'class_net.predict' in n:
473
+ m.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
474
+ else:
475
+ m.bias.data.zero_()
476
+ elif isinstance(m, nn.BatchNorm2d):
477
+ m.weight.data.fill_(1.0)
478
+ m.bias.data.zero_()
479
+
480
+
481
+ def get_feature_info(backbone):
482
+ if isinstance(backbone.feature_info, Callable):
483
+ # old accessor for timm versions <= 0.1.30, efficientnet and mobilenetv3 and related nets only
484
+ feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction'])
485
+ for i, f in enumerate(backbone.feature_info())]
486
+ else:
487
+ # new feature info accessor, timm >= 0.2, all models supported
488
+ feature_info = backbone.feature_info.get_dicts(keys=['num_chs', 'reduction'])
489
+ return feature_info
490
+
491
+
492
+ class EfficientDet(nn.Module):
493
+
494
+ def __init__(self, config, pretrained_backbone=True, alternate_init=False):
495
+ super(EfficientDet, self).__init__()
496
+ self.config = config
497
+ set_config_readonly(self.config)
498
+ self.backbone = create_model(
499
+ config.backbone_name, features_only=True, out_indices=(2, 3, 4),
500
+ pretrained=pretrained_backbone, **config.backbone_args)
501
+ feature_info = get_feature_info(self.backbone)
502
+ self.fpn = BiFpn(self.config, feature_info)
503
+ self.class_net = HeadNet(self.config, num_outputs=self.config.num_classes)
504
+ self.box_net = HeadNet(self.config, num_outputs=4)
505
+
506
+ for n, m in self.named_modules():
507
+ if 'backbone' not in n:
508
+ if alternate_init:
509
+ _init_weight_alt(m, n)
510
+ else:
511
+ _init_weight(m, n)
512
+
513
+ @torch.jit.ignore()
514
+ def reset_head(self, num_classes=None, aspect_ratios=None, num_scales=None, alternate_init=False):
515
+ reset_class_head = False
516
+ reset_box_head = False
517
+ set_config_writeable(self.config)
518
+ if num_classes is not None:
519
+ reset_class_head = True
520
+ self.config.num_classes = num_classes
521
+ if aspect_ratios is not None:
522
+ reset_box_head = True
523
+ self.config.aspect_ratios = aspect_ratios
524
+ if num_scales is not None:
525
+ reset_box_head = True
526
+ self.config.num_scales = num_scales
527
+ set_config_readonly(self.config)
528
+
529
+ if reset_class_head:
530
+ self.class_net = HeadNet(self.config, num_outputs=self.config.num_classes)
531
+ for n, m in self.class_net.named_modules(prefix='class_net'):
532
+ if alternate_init:
533
+ _init_weight_alt(m, n)
534
+ else:
535
+ _init_weight(m, n)
536
+
537
+ if reset_box_head:
538
+ self.box_net = HeadNet(self.config, num_outputs=4)
539
+ for n, m in self.box_net.named_modules(prefix='box_net'):
540
+ if alternate_init:
541
+ _init_weight_alt(m, n)
542
+ else:
543
+ _init_weight(m, n)
544
+
545
+ @torch.jit.ignore()
546
+ def toggle_head_bn_level_first(self):
547
+ """ Toggle the head batchnorm layers between being access with feature_level first vs repeat
548
+ """
549
+ self.class_net.toggle_bn_level_first()
550
+ self.box_net.toggle_bn_level_first()
551
+
552
+ def forward(self, x):
553
+ x = self.backbone(x)
554
+ x = self.fpn(x)
555
+ x_class = self.class_net(x)
556
+ x_box = self.box_net(x)
557
+ return x_class, x_box
efficientdet/effdet/evaluation/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Tensorflow Models Evaluation
2
+
3
+ The code in this folder has been extracted and adapted from evaluation/evaluator code at https://github.com/tensorflow/models/tree/master/research/object_detection/utils
4
+
5
+ Original code is licensed Apache 2.0, Copyright Google Inc.
6
+ https://github.com/tensorflow/models/blob/master/LICENSE
7
+
efficientdet/effdet/evaluation/__init__.py ADDED
File without changes
efficientdet/effdet/evaluation/detection_evaluator.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta
2
+ from abc import abstractmethod
3
+ #import collections
4
+ import logging
5
+ import unicodedata
6
+ import numpy as np
7
+
8
+ from .fields import InputDataFields, DetectionResultFields
9
+ from .object_detection_evaluation import ObjectDetectionEvaluation
10
+
11
+
12
+ def create_category_index(categories):
13
+ """Creates dictionary of COCO compatible categories keyed by category id.
14
+ Args:
15
+ categories: a list of dicts, each of which has the following keys:
16
+ 'id': (required) an integer id uniquely identifying this category.
17
+ 'name': (required) string representing category name e.g., 'cat', 'dog', 'pizza'.
18
+ Returns:
19
+ category_index: a dict containing the same entries as categories, but keyed
20
+ by the 'id' field of each category.
21
+ """
22
+ category_index = {}
23
+ for cat in categories:
24
+ category_index[cat['id']] = cat
25
+ return category_index
26
+
27
+
28
+ class DetectionEvaluator(metaclass=ABCMeta):
29
+ """Interface for object detection evalution classes.
30
+ Example usage of the Evaluator:
31
+ ------------------------------
32
+ evaluator = DetectionEvaluator(categories)
33
+ # Detections and groundtruth for image 1.
34
+ evaluator.add_single_gt_image_info(...)
35
+ evaluator.add_single_detected_image_info(...)
36
+ # Detections and groundtruth for image 2.
37
+ evaluator.add_single_gt_image_info(...)
38
+ evaluator.add_single_detected_image_info(...)
39
+ metrics_dict = evaluator.evaluation()
40
+ """
41
+
42
+ def __init__(self, categories):
43
+ """Constructor.
44
+ Args:
45
+ categories: A list of dicts, each of which has the following keys -
46
+ 'id': (required) an integer id uniquely identifying this category.
47
+ 'name': (required) string representing category name e.g., 'cat', 'dog'.
48
+ """
49
+ self._categories = categories
50
+
51
+ def observe_result_dict_for_single_example(self, eval_dict):
52
+ """Observes an evaluation result dict for a single example.
53
+ When executing eagerly, once all observations have been observed by this
54
+ method you can use `.evaluation()` to get the final metrics.
55
+ When using `tf.estimator.Estimator` for evaluation this function is used by
56
+ `get_estimator_eval_metric_ops()` to construct the metric update op.
57
+ Args:
58
+ eval_dict: A dictionary that holds tensors for evaluating an object
59
+ detection model, returned from
60
+ eval_util.result_dict_for_single_example().
61
+ Returns:
62
+ None when executing eagerly, or an update_op that can be used to update
63
+ the eval metrics in `tf.estimator.EstimatorSpec`.
64
+ """
65
+ raise NotImplementedError('Not implemented for this evaluator!')
66
+
67
+ @abstractmethod
68
+ def add_single_ground_truth_image_info(self, image_id, gt_dict):
69
+ """Adds groundtruth for a single image to be used for evaluation.
70
+ Args:
71
+ image_id: A unique string/integer identifier for the image.
72
+ gt_dict: A dictionary of groundtruth numpy arrays required for evaluations.
73
+ """
74
+ pass
75
+
76
+ @abstractmethod
77
+ def add_single_detected_image_info(self, image_id, detections_dict):
78
+ """Adds detections for a single image to be used for evaluation.
79
+ Args:
80
+ image_id: A unique string/integer identifier for the image.
81
+ detections_dict: A dictionary of detection numpy arrays required for evaluation.
82
+ """
83
+ pass
84
+
85
+ @abstractmethod
86
+ def evaluate(self):
87
+ """Evaluates detections and returns a dictionary of metrics."""
88
+ pass
89
+
90
+ @abstractmethod
91
+ def clear(self):
92
+ """Clears the state to prepare for a fresh evaluation."""
93
+ pass
94
+
95
+
96
+ class ObjectDetectionEvaluator(DetectionEvaluator):
97
+ """A class to evaluation detections."""
98
+
99
+ def __init__(self,
100
+ categories,
101
+ matching_iou_threshold=0.5,
102
+ recall_lower_bound=0.0,
103
+ recall_upper_bound=1.0,
104
+ evaluate_corlocs=False,
105
+ evaluate_precision_recall=False,
106
+ metric_prefix=None,
107
+ use_weighted_mean_ap=False,
108
+ evaluate_masks=False,
109
+ group_of_weight=0.0):
110
+ """Constructor.
111
+ Args:
112
+ categories: A list of dicts, each of which has the following keys -
113
+ 'id': (required) an integer id uniquely identifying this category.
114
+ 'name': (required) string representing category name e.g., 'cat', 'dog'.
115
+ matching_iou_threshold: IOU threshold to use for matching groundtruth boxes to detection boxes.
116
+ recall_lower_bound: lower bound of recall operating area.
117
+ recall_upper_bound: upper bound of recall operating area.
118
+ evaluate_corlocs: (optional) boolean which determines if corloc scores are to be returned or not.
119
+ evaluate_precision_recall: (optional) boolean which determines if
120
+ precision and recall values are to be returned or not.
121
+ metric_prefix: (optional) string prefix for metric name; if None, no prefix is used.
122
+ use_weighted_mean_ap: (optional) boolean which determines if the mean
123
+ average precision is computed directly from the scores and tp_fp_labels of all classes.
124
+ evaluate_masks: If False, evaluation will be performed based on boxes. If
125
+ True, mask evaluation will be performed instead.
126
+ group_of_weight: Weight of group-of boxes.If set to 0, detections of the
127
+ correct class within a group-of box are ignored. If weight is > 0, then
128
+ if at least one detection falls within a group-of box with
129
+ matching_iou_threshold, weight group_of_weight is added to true
130
+ positives. Consequently, if no detection falls within a group-of box,
131
+ weight group_of_weight is added to false negatives.
132
+ Raises:
133
+ ValueError: If the category ids are not 1-indexed.
134
+ """
135
+ super(ObjectDetectionEvaluator, self).__init__(categories)
136
+ self._num_classes = max([cat['id'] for cat in categories])
137
+ if min(cat['id'] for cat in categories) < 1:
138
+ raise ValueError('Classes should be 1-indexed.')
139
+ self._matching_iou_threshold = matching_iou_threshold
140
+ self._recall_lower_bound = recall_lower_bound
141
+ self._recall_upper_bound = recall_upper_bound
142
+ self._use_weighted_mean_ap = use_weighted_mean_ap
143
+ self._label_id_offset = 1
144
+ self._evaluate_masks = evaluate_masks
145
+ self._group_of_weight = group_of_weight
146
+ self._evaluation = ObjectDetectionEvaluation(
147
+ num_gt_classes=self._num_classes,
148
+ matching_iou_threshold=self._matching_iou_threshold,
149
+ recall_lower_bound=self._recall_lower_bound,
150
+ recall_upper_bound=self._recall_upper_bound,
151
+ use_weighted_mean_ap=self._use_weighted_mean_ap,
152
+ label_id_offset=self._label_id_offset,
153
+ group_of_weight=self._group_of_weight)
154
+ self._image_ids = set([])
155
+ self._evaluate_corlocs = evaluate_corlocs
156
+ self._evaluate_precision_recall = evaluate_precision_recall
157
+ self._metric_prefix = (metric_prefix + '_') if metric_prefix else ''
158
+ self._build_metric_names()
159
+
160
+ def _build_metric_names(self):
161
+ """Builds a list with metric names."""
162
+ if self._recall_lower_bound > 0.0 or self._recall_upper_bound < 1.0:
163
+ self._metric_names = [
164
+ self._metric_prefix + 'Precision/mAP@{}IOU@[{:.1f},{:.1f}]Recall'.format(
165
+ self._matching_iou_threshold, self._recall_lower_bound, self._recall_upper_bound)
166
+ ]
167
+ else:
168
+ self._metric_names = [
169
+ self._metric_prefix + 'Precision/mAP@{}IOU'.format(self._matching_iou_threshold)
170
+ ]
171
+ if self._evaluate_corlocs:
172
+ self._metric_names.append(
173
+ self._metric_prefix + 'Precision/meanCorLoc@{}IOU'.format(self._matching_iou_threshold))
174
+
175
+ category_index = create_category_index(self._categories)
176
+ for idx in range(self._num_classes):
177
+ if idx + self._label_id_offset in category_index:
178
+ category_name = category_index[idx + self._label_id_offset]['name']
179
+ category_name = unicodedata.normalize('NFKD', category_name)
180
+ self._metric_names.append(
181
+ self._metric_prefix + 'PerformanceByCategory/AP@{}IOU/{}'.format(
182
+ self._matching_iou_threshold, category_name))
183
+ if self._evaluate_corlocs:
184
+ self._metric_names.append(
185
+ self._metric_prefix + 'PerformanceByCategory/CorLoc@{}IOU/{}'.format(
186
+ self._matching_iou_threshold, category_name))
187
+
188
+ def add_single_ground_truth_image_info(self, image_id, gt_dict):
189
+ """Adds groundtruth for a single image to be used for evaluation.
190
+ Args:
191
+ image_id: A unique string/integer identifier for the image.
192
+ gt_dict: A dictionary containing -
193
+ InputDataFields.gt_boxes: float32 numpy array
194
+ of shape [num_boxes, 4] containing `num_boxes` groundtruth boxes of
195
+ the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
196
+ InputDataFields.gt_classes: integer numpy array
197
+ of shape [num_boxes] containing 1-indexed groundtruth classes for the boxes.
198
+ InputDataFields.gt_difficult: Optional length M numpy boolean array
199
+ denoting whether a ground truth box is a difficult instance or not.
200
+ This field is optional to support the case that no boxes are difficult.
201
+ InputDataFields.gt_instance_masks: Optional numpy array of shape
202
+ [num_boxes, height, width] with values in {0, 1}.
203
+ Raises:
204
+ ValueError: On adding groundtruth for an image more than once. Will also
205
+ raise error if instance masks are not in groundtruth dictionary.
206
+ """
207
+ if image_id in self._image_ids:
208
+ return
209
+
210
+ gt_classes = gt_dict[InputDataFields.gt_classes] - self._label_id_offset
211
+ # If the key is not present in the gt_dict or the array is empty
212
+ # (unless there are no annotations for the groundtruth on this image)
213
+ # use values from the dictionary or insert None otherwise.
214
+ if (InputDataFields.gt_difficult in gt_dict and
215
+ (gt_dict[InputDataFields.gt_difficult].size or not gt_classes.size)):
216
+ gt_difficult = gt_dict[InputDataFields.gt_difficult]
217
+ else:
218
+ gt_difficult = None
219
+ # FIXME disable difficult flag warning, will support flag eventually
220
+ # if not len(self._image_ids) % 1000:
221
+ # logging.warning('image %s does not have groundtruth difficult flag specified', image_id)
222
+ gt_masks = None
223
+ if self._evaluate_masks:
224
+ if InputDataFields.gt_instance_masks not in gt_dict:
225
+ raise ValueError('Instance masks not in groundtruth dictionary.')
226
+ gt_masks = gt_dict[InputDataFields.gt_instance_masks]
227
+ self._evaluation.add_single_ground_truth_image_info(
228
+ image_key=image_id,
229
+ gt_boxes=gt_dict[InputDataFields.gt_boxes],
230
+ gt_class_labels=gt_classes,
231
+ gt_is_difficult_list=gt_difficult,
232
+ gt_masks=gt_masks)
233
+ self._image_ids.update([image_id])
234
+
235
+ def add_single_detected_image_info(self, image_id, detections_dict):
236
+ """Adds detections for a single image to be used for evaluation.
237
+ Args:
238
+ image_id: A unique string/integer identifier for the image.
239
+ detections_dict: A dictionary containing -
240
+ DetectionResultFields.detection_boxes: float32 numpy
241
+ array of shape [num_boxes, 4] containing `num_boxes` detection boxes
242
+ of the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
243
+ DetectionResultFields.detection_scores: float32 numpy
244
+ array of shape [num_boxes] containing detection scores for the boxes.
245
+ DetectionResultFields.detection_classes: integer numpy
246
+ array of shape [num_boxes] containing 1-indexed detection classes for the boxes.
247
+ DetectionResultFields.detection_masks: uint8 numpy array
248
+ of shape [num_boxes, height, width] containing `num_boxes` masks of
249
+ values ranging between 0 and 1.
250
+ Raises:
251
+ ValueError: If detection masks are not in detections dictionary.
252
+ """
253
+ detection_classes = detections_dict[DetectionResultFields.detection_classes] - self._label_id_offset
254
+ detection_masks = None
255
+ if self._evaluate_masks:
256
+ if DetectionResultFields.detection_masks not in detections_dict:
257
+ raise ValueError('Detection masks not in detections dictionary.')
258
+ detection_masks = detections_dict[DetectionResultFields.detection_masks]
259
+ self._evaluation.add_single_detected_image_info(
260
+ image_key=image_id,
261
+ detected_boxes=detections_dict[DetectionResultFields.detection_boxes],
262
+ detected_scores=detections_dict[DetectionResultFields.detection_scores],
263
+ detected_class_labels=detection_classes,
264
+ detected_masks=detection_masks)
265
+
266
+ def evaluate(self):
267
+ """Compute evaluation result.
268
+ Returns:
269
+ A dictionary of metrics with the following fields -
270
+ 1. summary_metrics:
271
+ '<prefix if not empty>_Precision/mAP@<matching_iou_threshold>IOU': mean
272
+ average precision at the specified IOU threshold.
273
+ 2. per_category_ap: category specific results with keys of the form
274
+ '<prefix if not empty>_PerformanceByCategory/
275
+ mAP@<matching_iou_threshold>IOU/category'.
276
+ """
277
+ metrics = self._evaluation.evaluate()
278
+ pascal_metrics = {self._metric_names[0]: metrics['mean_ap']}
279
+ if self._evaluate_corlocs:
280
+ pascal_metrics[self._metric_names[1]] = metrics['mean_corloc']
281
+ category_index = create_category_index(self._categories)
282
+ for idx in range(metrics['per_class_ap'].size):
283
+ if idx + self._label_id_offset in category_index:
284
+ category_name = category_index[idx + self._label_id_offset]['name']
285
+ category_name = unicodedata.normalize('NFKD', category_name)
286
+ display_name = self._metric_prefix + 'PerformanceByCategory/AP@{}IOU/{}'.format(
287
+ self._matching_iou_threshold, category_name)
288
+ pascal_metrics[display_name] = metrics['per_class_ap'][idx]
289
+
290
+ # Optionally add precision and recall values
291
+ if self._evaluate_precision_recall:
292
+ display_name = self._metric_prefix + 'PerformanceByCategory/Precision@{}IOU/{}'.format(
293
+ self._matching_iou_threshold, category_name)
294
+ pascal_metrics[display_name] = metrics['per_class_precision'][idx]
295
+ display_name = self._metric_prefix + 'PerformanceByCategory/Recall@{}IOU/{}'.format(
296
+ self._matching_iou_threshold, category_name)
297
+ pascal_metrics[display_name] = metrics['per_class_precision'][idx]
298
+
299
+ # Optionally add CorLoc metrics.classes
300
+ if self._evaluate_corlocs:
301
+ display_name = self._metric_prefix + 'PerformanceByCategory/CorLoc@{}IOU/{}'.format(
302
+ self._matching_iou_threshold, category_name)
303
+ pascal_metrics[display_name] = metrics['per_class_corloc'][idx]
304
+
305
+ return pascal_metrics
306
+
307
+ def clear(self):
308
+ """Clears the state to prepare for a fresh evaluation."""
309
+ self._evaluation = ObjectDetectionEvaluation(
310
+ num_gt_classes=self._num_classes,
311
+ matching_iou_threshold=self._matching_iou_threshold,
312
+ use_weighted_mean_ap=self._use_weighted_mean_ap,
313
+ label_id_offset=self._label_id_offset)
314
+ self._image_ids.clear()
315
+
316
+
317
+ class PascalDetectionEvaluator(ObjectDetectionEvaluator):
318
+ """A class to evaluation detections using PASCAL metrics."""
319
+
320
+ def __init__(self, categories, matching_iou_threshold=0.5):
321
+ super(PascalDetectionEvaluator, self).__init__(
322
+ categories,
323
+ matching_iou_threshold=matching_iou_threshold,
324
+ evaluate_corlocs=False,
325
+ metric_prefix='PascalBoxes',
326
+ use_weighted_mean_ap=False)
327
+
328
+
329
+ class WeightedPascalDetectionEvaluator(ObjectDetectionEvaluator):
330
+ """A class to evaluation detections using weighted PASCAL metrics.
331
+ Weighted PASCAL metrics computes the mean average precision as the average
332
+ precision given the scores and tp_fp_labels of all classes. In comparison,
333
+ PASCAL metrics computes the mean average precision as the mean of the
334
+ per-class average precisions.
335
+ This definition is very similar to the mean of the per-class average
336
+ precisions weighted by class frequency. However, they are typically not the
337
+ same as the average precision is not a linear function of the scores and
338
+ tp_fp_labels.
339
+ """
340
+
341
+ def __init__(self, categories, matching_iou_threshold=0.5):
342
+ super(WeightedPascalDetectionEvaluator, self).__init__(
343
+ categories,
344
+ matching_iou_threshold=matching_iou_threshold,
345
+ evaluate_corlocs=False,
346
+ metric_prefix='WeightedPascalBoxes',
347
+ use_weighted_mean_ap=True)
348
+
349
+
350
+ class PrecisionAtRecallDetectionEvaluator(ObjectDetectionEvaluator):
351
+ """A class to evaluation detections using precision@recall metrics."""
352
+
353
+ def __init__(self,
354
+ categories,
355
+ matching_iou_threshold=0.5,
356
+ recall_lower_bound=0.,
357
+ recall_upper_bound=1.0):
358
+ super(PrecisionAtRecallDetectionEvaluator, self).__init__(
359
+ categories,
360
+ matching_iou_threshold=matching_iou_threshold,
361
+ recall_lower_bound=recall_lower_bound,
362
+ recall_upper_bound=recall_upper_bound,
363
+ evaluate_corlocs=False,
364
+ metric_prefix='PrecisionAtRecallBoxes',
365
+ use_weighted_mean_ap=False)
366
+
367
+
368
+ class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
369
+ """A class to evaluation detections using Open Images V2 metrics.
370
+ Open Images V2 introduce group_of type of bounding boxes and this metric
371
+ handles those boxes appropriately.
372
+ """
373
+
374
+ def __init__(self,
375
+ categories,
376
+ matching_iou_threshold=0.5,
377
+ evaluate_masks=False,
378
+ evaluate_corlocs=False,
379
+ metric_prefix='OpenImagesV5',
380
+ group_of_weight=0.0):
381
+ """Constructor.
382
+ Args:
383
+ categories: A list of dicts, each of which has the following keys -
384
+ 'id': (required) an integer id uniquely identifying this category.
385
+ 'name': (required) string representing category name e.g., 'cat', 'dog'.
386
+ matching_iou_threshold: IOU threshold to use for matching groundtruth
387
+ boxes to detection boxes.
388
+ evaluate_masks: if True, evaluator evaluates masks.
389
+ evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
390
+ metric_prefix: Prefix name of the metric.
391
+ group_of_weight: Weight of the group-of bounding box. If set to 0 (default
392
+ for Open Images V2 detection protocol), detections of the correct class
393
+ within a group-of box are ignored. If weight is > 0, then if at least
394
+ one detection falls within a group-of box with matching_iou_threshold,
395
+ weight group_of_weight is added to true positives. Consequently, if no
396
+ detection falls within a group-of box, weight group_of_weight is added
397
+ to false negatives.
398
+ """
399
+
400
+ super(OpenImagesDetectionEvaluator, self).__init__(
401
+ categories,
402
+ matching_iou_threshold,
403
+ evaluate_corlocs,
404
+ metric_prefix=metric_prefix,
405
+ group_of_weight=group_of_weight,
406
+ evaluate_masks=evaluate_masks)
407
+
408
+ def add_single_ground_truth_image_info(self, image_id, gt_dict):
409
+ """Adds groundtruth for a single image to be used for evaluation.
410
+ Args:
411
+ image_id: A unique string/integer identifier for the image.
412
+ gt_dict: A dictionary containing -
413
+ InputDataFields.gt_boxes: float32 numpy array
414
+ of shape [num_boxes, 4] containing `num_boxes` groundtruth boxes of
415
+ the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
416
+ InputDataFields.gt_classes: integer numpy array
417
+ of shape [num_boxes] containing 1-indexed groundtruth classes for the boxes.
418
+ InputDataFields.gt_group_of: Optional length M
419
+ numpy boolean array denoting whether a groundtruth box contains a group of instances.
420
+ Raises:
421
+ ValueError: On adding groundtruth for an image more than once.
422
+ """
423
+ if image_id in self._image_ids:
424
+ return
425
+
426
+ gt_classes = (gt_dict[InputDataFields.gt_classes] - self._label_id_offset)
427
+ # If the key is not present in the gt_dict or the array is empty
428
+ # (unless there are no annotations for the groundtruth on this image)
429
+ # use values from the dictionary or insert None otherwise.
430
+ if (InputDataFields.gt_group_of in gt_dict and
431
+ (gt_dict[InputDataFields.gt_group_of].size or not gt_classes.size)):
432
+ gt_group_of = gt_dict[InputDataFields.gt_group_of]
433
+ else:
434
+ gt_group_of = None
435
+ # FIXME disable warning for now, will add group_of flag eventually
436
+ # if not len(self._image_ids) % 1000:
437
+ # logging.warning('image %s does not have groundtruth group_of flag specified', image_id)
438
+ if self._evaluate_masks:
439
+ gt_masks = gt_dict[InputDataFields.gt_instance_masks]
440
+ else:
441
+ gt_masks = None
442
+
443
+ self._evaluation.add_single_ground_truth_image_info(
444
+ image_id,
445
+ gt_dict[InputDataFields.gt_boxes],
446
+ gt_classes,
447
+ gt_is_difficult_list=None,
448
+ gt_is_group_of_list=gt_group_of,
449
+ gt_masks=gt_masks)
450
+ self._image_ids.update([image_id])
451
+
452
+
453
+ class OpenImagesChallengeEvaluator(OpenImagesDetectionEvaluator):
454
+ """A class implements Open Images Challenge metrics.
455
+ Both Detection and Instance Segmentation evaluation metrics are implemented.
456
+ Open Images Challenge Detection metric has two major changes in comparison
457
+ with Open Images V2 detection metric:
458
+ - a custom weight might be specified for detecting an object contained in a group-of box.
459
+ - verified image-level labels should be explicitly provided for evaluation: in case an
460
+ image has neither positive nor negative image level label of class c, all detections of
461
+ this class on this image will be ignored.
462
+
463
+ Open Images Challenge Instance Segmentation metric allows to measure performance
464
+ of models in case of incomplete annotations: some instances are
465
+ annotations only on box level and some - on image-level. In addition,
466
+ image-level labels are taken into account as in detection metric.
467
+
468
+ Open Images Challenge Detection metric default parameters:
469
+ evaluate_masks = False
470
+ group_of_weight = 1.0
471
+
472
+ Open Images Challenge Instance Segmentation metric default parameters:
473
+ evaluate_masks = True
474
+ (group_of_weight will not matter)
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ categories,
480
+ evaluate_masks=False,
481
+ matching_iou_threshold=0.5,
482
+ evaluate_corlocs=False,
483
+ group_of_weight=1.0):
484
+ """Constructor.
485
+ Args:
486
+ categories: A list of dicts, each of which has the following keys -
487
+ 'id': (required) an integer id uniquely identifying this category.
488
+ 'name': (required) string representing category name e.g., 'cat', 'dog'.
489
+ evaluate_masks: set to true for instance segmentation metric and to false
490
+ for detection metric.
491
+ matching_iou_threshold: IOU threshold to use for matching groundtruth
492
+ boxes to detection boxes.
493
+ evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
494
+ group_of_weight: Weight of group-of boxes. If set to 0, detections of the
495
+ correct class within a group-of box are ignored. If weight is > 0, then
496
+ if at least one detection falls within a group-of box with
497
+ matching_iou_threshold, weight group_of_weight is added to true
498
+ positives. Consequently, if no detection falls within a group-of box,
499
+ weight group_of_weight is added to false negatives.
500
+ """
501
+ if not evaluate_masks:
502
+ metrics_prefix = 'OpenImagesDetectionChallenge'
503
+ else:
504
+ metrics_prefix = 'OpenImagesInstanceSegmentationChallenge'
505
+
506
+ super(OpenImagesChallengeEvaluator, self).__init__(
507
+ categories,
508
+ matching_iou_threshold,
509
+ evaluate_masks=evaluate_masks,
510
+ evaluate_corlocs=evaluate_corlocs,
511
+ group_of_weight=group_of_weight,
512
+ metric_prefix=metrics_prefix)
513
+
514
+ self._evaluatable_labels = {}
515
+
516
+ def add_single_ground_truth_image_info(self, image_id, gt_dict):
517
+ """Adds groundtruth for a single image to be used for evaluation.
518
+ Args:
519
+ image_id: A unique string/integer identifier for the image.
520
+ gt_dict: A dictionary containing -
521
+ InputDataFields.gt_boxes: float32 numpy array of shape [num_boxes, 4]
522
+ containing `num_boxes` groundtruth boxes of the format [ymin, xmin, ymax, xmax]
523
+ in absolute image coordinates.
524
+ InputDataFields.gt_classes: integer numpy array of shape [num_boxes]
525
+ containing 1-indexed groundtruth classes for the boxes.
526
+ InputDataFields.gt_image_classes: integer 1D
527
+ numpy array containing all classes for which labels are verified.
528
+ InputDataFields.gt_group_of: Optional length M
529
+ numpy boolean array denoting whether a groundtruth box contains a group of instances.
530
+ Raises:
531
+ ValueError: On adding groundtruth for an image more than once.
532
+ """
533
+ super(OpenImagesChallengeEvaluator,
534
+ self).add_single_ground_truth_image_info(image_id, gt_dict)
535
+ input_fields = InputDataFields
536
+ gt_classes = gt_dict[input_fields.gt_classes] - self._label_id_offset
537
+ image_classes = np.array([], dtype=int)
538
+ if input_fields.gt_image_classes in gt_dict:
539
+ image_classes = gt_dict[input_fields.gt_image_classes]
540
+ elif input_fields.gt_labeled_classes in gt_dict:
541
+ image_classes = gt_dict[input_fields.gt_labeled_classes]
542
+ image_classes -= self._label_id_offset
543
+ self._evaluatable_labels[image_id] = np.unique(
544
+ np.concatenate((image_classes, gt_classes)))
545
+
546
+ def add_single_detected_image_info(self, image_id, detections_dict):
547
+ """Adds detections for a single image to be used for evaluation.
548
+ Args:
549
+ image_id: A unique string/integer identifier for the image.
550
+ detections_dict: A dictionary containing -
551
+ DetectionResultFields.detection_boxes: float32 numpy
552
+ array of shape [num_boxes, 4] containing `num_boxes` detection boxes
553
+ of the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
554
+ DetectionResultFields.detection_scores: float32 numpy
555
+ array of shape [num_boxes] containing detection scores for the boxes.
556
+ DetectionResultFields.detection_classes: integer numpy
557
+ array of shape [num_boxes] containing 1-indexed detection classes for
558
+ the boxes.
559
+ Raises:
560
+ ValueError: If detection masks are not in detections dictionary.
561
+ """
562
+ if image_id not in self._image_ids:
563
+ # Since for the correct work of evaluator it is assumed that groundtruth
564
+ # is inserted first we make sure to break the code if is it not the case.
565
+ self._image_ids.update([image_id])
566
+ self._evaluatable_labels[image_id] = np.array([])
567
+
568
+ detection_classes = detections_dict[DetectionResultFields.detection_classes] - self._label_id_offset
569
+ allowed_classes = np.where(np.isin(detection_classes, self._evaluatable_labels[image_id]))
570
+ detection_classes = detection_classes[allowed_classes]
571
+ detected_boxes = detections_dict[DetectionResultFields.detection_boxes][allowed_classes]
572
+ detected_scores = detections_dict[DetectionResultFields.detection_scores][allowed_classes]
573
+
574
+ if self._evaluate_masks:
575
+ detection_masks = detections_dict[DetectionResultFields.detection_masks][allowed_classes]
576
+ else:
577
+ detection_masks = None
578
+ self._evaluation.add_single_detected_image_info(
579
+ image_key=image_id,
580
+ detected_boxes=detected_boxes,
581
+ detected_scores=detected_scores,
582
+ detected_class_labels=detection_classes,
583
+ detected_masks=detection_masks)
584
+
585
+ def clear(self):
586
+ """Clears stored data."""
587
+
588
+ super(OpenImagesChallengeEvaluator, self).clear()
589
+ self._evaluatable_labels.clear()
590
+
efficientdet/effdet/evaluation/fields.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class InputDataFields(object):
3
+ """Names for the input tensors.
4
+ Holds the standard data field names to use for identifying input tensors. This
5
+ should be used by the decoder to identify keys for the returned tensor_dict
6
+ containing input tensors. And it should be used by the model to identify the
7
+ tensors it needs.
8
+ Attributes:
9
+ image: image.
10
+ image_additional_channels: additional channels.
11
+ key: unique key corresponding to image.
12
+ filename: original filename of the dataset (without common path).
13
+ gt_image_classes: image-level class labels.
14
+ gt_image_confidences: image-level class confidences.
15
+ gt_labeled_classes: image-level annotation that indicates the
16
+ classes for which an image has been labeled.
17
+ gt_boxes: coordinates of the ground truth boxes in the image.
18
+ gt_classes: box-level class labels.
19
+ gt_confidences: box-level class confidences. The shape should be
20
+ the same as the shape of gt_classes.
21
+ gt_label_types: box-level label types (e.g. explicit negative).
22
+ gt_is_crowd: [DEPRECATED, use gt_group_of instead]
23
+ is the groundtruth a single object or a crowd.
24
+ gt_area: area of a groundtruth segment.
25
+ gt_difficult: is a `difficult` object
26
+ gt_group_of: is a `group_of` objects, e.g. multiple objects of the
27
+ same class, forming a connected group, where instances are heavily
28
+ occluding each other.
29
+ gt_instance_masks: ground truth instance masks.
30
+ gt_instance_boundaries: ground truth instance boundaries.
31
+ gt_instance_classes: instance mask-level class labels.
32
+ gt_label_weights: groundtruth label weights.
33
+ gt_weights: groundtruth weight factor for bounding boxes.
34
+ image_height: height of images, used to decode
35
+ image_width: width of images, used to decode
36
+ """
37
+ image = 'image'
38
+ key = 'image_id'
39
+ filename = 'filename'
40
+ gt_boxes = 'bbox'
41
+ gt_classes = 'cls'
42
+ gt_confidences = 'confidences'
43
+ gt_label_types = 'label_types'
44
+ gt_image_classes = 'img_cls'
45
+ gt_image_confidences = 'img_confidences'
46
+ gt_labeled_classes = 'labeled_cls'
47
+ gt_is_crowd = 'is_crowd'
48
+ gt_area = 'area'
49
+ gt_difficult = 'difficult'
50
+ gt_group_of = 'group_of'
51
+ gt_instance_masks = 'instance_masks'
52
+ gt_instance_boundaries = 'instance_boundaries'
53
+ gt_instance_classes = 'instance_classes'
54
+ image_height = 'img_height'
55
+ image_width = 'img_width'
56
+ image_size = 'img_size'
57
+
58
+
59
+ class DetectionResultFields(object):
60
+ """Naming conventions for storing the output of the detector.
61
+ Attributes:
62
+ source_id: source of the original image.
63
+ key: unique key corresponding to image.
64
+ detection_boxes: coordinates of the detection boxes in the image.
65
+ detection_scores: detection scores for the detection boxes in the image.
66
+ detection_multiclass_scores: class score distribution (including background)
67
+ for detection boxes in the image including background class.
68
+ detection_classes: detection-level class labels.
69
+ detection_masks: contains a segmentation mask for each detection box.
70
+ """
71
+
72
+ key = 'image_id'
73
+ detection_boxes = 'bbox'
74
+ detection_scores = 'score'
75
+ detection_classes = 'cls'
76
+ detection_masks = 'masks'
77
+
78
+
79
+ class BoxListFields(object):
80
+ """Naming conventions for BoxLists.
81
+ Attributes:
82
+ boxes: bounding box coordinates.
83
+ classes: classes per bounding box.
84
+ scores: scores per bounding box.
85
+ weights: sample weights per bounding box.
86
+ objectness: objectness score per bounding box.
87
+ masks: masks per bounding box.
88
+ boundaries: boundaries per bounding box.
89
+ keypoints: keypoints per bounding box.
90
+ keypoint_heatmaps: keypoint heatmaps per bounding box.
91
+ is_crowd: is_crowd annotation per bounding box.
92
+ """
93
+ boxes = 'boxes'
94
+ classes = 'classes'
95
+ scores = 'scores'
96
+ weights = 'weights'
97
+ confidences = 'confidences'
98
+ objectness = 'objectness'
99
+ masks = 'masks'
100
+ boundaries = 'boundaries'
101
+ keypoints = 'keypoints'
102
+ keypoint_visibilities = 'keypoint_visibilities'
103
+ keypoint_heatmaps = 'keypoint_heatmaps'
104
+ is_crowd = 'is_crowd'
105
+ group_of = 'group_of'
efficientdet/effdet/evaluation/metrics.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def compute_precision_recall(scores, labels, num_gt):
5
+ """Compute precision and recall.
6
+ Args:
7
+ scores: A float numpy array representing detection score
8
+ labels: A float numpy array representing weighted true/false positive labels
9
+ num_gt: Number of ground truth instances
10
+ Raises:
11
+ ValueError: if the input is not of the correct format
12
+ Returns:
13
+ precision: Fraction of positive instances over detected ones. This value is
14
+ None if no ground truth labels are present.
15
+ recall: Fraction of detected positive instance over all positive instances.
16
+ This value is None if no ground truth labels are present.
17
+ """
18
+ if not isinstance(labels, np.ndarray) or len(labels.shape) != 1:
19
+ raise ValueError("labels must be single dimension numpy array")
20
+
21
+ if labels.dtype != np.float and labels.dtype != np.bool:
22
+ raise ValueError("labels type must be either bool or float")
23
+
24
+ if not isinstance(scores, np.ndarray) or len(scores.shape) != 1:
25
+ raise ValueError("scores must be single dimension numpy array")
26
+
27
+ if num_gt < np.sum(labels):
28
+ raise ValueError("Number of true positives must be smaller than num_gt.")
29
+
30
+ if len(scores) != len(labels):
31
+ raise ValueError("scores and labels must be of the same size.")
32
+
33
+ if num_gt == 0:
34
+ return None, None
35
+
36
+ sorted_indices = np.argsort(scores)
37
+ sorted_indices = sorted_indices[::-1]
38
+ true_positive_labels = labels[sorted_indices]
39
+ false_positive_labels = (true_positive_labels <= 0).astype(float)
40
+ cum_true_positives = np.cumsum(true_positive_labels)
41
+ cum_false_positives = np.cumsum(false_positive_labels)
42
+ precision = cum_true_positives.astype(float) / (cum_true_positives + cum_false_positives)
43
+ recall = cum_true_positives.astype(float) / num_gt
44
+ return precision, recall
45
+
46
+
47
+ def compute_average_precision(precision, recall):
48
+ """Compute Average Precision according to the definition in VOCdevkit.
49
+ Precision is modified to ensure that it does not decrease as recall
50
+ decrease.
51
+ Args:
52
+ precision: A float [N, 1] numpy array of precisions
53
+ recall: A float [N, 1] numpy array of recalls
54
+ Raises:
55
+ ValueError: if the input is not of the correct format
56
+ Returns:
57
+ average_precison: The area under the precision recall curve. NaN if
58
+ precision and recall are None.
59
+ """
60
+ if precision is None:
61
+ if recall is not None:
62
+ raise ValueError("If precision is None, recall must also be None")
63
+ return np.NAN
64
+
65
+ if not isinstance(precision, np.ndarray) or not isinstance(recall, np.ndarray):
66
+ raise ValueError("precision and recall must be numpy array")
67
+ if precision.dtype != np.float or recall.dtype != np.float:
68
+ raise ValueError("input must be float numpy array.")
69
+ if len(precision) != len(recall):
70
+ raise ValueError("precision and recall must be of the same size.")
71
+ if not precision.size:
72
+ return 0.0
73
+ if np.amin(precision) < 0 or np.amax(precision) > 1:
74
+ raise ValueError("Precision must be in the range of [0, 1].")
75
+ if np.amin(recall) < 0 or np.amax(recall) > 1:
76
+ raise ValueError("recall must be in the range of [0, 1].")
77
+ if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)):
78
+ raise ValueError("recall must be a non-decreasing array")
79
+
80
+ recall = np.concatenate([[0], recall, [1]])
81
+ precision = np.concatenate([[0], precision, [0]])
82
+
83
+ # Preprocess precision to be a non-decreasing array
84
+ for i in range(len(precision) - 2, -1, -1):
85
+ precision[i] = np.maximum(precision[i], precision[i + 1])
86
+
87
+ indices = np.where(recall[1:] != recall[:-1])[0] + 1
88
+ average_precision = np.sum((recall[indices] - recall[indices - 1]) * precision[indices])
89
+ return average_precision
90
+
91
+
92
+ def compute_cor_loc(num_gt_imgs_per_class, num_images_correctly_detected_per_class):
93
+ """Compute CorLoc according to the definition in the following paper.
94
+ https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf
95
+ Returns nans if there are no ground truth images for a class.
96
+ Args:
97
+ num_gt_imgs_per_class: 1D array, representing number of images containing
98
+ at least one object instance of a particular class
99
+ num_images_correctly_detected_per_class: 1D array, representing number of
100
+ images that are correctly detected at least one object instance of a particular class
101
+ Returns:
102
+ corloc_per_class: A float numpy array represents the corloc score of each class
103
+ """
104
+ return np.where(
105
+ num_gt_imgs_per_class == 0, np.nan,
106
+ num_images_correctly_detected_per_class / num_gt_imgs_per_class)
107
+
108
+
109
+ def compute_median_rank_at_k(tp_fp_list, k):
110
+ """Computes MedianRank@k, where k is the top-scoring labels.
111
+ Args:
112
+ tp_fp_list: a list of numpy arrays; each numpy array corresponds to the all
113
+ detection on a single image, where the detections are sorted by score in
114
+ descending order. Further, each numpy array element can have boolean or
115
+ float values. True positive elements have either value >0.0 or True;
116
+ any other value is considered false positive.
117
+ k: number of top-scoring proposals to take.
118
+ Returns:
119
+ median_rank: median rank of all true positive proposals among top k by score.
120
+ """
121
+ ranks = []
122
+ for i in range(len(tp_fp_list)):
123
+ ranks.append(np.where(tp_fp_list[i][0:min(k, tp_fp_list[i].shape[0])] > 0)[0])
124
+ concatenated_ranks = np.concatenate(ranks)
125
+ return np.median(concatenated_ranks)
126
+
127
+
128
+ def compute_recall_at_k(tp_fp_list, num_gt, k):
129
+ """Computes Recall@k, MedianRank@k, where k is the top-scoring labels.
130
+ Args:
131
+ tp_fp_list: a list of numpy arrays; each numpy array corresponds to the all
132
+ detection on a single image, where the detections are sorted by score in
133
+ descending order. Further, each numpy array element can have boolean or
134
+ float values. True positive elements have either value >0.0 or True;
135
+ any other value is considered false positive.
136
+ num_gt: number of groundtruth anotations.
137
+ k: number of top-scoring proposals to take.
138
+ Returns:
139
+ recall: recall evaluated on the top k by score detections.
140
+ """
141
+
142
+ tp_fp_eval = []
143
+ for i in range(len(tp_fp_list)):
144
+ tp_fp_eval.append(tp_fp_list[i][0:min(k, tp_fp_list[i].shape[0])])
145
+
146
+ tp_fp_eval = np.concatenate(tp_fp_eval)
147
+
148
+ return np.sum(tp_fp_eval) / num_gt
efficientdet/effdet/evaluation/np_box_list.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Bounding Box List operations for Numpy BoxLists.
17
+
18
+ Example box operations that are supported:
19
+ * Areas: compute bounding box areas
20
+ * IOU: pairwise intersection-over-union scores
21
+ """
22
+ import numpy as np
23
+
24
+
25
+ class BoxList(object):
26
+ """Box collection.
27
+ BoxList represents a list of bounding boxes as numpy array, where each
28
+ bounding box is represented as a row of 4 numbers,
29
+ [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a
30
+ given list correspond to a single image.
31
+ Optionally, users can add additional related fields (such as
32
+ objectness/classification scores).
33
+ """
34
+
35
+ def __init__(self, data):
36
+ """Constructs box collection.
37
+ Args:
38
+ data: a numpy array of shape [N, 4] representing box coordinates
39
+ Raises:
40
+ ValueError: if bbox data is not a numpy array
41
+ ValueError: if invalid dimensions for bbox data
42
+ """
43
+ if not isinstance(data, np.ndarray):
44
+ raise ValueError('data must be a numpy array.')
45
+ if len(data.shape) != 2 or data.shape[1] != 4:
46
+ raise ValueError('Invalid dimensions for box data.')
47
+ if data.dtype != np.float32 and data.dtype != np.float64:
48
+ raise ValueError('Invalid data type for box data: float is required.')
49
+ if not self._is_valid_boxes(data):
50
+ raise ValueError('Invalid box data. data must be a numpy array of '
51
+ 'N*[y_min, x_min, y_max, x_max]')
52
+ self.data = {'boxes': data}
53
+
54
+ def num_boxes(self):
55
+ """Return number of boxes held in collections."""
56
+ return self.data['boxes'].shape[0]
57
+
58
+ def get_extra_fields(self):
59
+ """Return all non-box fields."""
60
+ return [k for k in self.data.keys() if k != 'boxes']
61
+
62
+ def has_field(self, field):
63
+ return field in self.data
64
+
65
+ def add_field(self, field, field_data):
66
+ """Add data to a specified field.
67
+ Args:
68
+ field: a string parameter used to speficy a related field to be accessed.
69
+ field_data: a numpy array of [N, ...] representing the data associated
70
+ with the field.
71
+ Raises:
72
+ ValueError: if the field is already exist or the dimension of the field
73
+ data does not matches the number of boxes.
74
+ """
75
+ if self.has_field(field):
76
+ raise ValueError('Field ' + field + 'already exists')
77
+ if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes():
78
+ raise ValueError('Invalid dimensions for field data')
79
+ self.data[field] = field_data
80
+
81
+ def get(self):
82
+ """Convenience function for accesssing box coordinates.
83
+ Returns:
84
+ a numpy array of shape [N, 4] representing box corners
85
+ """
86
+ return self.get_field('boxes')
87
+
88
+ def get_field(self, field):
89
+ """Accesses data associated with the specified field in the box collection.
90
+ Args:
91
+ field: a string parameter used to speficy a related field to be accessed.
92
+ Returns:
93
+ a numpy 1-d array representing data of an associated field
94
+ Raises:
95
+ ValueError: if invalid field
96
+ """
97
+ if not self.has_field(field):
98
+ raise ValueError('field {} does not exist'.format(field))
99
+ return self.data[field]
100
+
101
+ def get_coordinates(self):
102
+ """Get corner coordinates of boxes.
103
+ Returns:
104
+ a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max]
105
+ """
106
+ box_coordinates = self.get()
107
+ y_min = box_coordinates[:, 0]
108
+ x_min = box_coordinates[:, 1]
109
+ y_max = box_coordinates[:, 2]
110
+ x_max = box_coordinates[:, 3]
111
+ return [y_min, x_min, y_max, x_max]
112
+
113
+ def _is_valid_boxes(self, data):
114
+ """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin].
115
+ Args:
116
+ data: a numpy array of shape [N, 4] representing box coordinates
117
+ Returns:
118
+ a boolean indicating whether all ymax of boxes are equal or greater than
119
+ ymin, and all xmax of boxes are equal or greater than xmin.
120
+ """
121
+ if data.shape[0] > 0:
122
+ for i in range(data.shape[0]):
123
+ if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]:
124
+ return False
125
+ return True
126
+
127
+
128
+ def area(boxes):
129
+ """Computes area of boxes.
130
+
131
+ Args:
132
+ boxes: Numpy array with shape [N, 4] holding N boxes
133
+
134
+ Returns:
135
+ a numpy array with shape [N*1] representing box areas
136
+ """
137
+ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
138
+
139
+
140
+ def intersection(boxes1, boxes2):
141
+ """Compute pairwise intersection areas between boxes.
142
+
143
+ Args:
144
+ boxes1: a numpy array with shape [N, 4] holding N boxes
145
+ boxes2: a numpy array with shape [M, 4] holding M boxes
146
+
147
+ Returns:
148
+ a numpy array with shape [N*M] representing pairwise intersection area
149
+ """
150
+ [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1)
151
+ [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1)
152
+
153
+ all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2))
154
+ all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2))
155
+ intersect_heights = np.maximum(np.zeros(all_pairs_max_ymin.shape), all_pairs_min_ymax - all_pairs_max_ymin)
156
+ all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2))
157
+ all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2))
158
+ intersect_widths = np.maximum(np.zeros(all_pairs_max_xmin.shape), all_pairs_min_xmax - all_pairs_max_xmin)
159
+ return intersect_heights * intersect_widths
160
+
161
+
162
+ def iou(boxes1, boxes2):
163
+ """Computes pairwise intersection-over-union between box collections.
164
+
165
+ Args:
166
+ boxes1: a numpy array with shape [N, 4] holding N boxes.
167
+ boxes2: a numpy array with shape [M, 4] holding N boxes.
168
+
169
+ Returns:
170
+ a numpy array with shape [N, M] representing pairwise iou scores.
171
+ """
172
+ intersect = intersection(boxes1, boxes2)
173
+ area1 = area(boxes1)
174
+ area2 = area(boxes2)
175
+ union = np.expand_dims(area1, axis=1) + np.expand_dims(area2, axis=0) - intersect
176
+ return intersect / union
177
+
178
+
179
+ def ioa(boxes1, boxes2):
180
+ """Computes pairwise intersection-over-area between box collections.
181
+
182
+ Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
183
+ their intersection area over box2's area. Note that ioa is not symmetric,
184
+ that is, IOA(box1, box2) != IOA(box2, box1).
185
+
186
+ Args:
187
+ boxes1: a numpy array with shape [N, 4] holding N boxes.
188
+ boxes2: a numpy array with shape [M, 4] holding N boxes.
189
+
190
+ Returns:
191
+ a numpy array with shape [N, M] representing pairwise ioa scores.
192
+ """
193
+ intersect = intersection(boxes1, boxes2)
194
+ areas = np.expand_dims(area(boxes2), axis=0)
195
+ return intersect / areas
196
+
197
+
198
+ class SortOrder(object):
199
+ """Enum class for sort order.
200
+
201
+ Attributes:
202
+ ascend: ascend order.
203
+ descend: descend order.
204
+ """
205
+ ASCEND = 1
206
+ DESCEND = 2
207
+
208
+
209
+ def area_boxlist(boxlist):
210
+ """Computes area of boxes.
211
+
212
+ Args:
213
+ boxlist: BoxList holding N boxes
214
+
215
+ Returns:
216
+ a numpy array with shape [N*1] representing box areas
217
+ """
218
+ y_min, x_min, y_max, x_max = boxlist.get_coordinates()
219
+ return (y_max - y_min) * (x_max - x_min)
220
+
221
+
222
+ def intersection_boxlist(boxlist1, boxlist2):
223
+ """Compute pairwise intersection areas between boxes.
224
+
225
+ Args:
226
+ boxlist1: BoxList holding N boxes
227
+ boxlist2: BoxList holding M boxes
228
+
229
+ Returns:
230
+ a numpy array with shape [N*M] representing pairwise intersection area
231
+ """
232
+ return intersection(boxlist1.get(), boxlist2.get())
233
+
234
+
235
+ def iou_boxlist(boxlist1, boxlist2):
236
+ """Computes pairwise intersection-over-union between box collections.
237
+
238
+ Args:
239
+ boxlist1: BoxList holding N boxes
240
+ boxlist2: BoxList holding M boxes
241
+
242
+ Returns:
243
+ a numpy array with shape [N, M] representing pairwise iou scores.
244
+ """
245
+ return iou(boxlist1.get(), boxlist2.get())
246
+
247
+
248
+ def ioa_boxlist(boxlist1, boxlist2):
249
+ """Computes pairwise intersection-over-area between box collections.
250
+
251
+ Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
252
+ their intersection area over box2's area. Note that ioa is not symmetric,
253
+ that is, IOA(box1, box2) != IOA(box2, box1).
254
+
255
+ Args:
256
+ boxlist1: BoxList holding N boxes
257
+ boxlist2: BoxList holding M boxes
258
+
259
+ Returns:
260
+ a numpy array with shape [N, M] representing pairwise ioa scores.
261
+ """
262
+ return ioa(boxlist1.get(), boxlist2.get())
263
+
264
+
265
+ def gather_boxlist(boxlist, indices, fields=None):
266
+ """Gather boxes from BoxList according to indices and return new BoxList.
267
+
268
+ By default, gather returns boxes corresponding to the input index list, as
269
+ well as all additional fields stored in the boxlist (indexing into the
270
+ first dimension). However one can optionally only gather from a
271
+ subset of fields.
272
+
273
+ Args:
274
+ boxlist: BoxList holding N boxes
275
+ indices: a 1-d numpy array of type int_
276
+ fields: (optional) list of fields to also gather from. If None (default),
277
+ all fields are gathered from. Pass an empty fields list to only gather the box coordinates.
278
+
279
+ Returns:
280
+ subboxlist: a BoxList corresponding to the subset of the input BoxList specified by indices
281
+
282
+ Raises:
283
+ ValueError: if specified field is not contained in boxlist or if the indices are not of type int_
284
+ """
285
+ if indices.size:
286
+ if np.amax(indices) >= boxlist.num_boxes() or np.amin(indices) < 0:
287
+ raise ValueError('indices are out of valid range.')
288
+ subboxlist = BoxList(boxlist.get()[indices, :])
289
+ if fields is None:
290
+ fields = boxlist.get_extra_fields()
291
+ for field in fields:
292
+ extra_field_data = boxlist.get_field(field)
293
+ subboxlist.add_field(field, extra_field_data[indices, ...])
294
+ return subboxlist
295
+
296
+
297
+ def sort_by_field_boxlist(boxlist, field, order=SortOrder.DESCEND):
298
+ """Sort boxes and associated fields according to a scalar field.
299
+
300
+ A common use case is reordering the boxes according to descending scores.
301
+
302
+ Args:
303
+ boxlist: BoxList holding N boxes.
304
+ field: A BoxList field for sorting and reordering the BoxList.
305
+ order: (Optional) 'descend' or 'ascend'. Default is descend.
306
+
307
+ Returns:
308
+ sorted_boxlist: A sorted BoxList with the field in the specified order.
309
+
310
+ Raises:
311
+ ValueError: if specified field does not exist or is not of single dimension.
312
+ ValueError: if the order is not either descend or ascend.
313
+ """
314
+ if not boxlist.has_field(field):
315
+ raise ValueError('Field ' + field + ' does not exist')
316
+ if len(boxlist.get_field(field).shape) != 1:
317
+ raise ValueError('Field ' + field + 'should be single dimension.')
318
+ if order != SortOrder.DESCEND and order != SortOrder.ASCEND:
319
+ raise ValueError('Invalid sort order')
320
+
321
+ field_to_sort = boxlist.get_field(field)
322
+ sorted_indices = np.argsort(field_to_sort)
323
+ if order == SortOrder.DESCEND:
324
+ sorted_indices = sorted_indices[::-1]
325
+ return gather_boxlist(boxlist, sorted_indices)
326
+
327
+
328
+ def non_max_suppression(boxlist, max_output_size=10000, iou_threshold=1.0, score_threshold=-10.0):
329
+ """Non maximum suppression.
330
+
331
+ This op greedily selects a subset of detection bounding boxes, pruning
332
+ away boxes that have high IOU (intersection over union) overlap (> thresh)
333
+ with already selected boxes. In each iteration, the detected bounding box with
334
+ highest score in the available pool is selected.
335
+
336
+ Args:
337
+ boxlist: BoxList holding N boxes. Must contain a 'scores' field
338
+ representing detection scores. All scores belong to the same class.
339
+ max_output_size: maximum number of retained boxes
340
+ iou_threshold: intersection over union threshold.
341
+ score_threshold: minimum score threshold. Remove the boxes with scores less than
342
+ this value. Default value is set to -10. A very low threshold to pass pretty
343
+ much all the boxes, unless the user sets a different score threshold.
344
+
345
+ Returns:
346
+ a BoxList holding M boxes where M <= max_output_size
347
+ Raises:
348
+ ValueError: if 'scores' field does not exist
349
+ ValueError: if threshold is not in [0, 1]
350
+ ValueError: if max_output_size < 0
351
+ """
352
+ if not boxlist.has_field('scores'):
353
+ raise ValueError('Field scores does not exist')
354
+ if iou_threshold < 0. or iou_threshold > 1.0:
355
+ raise ValueError('IOU threshold must be in [0, 1]')
356
+ if max_output_size < 0:
357
+ raise ValueError('max_output_size must be bigger than 0.')
358
+
359
+ boxlist = filter_scores_greater_than(boxlist, score_threshold)
360
+ if boxlist.num_boxes() == 0:
361
+ return boxlist
362
+
363
+ boxlist = sort_by_field_boxlist(boxlist, 'scores')
364
+
365
+ # Prevent further computation if NMS is disabled.
366
+ if iou_threshold == 1.0:
367
+ if boxlist.num_boxes() > max_output_size:
368
+ selected_indices = np.arange(max_output_size)
369
+ return gather_boxlist(boxlist, selected_indices)
370
+ else:
371
+ return boxlist
372
+
373
+ boxes = boxlist.get()
374
+ num_boxes = boxlist.num_boxes()
375
+ # is_index_valid is True only for all remaining valid boxes,
376
+ is_index_valid = np.full(num_boxes, 1, dtype=bool)
377
+ selected_indices = []
378
+ num_output = 0
379
+ for i in range(num_boxes):
380
+ if num_output < max_output_size:
381
+ if is_index_valid[i]:
382
+ num_output += 1
383
+ selected_indices.append(i)
384
+ is_index_valid[i] = False
385
+ valid_indices = np.where(is_index_valid)[0]
386
+ if valid_indices.size == 0:
387
+ break
388
+
389
+ intersect_over_union = iou(np.expand_dims(boxes[i, :], axis=0), boxes[valid_indices, :])
390
+ intersect_over_union = np.squeeze(intersect_over_union, axis=0)
391
+ is_index_valid[valid_indices] = np.logical_and(
392
+ is_index_valid[valid_indices],
393
+ intersect_over_union <= iou_threshold)
394
+ return gather_boxlist(boxlist, np.array(selected_indices))
395
+
396
+
397
+ def multi_class_non_max_suppression(boxlist, score_thresh, iou_thresh, max_output_size):
398
+ """Multi-class version of non maximum suppression.
399
+
400
+ This op greedily selects a subset of detection bounding boxes, pruning
401
+ away boxes that have high IOU (intersection over union) overlap (> thresh)
402
+ with already selected boxes. It operates independently for each class for
403
+ which scores are provided (via the scores field of the input box_list),
404
+ pruning boxes with score less than a provided threshold prior to
405
+ applying NMS.
406
+
407
+ Args:
408
+ boxlist: BoxList holding N boxes. Must contain a 'scores' field
409
+ representing detection scores. This scores field is a tensor that can
410
+ be 1 dimensional (in the case of a single class) or 2-dimensional, which
411
+ which case we assume that it takes the shape [num_boxes, num_classes].
412
+ We further assume that this rank is known statically and that
413
+ scores.shape[1] is also known (i.e., the number of classes is fixed
414
+ and known at graph construction time).
415
+ score_thresh: scalar threshold for score (low scoring boxes are removed).
416
+ iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap
417
+ with previously selected boxes are removed).
418
+ max_output_size: maximum number of retained boxes per class.
419
+
420
+ Returns:
421
+ a BoxList holding M boxes with a rank-1 scores field representing
422
+ corresponding scores for each box with scores sorted in decreasing order
423
+ and a rank-1 classes field representing a class label for each box.
424
+ Raises:
425
+ ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
426
+ a valid scores field.
427
+ """
428
+ if not 0 <= iou_thresh <= 1.0:
429
+ raise ValueError('thresh must be between 0 and 1')
430
+ if not isinstance(boxlist, BoxList):
431
+ raise ValueError('boxlist must be a BoxList')
432
+ if not boxlist.has_field('scores'):
433
+ raise ValueError('input boxlist must have \'scores\' field')
434
+ scores = boxlist.get_field('scores')
435
+ if len(scores.shape) == 1:
436
+ scores = np.reshape(scores, [-1, 1])
437
+ elif len(scores.shape) == 2:
438
+ if scores.shape[1] is None:
439
+ raise ValueError('scores field must have statically defined second dimension')
440
+ else:
441
+ raise ValueError('scores field must be of rank 1 or 2')
442
+ num_boxes = boxlist.num_boxes()
443
+ num_scores = scores.shape[0]
444
+ num_classes = scores.shape[1]
445
+
446
+ if num_boxes != num_scores:
447
+ raise ValueError('Incorrect scores field length: actual vs expected.')
448
+
449
+ selected_boxes_list = []
450
+ for class_idx in range(num_classes):
451
+ boxlist_and_class_scores = BoxList(boxlist.get())
452
+ class_scores = np.reshape(scores[0:num_scores, class_idx], [-1])
453
+ boxlist_and_class_scores.add_field('scores', class_scores)
454
+ boxlist_filt = filter_scores_greater_than(boxlist_and_class_scores, score_thresh)
455
+ nms_result = non_max_suppression(
456
+ boxlist_filt, max_output_size=max_output_size, iou_threshold=iou_thresh, score_threshold=score_thresh)
457
+ nms_result.add_field('classes', np.zeros_like(nms_result.get_field('scores')) + class_idx)
458
+ selected_boxes_list.append(nms_result)
459
+ selected_boxes = concatenate_boxlist(selected_boxes_list)
460
+ sorted_boxes = sort_by_field_boxlist(selected_boxes, 'scores')
461
+ return sorted_boxes
462
+
463
+
464
+ def scale(boxlist, y_scale, x_scale):
465
+ """Scale box coordinates in x and y dimensions.
466
+
467
+ Args:
468
+ boxlist: BoxList holding N boxes
469
+ y_scale: float
470
+ x_scale: float
471
+
472
+ Returns:
473
+ boxlist: BoxList holding N boxes
474
+ """
475
+ y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
476
+ y_min = y_scale * y_min
477
+ y_max = y_scale * y_max
478
+ x_min = x_scale * x_min
479
+ x_max = x_scale * x_max
480
+ scaled_boxlist = BoxList(np.hstack([y_min, x_min, y_max, x_max]))
481
+
482
+ fields = boxlist.get_extra_fields()
483
+ for field in fields:
484
+ extra_field_data = boxlist.get_field(field)
485
+ scaled_boxlist.add_field(field, extra_field_data)
486
+
487
+ return scaled_boxlist
488
+
489
+
490
+ def clip_to_window(boxlist, window, filter_nonoverlapping=True):
491
+ """Clip bounding boxes to a window.
492
+
493
+ This op clips input bounding boxes (represented by bounding box
494
+ corners) to a window, optionally filtering out boxes that do not
495
+ overlap at all with the window.
496
+
497
+ Args:
498
+ boxlist: BoxList holding M_in boxes
499
+ window: a numpy array of shape [4] representing the [y_min, x_min, y_max, x_max]
500
+ window to which the op should clip boxes.
501
+ filter_nonoverlapping: whether to filter out boxes that do not overlap at all with the window.
502
+
503
+ Returns:
504
+ a BoxList holding M_out boxes where M_out <= M_in
505
+ """
506
+ y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
507
+ win_y_min = window[0]
508
+ win_x_min = window[1]
509
+ win_y_max = window[2]
510
+ win_x_max = window[3]
511
+ y_min_clipped = np.fmax(np.fmin(y_min, win_y_max), win_y_min)
512
+ y_max_clipped = np.fmax(np.fmin(y_max, win_y_max), win_y_min)
513
+ x_min_clipped = np.fmax(np.fmin(x_min, win_x_max), win_x_min)
514
+ x_max_clipped = np.fmax(np.fmin(x_max, win_x_max), win_x_min)
515
+ clipped = BoxList(np.hstack([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped]))
516
+ clipped = _copy_extra_fields(clipped, boxlist)
517
+ if filter_nonoverlapping:
518
+ areas = area(clipped)
519
+ nonzero_area_indices = np.reshape(np.nonzero(np.greater(areas, 0.0)), [-1]).astype(np.int32)
520
+ clipped = gather_boxlist(clipped, nonzero_area_indices)
521
+ return clipped
522
+
523
+
524
+ def prune_non_overlapping_boxes(boxlist1, boxlist2, minoverlap=0.0):
525
+ """Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2.
526
+
527
+ For each box in boxlist1, we want its IOA to be more than minoverlap with
528
+ at least one of the boxes in boxlist2. If it does not, we remove it.
529
+
530
+ Args:
531
+ boxlist1: BoxList holding N boxes.
532
+ boxlist2: BoxList holding M boxes.
533
+ minoverlap: Minimum required overlap between boxes, to count them as overlapping.
534
+
535
+ Returns:
536
+ A pruned boxlist with size [N', 4].
537
+ """
538
+ intersection_over_area = ioa(boxlist2, boxlist1) # [M, N] tensor
539
+ intersection_over_area = np.amax(intersection_over_area, axis=0) # [N] tensor
540
+ keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap))
541
+ keep_inds = np.nonzero(keep_bool)[0]
542
+ new_boxlist1 = gather_boxlist(boxlist1, keep_inds)
543
+ return new_boxlist1
544
+
545
+
546
+ def prune_outside_window(boxlist, window):
547
+ """Prunes bounding boxes that fall outside a given window.
548
+
549
+ This function prunes bounding boxes that even partially fall outside the given
550
+ window. See also ClipToWindow which only prunes bounding boxes that fall
551
+ completely outside the window, and clips any bounding boxes that partially
552
+ overflow.
553
+
554
+ Args:
555
+ boxlist: a BoxList holding M_in boxes.
556
+ window: a numpy array of size 4, representing [ymin, xmin, ymax, xmax] of the window.
557
+
558
+ Returns:
559
+ pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in.
560
+ valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes in the input tensor.
561
+ """
562
+
563
+ y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
564
+ win_y_min = window[0]
565
+ win_x_min = window[1]
566
+ win_y_max = window[2]
567
+ win_x_max = window[3]
568
+ coordinate_violations = np.hstack([
569
+ np.less(y_min, win_y_min), np.less(x_min, win_x_min),
570
+ np.greater(y_max, win_y_max), np.greater(x_max, win_x_max)])
571
+ valid_indices = np.reshape(np.where(np.logical_not(np.max(coordinate_violations, axis=1))), [-1])
572
+ return gather_boxlist(boxlist, valid_indices), valid_indices
573
+
574
+
575
+ def concatenate_boxlist(boxlists, fields=None):
576
+ """Concatenate list of BoxLists.
577
+
578
+ This op concatenates a list of input BoxLists into a larger BoxList. It also
579
+ handles concatenation of BoxList fields as long as the field tensor shapes
580
+ are equal except for the first dimension.
581
+
582
+ Args:
583
+ boxlists: list of BoxList objects
584
+ fields: optional list of fields to also concatenate. By default, all
585
+ fields from the first BoxList in the list are included in the concatenation.
586
+
587
+ Returns:
588
+ a BoxList with number of boxes equal to
589
+ sum([boxlist.num_boxes() for boxlist in BoxList])
590
+ Raises:
591
+ ValueError: if boxlists is invalid (i.e., is not a list, is empty, or
592
+ contains non BoxList objects), or if requested fields are not contained in all boxlists
593
+ """
594
+ if not isinstance(boxlists, list):
595
+ raise ValueError('boxlists should be a list')
596
+ if not boxlists:
597
+ raise ValueError('boxlists should have nonzero length')
598
+ for boxlist in boxlists:
599
+ if not isinstance(boxlist, BoxList):
600
+ raise ValueError('all elements of boxlists should be BoxList objects')
601
+ concatenated = BoxList(np.vstack([boxlist.get() for boxlist in boxlists]))
602
+ if fields is None:
603
+ fields = boxlists[0].get_extra_fields()
604
+ for field in fields:
605
+ first_field_shape = boxlists[0].get_field(field).shape
606
+ first_field_shape = first_field_shape[1:]
607
+ for boxlist in boxlists:
608
+ if not boxlist.has_field(field):
609
+ raise ValueError('boxlist must contain all requested fields')
610
+ field_shape = boxlist.get_field(field).shape
611
+ field_shape = field_shape[1:]
612
+ if field_shape != first_field_shape:
613
+ raise ValueError('field %s must have same shape for all boxlists '
614
+ 'except for the 0th dimension.' % field)
615
+ concatenated_field = np.concatenate([boxlist.get_field(field) for boxlist in boxlists], axis=0)
616
+ concatenated.add_field(field, concatenated_field)
617
+ return concatenated
618
+
619
+
620
+ def filter_scores_greater_than(boxlist, thresh):
621
+ """Filter to keep only boxes with score exceeding a given threshold.
622
+
623
+ This op keeps the collection of boxes whose corresponding scores are
624
+ greater than the input threshold.
625
+
626
+ Args:
627
+ boxlist: BoxList holding N boxes. Must contain a 'scores' field representing detection scores.
628
+ thresh: scalar threshold
629
+
630
+ Returns:
631
+ a BoxList holding M boxes where M <= N
632
+
633
+ Raises:
634
+ ValueError: if boxlist not a BoxList object or if it does not have a scores field
635
+ """
636
+ if not isinstance(boxlist, BoxList):
637
+ raise ValueError('boxlist must be a BoxList')
638
+ if not boxlist.has_field('scores'):
639
+ raise ValueError('input boxlist must have \'scores\' field')
640
+ scores = boxlist.get_field('scores')
641
+ if len(scores.shape) > 2:
642
+ raise ValueError('Scores should have rank 1 or 2')
643
+ if len(scores.shape) == 2 and scores.shape[1] != 1:
644
+ raise ValueError('Scores should have rank 1 or have shape '
645
+ 'consistent with [None, 1]')
646
+ high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), [-1]).astype(np.int32)
647
+ return gather_boxlist(boxlist, high_score_indices)
648
+
649
+
650
+ def change_coordinate_frame(boxlist, window):
651
+ """Change coordinate frame of the boxlist to be relative to window's frame.
652
+
653
+ Given a window of the form [ymin, xmin, ymax, xmax],
654
+ changes bounding box coordinates from boxlist to be relative to this window
655
+ (e.g., the min corner maps to (0,0) and the max corner maps to (1,1)).
656
+
657
+ An example use case is data augmentation: where we are given groundtruth
658
+ boxes (boxlist) and would like to randomly crop the image to some
659
+ window (window). In this case we need to change the coordinate frame of
660
+ each groundtruth box to be relative to this new window.
661
+
662
+ Args:
663
+ boxlist: A BoxList object holding N boxes.
664
+ window: a size 4 1-D numpy array.
665
+
666
+ Returns:
667
+ Returns a BoxList object with N boxes.
668
+ """
669
+ win_height = window[2] - window[0]
670
+ win_width = window[3] - window[1]
671
+ boxlist_new = scale(
672
+ BoxList(boxlist.get() - [window[0], window[1], window[0], window[1]]), 1.0 / win_height, 1.0 / win_width)
673
+ _copy_extra_fields(boxlist_new, boxlist)
674
+
675
+ return boxlist_new
676
+
677
+
678
+ def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
679
+ """Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
680
+
681
+ Args:
682
+ boxlist_to_copy_to: BoxList to which extra fields are copied.
683
+ boxlist_to_copy_from: BoxList from which fields are copied.
684
+
685
+ Returns:
686
+ boxlist_to_copy_to with extra fields.
687
+ """
688
+ for field in boxlist_to_copy_from.get_extra_fields():
689
+ boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field))
690
+ return boxlist_to_copy_to
691
+
692
+
693
+ def _update_valid_indices_by_removing_high_iou_boxes(
694
+ selected_indices, is_index_valid, intersect_over_union, threshold):
695
+ max_iou = np.max(intersect_over_union[:, selected_indices], axis=1)
696
+ return np.logical_and(is_index_valid, max_iou <= threshold)
efficientdet/effdet/evaluation/np_mask_list.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .np_box_list import *
3
+
4
+ EPSILON = 1e-7
5
+
6
+
7
+ class MaskList(BoxList):
8
+ """Convenience wrapper for BoxList with masks.
9
+
10
+ BoxMaskList extends the np_box_list.BoxList to contain masks as well.
11
+ In particular, its constructor receives both boxes and masks. Note that the
12
+ masks correspond to the full image.
13
+ """
14
+
15
+ def __init__(self, box_data, mask_data):
16
+ """Constructs box collection.
17
+
18
+ Args:
19
+ box_data: a numpy array of shape [N, 4] representing box coordinates
20
+ mask_data: a numpy array of shape [N, height, width] representing masks
21
+ with values are in {0,1}. The masks correspond to the full
22
+ image. The height and the width will be equal to image height and width.
23
+
24
+ Raises:
25
+ ValueError: if bbox data is not a numpy array
26
+ ValueError: if invalid dimensions for bbox data
27
+ ValueError: if mask data is not a numpy array
28
+ ValueError: if invalid dimension for mask data
29
+ """
30
+ super(MaskList, self).__init__(box_data)
31
+ if not isinstance(mask_data, np.ndarray):
32
+ raise ValueError('Mask data must be a numpy array.')
33
+ if len(mask_data.shape) != 3:
34
+ raise ValueError('Invalid dimensions for mask data.')
35
+ if mask_data.dtype != np.uint8:
36
+ raise ValueError('Invalid data type for mask data: uint8 is required.')
37
+ if mask_data.shape[0] != box_data.shape[0]:
38
+ raise ValueError('There should be the same number of boxes and masks.')
39
+ self.data['masks'] = mask_data
40
+
41
+ def get_masks(self):
42
+ """Convenience function for accessing masks.
43
+
44
+ Returns:
45
+ a numpy array of shape [N, height, width] representing masks
46
+ """
47
+ return self.get_field('masks')
48
+
49
+
50
+ def boxlist_to_masklist(boxlist):
51
+ """Converts a BoxList containing 'masks' into a BoxMaskList.
52
+
53
+ Args:
54
+ boxlist: An np_box_list.BoxList object.
55
+
56
+ Returns:
57
+ An BoxMaskList object.
58
+
59
+ Raises:
60
+ ValueError: If boxlist does not contain `masks` as a field.
61
+ """
62
+ if not boxlist.has_field('masks'):
63
+ raise ValueError('boxlist does not contain mask field.')
64
+ masklist = MaskList(box_data=boxlist.get(), mask_data=boxlist.get_field('masks'))
65
+ extra_fields = boxlist.get_extra_fields()
66
+ for key in extra_fields:
67
+ if key != 'masks':
68
+ masklist.data[key] = boxlist.get_field(key)
69
+ return masklist
70
+
71
+
72
+ def area_mask(masks):
73
+ """Computes area of masks.
74
+
75
+ Args:
76
+ masks: Numpy array with shape [N, height, width] holding N masks. Masks
77
+ values are of type np.uint8 and values are in {0,1}.
78
+
79
+ Returns:
80
+ a numpy array with shape [N*1] representing mask areas.
81
+
82
+ Raises:
83
+ ValueError: If masks.dtype is not np.uint8
84
+ """
85
+ if masks.dtype != np.uint8:
86
+ raise ValueError('Masks type should be np.uint8')
87
+ return np.sum(masks, axis=(1, 2), dtype=np.float32)
88
+
89
+
90
+ def intersection_mask(masks1, masks2):
91
+ """Compute pairwise intersection areas between masks.
92
+
93
+ Args:
94
+ masks1: a numpy array with shape [N, height, width] holding N masks. Masks
95
+ values are of type np.uint8 and values are in {0,1}.
96
+ masks2: a numpy array with shape [M, height, width] holding M masks. Masks
97
+ values are of type np.uint8 and values are in {0,1}.
98
+
99
+ Returns:
100
+ a numpy array with shape [N*M] representing pairwise intersection area.
101
+
102
+ Raises:
103
+ ValueError: If masks1 and masks2 are not of type np.uint8.
104
+ """
105
+ if masks1.dtype != np.uint8 or masks2.dtype != np.uint8:
106
+ raise ValueError('masks1 and masks2 should be of type np.uint8')
107
+ n = masks1.shape[0]
108
+ m = masks2.shape[0]
109
+ answer = np.zeros([n, m], dtype=np.float32)
110
+ for i in np.arange(n):
111
+ for j in np.arange(m):
112
+ answer[i, j] = np.sum(np.minimum(masks1[i], masks2[j]), dtype=np.float32)
113
+ return answer
114
+
115
+
116
+ def iou_mask(masks1, masks2):
117
+ """Computes pairwise intersection-over-union between mask collections.
118
+
119
+ Args:
120
+ masks1: a numpy array with shape [N, height, width] holding N masks. Masks
121
+ values are of type np.uint8 and values are in {0,1}.
122
+ masks2: a numpy array with shape [M, height, width] holding N masks. Masks
123
+ values are of type np.uint8 and values are in {0,1}.
124
+
125
+ Returns:
126
+ a numpy array with shape [N, M] representing pairwise iou scores.
127
+
128
+ Raises:
129
+ ValueError: If masks1 and masks2 are not of type np.uint8.
130
+ """
131
+ if masks1.dtype != np.uint8 or masks2.dtype != np.uint8:
132
+ raise ValueError('masks1 and masks2 should be of type np.uint8')
133
+ intersect = intersection(masks1, masks2)
134
+ area1 = area(masks1)
135
+ area2 = area(masks2)
136
+ union = np.expand_dims(area1, axis=1) + np.expand_dims(area2, axis=0) - intersect
137
+ return intersect / np.maximum(union, EPSILON)
138
+
139
+
140
+ def ioa_mask(masks1, masks2):
141
+ """Computes pairwise intersection-over-area between box collections.
142
+
143
+ Intersection-over-area (ioa) between two masks, mask1 and mask2 is defined as
144
+ their intersection area over mask2's area. Note that ioa is not symmetric,
145
+ that is, IOA(mask1, mask2) != IOA(mask2, mask1).
146
+
147
+ Args:
148
+ masks1: a numpy array with shape [N, height, width] holding N masks. Masks
149
+ values are of type np.uint8 and values are in {0,1}.
150
+ masks2: a numpy array with shape [M, height, width] holding N masks. Masks
151
+ values are of type np.uint8 and values are in {0,1}.
152
+
153
+ Returns:
154
+ a numpy array with shape [N, M] representing pairwise ioa scores.
155
+
156
+ Raises:
157
+ ValueError: If masks1 and masks2 are not of type np.uint8.
158
+ """
159
+ if masks1.dtype != np.uint8 or masks2.dtype != np.uint8:
160
+ raise ValueError('masks1 and masks2 should be of type np.uint8')
161
+ intersect = intersection(masks1, masks2)
162
+ areas = np.expand_dims(area(masks2), axis=0)
163
+ return intersect / (areas + EPSILON)
164
+
165
+
166
+ def area_masklist(masklist):
167
+ """Computes area of masks.
168
+
169
+ Args:
170
+ masklist: BoxMaskList holding N boxes and masks
171
+
172
+ Returns:
173
+ a numpy array with shape [N*1] representing mask areas
174
+ """
175
+ return area_mask(masklist.get_masks())
176
+
177
+
178
+ def intersection_masklist(masklist1, masklist2):
179
+ """Compute pairwise intersection areas between masks.
180
+
181
+ Args:
182
+ masklist1: BoxMaskList holding N boxes and masks
183
+ masklist2: BoxMaskList holding M boxes and masks
184
+
185
+ Returns:
186
+ a numpy array with shape [N*M] representing pairwise intersection area
187
+ """
188
+ return intersection_mask(masklist1.get_masks(), masklist2.get_masks())
189
+
190
+
191
+ def iou_masklist(masklist1, masklist2):
192
+ """Computes pairwise intersection-over-union between box and mask collections.
193
+
194
+ Args:
195
+ masklist1: BoxMaskList holding N boxes and masks
196
+ masklist2: BoxMaskList holding M boxes and masks
197
+
198
+ Returns:
199
+ a numpy array with shape [N, M] representing pairwise iou scores.
200
+ """
201
+ return iou_mask(masklist1.get_masks(), masklist2.get_masks())
202
+
203
+
204
+ def ioa_masklist(masklist1, masklist2):
205
+ """Computes pairwise intersection-over-area between box and mask collections.
206
+
207
+ Intersection-over-area (ioa) between two masks mask1 and mask2 is defined as
208
+ their intersection area over mask2's area. Note that ioa is not symmetric,
209
+ that is, IOA(mask1, mask2) != IOA(mask2, mask1).
210
+
211
+ Args:
212
+ masklist1: BoxMaskList holding N boxes and masks
213
+ masklist2: BoxMaskList holding M boxes and masks
214
+
215
+ Returns:
216
+ a numpy array with shape [N, M] representing pairwise ioa scores.
217
+ """
218
+ return ioa_mask(masklist1.get_masks(), masklist2.get_masks())
219
+
220
+
221
+ def gather_masklist(masklist, indices, fields=None):
222
+ """Gather boxes from BoxMaskList according to indices.
223
+
224
+ By default, gather returns boxes corresponding to the input index list, as
225
+ well as all additional fields stored in the masklist (indexing into the
226
+ first dimension). However one can optionally only gather from a
227
+ subset of fields.
228
+
229
+ Args:
230
+ masklist: BoxMaskList holding N boxes
231
+ indices: a 1-d numpy array of type int_
232
+ fields: (optional) list of fields to also gather from. If None (default), all fields
233
+ are gathered from. Pass an empty fields list to only gather the box coordinates.
234
+
235
+ Returns:
236
+ submasklist: a BoxMaskList corresponding to the subset of the input masklist specified by indices
237
+
238
+ Raises:
239
+ ValueError: if specified field is not contained in masklist or if the indices are not of type int_
240
+ """
241
+ if fields is not None:
242
+ if 'masks' not in fields:
243
+ fields.append('masks')
244
+ return boxlist_to_masklist(gather_boxlist(boxlist=masklist, indices=indices, fields=fields))
245
+
246
+
247
+ def sort_by_field_masklist(masklist, field, order=SortOrder.DESCEND):
248
+ """Sort boxes and associated fields according to a scalar field.
249
+
250
+ A common use case is reordering the boxes according to descending scores.
251
+
252
+ Args:
253
+ masklist: BoxMaskList holding N boxes.
254
+ field: A BoxMaskList field for sorting and reordering the BoxMaskList.
255
+ order: (Optional) 'descend' or 'ascend'. Default is descend.
256
+
257
+ Returns:
258
+ sorted_masklist: A sorted BoxMaskList with the field in the specified order.
259
+ """
260
+ return boxlist_to_masklist(sort_by_field_boxlist(boxlist=masklist, field=field, order=order))
261
+
262
+
263
+ def non_max_suppression_mask(masklist, max_output_size=10000, iou_threshold=1.0, score_threshold=-10.0):
264
+ """Non maximum suppression.
265
+
266
+ This op greedily selects a subset of detection bounding boxes, pruning
267
+ away boxes that have high IOU (intersection over union) overlap (> thresh)
268
+ with already selected boxes. In each iteration, the detected bounding box with
269
+ highest score in the available pool is selected.
270
+
271
+ Args:
272
+ masklist: BoxMaskList holding N boxes. Must contain a 'scores' field representing
273
+ detection scores. All scores belong to the same class.
274
+ max_output_size: maximum number of retained boxes
275
+ iou_threshold: intersection over union threshold.
276
+ score_threshold: minimum score threshold. Remove the boxes with scores
277
+ less than this value. Default value is set to -10. A very
278
+ low threshold to pass pretty much all the boxes, unless
279
+ the user sets a different score threshold.
280
+
281
+ Returns:
282
+ an BoxMaskList holding M boxes where M <= max_output_size
283
+
284
+ Raises:
285
+ ValueError: if 'scores' field does not exist
286
+ ValueError: if threshold is not in [0, 1]
287
+ ValueError: if max_output_size < 0
288
+ """
289
+ if not masklist.has_field('scores'):
290
+ raise ValueError('Field scores does not exist')
291
+ if iou_threshold < 0. or iou_threshold > 1.0:
292
+ raise ValueError('IOU threshold must be in [0, 1]')
293
+ if max_output_size < 0:
294
+ raise ValueError('max_output_size must be bigger than 0.')
295
+
296
+ masklist = filter_scores_greater_than(masklist, score_threshold)
297
+ if masklist.num_boxes() == 0:
298
+ return masklist
299
+
300
+ masklist = sort_by_field_boxlist(masklist, 'scores')
301
+
302
+ # Prevent further computation if NMS is disabled.
303
+ if iou_threshold == 1.0:
304
+ if masklist.num_boxes() > max_output_size:
305
+ selected_indices = np.arange(max_output_size)
306
+ return gather_masklist(masklist, selected_indices)
307
+ else:
308
+ return masklist
309
+
310
+ masks = masklist.get_masks()
311
+ num_masks = masklist.num_boxes()
312
+
313
+ # is_index_valid is True only for all remaining valid boxes,
314
+ is_index_valid = np.full(num_masks, 1, dtype=bool)
315
+ selected_indices = []
316
+ num_output = 0
317
+ for i in range(num_masks):
318
+ if num_output < max_output_size:
319
+ if is_index_valid[i]:
320
+ num_output += 1
321
+ selected_indices.append(i)
322
+ is_index_valid[i] = False
323
+ valid_indices = np.where(is_index_valid)[0]
324
+ if valid_indices.size == 0:
325
+ break
326
+
327
+ intersect_over_union = iou_mask(np.expand_dims(masks[i], axis=0), masks[valid_indices])
328
+ intersect_over_union = np.squeeze(intersect_over_union, axis=0)
329
+ is_index_valid[valid_indices] = np.logical_and(
330
+ is_index_valid[valid_indices],
331
+ intersect_over_union <= iou_threshold)
332
+ return gather_masklist(masklist, np.array(selected_indices))
333
+
334
+
335
+ def multi_class_non_max_suppression_mask(masklist, score_thresh, iou_thresh, max_output_size):
336
+ """Multi-class version of non maximum suppression.
337
+
338
+ This op greedily selects a subset of detection bounding boxes, pruning away boxes that have
339
+ high IOU (intersection over union) overlap (> thresh) with already selected boxes. It
340
+ operates independently for each class for which scores are provided (via the scores field
341
+ of the input box_list), pruning boxes with score less than a provided threshold prior to
342
+ applying NMS.
343
+
344
+ Args:
345
+ masklist: BoxMaskList holding N boxes. Must contain a 'scores' field representing detection
346
+ scores. This scores field is a tensor that can be 1 dimensional (in the case of a
347
+ single class) or 2-dimensional, in which case we assume that it takes the shape
348
+ [num_boxes, num_classes]. We further assume that this rank is known statically and
349
+ that scores.shape[1] is also known (i.e., the number of classes is fixed and known
350
+ at graph construction time).
351
+ score_thresh: scalar threshold for score (low scoring boxes are removed).
352
+ iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap with previously
353
+ selected boxes are removed).
354
+ max_output_size: maximum number of retained boxes per class.
355
+
356
+ Returns:
357
+ a masklist holding M boxes with a rank-1 scores field representing
358
+ corresponding scores for each box with scores sorted in decreasing order
359
+ and a rank-1 classes field representing a class label for each box.
360
+ Raises:
361
+ ValueError: if iou_thresh is not in [0, 1] or if input masklist does not have a valid scores field.
362
+ """
363
+ if not 0 <= iou_thresh <= 1.0:
364
+ raise ValueError('thresh must be between 0 and 1')
365
+ if not isinstance(masklist, MaskList):
366
+ raise ValueError('masklist must be a masklist')
367
+ if not masklist.has_field('scores'):
368
+ raise ValueError('input masklist must have \'scores\' field')
369
+ scores = masklist.get_field('scores')
370
+ if len(scores.shape) == 1:
371
+ scores = np.reshape(scores, [-1, 1])
372
+ elif len(scores.shape) == 2:
373
+ if scores.shape[1] is None:
374
+ raise ValueError('scores field must have statically defined second dimension')
375
+ else:
376
+ raise ValueError('scores field must be of rank 1 or 2')
377
+
378
+ num_boxes = masklist.num_boxes()
379
+ num_scores = scores.shape[0]
380
+ num_classes = scores.shape[1]
381
+
382
+ if num_boxes != num_scores:
383
+ raise ValueError('Incorrect scores field length: actual vs expected.')
384
+
385
+ selected_boxes_list = []
386
+ for class_idx in range(num_classes):
387
+ masklist_and_class_scores = MaskList(box_data=masklist.get(), mask_data=masklist.get_masks())
388
+ class_scores = np.reshape(scores[0:num_scores, class_idx], [-1])
389
+ masklist_and_class_scores.add_field('scores', class_scores)
390
+ masklist_filt = filter_scores_greater_than(masklist_and_class_scores, score_thresh)
391
+ nms_result = non_max_suppression(
392
+ masklist_filt,
393
+ max_output_size=max_output_size,
394
+ iou_threshold=iou_thresh,
395
+ score_threshold=score_thresh)
396
+ nms_result.add_field('classes', np.zeros_like(nms_result.get_field('scores')) + class_idx)
397
+ selected_boxes_list.append(nms_result)
398
+ selected_boxes = concatenate_boxlist(selected_boxes_list)
399
+ sorted_boxes = sort_by_field_boxlist(selected_boxes, 'scores')
400
+ return boxlist_to_masklist(boxlist=sorted_boxes)
401
+
402
+
403
+ def prune_non_overlapping_masklist(masklist1, masklist2, minoverlap=0.0):
404
+ """Prunes the boxes in list1 that overlap less than thresh with list2.
405
+
406
+ For each mask in masklist1, we want its IOA to be more than minoverlap
407
+ with at least one of the masks in masklist2. If it does not, we remove
408
+ it. If the masks are not full size image, we do the pruning based on boxes.
409
+
410
+ Args:
411
+ masklist1: BoxMaskList holding N boxes and masks.
412
+ masklist2: BoxMaskList holding M boxes and masks.
413
+ minoverlap: Minimum required overlap between boxes, to count them as overlapping.
414
+
415
+ Returns:
416
+ A pruned masklist with size [N', 4].
417
+ """
418
+ intersection_over_area = ioa_masklist(masklist2, masklist1) # [M, N] tensor
419
+ intersection_over_area = np.amax(intersection_over_area, axis=0) # [N] tensor
420
+ keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap))
421
+ keep_inds = np.nonzero(keep_bool)[0]
422
+ new_masklist1 = gather_masklist(masklist1, keep_inds)
423
+ return new_masklist1
424
+
425
+
426
+ def concatenate_masklist(masklists, fields=None):
427
+ """Concatenate list of masklists.
428
+
429
+ This op concatenates a list of input masklists into a larger
430
+ masklist. It also
431
+ handles concatenation of masklist fields as long as the field tensor
432
+ shapes are equal except for the first dimension.
433
+
434
+ Args:
435
+ masklists: list of BoxMaskList objects
436
+ fields: optional list of fields to also concatenate. By default, all
437
+ fields from the first BoxMaskList in the list are included in the concatenation.
438
+
439
+ Returns:
440
+ a masklist with number of boxes equal to sum([masklist.num_boxes() for masklist in masklist])
441
+ Raises:
442
+ ValueError: if masklists is invalid (i.e., is not a list, is empty, or contains non
443
+ masklist objects), or if requested fields are not contained in all masklists
444
+ """
445
+ if fields is not None:
446
+ if 'masks' not in fields:
447
+ fields.append('masks')
448
+ return boxlist_to_masklist(concatenate_boxlist(boxlists=masklists, fields=fields))
449
+
450
+
451
+ def filter_scores_greater_than_masklist(masklist, thresh):
452
+ """Filter to keep only boxes and masks with score exceeding a given threshold.
453
+
454
+ This op keeps the collection of boxes and masks whose corresponding scores are
455
+ greater than the input threshold.
456
+
457
+ Args:
458
+ masklist: BoxMaskList holding N boxes and masks. Must contain a
459
+ 'scores' field representing detection scores.
460
+ thresh: scalar threshold
461
+
462
+ Returns:
463
+ a BoxMaskList holding M boxes and masks where M <= N
464
+
465
+ Raises:
466
+ ValueError: if masklist not a BoxMaskList object or if it does not have a scores field
467
+ """
468
+ if not isinstance(masklist, MaskList):
469
+ raise ValueError('masklist must be a BoxMaskList')
470
+ if not masklist.has_field('scores'):
471
+ raise ValueError('input masklist must have \'scores\' field')
472
+ scores = masklist.get_field('scores')
473
+ if len(scores.shape) > 2:
474
+ raise ValueError('Scores should have rank 1 or 2')
475
+ if len(scores.shape) == 2 and scores.shape[1] != 1:
476
+ raise ValueError('Scores should have rank 1 or have shape consistent with [None, 1]')
477
+ high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), [-1]).astype(np.int32)
478
+ return gather_masklist(masklist, high_score_indices)
efficientdet/effdet/evaluation/object_detection_evaluation.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+
5
+ from effdet.evaluation.metrics import compute_precision_recall, compute_average_precision, compute_cor_loc
6
+ from effdet.evaluation.per_image_evaluation import PerImageEvaluation
7
+
8
+
9
+ class ObjectDetectionEvaluation:
10
+ """Internal implementation of Pascal object detection metrics."""
11
+
12
+ def __init__(self,
13
+ num_gt_classes,
14
+ matching_iou_threshold=0.5,
15
+ nms_iou_threshold=1.0,
16
+ nms_max_output_boxes=10000,
17
+ recall_lower_bound=0.0,
18
+ recall_upper_bound=1.0,
19
+ use_weighted_mean_ap=False,
20
+ label_id_offset=0,
21
+ group_of_weight=0.0,
22
+ per_image_eval_class=PerImageEvaluation):
23
+ """Constructor.
24
+ Args:
25
+ num_gt_classes: Number of ground-truth classes.
26
+ matching_iou_threshold: IOU threshold used for matching detected boxes to ground-truth boxes.
27
+ nms_iou_threshold: IOU threshold used for non-maximum suppression.
28
+ nms_max_output_boxes: Maximum number of boxes returned by non-maximum suppression.
29
+ recall_lower_bound: lower bound of recall operating area
30
+ recall_upper_bound: upper bound of recall operating area
31
+ use_weighted_mean_ap: (optional) boolean which determines if the mean
32
+ average precision is computed directly from the scores and tp_fp_labels of all classes.
33
+ label_id_offset: The label id offset.
34
+ group_of_weight: Weight of group-of boxes.If set to 0, detections of the
35
+ correct class within a group-of box are ignored. If weight is > 0, then
36
+ if at least one detection falls within a group-of box with
37
+ matching_iou_threshold, weight group_of_weight is added to true
38
+ positives. Consequently, if no detection falls within a group-of box,
39
+ weight group_of_weight is added to false negatives.
40
+ per_image_eval_class: The class that contains functions for computing per image metrics.
41
+ Raises:
42
+ ValueError: if num_gt_classes is smaller than 1.
43
+ """
44
+ if num_gt_classes < 1:
45
+ raise ValueError('Need at least 1 groundtruth class for evaluation.')
46
+
47
+ self.per_image_eval = per_image_eval_class(
48
+ num_gt_classes=num_gt_classes,
49
+ matching_iou_threshold=matching_iou_threshold,
50
+ nms_iou_threshold=nms_iou_threshold,
51
+ nms_max_output_boxes=nms_max_output_boxes,
52
+ group_of_weight=group_of_weight)
53
+ self.recall_lower_bound = recall_lower_bound
54
+ self.recall_upper_bound = recall_upper_bound
55
+ self.group_of_weight = group_of_weight
56
+ self.num_class = num_gt_classes
57
+ self.use_weighted_mean_ap = use_weighted_mean_ap
58
+ self.label_id_offset = label_id_offset
59
+
60
+ self.gt_boxes = {}
61
+ self.gt_class_labels = {}
62
+ self.gt_masks = {}
63
+ self.gt_is_difficult_list = {}
64
+ self.gt_is_group_of_list = {}
65
+ self.num_gt_instances_per_class = np.zeros(self.num_class, dtype=float)
66
+ self.num_gt_imgs_per_class = np.zeros(self.num_class, dtype=int)
67
+
68
+ self._initialize_detections()
69
+
70
+ def _initialize_detections(self):
71
+ """Initializes internal data structures."""
72
+ self.detection_keys = set()
73
+ self.scores_per_class = [[] for _ in range(self.num_class)]
74
+ self.tp_fp_labels_per_class = [[] for _ in range(self.num_class)]
75
+ self.num_images_correctly_detected_per_class = np.zeros(self.num_class)
76
+ self.average_precision_per_class = np.empty(self.num_class, dtype=float)
77
+ self.average_precision_per_class.fill(np.nan)
78
+ self.precisions_per_class = [np.nan] * self.num_class
79
+ self.recalls_per_class = [np.nan] * self.num_class
80
+ self.sum_tp_class = [np.nan] * self.num_class
81
+
82
+ self.corloc_per_class = np.ones(self.num_class, dtype=float)
83
+
84
+ def clear_detections(self):
85
+ self._initialize_detections()
86
+
87
+ def add_single_ground_truth_image_info(
88
+ self, image_key, gt_boxes, gt_class_labels,
89
+ gt_is_difficult_list=None, gt_is_group_of_list=None, gt_masks=None):
90
+ """Adds groundtruth for a single image to be used for evaluation.
91
+ Args:
92
+ image_key: A unique string/integer identifier for the image.
93
+ gt_boxes: float32 numpy array of shape [num_boxes, 4] containing
94
+ `num_boxes` groundtruth boxes of the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
95
+ gt_class_labels: integer numpy array of shape [num_boxes]
96
+ containing 0-indexed groundtruth classes for the boxes.
97
+ gt_is_difficult_list: A length M numpy boolean array denoting
98
+ whether a ground truth box is a difficult instance or not. To support
99
+ the case that no boxes are difficult, it is by default set as None.
100
+ gt_is_group_of_list: A length M numpy boolean array denoting
101
+ whether a ground truth box is a group-of box or not. To support the case
102
+ that no boxes are groups-of, it is by default set as None.
103
+ gt_masks: uint8 numpy array of shape [num_boxes, height, width]
104
+ containing `num_boxes` groundtruth masks. The mask values range from 0 to 1.
105
+ """
106
+ if image_key in self.gt_boxes:
107
+ logging.warning('image %s has already been added to the ground truth database.', image_key)
108
+ return
109
+
110
+ self.gt_boxes[image_key] = gt_boxes
111
+ self.gt_class_labels[image_key] = gt_class_labels
112
+ self.gt_masks[image_key] = gt_masks
113
+ if gt_is_difficult_list is None:
114
+ num_boxes = gt_boxes.shape[0]
115
+ gt_is_difficult_list = np.zeros(num_boxes, dtype=bool)
116
+ gt_is_difficult_list = gt_is_difficult_list.astype(dtype=bool)
117
+ self.gt_is_difficult_list[image_key] = gt_is_difficult_list
118
+ if gt_is_group_of_list is None:
119
+ num_boxes = gt_boxes.shape[0]
120
+ gt_is_group_of_list = np.zeros(num_boxes, dtype=bool)
121
+ if gt_masks is None:
122
+ num_boxes = gt_boxes.shape[0]
123
+ mask_presence_indicator = np.zeros(num_boxes, dtype=bool)
124
+ else:
125
+ mask_presence_indicator = (np.sum(gt_masks, axis=(1, 2)) == 0).astype(dtype=bool)
126
+
127
+ gt_is_group_of_list = gt_is_group_of_list.astype(dtype=bool)
128
+ self.gt_is_group_of_list[image_key] = gt_is_group_of_list
129
+
130
+ # ignore boxes without masks
131
+ masked_gt_is_difficult_list = gt_is_difficult_list | mask_presence_indicator
132
+ for class_index in range(self.num_class):
133
+ num_gt_instances = np.sum(
134
+ gt_class_labels[~masked_gt_is_difficult_list & ~gt_is_group_of_list] == class_index)
135
+ num_groupof_gt_instances = self.group_of_weight * np.sum(
136
+ gt_class_labels[gt_is_group_of_list & ~masked_gt_is_difficult_list] == class_index)
137
+ self.num_gt_instances_per_class[class_index] += num_gt_instances + num_groupof_gt_instances
138
+ if np.any(gt_class_labels == class_index):
139
+ self.num_gt_imgs_per_class[class_index] += 1
140
+
141
+ def add_single_detected_image_info(
142
+ self, image_key, detected_boxes, detected_scores, detected_class_labels, detected_masks=None):
143
+ """Adds detections for a single image to be used for evaluation.
144
+ Args:
145
+ image_key: A unique string/integer identifier for the image.
146
+ detected_boxes: float32 numpy array of shape [num_boxes, 4] containing
147
+ `num_boxes` detection boxes of the format [ymin, xmin, ymax, xmax] in
148
+ absolute image coordinates.
149
+ detected_scores: float32 numpy array of shape [num_boxes] containing
150
+ detection scores for the boxes.
151
+ detected_class_labels: integer numpy array of shape [num_boxes] containing
152
+ 0-indexed detection classes for the boxes.
153
+ detected_masks: np.uint8 numpy array of shape [num_boxes, height, width]
154
+ containing `num_boxes` detection masks with values ranging between 0 and 1.
155
+ Raises:
156
+ ValueError: if the number of boxes, scores and class labels differ in length.
157
+ """
158
+ if len(detected_boxes) != len(detected_scores) or len(detected_boxes) != len(detected_class_labels):
159
+ raise ValueError(
160
+ 'detected_boxes, detected_scores and '
161
+ 'detected_class_labels should all have same lengths. Got'
162
+ '[%d, %d, %d]' % len(detected_boxes), len(detected_scores),
163
+ len(detected_class_labels))
164
+
165
+ if image_key in self.detection_keys:
166
+ logging.warning('image %s has already been added to the detection result database', image_key)
167
+ return
168
+
169
+ self.detection_keys.add(image_key)
170
+ if image_key in self.gt_boxes:
171
+ gt_boxes = self.gt_boxes[image_key]
172
+ gt_class_labels = self.gt_class_labels[image_key]
173
+ # Masks are popped instead of look up. The reason is that we do not want
174
+ # to keep all masks in memory which can cause memory overflow.
175
+ gt_masks = self.gt_masks.pop(image_key)
176
+ gt_is_difficult_list = self.gt_is_difficult_list[image_key]
177
+ gt_is_group_of_list = self.gt_is_group_of_list[image_key]
178
+ else:
179
+ gt_boxes = np.empty(shape=[0, 4], dtype=float)
180
+ gt_class_labels = np.array([], dtype=int)
181
+ if detected_masks is None:
182
+ gt_masks = None
183
+ else:
184
+ gt_masks = np.empty(shape=[0, 1, 1], dtype=float)
185
+ gt_is_difficult_list = np.array([], dtype=bool)
186
+ gt_is_group_of_list = np.array([], dtype=bool)
187
+ scores, tp_fp_labels, is_class_correctly_detected_in_image = \
188
+ self.per_image_eval.compute_object_detection_metrics(
189
+ detected_boxes=detected_boxes,
190
+ detected_scores=detected_scores,
191
+ detected_class_labels=detected_class_labels,
192
+ gt_boxes=gt_boxes,
193
+ gt_class_labels=gt_class_labels,
194
+ gt_is_difficult_list=gt_is_difficult_list,
195
+ gt_is_group_of_list=gt_is_group_of_list,
196
+ detected_masks=detected_masks,
197
+ gt_masks=gt_masks)
198
+
199
+ for i in range(self.num_class):
200
+ if scores[i].shape[0] > 0:
201
+ self.scores_per_class[i].append(scores[i])
202
+ self.tp_fp_labels_per_class[i].append(tp_fp_labels[i])
203
+ self.num_images_correctly_detected_per_class += is_class_correctly_detected_in_image
204
+
205
+ def evaluate(self):
206
+ """Compute evaluation result.
207
+ Returns:
208
+ A dict with the following fields -
209
+ average_precision: float numpy array of average precision for each class.
210
+ mean_ap: mean average precision of all classes, float scalar
211
+ precisions: List of precisions, each precision is a float numpy array
212
+ recalls: List of recalls, each recall is a float numpy array
213
+ corloc: numpy float array
214
+ mean_corloc: Mean CorLoc score for each class, float scalar
215
+ """
216
+ if (self.num_gt_instances_per_class == 0).any():
217
+ logging.warning(
218
+ 'The following classes have no ground truth examples: %s',
219
+ np.squeeze(np.argwhere(self.num_gt_instances_per_class == 0)) + self.label_id_offset)
220
+
221
+ if self.use_weighted_mean_ap:
222
+ all_scores = np.array([], dtype=float)
223
+ all_tp_fp_labels = np.array([], dtype=bool)
224
+ for class_index in range(self.num_class):
225
+ if self.num_gt_instances_per_class[class_index] == 0:
226
+ continue
227
+ if not self.scores_per_class[class_index]:
228
+ scores = np.array([], dtype=float)
229
+ tp_fp_labels = np.array([], dtype=float)
230
+ else:
231
+ scores = np.concatenate(self.scores_per_class[class_index])
232
+ tp_fp_labels = np.concatenate(self.tp_fp_labels_per_class[class_index])
233
+ if self.use_weighted_mean_ap:
234
+ all_scores = np.append(all_scores, scores)
235
+ all_tp_fp_labels = np.append(all_tp_fp_labels, tp_fp_labels)
236
+ precision, recall = compute_precision_recall(
237
+ scores, tp_fp_labels, self.num_gt_instances_per_class[class_index])
238
+ recall_within_bound_indices = [
239
+ index for index, value in enumerate(recall) if
240
+ value >= self.recall_lower_bound and value <= self.recall_upper_bound
241
+ ]
242
+ recall_within_bound = recall[recall_within_bound_indices]
243
+ precision_within_bound = precision[recall_within_bound_indices]
244
+
245
+ self.precisions_per_class[class_index] = precision_within_bound
246
+ self.recalls_per_class[class_index] = recall_within_bound
247
+ self.sum_tp_class[class_index] = tp_fp_labels.sum()
248
+ average_precision = compute_average_precision(precision_within_bound, recall_within_bound)
249
+ self.average_precision_per_class[class_index] = average_precision
250
+ logging.debug('average_precision: %f', average_precision)
251
+
252
+ self.corloc_per_class = compute_cor_loc(
253
+ self.num_gt_imgs_per_class, self.num_images_correctly_detected_per_class)
254
+
255
+ if self.use_weighted_mean_ap:
256
+ num_gt_instances = np.sum(self.num_gt_instances_per_class)
257
+ precision, recall = compute_precision_recall(all_scores, all_tp_fp_labels, num_gt_instances)
258
+ recall_within_bound_indices = [
259
+ index for index, value in enumerate(recall) if
260
+ value >= self.recall_lower_bound and value <= self.recall_upper_bound
261
+ ]
262
+ recall_within_bound = recall[recall_within_bound_indices]
263
+ precision_within_bound = precision[recall_within_bound_indices]
264
+ mean_ap = compute_average_precision(precision_within_bound, recall_within_bound)
265
+ else:
266
+ mean_ap = np.nanmean(self.average_precision_per_class)
267
+ mean_corloc = np.nanmean(self.corloc_per_class)
268
+
269
+ return dict(
270
+ per_class_ap=self.average_precision_per_class, mean_ap=mean_ap,
271
+ per_class_precision=self.precisions_per_class,
272
+ per_class_recall=self.recalls_per_class,
273
+ per_class_corlocs=self.corloc_per_class, mean_corloc=mean_corloc)
efficientdet/effdet/evaluation/per_image_evaluation.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .np_mask_list import *
2
+ from .metrics import *
3
+
4
+
5
+ class PerImageEvaluation:
6
+ """Evaluate detection result of a single image."""
7
+
8
+ def __init__(self,
9
+ num_gt_classes,
10
+ matching_iou_threshold=0.5,
11
+ nms_iou_threshold=0.3,
12
+ nms_max_output_boxes=50,
13
+ group_of_weight=0.0):
14
+ """Initialized PerImageEvaluation by evaluation parameters.
15
+ Args:
16
+ num_gt_classes: Number of ground truth object classes
17
+ matching_iou_threshold: A ratio of area intersection to union, which is
18
+ the threshold to consider whether a detection is true positive or not
19
+ nms_iou_threshold: IOU threshold used in Non Maximum Suppression.
20
+ nms_max_output_boxes: Number of maximum output boxes in NMS.
21
+ group_of_weight: Weight of the group-of boxes.
22
+ """
23
+ self.matching_iou_threshold = matching_iou_threshold
24
+ self.nms_iou_threshold = nms_iou_threshold
25
+ self.nms_max_output_boxes = nms_max_output_boxes
26
+ self.num_gt_classes = num_gt_classes
27
+ self.group_of_weight = group_of_weight
28
+
29
+ def compute_object_detection_metrics(
30
+ self, detected_boxes, detected_scores, detected_class_labels,
31
+ gt_boxes, gt_class_labels, gt_is_difficult_list, gt_is_group_of_list,
32
+ detected_masks=None, gt_masks=None):
33
+ """Evaluates detections as being tp, fp or weighted from a single image.
34
+ The evaluation is done in two stages:
35
+ 1. All detections are matched to non group-of boxes; true positives are
36
+ determined and detections matched to difficult boxes are ignored.
37
+ 2. Detections that are determined as false positives are matched against
38
+ group-of boxes and weighted if matched.
39
+ Args:
40
+ detected_boxes: A float numpy array of shape [N, 4], representing N
41
+ regions of detected object regions. Each row is of the format [y_min, x_min, y_max, x_max]
42
+ detected_scores: A float numpy array of shape [N, 1], representing the
43
+ confidence scores of the detected N object instances.
44
+ detected_class_labels: A integer numpy array of shape [N, 1], repreneting
45
+ the class labels of the detected N object instances.
46
+ gt_boxes: A float numpy array of shape [M, 4], representing M
47
+ regions of object instances in ground truth
48
+ gt_class_labels: An integer numpy array of shape [M, 1],
49
+ representing M class labels of object instances in ground truth
50
+ gt_is_difficult_list: A boolean numpy array of length M denoting
51
+ whether a ground truth box is a difficult instance or not
52
+ gt_is_group_of_list: A boolean numpy array of length M denoting
53
+ whether a ground truth box has group-of tag
54
+ detected_masks: (optional) A uint8 numpy array of shape [N, height,
55
+ width]. If not None, the metrics will be computed based on masks.
56
+ gt_masks: (optional) A uint8 numpy array of shape [M, height,
57
+ width]. Can have empty masks, i.e. where all values are 0.
58
+ Returns:
59
+ scores: A list of C float numpy arrays. Each numpy array is of
60
+ shape [K, 1], representing K scores detected with object class label c
61
+ tp_fp_labels: A list of C boolean numpy arrays. Each numpy array
62
+ is of shape [K, 1], representing K True/False positive label of
63
+ object instances detected with class label c
64
+ is_class_correctly_detected_in_image: a numpy integer array of
65
+ shape [C, 1], indicating whether the correponding class has a least
66
+ one instance being correctly detected in the image
67
+ """
68
+ detected_boxes, detected_scores, detected_class_labels, detected_masks = (
69
+ self._remove_invalid_boxes(detected_boxes, detected_scores, detected_class_labels, detected_masks))
70
+
71
+ scores, tp_fp_labels = self._compute_tp_fp(
72
+ detected_boxes=detected_boxes,
73
+ detected_scores=detected_scores,
74
+ detected_class_labels=detected_class_labels,
75
+ gt_boxes=gt_boxes,
76
+ gt_class_labels=gt_class_labels,
77
+ gt_is_difficult_list=gt_is_difficult_list,
78
+ gt_is_group_of_list=gt_is_group_of_list,
79
+ detected_masks=detected_masks,
80
+ gt_masks=gt_masks)
81
+
82
+ is_class_correctly_detected_in_image = self._compute_cor_loc(
83
+ detected_boxes=detected_boxes,
84
+ detected_scores=detected_scores,
85
+ detected_class_labels=detected_class_labels,
86
+ gt_boxes=gt_boxes,
87
+ gt_class_labels=gt_class_labels,
88
+ detected_masks=detected_masks,
89
+ gt_masks=gt_masks)
90
+
91
+ return scores, tp_fp_labels, is_class_correctly_detected_in_image
92
+
93
+ def _compute_cor_loc(
94
+ self, detected_boxes, detected_scores, detected_class_labels,
95
+ gt_boxes, gt_class_labels, detected_masks=None, gt_masks=None):
96
+ """Compute CorLoc score for object detection result.
97
+ Args:
98
+ detected_boxes: A float numpy array of shape [N, 4], representing N
99
+ regions of detected object regions. Each row is of the format [y_min, x_min, y_max, x_max]
100
+ detected_scores: A float numpy array of shape [N, 1], representing the
101
+ confidence scores of the detected N object instances.
102
+ detected_class_labels: A integer numpy array of shape [N, 1], repreneting
103
+ the class labels of the detected N object instances.
104
+ gt_boxes: A float numpy array of shape [M, 4], representing M
105
+ regions of object instances in ground truth
106
+ gt_class_labels: An integer numpy array of shape [M, 1],
107
+ representing M class labels of object instances in ground truth
108
+ detected_masks: (optional) A uint8 numpy array of shape [N, height, width].
109
+ If not None, the scores will be computed based on masks.
110
+ gt_masks: (optional) A uint8 numpy array of shape [M, height, width].
111
+ Returns:
112
+ is_class_correctly_detected_in_image: a numpy integer array of
113
+ shape [C, 1], indicating whether the correponding class has a least
114
+ one instance being correctly detected in the image
115
+ Raises:
116
+ ValueError: If detected masks is not None but groundtruth masks are None,
117
+ or the other way around.
118
+ """
119
+ if (detected_masks is not None and gt_masks is None) or (
120
+ detected_masks is None and gt_masks is not None):
121
+ raise ValueError(
122
+ 'If `detected_masks` is provided, then `gt_masks` should also be provided.')
123
+
124
+ is_class_correctly_detected_in_image = np.zeros(
125
+ self.num_gt_classes, dtype=int)
126
+ for i in range(self.num_gt_classes):
127
+ (gt_boxes_at_ith_class, gt_masks_at_ith_class,
128
+ detected_boxes_at_ith_class, detected_scores_at_ith_class,
129
+ detected_masks_at_ith_class) = self._get_ith_class_arrays(
130
+ detected_boxes, detected_scores, detected_masks,
131
+ detected_class_labels, gt_boxes, gt_masks,
132
+ gt_class_labels, i)
133
+ is_class_correctly_detected_in_image[i] = (
134
+ self._compute_is_class_correctly_detected_in_image(
135
+ detected_boxes=detected_boxes_at_ith_class,
136
+ detected_scores=detected_scores_at_ith_class,
137
+ gt_boxes=gt_boxes_at_ith_class,
138
+ detected_masks=detected_masks_at_ith_class,
139
+ gt_masks=gt_masks_at_ith_class))
140
+
141
+ return is_class_correctly_detected_in_image
142
+
143
+ def _compute_is_class_correctly_detected_in_image(
144
+ self, detected_boxes, detected_scores, gt_boxes, detected_masks=None, gt_masks=None):
145
+ """Compute CorLoc score for a single class.
146
+ Args:
147
+ detected_boxes: A numpy array of shape [N, 4] representing detected box coordinates
148
+ detected_scores: A 1-d numpy array of length N representing classification score
149
+ gt_boxes: A numpy array of shape [M, 4] representing ground truth box coordinates
150
+ detected_masks: (optional) A np.uint8 numpy array of shape [N, height, width].
151
+ If not None, the scores will be computed based on masks.
152
+ gt_masks: (optional) A np.uint8 numpy array of shape [M, height, width].
153
+ Returns:
154
+ is_class_correctly_detected_in_image: An integer 1 or 0 denoting whether a
155
+ class is correctly detected in the image or not
156
+ """
157
+ if detected_boxes.size > 0:
158
+ if gt_boxes.size > 0:
159
+ max_score_id = np.argmax(detected_scores)
160
+ mask_mode = False
161
+ if detected_masks is not None and gt_masks is not None:
162
+ mask_mode = True
163
+ if mask_mode:
164
+ detected_boxlist = MaskList(
165
+ box_data=np.expand_dims(detected_boxes[max_score_id], axis=0),
166
+ mask_data=np.expand_dims(detected_masks[max_score_id], axis=0))
167
+ gt_boxlist = MaskList(box_data=gt_boxes, mask_data=gt_masks)
168
+ iou = iou_masklist(detected_boxlist, gt_boxlist)
169
+ else:
170
+ detected_boxlist = BoxList(np.expand_dims(detected_boxes[max_score_id, :], axis=0))
171
+ gt_boxlist = BoxList(gt_boxes)
172
+ iou = iou_boxlist(detected_boxlist, gt_boxlist)
173
+ if np.max(iou) >= self.matching_iou_threshold:
174
+ return 1
175
+ return 0
176
+
177
+ def _compute_tp_fp(
178
+ self, detected_boxes, detected_scores, detected_class_labels,
179
+ gt_boxes, gt_class_labels, gt_is_difficult_list, gt_is_group_of_list, detected_masks=None, gt_masks=None):
180
+ """Labels true/false positives of detections of an image across all classes.
181
+ Args:
182
+ detected_boxes: A float numpy array of shape [N, 4], representing N
183
+ regions of detected object regions. Each row is of the format [y_min, x_min, y_max, x_max]
184
+ detected_scores: A float numpy array of shape [N, 1], representing the
185
+ confidence scores of the detected N object instances.
186
+ detected_class_labels: A integer numpy array of shape [N, 1], representing
187
+ the class labels of the detected N object instances.
188
+ gt_boxes: A float numpy array of shape [M, 4], representing M
189
+ regions of object instances in ground truth
190
+ gt_class_labels: An integer numpy array of shape [M, 1],
191
+ representing M class labels of object instances in ground truth
192
+ gt_is_difficult_list: A boolean numpy array of length M denoting
193
+ whether a ground truth box is a difficult instance or not
194
+ gt_is_group_of_list: A boolean numpy array of length M denoting
195
+ whether a ground truth box has group-of tag
196
+ detected_masks: (optional) A np.uint8 numpy array of shape [N, height,
197
+ width]. If not None, the scores will be computed based on masks.
198
+ gt_masks: (optional) A np.uint8 numpy array of shape [M, height, width].
199
+ Returns:
200
+ result_scores: A list of float numpy arrays. Each numpy array is of
201
+ shape [K, 1], representing K scores detected with object class label c
202
+ result_tp_fp_labels: A list of boolean numpy array. Each numpy array is of
203
+ shape [K, 1], representing K True/False positive label of object
204
+ instances detected with class label c
205
+ Raises:
206
+ ValueError: If detected masks is not None but groundtruth masks are None,
207
+ or the other way around.
208
+ """
209
+ if detected_masks is not None and gt_masks is None:
210
+ raise ValueError(
211
+ 'Detected masks is available but groundtruth masks is not.')
212
+ if detected_masks is None and gt_masks is not None:
213
+ raise ValueError(
214
+ 'Groundtruth masks is available but detected masks is not.')
215
+
216
+ result_scores = []
217
+ result_tp_fp_labels = []
218
+ for i in range(self.num_gt_classes):
219
+ gt_is_difficult_list_at_ith_class = (
220
+ gt_is_difficult_list[gt_class_labels == i])
221
+ gt_is_group_of_list_at_ith_class = (
222
+ gt_is_group_of_list[gt_class_labels == i])
223
+ (gt_boxes_at_ith_class, gt_masks_at_ith_class,
224
+ detected_boxes_at_ith_class, detected_scores_at_ith_class,
225
+ detected_masks_at_ith_class) = self._get_ith_class_arrays(
226
+ detected_boxes, detected_scores, detected_masks,
227
+ detected_class_labels, gt_boxes, gt_masks,
228
+ gt_class_labels, i)
229
+ scores, tp_fp_labels = self._compute_tp_fp_for_single_class(
230
+ detected_boxes=detected_boxes_at_ith_class,
231
+ detected_scores=detected_scores_at_ith_class,
232
+ gt_boxes=gt_boxes_at_ith_class,
233
+ gt_is_difficult_list=gt_is_difficult_list_at_ith_class,
234
+ gt_is_group_of_list=gt_is_group_of_list_at_ith_class,
235
+ detected_masks=detected_masks_at_ith_class,
236
+ gt_masks=gt_masks_at_ith_class)
237
+ result_scores.append(scores)
238
+ result_tp_fp_labels.append(tp_fp_labels)
239
+ return result_scores, result_tp_fp_labels
240
+
241
+ def _get_overlaps_and_scores_mask_mode(
242
+ self, detected_boxes, detected_scores, detected_masks,
243
+ gt_boxes, gt_masks, gt_is_group_of_list):
244
+ """Computes overlaps and scores between detected and groudntruth masks.
245
+ Args:
246
+ detected_boxes: A numpy array of shape [N, 4] representing detected box coordinates
247
+ detected_scores: A 1-d numpy array of length N representing classification score
248
+ detected_masks: A uint8 numpy array of shape [N, height, width]. If not
249
+ None, the scores will be computed based on masks.
250
+ gt_boxes: A numpy array of shape [M, 4] representing ground truth box coordinates
251
+ gt_masks: A uint8 numpy array of shape [M, height, width].
252
+ gt_is_group_of_list: A boolean numpy array of length M denoting
253
+ whether a ground truth box has group-of tag. If a groundtruth box is
254
+ group-of box, every detection matching this box is ignored.
255
+ Returns:
256
+ iou: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
257
+ gt_non_group_of_boxlist.num_boxes() == 0 it will be None.
258
+ ioa: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
259
+ gt_group_of_boxlist.num_boxes() == 0 it will be None.
260
+ scores: The score of the detected boxlist.
261
+ num_boxes: Number of non-maximum suppressed detected boxes.
262
+ """
263
+ detected_boxlist = MaskList(box_data=detected_boxes, mask_data=detected_masks)
264
+ detected_boxlist.add_field('scores', detected_scores)
265
+ detected_boxlist = non_max_suppression(detected_boxlist, self.nms_max_output_boxes, self.nms_iou_threshold)
266
+ gt_non_group_of_boxlist = MaskList(
267
+ box_data=gt_boxes[~gt_is_group_of_list], mask_data=gt_masks[~gt_is_group_of_list])
268
+ gt_group_of_boxlist = MaskList(
269
+ box_data=gt_boxes[gt_is_group_of_list], mask_data=gt_masks[gt_is_group_of_list])
270
+ iou_b = iou_masklist(detected_boxlist, gt_non_group_of_boxlist)
271
+ ioa_b = np.transpose(ioa_masklist(gt_group_of_boxlist, detected_boxlist))
272
+ scores = detected_boxlist.get_field('scores')
273
+ num_boxes = detected_boxlist.num_boxes()
274
+ return iou_b, ioa_b, scores, num_boxes
275
+
276
+ def _get_overlaps_and_scores_box_mode(
277
+ self, detected_boxes, detected_scores, gt_boxes, gt_is_group_of_list):
278
+ """Computes overlaps and scores between detected and groudntruth boxes.
279
+ Args:
280
+ detected_boxes: A numpy array of shape [N, 4] representing detected box coordinates
281
+ detected_scores: A 1-d numpy array of length N representing classification score
282
+ gt_boxes: A numpy array of shape [M, 4] representing ground truth box coordinates
283
+ gt_is_group_of_list: A boolean numpy array of length M denoting
284
+ whether a ground truth box has group-of tag. If a groundtruth box is
285
+ group-of box, every detection matching this box is ignored.
286
+ Returns:
287
+ iou: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
288
+ gt_non_group_of_boxlist.num_boxes() == 0 it will be None.
289
+ ioa: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
290
+ gt_group_of_boxlist.num_boxes() == 0 it will be None.
291
+ scores: The score of the detected boxlist.
292
+ num_boxes: Number of non-maximum suppressed detected boxes.
293
+ """
294
+ detected_boxlist = BoxList(detected_boxes)
295
+ detected_boxlist.add_field('scores', detected_scores)
296
+ detected_boxlist = non_max_suppression(detected_boxlist, self.nms_max_output_boxes, self.nms_iou_threshold)
297
+ gt_non_group_of_boxlist = BoxList(gt_boxes[~gt_is_group_of_list])
298
+ gt_group_of_boxlist = BoxList(gt_boxes[gt_is_group_of_list])
299
+ iou_b = iou_boxlist(detected_boxlist, gt_non_group_of_boxlist)
300
+ ioa_b = np.transpose(ioa_boxlist(gt_group_of_boxlist, detected_boxlist))
301
+ scores = detected_boxlist.get_field('scores')
302
+ num_boxes = detected_boxlist.num_boxes()
303
+ return iou_b, ioa_b, scores, num_boxes
304
+
305
+ def _compute_tp_fp_for_single_class(
306
+ self, detected_boxes, detected_scores, gt_boxes,
307
+ gt_is_difficult_list, gt_is_group_of_list, detected_masks=None, gt_masks=None):
308
+ """Labels boxes detected with the same class from the same image as tp/fp.
309
+ Args:
310
+ detected_boxes: A numpy array of shape [N, 4] representing detected box coordinates
311
+ detected_scores: A 1-d numpy array of length N representing classification score
312
+ gt_boxes: A numpy array of shape [M, 4] representing ground truth box coordinates
313
+ gt_is_difficult_list: A boolean numpy array of length M denoting
314
+ whether a ground truth box is a difficult instance or not. If a
315
+ groundtruth box is difficult, every detection matching this box is ignored.
316
+ gt_is_group_of_list: A boolean numpy array of length M denoting
317
+ whether a ground truth box has group-of tag. If a groundtruth box is
318
+ group-of box, every detection matching this box is ignored.
319
+ detected_masks: (optional) A uint8 numpy array of shape [N, height,
320
+ width]. If not None, the scores will be computed based on masks.
321
+ gt_masks: (optional) A uint8 numpy array of shape [M, height, width].
322
+ Returns:
323
+ Two arrays of the same size, containing all boxes that were evaluated as
324
+ being true positives or false positives; if a box matched to a difficult
325
+ box or to a group-of box, it is ignored.
326
+ scores: A numpy array representing the detection scores.
327
+ tp_fp_labels: a boolean numpy array indicating whether a detection is a true positive.
328
+ """
329
+ if detected_boxes.size == 0:
330
+ return np.array([], dtype=float), np.array([], dtype=bool)
331
+
332
+ mask_mode = False
333
+ if detected_masks is not None and gt_masks is not None:
334
+ mask_mode = True
335
+
336
+ iou_b = np.ndarray([0, 0])
337
+ ioa_b = np.ndarray([0, 0])
338
+ iou_m = np.ndarray([0, 0])
339
+ ioa_m = np.ndarray([0, 0])
340
+ if mask_mode:
341
+ # For Instance Segmentation Evaluation on Open Images V5, not all boxed
342
+ # instances have corresponding segmentation annotations. Those boxes that
343
+ # dont have segmentation annotations are represented as empty masks in
344
+ # gt_masks nd array.
345
+ mask_presence_indicator = (np.sum(gt_masks, axis=(1, 2)) > 0)
346
+
347
+ iou_m, ioa_m, scores, num_detected_boxes = self._get_overlaps_and_scores_mask_mode(
348
+ detected_boxes=detected_boxes,
349
+ detected_scores=detected_scores,
350
+ detected_masks=detected_masks,
351
+ gt_boxes=gt_boxes[mask_presence_indicator, :],
352
+ gt_masks=gt_masks[mask_presence_indicator, :],
353
+ gt_is_group_of_list=gt_is_group_of_list[mask_presence_indicator])
354
+
355
+ if sum(mask_presence_indicator) < len(mask_presence_indicator):
356
+ # Not all masks are present - some masks are empty
357
+ iou_b, ioa_b, _, num_detected_boxes = self._get_overlaps_and_scores_box_mode(
358
+ detected_boxes=detected_boxes,
359
+ detected_scores=detected_scores,
360
+ gt_boxes=gt_boxes[~mask_presence_indicator, :],
361
+ gt_is_group_of_list=gt_is_group_of_list[~mask_presence_indicator])
362
+ num_detected_boxes = detected_boxes.shape[0]
363
+ else:
364
+ mask_presence_indicator = np.zeros(gt_is_group_of_list.shape, dtype=bool)
365
+ iou_b, ioa_b, scores, num_detected_boxes = self._get_overlaps_and_scores_box_mode(
366
+ detected_boxes=detected_boxes,
367
+ detected_scores=detected_scores,
368
+ gt_boxes=gt_boxes,
369
+ gt_is_group_of_list=gt_is_group_of_list)
370
+
371
+ if gt_boxes.size == 0:
372
+ return scores, np.zeros(num_detected_boxes, dtype=bool)
373
+
374
+ tp_fp_labels = np.zeros(num_detected_boxes, dtype=bool)
375
+ is_matched_to_box = np.zeros(num_detected_boxes, dtype=bool)
376
+ is_matched_to_difficult = np.zeros(num_detected_boxes, dtype=bool)
377
+ is_matched_to_group_of = np.zeros(num_detected_boxes, dtype=bool)
378
+
379
+ def compute_match_iou(iou_matrix, gt_nongroup_of_is_difficult_list, is_box):
380
+ """Computes TP/FP for non group-of box matching.
381
+ The function updates the following local variables:
382
+ tp_fp_labels - if a box is matched to group-of
383
+ is_matched_to_difficult - the detections that were processed at this are
384
+ matched to difficult box.
385
+ is_matched_to_box - the detections that were processed at this stage are marked as is_box.
386
+ Args:
387
+ iou_matrix: intersection-over-union matrix [num_gt_boxes]x[num_det_boxes].
388
+ gt_nongroup_of_is_difficult_list: boolean that specifies if gt box is difficult.
389
+ is_box: boolean that specifies if currently boxes or masks are processed.
390
+ """
391
+ max_overlap_gt_ids = np.argmax(iou_matrix, axis=1)
392
+ is_gt_detected = np.zeros(iou_matrix.shape[1], dtype=bool)
393
+ for i in range(num_detected_boxes):
394
+ gt_id = max_overlap_gt_ids[i]
395
+ is_evaluatable = (
396
+ not tp_fp_labels[i] and
397
+ not is_matched_to_difficult[i] and
398
+ iou_matrix[i, gt_id] >= self.matching_iou_threshold and
399
+ not is_matched_to_group_of[i])
400
+ if is_evaluatable:
401
+ if not gt_nongroup_of_is_difficult_list[gt_id]:
402
+ if not is_gt_detected[gt_id]:
403
+ tp_fp_labels[i] = True
404
+ is_gt_detected[gt_id] = True
405
+ is_matched_to_box[i] = is_box
406
+ else:
407
+ is_matched_to_difficult[i] = True
408
+
409
+ def compute_match_ioa(ioa_matrix, is_box):
410
+ """Computes TP/FP for group-of box matching.
411
+ The function updates the following local variables:
412
+ is_matched_to_group_of - if a box is matched to group-of
413
+ is_matched_to_box - the detections that were processed at this stage are marked as is_box.
414
+ Args:
415
+ ioa_matrix: intersection-over-area matrix [num_gt_boxes]x[num_det_boxes].
416
+ is_box: boolean that specifies if currently boxes or masks are processed.
417
+ Returns:
418
+ scores_group_of: of detections matched to group-of boxes[num_groupof_matched].
419
+ tp_fp_labels_group_of: boolean array of size [num_groupof_matched], all values are True.
420
+ """
421
+ scores_group_of = np.zeros(ioa_matrix.shape[1], dtype=float)
422
+ tp_fp_labels_group_of = self.group_of_weight * np.ones(ioa_matrix.shape[1], dtype=float)
423
+ max_overlap_group_of_gt_ids = np.argmax(ioa_matrix, axis=1)
424
+ for i in range(num_detected_boxes):
425
+ gt_id = max_overlap_group_of_gt_ids[i]
426
+ is_evaluatable = (
427
+ not tp_fp_labels[i] and
428
+ not is_matched_to_difficult[i] and
429
+ ioa_matrix[i, gt_id] >= self.matching_iou_threshold and
430
+ not is_matched_to_group_of[i])
431
+ if is_evaluatable:
432
+ is_matched_to_group_of[i] = True
433
+ is_matched_to_box[i] = is_box
434
+ scores_group_of[gt_id] = max(scores_group_of[gt_id], scores[i])
435
+ selector = np.where((scores_group_of > 0) & (tp_fp_labels_group_of > 0))
436
+ scores_group_of = scores_group_of[selector]
437
+ tp_fp_labels_group_of = tp_fp_labels_group_of[selector]
438
+
439
+ return scores_group_of, tp_fp_labels_group_of
440
+
441
+ # The evaluation is done in two stages:
442
+ # 1. Evaluate all objects that actually have instance level masks.
443
+ # 2. Evaluate all objects that are not already evaluated as boxes.
444
+ if iou_m.shape[1] > 0:
445
+ gt_is_difficult_mask_list = gt_is_difficult_list[mask_presence_indicator]
446
+ gt_is_group_of_mask_list = gt_is_group_of_list[mask_presence_indicator]
447
+ compute_match_iou(iou_m, gt_is_difficult_mask_list[~gt_is_group_of_mask_list], is_box=False)
448
+
449
+ scores_mask_group_of = np.ndarray([0], dtype=float)
450
+ tp_fp_labels_mask_group_of = np.ndarray([0], dtype=float)
451
+ if ioa_m.shape[1] > 0:
452
+ scores_mask_group_of, tp_fp_labels_mask_group_of = compute_match_ioa(ioa_m, is_box=False)
453
+
454
+ # Tp-fp evaluation for non-group of boxes (if any).
455
+ if iou_b.shape[1] > 0:
456
+ gt_is_difficult_box_list = gt_is_difficult_list[~mask_presence_indicator]
457
+ gt_is_group_of_box_list = gt_is_group_of_list[~mask_presence_indicator]
458
+ compute_match_iou(iou_b, gt_is_difficult_box_list[~gt_is_group_of_box_list], is_box=True)
459
+
460
+ scores_box_group_of = np.ndarray([0], dtype=float)
461
+ tp_fp_labels_box_group_of = np.ndarray([0], dtype=float)
462
+ if ioa_b.shape[1] > 0:
463
+ scores_box_group_of, tp_fp_labels_box_group_of = compute_match_ioa(ioa_b, is_box=True)
464
+
465
+ if mask_mode:
466
+ # Note: here crowds are treated as ignore regions.
467
+ valid_entries = (~is_matched_to_difficult & ~is_matched_to_group_of & ~is_matched_to_box)
468
+ return np.concatenate((scores[valid_entries], scores_mask_group_of)),\
469
+ np.concatenate((tp_fp_labels[valid_entries].astype(float), tp_fp_labels_mask_group_of))
470
+ else:
471
+ valid_entries = (~is_matched_to_difficult & ~is_matched_to_group_of)
472
+ return np.concatenate((scores[valid_entries], scores_box_group_of)),\
473
+ np.concatenate((tp_fp_labels[valid_entries].astype(float), tp_fp_labels_box_group_of))
474
+
475
+ def _get_ith_class_arrays(
476
+ self, detected_boxes, detected_scores, detected_masks, detected_class_labels,
477
+ gt_boxes, gt_masks, gt_class_labels, class_index):
478
+ """Returns numpy arrays belonging to class with index `class_index`.
479
+ Args:
480
+ detected_boxes: A numpy array containing detected boxes.
481
+ detected_scores: A numpy array containing detected scores.
482
+ detected_masks: A numpy array containing detected masks.
483
+ detected_class_labels: A numpy array containing detected class labels.
484
+ gt_boxes: A numpy array containing groundtruth boxes.
485
+ gt_masks: A numpy array containing groundtruth masks.
486
+ gt_class_labels: A numpy array containing groundtruth class labels.
487
+ class_index: An integer index.
488
+ Returns:
489
+ gt_boxes_at_ith_class: A numpy array containing groundtruth boxes labeled as ith class.
490
+ gt_masks_at_ith_class: A numpy array containing groundtruth masks labeled as ith class.
491
+ detected_boxes_at_ith_class: A numpy array containing detected boxes corresponding to the ith class.
492
+ detected_scores_at_ith_class: A numpy array containing detected scores corresponding to the ith class.
493
+ detected_masks_at_ith_class: A numpy array containing detected masks corresponding to the ith class.
494
+ """
495
+ selected_groundtruth = (gt_class_labels == class_index)
496
+ gt_boxes_at_ith_class = gt_boxes[selected_groundtruth]
497
+ if gt_masks is not None:
498
+ gt_masks_at_ith_class = gt_masks[selected_groundtruth]
499
+ else:
500
+ gt_masks_at_ith_class = None
501
+ selected_detections = (detected_class_labels == class_index)
502
+ detected_boxes_at_ith_class = detected_boxes[selected_detections]
503
+ detected_scores_at_ith_class = detected_scores[selected_detections]
504
+ if detected_masks is not None:
505
+ detected_masks_at_ith_class = detected_masks[selected_detections]
506
+ else:
507
+ detected_masks_at_ith_class = None
508
+ return (gt_boxes_at_ith_class, gt_masks_at_ith_class,
509
+ detected_boxes_at_ith_class, detected_scores_at_ith_class,
510
+ detected_masks_at_ith_class)
511
+
512
+ def _remove_invalid_boxes(
513
+ self, detected_boxes, detected_scores, detected_class_labels, detected_masks=None):
514
+ """Removes entries with invalid boxes.
515
+ A box is invalid if either its xmax is smaller than its xmin, or its ymax is smaller than its ymin.
516
+ Args:
517
+ detected_boxes: A float numpy array of size [num_boxes, 4] containing box
518
+ coordinates in [ymin, xmin, ymax, xmax] format.
519
+ detected_scores: A float numpy array of size [num_boxes].
520
+ detected_class_labels: A int32 numpy array of size [num_boxes].
521
+ detected_masks: A uint8 numpy array of size [num_boxes, height, width].
522
+ Returns:
523
+ valid_detected_boxes: A float numpy array of size [num_valid_boxes, 4]
524
+ containing box coordinates in [ymin, xmin, ymax, xmax] format.
525
+ valid_detected_scores: A float numpy array of size [num_valid_boxes].
526
+ valid_detected_class_labels: A int32 numpy array of size [num_valid_boxes].
527
+ valid_detected_masks: A uint8 numpy array of size [num_valid_boxes, height, width].
528
+ """
529
+ valid_indices = np.logical_and(
530
+ detected_boxes[:, 0] < detected_boxes[:, 2], detected_boxes[:, 1] < detected_boxes[:, 3])
531
+ detected_boxes = detected_boxes[valid_indices]
532
+ detected_scores = detected_scores[valid_indices]
533
+ detected_class_labels = detected_class_labels[valid_indices]
534
+ if detected_masks is not None:
535
+ detected_masks = detected_masks[valid_indices]
536
+ return [detected_boxes, detected_scores, detected_class_labels, detected_masks]
537
+
538
+
efficientdet/effdet/evaluator.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import abc
4
+ import json
5
+ import logging
6
+ import time
7
+ import numpy as np
8
+
9
+ from .distributed import synchronize, is_main_process, all_gather_container
10
+ from pycocotools.cocoeval import COCOeval
11
+
12
+ # FIXME experimenting with speedups for OpenImages eval, it's slow
13
+ #import pyximport; py_importer, pyx_importer = pyximport.install(pyimport=True)
14
+ import effdet.evaluation.detection_evaluator as tfm_eval
15
+ #pyximport.uninstall(py_importer, pyx_importer)
16
+
17
+ _logger = logging.getLogger(__name__)
18
+
19
+
20
+ __all__ = ['CocoEvaluator', 'PascalEvaluator', 'OpenImagesEvaluator', 'create_evaluator']
21
+
22
+
23
+ class Evaluator:
24
+
25
+ def __init__(self, distributed=False, pred_yxyx=False):
26
+ self.distributed = distributed
27
+ self.distributed_device = None
28
+ self.pred_yxyx = pred_yxyx
29
+ self.img_indices = []
30
+ self.predictions = []
31
+
32
+ def add_predictions(self, detections, target):
33
+ if self.distributed:
34
+ if self.distributed_device is None:
35
+ # cache for use later to broadcast end metric
36
+ self.distributed_device = detections.device
37
+ synchronize()
38
+ detections = all_gather_container(detections)
39
+ img_indices = all_gather_container(target['img_idx'])
40
+ if not is_main_process():
41
+ return
42
+ else:
43
+ img_indices = target['img_idx']
44
+
45
+ detections = detections.cpu().numpy()
46
+ img_indices = img_indices.cpu().numpy()
47
+ for img_idx, img_dets in zip(img_indices, detections):
48
+ self.img_indices.append(img_idx)
49
+ self.predictions.append(img_dets)
50
+
51
+ def _coco_predictions(self):
52
+ # generate coco-style predictions
53
+ coco_predictions = []
54
+ coco_ids = []
55
+ for img_idx, img_dets in zip(self.img_indices, self.predictions):
56
+ img_id = self._dataset.img_ids[img_idx]
57
+ coco_ids.append(img_id)
58
+ if self.pred_yxyx:
59
+ # to xyxy
60
+ img_dets[:, 0:4] = img_dets[:, [1, 0, 3, 2]]
61
+ # to xywh
62
+ img_dets[:, 2] -= img_dets[:, 0]
63
+ img_dets[:, 3] -= img_dets[:, 1]
64
+ for det in img_dets:
65
+ score = float(det[4])
66
+ if score < .001: # stop when below this threshold, scores in descending order
67
+ break
68
+ coco_det = dict(
69
+ image_id=int(img_id),
70
+ bbox=det[0:4].tolist(),
71
+ score=score,
72
+ category_id=int(det[5]))
73
+ coco_predictions.append(coco_det)
74
+ return coco_predictions, coco_ids
75
+
76
+ @abc.abstractmethod
77
+ def evaluate(self):
78
+ pass
79
+
80
+ def save(self, result_file):
81
+ # save results in coco style, override to save in a alternate form
82
+ if not self.distributed or dist.get_rank() == 0:
83
+ assert len(self.predictions)
84
+ coco_predictions, coco_ids = self._coco_predictions()
85
+ json.dump(coco_predictions, open(result_file, 'w'), indent=4)
86
+
87
+
88
+ class CocoEvaluator(Evaluator):
89
+
90
+ def __init__(self, dataset, neptune=None, distributed=False, pred_yxyx=False):
91
+ super().__init__(distributed=distributed, pred_yxyx=pred_yxyx)
92
+ self._dataset = dataset.parser
93
+ self.coco_api = dataset.parser.coco
94
+ self.neptune = neptune
95
+
96
+ def reset(self):
97
+ self.img_indices = []
98
+ self.predictions = []
99
+
100
+ def evaluate(self):
101
+ if not self.distributed or dist.get_rank() == 0:
102
+ assert len(self.predictions)
103
+ coco_predictions, coco_ids = self._coco_predictions()
104
+ json.dump(coco_predictions, open('./temp.json', 'w'), indent=4)
105
+ results = self.coco_api.loadRes('./temp.json')
106
+ coco_eval = COCOeval(self.coco_api, results, 'bbox')
107
+ coco_eval.params.imgIds = coco_ids # score only ids we've used
108
+ coco_eval.evaluate()
109
+ coco_eval.accumulate()
110
+ coco_eval.summarize()
111
+ metric = coco_eval.stats[0] # mAP 0.5-0.95
112
+ if self.neptune:
113
+ self.neptune.log_metric('valid/mAP/0.5-0.95IOU', metric)
114
+ self.neptune.log_metric('valid/mAP/0.5IOU', coco_eval.stats[1])
115
+ if self.distributed:
116
+ dist.broadcast(torch.tensor(metric, device=self.distributed_device), 0)
117
+ else:
118
+ metric = torch.tensor(0, device=self.distributed_device)
119
+ dist.broadcast(metric, 0)
120
+ metric = metric.item()
121
+ self.reset()
122
+ return metric
123
+
124
+
125
+ class TfmEvaluator(Evaluator):
126
+ """ Tensorflow Models Evaluator Wrapper """
127
+ def __init__(
128
+ self, dataset, neptune=None, distributed=False, pred_yxyx=False,
129
+ evaluator_cls=tfm_eval.ObjectDetectionEvaluator):
130
+ super().__init__(distributed=distributed, pred_yxyx=pred_yxyx)
131
+ self._evaluator = evaluator_cls(categories=dataset.parser.cat_dicts)
132
+ self._eval_metric_name = self._evaluator._metric_names[0]
133
+ self._dataset = dataset.parser
134
+ self.neptune = neptune
135
+
136
+ def reset(self):
137
+ self._evaluator.clear()
138
+ self.img_indices = []
139
+ self.predictions = []
140
+
141
+ def evaluate(self):
142
+ if not self.distributed or dist.get_rank() == 0:
143
+ for img_idx, img_dets in zip(self.img_indices, self.predictions):
144
+ gt = self._dataset.get_ann_info(img_idx)
145
+ self._evaluator.add_single_ground_truth_image_info(img_idx, gt)
146
+
147
+ bbox = img_dets[:, 0:4] if self.pred_yxyx else img_dets[:, [1, 0, 3, 2]]
148
+ det = dict(bbox=bbox, score=img_dets[:, 4], cls=img_dets[:, 5])
149
+ self._evaluator.add_single_detected_image_info(img_idx, det)
150
+
151
+ metrics = self._evaluator.evaluate()
152
+ _logger.info('Metrics:')
153
+ for k, v in metrics.items():
154
+ _logger.info(f'{k}: {v}')
155
+ if self.neptune:
156
+ key = 'valid/mAP/' + str(k).split('/')[-1]
157
+ self.neptune.log_metric(key, v)
158
+
159
+ map_metric = metrics[self._eval_metric_name]
160
+ if self.distributed:
161
+ dist.broadcast(torch.tensor(map_metric, device=self.distributed_device), 0)
162
+ else:
163
+ map_metric = torch.tensor(0, device=self.distributed_device)
164
+ wait = dist.broadcast(map_metric, 0, async_op=True)
165
+ while not wait.is_completed():
166
+ # wait without spinning the cpu @ 100%, no need for low latency here
167
+ time.sleep(0.5)
168
+ map_metric = map_metric.item()
169
+ self.reset()
170
+ return map_metric
171
+
172
+
173
+ class PascalEvaluator(TfmEvaluator):
174
+
175
+ def __init__(self, dataset, neptune=None, distributed=False, pred_yxyx=False):
176
+ super().__init__(
177
+ dataset, neptune, distributed=distributed, pred_yxyx=pred_yxyx, evaluator_cls=tfm_eval.PascalDetectionEvaluator)
178
+
179
+
180
+ class OpenImagesEvaluator(TfmEvaluator):
181
+
182
+ def __init__(self, dataset, distributed=False, pred_yxyx=False):
183
+ super().__init__(
184
+ dataset, distributed=distributed, pred_yxyx=pred_yxyx, evaluator_cls=tfm_eval.OpenImagesDetectionEvaluator)
185
+
186
+
187
+ def create_evaluator(name, dataset, neptune=None, distributed=False, pred_yxyx=False):
188
+ # FIXME support OpenImages Challenge2019 metric w/ image level label consideration
189
+ if 'coco' in name:
190
+ return CocoEvaluator(dataset, neptune, distributed=distributed, pred_yxyx=pred_yxyx)
191
+ elif 'openimages' in name:
192
+ return OpenImagesEvaluator(dataset, distributed=distributed, pred_yxyx=pred_yxyx)
193
+ else:
194
+ return CocoEvaluator(dataset, neptune, distributed=distributed, pred_yxyx=pred_yxyx)
195
+ #return PascalEvaluator(dataset, neptune, distributed=distributed, pred_yxyx=pred_yxyx)
efficientdet/effdet/factory.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .efficientdet import EfficientDet, HeadNet
2
+ from .bench import DetBenchTrain, DetBenchPredict
3
+ from .config import get_efficientdet_config
4
+ from .helpers import load_pretrained, load_checkpoint
5
+
6
+
7
+ def create_model(
8
+ model_name, bench_task='', num_classes=None, pretrained=False,
9
+ checkpoint_path='', checkpoint_ema=False, **kwargs):
10
+
11
+ config = get_efficientdet_config(model_name)
12
+ return create_model_from_config(
13
+ config, bench_task=bench_task, num_classes=num_classes, pretrained=pretrained,
14
+ checkpoint_path=checkpoint_path, checkpoint_ema=checkpoint_ema, **kwargs)
15
+
16
+
17
+ def create_model_from_config(
18
+ config, bench_task='', num_classes=None, pretrained=False,
19
+ checkpoint_path='', checkpoint_ema=False, **kwargs):
20
+
21
+ pretrained_backbone = kwargs.pop('pretrained_backbone', True)
22
+ if pretrained or checkpoint_path:
23
+ pretrained_backbone = False # no point in loading backbone weights
24
+
25
+ # Config overrides, override some config values via kwargs.
26
+ overrides = ('redundant_bias', 'label_smoothing', 'new_focal', 'jit_loss')
27
+ for ov in overrides:
28
+ value = kwargs.pop(ov, None)
29
+ if value is not None:
30
+ setattr(config, ov, value)
31
+
32
+ labeler = kwargs.pop('bench_labeler', False)
33
+
34
+ # create the base model
35
+ model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs)
36
+
37
+ # pretrained weights are always spec'd for original config, load them before we change the model
38
+ if pretrained:
39
+ load_pretrained(model, config.url)
40
+
41
+ # reset model head if num_classes doesn't match configs
42
+ if num_classes is not None and num_classes != config.num_classes:
43
+ model.reset_head(num_classes=num_classes)
44
+
45
+ # load an argument specified training checkpoint
46
+ if checkpoint_path:
47
+ load_checkpoint(model, checkpoint_path, use_ema=checkpoint_ema)
48
+
49
+ # wrap model in task specific training/prediction bench if set
50
+ if bench_task == 'train':
51
+ model = DetBenchTrain(model, create_labeler=labeler)
52
+ elif bench_task == 'predict':
53
+ model = DetBenchPredict(model)
54
+ return model
efficientdet/effdet/helpers.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import logging
4
+ from collections import OrderedDict
5
+
6
+ from timm.models import load_checkpoint
7
+
8
+ try:
9
+ from torch.hub import load_state_dict_from_url
10
+ except ImportError:
11
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
12
+
13
+
14
+ def load_pretrained(model, url, filter_fn=None, strict=True):
15
+ if not url:
16
+ logging.warning("Pretrained model URL is empty, using random initialization. "
17
+ "Did you intend to use a `tf_` variant of the model?")
18
+ return
19
+ state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
20
+ if filter_fn is not None:
21
+ state_dict = filter_fn(state_dict)
22
+ model.load_state_dict(state_dict, strict=strict)
efficientdet/effdet/loss.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ EfficientDet Focal, Huber/Smooth L1 loss fns w/ jit support
2
+
3
+ Based on loss fn in Google's automl EfficientDet repository (Apache 2.0 license).
4
+ https://github.com/google/automl/tree/master/efficientdet
5
+
6
+ Copyright 2020 Ross Wightman
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from typing import Optional, List, Tuple
13
+
14
+
15
+ def focal_loss_legacy(logits, targets, alpha: float, gamma: float, normalizer):
16
+ """Compute the focal loss between `logits` and the golden `target` values.
17
+
18
+ 'Legacy focal loss matches the loss used in the official Tensorflow impl for initial
19
+ model releases and some time after that. It eventually transitioned to the 'New' loss
20
+ defined below.
21
+
22
+ Focal loss = -(1-pt)^gamma * log(pt)
23
+ where pt is the probability of being classified to the true class.
24
+
25
+ Args:
26
+ logits: A float32 tensor of size [batch, height_in, width_in, num_predictions].
27
+
28
+ targets: A float32 tensor of size [batch, height_in, width_in, num_predictions].
29
+
30
+ alpha: A float32 scalar multiplying alpha to the loss from positive examples
31
+ and (1-alpha) to the loss from negative examples.
32
+
33
+ gamma: A float32 scalar modulating loss from hard and easy examples.
34
+
35
+ normalizer: A float32 scalar normalizes the total loss from all examples.
36
+
37
+ Returns:
38
+ loss: A float32 scalar representing normalized total loss.
39
+ """
40
+ positive_label_mask = targets == 1.0
41
+ cross_entropy = F.binary_cross_entropy_with_logits(logits, targets.to(logits.dtype), reduction='none')
42
+ neg_logits = -1.0 * logits
43
+ modulator = torch.exp(gamma * targets * neg_logits - gamma * torch.log1p(torch.exp(neg_logits)))
44
+
45
+ loss = modulator * cross_entropy
46
+ weighted_loss = torch.where(positive_label_mask, alpha * loss, (1.0 - alpha) * loss)
47
+ return weighted_loss / normalizer
48
+
49
+
50
+ def new_focal_loss(logits, targets, alpha: float, gamma: float, normalizer, label_smoothing: float = 0.01):
51
+ """Compute the focal loss between `logits` and the golden `target` values.
52
+
53
+ 'New' is not the best descriptor, but this focal loss impl matches recent versions of
54
+ the official Tensorflow impl of EfficientDet. It has support for label smoothing, however
55
+ it is a bit slower, doesn't jit optimize well, and uses more memory.
56
+
57
+ Focal loss = -(1-pt)^gamma * log(pt)
58
+ where pt is the probability of being classified to the true class.
59
+ Args:
60
+ logits: A float32 tensor of size [batch, height_in, width_in, num_predictions].
61
+ targets: A float32 tensor of size [batch, height_in, width_in, num_predictions].
62
+ alpha: A float32 scalar multiplying alpha to the loss from positive examples
63
+ and (1-alpha) to the loss from negative examples.
64
+ gamma: A float32 scalar modulating loss from hard and easy examples.
65
+ normalizer: Divide loss by this value.
66
+ label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
67
+ Returns:
68
+ loss: A float32 scalar representing normalized total loss.
69
+ """
70
+ # compute focal loss multipliers before label smoothing, such that it will not blow up the loss.
71
+ pred_prob = logits.sigmoid()
72
+ targets = targets.to(logits.dtype)
73
+ onem_targets = 1. - targets
74
+ p_t = (targets * pred_prob) + (onem_targets * (1. - pred_prob))
75
+ alpha_factor = targets * alpha + onem_targets * (1. - alpha)
76
+ modulating_factor = (1. - p_t) ** gamma
77
+
78
+ # apply label smoothing for cross_entropy for each entry.
79
+ if label_smoothing > 0.:
80
+ targets = targets * (1. - label_smoothing) + .5 * label_smoothing
81
+ ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
82
+
83
+ # compute the final loss and return
84
+ return (1 / normalizer) * alpha_factor * modulating_factor * ce
85
+
86
+
87
+ def huber_loss(
88
+ input, target, delta: float = 1., weights: Optional[torch.Tensor] = None, size_average: bool = True):
89
+ """
90
+ """
91
+ err = input - target
92
+ abs_err = err.abs()
93
+ quadratic = torch.clamp(abs_err, max=delta)
94
+ linear = abs_err - quadratic
95
+ loss = 0.5 * quadratic.pow(2) + delta * linear
96
+ if weights is not None:
97
+ loss *= weights
98
+ if size_average:
99
+ return loss.mean()
100
+ else:
101
+ return loss.sum()
102
+
103
+
104
+ def smooth_l1_loss(
105
+ input, target, beta: float = 1. / 9, weights: Optional[torch.Tensor] = None, size_average: bool = True):
106
+ """
107
+ very similar to the smooth_l1_loss from pytorch, but with the extra beta parameter
108
+ """
109
+ if beta < 1e-5:
110
+ # if beta == 0, then torch.where will result in nan gradients when
111
+ # the chain rule is applied due to pytorch implementation details
112
+ # (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of
113
+ # zeros, rather than "no gradient"). To avoid this issue, we define
114
+ # small values of beta to be exactly l1 loss.
115
+ loss = torch.abs(input - target)
116
+ else:
117
+ err = torch.abs(input - target)
118
+ loss = torch.where(err < beta, 0.5 * err.pow(2) / beta, err - 0.5 * beta)
119
+ if weights is not None:
120
+ loss *= weights
121
+ if size_average:
122
+ return loss.mean()
123
+ else:
124
+ return loss.sum()
125
+
126
+
127
+ def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1):
128
+ """Computes box regression loss."""
129
+ # delta is typically around the mean value of regression target.
130
+ # for instances, the regression targets of 512x512 input with 6 anchors on
131
+ # P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2].
132
+ normalizer = num_positives * 4.0
133
+ mask = box_targets != 0.0
134
+ box_loss = huber_loss(box_outputs, box_targets, weights=mask, delta=delta, size_average=False)
135
+ return box_loss / normalizer
136
+
137
+
138
+ def one_hot(x, num_classes: int):
139
+ # NOTE: PyTorch one-hot does not handle -ve entries (no hot) like Tensorflow, so mask them out
140
+ x_non_neg = (x >= 0).unsqueeze(-1)
141
+ onehot = torch.zeros(x.shape + (num_classes,), device=x.device, dtype=torch.float32)
142
+ return onehot.scatter(-1, x.unsqueeze(-1) * x_non_neg, 1) * x_non_neg
143
+
144
+
145
+ def loss_fn(
146
+ cls_outputs: List[torch.Tensor],
147
+ box_outputs: List[torch.Tensor],
148
+ cls_targets: List[torch.Tensor],
149
+ box_targets: List[torch.Tensor],
150
+ num_positives: torch.Tensor,
151
+ num_classes: int,
152
+ alpha: float,
153
+ gamma: float,
154
+ delta: float,
155
+ box_loss_weight: float,
156
+ label_smoothing: float = 0.,
157
+ new_focal: bool = False,
158
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
159
+ """Computes total detection loss.
160
+ Computes total detection loss including box and class loss from all levels.
161
+ Args:
162
+ cls_outputs: a List with values representing logits in [batch_size, height, width, num_anchors].
163
+ at each feature level (index)
164
+
165
+ box_outputs: a List with values representing box regression targets in
166
+ [batch_size, height, width, num_anchors * 4] at each feature level (index)
167
+
168
+ cls_targets: groundtruth class targets.
169
+
170
+ box_targets: groundtrusth box targets.
171
+
172
+ num_positives: num positive grountruth anchors
173
+
174
+ Returns:
175
+ total_loss: an integer tensor representing total loss reducing from class and box losses from all levels.
176
+
177
+ cls_loss: an integer tensor representing total class loss.
178
+
179
+ box_loss: an integer tensor representing total box regression loss.
180
+ """
181
+ # Sum all positives in a batch for normalization and avoid zero
182
+ # num_positives_sum, which would lead to inf loss during training
183
+ num_positives_sum = (num_positives.sum() + 1.0).float()
184
+ levels = len(cls_outputs)
185
+
186
+ cls_losses = []
187
+ box_losses = []
188
+ for l in range(levels):
189
+ cls_targets_at_level = cls_targets[l]
190
+ box_targets_at_level = box_targets[l]
191
+
192
+ # Onehot encoding for classification labels.
193
+ cls_targets_at_level_oh = one_hot(cls_targets_at_level, num_classes)
194
+
195
+ bs, height, width, _, _ = cls_targets_at_level_oh.shape
196
+ cls_targets_at_level_oh = cls_targets_at_level_oh.view(bs, height, width, -1)
197
+ cls_outputs_at_level = cls_outputs[l].permute(0, 2, 3, 1).float()
198
+ if new_focal:
199
+ cls_loss = new_focal_loss(
200
+ cls_outputs_at_level, cls_targets_at_level_oh,
201
+ alpha=alpha, gamma=gamma, normalizer=num_positives_sum, label_smoothing=label_smoothing)
202
+ else:
203
+ cls_loss = focal_loss_legacy(
204
+ cls_outputs_at_level, cls_targets_at_level_oh,
205
+ alpha=alpha, gamma=gamma, normalizer=num_positives_sum)
206
+ cls_loss = cls_loss.view(bs, height, width, -1, num_classes)
207
+ cls_loss = cls_loss * (cls_targets_at_level != -2).unsqueeze(-1)
208
+ cls_losses.append(cls_loss.sum()) # FIXME reference code added a clamp here at some point ...clamp(0, 2))
209
+
210
+ box_losses.append(_box_loss(
211
+ box_outputs[l].permute(0, 2, 3, 1).float(),
212
+ box_targets_at_level,
213
+ num_positives_sum,
214
+ delta=delta))
215
+
216
+ # Sum per level losses to total loss.
217
+ cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1)
218
+ box_loss = torch.sum(torch.stack(box_losses, dim=-1), dim=-1)
219
+ total_loss = cls_loss + box_loss_weight * box_loss
220
+ return total_loss, cls_loss, box_loss
221
+
222
+
223
+ loss_jit = torch.jit.script(loss_fn)
224
+
225
+
226
+ class DetectionLoss(nn.Module):
227
+
228
+ __constants__ = ['num_classes']
229
+
230
+ def __init__(self, config):
231
+ super(DetectionLoss, self).__init__()
232
+ self.config = config
233
+ self.num_classes = config.num_classes
234
+ self.alpha = config.alpha
235
+ self.gamma = config.gamma
236
+ self.delta = config.delta
237
+ self.box_loss_weight = config.box_loss_weight
238
+ self.label_smoothing = config.label_smoothing
239
+ self.new_focal = config.new_focal
240
+ self.use_jit = config.jit_loss
241
+
242
+ def forward(
243
+ self,
244
+ cls_outputs: List[torch.Tensor],
245
+ box_outputs: List[torch.Tensor],
246
+ cls_targets: List[torch.Tensor],
247
+ box_targets: List[torch.Tensor],
248
+ num_positives: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
249
+
250
+ l_fn = loss_fn
251
+ if not torch.jit.is_scripting() and self.use_jit:
252
+ # This branch only active if parent / bench itself isn't being scripted
253
+ # NOTE: I haven't figured out what to do here wrt to tracing, is it an issue?
254
+ l_fn = loss_jit
255
+
256
+ return l_fn(
257
+ cls_outputs, box_outputs, cls_targets, box_targets, num_positives,
258
+ num_classes=self.num_classes, alpha=self.alpha, gamma=self.gamma, delta=self.delta,
259
+ box_loss_weight=self.box_loss_weight, label_smoothing=self.label_smoothing, new_focal=self.new_focal)
efficientdet/effdet/object_detection/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Tensorflow Object Detection
2
+
3
+ All of this code is adapted/ported/copied from https://github.com/google/automl/tree/552d0facd14f4fe9205a67fb13ecb5690a4d1c94/efficientdet/object_detection
efficientdet/effdet/object_detection/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Google Research. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Object detection data loaders and libraries are mostly based on RetinaNet:
16
+ # https://github.com/tensorflow/tpu/tree/master/models/official/retinanet
17
+ from .argmax_matcher import ArgMaxMatcher
18
+ from .box_coder import FasterRcnnBoxCoder
19
+ from .box_list import BoxList
20
+ from .matcher import Match
21
+ from .region_similarity_calculator import IouSimilarity
22
+ from .target_assigner import TargetAssigner
efficientdet/effdet/object_detection/argmax_matcher.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Google Research. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Argmax matcher implementation.
16
+
17
+ This class takes a similarity matrix and matches columns to rows based on the
18
+ maximum value per column. One can specify matched_thresholds and
19
+ to prevent columns from matching to rows (generally resulting in a negative
20
+ training example) and unmatched_theshold to ignore the match (generally
21
+ resulting in neither a positive or negative training example).
22
+
23
+ This matcher is used in Fast(er)-RCNN.
24
+
25
+ Note: matchers are used in TargetAssigners. There is a create_target_assigner
26
+ factory function for popular implementations.
27
+ """
28
+ import torch
29
+ from .matcher import Match
30
+ from typing import Optional
31
+
32
+
33
+ def one_hot_bool(x, num_classes: int):
34
+ # for improved perf over PyTorch builtin one_hot, scatter to bool
35
+ onehot = torch.zeros(x.size(0), num_classes, device=x.device, dtype=torch.bool)
36
+ return onehot.scatter_(1, x.unsqueeze(1), 1)
37
+
38
+
39
+ @torch.jit.script
40
+ class ArgMaxMatcher(object): # cannot inherit with torchscript
41
+ """Matcher based on highest value.
42
+
43
+ This class computes matches from a similarity matrix. Each column is matched
44
+ to a single row.
45
+
46
+ To support object detection target assignment this class enables setting both
47
+ matched_threshold (upper threshold) and unmatched_threshold (lower thresholds)
48
+ defining three categories of similarity which define whether examples are
49
+ positive, negative, or ignored:
50
+ (1) similarity >= matched_threshold: Highest similarity. Matched/Positive!
51
+ (2) matched_threshold > similarity >= unmatched_threshold: Medium similarity.
52
+ Depending on negatives_lower_than_unmatched, this is either
53
+ Unmatched/Negative OR Ignore.
54
+ (3) unmatched_threshold > similarity: Lowest similarity. Depending on flag
55
+ negatives_lower_than_unmatched, either Unmatched/Negative OR Ignore.
56
+ For ignored matches this class sets the values in the Match object to -2.
57
+ """
58
+
59
+ def __init__(self,
60
+ matched_threshold: float,
61
+ unmatched_threshold: Optional[float] = None,
62
+ negatives_lower_than_unmatched: bool = True,
63
+ force_match_for_each_row: bool = False):
64
+ """Construct ArgMaxMatcher.
65
+
66
+ Args:
67
+ matched_threshold: Threshold for positive matches. Positive if
68
+ sim >= matched_threshold, where sim is the maximum value of the
69
+ similarity matrix for a given column. Set to None for no threshold.
70
+ unmatched_threshold: Threshold for negative matches. Negative if
71
+ sim < unmatched_threshold. Defaults to matched_threshold
72
+ when set to None.
73
+ negatives_lower_than_unmatched: Boolean which defaults to True. If True
74
+ then negative matches are the ones below the unmatched_threshold,
75
+ whereas ignored matches are in between the matched and unmatched
76
+ threshold. If False, then negative matches are in between the matched
77
+ and unmatched threshold, and everything lower than unmatched is ignored.
78
+ force_match_for_each_row: If True, ensures that each row is matched to
79
+ at least one column (which is not guaranteed otherwise if the
80
+ matched_threshold is high). Defaults to False. See
81
+ argmax_matcher_test.testMatcherForceMatch() for an example.
82
+
83
+ Raises:
84
+ ValueError: if unmatched_threshold is set but matched_threshold is not set
85
+ or if unmatched_threshold > matched_threshold.
86
+ """
87
+ if (matched_threshold is None) and (unmatched_threshold is not None):
88
+ raise ValueError('Need to also define matched_threshold when unmatched_threshold is defined')
89
+ self._matched_threshold = matched_threshold
90
+ self._unmatched_threshold: float = 0.
91
+ if unmatched_threshold is None:
92
+ self._unmatched_threshold = matched_threshold
93
+ else:
94
+ if unmatched_threshold > matched_threshold:
95
+ raise ValueError('unmatched_threshold needs to be smaller or equal to matched_threshold')
96
+ self._unmatched_threshold = unmatched_threshold
97
+ if not negatives_lower_than_unmatched:
98
+ if self._unmatched_threshold == self._matched_threshold:
99
+ raise ValueError('When negatives are in between matched and unmatched thresholds, these '
100
+ 'cannot be of equal value. matched: %s, unmatched: %s',
101
+ self._matched_threshold, self._unmatched_threshold)
102
+ self._force_match_for_each_row = force_match_for_each_row
103
+ self._negatives_lower_than_unmatched = negatives_lower_than_unmatched
104
+
105
+ def _match_when_rows_are_empty(self, similarity_matrix):
106
+ """Performs matching when the rows of similarity matrix are empty.
107
+
108
+ When the rows are empty, all detections are false positives. So we return
109
+ a tensor of -1's to indicate that the columns do not match to any rows.
110
+
111
+ Returns:
112
+ matches: int32 tensor indicating the row each column matches to.
113
+ """
114
+ return -1 * torch.ones(similarity_matrix.shape[1], dtype=torch.long, device=similarity_matrix.device)
115
+
116
+ def _match_when_rows_are_non_empty(self, similarity_matrix):
117
+ """Performs matching when the rows of similarity matrix are non empty.
118
+
119
+ Returns:
120
+ matches: int32 tensor indicating the row each column matches to.
121
+ """
122
+ # Matches for each column
123
+ matched_vals, matches = torch.max(similarity_matrix, 0)
124
+
125
+ # Deal with matched and unmatched threshold
126
+ if self._matched_threshold is not None:
127
+ # Get logical indices of ignored and unmatched columns as tf.int64
128
+ below_unmatched_threshold = self._unmatched_threshold > matched_vals
129
+ between_thresholds = (matched_vals >= self._unmatched_threshold) & \
130
+ (self._matched_threshold > matched_vals)
131
+
132
+ if self._negatives_lower_than_unmatched:
133
+ matches = self._set_values_using_indicator(matches, below_unmatched_threshold, -1)
134
+ matches = self._set_values_using_indicator(matches, between_thresholds, -2)
135
+ else:
136
+ matches = self._set_values_using_indicator(matches, below_unmatched_threshold, -2)
137
+ matches = self._set_values_using_indicator(matches, between_thresholds, -1)
138
+
139
+ if self._force_match_for_each_row:
140
+ force_match_column_ids = torch.argmax(similarity_matrix, 1)
141
+ force_match_column_indicators = one_hot_bool(force_match_column_ids, similarity_matrix.shape[1])
142
+ force_match_column_mask, force_match_row_ids = torch.max(force_match_column_indicators, 0)
143
+ final_matches = torch.where(force_match_column_mask, force_match_row_ids, matches)
144
+ return final_matches
145
+ else:
146
+ return matches
147
+
148
+ def match(self, similarity_matrix):
149
+ """Tries to match each column of the similarity matrix to a row.
150
+
151
+ Args:
152
+ similarity_matrix: tensor of shape [N, M] representing any similarity metric.
153
+
154
+ Returns:
155
+ Match object with corresponding matches for each of M columns.
156
+ """
157
+ if similarity_matrix.shape[0] == 0:
158
+ return Match(self._match_when_rows_are_empty(similarity_matrix))
159
+ else:
160
+ return Match(self._match_when_rows_are_non_empty(similarity_matrix))
161
+
162
+ def _set_values_using_indicator(self, x, indicator, val: int):
163
+ """Set the indicated fields of x to val.
164
+
165
+ Args:
166
+ x: tensor.
167
+ indicator: boolean with same shape as x.
168
+ val: scalar with value to set.
169
+
170
+ Returns:
171
+ modified tensor.
172
+ """
173
+ indicator = indicator.to(dtype=x.dtype)
174
+ return x * (1 - indicator) + val * indicator
efficientdet/effdet/object_detection/box_coder.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Google Research. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Base box coder.
16
+
17
+ Box coders convert between coordinate frames, namely image-centric
18
+ (with (0,0) on the top left of image) and anchor-centric (with (0,0) being
19
+ defined by a specific anchor).
20
+
21
+ Users of a BoxCoder can call two methods:
22
+ encode: which encodes a box with respect to a given anchor
23
+ (or rather, a tensor of boxes wrt a corresponding tensor of anchors) and
24
+ decode: which inverts this encoding with a decode operation.
25
+ In both cases, the arguments are assumed to be in 1-1 correspondence already;
26
+ it is not the job of a BoxCoder to perform matching.
27
+ """
28
+ import torch
29
+ from typing import List, Optional
30
+ from .box_list import BoxList
31
+
32
+ # Box coder types.
33
+ FASTER_RCNN = 'faster_rcnn'
34
+ KEYPOINT = 'keypoint'
35
+ MEAN_STDDEV = 'mean_stddev'
36
+ SQUARE = 'square'
37
+
38
+
39
+ """Faster RCNN box coder.
40
+
41
+ Faster RCNN box coder follows the coding schema described below:
42
+ ty = (y - ya) / ha
43
+ tx = (x - xa) / wa
44
+ th = log(h / ha)
45
+ tw = log(w / wa)
46
+ where x, y, w, h denote the box's center coordinates, width and height
47
+ respectively. Similarly, xa, ya, wa, ha denote the anchor's center
48
+ coordinates, width and height. tx, ty, tw and th denote the anchor-encoded
49
+ center, width and height respectively.
50
+
51
+ See http://arxiv.org/abs/1506.01497 for details.
52
+ """
53
+
54
+
55
+ EPS = 1e-8
56
+
57
+
58
+ #@torch.jit.script
59
+ class FasterRcnnBoxCoder(object):
60
+ """Faster RCNN box coder."""
61
+
62
+ def __init__(self, scale_factors: Optional[List[float]] = None, eps: float = EPS):
63
+ """Constructor for FasterRcnnBoxCoder.
64
+
65
+ Args:
66
+ scale_factors: List of 4 positive scalars to scale ty, tx, th and tw.
67
+ If set to None, does not perform scaling. For Faster RCNN,
68
+ the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0].
69
+ """
70
+ self._scale_factors = scale_factors
71
+ if scale_factors is not None:
72
+ assert len(scale_factors) == 4
73
+ for scalar in scale_factors:
74
+ assert scalar > 0
75
+ self.eps = eps
76
+
77
+ #@property
78
+ def code_size(self):
79
+ return 4
80
+
81
+ def encode(self, boxes: BoxList, anchors: BoxList):
82
+ """Encode a box collection with respect to anchor collection.
83
+
84
+ Args:
85
+ boxes: BoxList holding N boxes to be encoded.
86
+ anchors: BoxList of anchors.
87
+
88
+ Returns:
89
+ a tensor representing N anchor-encoded boxes of the format [ty, tx, th, tw].
90
+ """
91
+ # Convert anchors to the center coordinate representation.
92
+ ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
93
+ ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes()
94
+ # Avoid NaN in division and log below.
95
+ ha += self.eps
96
+ wa += self.eps
97
+ h += self.eps
98
+ w += self.eps
99
+
100
+ tx = (xcenter - xcenter_a) / wa
101
+ ty = (ycenter - ycenter_a) / ha
102
+ tw = torch.log(w / wa)
103
+ th = torch.log(h / ha)
104
+ # Scales location targets as used in paper for joint training.
105
+ if self._scale_factors is not None:
106
+ ty *= self._scale_factors[0]
107
+ tx *= self._scale_factors[1]
108
+ th *= self._scale_factors[2]
109
+ tw *= self._scale_factors[3]
110
+ return torch.stack([ty, tx, th, tw]).t()
111
+
112
+ def decode(self, rel_codes, anchors: BoxList):
113
+ """Decode relative codes to boxes.
114
+
115
+ Args:
116
+ rel_codes: a tensor representing N anchor-encoded boxes.
117
+ anchors: BoxList of anchors.
118
+
119
+ Returns:
120
+ boxes: BoxList holding N bounding boxes.
121
+ """
122
+ ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
123
+
124
+ ty, tx, th, tw = rel_codes.t().unbind()
125
+ if self._scale_factors is not None:
126
+ ty /= self._scale_factors[0]
127
+ tx /= self._scale_factors[1]
128
+ th /= self._scale_factors[2]
129
+ tw /= self._scale_factors[3]
130
+ w = torch.exp(tw) * wa
131
+ h = torch.exp(th) * ha
132
+ ycenter = ty * ha + ycenter_a
133
+ xcenter = tx * wa + xcenter_a
134
+ ymin = ycenter - h / 2.
135
+ xmin = xcenter - w / 2.
136
+ ymax = ycenter + h / 2.
137
+ xmax = xcenter + w / 2.
138
+ return BoxList(torch.stack([ymin, xmin, ymax, xmax]).t())
139
+
140
+
141
+ def batch_decode(encoded_boxes, box_coder: FasterRcnnBoxCoder, anchors: BoxList):
142
+ """Decode a batch of encoded boxes.
143
+
144
+ This op takes a batch of encoded bounding boxes and transforms
145
+ them to a batch of bounding boxes specified by their corners in
146
+ the order of [y_min, x_min, y_max, x_max].
147
+
148
+ Args:
149
+ encoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
150
+ code_size] representing the location of the objects.
151
+ box_coder: a BoxCoder object.
152
+ anchors: a BoxList of anchors used to encode `encoded_boxes`.
153
+
154
+ Returns:
155
+ decoded_boxes: a float32 tensor of shape [batch_size, num_anchors, coder_size]
156
+ representing the corners of the objects in the order of [y_min, x_min, y_max, x_max].
157
+
158
+ Raises:
159
+ ValueError: if batch sizes of the inputs are inconsistent, or if
160
+ the number of anchors inferred from encoded_boxes and anchors are inconsistent.
161
+ """
162
+ assert len(encoded_boxes.shape) == 3
163
+ if encoded_boxes.shape[1] != anchors.num_boxes():
164
+ raise ValueError('The number of anchors inferred from encoded_boxes'
165
+ ' and anchors are inconsistent: shape[1] of encoded_boxes'
166
+ ' %s should be equal to the number of anchors: %s.' %
167
+ (encoded_boxes.shape[1], anchors.num_boxes()))
168
+
169
+ decoded_boxes = torch.stack([
170
+ box_coder.decode(boxes, anchors).boxes for boxes in encoded_boxes.unbind()
171
+ ])
172
+ return decoded_boxes