Spaces:
Sleeping
Sleeping
santit96
commited on
Commit
·
fa84113
0
Parent(s):
Create the streamlit app that classifies the trash in an image into classes
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .github/workflows/main.yml +20 -0
- .gitignore +5 -0
- README.md +14 -0
- app.py +82 -0
- constants.py +8 -0
- efficientdet/__init__.py +0 -0
- efficientdet/effdet/__init__.py +7 -0
- efficientdet/effdet/anchors.py +421 -0
- efficientdet/effdet/bench.py +143 -0
- efficientdet/effdet/config/__init__.py +4 -0
- efficientdet/effdet/config/config_utils.py +9 -0
- efficientdet/effdet/config/fpn_config.py +184 -0
- efficientdet/effdet/config/model_config.py +538 -0
- efficientdet/effdet/config/train_config.py +34 -0
- efficientdet/effdet/data/__init__.py +6 -0
- efficientdet/effdet/data/dataset.py +145 -0
- efficientdet/effdet/data/dataset_config.py +194 -0
- efficientdet/effdet/data/dataset_factory.py +85 -0
- efficientdet/effdet/data/input_config.py +60 -0
- efficientdet/effdet/data/loader.py +226 -0
- efficientdet/effdet/data/parsers/__init__.py +2 -0
- efficientdet/effdet/data/parsers/parser.py +82 -0
- efficientdet/effdet/data/parsers/parser_coco.py +93 -0
- efficientdet/effdet/data/parsers/parser_config.py +49 -0
- efficientdet/effdet/data/parsers/parser_factory.py +19 -0
- efficientdet/effdet/data/parsers/parser_open_images.py +211 -0
- efficientdet/effdet/data/parsers/parser_voc.py +148 -0
- efficientdet/effdet/data/random_erasing.py +94 -0
- efficientdet/effdet/data/transforms.py +275 -0
- efficientdet/effdet/data/transforms_albumentation.py +23 -0
- efficientdet/effdet/distributed.py +308 -0
- efficientdet/effdet/efficientdet.py +557 -0
- efficientdet/effdet/evaluation/README.md +7 -0
- efficientdet/effdet/evaluation/__init__.py +0 -0
- efficientdet/effdet/evaluation/detection_evaluator.py +590 -0
- efficientdet/effdet/evaluation/fields.py +105 -0
- efficientdet/effdet/evaluation/metrics.py +148 -0
- efficientdet/effdet/evaluation/np_box_list.py +696 -0
- efficientdet/effdet/evaluation/np_mask_list.py +478 -0
- efficientdet/effdet/evaluation/object_detection_evaluation.py +273 -0
- efficientdet/effdet/evaluation/per_image_evaluation.py +538 -0
- efficientdet/effdet/evaluator.py +195 -0
- efficientdet/effdet/factory.py +54 -0
- efficientdet/effdet/helpers.py +22 -0
- efficientdet/effdet/loss.py +259 -0
- efficientdet/effdet/object_detection/README.md +3 -0
- efficientdet/effdet/object_detection/__init__.py +22 -0
- efficientdet/effdet/object_detection/argmax_matcher.py +174 -0
- efficientdet/effdet/object_detection/box_coder.py +172 -0
.gitattributes
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.psd filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/main.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
lfs: true
|
17 |
+
- name: Push to hub
|
18 |
+
env:
|
19 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
20 |
+
run: git push --force https://santit96:[email protected]/spaces/rootstrap-org/waste-classifier main
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.DS_Store
|
3 |
+
*.jpg
|
4 |
+
*.png
|
5 |
+
*.jpeg
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Waste Classifier
|
3 |
+
emoji: ♻️
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: gray
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.25.0
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
+
Waste Classifier
|
12 |
+
==============================
|
13 |
+
|
14 |
+
Waste Detection and Classifier tool
|
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Streamlit app
|
3 |
+
"""
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
+
from constants import (CLAS_FILEPATH, CLAS_THRESHOLD, CLASSES, DET_FILEPATH,
|
9 |
+
DET_NAME, DET_THRESHOLD, DEVICE, OUTPUT_IMG_FILEPATH)
|
10 |
+
|
11 |
+
sys.path.append("./efficientdet")
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from efficientdet.efficientdet import plot_results
|
16 |
+
from trash_detector import detect_trash
|
17 |
+
|
18 |
+
|
19 |
+
def initial_config():
|
20 |
+
"""
|
21 |
+
Initial configuration of streamlit page
|
22 |
+
"""
|
23 |
+
st.set_page_config(
|
24 |
+
page_title="Waste Classifier",
|
25 |
+
page_icon="♻️",
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def render():
|
30 |
+
"""
|
31 |
+
Render the streamlit app
|
32 |
+
"""
|
33 |
+
st.title("Waste classifier")
|
34 |
+
st.markdown("""Classify your waste into different classes""")
|
35 |
+
|
36 |
+
# Image loader and button
|
37 |
+
uploaded_file = st.file_uploader(
|
38 |
+
"Upload image with trash", type=["jpg", "jpeg", "png", "gif", "bmp"]
|
39 |
+
)
|
40 |
+
classify_button = st.button("Classify trash")
|
41 |
+
|
42 |
+
if classify_button:
|
43 |
+
if not uploaded_file:
|
44 |
+
st.error("Upload an image")
|
45 |
+
else:
|
46 |
+
# Create two columns
|
47 |
+
col1, col2 = st.columns(2)
|
48 |
+
|
49 |
+
# Column 1: Uploaded image
|
50 |
+
with col1:
|
51 |
+
st.write("Uploaded image")
|
52 |
+
st.image(
|
53 |
+
uploaded_file, caption="Uploaded Image.", use_column_width=True
|
54 |
+
)
|
55 |
+
|
56 |
+
# Column 2: Classified image
|
57 |
+
with col2:
|
58 |
+
with st.spinner(text="Classifying the trash..."):
|
59 |
+
img = Image.open(uploaded_file).convert("RGB")
|
60 |
+
cls_prob, bboxes_final = detect_trash(
|
61 |
+
img,
|
62 |
+
DET_NAME,
|
63 |
+
DET_FILEPATH,
|
64 |
+
CLAS_FILEPATH,
|
65 |
+
DEVICE,
|
66 |
+
DET_THRESHOLD,
|
67 |
+
CLAS_THRESHOLD,
|
68 |
+
)
|
69 |
+
# plot and save demo image
|
70 |
+
plot_results(
|
71 |
+
img, cls_prob, bboxes_final, CLASSES, OUTPUT_IMG_FILEPATH
|
72 |
+
)
|
73 |
+
output_img = Image.open(OUTPUT_IMG_FILEPATH)
|
74 |
+
st.write("Classified image")
|
75 |
+
st.image(
|
76 |
+
output_img, caption="Classified Image.", use_column_width=True
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
initial_config()
|
82 |
+
render()
|
constants.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CLAS_FILEPATH = "models/resnet50-classifier.pkl"
|
2 |
+
DET_FILEPATH = "models/efficientdet-d2-detector.pth.tar"
|
3 |
+
CLASSES = ["cardboard", "compost", "glass", "metal", "paper", "plastic", "trash"]
|
4 |
+
DET_NAME = "tf_efficientdet_d2"
|
5 |
+
CLAS_THRESHOLD = 0.5
|
6 |
+
DET_THRESHOLD = 0.17
|
7 |
+
DEVICE = "cpu"
|
8 |
+
OUTPUT_IMG_FILEPATH = "classified_image.jpg"
|
efficientdet/__init__.py
ADDED
File without changes
|
efficientdet/effdet/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .efficientdet import EfficientDet
|
2 |
+
from .bench import DetBenchPredict, DetBenchTrain, unwrap_bench
|
3 |
+
from .data import create_dataset, create_loader, create_parser, DetectionDatset, SkipSubset
|
4 |
+
from .evaluator import CocoEvaluator, PascalEvaluator, OpenImagesEvaluator, create_evaluator
|
5 |
+
from .config import get_efficientdet_config, default_detection_model_configs
|
6 |
+
from .factory import create_model, create_model_from_config
|
7 |
+
from .helpers import load_checkpoint, load_pretrained
|
efficientdet/effdet/anchors.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" RetinaNet / EfficientDet Anchor Gen
|
2 |
+
|
3 |
+
Adapted for PyTorch from Tensorflow impl at
|
4 |
+
https://github.com/google/automl/blob/6f6694cec1a48cdb33d5d1551a2d5db8ad227798/efficientdet/anchors.py
|
5 |
+
|
6 |
+
Hacked together by Ross Wightman, original copyright below
|
7 |
+
"""
|
8 |
+
# Copyright 2020 Google Research. All Rights Reserved.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
# ==============================================================================
|
22 |
+
"""Anchor definition.
|
23 |
+
|
24 |
+
This module is borrowed from TPU RetinaNet implementation:
|
25 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/retinanet/anchors.py
|
26 |
+
"""
|
27 |
+
from typing import Optional, Tuple, Sequence
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
#import torchvision.ops.boxes as tvb
|
33 |
+
from torchvision.ops.boxes import batched_nms, remove_small_boxes
|
34 |
+
from typing import List
|
35 |
+
|
36 |
+
from effdet.object_detection import ArgMaxMatcher, FasterRcnnBoxCoder, BoxList, IouSimilarity, TargetAssigner
|
37 |
+
from .soft_nms import batched_soft_nms
|
38 |
+
|
39 |
+
|
40 |
+
# The minimum score to consider a logit for identifying detections.
|
41 |
+
MIN_CLASS_SCORE = -5.0
|
42 |
+
|
43 |
+
# The score for a dummy detection
|
44 |
+
_DUMMY_DETECTION_SCORE = -1e5
|
45 |
+
|
46 |
+
# The maximum number of (anchor,class) pairs to keep for non-max suppression.
|
47 |
+
MAX_DETECTION_POINTS = 5000
|
48 |
+
|
49 |
+
# The maximum number of detections per image.
|
50 |
+
MAX_DETECTIONS_PER_IMAGE = 100
|
51 |
+
|
52 |
+
|
53 |
+
def decode_box_outputs(rel_codes, anchors, output_xyxy: bool=False):
|
54 |
+
"""Transforms relative regression coordinates to absolute positions.
|
55 |
+
|
56 |
+
Network predictions are normalized and relative to a given anchor; this
|
57 |
+
reverses the transformation and outputs absolute coordinates for the input image.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
rel_codes: box regression targets.
|
61 |
+
|
62 |
+
anchors: anchors on all feature levels.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
outputs: bounding boxes.
|
66 |
+
|
67 |
+
"""
|
68 |
+
ycenter_a = (anchors[:, 0] + anchors[:, 2]) / 2
|
69 |
+
xcenter_a = (anchors[:, 1] + anchors[:, 3]) / 2
|
70 |
+
ha = anchors[:, 2] - anchors[:, 0]
|
71 |
+
wa = anchors[:, 3] - anchors[:, 1]
|
72 |
+
|
73 |
+
ty, tx, th, tw = rel_codes.unbind(dim=1)
|
74 |
+
|
75 |
+
w = torch.exp(tw) * wa
|
76 |
+
h = torch.exp(th) * ha
|
77 |
+
ycenter = ty * ha + ycenter_a
|
78 |
+
xcenter = tx * wa + xcenter_a
|
79 |
+
ymin = ycenter - h / 2.
|
80 |
+
xmin = xcenter - w / 2.
|
81 |
+
ymax = ycenter + h / 2.
|
82 |
+
xmax = xcenter + w / 2.
|
83 |
+
if output_xyxy:
|
84 |
+
out = torch.stack([xmin, ymin, xmax, ymax], dim=1)
|
85 |
+
else:
|
86 |
+
out = torch.stack([ymin, xmin, ymax, xmax], dim=1)
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
def clip_boxes_xyxy(boxes: torch.Tensor, size: torch.Tensor):
|
91 |
+
boxes = boxes.clamp(min=0)
|
92 |
+
size = torch.cat([size, size], dim=0)
|
93 |
+
boxes = boxes.min(size)
|
94 |
+
return boxes
|
95 |
+
|
96 |
+
|
97 |
+
def generate_detections(
|
98 |
+
cls_outputs, box_outputs, anchor_boxes, indices, classes,
|
99 |
+
img_scale: Optional[torch.Tensor], img_size: Optional[torch.Tensor],
|
100 |
+
max_det_per_image: int = MAX_DETECTIONS_PER_IMAGE, soft_nms: bool = False):
|
101 |
+
"""Generates detections with RetinaNet model outputs and anchors.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
cls_outputs: a torch tensor with shape [N, 1], which has the highest class
|
105 |
+
scores on all feature levels. The N is the number of selected
|
106 |
+
top-K total anchors on all levels. (k being MAX_DETECTION_POINTS)
|
107 |
+
|
108 |
+
box_outputs: a torch tensor with shape [N, 4], which stacks box regression
|
109 |
+
outputs on all feature levels. The N is the number of selected top-k
|
110 |
+
total anchors on all levels. (k being MAX_DETECTION_POINTS)
|
111 |
+
|
112 |
+
anchor_boxes: a torch tensor with shape [N, 4], which stacks anchors on all
|
113 |
+
feature levels. The N is the number of selected top-k total anchors on all levels.
|
114 |
+
|
115 |
+
indices: a torch tensor with shape [N], which is the indices from top-k selection.
|
116 |
+
|
117 |
+
classes: a torch tensor with shape [N], which represents the class
|
118 |
+
prediction on all selected anchors from top-k selection.
|
119 |
+
|
120 |
+
img_scale: a float tensor representing the scale between original image
|
121 |
+
and input image for the detector. It is used to rescale detections for
|
122 |
+
evaluating with the original groundtruth annotations.
|
123 |
+
|
124 |
+
max_det_per_image: an int constant, added as argument to make torchscript happy
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
detections: detection results in a tensor with shape [MAX_DETECTION_POINTS, 6],
|
128 |
+
each row representing [x_min, y_min, x_max, y_max, score, class]
|
129 |
+
"""
|
130 |
+
assert box_outputs.shape[-1] == 4
|
131 |
+
assert anchor_boxes.shape[-1] == 4
|
132 |
+
assert cls_outputs.shape[-1] == 1
|
133 |
+
|
134 |
+
anchor_boxes = anchor_boxes[indices, :]
|
135 |
+
|
136 |
+
# Appply bounding box regression to anchors, boxes are converted to xyxy
|
137 |
+
# here since PyTorch NMS expects them in that form.
|
138 |
+
boxes = decode_box_outputs(box_outputs.float(), anchor_boxes, output_xyxy=True)
|
139 |
+
if img_scale is not None and img_size is not None:
|
140 |
+
boxes = clip_boxes_xyxy(boxes, img_size / img_scale) # clip before NMS better?
|
141 |
+
|
142 |
+
scores = cls_outputs.sigmoid().squeeze(1).float()
|
143 |
+
if soft_nms:
|
144 |
+
top_detection_idx, soft_scores = batched_soft_nms(
|
145 |
+
boxes, scores, classes, method_gaussian=True, iou_threshold=0.3, score_threshold=.001)
|
146 |
+
scores[top_detection_idx] = soft_scores
|
147 |
+
else:
|
148 |
+
top_detection_idx = batched_nms(boxes, scores, classes, iou_threshold=0.5)
|
149 |
+
|
150 |
+
# keep only topk scoring predictions
|
151 |
+
top_detection_idx = top_detection_idx[:max_det_per_image]
|
152 |
+
boxes = boxes[top_detection_idx]
|
153 |
+
scores = scores[top_detection_idx, None]
|
154 |
+
classes = classes[top_detection_idx, None] + 1 # back to class idx with background class = 0
|
155 |
+
|
156 |
+
if img_scale is not None:
|
157 |
+
boxes = boxes * img_scale
|
158 |
+
|
159 |
+
# FIXME add option to convert boxes back to yxyx? Otherwise must be handled downstream if
|
160 |
+
# that is the preferred output format.
|
161 |
+
|
162 |
+
# stack em and pad out to MAX_DETECTIONS_PER_IMAGE if necessary
|
163 |
+
num_det = len(top_detection_idx)
|
164 |
+
detections = torch.cat([boxes, scores, classes.float()], dim=1)
|
165 |
+
if num_det < max_det_per_image:
|
166 |
+
detections = torch.cat([
|
167 |
+
detections,
|
168 |
+
torch.zeros((max_det_per_image - num_det, 6), device=detections.device, dtype=detections.dtype)
|
169 |
+
], dim=0)
|
170 |
+
return detections
|
171 |
+
|
172 |
+
|
173 |
+
def get_feat_sizes(image_size: Tuple[int, int], max_level: int):
|
174 |
+
"""Get feat widths and heights for all levels.
|
175 |
+
Args:
|
176 |
+
image_size: a tuple (H, W)
|
177 |
+
max_level: maximum feature level.
|
178 |
+
Returns:
|
179 |
+
feat_sizes: a list of tuples (height, width) for each level.
|
180 |
+
"""
|
181 |
+
feat_size = image_size
|
182 |
+
feat_sizes = [feat_size]
|
183 |
+
for _ in range(1, max_level + 1):
|
184 |
+
feat_size = ((feat_size[0] - 1) // 2 + 1, (feat_size[1] - 1) // 2 + 1)
|
185 |
+
feat_sizes.append(feat_size)
|
186 |
+
return feat_sizes
|
187 |
+
|
188 |
+
|
189 |
+
class Anchors(nn.Module):
|
190 |
+
"""RetinaNet Anchors class."""
|
191 |
+
|
192 |
+
def __init__(self, min_level, max_level, num_scales, aspect_ratios, anchor_scale, image_size: Tuple[int, int]):
|
193 |
+
"""Constructs multiscale RetinaNet anchors.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
min_level: integer number of minimum level of the output feature pyramid.
|
197 |
+
|
198 |
+
max_level: integer number of maximum level of the output feature pyramid.
|
199 |
+
|
200 |
+
num_scales: integer number representing intermediate scales added
|
201 |
+
on each level. For instances, num_scales=2 adds two additional
|
202 |
+
anchor scales [2^0, 2^0.5] on each level.
|
203 |
+
|
204 |
+
aspect_ratios: list of tuples representing the aspect ratio anchors added
|
205 |
+
on each level. For instances, aspect_ratios =
|
206 |
+
[(1, 1), (1.4, 0.7), (0.7, 1.4)] adds three anchors on each level.
|
207 |
+
|
208 |
+
anchor_scale: float number representing the scale of size of the base
|
209 |
+
anchor to the feature stride 2^level.
|
210 |
+
|
211 |
+
image_size: Sequence specifying input image size of model (H, W).
|
212 |
+
The image_size should be divided by the largest feature stride 2^max_level.
|
213 |
+
"""
|
214 |
+
super(Anchors, self).__init__()
|
215 |
+
self.min_level = min_level
|
216 |
+
self.max_level = max_level
|
217 |
+
self.num_scales = num_scales
|
218 |
+
self.aspect_ratios = aspect_ratios
|
219 |
+
if isinstance(anchor_scale, Sequence):
|
220 |
+
assert len(anchor_scale) == max_level - min_level + 1
|
221 |
+
self.anchor_scales = anchor_scale
|
222 |
+
else:
|
223 |
+
self.anchor_scales = [anchor_scale] * (max_level - min_level + 1)
|
224 |
+
|
225 |
+
assert isinstance(image_size, Sequence) and len(image_size) == 2
|
226 |
+
# FIXME this restriction can likely be relaxed with some additional changes
|
227 |
+
assert image_size[0] % 2 ** max_level == 0, 'Image size must be divisible by 2 ** max_level (128)'
|
228 |
+
assert image_size[1] % 2 ** max_level == 0, 'Image size must be divisible by 2 ** max_level (128)'
|
229 |
+
self.image_size = tuple(image_size)
|
230 |
+
self.feat_sizes = get_feat_sizes(image_size, max_level)
|
231 |
+
self.config = self._generate_configs()
|
232 |
+
self.register_buffer('boxes', self._generate_boxes())
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def from_config(cls, config):
|
236 |
+
return cls(
|
237 |
+
config.min_level, config.max_level,
|
238 |
+
config.num_scales, config.aspect_ratios,
|
239 |
+
config.anchor_scale, config.image_size)
|
240 |
+
|
241 |
+
def _generate_configs(self):
|
242 |
+
"""Generate configurations of anchor boxes."""
|
243 |
+
anchor_configs = {}
|
244 |
+
feat_sizes = self.feat_sizes
|
245 |
+
for level in range(self.min_level, self.max_level + 1):
|
246 |
+
anchor_configs[level] = []
|
247 |
+
for scale_octave in range(self.num_scales):
|
248 |
+
for aspect in self.aspect_ratios:
|
249 |
+
anchor_configs[level].append(
|
250 |
+
((feat_sizes[0][0] // feat_sizes[level][0],
|
251 |
+
feat_sizes[0][1] // feat_sizes[level][1]),
|
252 |
+
scale_octave / float(self.num_scales), aspect,
|
253 |
+
self.anchor_scales[level - self.min_level]))
|
254 |
+
return anchor_configs
|
255 |
+
|
256 |
+
def _generate_boxes(self):
|
257 |
+
"""Generates multiscale anchor boxes."""
|
258 |
+
boxes_all = []
|
259 |
+
for _, configs in self.config.items():
|
260 |
+
boxes_level = []
|
261 |
+
for config in configs:
|
262 |
+
stride, octave_scale, aspect, anchor_scale = config
|
263 |
+
base_anchor_size_x = anchor_scale * stride[1] * 2 ** octave_scale
|
264 |
+
base_anchor_size_y = anchor_scale * stride[0] * 2 ** octave_scale
|
265 |
+
if isinstance(aspect, Sequence):
|
266 |
+
aspect_x = aspect[0]
|
267 |
+
aspect_y = aspect[1]
|
268 |
+
else:
|
269 |
+
aspect_x = np.sqrt(aspect)
|
270 |
+
aspect_y = 1.0 / aspect_x
|
271 |
+
anchor_size_x_2 = base_anchor_size_x * aspect_x / 2.0
|
272 |
+
anchor_size_y_2 = base_anchor_size_y * aspect_y / 2.0
|
273 |
+
|
274 |
+
x = np.arange(stride[1] / 2, self.image_size[1], stride[1])
|
275 |
+
y = np.arange(stride[0] / 2, self.image_size[0], stride[0])
|
276 |
+
xv, yv = np.meshgrid(x, y)
|
277 |
+
xv = xv.reshape(-1)
|
278 |
+
yv = yv.reshape(-1)
|
279 |
+
|
280 |
+
boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
|
281 |
+
yv + anchor_size_y_2, xv + anchor_size_x_2))
|
282 |
+
boxes = np.swapaxes(boxes, 0, 1)
|
283 |
+
boxes_level.append(np.expand_dims(boxes, axis=1))
|
284 |
+
|
285 |
+
# concat anchors on the same level to the reshape NxAx4
|
286 |
+
boxes_level = np.concatenate(boxes_level, axis=1)
|
287 |
+
boxes_all.append(boxes_level.reshape([-1, 4]))
|
288 |
+
|
289 |
+
anchor_boxes = np.vstack(boxes_all)
|
290 |
+
anchor_boxes = torch.from_numpy(anchor_boxes).float()
|
291 |
+
return anchor_boxes
|
292 |
+
|
293 |
+
def get_anchors_per_location(self):
|
294 |
+
return self.num_scales * len(self.aspect_ratios)
|
295 |
+
|
296 |
+
|
297 |
+
class AnchorLabeler(object):
|
298 |
+
"""Labeler for multiscale anchor boxes.
|
299 |
+
"""
|
300 |
+
|
301 |
+
def __init__(self, anchors, num_classes: int, match_threshold: float = 0.5):
|
302 |
+
"""Constructs anchor labeler to assign labels to anchors.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
anchors: an instance of class Anchors.
|
306 |
+
|
307 |
+
num_classes: integer number representing number of classes in the dataset.
|
308 |
+
|
309 |
+
match_threshold: float number between 0 and 1 representing the threshold
|
310 |
+
to assign positive labels for anchors.
|
311 |
+
"""
|
312 |
+
similarity_calc = IouSimilarity()
|
313 |
+
matcher = ArgMaxMatcher(
|
314 |
+
match_threshold,
|
315 |
+
unmatched_threshold=match_threshold,
|
316 |
+
negatives_lower_than_unmatched=True,
|
317 |
+
force_match_for_each_row=True)
|
318 |
+
box_coder = FasterRcnnBoxCoder()
|
319 |
+
|
320 |
+
self.target_assigner = TargetAssigner(similarity_calc, matcher, box_coder)
|
321 |
+
self.anchors = anchors
|
322 |
+
self.match_threshold = match_threshold
|
323 |
+
self.num_classes = num_classes
|
324 |
+
self.indices_cache = {}
|
325 |
+
|
326 |
+
def label_anchors(self, gt_boxes, gt_classes, filter_valid=True):
|
327 |
+
"""Labels anchors with ground truth inputs.
|
328 |
+
|
329 |
+
Args:
|
330 |
+
gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
|
331 |
+
For each row, it stores [y0, x0, y1, x1] for four corners of a box.
|
332 |
+
|
333 |
+
gt_classes: A integer tensor with shape [N, 1] representing groundtruth classes.
|
334 |
+
|
335 |
+
filter_valid: Filter out any boxes w/ gt class <= -1 before assigning
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
cls_targets_dict: ordered dictionary with keys [min_level, min_level+1, ..., max_level].
|
339 |
+
The values are tensor with shape [height_l, width_l, num_anchors]. The height_l and width_l
|
340 |
+
represent the dimension of class logits at l-th level.
|
341 |
+
|
342 |
+
box_targets_dict: ordered dictionary with keys [min_level, min_level+1, ..., max_level].
|
343 |
+
The values are tensor with shape [height_l, width_l, num_anchors * 4]. The height_l and
|
344 |
+
width_l represent the dimension of bounding box regression output at l-th level.
|
345 |
+
|
346 |
+
num_positives: scalar tensor storing number of positives in an image.
|
347 |
+
"""
|
348 |
+
cls_targets_out = []
|
349 |
+
box_targets_out = []
|
350 |
+
|
351 |
+
if filter_valid:
|
352 |
+
valid_idx = gt_classes > -1 # filter gt targets w/ label <= -1
|
353 |
+
gt_boxes = gt_boxes[valid_idx]
|
354 |
+
gt_classes = gt_classes[valid_idx]
|
355 |
+
|
356 |
+
cls_targets, box_targets, matches = self.target_assigner.assign(
|
357 |
+
BoxList(self.anchors.boxes), BoxList(gt_boxes), gt_classes)
|
358 |
+
|
359 |
+
# class labels start from 1 and the background class = -1
|
360 |
+
cls_targets = (cls_targets - 1).long()
|
361 |
+
|
362 |
+
# Unpack labels.
|
363 |
+
"""Unpacks an array of cls/box into multiple scales."""
|
364 |
+
count = 0
|
365 |
+
for level in range(self.anchors.min_level, self.anchors.max_level + 1):
|
366 |
+
feat_size = self.anchors.feat_sizes[level]
|
367 |
+
steps = feat_size[0] * feat_size[1] * self.anchors.get_anchors_per_location()
|
368 |
+
cls_targets_out.append(cls_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
|
369 |
+
box_targets_out.append(box_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
|
370 |
+
count += steps
|
371 |
+
|
372 |
+
num_positives = (matches.match_results > -1).float().sum()
|
373 |
+
|
374 |
+
return cls_targets_out, box_targets_out, num_positives
|
375 |
+
|
376 |
+
def batch_label_anchors(self, gt_boxes, gt_classes, filter_valid=True):
|
377 |
+
batch_size = len(gt_boxes)
|
378 |
+
assert batch_size == len(gt_classes)
|
379 |
+
num_levels = self.anchors.max_level - self.anchors.min_level + 1
|
380 |
+
cls_targets_out = [[] for _ in range(num_levels)]
|
381 |
+
box_targets_out = [[] for _ in range(num_levels)]
|
382 |
+
num_positives_out = []
|
383 |
+
|
384 |
+
anchor_box_list = BoxList(self.anchors.boxes)
|
385 |
+
for i in range(batch_size):
|
386 |
+
last_sample = i == batch_size - 1
|
387 |
+
|
388 |
+
if filter_valid:
|
389 |
+
valid_idx = gt_classes[i] > -1 # filter gt targets w/ label <= -1
|
390 |
+
gt_box_list = BoxList(gt_boxes[i][valid_idx])
|
391 |
+
gt_class_i = gt_classes[i][valid_idx]
|
392 |
+
else:
|
393 |
+
gt_box_list = BoxList(gt_boxes[i])
|
394 |
+
gt_class_i = gt_classes[i]
|
395 |
+
cls_targets, box_targets, matches = self.target_assigner.assign(anchor_box_list, gt_box_list, gt_class_i)
|
396 |
+
|
397 |
+
# class labels start from 1 and the background class = -1
|
398 |
+
cls_targets = (cls_targets - 1).long()
|
399 |
+
|
400 |
+
# Unpack labels.
|
401 |
+
"""Unpacks an array of cls/box into multiple scales."""
|
402 |
+
count = 0
|
403 |
+
for level in range(self.anchors.min_level, self.anchors.max_level + 1):
|
404 |
+
level_idx = level - self.anchors.min_level
|
405 |
+
feat_size = self.anchors.feat_sizes[level]
|
406 |
+
steps = feat_size[0] * feat_size[1] * self.anchors.get_anchors_per_location()
|
407 |
+
cls_targets_out[level_idx].append(
|
408 |
+
cls_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
|
409 |
+
box_targets_out[level_idx].append(
|
410 |
+
box_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
|
411 |
+
count += steps
|
412 |
+
if last_sample:
|
413 |
+
cls_targets_out[level_idx] = torch.stack(cls_targets_out[level_idx])
|
414 |
+
box_targets_out[level_idx] = torch.stack(box_targets_out[level_idx])
|
415 |
+
|
416 |
+
num_positives_out.append((matches.match_results > -1).float().sum())
|
417 |
+
if last_sample:
|
418 |
+
num_positives_out = torch.stack(num_positives_out)
|
419 |
+
|
420 |
+
return cls_targets_out, box_targets_out, num_positives_out
|
421 |
+
|
efficientdet/effdet/bench.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch EfficientDet support benches
|
2 |
+
|
3 |
+
Hacked together by Ross Wightman
|
4 |
+
"""
|
5 |
+
from typing import Optional, Dict, List
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from timm.utils import ModelEma
|
9 |
+
from .anchors import Anchors, AnchorLabeler, generate_detections, MAX_DETECTION_POINTS
|
10 |
+
from .loss import DetectionLoss
|
11 |
+
|
12 |
+
|
13 |
+
def _post_process(
|
14 |
+
cls_outputs: List[torch.Tensor],
|
15 |
+
box_outputs: List[torch.Tensor],
|
16 |
+
num_levels: int,
|
17 |
+
num_classes: int,
|
18 |
+
max_detection_points: int = MAX_DETECTION_POINTS,
|
19 |
+
):
|
20 |
+
"""Selects top-k predictions.
|
21 |
+
|
22 |
+
Post-proc code adapted from Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet
|
23 |
+
and optimized for PyTorch.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
cls_outputs: an OrderDict with keys representing levels and values
|
27 |
+
representing logits in [batch_size, height, width, num_anchors].
|
28 |
+
|
29 |
+
box_outputs: an OrderDict with keys representing levels and values
|
30 |
+
representing box regression targets in [batch_size, height, width, num_anchors * 4].
|
31 |
+
|
32 |
+
num_levels (int): number of feature levels
|
33 |
+
|
34 |
+
num_classes (int): number of output classes
|
35 |
+
"""
|
36 |
+
batch_size = cls_outputs[0].shape[0]
|
37 |
+
cls_outputs_all = torch.cat([
|
38 |
+
cls_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, num_classes])
|
39 |
+
for level in range(num_levels)], 1)
|
40 |
+
|
41 |
+
box_outputs_all = torch.cat([
|
42 |
+
box_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, 4])
|
43 |
+
for level in range(num_levels)], 1)
|
44 |
+
|
45 |
+
_, cls_topk_indices_all = torch.topk(cls_outputs_all.reshape(batch_size, -1), dim=1, k=max_detection_points)
|
46 |
+
indices_all = cls_topk_indices_all // num_classes
|
47 |
+
classes_all = cls_topk_indices_all % num_classes
|
48 |
+
|
49 |
+
box_outputs_all_after_topk = torch.gather(
|
50 |
+
box_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, 4))
|
51 |
+
|
52 |
+
cls_outputs_all_after_topk = torch.gather(
|
53 |
+
cls_outputs_all, 1, indices_all.unsqueeze(2).expand(-1, -1, num_classes))
|
54 |
+
cls_outputs_all_after_topk = torch.gather(
|
55 |
+
cls_outputs_all_after_topk, 2, classes_all.unsqueeze(2))
|
56 |
+
|
57 |
+
return cls_outputs_all_after_topk, box_outputs_all_after_topk, indices_all, classes_all
|
58 |
+
|
59 |
+
|
60 |
+
@torch.jit.script
|
61 |
+
def _batch_detection(
|
62 |
+
batch_size: int, class_out, box_out, anchor_boxes, indices, classes,
|
63 |
+
img_scale: Optional[torch.Tensor] = None, img_size: Optional[torch.Tensor] = None):
|
64 |
+
batch_detections = []
|
65 |
+
# FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome
|
66 |
+
for i in range(batch_size):
|
67 |
+
img_scale_i = None if img_scale is None else img_scale[i]
|
68 |
+
img_size_i = None if img_size is None else img_size[i]
|
69 |
+
detections = generate_detections(
|
70 |
+
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale_i, img_size_i)
|
71 |
+
batch_detections.append(detections)
|
72 |
+
return torch.stack(batch_detections, dim=0)
|
73 |
+
|
74 |
+
|
75 |
+
class DetBenchPredict(nn.Module):
|
76 |
+
def __init__(self, model):
|
77 |
+
super(DetBenchPredict, self).__init__()
|
78 |
+
self.model = model
|
79 |
+
self.config = model.config # FIXME remove this when we can use @property (torchscript limitation)
|
80 |
+
self.num_levels = model.config.num_levels
|
81 |
+
self.num_classes = model.config.num_classes
|
82 |
+
self.anchors = Anchors.from_config(model.config)
|
83 |
+
|
84 |
+
def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None):
|
85 |
+
class_out, box_out = self.model(x)
|
86 |
+
class_out, box_out, indices, classes = _post_process(
|
87 |
+
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
|
88 |
+
if img_info is None:
|
89 |
+
img_scale, img_size = None, None
|
90 |
+
else:
|
91 |
+
img_scale, img_size = img_info['img_scale'], img_info['img_size']
|
92 |
+
return _batch_detection(
|
93 |
+
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scale, img_size)
|
94 |
+
|
95 |
+
|
96 |
+
class DetBenchTrain(nn.Module):
|
97 |
+
def __init__(self, model, create_labeler=True):
|
98 |
+
super(DetBenchTrain, self).__init__()
|
99 |
+
self.model = model
|
100 |
+
self.config = model.config # FIXME remove this when we can use @property (torchscript limitation)
|
101 |
+
self.num_levels = model.config.num_levels
|
102 |
+
self.num_classes = model.config.num_classes
|
103 |
+
self.anchors = Anchors.from_config(model.config)
|
104 |
+
self.anchor_labeler = None
|
105 |
+
if create_labeler:
|
106 |
+
self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
|
107 |
+
self.loss_fn = DetectionLoss(model.config)
|
108 |
+
|
109 |
+
def forward(self, x, target: Dict[str, torch.Tensor]):
|
110 |
+
class_out, box_out = self.model(x)
|
111 |
+
if self.anchor_labeler is None:
|
112 |
+
# target should contain pre-computed anchor labels if labeler not present in bench
|
113 |
+
assert 'label_num_positives' in target
|
114 |
+
cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)]
|
115 |
+
box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)]
|
116 |
+
num_positives = target['label_num_positives']
|
117 |
+
else:
|
118 |
+
cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
|
119 |
+
target['bbox'], target['cls'])
|
120 |
+
|
121 |
+
loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
|
122 |
+
output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss}
|
123 |
+
if not self.training:
|
124 |
+
# if eval mode, output detections for evaluation
|
125 |
+
class_out_pp, box_out_pp, indices, classes = _post_process(
|
126 |
+
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
|
127 |
+
output['detections'] = _batch_detection(
|
128 |
+
x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes,
|
129 |
+
target['img_scale'], target['img_size'])
|
130 |
+
return output
|
131 |
+
|
132 |
+
|
133 |
+
def unwrap_bench(model):
|
134 |
+
# Unwrap a model in support bench so that various other fns can access the weights and attribs of the
|
135 |
+
# underlying model directly
|
136 |
+
if isinstance(model, ModelEma): # unwrap ModelEma
|
137 |
+
return unwrap_bench(model.ema)
|
138 |
+
elif hasattr(model, 'module'): # unwrap DDP
|
139 |
+
return unwrap_bench(model.module)
|
140 |
+
elif hasattr(model, 'model'): # unwrap Bench -> model
|
141 |
+
return unwrap_bench(model.model)
|
142 |
+
else:
|
143 |
+
return model
|
efficientdet/effdet/config/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config_utils import set_config_readonly, set_config_writeable
|
2 |
+
from .fpn_config import get_fpn_config
|
3 |
+
from .model_config import get_efficientdet_config, default_detection_model_configs
|
4 |
+
from .train_config import default_detection_train_config
|
efficientdet/effdet/config/config_utils.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import OmegaConf
|
2 |
+
|
3 |
+
|
4 |
+
def set_config_readonly(conf):
|
5 |
+
OmegaConf.set_readonly(conf, True)
|
6 |
+
|
7 |
+
|
8 |
+
def set_config_writeable(conf):
|
9 |
+
OmegaConf.set_readonly(conf, False)
|
efficientdet/effdet/config/fpn_config.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
|
5 |
+
|
6 |
+
def bifpn_config(min_level, max_level, weight_method=None):
|
7 |
+
"""BiFPN config.
|
8 |
+
Adapted from https://github.com/google/automl/blob/56815c9986ffd4b508fe1d68508e268d129715c1/efficientdet/keras/fpn_configs.py
|
9 |
+
"""
|
10 |
+
p = OmegaConf.create()
|
11 |
+
weight_method = weight_method or 'fastattn'
|
12 |
+
|
13 |
+
num_levels = max_level - min_level + 1
|
14 |
+
node_ids = {min_level + i: [i] for i in range(num_levels)}
|
15 |
+
|
16 |
+
level_last_id = lambda level: node_ids[level][-1]
|
17 |
+
level_all_ids = lambda level: node_ids[level]
|
18 |
+
id_cnt = itertools.count(num_levels)
|
19 |
+
|
20 |
+
p.nodes = []
|
21 |
+
for i in range(max_level - 1, min_level - 1, -1):
|
22 |
+
# top-down path.
|
23 |
+
p.nodes.append({
|
24 |
+
'reduction': 1 << i,
|
25 |
+
'inputs_offsets': [level_last_id(i), level_last_id(i + 1)],
|
26 |
+
'weight_method': weight_method,
|
27 |
+
})
|
28 |
+
node_ids[i].append(next(id_cnt))
|
29 |
+
|
30 |
+
for i in range(min_level + 1, max_level + 1):
|
31 |
+
# bottom-up path.
|
32 |
+
p.nodes.append({
|
33 |
+
'reduction': 1 << i,
|
34 |
+
'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)],
|
35 |
+
'weight_method': weight_method,
|
36 |
+
})
|
37 |
+
node_ids[i].append(next(id_cnt))
|
38 |
+
return p
|
39 |
+
|
40 |
+
|
41 |
+
def panfpn_config(min_level, max_level, weight_method=None):
|
42 |
+
"""PAN FPN config.
|
43 |
+
|
44 |
+
This defines FPN layout from Path Aggregation Networks as an alternate to
|
45 |
+
BiFPN, it does not implement the full PAN spec.
|
46 |
+
|
47 |
+
Paper: https://arxiv.org/abs/1803.01534
|
48 |
+
"""
|
49 |
+
p = OmegaConf.create()
|
50 |
+
weight_method = weight_method or 'fastattn'
|
51 |
+
|
52 |
+
num_levels = max_level - min_level + 1
|
53 |
+
node_ids = {min_level + i: [i] for i in range(num_levels)}
|
54 |
+
level_last_id = lambda level: node_ids[level][-1]
|
55 |
+
id_cnt = itertools.count(num_levels)
|
56 |
+
|
57 |
+
p.nodes = []
|
58 |
+
for i in range(max_level, min_level - 1, -1):
|
59 |
+
# top-down path.
|
60 |
+
offsets = [level_last_id(i), level_last_id(i + 1)] if i != max_level else [level_last_id(i)]
|
61 |
+
p.nodes.append({
|
62 |
+
'reduction': 1 << i,
|
63 |
+
'inputs_offsets': offsets,
|
64 |
+
'weight_method': weight_method,
|
65 |
+
})
|
66 |
+
node_ids[i].append(next(id_cnt))
|
67 |
+
|
68 |
+
for i in range(min_level, max_level + 1):
|
69 |
+
# bottom-up path.
|
70 |
+
offsets = [level_last_id(i), level_last_id(i - 1)] if i != min_level else [level_last_id(i)]
|
71 |
+
p.nodes.append({
|
72 |
+
'reduction': 1 << i,
|
73 |
+
'inputs_offsets': offsets,
|
74 |
+
'weight_method': weight_method,
|
75 |
+
})
|
76 |
+
node_ids[i].append(next(id_cnt))
|
77 |
+
|
78 |
+
return p
|
79 |
+
|
80 |
+
|
81 |
+
def qufpn_config(min_level, max_level, weight_method=None):
|
82 |
+
"""A dynamic quad fpn config that can adapt to different min/max levels.
|
83 |
+
|
84 |
+
It extends the idea of BiFPN, and has four paths:
|
85 |
+
(up_down -> bottom_up) + (bottom_up -> up_down).
|
86 |
+
|
87 |
+
Paper: https://ieeexplore.ieee.org/document/9225379
|
88 |
+
Ref code: From contribution to TF EfficientDet
|
89 |
+
https://github.com/google/automl/blob/eb74c6739382e9444817d2ad97c4582dbe9a9020/efficientdet/keras/fpn_configs.py
|
90 |
+
"""
|
91 |
+
p = OmegaConf.create()
|
92 |
+
weight_method = weight_method or 'fastattn'
|
93 |
+
quad_method = 'fastattn'
|
94 |
+
num_levels = max_level - min_level + 1
|
95 |
+
node_ids = {min_level + i: [i] for i in range(num_levels)}
|
96 |
+
level_last_id = lambda level: node_ids[level][-1]
|
97 |
+
level_all_ids = lambda level: node_ids[level]
|
98 |
+
level_first_id = lambda level: node_ids[level][0]
|
99 |
+
id_cnt = itertools.count(num_levels)
|
100 |
+
|
101 |
+
p.nodes = []
|
102 |
+
for i in range(max_level - 1, min_level - 1, -1):
|
103 |
+
# top-down path 1.
|
104 |
+
p.nodes.append({
|
105 |
+
'reduction': 1 << i,
|
106 |
+
'inputs_offsets': [level_last_id(i), level_last_id(i + 1)],
|
107 |
+
'weight_method': weight_method
|
108 |
+
})
|
109 |
+
node_ids[i].append(next(id_cnt))
|
110 |
+
node_ids[max_level].append(node_ids[max_level][-1])
|
111 |
+
|
112 |
+
for i in range(min_level + 1, max_level):
|
113 |
+
# bottom-up path 2.
|
114 |
+
p.nodes.append({
|
115 |
+
'reduction': 1 << i,
|
116 |
+
'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)],
|
117 |
+
'weight_method': weight_method
|
118 |
+
})
|
119 |
+
node_ids[i].append(next(id_cnt))
|
120 |
+
|
121 |
+
i = max_level
|
122 |
+
p.nodes.append({
|
123 |
+
'reduction': 1 << i,
|
124 |
+
'inputs_offsets': [level_first_id(i)] + [level_last_id(i - 1)],
|
125 |
+
'weight_method': weight_method
|
126 |
+
})
|
127 |
+
node_ids[i].append(next(id_cnt))
|
128 |
+
node_ids[min_level].append(node_ids[min_level][-1])
|
129 |
+
|
130 |
+
for i in range(min_level + 1, max_level + 1, 1):
|
131 |
+
# bottom-up path 3.
|
132 |
+
p.nodes.append({
|
133 |
+
'reduction': 1 << i,
|
134 |
+
'inputs_offsets': [
|
135 |
+
level_first_id(i), level_last_id(i - 1) if i != min_level + 1 else level_first_id(i - 1)],
|
136 |
+
'weight_method': weight_method
|
137 |
+
})
|
138 |
+
node_ids[i].append(next(id_cnt))
|
139 |
+
node_ids[min_level].append(node_ids[min_level][-1])
|
140 |
+
|
141 |
+
for i in range(max_level - 1, min_level, -1):
|
142 |
+
# top-down path 4.
|
143 |
+
p.nodes.append({
|
144 |
+
'reduction': 1 << i,
|
145 |
+
'inputs_offsets': [node_ids[i][0]] + [node_ids[i][-1]] + [level_last_id(i + 1)],
|
146 |
+
'weight_method': weight_method
|
147 |
+
})
|
148 |
+
node_ids[i].append(next(id_cnt))
|
149 |
+
i = min_level
|
150 |
+
p.nodes.append({
|
151 |
+
'reduction': 1 << i,
|
152 |
+
'inputs_offsets': [node_ids[i][0]] + [level_last_id(i + 1)],
|
153 |
+
'weight_method': weight_method
|
154 |
+
})
|
155 |
+
node_ids[i].append(next(id_cnt))
|
156 |
+
node_ids[max_level].append(node_ids[max_level][-1])
|
157 |
+
|
158 |
+
# NOTE: the order of the quad path is reversed from the original, my code expects the output of
|
159 |
+
# each FPN repeat to be same as input from backbone, in order of increasing reductions
|
160 |
+
for i in range(min_level, max_level + 1):
|
161 |
+
# quad-add path.
|
162 |
+
p.nodes.append({
|
163 |
+
'reduction': 1 << i,
|
164 |
+
'inputs_offsets': [node_ids[i][2], node_ids[i][4]],
|
165 |
+
'weight_method': quad_method
|
166 |
+
})
|
167 |
+
node_ids[i].append(next(id_cnt))
|
168 |
+
|
169 |
+
return p
|
170 |
+
|
171 |
+
|
172 |
+
def get_fpn_config(fpn_name, min_level=3, max_level=7):
|
173 |
+
if not fpn_name:
|
174 |
+
fpn_name = 'bifpn_fa'
|
175 |
+
name_to_config = {
|
176 |
+
'bifpn_sum': bifpn_config(min_level=min_level, max_level=max_level, weight_method='sum'),
|
177 |
+
'bifpn_attn': bifpn_config(min_level=min_level, max_level=max_level, weight_method='attn'),
|
178 |
+
'bifpn_fa': bifpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'),
|
179 |
+
'pan_sum': panfpn_config(min_level=min_level, max_level=max_level, weight_method='sum'),
|
180 |
+
'pan_fa': panfpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'),
|
181 |
+
'qufpn_sum': qufpn_config(min_level=min_level, max_level=max_level, weight_method='sum'),
|
182 |
+
'qufpn_fa': qufpn_config(min_level=min_level, max_level=max_level, weight_method='fastattn'),
|
183 |
+
}
|
184 |
+
return name_to_config[fpn_name]
|
efficientdet/effdet/config/model_config.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""EfficientDet Configurations
|
2 |
+
|
3 |
+
Adapted from official impl at https://github.com/google/automl/tree/master/efficientdet
|
4 |
+
|
5 |
+
TODO use a different config system (OmegaConfig -> Hydra?), separate model from train specific hparams
|
6 |
+
"""
|
7 |
+
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from copy import deepcopy
|
10 |
+
|
11 |
+
|
12 |
+
def default_detection_model_configs():
|
13 |
+
"""Returns a default detection configs."""
|
14 |
+
h = OmegaConf.create()
|
15 |
+
|
16 |
+
# model name.
|
17 |
+
h.name = 'tf_efficientdet_d1'
|
18 |
+
|
19 |
+
h.backbone_name = 'tf_efficientnet_b1'
|
20 |
+
h.backbone_args = None # FIXME sort out kwargs vs config for backbone creation
|
21 |
+
|
22 |
+
# model specific, input preprocessing parameters
|
23 |
+
h.image_size = (640, 640)
|
24 |
+
|
25 |
+
# dataset specific head parameters
|
26 |
+
h.num_classes = 90
|
27 |
+
|
28 |
+
# feature + anchor config
|
29 |
+
h.min_level = 3
|
30 |
+
h.max_level = 7
|
31 |
+
h.num_levels = h.max_level - h.min_level + 1
|
32 |
+
h.num_scales = 3
|
33 |
+
h.aspect_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
|
34 |
+
# ratio w/h: 2.0 means w=1.4, h=0.7. Can be computed with k-mean per dataset.
|
35 |
+
#h.aspect_ratios = [1.0, 2.0, 0.5]
|
36 |
+
h.anchor_scale = 4.0
|
37 |
+
|
38 |
+
# FPN and head config
|
39 |
+
h.pad_type = 'same' # original TF models require an equivalent of Tensorflow 'SAME' padding
|
40 |
+
h.act_type = 'swish'
|
41 |
+
h.norm_layer = None # defaults to batch norm when None
|
42 |
+
h.norm_kwargs = dict(eps=.001, momentum=.01)
|
43 |
+
h.box_class_repeats = 3
|
44 |
+
h.fpn_cell_repeats = 3
|
45 |
+
h.fpn_channels = 88
|
46 |
+
h.separable_conv = True
|
47 |
+
h.apply_bn_for_resampling = True
|
48 |
+
h.conv_after_downsample = False
|
49 |
+
h.conv_bn_relu_pattern = False
|
50 |
+
h.use_native_resize_op = False
|
51 |
+
h.pooling_type = None
|
52 |
+
h.redundant_bias = True # original TF models have back to back bias + BN layers, not necessary!
|
53 |
+
h.head_bn_level_first = False # change order of BN in head repeat list of lists, True for torchscript compat
|
54 |
+
|
55 |
+
h.fpn_name = None
|
56 |
+
h.fpn_config = None
|
57 |
+
h.fpn_drop_path_rate = 0. # No stochastic depth in default. NOTE not currently used, unstable training
|
58 |
+
|
59 |
+
# classification loss (used by train bench)
|
60 |
+
h.alpha = 0.25
|
61 |
+
h.gamma = 1.5
|
62 |
+
h.label_smoothing = 0. # only supported if new_focal == True
|
63 |
+
h.new_focal = False # use new focal loss (supports label smoothing but uses more mem, less optimal w/ jit script)
|
64 |
+
h.jit_loss = False # torchscript jit for loss fn speed improvement, can impact stability and/or increase mem usage
|
65 |
+
|
66 |
+
# localization loss (used by train bench)
|
67 |
+
h.delta = 0.1
|
68 |
+
h.box_loss_weight = 50.0
|
69 |
+
|
70 |
+
return h
|
71 |
+
|
72 |
+
|
73 |
+
efficientdet_model_param_dict = dict(
|
74 |
+
# Models with PyTorch friendly padding and my PyTorch pretrained backbones, training TBD
|
75 |
+
efficientdet_d0=dict(
|
76 |
+
name='efficientdet_d0',
|
77 |
+
backbone_name='efficientnet_b0',
|
78 |
+
image_size=(512, 512),
|
79 |
+
fpn_channels=64,
|
80 |
+
fpn_cell_repeats=3,
|
81 |
+
box_class_repeats=3,
|
82 |
+
pad_type='',
|
83 |
+
redundant_bias=False,
|
84 |
+
backbone_args=dict(drop_path_rate=0.1),
|
85 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_d0-f3276ba8.pth',
|
86 |
+
),
|
87 |
+
efficientdet_d1=dict(
|
88 |
+
name='efficientdet_d1',
|
89 |
+
backbone_name='efficientnet_b1',
|
90 |
+
image_size=(640, 640),
|
91 |
+
fpn_channels=88,
|
92 |
+
fpn_cell_repeats=4,
|
93 |
+
box_class_repeats=3,
|
94 |
+
pad_type='',
|
95 |
+
redundant_bias=False,
|
96 |
+
backbone_args=dict(drop_path_rate=0.2),
|
97 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_d1-bb7e98fe.pth',
|
98 |
+
),
|
99 |
+
efficientdet_d2=dict(
|
100 |
+
name='efficientdet_d2',
|
101 |
+
backbone_name='efficientnet_b2',
|
102 |
+
image_size=(768, 768),
|
103 |
+
fpn_channels=112,
|
104 |
+
fpn_cell_repeats=5,
|
105 |
+
box_class_repeats=3,
|
106 |
+
pad_type='',
|
107 |
+
redundant_bias=False,
|
108 |
+
backbone_args=dict(drop_path_rate=0.2),
|
109 |
+
url='', # no pretrained weights yet
|
110 |
+
),
|
111 |
+
efficientdet_d3=dict(
|
112 |
+
name='efficientdet_d3',
|
113 |
+
backbone_name='efficientnet_b3',
|
114 |
+
image_size=(896, 896),
|
115 |
+
fpn_channels=160,
|
116 |
+
fpn_cell_repeats=6,
|
117 |
+
box_class_repeats=4,
|
118 |
+
pad_type='',
|
119 |
+
redundant_bias=False,
|
120 |
+
backbone_args=dict(drop_path_rate=0.2),
|
121 |
+
url='', # no pretrained weights yet
|
122 |
+
),
|
123 |
+
efficientdet_d4=dict(
|
124 |
+
name='efficientdet_d4',
|
125 |
+
backbone_name='efficientnet_b4',
|
126 |
+
image_size=(1024, 1024),
|
127 |
+
fpn_channels=224,
|
128 |
+
fpn_cell_repeats=7,
|
129 |
+
box_class_repeats=4,
|
130 |
+
backbone_args=dict(drop_path_rate=0.2),
|
131 |
+
),
|
132 |
+
efficientdet_d5=dict(
|
133 |
+
name='efficientdet_d5',
|
134 |
+
backbone_name='efficientnet_b5',
|
135 |
+
image_size=(1280, 1280),
|
136 |
+
fpn_channels=288,
|
137 |
+
fpn_cell_repeats=7,
|
138 |
+
box_class_repeats=4,
|
139 |
+
backbone_args=dict(drop_path_rate=0.2),
|
140 |
+
url='',
|
141 |
+
),
|
142 |
+
|
143 |
+
# My own experimental configs with alternate models, training TBD
|
144 |
+
# Note: any 'timm' model in the EfficientDet family can be used as a backbone here.
|
145 |
+
resdet50=dict(
|
146 |
+
name='resdet50',
|
147 |
+
backbone_name='resnet50',
|
148 |
+
image_size=(640, 640),
|
149 |
+
fpn_channels=88,
|
150 |
+
fpn_cell_repeats=4,
|
151 |
+
box_class_repeats=3,
|
152 |
+
pad_type='',
|
153 |
+
act_type='relu',
|
154 |
+
redundant_bias=False,
|
155 |
+
separable_conv=False,
|
156 |
+
backbone_args=dict(drop_path_rate=0.2),
|
157 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/resdet50_416-08676892.pth',
|
158 |
+
),
|
159 |
+
cspresdet50=dict(
|
160 |
+
name='cspresdet50',
|
161 |
+
backbone_name='cspresnet50',
|
162 |
+
image_size=(640, 640),
|
163 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
164 |
+
fpn_channels=88,
|
165 |
+
fpn_cell_repeats=4,
|
166 |
+
box_class_repeats=3,
|
167 |
+
pad_type='',
|
168 |
+
act_type='leaky_relu',
|
169 |
+
redundant_bias=False,
|
170 |
+
separable_conv=False,
|
171 |
+
head_bn_level_first=True,
|
172 |
+
backbone_args=dict(drop_path_rate=0.2),
|
173 |
+
url='',
|
174 |
+
),
|
175 |
+
cspresdext50=dict(
|
176 |
+
name='cspresdext50',
|
177 |
+
backbone_name='cspresnext50',
|
178 |
+
image_size=(640, 640),
|
179 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
180 |
+
fpn_channels=88,
|
181 |
+
fpn_cell_repeats=4,
|
182 |
+
box_class_repeats=3,
|
183 |
+
pad_type='',
|
184 |
+
act_type='leaky_relu',
|
185 |
+
redundant_bias=False,
|
186 |
+
separable_conv=False,
|
187 |
+
head_bn_level_first=True,
|
188 |
+
backbone_args=dict(drop_path_rate=0.2),
|
189 |
+
url='',
|
190 |
+
),
|
191 |
+
cspresdext50pan=dict(
|
192 |
+
name='cspresdext50pan',
|
193 |
+
backbone_name='cspresnext50',
|
194 |
+
image_size=(640, 640),
|
195 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
196 |
+
fpn_channels=88,
|
197 |
+
fpn_cell_repeats=3,
|
198 |
+
box_class_repeats=3,
|
199 |
+
pad_type='',
|
200 |
+
act_type='leaky_relu',
|
201 |
+
fpn_name='pan_fa', # PAN FPN experiment
|
202 |
+
redundant_bias=False,
|
203 |
+
separable_conv=False,
|
204 |
+
head_bn_level_first=True,
|
205 |
+
backbone_args=dict(drop_path_rate=0.2),
|
206 |
+
url='',
|
207 |
+
),
|
208 |
+
cspdarkdet53=dict(
|
209 |
+
name='cspdarkdet53',
|
210 |
+
backbone_name='cspdarknet53',
|
211 |
+
image_size=(640, 640),
|
212 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
213 |
+
fpn_channels=88,
|
214 |
+
fpn_cell_repeats=4,
|
215 |
+
box_class_repeats=3,
|
216 |
+
pad_type='',
|
217 |
+
act_type='leaky_relu',
|
218 |
+
redundant_bias=False,
|
219 |
+
separable_conv=False,
|
220 |
+
head_bn_level_first=True,
|
221 |
+
backbone_args=dict(drop_path_rate=0.2),
|
222 |
+
url='',
|
223 |
+
),
|
224 |
+
mixdet_m=dict(
|
225 |
+
name='mixdet_m',
|
226 |
+
backbone_name='mixnet_m',
|
227 |
+
image_size=(512, 512),
|
228 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
229 |
+
fpn_channels=64,
|
230 |
+
fpn_cell_repeats=3,
|
231 |
+
box_class_repeats=3,
|
232 |
+
pad_type='',
|
233 |
+
redundant_bias=False,
|
234 |
+
head_bn_level_first=True,
|
235 |
+
backbone_args=dict(drop_path_rate=0.1),
|
236 |
+
url='', # no pretrained weights yet
|
237 |
+
),
|
238 |
+
mixdet_l=dict(
|
239 |
+
name='mixdet_l',
|
240 |
+
backbone_name='mixnet_l',
|
241 |
+
image_size=(640, 640),
|
242 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
243 |
+
fpn_channels=88,
|
244 |
+
fpn_cell_repeats=4,
|
245 |
+
box_class_repeats=3,
|
246 |
+
pad_type='',
|
247 |
+
redundant_bias=False,
|
248 |
+
head_bn_level_first=True,
|
249 |
+
backbone_args=dict(drop_path_rate=0.2),
|
250 |
+
url='', # no pretrained weights yet
|
251 |
+
),
|
252 |
+
mobiledetv2_110d=dict(
|
253 |
+
name='mobiledetv2_110d',
|
254 |
+
backbone_name='mobilenetv2_110d',
|
255 |
+
image_size=(384, 384),
|
256 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
257 |
+
fpn_channels=48,
|
258 |
+
fpn_cell_repeats=3,
|
259 |
+
box_class_repeats=3,
|
260 |
+
pad_type='',
|
261 |
+
act_type='relu6',
|
262 |
+
redundant_bias=False,
|
263 |
+
head_bn_level_first=True,
|
264 |
+
backbone_args=dict(drop_path_rate=0.05),
|
265 |
+
url='', # no pretrained weights yet
|
266 |
+
),
|
267 |
+
mobiledetv2_120d=dict(
|
268 |
+
name='mobiledetv2_120d',
|
269 |
+
backbone_name='mobilenetv2_120d',
|
270 |
+
image_size=(512, 512),
|
271 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
272 |
+
fpn_channels=56,
|
273 |
+
fpn_cell_repeats=3,
|
274 |
+
box_class_repeats=3,
|
275 |
+
pad_type='',
|
276 |
+
act_type='relu6',
|
277 |
+
redundant_bias=False,
|
278 |
+
head_bn_level_first=True,
|
279 |
+
backbone_args=dict(drop_path_rate=0.1),
|
280 |
+
url='', # no pretrained weights yet
|
281 |
+
),
|
282 |
+
mobiledetv3_large=dict(
|
283 |
+
name='mobiledetv3_large',
|
284 |
+
backbone_name='mobilenetv3_large_100',
|
285 |
+
image_size=(512, 512),
|
286 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
287 |
+
fpn_channels=64,
|
288 |
+
fpn_cell_repeats=3,
|
289 |
+
box_class_repeats=3,
|
290 |
+
pad_type='',
|
291 |
+
act_type='hard_swish',
|
292 |
+
redundant_bias=False,
|
293 |
+
head_bn_level_first=True,
|
294 |
+
backbone_args=dict(drop_path_rate=0.1),
|
295 |
+
url='', # no pretrained weights yet
|
296 |
+
),
|
297 |
+
efficientdet_q0=dict(
|
298 |
+
name='efficientdet_q0',
|
299 |
+
backbone_name='efficientnet_b0',
|
300 |
+
image_size=(512, 512),
|
301 |
+
fpn_channels=64,
|
302 |
+
fpn_cell_repeats=3,
|
303 |
+
box_class_repeats=3,
|
304 |
+
pad_type='',
|
305 |
+
fpn_name='qufpn_fa', # quad-fpn + fast attn experiment
|
306 |
+
redundant_bias=False,
|
307 |
+
head_bn_level_first=True,
|
308 |
+
backbone_args=dict(drop_path_rate=0.1),
|
309 |
+
url='',
|
310 |
+
),
|
311 |
+
efficientdet_w0=dict(
|
312 |
+
name='efficientdet_w0', # 'wide'
|
313 |
+
backbone_name='efficientnet_b0',
|
314 |
+
image_size=(512, 512),
|
315 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
316 |
+
fpn_channels=80,
|
317 |
+
fpn_cell_repeats=3,
|
318 |
+
box_class_repeats=3,
|
319 |
+
pad_type='',
|
320 |
+
redundant_bias=False,
|
321 |
+
head_bn_level_first=True,
|
322 |
+
backbone_args=dict(
|
323 |
+
drop_path_rate=0.1,
|
324 |
+
feature_location='depthwise'), # features from after DW/SE in IR block
|
325 |
+
url='', # no pretrained weights yet
|
326 |
+
),
|
327 |
+
efficientdet_es=dict(
|
328 |
+
name='efficientdet_es', #EdgeTPU-Small
|
329 |
+
backbone_name='efficientnet_es',
|
330 |
+
image_size=(512, 512),
|
331 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
332 |
+
fpn_channels=72,
|
333 |
+
fpn_cell_repeats=3,
|
334 |
+
box_class_repeats=3,
|
335 |
+
pad_type='',
|
336 |
+
act_type='relu',
|
337 |
+
redundant_bias=False,
|
338 |
+
head_bn_level_first=True,
|
339 |
+
separable_conv=False,
|
340 |
+
backbone_args=dict(drop_path_rate=0.1),
|
341 |
+
url='',
|
342 |
+
),
|
343 |
+
efficientdet_em=dict(
|
344 |
+
name='efficientdet_em', # Edge-TPU Medium
|
345 |
+
backbone_name='efficientnet_em',
|
346 |
+
image_size=(640, 640),
|
347 |
+
aspect_ratios=[1.0, 2.0, 0.5],
|
348 |
+
fpn_channels=96,
|
349 |
+
fpn_cell_repeats=4,
|
350 |
+
box_class_repeats=3,
|
351 |
+
pad_type='',
|
352 |
+
act_type='relu',
|
353 |
+
redundant_bias=False,
|
354 |
+
head_bn_level_first=True,
|
355 |
+
separable_conv=False,
|
356 |
+
backbone_args=dict(drop_path_rate=0.2),
|
357 |
+
url='', # no pretrained weights yet
|
358 |
+
),
|
359 |
+
efficientdet_lite0=dict(
|
360 |
+
name='efficientdet_lite0',
|
361 |
+
backbone_name='efficientnet_lite0',
|
362 |
+
image_size=(512, 512),
|
363 |
+
fpn_channels=64,
|
364 |
+
fpn_cell_repeats=3,
|
365 |
+
box_class_repeats=3,
|
366 |
+
act_type='relu',
|
367 |
+
redundant_bias=False,
|
368 |
+
head_bn_level_first=True,
|
369 |
+
backbone_args=dict(drop_path_rate=0.1),
|
370 |
+
url='',
|
371 |
+
),
|
372 |
+
|
373 |
+
# Models ported from Tensorflow with pretrained backbones ported from Tensorflow
|
374 |
+
tf_efficientdet_d0=dict(
|
375 |
+
name='tf_efficientdet_d0',
|
376 |
+
backbone_name='tf_efficientnet_b0',
|
377 |
+
image_size=(512, 512),
|
378 |
+
fpn_channels=64,
|
379 |
+
fpn_cell_repeats=3,
|
380 |
+
box_class_repeats=3,
|
381 |
+
backbone_args=dict(drop_path_rate=0.2),
|
382 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d0_34-f153e0cf.pth',
|
383 |
+
),
|
384 |
+
tf_efficientdet_d1=dict(
|
385 |
+
name='tf_efficientdet_d1',
|
386 |
+
backbone_name='tf_efficientnet_b1',
|
387 |
+
image_size=(640, 640),
|
388 |
+
fpn_channels=88,
|
389 |
+
fpn_cell_repeats=4,
|
390 |
+
box_class_repeats=3,
|
391 |
+
backbone_args=dict(drop_path_rate=0.2),
|
392 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d1_40-a30f94af.pth'
|
393 |
+
),
|
394 |
+
tf_efficientdet_d2=dict(
|
395 |
+
name='tf_efficientdet_d2',
|
396 |
+
backbone_name='tf_efficientnet_b2',
|
397 |
+
image_size=(768, 768),
|
398 |
+
fpn_channels=112,
|
399 |
+
fpn_cell_repeats=5,
|
400 |
+
box_class_repeats=3,
|
401 |
+
backbone_args=dict(drop_path_rate=0.2),
|
402 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d2_43-8107aa99.pth',
|
403 |
+
),
|
404 |
+
tf_efficientdet_d3=dict(
|
405 |
+
name='tf_efficientdet_d3',
|
406 |
+
backbone_name='tf_efficientnet_b3',
|
407 |
+
image_size=(896, 896),
|
408 |
+
fpn_channels=160,
|
409 |
+
fpn_cell_repeats=6,
|
410 |
+
box_class_repeats=4,
|
411 |
+
backbone_args=dict(drop_path_rate=0.2),
|
412 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d3_47-0b525f35.pth',
|
413 |
+
),
|
414 |
+
tf_efficientdet_d4=dict(
|
415 |
+
name='tf_efficientdet_d4',
|
416 |
+
backbone_name='tf_efficientnet_b4',
|
417 |
+
image_size=(1024, 1024),
|
418 |
+
fpn_channels=224,
|
419 |
+
fpn_cell_repeats=7,
|
420 |
+
box_class_repeats=4,
|
421 |
+
backbone_args=dict(drop_path_rate=0.2),
|
422 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d4_49-f56376d9.pth',
|
423 |
+
),
|
424 |
+
tf_efficientdet_d5=dict(
|
425 |
+
name='tf_efficientdet_d5',
|
426 |
+
backbone_name='tf_efficientnet_b5',
|
427 |
+
image_size=(1280, 1280),
|
428 |
+
fpn_channels=288,
|
429 |
+
fpn_cell_repeats=7,
|
430 |
+
box_class_repeats=4,
|
431 |
+
backbone_args=dict(drop_path_rate=0.2),
|
432 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d5_51-c79f9be6.pth',
|
433 |
+
),
|
434 |
+
tf_efficientdet_d6=dict(
|
435 |
+
name='tf_efficientdet_d6',
|
436 |
+
backbone_name='tf_efficientnet_b6',
|
437 |
+
image_size=(1280, 1280),
|
438 |
+
fpn_channels=384,
|
439 |
+
fpn_cell_repeats=8,
|
440 |
+
box_class_repeats=5,
|
441 |
+
fpn_name='bifpn_sum', # Use unweighted sum for training stability.
|
442 |
+
backbone_args=dict(drop_path_rate=0.2),
|
443 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d6_52-4eda3773.pth'
|
444 |
+
),
|
445 |
+
tf_efficientdet_d7=dict(
|
446 |
+
name='tf_efficientdet_d7',
|
447 |
+
backbone_name='tf_efficientnet_b6',
|
448 |
+
image_size=(1536, 1536),
|
449 |
+
fpn_channels=384,
|
450 |
+
fpn_cell_repeats=8,
|
451 |
+
box_class_repeats=5,
|
452 |
+
anchor_scale=5.0,
|
453 |
+
fpn_name='bifpn_sum', # Use unweighted sum for training stability.
|
454 |
+
backbone_args=dict(drop_path_rate=0.2),
|
455 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d7_53-6d1d7a95.pth'
|
456 |
+
),
|
457 |
+
tf_efficientdet_d7x=dict(
|
458 |
+
name='tf_efficientdet_d7x',
|
459 |
+
backbone_name='tf_efficientnet_b7',
|
460 |
+
image_size=(1536, 1536),
|
461 |
+
fpn_channels=384,
|
462 |
+
fpn_cell_repeats=8,
|
463 |
+
box_class_repeats=5,
|
464 |
+
anchor_scale=4.0,
|
465 |
+
max_level=8,
|
466 |
+
fpn_name='bifpn_sum', # Use unweighted sum for training stability.
|
467 |
+
backbone_args=dict(drop_path_rate=0.2),
|
468 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d7x-f390b87c.pth'
|
469 |
+
),
|
470 |
+
|
471 |
+
# The lite configs are in TF automl repository but no weights yet and listed as 'not final'
|
472 |
+
tf_efficientdet_lite0=dict(
|
473 |
+
name='tf_efficientdet_lite0',
|
474 |
+
backbone_name='tf_efficientnet_lite0',
|
475 |
+
image_size=(512, 512),
|
476 |
+
fpn_channels=64,
|
477 |
+
fpn_cell_repeats=3,
|
478 |
+
box_class_repeats=3,
|
479 |
+
act_type='relu',
|
480 |
+
redundant_bias=False,
|
481 |
+
backbone_args=dict(drop_path_rate=0.1),
|
482 |
+
# unlike other tf_ models, this was not ported from tf automl impl, but trained from tf pretrained efficient lite
|
483 |
+
# weights using this code, will likely replace if/when official det-lite weights are released
|
484 |
+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_lite0-f5f303a9.pth',
|
485 |
+
),
|
486 |
+
tf_efficientdet_lite1=dict(
|
487 |
+
name='tf_efficientdet_lite1',
|
488 |
+
backbone_name='tf_efficientnet_lite1',
|
489 |
+
image_size=(640, 640),
|
490 |
+
fpn_channels=88,
|
491 |
+
fpn_cell_repeats=4,
|
492 |
+
box_class_repeats=3,
|
493 |
+
act_type='relu',
|
494 |
+
backbone_args=dict(drop_path_rate=0.2),
|
495 |
+
url='', # no pretrained weights yet
|
496 |
+
),
|
497 |
+
tf_efficientdet_lite2=dict(
|
498 |
+
name='tf_efficientdet_lite2',
|
499 |
+
backbone_name='tf_efficientnet_lite2',
|
500 |
+
image_size=(768, 768),
|
501 |
+
fpn_channels=112,
|
502 |
+
fpn_cell_repeats=5,
|
503 |
+
box_class_repeats=3,
|
504 |
+
act_type='relu',
|
505 |
+
backbone_args=dict(drop_path_rate=0.2),
|
506 |
+
url='',
|
507 |
+
),
|
508 |
+
tf_efficientdet_lite3=dict(
|
509 |
+
name='tf_efficientdet_lite3',
|
510 |
+
backbone_name='tf_efficientnet_lite3',
|
511 |
+
image_size=(896, 896),
|
512 |
+
fpn_channels=160,
|
513 |
+
fpn_cell_repeats=6,
|
514 |
+
box_class_repeats=4,
|
515 |
+
act_type='relu',
|
516 |
+
backbone_args=dict(drop_path_rate=0.2),
|
517 |
+
url='',
|
518 |
+
),
|
519 |
+
tf_efficientdet_lite4=dict(
|
520 |
+
name='tf_efficientdet_lite4',
|
521 |
+
backbone_name='tf_efficientnet_lite4',
|
522 |
+
image_size=(1024, 1024),
|
523 |
+
fpn_channels=224,
|
524 |
+
fpn_cell_repeats=7,
|
525 |
+
box_class_repeats=4,
|
526 |
+
act_type='relu',
|
527 |
+
backbone_args=dict(drop_path_rate=0.2),
|
528 |
+
url='',
|
529 |
+
),
|
530 |
+
)
|
531 |
+
|
532 |
+
|
533 |
+
def get_efficientdet_config(model_name='tf_efficientdet_d1'):
|
534 |
+
"""Get the default config for EfficientDet based on model name."""
|
535 |
+
h = default_detection_model_configs()
|
536 |
+
h.update(efficientdet_model_param_dict[model_name])
|
537 |
+
h.num_levels = h.max_level - h.min_level + 1
|
538 |
+
return deepcopy(h) # may be unnecessary, ensure no references to param dict values
|
efficientdet/effdet/config/train_config.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import OmegaConf
|
2 |
+
|
3 |
+
|
4 |
+
def default_detection_train_config():
|
5 |
+
# FIXME currently using args for train config, will revisit, perhaps move to Hydra
|
6 |
+
h = OmegaConf.create()
|
7 |
+
|
8 |
+
# dataset
|
9 |
+
h.skip_crowd_during_training = True
|
10 |
+
|
11 |
+
# augmentation
|
12 |
+
h.input_rand_hflip = True
|
13 |
+
h.train_scale_min = 0.1
|
14 |
+
h.train_scale_max = 2.0
|
15 |
+
h.autoaugment_policy = None
|
16 |
+
|
17 |
+
# optimization
|
18 |
+
h.momentum = 0.9
|
19 |
+
h.learning_rate = 0.08
|
20 |
+
h.lr_warmup_init = 0.008
|
21 |
+
h.lr_warmup_epoch = 1.0
|
22 |
+
h.first_lr_drop_epoch = 200.0
|
23 |
+
h.second_lr_drop_epoch = 250.0
|
24 |
+
h.clip_gradients_norm = 10.0
|
25 |
+
h.num_epochs = 300
|
26 |
+
|
27 |
+
# regularization l2 loss.
|
28 |
+
h.weight_decay = 4e-5
|
29 |
+
|
30 |
+
h.lr_decay_method = 'cosine'
|
31 |
+
h.moving_average_decay = 0.9998
|
32 |
+
h.ckpt_var_scope = None
|
33 |
+
|
34 |
+
return h
|
efficientdet/effdet/data/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset_factory import create_dataset
|
2 |
+
from .dataset import DetectionDatset, SkipSubset
|
3 |
+
from .input_config import resolve_input_config
|
4 |
+
from .loader import create_loader
|
5 |
+
from .parsers import create_parser
|
6 |
+
from .transforms import *
|
efficientdet/effdet/data/dataset.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Detection dataset
|
2 |
+
|
3 |
+
Hacked together by Ross Wightman
|
4 |
+
"""
|
5 |
+
import torch.utils.data as data
|
6 |
+
import numpy as np
|
7 |
+
import albumentations as A
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
from .parsers import create_parser
|
12 |
+
|
13 |
+
|
14 |
+
class DetectionDatset(data.Dataset):
|
15 |
+
"""`Object Detection Dataset. Use with parsers for COCO, VOC, and OpenImages.
|
16 |
+
Args:
|
17 |
+
parser (string, Parser):
|
18 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
19 |
+
and returns a transformed version. E.g, ``transforms.ToTensor``
|
20 |
+
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, data_dir, parser=None, parser_kwargs=None, transform=None, transforms=None):
|
24 |
+
super(DetectionDatset, self).__init__()
|
25 |
+
parser_kwargs = parser_kwargs or {}
|
26 |
+
self.data_dir = data_dir
|
27 |
+
if isinstance(parser, str):
|
28 |
+
self._parser = create_parser(parser, **parser_kwargs)
|
29 |
+
else:
|
30 |
+
assert parser is not None and len(parser.img_ids)
|
31 |
+
self._parser = parser
|
32 |
+
self._transform = transform
|
33 |
+
self._transforms = transforms
|
34 |
+
|
35 |
+
def __getitem__(self, index):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
index (int): Index
|
39 |
+
Returns:
|
40 |
+
tuple: Tuple (image, annotations (target)).
|
41 |
+
"""
|
42 |
+
img_info = self._parser.img_infos[index]
|
43 |
+
target = dict(img_idx=index, img_size=(img_info['width'], img_info['height']))
|
44 |
+
if self._parser.has_labels:
|
45 |
+
ann = self._parser.get_ann_info(index)
|
46 |
+
target.update(ann)
|
47 |
+
img_path = self.data_dir / img_info['file_name']
|
48 |
+
img = Image.open(img_path).convert('RGB')
|
49 |
+
if self.transforms is not None:
|
50 |
+
img = torch.as_tensor(np.array(img), dtype=torch.uint8)
|
51 |
+
voc_boxes = []
|
52 |
+
for coord in target['bbox']:
|
53 |
+
xmin = coord[1]
|
54 |
+
ymin = coord[0]
|
55 |
+
xmax = coord[3]
|
56 |
+
ymax = coord[2]
|
57 |
+
if xmin<1:
|
58 |
+
xmin = 1
|
59 |
+
if ymin<1:
|
60 |
+
ymin = 1
|
61 |
+
if xmax>=img.shape[1]-1:
|
62 |
+
xmax = img.shape[1]-1
|
63 |
+
if ymax>=img.shape[0]-1:
|
64 |
+
ymax = img.shape[0]-1
|
65 |
+
voc_boxes.append([xmin, ymin, xmax, ymax])
|
66 |
+
transformed = self.transforms(image=np.array(img), bbox_classes=target['cls'], bboxes=voc_boxes)
|
67 |
+
img = torch.as_tensor(transformed['image'], dtype=torch.uint8)
|
68 |
+
target['bbox'] = []
|
69 |
+
for coord in transformed['bboxes']:
|
70 |
+
ymin = int(coord[1])
|
71 |
+
xmin = int(coord[0])
|
72 |
+
ymax = int(coord[3])
|
73 |
+
xmax = int(coord[2])
|
74 |
+
target['bbox'].append([ymin, xmin, ymax, xmax])
|
75 |
+
target['bbox'] = np.array(target['bbox'], dtype=np.float32)
|
76 |
+
target['cls'] = np.array(transformed['bbox_classes'])
|
77 |
+
img = Image.fromarray(np.array(img).astype('uint8'), 'RGB')
|
78 |
+
target['img_size'] = img.size
|
79 |
+
|
80 |
+
if self.transform is not None:
|
81 |
+
img, target = self.transform(img, target)
|
82 |
+
|
83 |
+
return img, target
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self._parser.img_ids)
|
87 |
+
|
88 |
+
@property
|
89 |
+
def parser(self):
|
90 |
+
return self._parser
|
91 |
+
|
92 |
+
@property
|
93 |
+
def transform(self):
|
94 |
+
return self._transform
|
95 |
+
|
96 |
+
@transform.setter
|
97 |
+
def transform(self, t):
|
98 |
+
self._transform = t
|
99 |
+
|
100 |
+
@property
|
101 |
+
def transforms(self):
|
102 |
+
return self._transforms
|
103 |
+
|
104 |
+
@transforms.setter
|
105 |
+
def transforms(self, t):
|
106 |
+
self._transforms = t
|
107 |
+
|
108 |
+
class SkipSubset(data.Dataset):
|
109 |
+
r"""
|
110 |
+
Subset of a dataset at specified indices.
|
111 |
+
|
112 |
+
Arguments:
|
113 |
+
dataset (Dataset): The whole Dataset
|
114 |
+
n (int): skip rate (select every nth)
|
115 |
+
"""
|
116 |
+
def __init__(self, dataset, n=2):
|
117 |
+
self.dataset = dataset
|
118 |
+
assert n >= 1
|
119 |
+
self.indices = np.arange(len(dataset))[::n]
|
120 |
+
|
121 |
+
def __getitem__(self, idx):
|
122 |
+
return self.dataset[self.indices[idx]]
|
123 |
+
|
124 |
+
def __len__(self):
|
125 |
+
return len(self.indices)
|
126 |
+
|
127 |
+
@property
|
128 |
+
def parser(self):
|
129 |
+
return self.dataset.parser
|
130 |
+
|
131 |
+
@property
|
132 |
+
def transform(self):
|
133 |
+
return self.dataset.transform
|
134 |
+
|
135 |
+
@transform.setter
|
136 |
+
def transform(self, t):
|
137 |
+
self.dataset.transform = t
|
138 |
+
|
139 |
+
@property
|
140 |
+
def transforms(self):
|
141 |
+
return self.dataset.transforms
|
142 |
+
|
143 |
+
@transforms.setter
|
144 |
+
def transforms(self, t):
|
145 |
+
self.dataset.transforms = t
|
efficientdet/effdet/data/dataset_config.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" COCO detect-waste dataset configurations
|
2 |
+
|
3 |
+
Updated 2021 Wimlds in Detect Waste in Pomerania
|
4 |
+
"""
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Dict
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class CocoCfg:
|
11 |
+
variant: str = None
|
12 |
+
parser: str = 'coco'
|
13 |
+
num_classes: int = 80
|
14 |
+
splits: Dict[str, dict] = None
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class TACOCfg(CocoCfg):
|
19 |
+
root: str = ""
|
20 |
+
ann: str = ""
|
21 |
+
variant: str = '2017'
|
22 |
+
num_classes: int = 28
|
23 |
+
|
24 |
+
def add_split(self):
|
25 |
+
self.splits = {
|
26 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
27 |
+
'img_dir': self.root,
|
28 |
+
'has_labels': True},
|
29 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
30 |
+
'img_dir': self.root,
|
31 |
+
'has_labels': True}
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class DetectwasteCfg(CocoCfg):
|
37 |
+
root: str = ""
|
38 |
+
ann: str = ""
|
39 |
+
variant: str = '2017'
|
40 |
+
num_classes: int = 7
|
41 |
+
|
42 |
+
def add_split(self):
|
43 |
+
self.splits = {
|
44 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
45 |
+
'img_dir': self.root,
|
46 |
+
'has_labels': True},
|
47 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
48 |
+
'img_dir': self.root,
|
49 |
+
'has_labels': True}
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
@dataclass
|
54 |
+
class BinaryCfg(CocoCfg):
|
55 |
+
root: str = ""
|
56 |
+
ann: str = ""
|
57 |
+
variant: str = '2017'
|
58 |
+
num_classes: int = 1
|
59 |
+
|
60 |
+
def add_split(self):
|
61 |
+
self.splits = {
|
62 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
63 |
+
'img_dir': self.root,
|
64 |
+
'has_labels': True},
|
65 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
66 |
+
'img_dir': self.root,
|
67 |
+
'has_labels': True}
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class BinaryMultiCfg(CocoCfg):
|
73 |
+
root: str = ""
|
74 |
+
ann: str = ""
|
75 |
+
variant: str = '2017'
|
76 |
+
num_classes: int = 1
|
77 |
+
|
78 |
+
def add_split(self):
|
79 |
+
self.splits = {
|
80 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
81 |
+
'img_dir': self.root,
|
82 |
+
'has_labels': True},
|
83 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
84 |
+
'img_dir': self.root,
|
85 |
+
'has_labels': True}
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
@dataclass
|
90 |
+
class TrashCanCfg(CocoCfg):
|
91 |
+
root: str = ""
|
92 |
+
ann: str = ""
|
93 |
+
variant: str = '2017'
|
94 |
+
num_classes: int = 8
|
95 |
+
|
96 |
+
def add_split(self):
|
97 |
+
self.splits = {
|
98 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
99 |
+
'img_dir': self.root,
|
100 |
+
'has_labels': True},
|
101 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
102 |
+
'img_dir': self.root,
|
103 |
+
'has_labels': True}
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class UAVVasteCfg(CocoCfg):
|
109 |
+
root: str = ""
|
110 |
+
ann: str = ""
|
111 |
+
variant: str = '2017'
|
112 |
+
num_classes: int = 1
|
113 |
+
|
114 |
+
def add_split(self):
|
115 |
+
self.splits = {
|
116 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
117 |
+
'img_dir': self.root,
|
118 |
+
'has_labels': True},
|
119 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
120 |
+
'img_dir': self.root,
|
121 |
+
'has_labels': True}
|
122 |
+
}
|
123 |
+
|
124 |
+
|
125 |
+
@dataclass
|
126 |
+
class ICRACfg(CocoCfg):
|
127 |
+
root: str = ""
|
128 |
+
ann: str = ""
|
129 |
+
variant: str = '2017'
|
130 |
+
num_classes: int = 7
|
131 |
+
|
132 |
+
def add_split(self):
|
133 |
+
self.splits = {
|
134 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
135 |
+
'img_dir': self.root,
|
136 |
+
'has_labels': True},
|
137 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
138 |
+
'img_dir': self.root,
|
139 |
+
'has_labels': True}
|
140 |
+
}
|
141 |
+
|
142 |
+
|
143 |
+
@dataclass
|
144 |
+
class DrinkWasteCfg(CocoCfg):
|
145 |
+
root: str = ""
|
146 |
+
ann: str = ""
|
147 |
+
variant: str = '2017'
|
148 |
+
num_classes: int = 4
|
149 |
+
|
150 |
+
def add_split(self):
|
151 |
+
self.splits = {
|
152 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
153 |
+
'img_dir': self.root,
|
154 |
+
'has_labels': True},
|
155 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
156 |
+
'img_dir': self.root,
|
157 |
+
'has_labels': True}
|
158 |
+
}
|
159 |
+
|
160 |
+
|
161 |
+
@dataclass
|
162 |
+
class MJU_WasteCfg(CocoCfg):
|
163 |
+
root: str = ""
|
164 |
+
ann: str = ""
|
165 |
+
variant: str = '2017'
|
166 |
+
num_classes: int = 1
|
167 |
+
|
168 |
+
def add_split(self):
|
169 |
+
self.splits = {
|
170 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
171 |
+
'img_dir': self.root,
|
172 |
+
'has_labels': True},
|
173 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
174 |
+
'img_dir': self.root,
|
175 |
+
'has_labels': True}
|
176 |
+
}
|
177 |
+
|
178 |
+
|
179 |
+
@dataclass
|
180 |
+
class WadeCfg(CocoCfg):
|
181 |
+
root: str = ""
|
182 |
+
ann: str = ""
|
183 |
+
variant: str = '2017'
|
184 |
+
num_classes: int = 1
|
185 |
+
|
186 |
+
def add_split(self):
|
187 |
+
self.splits = {
|
188 |
+
'train': {'ann_filename': self.ann+'_train.json',
|
189 |
+
'img_dir': self.root,
|
190 |
+
'has_labels': True},
|
191 |
+
'val': {'ann_filename': self.ann+'_test.json',
|
192 |
+
'img_dir': self.root,
|
193 |
+
'has_labels': True}
|
194 |
+
}
|
efficientdet/effdet/data/dataset_factory.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Dataset factory
|
2 |
+
|
3 |
+
Updated 2021 Wimlds in Detect Waste in Pomerania
|
4 |
+
"""
|
5 |
+
from collections import OrderedDict
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from .dataset_config import *
|
9 |
+
from .parsers import *
|
10 |
+
from .dataset import DetectionDatset
|
11 |
+
from .parsers import create_parser
|
12 |
+
|
13 |
+
# list of detect-waste datasets
|
14 |
+
waste_datasets_list = ['taco', 'detectwaste', 'binary', 'multi',
|
15 |
+
'uav', 'mju', 'trashcan', 'wade', 'icra'
|
16 |
+
'drinkwaste']
|
17 |
+
|
18 |
+
|
19 |
+
def create_dataset(name, root, ann, splits=('train', 'val')):
|
20 |
+
if isinstance(splits, str):
|
21 |
+
splits = (splits,)
|
22 |
+
name = name.lower()
|
23 |
+
root = Path(root)
|
24 |
+
dataset_cls = DetectionDatset
|
25 |
+
datasets = OrderedDict()
|
26 |
+
if name.startswith('coco'):
|
27 |
+
if 'coco2014' in name:
|
28 |
+
dataset_cfg = Coco2014Cfg()
|
29 |
+
else:
|
30 |
+
dataset_cfg = Coco2017Cfg()
|
31 |
+
for s in splits:
|
32 |
+
if s not in dataset_cfg.splits:
|
33 |
+
raise RuntimeError(f'{s} split not found in config')
|
34 |
+
split_cfg = dataset_cfg.splits[s]
|
35 |
+
ann_file = root / split_cfg['ann_filename']
|
36 |
+
parser_cfg = CocoParserCfg(
|
37 |
+
ann_filename=ann_file,
|
38 |
+
has_labels=split_cfg['has_labels']
|
39 |
+
)
|
40 |
+
datasets[s] = dataset_cls(
|
41 |
+
data_dir=root / Path(split_cfg['img_dir']),
|
42 |
+
parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
|
43 |
+
)
|
44 |
+
datasets = OrderedDict()
|
45 |
+
elif name in waste_datasets_list:
|
46 |
+
if name.startswith('taco'):
|
47 |
+
dataset_cfg = TACOCfg(root=root, ann=ann)
|
48 |
+
elif name.startswith('detectwaste'):
|
49 |
+
dataset_cfg = DetectwasteCfg(root=root, ann=ann)
|
50 |
+
elif name.startswith('binary'):
|
51 |
+
dataset_cfg = BinaryCfg(root=root, ann=ann)
|
52 |
+
elif name.startswith('multi'):
|
53 |
+
dataset_cfg = BinaryMultiCfg(root=root, ann=ann)
|
54 |
+
elif name.startswith('uav'):
|
55 |
+
dataset_cfg = UAVVasteCfg(root=root, ann=ann)
|
56 |
+
elif name.startswith('trashcan'):
|
57 |
+
dataset_cfg = TrashCanCfg(root=root, ann=ann)
|
58 |
+
elif name.startswith('drinkwaste'):
|
59 |
+
dataset_cfg = DrinkWasteCfg(root=root, ann=ann)
|
60 |
+
elif name.startswith('mju'):
|
61 |
+
dataset_cfg = MJU_WasteCfg(root=root, ann=ann)
|
62 |
+
elif name.startswith('wade'):
|
63 |
+
dataset_cfg = WadeCfg(root=root, ann=ann)
|
64 |
+
elif name.startswith('icra'):
|
65 |
+
dataset_cfg = ICRACfg(root=root, ann=ann)
|
66 |
+
else:
|
67 |
+
assert False, f'Unknown dataset parser ({name})'
|
68 |
+
dataset_cfg.add_split()
|
69 |
+
for s in splits:
|
70 |
+
if s not in dataset_cfg.splits:
|
71 |
+
raise RuntimeError(f'{s} split not found in config')
|
72 |
+
split_cfg = dataset_cfg.splits[s]
|
73 |
+
parser_cfg = CocoParserCfg(
|
74 |
+
ann_filename=split_cfg['ann_filename'],
|
75 |
+
has_labels=split_cfg['has_labels']
|
76 |
+
)
|
77 |
+
datasets[s] = dataset_cls(
|
78 |
+
data_dir=split_cfg['img_dir'],
|
79 |
+
parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
assert False, f'Unknown dataset parser ({name})'
|
83 |
+
|
84 |
+
datasets = list(datasets.values())
|
85 |
+
return datasets if len(datasets) > 1 else datasets[0]
|
efficientdet/effdet/data/input_config.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
2 |
+
|
3 |
+
|
4 |
+
def resolve_input_config(args, model_config=None, model=None):
|
5 |
+
if not isinstance(args, dict):
|
6 |
+
args = vars(args)
|
7 |
+
input_config = {}
|
8 |
+
if not model_config and model is not None and hasattr(model, 'config'):
|
9 |
+
model_config = model.config
|
10 |
+
|
11 |
+
# Resolve input/image size
|
12 |
+
in_chans = 3
|
13 |
+
input_size = (in_chans, 512, 512)
|
14 |
+
|
15 |
+
if 'input_size' in model_config:
|
16 |
+
input_size = tuple(model_config['input_size'])
|
17 |
+
elif 'image_size' in model_config:
|
18 |
+
input_size = (in_chans,) + tuple(model_config['image_size'])
|
19 |
+
assert isinstance(input_size, tuple) and len(input_size) == 3
|
20 |
+
input_config['input_size'] = input_size
|
21 |
+
|
22 |
+
# resolve interpolation method
|
23 |
+
input_config['interpolation'] = 'bicubic'
|
24 |
+
if 'interpolation' in args and args['interpolation']:
|
25 |
+
input_config['interpolation'] = args['interpolation']
|
26 |
+
elif 'interpolation' in model_config:
|
27 |
+
input_config['interpolation'] = model_config['interpolation']
|
28 |
+
|
29 |
+
# resolve dataset + model mean for normalization
|
30 |
+
input_config['mean'] = IMAGENET_DEFAULT_MEAN
|
31 |
+
if 'mean' in args and args['mean'] is not None:
|
32 |
+
mean = tuple(args['mean'])
|
33 |
+
if len(mean) == 1:
|
34 |
+
mean = tuple(list(mean) * in_chans)
|
35 |
+
else:
|
36 |
+
assert len(mean) == in_chans
|
37 |
+
input_config['mean'] = mean
|
38 |
+
elif 'mean' in model_config:
|
39 |
+
input_config['mean'] = model_config['mean']
|
40 |
+
|
41 |
+
# resolve dataset + model std deviation for normalization
|
42 |
+
input_config['std'] = IMAGENET_DEFAULT_STD
|
43 |
+
if 'std' in args and args['std'] is not None:
|
44 |
+
std = tuple(args['std'])
|
45 |
+
if len(std) == 1:
|
46 |
+
std = tuple(list(std) * in_chans)
|
47 |
+
else:
|
48 |
+
assert len(std) == in_chans
|
49 |
+
input_config['std'] = std
|
50 |
+
elif 'std' in model_config:
|
51 |
+
input_config['std'] = model_config['std']
|
52 |
+
|
53 |
+
# resolve letterbox fill color
|
54 |
+
input_config['fill_color'] = 'mean'
|
55 |
+
if 'fill_color' in args and args['fill_color'] is not None:
|
56 |
+
input_config['fill_color'] = args['fill_color']
|
57 |
+
elif 'fill_color' in model_config:
|
58 |
+
input_config['fill_color'] = model_config['fill_color']
|
59 |
+
|
60 |
+
return input_config
|
efficientdet/effdet/data/loader.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Object detection loader/collate
|
2 |
+
|
3 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import torch.utils.data
|
6 |
+
from .transforms import *
|
7 |
+
from .transforms_albumentation import get_transform
|
8 |
+
from .random_erasing import RandomErasing
|
9 |
+
from effdet.anchors import AnchorLabeler
|
10 |
+
from timm.data.distributed_sampler import OrderedDistributedSampler
|
11 |
+
import os
|
12 |
+
|
13 |
+
MAX_NUM_INSTANCES = 100
|
14 |
+
|
15 |
+
|
16 |
+
class DetectionFastCollate:
|
17 |
+
""" A detection specific, optimized collate function w/ a bit of state.
|
18 |
+
|
19 |
+
Optionally performs anchor labelling. Doing this here offloads some work from the
|
20 |
+
GPU and the main training process thread and increases the load on the dataloader
|
21 |
+
threads.
|
22 |
+
|
23 |
+
"""
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
instance_keys=None,
|
27 |
+
instance_shapes=None,
|
28 |
+
instance_fill=-1,
|
29 |
+
max_instances=MAX_NUM_INSTANCES,
|
30 |
+
anchor_labeler=None,
|
31 |
+
):
|
32 |
+
instance_keys = instance_keys or {'bbox', 'bbox_ignore', 'cls'}
|
33 |
+
instance_shapes = instance_shapes or dict(
|
34 |
+
bbox=(max_instances, 4), bbox_ignore=(max_instances, 4), cls=(max_instances,))
|
35 |
+
self.instance_info = {k: dict(fill=instance_fill, shape=instance_shapes[k]) for k in instance_keys}
|
36 |
+
self.max_instances = max_instances
|
37 |
+
self.anchor_labeler = anchor_labeler
|
38 |
+
|
39 |
+
def __call__(self, batch):
|
40 |
+
batch_size = len(batch)
|
41 |
+
target = dict()
|
42 |
+
labeler_outputs = dict()
|
43 |
+
img_tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
44 |
+
for i in range(batch_size):
|
45 |
+
img_tensor[i] += torch.from_numpy(batch[i][0])
|
46 |
+
labeler_inputs = {}
|
47 |
+
for tk, tv in batch[i][1].items():
|
48 |
+
instance_info = self.instance_info.get(tk, None)
|
49 |
+
if instance_info is not None:
|
50 |
+
# target tensor is associated with a detection instance
|
51 |
+
tv = torch.from_numpy(tv).to(dtype=torch.float32)
|
52 |
+
if self.anchor_labeler is None:
|
53 |
+
if i == 0:
|
54 |
+
shape = (batch_size,) + instance_info['shape']
|
55 |
+
target_tensor = torch.full(shape, instance_info['fill'], dtype=torch.float32)
|
56 |
+
target[tk] = target_tensor
|
57 |
+
else:
|
58 |
+
target_tensor = target[tk]
|
59 |
+
num_elem = min(tv.shape[0], self.max_instances)
|
60 |
+
target_tensor[i, 0:num_elem] = tv[0:num_elem]
|
61 |
+
else:
|
62 |
+
# no need to pass gt tensors through when labeler in use
|
63 |
+
if tk in ('bbox', 'cls'):
|
64 |
+
labeler_inputs[tk] = tv
|
65 |
+
else:
|
66 |
+
# target tensor is an image-level annotation / metadata
|
67 |
+
if i == 0:
|
68 |
+
# first batch elem, create destination tensors
|
69 |
+
if isinstance(tv, (tuple, list)):
|
70 |
+
# per batch elem sequence
|
71 |
+
shape = (batch_size, len(tv))
|
72 |
+
dtype = torch.float32 if isinstance(tv[0], (float, np.floating)) else torch.int32
|
73 |
+
else:
|
74 |
+
# per batch elem scalar
|
75 |
+
shape = batch_size,
|
76 |
+
dtype = torch.float32 if isinstance(tv, (float, np.floating)) else torch.int64
|
77 |
+
target_tensor = torch.zeros(shape, dtype=dtype)
|
78 |
+
target[tk] = target_tensor
|
79 |
+
else:
|
80 |
+
target_tensor = target[tk]
|
81 |
+
target_tensor[i] = torch.tensor(tv, dtype=target_tensor.dtype)
|
82 |
+
|
83 |
+
if self.anchor_labeler is not None:
|
84 |
+
cls_targets, box_targets, num_positives = self.anchor_labeler.label_anchors(
|
85 |
+
labeler_inputs['bbox'], labeler_inputs['cls'], filter_valid=False)
|
86 |
+
if i == 0:
|
87 |
+
# first batch elem, create destination tensors, separate key per level
|
88 |
+
for j, (ct, bt) in enumerate(zip(cls_targets, box_targets)):
|
89 |
+
labeler_outputs[f'label_cls_{j}'] = torch.zeros(
|
90 |
+
(batch_size,) + ct.shape, dtype=torch.int64)
|
91 |
+
labeler_outputs[f'label_bbox_{j}'] = torch.zeros(
|
92 |
+
(batch_size,) + bt.shape, dtype=torch.float32)
|
93 |
+
labeler_outputs['label_num_positives'] = torch.zeros(batch_size)
|
94 |
+
for j, (ct, bt) in enumerate(zip(cls_targets, box_targets)):
|
95 |
+
labeler_outputs[f'label_cls_{j}'][i] = ct
|
96 |
+
labeler_outputs[f'label_bbox_{j}'][i] = bt
|
97 |
+
labeler_outputs['label_num_positives'][i] = num_positives
|
98 |
+
if labeler_outputs:
|
99 |
+
target.update(labeler_outputs)
|
100 |
+
|
101 |
+
return img_tensor, target
|
102 |
+
|
103 |
+
|
104 |
+
class PrefetchLoader:
|
105 |
+
|
106 |
+
def __init__(self,
|
107 |
+
loader,
|
108 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
109 |
+
std=IMAGENET_DEFAULT_STD,
|
110 |
+
re_prob=0.,
|
111 |
+
re_mode='pixel',
|
112 |
+
re_count=1,
|
113 |
+
):
|
114 |
+
self.loader = loader
|
115 |
+
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
116 |
+
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
117 |
+
if re_prob > 0.:
|
118 |
+
self.random_erasing = RandomErasing(probability=re_prob, mode=re_mode, max_count=re_count)
|
119 |
+
else:
|
120 |
+
self.random_erasing = None
|
121 |
+
|
122 |
+
def __iter__(self):
|
123 |
+
stream = torch.cuda.Stream()
|
124 |
+
first = True
|
125 |
+
|
126 |
+
for next_input, next_target in self.loader:
|
127 |
+
with torch.cuda.stream(stream):
|
128 |
+
next_input = next_input.cuda(non_blocking=True)
|
129 |
+
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
130 |
+
next_target = {k: v.cuda(non_blocking=True) for k, v in next_target.items()}
|
131 |
+
if self.random_erasing is not None:
|
132 |
+
next_input = self.random_erasing(next_input, next_target)
|
133 |
+
|
134 |
+
if not first:
|
135 |
+
yield input, target
|
136 |
+
else:
|
137 |
+
first = False
|
138 |
+
|
139 |
+
torch.cuda.current_stream().wait_stream(stream)
|
140 |
+
input = next_input
|
141 |
+
target = next_target
|
142 |
+
|
143 |
+
yield input, target
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.loader)
|
147 |
+
|
148 |
+
@property
|
149 |
+
def sampler(self):
|
150 |
+
return self.loader.sampler
|
151 |
+
|
152 |
+
@property
|
153 |
+
def dataset(self):
|
154 |
+
return self.loader.dataset
|
155 |
+
|
156 |
+
|
157 |
+
def create_loader(
|
158 |
+
dataset,
|
159 |
+
input_size,
|
160 |
+
batch_size,
|
161 |
+
is_training=False,
|
162 |
+
use_prefetcher=True,
|
163 |
+
re_prob=0.,
|
164 |
+
re_mode='pixel',
|
165 |
+
re_count=1,
|
166 |
+
interpolation='bilinear',
|
167 |
+
fill_color='mean',
|
168 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
169 |
+
std=IMAGENET_DEFAULT_STD,
|
170 |
+
num_workers=1,
|
171 |
+
distributed=False,
|
172 |
+
pin_mem=False,
|
173 |
+
anchor_labeler=None,
|
174 |
+
):
|
175 |
+
if isinstance(input_size, tuple):
|
176 |
+
img_size = input_size[-2:]
|
177 |
+
else:
|
178 |
+
img_size = input_size
|
179 |
+
|
180 |
+
if is_training:
|
181 |
+
transforms = get_transform()
|
182 |
+
transform = transforms_coco_train(
|
183 |
+
img_size,
|
184 |
+
interpolation=interpolation,
|
185 |
+
use_prefetcher=use_prefetcher,
|
186 |
+
fill_color=fill_color,
|
187 |
+
mean=mean,
|
188 |
+
std=std)
|
189 |
+
else:
|
190 |
+
transforms = None
|
191 |
+
transform = transforms_coco_eval(
|
192 |
+
img_size,
|
193 |
+
interpolation=interpolation,
|
194 |
+
use_prefetcher=use_prefetcher,
|
195 |
+
fill_color=fill_color,
|
196 |
+
mean=mean,
|
197 |
+
std=std)
|
198 |
+
dataset.transforms = transforms
|
199 |
+
dataset.transform = transform
|
200 |
+
|
201 |
+
sampler = None
|
202 |
+
if distributed:
|
203 |
+
if is_training:
|
204 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
205 |
+
else:
|
206 |
+
# This will add extra duplicate entries to result in equal num
|
207 |
+
# of samples per-process, will slightly alter validation results
|
208 |
+
sampler = OrderedDistributedSampler(dataset)
|
209 |
+
|
210 |
+
collate_fn = DetectionFastCollate(anchor_labeler=anchor_labeler)
|
211 |
+
loader = torch.utils.data.DataLoader(
|
212 |
+
dataset,
|
213 |
+
batch_size=batch_size,
|
214 |
+
shuffle=sampler is None and is_training,
|
215 |
+
num_workers=num_workers,
|
216 |
+
sampler=sampler,
|
217 |
+
pin_memory=pin_mem,
|
218 |
+
collate_fn=collate_fn,
|
219 |
+
)
|
220 |
+
if use_prefetcher:
|
221 |
+
if is_training:
|
222 |
+
loader = PrefetchLoader(loader, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count)
|
223 |
+
else:
|
224 |
+
loader = PrefetchLoader(loader, mean=mean, std=std)
|
225 |
+
|
226 |
+
return loader
|
efficientdet/effdet/data/parsers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .parser_config import OpenImagesParserCfg, CocoParserCfg, VocParserCfg
|
2 |
+
from .parser_factory import create_parser
|
efficientdet/effdet/data/parsers/parser.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numbers import Integral
|
2 |
+
from typing import List, Union, Dict, Any
|
3 |
+
|
4 |
+
|
5 |
+
class Parser:
|
6 |
+
""" Parser base class.
|
7 |
+
|
8 |
+
The attributes listed below make up a public interface common to all parsers. They can be accessed directly
|
9 |
+
once the dataset is constructed and annotations are populated.
|
10 |
+
|
11 |
+
Attributes:
|
12 |
+
|
13 |
+
cat_names (list[str]):
|
14 |
+
list of category (class) names, with background class at position 0.
|
15 |
+
cat_ids (list[union[str, int]):
|
16 |
+
list of dataset specific, unique integer or string category ids, does not include background
|
17 |
+
cat_id_to_label (dict):
|
18 |
+
map from category id to integer 1-indexed class label
|
19 |
+
|
20 |
+
img_ids (list):
|
21 |
+
list of dataset specific, unique image ids corresponding to valid samples in dataset
|
22 |
+
img_ids_invalid (list):
|
23 |
+
list of image ids corresponding to invalid images, not used as samples
|
24 |
+
img_infos (list[dict]):
|
25 |
+
image info, list of info dicts with filename, width, height for each image sample
|
26 |
+
"""
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
bbox_yxyx: bool = False,
|
30 |
+
has_labels: bool = True,
|
31 |
+
include_masks: bool = False,
|
32 |
+
include_bboxes_ignore: bool = False,
|
33 |
+
ignore_empty_gt: bool = False,
|
34 |
+
min_img_size: int = 32,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
yxyx (bool): output coords in yxyx format, otherwise xyxy
|
39 |
+
has_labels (bool): dataset has labels (for training validation, False usually for test sets)
|
40 |
+
include_masks (bool): include segmentation masks in target output (not supported yet for any dataset)
|
41 |
+
include_bboxes_ignore (bool): include ignored bbox in target output
|
42 |
+
ignore_empty_gt (bool): ignore images with no ground truth (no negative images)
|
43 |
+
min_img_size (bool): ignore images with width or height smaller than this number
|
44 |
+
sub_sample (int): sample every N images from the dataset
|
45 |
+
"""
|
46 |
+
# parser config, determines how dataset parsed and validated
|
47 |
+
self.yxyx = bbox_yxyx
|
48 |
+
self.has_labels = has_labels
|
49 |
+
self.include_masks = include_masks
|
50 |
+
self.include_bboxes_ignore = include_bboxes_ignore
|
51 |
+
self.ignore_empty_gt = ignore_empty_gt
|
52 |
+
self.min_img_size = min_img_size
|
53 |
+
self.label_offset = 1
|
54 |
+
|
55 |
+
# Category (class) metadata. Populated by _load_annotations()
|
56 |
+
self.cat_names: List[str] = []
|
57 |
+
self.cat_ids: List[Union[str, Integral]] = []
|
58 |
+
self.cat_id_to_label: Dict[Union[str, Integral], Integral] = dict()
|
59 |
+
|
60 |
+
# Image metadata. Populated by _load_annotations()
|
61 |
+
self.img_ids: List[Union[str, Integral]] = []
|
62 |
+
self.img_ids_invalid: List[Union[str, Integral]] = []
|
63 |
+
self.img_infos: List[Dict[str, Any]] = []
|
64 |
+
|
65 |
+
@property
|
66 |
+
def cat_dicts(self):
|
67 |
+
"""return category names and labels in format compatible with TF Models Evaluator
|
68 |
+
list[dict(name=<class name>, id=<class label>)]
|
69 |
+
"""
|
70 |
+
return [
|
71 |
+
dict(
|
72 |
+
name=name,
|
73 |
+
id=cat_id if not self.cat_id_to_label else self.cat_id_to_label[cat_id]
|
74 |
+
) for name, cat_id in zip(self.cat_names, self.cat_ids)]
|
75 |
+
|
76 |
+
@property
|
77 |
+
def max_label(self):
|
78 |
+
if self.cat_id_to_label:
|
79 |
+
return max(self.cat_id_to_label.values())
|
80 |
+
else:
|
81 |
+
assert len(self.cat_ids) and isinstance(self.cat_ids[0], Integral)
|
82 |
+
return max(self.cat_ids)
|
efficientdet/effdet/data/parsers/parser_coco.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" COCO dataset parser
|
2 |
+
|
3 |
+
Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import numpy as np
|
6 |
+
from pycocotools.coco import COCO
|
7 |
+
from .parser import Parser
|
8 |
+
from .parser_config import CocoParserCfg
|
9 |
+
|
10 |
+
|
11 |
+
class CocoParser(Parser):
|
12 |
+
|
13 |
+
def __init__(self, cfg: CocoParserCfg):
|
14 |
+
super().__init__(
|
15 |
+
bbox_yxyx=cfg.bbox_yxyx,
|
16 |
+
has_labels=cfg.has_labels,
|
17 |
+
include_masks=cfg.include_masks,
|
18 |
+
include_bboxes_ignore=cfg.include_bboxes_ignore,
|
19 |
+
ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt,
|
20 |
+
min_img_size=cfg.min_img_size
|
21 |
+
)
|
22 |
+
self.cat_ids_as_labels = True # this is the default for original TF EfficientDet models
|
23 |
+
self.coco = None
|
24 |
+
self._load_annotations(cfg.ann_filename)
|
25 |
+
|
26 |
+
def get_ann_info(self, idx):
|
27 |
+
img_id = self.img_ids[idx]
|
28 |
+
return self._parse_img_ann(img_id)
|
29 |
+
|
30 |
+
def _load_annotations(self, ann_file):
|
31 |
+
assert self.coco is None
|
32 |
+
self.coco = COCO(ann_file)
|
33 |
+
self.cat_ids = self.coco.getCatIds()
|
34 |
+
self.cat_names = [c['name'] for c in self.coco.loadCats(ids=self.cat_ids)]
|
35 |
+
if not self.cat_ids_as_labels:
|
36 |
+
self.cat_id_to_label = {cat_id: i + self.label_offset for i, cat_id in enumerate(self.cat_ids)}
|
37 |
+
img_ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
|
38 |
+
for img_id in sorted(self.coco.imgs.keys()):
|
39 |
+
info = self.coco.loadImgs([img_id])[0]
|
40 |
+
if (min(info['width'], info['height']) < self.min_img_size or
|
41 |
+
(self.ignore_empty_gt and img_id not in img_ids_with_ann)):
|
42 |
+
self.img_ids_invalid.append(img_id)
|
43 |
+
continue
|
44 |
+
self.img_ids.append(img_id)
|
45 |
+
self.img_infos.append(info)
|
46 |
+
|
47 |
+
def _parse_img_ann(self, img_id):
|
48 |
+
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
|
49 |
+
ann_info = self.coco.loadAnns(ann_ids)
|
50 |
+
bboxes = []
|
51 |
+
bboxes_ignore = []
|
52 |
+
cls = []
|
53 |
+
|
54 |
+
for i, ann in enumerate(ann_info):
|
55 |
+
if ann.get('ignore', False):
|
56 |
+
continue
|
57 |
+
x1, y1, w, h = ann['bbox']
|
58 |
+
if self.include_masks and ann['area'] <= 0:
|
59 |
+
continue
|
60 |
+
if w < 1 or h < 1:
|
61 |
+
continue
|
62 |
+
|
63 |
+
if self.yxyx:
|
64 |
+
bbox = [y1, x1, y1 + h, x1 + w]
|
65 |
+
else:
|
66 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
67 |
+
|
68 |
+
if ann.get('iscrowd', False):
|
69 |
+
if self.include_bboxes_ignore:
|
70 |
+
bboxes_ignore.append(bbox)
|
71 |
+
else:
|
72 |
+
bboxes.append(bbox)
|
73 |
+
cls.append(self.cat_id_to_label[ann['category_id']] if self.cat_id_to_label else ann['category_id'])
|
74 |
+
|
75 |
+
if bboxes:
|
76 |
+
bboxes = np.array(bboxes, ndmin=2, dtype=np.float32)
|
77 |
+
cls = np.array(cls, dtype=np.int64)
|
78 |
+
else:
|
79 |
+
bboxes = np.zeros((0, 4), dtype=np.float32)
|
80 |
+
cls = np.array([], dtype=np.int64)
|
81 |
+
|
82 |
+
if self.include_bboxes_ignore:
|
83 |
+
if bboxes_ignore:
|
84 |
+
bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32)
|
85 |
+
else:
|
86 |
+
bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
87 |
+
|
88 |
+
ann = dict(bbox=bboxes, cls=cls)
|
89 |
+
|
90 |
+
if self.include_bboxes_ignore:
|
91 |
+
ann['bbox_ignore'] = bboxes_ignore
|
92 |
+
|
93 |
+
return ann
|
efficientdet/effdet/data/parsers/parser_config.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Dataset parser configs
|
2 |
+
|
3 |
+
Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
__all__ = ['CocoParserCfg', 'OpenImagesParserCfg', 'VocParserCfg']
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class CocoParserCfg:
|
12 |
+
ann_filename: str # absolute path
|
13 |
+
include_masks: bool = False
|
14 |
+
include_bboxes_ignore: bool = False
|
15 |
+
has_labels: bool = True
|
16 |
+
bbox_yxyx: bool = True
|
17 |
+
min_img_size: int = 32
|
18 |
+
ignore_empty_gt: bool = False
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class VocParserCfg:
|
23 |
+
split_filename: str
|
24 |
+
ann_filename: str
|
25 |
+
img_filename: str = '%.jpg'
|
26 |
+
keep_difficult: bool = True
|
27 |
+
classes: list = None
|
28 |
+
add_background: bool = True
|
29 |
+
has_labels: bool = True
|
30 |
+
bbox_yxyx: bool = True
|
31 |
+
min_img_size: int = 32
|
32 |
+
ignore_empty_gt: bool = False
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class OpenImagesParserCfg:
|
37 |
+
categories_filename: str
|
38 |
+
img_info_filename: str
|
39 |
+
bbox_filename: str
|
40 |
+
img_label_filename: str = ''
|
41 |
+
masks_filename: str = ''
|
42 |
+
img_filename: str = '%s.jpg' # relative to dataset img_dir
|
43 |
+
task: str = 'obj'
|
44 |
+
prefix_levels: int = 1
|
45 |
+
add_background: bool = True
|
46 |
+
has_labels: bool = True
|
47 |
+
bbox_yxyx: bool = True
|
48 |
+
min_img_size: int = 32
|
49 |
+
ignore_empty_gt: bool = False
|
efficientdet/effdet/data/parsers/parser_factory.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Parser factory
|
2 |
+
|
3 |
+
Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
from .parser_coco import CocoParser
|
6 |
+
from .parser_voc import VocParser
|
7 |
+
from .parser_open_images import OpenImagesParser
|
8 |
+
|
9 |
+
|
10 |
+
def create_parser(name, **kwargs):
|
11 |
+
if name == 'coco':
|
12 |
+
parser = CocoParser(**kwargs)
|
13 |
+
elif name == 'voc':
|
14 |
+
parser = VocParser(**kwargs)
|
15 |
+
elif name == 'openimages':
|
16 |
+
parser = OpenImagesParser(**kwargs)
|
17 |
+
else:
|
18 |
+
assert False, f'Unknown dataset parser ({name})'
|
19 |
+
return parser
|
efficientdet/effdet/data/parsers/parser_open_images.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" OpenImages dataset parser
|
2 |
+
|
3 |
+
Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from .parser import Parser
|
10 |
+
from .parser_config import OpenImagesParserCfg
|
11 |
+
|
12 |
+
_logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class OpenImagesParser(Parser):
|
16 |
+
|
17 |
+
def __init__(self, cfg: OpenImagesParserCfg):
|
18 |
+
super().__init__(
|
19 |
+
bbox_yxyx=cfg.bbox_yxyx,
|
20 |
+
has_labels=cfg.has_labels,
|
21 |
+
include_masks=False, # FIXME to support someday
|
22 |
+
include_bboxes_ignore=False,
|
23 |
+
ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt,
|
24 |
+
min_img_size=cfg.min_img_size
|
25 |
+
)
|
26 |
+
self.img_prefix_levels = cfg.prefix_levels
|
27 |
+
self.mask_prefix_levels = 1
|
28 |
+
self._anns = None # access via get_ann_info()
|
29 |
+
self._img_to_ann = None
|
30 |
+
self._load_annotations(
|
31 |
+
categories_filename=cfg.categories_filename,
|
32 |
+
img_info_filename=cfg.img_info_filename,
|
33 |
+
img_filename=cfg.img_filename,
|
34 |
+
masks_filename=cfg.masks_filename,
|
35 |
+
bbox_filename=cfg.bbox_filename
|
36 |
+
)
|
37 |
+
|
38 |
+
def _load_annotations(
|
39 |
+
self,
|
40 |
+
categories_filename: str,
|
41 |
+
img_info_filename: str,
|
42 |
+
img_filename: str,
|
43 |
+
masks_filename: str,
|
44 |
+
bbox_filename: str,
|
45 |
+
):
|
46 |
+
import pandas as pd # For now, blow up on pandas req only when trying to load open images anno
|
47 |
+
|
48 |
+
_logger.info('Loading categories...')
|
49 |
+
classes_df = pd.read_csv(categories_filename, header=None)
|
50 |
+
self.cat_ids = classes_df[0].tolist()
|
51 |
+
self.cat_names = classes_df[1].tolist()
|
52 |
+
self.cat_id_to_label = {c: i + self.label_offset for i, c in enumerate(self.cat_ids)}
|
53 |
+
|
54 |
+
def _img_filename(img_id):
|
55 |
+
# build image filenames that are relative to img_dir
|
56 |
+
filename = img_filename % img_id
|
57 |
+
if self.img_prefix_levels:
|
58 |
+
levels = [c for c in img_id[:self.img_prefix_levels]]
|
59 |
+
filename = os.path.join(*levels, filename)
|
60 |
+
return filename
|
61 |
+
|
62 |
+
def _mask_filename(mask_path):
|
63 |
+
# FIXME finish
|
64 |
+
if self.mask_prefix_levels:
|
65 |
+
levels = [c for c in mask_path[:self.mask_prefix_levels]]
|
66 |
+
mask_path = os.path.join(*levels, mask_path)
|
67 |
+
return mask_path
|
68 |
+
|
69 |
+
def _load_img_info(csv_file, select_img_ids=None):
|
70 |
+
_logger.info('Read img_info csv...')
|
71 |
+
img_info_df = pd.read_csv(csv_file, index_col='id')
|
72 |
+
|
73 |
+
_logger.info('Filter images...')
|
74 |
+
if select_img_ids is not None:
|
75 |
+
img_info_df = img_info_df.loc[select_img_ids]
|
76 |
+
img_info_df = img_info_df[
|
77 |
+
(img_info_df['width'] >= self.min_img_size) & (img_info_df['height'] >= self.min_img_size)]
|
78 |
+
|
79 |
+
_logger.info('Mapping ids...')
|
80 |
+
img_info_df['img_id'] = img_info_df.index
|
81 |
+
img_info_df['file_name'] = img_info_df.index.map(lambda x: _img_filename(x))
|
82 |
+
img_info_df = img_info_df[['img_id', 'file_name', 'width', 'height']]
|
83 |
+
img_sizes = img_info_df[['width', 'height']].values
|
84 |
+
self.img_infos = img_info_df.to_dict('records')
|
85 |
+
self.img_ids = img_info_df.index.values.tolist()
|
86 |
+
img_id_to_idx = {img_id: idx for idx, img_id in enumerate(self.img_ids)}
|
87 |
+
return img_sizes, img_id_to_idx
|
88 |
+
|
89 |
+
if self.include_masks and self.has_labels:
|
90 |
+
masks_df = pd.read_csv(masks_filename)
|
91 |
+
|
92 |
+
# NOTE currently using dataset masks anno ImageIDs to form valid img_ids from the dataset
|
93 |
+
anno_img_ids = sorted(masks_df['ImageID'].unique())
|
94 |
+
img_sizes, img_id_to_idx = _load_img_info(img_info_filename, select_img_ids=anno_img_ids)
|
95 |
+
|
96 |
+
masks_df['ImageIdx'] = masks_df['ImageID'].map(img_id_to_idx)
|
97 |
+
if np.issubdtype(masks_df.ImageIdx.dtype, np.floating):
|
98 |
+
masks_df = masks_df.dropna(axis='rows')
|
99 |
+
masks_df['ImageIdx'] = masks_df.ImageIdx.astype(np.int32)
|
100 |
+
masks_df.sort_values('ImageIdx', inplace=True)
|
101 |
+
ann_img_idx = masks_df['ImageIdx'].values
|
102 |
+
img_sizes = img_sizes[ann_img_idx]
|
103 |
+
masks_df['BoxXMin'] = masks_df['BoxXMin'] * img_sizes[:, 0]
|
104 |
+
masks_df['BoxXMax'] = masks_df['BoxXMax'] * img_sizes[:, 0]
|
105 |
+
masks_df['BoxYMin'] = masks_df['BoxYMin'] * img_sizes[:, 1]
|
106 |
+
masks_df['BoxYMax'] = masks_df['BoxYMax'] * img_sizes[:, 1]
|
107 |
+
masks_df['LabelIdx'] = masks_df['LabelName'].map(self.cat_id_to_label)
|
108 |
+
# FIXME remap mask filename with _mask_filename
|
109 |
+
|
110 |
+
self._anns = dict(
|
111 |
+
bbox=masks_df[['BoxXMin', 'BoxYMin', 'BoxXMax', 'BoxYMax']].values.astype(np.float32),
|
112 |
+
label=masks_df[['LabelIdx']].values.astype(np.int32),
|
113 |
+
mask_path=masks_df[['MaskPath']].values
|
114 |
+
)
|
115 |
+
_, ri, rc = np.unique(ann_img_idx, return_index=True, return_counts=True)
|
116 |
+
self._img_to_ann = list(zip(ri, rc)) # index, count tuples
|
117 |
+
elif self.has_labels:
|
118 |
+
_logger.info('Loading bbox...')
|
119 |
+
bbox_df = pd.read_csv(bbox_filename)
|
120 |
+
|
121 |
+
# NOTE currently using dataset box anno ImageIDs to form valid img_ids from the larger dataset.
|
122 |
+
# FIXME use *imagelabels.csv or imagelabels-boxable.csv for negative examples (without box?)
|
123 |
+
anno_img_ids = sorted(bbox_df['ImageID'].unique())
|
124 |
+
img_sizes, img_id_to_idx = _load_img_info(img_info_filename, select_img_ids=anno_img_ids)
|
125 |
+
|
126 |
+
_logger.info('Process bbox...')
|
127 |
+
bbox_df['ImageIdx'] = bbox_df['ImageID'].map(img_id_to_idx)
|
128 |
+
if np.issubdtype(bbox_df.ImageIdx.dtype, np.floating):
|
129 |
+
bbox_df = bbox_df.dropna(axis='rows')
|
130 |
+
bbox_df['ImageIdx'] = bbox_df.ImageIdx.astype(np.int32)
|
131 |
+
bbox_df.sort_values('ImageIdx', inplace=True)
|
132 |
+
ann_img_idx = bbox_df['ImageIdx'].values
|
133 |
+
img_sizes = img_sizes[ann_img_idx]
|
134 |
+
bbox_df['XMin'] = bbox_df['XMin'] * img_sizes[:, 0]
|
135 |
+
bbox_df['XMax'] = bbox_df['XMax'] * img_sizes[:, 0]
|
136 |
+
bbox_df['YMin'] = bbox_df['YMin'] * img_sizes[:, 1]
|
137 |
+
bbox_df['YMax'] = bbox_df['YMax'] * img_sizes[:, 1]
|
138 |
+
bbox_df['LabelIdx'] = bbox_df['LabelName'].map(self.cat_id_to_label).astype(np.int32)
|
139 |
+
|
140 |
+
self._anns = dict(
|
141 |
+
bbox=bbox_df[['XMin', 'YMin', 'XMax', 'YMax']].values.astype(np.float32),
|
142 |
+
label=bbox_df[['LabelIdx', 'IsGroupOf']].values.astype(np.int32),
|
143 |
+
)
|
144 |
+
_, ri, rc = np.unique(ann_img_idx, return_index=True, return_counts=True)
|
145 |
+
self._img_to_ann = list(zip(ri, rc)) # index, count tuples
|
146 |
+
else:
|
147 |
+
_load_img_info(img_info_filename)
|
148 |
+
|
149 |
+
_logger.info('Annotations loaded!')
|
150 |
+
|
151 |
+
def get_ann_info(self, idx):
|
152 |
+
if not self.has_labels:
|
153 |
+
return dict()
|
154 |
+
start_idx, num_ann = self._img_to_ann[idx]
|
155 |
+
ann_keys = tuple(self._anns.keys())
|
156 |
+
ann_values = tuple(self._anns[k][start_idx:start_idx + num_ann] for k in ann_keys)
|
157 |
+
return self._parse_ann_info(idx, ann_keys, ann_values)
|
158 |
+
|
159 |
+
def _parse_ann_info(self, img_idx, ann_keys, ann_values):
|
160 |
+
"""
|
161 |
+
"""
|
162 |
+
gt_bboxes = []
|
163 |
+
gt_labels = []
|
164 |
+
gt_bboxes_ignore = []
|
165 |
+
if self.include_masks:
|
166 |
+
assert 'mask_path' in ann_keys
|
167 |
+
gt_masks = []
|
168 |
+
|
169 |
+
for ann in zip(*ann_values):
|
170 |
+
ann = dict(zip(ann_keys, ann))
|
171 |
+
x1, y1, x2, y2 = ann['bbox']
|
172 |
+
if x2 - x1 < 1 or y2 - y1 < 1:
|
173 |
+
continue
|
174 |
+
label = ann['label'][0]
|
175 |
+
iscrowd = False
|
176 |
+
if len(ann['label']) > 1:
|
177 |
+
iscrowd = ann['label'][1]
|
178 |
+
if self.yxyx:
|
179 |
+
bbox = np.array([y1, x1, y2, x2], dtype=np.float32)
|
180 |
+
else:
|
181 |
+
bbox = ann['bbox']
|
182 |
+
if iscrowd:
|
183 |
+
gt_bboxes_ignore.append(bbox)
|
184 |
+
else:
|
185 |
+
gt_bboxes.append(bbox)
|
186 |
+
gt_labels.append(label)
|
187 |
+
# if self.include_masks:
|
188 |
+
# img_info = self.img_infos[img_idx]
|
189 |
+
# mask_img = SegmentationMask(ann['mask_filename'], img_info['width'], img_info['height'])
|
190 |
+
# gt_masks.append(mask_img)
|
191 |
+
|
192 |
+
if gt_bboxes:
|
193 |
+
gt_bboxes = np.array(gt_bboxes, ndmin=2, dtype=np.float32)
|
194 |
+
gt_labels = np.array(gt_labels, dtype=np.int64)
|
195 |
+
else:
|
196 |
+
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
|
197 |
+
gt_labels = np.array([], dtype=np.int64)
|
198 |
+
|
199 |
+
if self.include_bboxes_ignore:
|
200 |
+
if gt_bboxes_ignore:
|
201 |
+
gt_bboxes_ignore = np.array(gt_bboxes_ignore, ndmin=2, dtype=np.float32)
|
202 |
+
else:
|
203 |
+
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
204 |
+
|
205 |
+
ann = dict(bbox=gt_bboxes, cls=gt_labels)
|
206 |
+
|
207 |
+
if self.include_bboxes_ignore:
|
208 |
+
ann.update(dict(bbox_ignore=gt_bboxes_ignore, cls_ignore=np.array([], dtype=np.int64)))
|
209 |
+
if self.include_masks:
|
210 |
+
ann['masks'] = gt_masks
|
211 |
+
return ann
|
efficientdet/effdet/data/parsers/parser_voc.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Pascal VOC dataset parser
|
2 |
+
|
3 |
+
Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import xml.etree.ElementTree as ET
|
7 |
+
from collections import defaultdict
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from .parser import Parser
|
11 |
+
from .parser_config import VocParserCfg
|
12 |
+
|
13 |
+
|
14 |
+
class VocParser(Parser):
|
15 |
+
|
16 |
+
DEFAULT_CLASSES = (
|
17 |
+
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
|
18 |
+
'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant',
|
19 |
+
'sheep', 'sofa', 'train', 'tvmonitor')
|
20 |
+
|
21 |
+
def __init__(self, cfg: VocParserCfg):
|
22 |
+
super().__init__(
|
23 |
+
bbox_yxyx=cfg.bbox_yxyx,
|
24 |
+
has_labels=cfg.has_labels,
|
25 |
+
include_masks=False, # FIXME to support someday
|
26 |
+
include_bboxes_ignore=False,
|
27 |
+
ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt,
|
28 |
+
min_img_size=cfg.min_img_size
|
29 |
+
)
|
30 |
+
self.correct_bbox = 1
|
31 |
+
self.keep_difficult = cfg.keep_difficult
|
32 |
+
|
33 |
+
self.anns = None
|
34 |
+
self.img_id_to_idx = {}
|
35 |
+
self._load_annotations(
|
36 |
+
split_filename=cfg.split_filename,
|
37 |
+
img_filename=cfg.img_filename,
|
38 |
+
ann_filename=cfg.ann_filename,
|
39 |
+
classes=cfg.classes,
|
40 |
+
)
|
41 |
+
|
42 |
+
def _load_annotations(
|
43 |
+
self,
|
44 |
+
split_filename: str,
|
45 |
+
img_filename: str,
|
46 |
+
ann_filename: str,
|
47 |
+
classes=None,
|
48 |
+
):
|
49 |
+
classes = classes or self.DEFAULT_CLASSES
|
50 |
+
self.cat_names = list(classes)
|
51 |
+
self.cat_ids = self.cat_names
|
52 |
+
self.cat_id_to_label = {cat: i + self.label_offset for i, cat in enumerate(self.cat_ids)}
|
53 |
+
|
54 |
+
self.anns = []
|
55 |
+
|
56 |
+
with open(split_filename) as f:
|
57 |
+
ids = f.readlines()
|
58 |
+
for img_id in ids:
|
59 |
+
img_id = img_id.strip("\n")
|
60 |
+
filename = img_filename % img_id
|
61 |
+
xml_path = ann_filename % img_id
|
62 |
+
tree = ET.parse(xml_path)
|
63 |
+
root = tree.getroot()
|
64 |
+
size = root.find('size')
|
65 |
+
width = int(size.find('width').text)
|
66 |
+
height = int(size.find('height').text)
|
67 |
+
if min(width, height) < self.min_img_size:
|
68 |
+
continue
|
69 |
+
|
70 |
+
anns = []
|
71 |
+
for obj_idx, obj in enumerate(root.findall('object')):
|
72 |
+
name = obj.find('name').text
|
73 |
+
label = self.cat_id_to_label[name]
|
74 |
+
difficult = int(obj.find('difficult').text)
|
75 |
+
bnd_box = obj.find('bndbox')
|
76 |
+
bbox = [
|
77 |
+
int(bnd_box.find('xmin').text),
|
78 |
+
int(bnd_box.find('ymin').text),
|
79 |
+
int(bnd_box.find('xmax').text),
|
80 |
+
int(bnd_box.find('ymax').text)
|
81 |
+
]
|
82 |
+
anns.append(dict(label=label, bbox=bbox, difficult=difficult))
|
83 |
+
|
84 |
+
if not self.ignore_empty_gt or len(anns):
|
85 |
+
self.anns.append(anns)
|
86 |
+
self.img_infos.append(dict(id=img_id, file_name=filename, width=width, height=height))
|
87 |
+
self.img_ids.append(img_id)
|
88 |
+
else:
|
89 |
+
self.img_ids_invalid.append(img_id)
|
90 |
+
|
91 |
+
def merge(self, other):
|
92 |
+
assert len(self.cat_ids) == len(other.cat_ids)
|
93 |
+
self.img_ids.extend(other.img_ids)
|
94 |
+
self.img_infos.extend(other.img_infos)
|
95 |
+
self.anns.extend(other.anns)
|
96 |
+
|
97 |
+
def get_ann_info(self, idx):
|
98 |
+
return self._parse_ann_info(self.anns[idx])
|
99 |
+
|
100 |
+
def _parse_ann_info(self, ann_info):
|
101 |
+
bboxes = []
|
102 |
+
labels = []
|
103 |
+
bboxes_ignore = []
|
104 |
+
labels_ignore = []
|
105 |
+
for ann in ann_info:
|
106 |
+
ignore = False
|
107 |
+
x1, y1, x2, y2 = ann['bbox']
|
108 |
+
label = ann['label']
|
109 |
+
w = x2 - x1
|
110 |
+
h = y2 - y1
|
111 |
+
if w < 1 or h < 1:
|
112 |
+
ignore = True
|
113 |
+
if self.yxyx:
|
114 |
+
bbox = [y1, x1, y2, x2]
|
115 |
+
else:
|
116 |
+
bbox = ann['bbox']
|
117 |
+
if ignore or (ann['difficult'] and not self.keep_difficult):
|
118 |
+
bboxes_ignore.append(bbox)
|
119 |
+
labels_ignore.append(label)
|
120 |
+
else:
|
121 |
+
bboxes.append(bbox)
|
122 |
+
labels.append(label)
|
123 |
+
|
124 |
+
if not bboxes:
|
125 |
+
bboxes = np.zeros((0, 4), dtype=np.float32)
|
126 |
+
labels = np.zeros((0, ), dtype=np.float32)
|
127 |
+
else:
|
128 |
+
bboxes = np.array(bboxes, ndmin=2, dtype=np.float32) - self.correct_bbox
|
129 |
+
labels = np.array(labels, dtype=np.float32)
|
130 |
+
|
131 |
+
if self.include_bboxes_ignore:
|
132 |
+
if not bboxes_ignore:
|
133 |
+
bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
134 |
+
labels_ignore = np.zeros((0, ), dtype=np.float32)
|
135 |
+
else:
|
136 |
+
bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32) - self.correct_bbox
|
137 |
+
labels_ignore = np.array(labels_ignore, dtype=np.float32)
|
138 |
+
|
139 |
+
ann = dict(
|
140 |
+
bbox=bboxes.astype(np.float32),
|
141 |
+
cls=labels.astype(np.int64))
|
142 |
+
|
143 |
+
if self.include_bboxes_ignore:
|
144 |
+
ann.update(dict(
|
145 |
+
bbox_ignore=bboxes_ignore.astype(np.float32),
|
146 |
+
cls_ignore=labels_ignore.astype(np.int64)))
|
147 |
+
return ann
|
148 |
+
|
efficientdet/effdet/data/random_erasing.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Multi-Scale RandomErasing
|
2 |
+
|
3 |
+
Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
|
11 |
+
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
|
12 |
+
# paths, flip the order so normal is run on CPU if this becomes a problem
|
13 |
+
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
|
14 |
+
if per_pixel:
|
15 |
+
return torch.empty(patch_size, dtype=dtype, device=device).normal_()
|
16 |
+
elif rand_color:
|
17 |
+
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
|
18 |
+
else:
|
19 |
+
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
|
20 |
+
|
21 |
+
|
22 |
+
class RandomErasing:
|
23 |
+
""" Randomly selects a rectangle region in an image and erases its pixels.
|
24 |
+
'Random Erasing Data Augmentation' by Zhong et al.
|
25 |
+
See https://arxiv.org/pdf/1708.04896.pdf
|
26 |
+
|
27 |
+
This variant of RandomErasing is tweaked for multi-scale obj detection training.
|
28 |
+
Args:
|
29 |
+
probability: Probability that the Random Erasing operation will be performed.
|
30 |
+
min_area: Minimum percentage of erased area wrt input image area.
|
31 |
+
max_area: Maximum percentage of erased area wrt input image area.
|
32 |
+
min_aspect: Minimum aspect ratio of erased area.
|
33 |
+
mode: pixel color mode, one of 'const', 'rand', or 'pixel'
|
34 |
+
'const' - erase block is constant color of 0 for all channels
|
35 |
+
'rand' - erase block is same per-channel random (normal) color
|
36 |
+
'pixel' - erase block is per-pixel random (normal) color
|
37 |
+
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
|
38 |
+
per-image count is randomly chosen between 1 and this value.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
probability=0.5, min_area=0.02, max_area=1/4, min_aspect=0.3, max_aspect=None,
|
44 |
+
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
|
45 |
+
self.probability = probability
|
46 |
+
self.min_area = min_area
|
47 |
+
self.max_area = max_area
|
48 |
+
max_aspect = max_aspect or 1 / min_aspect
|
49 |
+
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
50 |
+
self.min_count = min_count
|
51 |
+
self.max_count = max_count or min_count
|
52 |
+
self.num_splits = num_splits
|
53 |
+
mode = mode.lower()
|
54 |
+
self.rand_color = False
|
55 |
+
self.per_pixel = False
|
56 |
+
if mode == 'rand':
|
57 |
+
self.rand_color = True # per block random normal
|
58 |
+
elif mode == 'pixel':
|
59 |
+
self.per_pixel = True # per pixel random normal
|
60 |
+
else:
|
61 |
+
assert not mode or mode == 'const'
|
62 |
+
self.device = device
|
63 |
+
|
64 |
+
def _erase(self, img, chan, img_h, img_w, dtype):
|
65 |
+
if random.random() > self.probability:
|
66 |
+
return
|
67 |
+
area = img_h * img_w
|
68 |
+
count = self.min_count if self.min_count == self.max_count else \
|
69 |
+
random.randint(self.min_count, self.max_count)
|
70 |
+
for _ in range(count):
|
71 |
+
for attempt in range(10):
|
72 |
+
target_area = random.uniform(self.min_area, self.max_area) * area / count
|
73 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
74 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
75 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
76 |
+
if w < img_w and h < img_h:
|
77 |
+
top = random.randint(0, img_h - h)
|
78 |
+
left = random.randint(0, img_w - w)
|
79 |
+
img[:, top:top + h, left:left + w] = _get_pixels(
|
80 |
+
self.per_pixel, self.rand_color, (chan, h, w),
|
81 |
+
dtype=dtype, device=self.device)
|
82 |
+
break
|
83 |
+
|
84 |
+
def __call__(self, input, target):
|
85 |
+
batch_size, chan, input_h, input_w = input.shape
|
86 |
+
img_scales = target['img_scale']
|
87 |
+
img_size = (target['img_size'] / img_scales.unsqueeze(1)).int()
|
88 |
+
img_size[:, 0] = img_size[:, 0].clamp(max=input_w)
|
89 |
+
img_size[:, 1] = img_size[:, 1].clamp(max=input_h)
|
90 |
+
# skip first slice of batch if num_splits is set (for clean portion of samples)
|
91 |
+
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
|
92 |
+
for i in range(batch_start, batch_size):
|
93 |
+
self._erase(input[i], chan, img_size[i, 1], img_size[i, 0], input.dtype)
|
94 |
+
return input
|
efficientdet/effdet/data/transforms.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" COCO transforms (quick and dirty)
|
2 |
+
|
3 |
+
Hacked together by Ross Wightman
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
import math
|
10 |
+
|
11 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
12 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
13 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
14 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
15 |
+
|
16 |
+
|
17 |
+
class ImageToNumpy:
|
18 |
+
|
19 |
+
def __call__(self, pil_img, annotations: dict):
|
20 |
+
np_img = np.array(pil_img, dtype=np.uint8)
|
21 |
+
if np_img.ndim < 3:
|
22 |
+
np_img = np.expand_dims(np_img, axis=-1)
|
23 |
+
np_img = np.moveaxis(np_img, 2, 0) # HWC to CHW
|
24 |
+
return np_img, annotations
|
25 |
+
|
26 |
+
|
27 |
+
class ImageToTensor:
|
28 |
+
|
29 |
+
def __init__(self, dtype=torch.float32):
|
30 |
+
self.dtype = dtype
|
31 |
+
|
32 |
+
def __call__(self, pil_img, annotations: dict):
|
33 |
+
np_img = np.array(pil_img, dtype=np.uint8)
|
34 |
+
if np_img.ndim < 3:
|
35 |
+
np_img = np.expand_dims(np_img, axis=-1)
|
36 |
+
np_img = np.moveaxis(np_img, 2, 0) # HWC to CHW
|
37 |
+
return torch.from_numpy(np_img).to(dtype=self.dtype), annotations
|
38 |
+
|
39 |
+
|
40 |
+
def _pil_interp(method):
|
41 |
+
if method == 'bicubic':
|
42 |
+
return Image.BICUBIC
|
43 |
+
elif method == 'lanczos':
|
44 |
+
return Image.LANCZOS
|
45 |
+
elif method == 'hamming':
|
46 |
+
return Image.HAMMING
|
47 |
+
else:
|
48 |
+
# default bilinear, do we want to allow nearest?
|
49 |
+
return Image.BILINEAR
|
50 |
+
|
51 |
+
|
52 |
+
def clip_boxes_(boxes, img_size):
|
53 |
+
height, width = img_size
|
54 |
+
clip_upper = np.array([height, width] * 2, dtype=boxes.dtype)
|
55 |
+
np.clip(boxes, 0, clip_upper, out=boxes)
|
56 |
+
|
57 |
+
|
58 |
+
def clip_boxes(boxes, img_size):
|
59 |
+
clipped_boxes = boxes.copy()
|
60 |
+
clip_boxes_(clipped_boxes, img_size)
|
61 |
+
return clipped_boxes
|
62 |
+
|
63 |
+
|
64 |
+
def _size_tuple(size):
|
65 |
+
if isinstance(size, int):
|
66 |
+
return size, size
|
67 |
+
else:
|
68 |
+
assert len(size) == 2
|
69 |
+
return size
|
70 |
+
|
71 |
+
|
72 |
+
class ResizePad:
|
73 |
+
|
74 |
+
def __init__(self, target_size: int, interpolation: str = 'bilinear', fill_color: tuple = (0, 0, 0)):
|
75 |
+
self.target_size = _size_tuple(target_size)
|
76 |
+
self.interpolation = interpolation
|
77 |
+
self.fill_color = fill_color
|
78 |
+
|
79 |
+
def __call__(self, img, anno: dict):
|
80 |
+
width, height = img.size
|
81 |
+
|
82 |
+
img_scale_y = self.target_size[0] / height
|
83 |
+
img_scale_x = self.target_size[1] / width
|
84 |
+
img_scale = min(img_scale_y, img_scale_x)
|
85 |
+
scaled_h = int(height * img_scale)
|
86 |
+
scaled_w = int(width * img_scale)
|
87 |
+
|
88 |
+
new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color)
|
89 |
+
interp_method = _pil_interp(self.interpolation)
|
90 |
+
img = img.resize((scaled_w, scaled_h), interp_method)
|
91 |
+
new_img.paste(img)
|
92 |
+
|
93 |
+
if 'bbox' in anno:
|
94 |
+
# FIXME haven't tested this path since not currently using dataset annotations for train/eval
|
95 |
+
bbox = anno['bbox']
|
96 |
+
bbox[:, :4] *= img_scale
|
97 |
+
clip_boxes_(bbox, (scaled_h, scaled_w))
|
98 |
+
valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1)
|
99 |
+
anno['bbox'] = bbox[valid_indices, :]
|
100 |
+
anno['cls'] = anno['cls'][valid_indices]
|
101 |
+
|
102 |
+
anno['img_scale'] = 1. / img_scale # back to original
|
103 |
+
|
104 |
+
return new_img, anno
|
105 |
+
|
106 |
+
|
107 |
+
class RandomResizePad:
|
108 |
+
|
109 |
+
def __init__(self, target_size: int, scale: tuple = (0.1, 2.0), interpolation: str = 'bilinear',
|
110 |
+
fill_color: tuple = (0, 0, 0)):
|
111 |
+
self.target_size = _size_tuple(target_size)
|
112 |
+
self.scale = scale
|
113 |
+
self.interpolation = interpolation
|
114 |
+
self.fill_color = fill_color
|
115 |
+
|
116 |
+
def _get_params(self, img):
|
117 |
+
# Select a random scale factor.
|
118 |
+
scale_factor = random.uniform(*self.scale)
|
119 |
+
scaled_target_height = scale_factor * self.target_size[0]
|
120 |
+
scaled_target_width = scale_factor * self.target_size[1]
|
121 |
+
|
122 |
+
# Recompute the accurate scale_factor using rounded scaled image size.
|
123 |
+
width, height = img.size
|
124 |
+
img_scale_y = scaled_target_height / height
|
125 |
+
img_scale_x = scaled_target_width / width
|
126 |
+
img_scale = min(img_scale_y, img_scale_x)
|
127 |
+
|
128 |
+
# Select non-zero random offset (x, y) if scaled image is larger than target size
|
129 |
+
scaled_h = int(height * img_scale)
|
130 |
+
scaled_w = int(width * img_scale)
|
131 |
+
offset_y = scaled_h - self.target_size[0]
|
132 |
+
offset_x = scaled_w - self.target_size[1]
|
133 |
+
offset_y = int(max(0.0, float(offset_y)) * random.uniform(0, 1))
|
134 |
+
offset_x = int(max(0.0, float(offset_x)) * random.uniform(0, 1))
|
135 |
+
return scaled_h, scaled_w, offset_y, offset_x, img_scale
|
136 |
+
|
137 |
+
def __call__(self, img, anno: dict):
|
138 |
+
scaled_h, scaled_w, offset_y, offset_x, img_scale = self._get_params(img)
|
139 |
+
|
140 |
+
interp_method = _pil_interp(self.interpolation)
|
141 |
+
img = img.resize((scaled_w, scaled_h), interp_method)
|
142 |
+
right, lower = min(scaled_w, offset_x + self.target_size[1]), min(scaled_h, offset_y + self.target_size[0])
|
143 |
+
img = img.crop((offset_x, offset_y, right, lower))
|
144 |
+
new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color)
|
145 |
+
new_img.paste(img)
|
146 |
+
|
147 |
+
if 'bbox' in anno:
|
148 |
+
# FIXME not fully tested
|
149 |
+
bbox = anno['bbox'].copy() # FIXME copy for debugger inspection, back to inplace
|
150 |
+
bbox[:, :4] *= img_scale
|
151 |
+
box_offset = np.stack([offset_y, offset_x] * 2)
|
152 |
+
bbox -= box_offset
|
153 |
+
clip_boxes_(bbox, (scaled_h, scaled_w))
|
154 |
+
valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1)
|
155 |
+
anno['bbox'] = bbox[valid_indices, :]
|
156 |
+
anno['cls'] = anno['cls'][valid_indices]
|
157 |
+
|
158 |
+
anno['img_scale'] = 1. / img_scale # back to original
|
159 |
+
|
160 |
+
return new_img, anno
|
161 |
+
|
162 |
+
|
163 |
+
class RandomFlip:
|
164 |
+
|
165 |
+
def __init__(self, horizontal=True, vertical=False, prob=0.5):
|
166 |
+
self.horizontal = horizontal
|
167 |
+
self.vertical = vertical
|
168 |
+
self.prob = prob
|
169 |
+
|
170 |
+
def _get_params(self):
|
171 |
+
do_horizontal = random.random() < self.prob if self.horizontal else False
|
172 |
+
do_vertical = random.random() < self.prob if self.vertical else False
|
173 |
+
return do_horizontal, do_vertical
|
174 |
+
|
175 |
+
def __call__(self, img, annotations: dict):
|
176 |
+
do_horizontal, do_vertical = self._get_params()
|
177 |
+
width, height = img.size
|
178 |
+
|
179 |
+
def _fliph(bbox):
|
180 |
+
x_max = width - bbox[:, 1]
|
181 |
+
x_min = width - bbox[:, 3]
|
182 |
+
bbox[:, 1] = x_min
|
183 |
+
bbox[:, 3] = x_max
|
184 |
+
|
185 |
+
def _flipv(bbox):
|
186 |
+
y_max = height - bbox[:, 0]
|
187 |
+
y_min = height - bbox[:, 2]
|
188 |
+
bbox[:, 0] = y_min
|
189 |
+
bbox[:, 2] = y_max
|
190 |
+
|
191 |
+
if do_horizontal and do_vertical:
|
192 |
+
img = img.transpose(Image.ROTATE_180)
|
193 |
+
if 'bbox' in annotations:
|
194 |
+
_fliph(annotations['bbox'])
|
195 |
+
_flipv(annotations['bbox'])
|
196 |
+
elif do_horizontal:
|
197 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
198 |
+
if 'bbox' in annotations:
|
199 |
+
_fliph(annotations['bbox'])
|
200 |
+
elif do_vertical:
|
201 |
+
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
202 |
+
if 'bbox' in annotations:
|
203 |
+
_flipv(annotations['bbox'])
|
204 |
+
|
205 |
+
return img, annotations
|
206 |
+
|
207 |
+
|
208 |
+
def resolve_fill_color(fill_color, img_mean=IMAGENET_DEFAULT_MEAN):
|
209 |
+
if isinstance(fill_color, tuple):
|
210 |
+
assert len(fill_color) == 3
|
211 |
+
fill_color = fill_color
|
212 |
+
else:
|
213 |
+
try:
|
214 |
+
int_color = int(fill_color)
|
215 |
+
fill_color = (int_color,) * 3
|
216 |
+
except ValueError:
|
217 |
+
assert fill_color == 'mean'
|
218 |
+
fill_color = tuple([int(round(255 * x)) for x in img_mean])
|
219 |
+
return fill_color
|
220 |
+
|
221 |
+
|
222 |
+
class Compose:
|
223 |
+
|
224 |
+
def __init__(self, transforms: list):
|
225 |
+
self.transforms = transforms
|
226 |
+
|
227 |
+
def __call__(self, img, annotations: dict):
|
228 |
+
for t in self.transforms:
|
229 |
+
img, annotations = t(img, annotations)
|
230 |
+
return img, annotations
|
231 |
+
|
232 |
+
|
233 |
+
def transforms_coco_eval(
|
234 |
+
img_size=224,
|
235 |
+
interpolation='bilinear',
|
236 |
+
use_prefetcher=False,
|
237 |
+
fill_color='mean',
|
238 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
239 |
+
std=IMAGENET_DEFAULT_STD):
|
240 |
+
|
241 |
+
fill_color = resolve_fill_color(fill_color, mean)
|
242 |
+
|
243 |
+
image_tfl = [
|
244 |
+
ResizePad(
|
245 |
+
target_size=img_size, interpolation=interpolation, fill_color=fill_color),
|
246 |
+
ImageToNumpy(),
|
247 |
+
]
|
248 |
+
|
249 |
+
assert use_prefetcher, "Only supporting prefetcher usage right now"
|
250 |
+
|
251 |
+
image_tf = Compose(image_tfl)
|
252 |
+
return image_tf
|
253 |
+
|
254 |
+
|
255 |
+
def transforms_coco_train(
|
256 |
+
img_size=224,
|
257 |
+
interpolation='random',
|
258 |
+
use_prefetcher=False,
|
259 |
+
fill_color='mean',
|
260 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
261 |
+
std=IMAGENET_DEFAULT_STD):
|
262 |
+
|
263 |
+
fill_color = resolve_fill_color(fill_color, mean)
|
264 |
+
|
265 |
+
image_tfl = [
|
266 |
+
RandomFlip(horizontal=True, prob=0.5),
|
267 |
+
RandomResizePad(
|
268 |
+
target_size=img_size, interpolation=interpolation, fill_color=fill_color),
|
269 |
+
ImageToNumpy(),
|
270 |
+
]
|
271 |
+
|
272 |
+
assert use_prefetcher, "Only supporting prefetcher usage right now"
|
273 |
+
|
274 |
+
image_tf = Compose(image_tfl)
|
275 |
+
return image_tf
|
efficientdet/effdet/data/transforms_albumentation.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
|
3 |
+
from albumentations.augmentations.transforms import (
|
4 |
+
RandomBrightness, Downscale, RandomFog, RandomRain, RandomSnow)
|
5 |
+
|
6 |
+
from albumentations.augmentations.blur.transforms import Blur
|
7 |
+
|
8 |
+
def get_transform():
|
9 |
+
transforms = A.Compose([
|
10 |
+
#HorizontalFlip(p=0.5),
|
11 |
+
#VerticalFlip(p=0.5),
|
12 |
+
#RandomSizedBBoxSafeCrop(700, 700, erosion_rate=0.0, interpolation=1, always_apply=False, p=0.5),
|
13 |
+
Blur(blur_limit=7, always_apply=False, p=0.5),
|
14 |
+
RandomBrightness(limit=0.2, always_apply=False, p=0.5),
|
15 |
+
#Downscale(scale_min=0.5, scale_max=0.9, interpolation=0, always_apply=False, p=0.5),
|
16 |
+
#PadIfNeeded(min_height=1024, min_width=1024, pad_height_divisor=None, pad_width_divisor=None, border_mode=4, value=None, mask_value=None, always_apply=False, p=1.0),
|
17 |
+
#RandomFog(fog_coef_lower=0.3, fog_coef_upper=1, alpha_coef=0.08, always_apply=False, p=0.2),
|
18 |
+
#RandomRain(slant_lower=-10, slant_upper=10, drop_length=20, drop_width=1, drop_color=(200, 200, 200), p=0.2),
|
19 |
+
#RandomSnow(snow_point_lower=0.1, snow_point_upper=0.3, brightness_coeff=2.5, always_apply=False, p=0.2)
|
20 |
+
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_classes'])
|
21 |
+
)
|
22 |
+
return transforms
|
23 |
+
|
efficientdet/effdet/distributed.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch distributed helpers
|
2 |
+
|
3 |
+
Some of this lifted from Detectron2 with other fns added by myself. Some of the Detectron2 fns
|
4 |
+
were intended for use with GLOO PG. I am using NCCL here with default PG so not everything will work
|
5 |
+
as is -RW
|
6 |
+
"""
|
7 |
+
import functools
|
8 |
+
import logging
|
9 |
+
import numpy as np
|
10 |
+
import pickle
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
_LOCAL_PROCESS_GROUP = None
|
15 |
+
"""
|
16 |
+
A torch process group which only includes processes that on the same machine as the current process.
|
17 |
+
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
def get_world_size() -> int:
|
22 |
+
if not dist.is_available():
|
23 |
+
return 1
|
24 |
+
if not dist.is_initialized():
|
25 |
+
return 1
|
26 |
+
return dist.get_world_size()
|
27 |
+
|
28 |
+
|
29 |
+
def get_rank() -> int:
|
30 |
+
if not dist.is_available():
|
31 |
+
return 0
|
32 |
+
if not dist.is_initialized():
|
33 |
+
return 0
|
34 |
+
return dist.get_rank()
|
35 |
+
|
36 |
+
|
37 |
+
def get_local_rank() -> int:
|
38 |
+
"""
|
39 |
+
Returns:
|
40 |
+
The rank of the current process within the local (per-machine) process group.
|
41 |
+
"""
|
42 |
+
if not dist.is_available():
|
43 |
+
return 0
|
44 |
+
if not dist.is_initialized():
|
45 |
+
return 0
|
46 |
+
assert _LOCAL_PROCESS_GROUP is not None
|
47 |
+
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
48 |
+
|
49 |
+
|
50 |
+
def get_local_size() -> int:
|
51 |
+
"""
|
52 |
+
Returns:
|
53 |
+
The size of the per-machine process group,
|
54 |
+
i.e. the number of processes per machine.
|
55 |
+
"""
|
56 |
+
if not dist.is_available():
|
57 |
+
return 1
|
58 |
+
if not dist.is_initialized():
|
59 |
+
return 1
|
60 |
+
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
61 |
+
|
62 |
+
|
63 |
+
def is_main_process() -> bool:
|
64 |
+
return get_rank() == 0
|
65 |
+
|
66 |
+
|
67 |
+
def synchronize():
|
68 |
+
"""
|
69 |
+
Helper function to synchronize (barrier) among all processes when
|
70 |
+
using distributed training
|
71 |
+
"""
|
72 |
+
if not dist.is_available():
|
73 |
+
return
|
74 |
+
if not dist.is_initialized():
|
75 |
+
return
|
76 |
+
world_size = dist.get_world_size()
|
77 |
+
if world_size == 1:
|
78 |
+
return
|
79 |
+
dist.barrier()
|
80 |
+
|
81 |
+
|
82 |
+
@functools.lru_cache()
|
83 |
+
def _get_global_gloo_group():
|
84 |
+
"""
|
85 |
+
Return a process group based on gloo backend, containing all the ranks
|
86 |
+
The result is cached.
|
87 |
+
"""
|
88 |
+
if dist.get_backend() == "nccl":
|
89 |
+
return dist.new_group(backend="gloo")
|
90 |
+
else:
|
91 |
+
return dist.group.WORLD
|
92 |
+
|
93 |
+
|
94 |
+
def _serialize_to_tensor(data, group):
|
95 |
+
backend = dist.get_backend(group)
|
96 |
+
assert backend in ["gloo", "nccl"]
|
97 |
+
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
98 |
+
|
99 |
+
buffer = pickle.dumps(data)
|
100 |
+
if len(buffer) > 1024 ** 3:
|
101 |
+
logger = logging.getLogger(__name__)
|
102 |
+
logger.warning(
|
103 |
+
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
104 |
+
get_rank(), len(buffer) / (1024 ** 3), device
|
105 |
+
)
|
106 |
+
)
|
107 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
108 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
109 |
+
return tensor
|
110 |
+
|
111 |
+
|
112 |
+
def _pad_to_largest_tensor(tensor, group):
|
113 |
+
"""
|
114 |
+
Returns:
|
115 |
+
list[int]: size of the tensor, on each rank
|
116 |
+
Tensor: padded tensor that has the max size
|
117 |
+
"""
|
118 |
+
world_size = dist.get_world_size(group=group)
|
119 |
+
assert (
|
120 |
+
world_size >= 1
|
121 |
+
), "comm.gather/all_gather must be called from ranks within the given group!"
|
122 |
+
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
123 |
+
size_list = [
|
124 |
+
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
|
125 |
+
]
|
126 |
+
dist.all_gather(size_list, local_size, group=group)
|
127 |
+
size_list = [int(size.item()) for size in size_list]
|
128 |
+
|
129 |
+
max_size = max(size_list)
|
130 |
+
|
131 |
+
# we pad the tensor because torch all_gather does not support
|
132 |
+
# gathering tensors of different shapes
|
133 |
+
if local_size != max_size:
|
134 |
+
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
|
135 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
136 |
+
return size_list, tensor
|
137 |
+
|
138 |
+
|
139 |
+
def all_gather(data, group=None):
|
140 |
+
"""
|
141 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
142 |
+
Args:
|
143 |
+
data: any picklable object
|
144 |
+
group: a torch process group. By default, will use a group which
|
145 |
+
contains all ranks on gloo backend.
|
146 |
+
Returns:
|
147 |
+
list[data]: list of data gathered from each rank
|
148 |
+
"""
|
149 |
+
if get_world_size() == 1:
|
150 |
+
return [data]
|
151 |
+
if group is None:
|
152 |
+
group = _get_global_gloo_group()
|
153 |
+
if dist.get_world_size(group) == 1:
|
154 |
+
return [data]
|
155 |
+
|
156 |
+
tensor = _serialize_to_tensor(data, group)
|
157 |
+
|
158 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
159 |
+
max_size = max(size_list)
|
160 |
+
|
161 |
+
# receiving Tensor from all ranks
|
162 |
+
tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list]
|
163 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
164 |
+
|
165 |
+
data_list = []
|
166 |
+
for size, tensor in zip(size_list, tensor_list):
|
167 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
168 |
+
data_list.append(pickle.loads(buffer))
|
169 |
+
|
170 |
+
return data_list
|
171 |
+
|
172 |
+
|
173 |
+
def gather(data, dst=0, group=None):
|
174 |
+
"""
|
175 |
+
Run gather on arbitrary picklable data (not necessarily tensors).
|
176 |
+
Args:
|
177 |
+
data: any picklable object
|
178 |
+
dst (int): destination rank
|
179 |
+
group: a torch process group. By default, will use a group which
|
180 |
+
contains all ranks on gloo backend.
|
181 |
+
Returns:
|
182 |
+
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
183 |
+
an empty list.
|
184 |
+
"""
|
185 |
+
if get_world_size() == 1:
|
186 |
+
return [data]
|
187 |
+
if group is None:
|
188 |
+
group = _get_global_gloo_group()
|
189 |
+
if dist.get_world_size(group=group) == 1:
|
190 |
+
return [data]
|
191 |
+
rank = dist.get_rank(group=group)
|
192 |
+
|
193 |
+
tensor = _serialize_to_tensor(data, group)
|
194 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
195 |
+
|
196 |
+
# receiving Tensor from all ranks
|
197 |
+
if rank == dst:
|
198 |
+
max_size = max(size_list)
|
199 |
+
tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list]
|
200 |
+
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
201 |
+
|
202 |
+
data_list = []
|
203 |
+
for size, tensor in zip(size_list, tensor_list):
|
204 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
205 |
+
data_list.append(pickle.loads(buffer))
|
206 |
+
return data_list
|
207 |
+
else:
|
208 |
+
dist.gather(tensor, [], dst=dst, group=group)
|
209 |
+
return []
|
210 |
+
|
211 |
+
|
212 |
+
def shared_random_seed():
|
213 |
+
"""
|
214 |
+
Returns:
|
215 |
+
int: a random number that is the same across all workers.
|
216 |
+
If workers need a shared RNG, they can use this shared seed to
|
217 |
+
create one.
|
218 |
+
All workers must call this function, otherwise it will deadlock.
|
219 |
+
"""
|
220 |
+
ints = np.random.randint(2 ** 31)
|
221 |
+
all_ints = all_gather(ints)
|
222 |
+
return all_ints[0]
|
223 |
+
|
224 |
+
|
225 |
+
def reduce_dict(input_dict, average=True):
|
226 |
+
"""
|
227 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
228 |
+
0 has the reduced results.
|
229 |
+
Args:
|
230 |
+
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
231 |
+
average (bool): whether to do average or sum
|
232 |
+
Returns:
|
233 |
+
a dict with the same keys as input_dict, after reduction.
|
234 |
+
"""
|
235 |
+
world_size = get_world_size()
|
236 |
+
if world_size < 2:
|
237 |
+
return input_dict
|
238 |
+
with torch.no_grad():
|
239 |
+
names = []
|
240 |
+
values = []
|
241 |
+
# sort the keys so that they are consistent across processes
|
242 |
+
for k in sorted(input_dict.keys()):
|
243 |
+
names.append(k)
|
244 |
+
values.append(input_dict[k])
|
245 |
+
values = torch.stack(values, dim=0)
|
246 |
+
dist.reduce(values, dst=0)
|
247 |
+
if dist.get_rank() == 0 and average:
|
248 |
+
# only main process gets accumulated, so only divide by
|
249 |
+
# world_size in this case
|
250 |
+
values /= world_size
|
251 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
252 |
+
return reduced_dict
|
253 |
+
|
254 |
+
|
255 |
+
def all_gather_container(container, group=None, cat_dim=0):
|
256 |
+
group = group or dist.group.WORLD
|
257 |
+
world_size = dist.get_world_size(group)
|
258 |
+
|
259 |
+
def _do_gather(tensor):
|
260 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
261 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
262 |
+
return torch.cat(tensor_list, dim=cat_dim)
|
263 |
+
|
264 |
+
if isinstance(container, dict):
|
265 |
+
gathered = dict()
|
266 |
+
for k, v in container.items():
|
267 |
+
v = _do_gather(v)
|
268 |
+
gathered[k] = v
|
269 |
+
return gathered
|
270 |
+
elif isinstance(container, (list, tuple)):
|
271 |
+
gathered = [_do_gather(v) for v in container]
|
272 |
+
if isinstance(container, tuple):
|
273 |
+
gathered = tuple(gathered)
|
274 |
+
return gathered
|
275 |
+
else:
|
276 |
+
# if not a dict, list, tuple, expect a singular tensor
|
277 |
+
assert isinstance(container, torch.Tensor)
|
278 |
+
return _do_gather(container)
|
279 |
+
|
280 |
+
|
281 |
+
def gather_container(container, dst, group=None, cat_dim=0):
|
282 |
+
group = group or dist.group.WORLD
|
283 |
+
world_size = dist.get_world_size(group)
|
284 |
+
this_rank = dist.get_rank(group)
|
285 |
+
|
286 |
+
def _do_gather(tensor):
|
287 |
+
if this_rank == dst:
|
288 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
289 |
+
else:
|
290 |
+
tensor_list = None
|
291 |
+
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
292 |
+
return torch.cat(tensor_list, dim=cat_dim)
|
293 |
+
|
294 |
+
if isinstance(container, dict):
|
295 |
+
gathered = dict()
|
296 |
+
for k, v in container.items():
|
297 |
+
v = _do_gather(v)
|
298 |
+
gathered[k] = v
|
299 |
+
return gathered
|
300 |
+
elif isinstance(container, (list, tuple)):
|
301 |
+
gathered = [_do_gather(v) for v in container]
|
302 |
+
if isinstance(container, tuple):
|
303 |
+
gathered = tuple(gathered)
|
304 |
+
return gathered
|
305 |
+
else:
|
306 |
+
# if not a dict, list, tuple, expect a singular tensor
|
307 |
+
assert isinstance(container, torch.Tensor)
|
308 |
+
return _do_gather(container)
|
efficientdet/effdet/efficientdet.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch EfficientDet model
|
2 |
+
|
3 |
+
Based on official Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet
|
4 |
+
Paper: https://arxiv.org/abs/1911.09070
|
5 |
+
|
6 |
+
Hacked together by Ross Wightman
|
7 |
+
"""
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
from collections import OrderedDict
|
13 |
+
from typing import List, Callable
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
|
17 |
+
from timm import create_model
|
18 |
+
from timm.models.layers import create_conv2d, drop_path, create_pool2d, Swish, get_act_layer
|
19 |
+
from .config import get_fpn_config, set_config_writeable, set_config_readonly
|
20 |
+
|
21 |
+
_DEBUG = False
|
22 |
+
|
23 |
+
_ACT_LAYER = Swish
|
24 |
+
|
25 |
+
|
26 |
+
class SequentialList(nn.Sequential):
|
27 |
+
""" This module exists to work around torchscript typing issues list -> list"""
|
28 |
+
def __init__(self, *args):
|
29 |
+
super(SequentialList, self).__init__(*args)
|
30 |
+
|
31 |
+
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
32 |
+
for module in self:
|
33 |
+
x = module(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class ConvBnAct2d(nn.Module):
|
38 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding='', bias=False,
|
39 |
+
norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER):
|
40 |
+
super(ConvBnAct2d, self).__init__()
|
41 |
+
self.conv = create_conv2d(
|
42 |
+
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias)
|
43 |
+
self.bn = None if norm_layer is None else norm_layer(out_channels)
|
44 |
+
self.act = None if act_layer is None else act_layer(inplace=True)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.conv(x)
|
48 |
+
if self.bn is not None:
|
49 |
+
x = self.bn(x)
|
50 |
+
if self.act is not None:
|
51 |
+
x = self.act(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class SeparableConv2d(nn.Module):
|
56 |
+
""" Separable Conv
|
57 |
+
"""
|
58 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
|
59 |
+
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER):
|
60 |
+
super(SeparableConv2d, self).__init__()
|
61 |
+
self.conv_dw = create_conv2d(
|
62 |
+
in_channels, int(in_channels * channel_multiplier), kernel_size,
|
63 |
+
stride=stride, dilation=dilation, padding=padding, depthwise=True)
|
64 |
+
|
65 |
+
self.conv_pw = create_conv2d(
|
66 |
+
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
|
67 |
+
|
68 |
+
self.bn = None if norm_layer is None else norm_layer(out_channels)
|
69 |
+
self.act = None if act_layer is None else act_layer(inplace=True)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.conv_dw(x)
|
73 |
+
x = self.conv_pw(x)
|
74 |
+
if self.bn is not None:
|
75 |
+
x = self.bn(x)
|
76 |
+
if self.act is not None:
|
77 |
+
x = self.act(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class ResampleFeatureMap(nn.Sequential):
|
82 |
+
|
83 |
+
def __init__(self, in_channels, out_channels, reduction_ratio=1., pad_type='', pooling_type='max',
|
84 |
+
norm_layer=nn.BatchNorm2d, apply_bn=False, conv_after_downsample=False, redundant_bias=False):
|
85 |
+
super(ResampleFeatureMap, self).__init__()
|
86 |
+
pooling_type = pooling_type or 'max'
|
87 |
+
self.in_channels = in_channels
|
88 |
+
self.out_channels = out_channels
|
89 |
+
self.reduction_ratio = reduction_ratio
|
90 |
+
self.conv_after_downsample = conv_after_downsample
|
91 |
+
|
92 |
+
conv = None
|
93 |
+
if in_channels != out_channels:
|
94 |
+
conv = ConvBnAct2d(
|
95 |
+
in_channels, out_channels, kernel_size=1, padding=pad_type,
|
96 |
+
norm_layer=norm_layer if apply_bn else None,
|
97 |
+
bias=not apply_bn or redundant_bias, act_layer=None)
|
98 |
+
|
99 |
+
if reduction_ratio > 1:
|
100 |
+
stride_size = int(reduction_ratio)
|
101 |
+
if conv is not None and not self.conv_after_downsample:
|
102 |
+
self.add_module('conv', conv)
|
103 |
+
self.add_module(
|
104 |
+
'downsample',
|
105 |
+
create_pool2d(
|
106 |
+
pooling_type, kernel_size=stride_size + 1, stride=stride_size, padding=pad_type))
|
107 |
+
if conv is not None and self.conv_after_downsample:
|
108 |
+
self.add_module('conv', conv)
|
109 |
+
else:
|
110 |
+
if conv is not None:
|
111 |
+
self.add_module('conv', conv)
|
112 |
+
if reduction_ratio < 1:
|
113 |
+
scale = int(1 // reduction_ratio)
|
114 |
+
self.add_module('upsample', nn.UpsamplingNearest2d(scale_factor=scale))
|
115 |
+
|
116 |
+
# def forward(self, x):
|
117 |
+
# # here for debugging only
|
118 |
+
# assert x.shape[1] == self.in_channels
|
119 |
+
# if self.reduction_ratio > 1:
|
120 |
+
# if hasattr(self, 'conv') and not self.conv_after_downsample:
|
121 |
+
# x = self.conv(x)
|
122 |
+
# x = self.downsample(x)
|
123 |
+
# if hasattr(self, 'conv') and self.conv_after_downsample:
|
124 |
+
# x = self.conv(x)
|
125 |
+
# else:
|
126 |
+
# if hasattr(self, 'conv'):
|
127 |
+
# x = self.conv(x)
|
128 |
+
# if self.reduction_ratio < 1:
|
129 |
+
# x = self.upsample(x)
|
130 |
+
# return x
|
131 |
+
|
132 |
+
|
133 |
+
class FpnCombine(nn.Module):
|
134 |
+
def __init__(self, feature_info, fpn_config, fpn_channels, inputs_offsets, target_reduction, pad_type='',
|
135 |
+
pooling_type='max', norm_layer=nn.BatchNorm2d, apply_bn_for_resampling=False,
|
136 |
+
conv_after_downsample=False, redundant_bias=False, weight_method='attn'):
|
137 |
+
super(FpnCombine, self).__init__()
|
138 |
+
self.inputs_offsets = inputs_offsets
|
139 |
+
self.weight_method = weight_method
|
140 |
+
|
141 |
+
self.resample = nn.ModuleDict()
|
142 |
+
for idx, offset in enumerate(inputs_offsets):
|
143 |
+
in_channels = fpn_channels
|
144 |
+
if offset < len(feature_info):
|
145 |
+
in_channels = feature_info[offset]['num_chs']
|
146 |
+
input_reduction = feature_info[offset]['reduction']
|
147 |
+
else:
|
148 |
+
node_idx = offset - len(feature_info)
|
149 |
+
input_reduction = fpn_config.nodes[node_idx]['reduction']
|
150 |
+
reduction_ratio = target_reduction / input_reduction
|
151 |
+
self.resample[str(offset)] = ResampleFeatureMap(
|
152 |
+
in_channels, fpn_channels, reduction_ratio=reduction_ratio, pad_type=pad_type,
|
153 |
+
pooling_type=pooling_type, norm_layer=norm_layer, apply_bn=apply_bn_for_resampling,
|
154 |
+
conv_after_downsample=conv_after_downsample, redundant_bias=redundant_bias)
|
155 |
+
|
156 |
+
if weight_method == 'attn' or weight_method == 'fastattn':
|
157 |
+
self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets)), requires_grad=True) # WSM
|
158 |
+
else:
|
159 |
+
self.edge_weights = None
|
160 |
+
|
161 |
+
def forward(self, x: List[torch.Tensor]):
|
162 |
+
dtype = x[0].dtype
|
163 |
+
nodes = []
|
164 |
+
for offset, resample in zip(self.inputs_offsets, self.resample.values()):
|
165 |
+
input_node = x[offset]
|
166 |
+
input_node = resample(input_node)
|
167 |
+
nodes.append(input_node)
|
168 |
+
|
169 |
+
if self.weight_method == 'attn':
|
170 |
+
normalized_weights = torch.softmax(self.edge_weights.to(dtype=dtype), dim=0)
|
171 |
+
out = torch.stack(nodes, dim=-1) * normalized_weights
|
172 |
+
elif self.weight_method == 'fastattn':
|
173 |
+
edge_weights = nn.functional.relu(self.edge_weights.to(dtype=dtype))
|
174 |
+
weights_sum = torch.sum(edge_weights)
|
175 |
+
out = torch.stack(
|
176 |
+
[(nodes[i] * edge_weights[i]) / (weights_sum + 0.0001) for i in range(len(nodes))], dim=-1)
|
177 |
+
elif self.weight_method == 'sum':
|
178 |
+
out = torch.stack(nodes, dim=-1)
|
179 |
+
else:
|
180 |
+
raise ValueError('unknown weight_method {}'.format(self.weight_method))
|
181 |
+
out = torch.sum(out, dim=-1)
|
182 |
+
return out
|
183 |
+
|
184 |
+
|
185 |
+
class Fnode(nn.Module):
|
186 |
+
""" A simple wrapper used in place of nn.Sequential for torchscript typing
|
187 |
+
Handles input type List[Tensor] -> output type Tensor
|
188 |
+
"""
|
189 |
+
def __init__(self, combine: nn.Module, after_combine: nn.Module):
|
190 |
+
super(Fnode, self).__init__()
|
191 |
+
self.combine = combine
|
192 |
+
self.after_combine = after_combine
|
193 |
+
|
194 |
+
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
|
195 |
+
return self.after_combine(self.combine(x))
|
196 |
+
|
197 |
+
|
198 |
+
class BiFpnLayer(nn.Module):
|
199 |
+
def __init__(self, feature_info, fpn_config, fpn_channels, num_levels=5, pad_type='',
|
200 |
+
pooling_type='max', norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER,
|
201 |
+
apply_bn_for_resampling=False, conv_after_downsample=True, conv_bn_relu_pattern=False,
|
202 |
+
separable_conv=True, redundant_bias=False):
|
203 |
+
super(BiFpnLayer, self).__init__()
|
204 |
+
self.num_levels = num_levels
|
205 |
+
self.conv_bn_relu_pattern = False
|
206 |
+
|
207 |
+
self.feature_info = []
|
208 |
+
self.fnode = nn.ModuleList()
|
209 |
+
for i, fnode_cfg in enumerate(fpn_config.nodes):
|
210 |
+
logging.debug('fnode {} : {}'.format(i, fnode_cfg))
|
211 |
+
reduction = fnode_cfg['reduction']
|
212 |
+
combine = FpnCombine(
|
213 |
+
feature_info, fpn_config, fpn_channels, tuple(fnode_cfg['inputs_offsets']),
|
214 |
+
target_reduction=reduction, pad_type=pad_type, pooling_type=pooling_type, norm_layer=norm_layer,
|
215 |
+
apply_bn_for_resampling=apply_bn_for_resampling, conv_after_downsample=conv_after_downsample,
|
216 |
+
redundant_bias=redundant_bias, weight_method=fnode_cfg['weight_method'])
|
217 |
+
|
218 |
+
after_combine = nn.Sequential()
|
219 |
+
conv_kwargs = dict(
|
220 |
+
in_channels=fpn_channels, out_channels=fpn_channels, kernel_size=3, padding=pad_type,
|
221 |
+
bias=False, norm_layer=norm_layer, act_layer=act_layer)
|
222 |
+
if not conv_bn_relu_pattern:
|
223 |
+
conv_kwargs['bias'] = redundant_bias
|
224 |
+
conv_kwargs['act_layer'] = None
|
225 |
+
after_combine.add_module('act', act_layer(inplace=True))
|
226 |
+
after_combine.add_module(
|
227 |
+
'conv', SeparableConv2d(**conv_kwargs) if separable_conv else ConvBnAct2d(**conv_kwargs))
|
228 |
+
|
229 |
+
self.fnode.append(Fnode(combine=combine, after_combine=after_combine))
|
230 |
+
self.feature_info.append(dict(num_chs=fpn_channels, reduction=reduction))
|
231 |
+
|
232 |
+
self.feature_info = self.feature_info[-num_levels::]
|
233 |
+
|
234 |
+
def forward(self, x: List[torch.Tensor]):
|
235 |
+
for fn in self.fnode:
|
236 |
+
x.append(fn(x))
|
237 |
+
return x[-self.num_levels::]
|
238 |
+
|
239 |
+
|
240 |
+
class BiFpn(nn.Module):
|
241 |
+
|
242 |
+
def __init__(self, config, feature_info):
|
243 |
+
super(BiFpn, self).__init__()
|
244 |
+
self.num_levels = config.num_levels
|
245 |
+
norm_layer = config.norm_layer or nn.BatchNorm2d
|
246 |
+
if config.norm_kwargs:
|
247 |
+
norm_layer = partial(norm_layer, **config.norm_kwargs)
|
248 |
+
act_layer = get_act_layer(config.act_type) or _ACT_LAYER
|
249 |
+
fpn_config = config.fpn_config or get_fpn_config(
|
250 |
+
config.fpn_name, min_level=config.min_level, max_level=config.max_level)
|
251 |
+
|
252 |
+
self.resample = nn.ModuleDict()
|
253 |
+
for level in range(config.num_levels):
|
254 |
+
if level < len(feature_info):
|
255 |
+
in_chs = feature_info[level]['num_chs']
|
256 |
+
reduction = feature_info[level]['reduction']
|
257 |
+
else:
|
258 |
+
# Adds a coarser level by downsampling the last feature map
|
259 |
+
reduction_ratio = 2
|
260 |
+
self.resample[str(level)] = ResampleFeatureMap(
|
261 |
+
in_channels=in_chs,
|
262 |
+
out_channels=config.fpn_channels,
|
263 |
+
pad_type=config.pad_type,
|
264 |
+
pooling_type=config.pooling_type,
|
265 |
+
norm_layer=norm_layer,
|
266 |
+
reduction_ratio=reduction_ratio,
|
267 |
+
apply_bn=config.apply_bn_for_resampling,
|
268 |
+
conv_after_downsample=config.conv_after_downsample,
|
269 |
+
redundant_bias=config.redundant_bias,
|
270 |
+
)
|
271 |
+
in_chs = config.fpn_channels
|
272 |
+
reduction = int(reduction * reduction_ratio)
|
273 |
+
feature_info.append(dict(num_chs=in_chs, reduction=reduction))
|
274 |
+
|
275 |
+
self.cell = SequentialList()
|
276 |
+
for rep in range(config.fpn_cell_repeats):
|
277 |
+
logging.debug('building cell {}'.format(rep))
|
278 |
+
fpn_layer = BiFpnLayer(
|
279 |
+
feature_info=feature_info,
|
280 |
+
fpn_config=fpn_config,
|
281 |
+
fpn_channels=config.fpn_channels,
|
282 |
+
num_levels=config.num_levels,
|
283 |
+
pad_type=config.pad_type,
|
284 |
+
pooling_type=config.pooling_type,
|
285 |
+
norm_layer=norm_layer,
|
286 |
+
act_layer=act_layer,
|
287 |
+
separable_conv=config.separable_conv,
|
288 |
+
apply_bn_for_resampling=config.apply_bn_for_resampling,
|
289 |
+
conv_after_downsample=config.conv_after_downsample,
|
290 |
+
conv_bn_relu_pattern=config.conv_bn_relu_pattern,
|
291 |
+
redundant_bias=config.redundant_bias,
|
292 |
+
)
|
293 |
+
self.cell.add_module(str(rep), fpn_layer)
|
294 |
+
feature_info = fpn_layer.feature_info
|
295 |
+
|
296 |
+
def forward(self, x: List[torch.Tensor]):
|
297 |
+
for resample in self.resample.values():
|
298 |
+
x.append(resample(x[-1]))
|
299 |
+
x = self.cell(x)
|
300 |
+
return x
|
301 |
+
|
302 |
+
|
303 |
+
class HeadNet(nn.Module):
|
304 |
+
|
305 |
+
def __init__(self, config, num_outputs):
|
306 |
+
super(HeadNet, self).__init__()
|
307 |
+
self.num_levels = config.num_levels
|
308 |
+
self.bn_level_first = getattr(config, 'head_bn_level_first', False)
|
309 |
+
norm_layer = config.norm_layer or nn.BatchNorm2d
|
310 |
+
if config.norm_kwargs:
|
311 |
+
norm_layer = partial(norm_layer, **config.norm_kwargs)
|
312 |
+
act_layer = get_act_layer(config.act_type) or _ACT_LAYER
|
313 |
+
|
314 |
+
# Build convolution repeats
|
315 |
+
conv_fn = SeparableConv2d if config.separable_conv else ConvBnAct2d
|
316 |
+
conv_kwargs = dict(
|
317 |
+
in_channels=config.fpn_channels, out_channels=config.fpn_channels, kernel_size=3,
|
318 |
+
padding=config.pad_type, bias=config.redundant_bias, act_layer=None, norm_layer=None)
|
319 |
+
self.conv_rep = nn.ModuleList([conv_fn(**conv_kwargs) for _ in range(config.box_class_repeats)])
|
320 |
+
|
321 |
+
# Build batchnorm repeats. There is a unique batchnorm per feature level for each repeat.
|
322 |
+
# This can be organized with repeats first or feature levels first in module lists, the original models
|
323 |
+
# and weights were setup with repeats first, levels first is required for efficient torchscript usage.
|
324 |
+
self.bn_rep = nn.ModuleList()
|
325 |
+
if self.bn_level_first:
|
326 |
+
for _ in range(self.num_levels):
|
327 |
+
self.bn_rep.append(nn.ModuleList([
|
328 |
+
norm_layer(config.fpn_channels) for _ in range(config.box_class_repeats)]))
|
329 |
+
else:
|
330 |
+
for _ in range(config.box_class_repeats):
|
331 |
+
self.bn_rep.append(nn.ModuleList([
|
332 |
+
nn.Sequential(OrderedDict([('bn', norm_layer(config.fpn_channels))]))
|
333 |
+
for _ in range(self.num_levels)]))
|
334 |
+
|
335 |
+
self.act = act_layer(inplace=True)
|
336 |
+
|
337 |
+
# Prediction (output) layer. Has bias with special init reqs, see init fn.
|
338 |
+
num_anchors = len(config.aspect_ratios) * config.num_scales
|
339 |
+
predict_kwargs = dict(
|
340 |
+
in_channels=config.fpn_channels, out_channels=num_outputs * num_anchors, kernel_size=3,
|
341 |
+
padding=config.pad_type, bias=True, norm_layer=None, act_layer=None)
|
342 |
+
self.predict = conv_fn(**predict_kwargs)
|
343 |
+
|
344 |
+
@torch.jit.ignore()
|
345 |
+
def toggle_bn_level_first(self):
|
346 |
+
""" Toggle the batchnorm layers between feature level first vs repeat first access pattern
|
347 |
+
Limitations in torchscript require feature levels to be iterated over first.
|
348 |
+
|
349 |
+
This function can be used to allow loading weights in the original order, and then toggle before
|
350 |
+
jit scripting the model.
|
351 |
+
"""
|
352 |
+
with torch.no_grad():
|
353 |
+
new_bn_rep = nn.ModuleList()
|
354 |
+
for i in range(len(self.bn_rep[0])):
|
355 |
+
bn_first = nn.ModuleList()
|
356 |
+
for r in self.bn_rep.children():
|
357 |
+
m = r[i]
|
358 |
+
# NOTE original rep first model def has extra Sequential container with 'bn', this was
|
359 |
+
# flattened in the level first definition.
|
360 |
+
bn_first.append(m[0] if isinstance(m, nn.Sequential) else nn.Sequential(OrderedDict([('bn', m)])))
|
361 |
+
new_bn_rep.append(bn_first)
|
362 |
+
self.bn_level_first = not self.bn_level_first
|
363 |
+
self.bn_rep = new_bn_rep
|
364 |
+
|
365 |
+
@torch.jit.ignore()
|
366 |
+
def _forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
367 |
+
outputs = []
|
368 |
+
for level in range(self.num_levels):
|
369 |
+
x_level = x[level]
|
370 |
+
for conv, bn in zip(self.conv_rep, self.bn_rep):
|
371 |
+
x_level = conv(x_level)
|
372 |
+
x_level = bn[level](x_level) # this is not allowed in torchscript
|
373 |
+
x_level = self.act(x_level)
|
374 |
+
outputs.append(self.predict(x_level))
|
375 |
+
return outputs
|
376 |
+
|
377 |
+
def _forward_level_first(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
378 |
+
outputs = []
|
379 |
+
for level, bn_rep in enumerate(self.bn_rep): # iterating over first bn dim first makes TS happy
|
380 |
+
x_level = x[level]
|
381 |
+
for conv, bn in zip(self.conv_rep, bn_rep):
|
382 |
+
x_level = conv(x_level)
|
383 |
+
x_level = bn(x_level)
|
384 |
+
x_level = self.act(x_level)
|
385 |
+
outputs.append(self.predict(x_level))
|
386 |
+
return outputs
|
387 |
+
|
388 |
+
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
389 |
+
if self.bn_level_first:
|
390 |
+
return self._forward_level_first(x)
|
391 |
+
else:
|
392 |
+
return self._forward(x)
|
393 |
+
|
394 |
+
|
395 |
+
def _init_weight(m, n='', ):
|
396 |
+
""" Weight initialization as per Tensorflow official implementations.
|
397 |
+
"""
|
398 |
+
|
399 |
+
def _fan_in_out(w, groups=1):
|
400 |
+
dimensions = w.dim()
|
401 |
+
if dimensions < 2:
|
402 |
+
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
403 |
+
num_input_fmaps = w.size(1)
|
404 |
+
num_output_fmaps = w.size(0)
|
405 |
+
receptive_field_size = 1
|
406 |
+
if w.dim() > 2:
|
407 |
+
receptive_field_size = w[0][0].numel()
|
408 |
+
fan_in = num_input_fmaps * receptive_field_size
|
409 |
+
fan_out = num_output_fmaps * receptive_field_size
|
410 |
+
fan_out //= groups
|
411 |
+
return fan_in, fan_out
|
412 |
+
|
413 |
+
def _glorot_uniform(w, gain=1, groups=1):
|
414 |
+
fan_in, fan_out = _fan_in_out(w, groups)
|
415 |
+
gain /= max(1., (fan_in + fan_out) / 2.) # fan avg
|
416 |
+
limit = math.sqrt(3.0 * gain)
|
417 |
+
w.data.uniform_(-limit, limit)
|
418 |
+
|
419 |
+
def _variance_scaling(w, gain=1, groups=1):
|
420 |
+
fan_in, fan_out = _fan_in_out(w, groups)
|
421 |
+
gain /= max(1., fan_in) # fan in
|
422 |
+
# gain /= max(1., (fan_in + fan_out) / 2.) # fan
|
423 |
+
|
424 |
+
# should it be normal or trunc normal? using normal for now since no good trunc in PT
|
425 |
+
# constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
|
426 |
+
# std = math.sqrt(gain) / .87962566103423978
|
427 |
+
# w.data.trunc_normal(std=std)
|
428 |
+
std = math.sqrt(gain)
|
429 |
+
w.data.normal_(std=std)
|
430 |
+
|
431 |
+
if isinstance(m, SeparableConv2d):
|
432 |
+
if 'box_net' in n or 'class_net' in n:
|
433 |
+
_variance_scaling(m.conv_dw.weight, groups=m.conv_dw.groups)
|
434 |
+
_variance_scaling(m.conv_pw.weight)
|
435 |
+
if m.conv_pw.bias is not None:
|
436 |
+
if 'class_net.predict' in n:
|
437 |
+
m.conv_pw.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
|
438 |
+
else:
|
439 |
+
m.conv_pw.bias.data.zero_()
|
440 |
+
else:
|
441 |
+
_glorot_uniform(m.conv_dw.weight, groups=m.conv_dw.groups)
|
442 |
+
_glorot_uniform(m.conv_pw.weight)
|
443 |
+
if m.conv_pw.bias is not None:
|
444 |
+
m.conv_pw.bias.data.zero_()
|
445 |
+
elif isinstance(m, ConvBnAct2d):
|
446 |
+
if 'box_net' in n or 'class_net' in n:
|
447 |
+
m.conv.weight.data.normal_(std=.01)
|
448 |
+
if m.conv.bias is not None:
|
449 |
+
if 'class_net.predict' in n:
|
450 |
+
m.conv.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
|
451 |
+
else:
|
452 |
+
m.conv.bias.data.zero_()
|
453 |
+
else:
|
454 |
+
_glorot_uniform(m.conv.weight)
|
455 |
+
if m.conv.bias is not None:
|
456 |
+
m.conv.bias.data.zero_()
|
457 |
+
elif isinstance(m, nn.BatchNorm2d):
|
458 |
+
# looks like all bn init the same?
|
459 |
+
m.weight.data.fill_(1.0)
|
460 |
+
m.bias.data.zero_()
|
461 |
+
|
462 |
+
|
463 |
+
def _init_weight_alt(m, n='', ):
|
464 |
+
""" Weight initialization alternative, based on EfficientNet bacbkone init w/ class bias addition
|
465 |
+
NOTE: this will likely be removed after some experimentation
|
466 |
+
"""
|
467 |
+
if isinstance(m, nn.Conv2d):
|
468 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
469 |
+
fan_out //= m.groups
|
470 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
471 |
+
if m.bias is not None:
|
472 |
+
if 'class_net.predict' in n:
|
473 |
+
m.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
|
474 |
+
else:
|
475 |
+
m.bias.data.zero_()
|
476 |
+
elif isinstance(m, nn.BatchNorm2d):
|
477 |
+
m.weight.data.fill_(1.0)
|
478 |
+
m.bias.data.zero_()
|
479 |
+
|
480 |
+
|
481 |
+
def get_feature_info(backbone):
|
482 |
+
if isinstance(backbone.feature_info, Callable):
|
483 |
+
# old accessor for timm versions <= 0.1.30, efficientnet and mobilenetv3 and related nets only
|
484 |
+
feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction'])
|
485 |
+
for i, f in enumerate(backbone.feature_info())]
|
486 |
+
else:
|
487 |
+
# new feature info accessor, timm >= 0.2, all models supported
|
488 |
+
feature_info = backbone.feature_info.get_dicts(keys=['num_chs', 'reduction'])
|
489 |
+
return feature_info
|
490 |
+
|
491 |
+
|
492 |
+
class EfficientDet(nn.Module):
|
493 |
+
|
494 |
+
def __init__(self, config, pretrained_backbone=True, alternate_init=False):
|
495 |
+
super(EfficientDet, self).__init__()
|
496 |
+
self.config = config
|
497 |
+
set_config_readonly(self.config)
|
498 |
+
self.backbone = create_model(
|
499 |
+
config.backbone_name, features_only=True, out_indices=(2, 3, 4),
|
500 |
+
pretrained=pretrained_backbone, **config.backbone_args)
|
501 |
+
feature_info = get_feature_info(self.backbone)
|
502 |
+
self.fpn = BiFpn(self.config, feature_info)
|
503 |
+
self.class_net = HeadNet(self.config, num_outputs=self.config.num_classes)
|
504 |
+
self.box_net = HeadNet(self.config, num_outputs=4)
|
505 |
+
|
506 |
+
for n, m in self.named_modules():
|
507 |
+
if 'backbone' not in n:
|
508 |
+
if alternate_init:
|
509 |
+
_init_weight_alt(m, n)
|
510 |
+
else:
|
511 |
+
_init_weight(m, n)
|
512 |
+
|
513 |
+
@torch.jit.ignore()
|
514 |
+
def reset_head(self, num_classes=None, aspect_ratios=None, num_scales=None, alternate_init=False):
|
515 |
+
reset_class_head = False
|
516 |
+
reset_box_head = False
|
517 |
+
set_config_writeable(self.config)
|
518 |
+
if num_classes is not None:
|
519 |
+
reset_class_head = True
|
520 |
+
self.config.num_classes = num_classes
|
521 |
+
if aspect_ratios is not None:
|
522 |
+
reset_box_head = True
|
523 |
+
self.config.aspect_ratios = aspect_ratios
|
524 |
+
if num_scales is not None:
|
525 |
+
reset_box_head = True
|
526 |
+
self.config.num_scales = num_scales
|
527 |
+
set_config_readonly(self.config)
|
528 |
+
|
529 |
+
if reset_class_head:
|
530 |
+
self.class_net = HeadNet(self.config, num_outputs=self.config.num_classes)
|
531 |
+
for n, m in self.class_net.named_modules(prefix='class_net'):
|
532 |
+
if alternate_init:
|
533 |
+
_init_weight_alt(m, n)
|
534 |
+
else:
|
535 |
+
_init_weight(m, n)
|
536 |
+
|
537 |
+
if reset_box_head:
|
538 |
+
self.box_net = HeadNet(self.config, num_outputs=4)
|
539 |
+
for n, m in self.box_net.named_modules(prefix='box_net'):
|
540 |
+
if alternate_init:
|
541 |
+
_init_weight_alt(m, n)
|
542 |
+
else:
|
543 |
+
_init_weight(m, n)
|
544 |
+
|
545 |
+
@torch.jit.ignore()
|
546 |
+
def toggle_head_bn_level_first(self):
|
547 |
+
""" Toggle the head batchnorm layers between being access with feature_level first vs repeat
|
548 |
+
"""
|
549 |
+
self.class_net.toggle_bn_level_first()
|
550 |
+
self.box_net.toggle_bn_level_first()
|
551 |
+
|
552 |
+
def forward(self, x):
|
553 |
+
x = self.backbone(x)
|
554 |
+
x = self.fpn(x)
|
555 |
+
x_class = self.class_net(x)
|
556 |
+
x_box = self.box_net(x)
|
557 |
+
return x_class, x_box
|
efficientdet/effdet/evaluation/README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tensorflow Models Evaluation
|
2 |
+
|
3 |
+
The code in this folder has been extracted and adapted from evaluation/evaluator code at https://github.com/tensorflow/models/tree/master/research/object_detection/utils
|
4 |
+
|
5 |
+
Original code is licensed Apache 2.0, Copyright Google Inc.
|
6 |
+
https://github.com/tensorflow/models/blob/master/LICENSE
|
7 |
+
|
efficientdet/effdet/evaluation/__init__.py
ADDED
File without changes
|
efficientdet/effdet/evaluation/detection_evaluator.py
ADDED
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta
|
2 |
+
from abc import abstractmethod
|
3 |
+
#import collections
|
4 |
+
import logging
|
5 |
+
import unicodedata
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from .fields import InputDataFields, DetectionResultFields
|
9 |
+
from .object_detection_evaluation import ObjectDetectionEvaluation
|
10 |
+
|
11 |
+
|
12 |
+
def create_category_index(categories):
|
13 |
+
"""Creates dictionary of COCO compatible categories keyed by category id.
|
14 |
+
Args:
|
15 |
+
categories: a list of dicts, each of which has the following keys:
|
16 |
+
'id': (required) an integer id uniquely identifying this category.
|
17 |
+
'name': (required) string representing category name e.g., 'cat', 'dog', 'pizza'.
|
18 |
+
Returns:
|
19 |
+
category_index: a dict containing the same entries as categories, but keyed
|
20 |
+
by the 'id' field of each category.
|
21 |
+
"""
|
22 |
+
category_index = {}
|
23 |
+
for cat in categories:
|
24 |
+
category_index[cat['id']] = cat
|
25 |
+
return category_index
|
26 |
+
|
27 |
+
|
28 |
+
class DetectionEvaluator(metaclass=ABCMeta):
|
29 |
+
"""Interface for object detection evalution classes.
|
30 |
+
Example usage of the Evaluator:
|
31 |
+
------------------------------
|
32 |
+
evaluator = DetectionEvaluator(categories)
|
33 |
+
# Detections and groundtruth for image 1.
|
34 |
+
evaluator.add_single_gt_image_info(...)
|
35 |
+
evaluator.add_single_detected_image_info(...)
|
36 |
+
# Detections and groundtruth for image 2.
|
37 |
+
evaluator.add_single_gt_image_info(...)
|
38 |
+
evaluator.add_single_detected_image_info(...)
|
39 |
+
metrics_dict = evaluator.evaluation()
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, categories):
|
43 |
+
"""Constructor.
|
44 |
+
Args:
|
45 |
+
categories: A list of dicts, each of which has the following keys -
|
46 |
+
'id': (required) an integer id uniquely identifying this category.
|
47 |
+
'name': (required) string representing category name e.g., 'cat', 'dog'.
|
48 |
+
"""
|
49 |
+
self._categories = categories
|
50 |
+
|
51 |
+
def observe_result_dict_for_single_example(self, eval_dict):
|
52 |
+
"""Observes an evaluation result dict for a single example.
|
53 |
+
When executing eagerly, once all observations have been observed by this
|
54 |
+
method you can use `.evaluation()` to get the final metrics.
|
55 |
+
When using `tf.estimator.Estimator` for evaluation this function is used by
|
56 |
+
`get_estimator_eval_metric_ops()` to construct the metric update op.
|
57 |
+
Args:
|
58 |
+
eval_dict: A dictionary that holds tensors for evaluating an object
|
59 |
+
detection model, returned from
|
60 |
+
eval_util.result_dict_for_single_example().
|
61 |
+
Returns:
|
62 |
+
None when executing eagerly, or an update_op that can be used to update
|
63 |
+
the eval metrics in `tf.estimator.EstimatorSpec`.
|
64 |
+
"""
|
65 |
+
raise NotImplementedError('Not implemented for this evaluator!')
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def add_single_ground_truth_image_info(self, image_id, gt_dict):
|
69 |
+
"""Adds groundtruth for a single image to be used for evaluation.
|
70 |
+
Args:
|
71 |
+
image_id: A unique string/integer identifier for the image.
|
72 |
+
gt_dict: A dictionary of groundtruth numpy arrays required for evaluations.
|
73 |
+
"""
|
74 |
+
pass
|
75 |
+
|
76 |
+
@abstractmethod
|
77 |
+
def add_single_detected_image_info(self, image_id, detections_dict):
|
78 |
+
"""Adds detections for a single image to be used for evaluation.
|
79 |
+
Args:
|
80 |
+
image_id: A unique string/integer identifier for the image.
|
81 |
+
detections_dict: A dictionary of detection numpy arrays required for evaluation.
|
82 |
+
"""
|
83 |
+
pass
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def evaluate(self):
|
87 |
+
"""Evaluates detections and returns a dictionary of metrics."""
|
88 |
+
pass
|
89 |
+
|
90 |
+
@abstractmethod
|
91 |
+
def clear(self):
|
92 |
+
"""Clears the state to prepare for a fresh evaluation."""
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class ObjectDetectionEvaluator(DetectionEvaluator):
|
97 |
+
"""A class to evaluation detections."""
|
98 |
+
|
99 |
+
def __init__(self,
|
100 |
+
categories,
|
101 |
+
matching_iou_threshold=0.5,
|
102 |
+
recall_lower_bound=0.0,
|
103 |
+
recall_upper_bound=1.0,
|
104 |
+
evaluate_corlocs=False,
|
105 |
+
evaluate_precision_recall=False,
|
106 |
+
metric_prefix=None,
|
107 |
+
use_weighted_mean_ap=False,
|
108 |
+
evaluate_masks=False,
|
109 |
+
group_of_weight=0.0):
|
110 |
+
"""Constructor.
|
111 |
+
Args:
|
112 |
+
categories: A list of dicts, each of which has the following keys -
|
113 |
+
'id': (required) an integer id uniquely identifying this category.
|
114 |
+
'name': (required) string representing category name e.g., 'cat', 'dog'.
|
115 |
+
matching_iou_threshold: IOU threshold to use for matching groundtruth boxes to detection boxes.
|
116 |
+
recall_lower_bound: lower bound of recall operating area.
|
117 |
+
recall_upper_bound: upper bound of recall operating area.
|
118 |
+
evaluate_corlocs: (optional) boolean which determines if corloc scores are to be returned or not.
|
119 |
+
evaluate_precision_recall: (optional) boolean which determines if
|
120 |
+
precision and recall values are to be returned or not.
|
121 |
+
metric_prefix: (optional) string prefix for metric name; if None, no prefix is used.
|
122 |
+
use_weighted_mean_ap: (optional) boolean which determines if the mean
|
123 |
+
average precision is computed directly from the scores and tp_fp_labels of all classes.
|
124 |
+
evaluate_masks: If False, evaluation will be performed based on boxes. If
|
125 |
+
True, mask evaluation will be performed instead.
|
126 |
+
group_of_weight: Weight of group-of boxes.If set to 0, detections of the
|
127 |
+
correct class within a group-of box are ignored. If weight is > 0, then
|
128 |
+
if at least one detection falls within a group-of box with
|
129 |
+
matching_iou_threshold, weight group_of_weight is added to true
|
130 |
+
positives. Consequently, if no detection falls within a group-of box,
|
131 |
+
weight group_of_weight is added to false negatives.
|
132 |
+
Raises:
|
133 |
+
ValueError: If the category ids are not 1-indexed.
|
134 |
+
"""
|
135 |
+
super(ObjectDetectionEvaluator, self).__init__(categories)
|
136 |
+
self._num_classes = max([cat['id'] for cat in categories])
|
137 |
+
if min(cat['id'] for cat in categories) < 1:
|
138 |
+
raise ValueError('Classes should be 1-indexed.')
|
139 |
+
self._matching_iou_threshold = matching_iou_threshold
|
140 |
+
self._recall_lower_bound = recall_lower_bound
|
141 |
+
self._recall_upper_bound = recall_upper_bound
|
142 |
+
self._use_weighted_mean_ap = use_weighted_mean_ap
|
143 |
+
self._label_id_offset = 1
|
144 |
+
self._evaluate_masks = evaluate_masks
|
145 |
+
self._group_of_weight = group_of_weight
|
146 |
+
self._evaluation = ObjectDetectionEvaluation(
|
147 |
+
num_gt_classes=self._num_classes,
|
148 |
+
matching_iou_threshold=self._matching_iou_threshold,
|
149 |
+
recall_lower_bound=self._recall_lower_bound,
|
150 |
+
recall_upper_bound=self._recall_upper_bound,
|
151 |
+
use_weighted_mean_ap=self._use_weighted_mean_ap,
|
152 |
+
label_id_offset=self._label_id_offset,
|
153 |
+
group_of_weight=self._group_of_weight)
|
154 |
+
self._image_ids = set([])
|
155 |
+
self._evaluate_corlocs = evaluate_corlocs
|
156 |
+
self._evaluate_precision_recall = evaluate_precision_recall
|
157 |
+
self._metric_prefix = (metric_prefix + '_') if metric_prefix else ''
|
158 |
+
self._build_metric_names()
|
159 |
+
|
160 |
+
def _build_metric_names(self):
|
161 |
+
"""Builds a list with metric names."""
|
162 |
+
if self._recall_lower_bound > 0.0 or self._recall_upper_bound < 1.0:
|
163 |
+
self._metric_names = [
|
164 |
+
self._metric_prefix + 'Precision/mAP@{}IOU@[{:.1f},{:.1f}]Recall'.format(
|
165 |
+
self._matching_iou_threshold, self._recall_lower_bound, self._recall_upper_bound)
|
166 |
+
]
|
167 |
+
else:
|
168 |
+
self._metric_names = [
|
169 |
+
self._metric_prefix + 'Precision/mAP@{}IOU'.format(self._matching_iou_threshold)
|
170 |
+
]
|
171 |
+
if self._evaluate_corlocs:
|
172 |
+
self._metric_names.append(
|
173 |
+
self._metric_prefix + 'Precision/meanCorLoc@{}IOU'.format(self._matching_iou_threshold))
|
174 |
+
|
175 |
+
category_index = create_category_index(self._categories)
|
176 |
+
for idx in range(self._num_classes):
|
177 |
+
if idx + self._label_id_offset in category_index:
|
178 |
+
category_name = category_index[idx + self._label_id_offset]['name']
|
179 |
+
category_name = unicodedata.normalize('NFKD', category_name)
|
180 |
+
self._metric_names.append(
|
181 |
+
self._metric_prefix + 'PerformanceByCategory/AP@{}IOU/{}'.format(
|
182 |
+
self._matching_iou_threshold, category_name))
|
183 |
+
if self._evaluate_corlocs:
|
184 |
+
self._metric_names.append(
|
185 |
+
self._metric_prefix + 'PerformanceByCategory/CorLoc@{}IOU/{}'.format(
|
186 |
+
self._matching_iou_threshold, category_name))
|
187 |
+
|
188 |
+
def add_single_ground_truth_image_info(self, image_id, gt_dict):
|
189 |
+
"""Adds groundtruth for a single image to be used for evaluation.
|
190 |
+
Args:
|
191 |
+
image_id: A unique string/integer identifier for the image.
|
192 |
+
gt_dict: A dictionary containing -
|
193 |
+
InputDataFields.gt_boxes: float32 numpy array
|
194 |
+
of shape [num_boxes, 4] containing `num_boxes` groundtruth boxes of
|
195 |
+
the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
196 |
+
InputDataFields.gt_classes: integer numpy array
|
197 |
+
of shape [num_boxes] containing 1-indexed groundtruth classes for the boxes.
|
198 |
+
InputDataFields.gt_difficult: Optional length M numpy boolean array
|
199 |
+
denoting whether a ground truth box is a difficult instance or not.
|
200 |
+
This field is optional to support the case that no boxes are difficult.
|
201 |
+
InputDataFields.gt_instance_masks: Optional numpy array of shape
|
202 |
+
[num_boxes, height, width] with values in {0, 1}.
|
203 |
+
Raises:
|
204 |
+
ValueError: On adding groundtruth for an image more than once. Will also
|
205 |
+
raise error if instance masks are not in groundtruth dictionary.
|
206 |
+
"""
|
207 |
+
if image_id in self._image_ids:
|
208 |
+
return
|
209 |
+
|
210 |
+
gt_classes = gt_dict[InputDataFields.gt_classes] - self._label_id_offset
|
211 |
+
# If the key is not present in the gt_dict or the array is empty
|
212 |
+
# (unless there are no annotations for the groundtruth on this image)
|
213 |
+
# use values from the dictionary or insert None otherwise.
|
214 |
+
if (InputDataFields.gt_difficult in gt_dict and
|
215 |
+
(gt_dict[InputDataFields.gt_difficult].size or not gt_classes.size)):
|
216 |
+
gt_difficult = gt_dict[InputDataFields.gt_difficult]
|
217 |
+
else:
|
218 |
+
gt_difficult = None
|
219 |
+
# FIXME disable difficult flag warning, will support flag eventually
|
220 |
+
# if not len(self._image_ids) % 1000:
|
221 |
+
# logging.warning('image %s does not have groundtruth difficult flag specified', image_id)
|
222 |
+
gt_masks = None
|
223 |
+
if self._evaluate_masks:
|
224 |
+
if InputDataFields.gt_instance_masks not in gt_dict:
|
225 |
+
raise ValueError('Instance masks not in groundtruth dictionary.')
|
226 |
+
gt_masks = gt_dict[InputDataFields.gt_instance_masks]
|
227 |
+
self._evaluation.add_single_ground_truth_image_info(
|
228 |
+
image_key=image_id,
|
229 |
+
gt_boxes=gt_dict[InputDataFields.gt_boxes],
|
230 |
+
gt_class_labels=gt_classes,
|
231 |
+
gt_is_difficult_list=gt_difficult,
|
232 |
+
gt_masks=gt_masks)
|
233 |
+
self._image_ids.update([image_id])
|
234 |
+
|
235 |
+
def add_single_detected_image_info(self, image_id, detections_dict):
|
236 |
+
"""Adds detections for a single image to be used for evaluation.
|
237 |
+
Args:
|
238 |
+
image_id: A unique string/integer identifier for the image.
|
239 |
+
detections_dict: A dictionary containing -
|
240 |
+
DetectionResultFields.detection_boxes: float32 numpy
|
241 |
+
array of shape [num_boxes, 4] containing `num_boxes` detection boxes
|
242 |
+
of the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
243 |
+
DetectionResultFields.detection_scores: float32 numpy
|
244 |
+
array of shape [num_boxes] containing detection scores for the boxes.
|
245 |
+
DetectionResultFields.detection_classes: integer numpy
|
246 |
+
array of shape [num_boxes] containing 1-indexed detection classes for the boxes.
|
247 |
+
DetectionResultFields.detection_masks: uint8 numpy array
|
248 |
+
of shape [num_boxes, height, width] containing `num_boxes` masks of
|
249 |
+
values ranging between 0 and 1.
|
250 |
+
Raises:
|
251 |
+
ValueError: If detection masks are not in detections dictionary.
|
252 |
+
"""
|
253 |
+
detection_classes = detections_dict[DetectionResultFields.detection_classes] - self._label_id_offset
|
254 |
+
detection_masks = None
|
255 |
+
if self._evaluate_masks:
|
256 |
+
if DetectionResultFields.detection_masks not in detections_dict:
|
257 |
+
raise ValueError('Detection masks not in detections dictionary.')
|
258 |
+
detection_masks = detections_dict[DetectionResultFields.detection_masks]
|
259 |
+
self._evaluation.add_single_detected_image_info(
|
260 |
+
image_key=image_id,
|
261 |
+
detected_boxes=detections_dict[DetectionResultFields.detection_boxes],
|
262 |
+
detected_scores=detections_dict[DetectionResultFields.detection_scores],
|
263 |
+
detected_class_labels=detection_classes,
|
264 |
+
detected_masks=detection_masks)
|
265 |
+
|
266 |
+
def evaluate(self):
|
267 |
+
"""Compute evaluation result.
|
268 |
+
Returns:
|
269 |
+
A dictionary of metrics with the following fields -
|
270 |
+
1. summary_metrics:
|
271 |
+
'<prefix if not empty>_Precision/mAP@<matching_iou_threshold>IOU': mean
|
272 |
+
average precision at the specified IOU threshold.
|
273 |
+
2. per_category_ap: category specific results with keys of the form
|
274 |
+
'<prefix if not empty>_PerformanceByCategory/
|
275 |
+
mAP@<matching_iou_threshold>IOU/category'.
|
276 |
+
"""
|
277 |
+
metrics = self._evaluation.evaluate()
|
278 |
+
pascal_metrics = {self._metric_names[0]: metrics['mean_ap']}
|
279 |
+
if self._evaluate_corlocs:
|
280 |
+
pascal_metrics[self._metric_names[1]] = metrics['mean_corloc']
|
281 |
+
category_index = create_category_index(self._categories)
|
282 |
+
for idx in range(metrics['per_class_ap'].size):
|
283 |
+
if idx + self._label_id_offset in category_index:
|
284 |
+
category_name = category_index[idx + self._label_id_offset]['name']
|
285 |
+
category_name = unicodedata.normalize('NFKD', category_name)
|
286 |
+
display_name = self._metric_prefix + 'PerformanceByCategory/AP@{}IOU/{}'.format(
|
287 |
+
self._matching_iou_threshold, category_name)
|
288 |
+
pascal_metrics[display_name] = metrics['per_class_ap'][idx]
|
289 |
+
|
290 |
+
# Optionally add precision and recall values
|
291 |
+
if self._evaluate_precision_recall:
|
292 |
+
display_name = self._metric_prefix + 'PerformanceByCategory/Precision@{}IOU/{}'.format(
|
293 |
+
self._matching_iou_threshold, category_name)
|
294 |
+
pascal_metrics[display_name] = metrics['per_class_precision'][idx]
|
295 |
+
display_name = self._metric_prefix + 'PerformanceByCategory/Recall@{}IOU/{}'.format(
|
296 |
+
self._matching_iou_threshold, category_name)
|
297 |
+
pascal_metrics[display_name] = metrics['per_class_precision'][idx]
|
298 |
+
|
299 |
+
# Optionally add CorLoc metrics.classes
|
300 |
+
if self._evaluate_corlocs:
|
301 |
+
display_name = self._metric_prefix + 'PerformanceByCategory/CorLoc@{}IOU/{}'.format(
|
302 |
+
self._matching_iou_threshold, category_name)
|
303 |
+
pascal_metrics[display_name] = metrics['per_class_corloc'][idx]
|
304 |
+
|
305 |
+
return pascal_metrics
|
306 |
+
|
307 |
+
def clear(self):
|
308 |
+
"""Clears the state to prepare for a fresh evaluation."""
|
309 |
+
self._evaluation = ObjectDetectionEvaluation(
|
310 |
+
num_gt_classes=self._num_classes,
|
311 |
+
matching_iou_threshold=self._matching_iou_threshold,
|
312 |
+
use_weighted_mean_ap=self._use_weighted_mean_ap,
|
313 |
+
label_id_offset=self._label_id_offset)
|
314 |
+
self._image_ids.clear()
|
315 |
+
|
316 |
+
|
317 |
+
class PascalDetectionEvaluator(ObjectDetectionEvaluator):
|
318 |
+
"""A class to evaluation detections using PASCAL metrics."""
|
319 |
+
|
320 |
+
def __init__(self, categories, matching_iou_threshold=0.5):
|
321 |
+
super(PascalDetectionEvaluator, self).__init__(
|
322 |
+
categories,
|
323 |
+
matching_iou_threshold=matching_iou_threshold,
|
324 |
+
evaluate_corlocs=False,
|
325 |
+
metric_prefix='PascalBoxes',
|
326 |
+
use_weighted_mean_ap=False)
|
327 |
+
|
328 |
+
|
329 |
+
class WeightedPascalDetectionEvaluator(ObjectDetectionEvaluator):
|
330 |
+
"""A class to evaluation detections using weighted PASCAL metrics.
|
331 |
+
Weighted PASCAL metrics computes the mean average precision as the average
|
332 |
+
precision given the scores and tp_fp_labels of all classes. In comparison,
|
333 |
+
PASCAL metrics computes the mean average precision as the mean of the
|
334 |
+
per-class average precisions.
|
335 |
+
This definition is very similar to the mean of the per-class average
|
336 |
+
precisions weighted by class frequency. However, they are typically not the
|
337 |
+
same as the average precision is not a linear function of the scores and
|
338 |
+
tp_fp_labels.
|
339 |
+
"""
|
340 |
+
|
341 |
+
def __init__(self, categories, matching_iou_threshold=0.5):
|
342 |
+
super(WeightedPascalDetectionEvaluator, self).__init__(
|
343 |
+
categories,
|
344 |
+
matching_iou_threshold=matching_iou_threshold,
|
345 |
+
evaluate_corlocs=False,
|
346 |
+
metric_prefix='WeightedPascalBoxes',
|
347 |
+
use_weighted_mean_ap=True)
|
348 |
+
|
349 |
+
|
350 |
+
class PrecisionAtRecallDetectionEvaluator(ObjectDetectionEvaluator):
|
351 |
+
"""A class to evaluation detections using precision@recall metrics."""
|
352 |
+
|
353 |
+
def __init__(self,
|
354 |
+
categories,
|
355 |
+
matching_iou_threshold=0.5,
|
356 |
+
recall_lower_bound=0.,
|
357 |
+
recall_upper_bound=1.0):
|
358 |
+
super(PrecisionAtRecallDetectionEvaluator, self).__init__(
|
359 |
+
categories,
|
360 |
+
matching_iou_threshold=matching_iou_threshold,
|
361 |
+
recall_lower_bound=recall_lower_bound,
|
362 |
+
recall_upper_bound=recall_upper_bound,
|
363 |
+
evaluate_corlocs=False,
|
364 |
+
metric_prefix='PrecisionAtRecallBoxes',
|
365 |
+
use_weighted_mean_ap=False)
|
366 |
+
|
367 |
+
|
368 |
+
class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
|
369 |
+
"""A class to evaluation detections using Open Images V2 metrics.
|
370 |
+
Open Images V2 introduce group_of type of bounding boxes and this metric
|
371 |
+
handles those boxes appropriately.
|
372 |
+
"""
|
373 |
+
|
374 |
+
def __init__(self,
|
375 |
+
categories,
|
376 |
+
matching_iou_threshold=0.5,
|
377 |
+
evaluate_masks=False,
|
378 |
+
evaluate_corlocs=False,
|
379 |
+
metric_prefix='OpenImagesV5',
|
380 |
+
group_of_weight=0.0):
|
381 |
+
"""Constructor.
|
382 |
+
Args:
|
383 |
+
categories: A list of dicts, each of which has the following keys -
|
384 |
+
'id': (required) an integer id uniquely identifying this category.
|
385 |
+
'name': (required) string representing category name e.g., 'cat', 'dog'.
|
386 |
+
matching_iou_threshold: IOU threshold to use for matching groundtruth
|
387 |
+
boxes to detection boxes.
|
388 |
+
evaluate_masks: if True, evaluator evaluates masks.
|
389 |
+
evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
|
390 |
+
metric_prefix: Prefix name of the metric.
|
391 |
+
group_of_weight: Weight of the group-of bounding box. If set to 0 (default
|
392 |
+
for Open Images V2 detection protocol), detections of the correct class
|
393 |
+
within a group-of box are ignored. If weight is > 0, then if at least
|
394 |
+
one detection falls within a group-of box with matching_iou_threshold,
|
395 |
+
weight group_of_weight is added to true positives. Consequently, if no
|
396 |
+
detection falls within a group-of box, weight group_of_weight is added
|
397 |
+
to false negatives.
|
398 |
+
"""
|
399 |
+
|
400 |
+
super(OpenImagesDetectionEvaluator, self).__init__(
|
401 |
+
categories,
|
402 |
+
matching_iou_threshold,
|
403 |
+
evaluate_corlocs,
|
404 |
+
metric_prefix=metric_prefix,
|
405 |
+
group_of_weight=group_of_weight,
|
406 |
+
evaluate_masks=evaluate_masks)
|
407 |
+
|
408 |
+
def add_single_ground_truth_image_info(self, image_id, gt_dict):
|
409 |
+
"""Adds groundtruth for a single image to be used for evaluation.
|
410 |
+
Args:
|
411 |
+
image_id: A unique string/integer identifier for the image.
|
412 |
+
gt_dict: A dictionary containing -
|
413 |
+
InputDataFields.gt_boxes: float32 numpy array
|
414 |
+
of shape [num_boxes, 4] containing `num_boxes` groundtruth boxes of
|
415 |
+
the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
416 |
+
InputDataFields.gt_classes: integer numpy array
|
417 |
+
of shape [num_boxes] containing 1-indexed groundtruth classes for the boxes.
|
418 |
+
InputDataFields.gt_group_of: Optional length M
|
419 |
+
numpy boolean array denoting whether a groundtruth box contains a group of instances.
|
420 |
+
Raises:
|
421 |
+
ValueError: On adding groundtruth for an image more than once.
|
422 |
+
"""
|
423 |
+
if image_id in self._image_ids:
|
424 |
+
return
|
425 |
+
|
426 |
+
gt_classes = (gt_dict[InputDataFields.gt_classes] - self._label_id_offset)
|
427 |
+
# If the key is not present in the gt_dict or the array is empty
|
428 |
+
# (unless there are no annotations for the groundtruth on this image)
|
429 |
+
# use values from the dictionary or insert None otherwise.
|
430 |
+
if (InputDataFields.gt_group_of in gt_dict and
|
431 |
+
(gt_dict[InputDataFields.gt_group_of].size or not gt_classes.size)):
|
432 |
+
gt_group_of = gt_dict[InputDataFields.gt_group_of]
|
433 |
+
else:
|
434 |
+
gt_group_of = None
|
435 |
+
# FIXME disable warning for now, will add group_of flag eventually
|
436 |
+
# if not len(self._image_ids) % 1000:
|
437 |
+
# logging.warning('image %s does not have groundtruth group_of flag specified', image_id)
|
438 |
+
if self._evaluate_masks:
|
439 |
+
gt_masks = gt_dict[InputDataFields.gt_instance_masks]
|
440 |
+
else:
|
441 |
+
gt_masks = None
|
442 |
+
|
443 |
+
self._evaluation.add_single_ground_truth_image_info(
|
444 |
+
image_id,
|
445 |
+
gt_dict[InputDataFields.gt_boxes],
|
446 |
+
gt_classes,
|
447 |
+
gt_is_difficult_list=None,
|
448 |
+
gt_is_group_of_list=gt_group_of,
|
449 |
+
gt_masks=gt_masks)
|
450 |
+
self._image_ids.update([image_id])
|
451 |
+
|
452 |
+
|
453 |
+
class OpenImagesChallengeEvaluator(OpenImagesDetectionEvaluator):
|
454 |
+
"""A class implements Open Images Challenge metrics.
|
455 |
+
Both Detection and Instance Segmentation evaluation metrics are implemented.
|
456 |
+
Open Images Challenge Detection metric has two major changes in comparison
|
457 |
+
with Open Images V2 detection metric:
|
458 |
+
- a custom weight might be specified for detecting an object contained in a group-of box.
|
459 |
+
- verified image-level labels should be explicitly provided for evaluation: in case an
|
460 |
+
image has neither positive nor negative image level label of class c, all detections of
|
461 |
+
this class on this image will be ignored.
|
462 |
+
|
463 |
+
Open Images Challenge Instance Segmentation metric allows to measure performance
|
464 |
+
of models in case of incomplete annotations: some instances are
|
465 |
+
annotations only on box level and some - on image-level. In addition,
|
466 |
+
image-level labels are taken into account as in detection metric.
|
467 |
+
|
468 |
+
Open Images Challenge Detection metric default parameters:
|
469 |
+
evaluate_masks = False
|
470 |
+
group_of_weight = 1.0
|
471 |
+
|
472 |
+
Open Images Challenge Instance Segmentation metric default parameters:
|
473 |
+
evaluate_masks = True
|
474 |
+
(group_of_weight will not matter)
|
475 |
+
"""
|
476 |
+
|
477 |
+
def __init__(
|
478 |
+
self,
|
479 |
+
categories,
|
480 |
+
evaluate_masks=False,
|
481 |
+
matching_iou_threshold=0.5,
|
482 |
+
evaluate_corlocs=False,
|
483 |
+
group_of_weight=1.0):
|
484 |
+
"""Constructor.
|
485 |
+
Args:
|
486 |
+
categories: A list of dicts, each of which has the following keys -
|
487 |
+
'id': (required) an integer id uniquely identifying this category.
|
488 |
+
'name': (required) string representing category name e.g., 'cat', 'dog'.
|
489 |
+
evaluate_masks: set to true for instance segmentation metric and to false
|
490 |
+
for detection metric.
|
491 |
+
matching_iou_threshold: IOU threshold to use for matching groundtruth
|
492 |
+
boxes to detection boxes.
|
493 |
+
evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
|
494 |
+
group_of_weight: Weight of group-of boxes. If set to 0, detections of the
|
495 |
+
correct class within a group-of box are ignored. If weight is > 0, then
|
496 |
+
if at least one detection falls within a group-of box with
|
497 |
+
matching_iou_threshold, weight group_of_weight is added to true
|
498 |
+
positives. Consequently, if no detection falls within a group-of box,
|
499 |
+
weight group_of_weight is added to false negatives.
|
500 |
+
"""
|
501 |
+
if not evaluate_masks:
|
502 |
+
metrics_prefix = 'OpenImagesDetectionChallenge'
|
503 |
+
else:
|
504 |
+
metrics_prefix = 'OpenImagesInstanceSegmentationChallenge'
|
505 |
+
|
506 |
+
super(OpenImagesChallengeEvaluator, self).__init__(
|
507 |
+
categories,
|
508 |
+
matching_iou_threshold,
|
509 |
+
evaluate_masks=evaluate_masks,
|
510 |
+
evaluate_corlocs=evaluate_corlocs,
|
511 |
+
group_of_weight=group_of_weight,
|
512 |
+
metric_prefix=metrics_prefix)
|
513 |
+
|
514 |
+
self._evaluatable_labels = {}
|
515 |
+
|
516 |
+
def add_single_ground_truth_image_info(self, image_id, gt_dict):
|
517 |
+
"""Adds groundtruth for a single image to be used for evaluation.
|
518 |
+
Args:
|
519 |
+
image_id: A unique string/integer identifier for the image.
|
520 |
+
gt_dict: A dictionary containing -
|
521 |
+
InputDataFields.gt_boxes: float32 numpy array of shape [num_boxes, 4]
|
522 |
+
containing `num_boxes` groundtruth boxes of the format [ymin, xmin, ymax, xmax]
|
523 |
+
in absolute image coordinates.
|
524 |
+
InputDataFields.gt_classes: integer numpy array of shape [num_boxes]
|
525 |
+
containing 1-indexed groundtruth classes for the boxes.
|
526 |
+
InputDataFields.gt_image_classes: integer 1D
|
527 |
+
numpy array containing all classes for which labels are verified.
|
528 |
+
InputDataFields.gt_group_of: Optional length M
|
529 |
+
numpy boolean array denoting whether a groundtruth box contains a group of instances.
|
530 |
+
Raises:
|
531 |
+
ValueError: On adding groundtruth for an image more than once.
|
532 |
+
"""
|
533 |
+
super(OpenImagesChallengeEvaluator,
|
534 |
+
self).add_single_ground_truth_image_info(image_id, gt_dict)
|
535 |
+
input_fields = InputDataFields
|
536 |
+
gt_classes = gt_dict[input_fields.gt_classes] - self._label_id_offset
|
537 |
+
image_classes = np.array([], dtype=int)
|
538 |
+
if input_fields.gt_image_classes in gt_dict:
|
539 |
+
image_classes = gt_dict[input_fields.gt_image_classes]
|
540 |
+
elif input_fields.gt_labeled_classes in gt_dict:
|
541 |
+
image_classes = gt_dict[input_fields.gt_labeled_classes]
|
542 |
+
image_classes -= self._label_id_offset
|
543 |
+
self._evaluatable_labels[image_id] = np.unique(
|
544 |
+
np.concatenate((image_classes, gt_classes)))
|
545 |
+
|
546 |
+
def add_single_detected_image_info(self, image_id, detections_dict):
|
547 |
+
"""Adds detections for a single image to be used for evaluation.
|
548 |
+
Args:
|
549 |
+
image_id: A unique string/integer identifier for the image.
|
550 |
+
detections_dict: A dictionary containing -
|
551 |
+
DetectionResultFields.detection_boxes: float32 numpy
|
552 |
+
array of shape [num_boxes, 4] containing `num_boxes` detection boxes
|
553 |
+
of the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
554 |
+
DetectionResultFields.detection_scores: float32 numpy
|
555 |
+
array of shape [num_boxes] containing detection scores for the boxes.
|
556 |
+
DetectionResultFields.detection_classes: integer numpy
|
557 |
+
array of shape [num_boxes] containing 1-indexed detection classes for
|
558 |
+
the boxes.
|
559 |
+
Raises:
|
560 |
+
ValueError: If detection masks are not in detections dictionary.
|
561 |
+
"""
|
562 |
+
if image_id not in self._image_ids:
|
563 |
+
# Since for the correct work of evaluator it is assumed that groundtruth
|
564 |
+
# is inserted first we make sure to break the code if is it not the case.
|
565 |
+
self._image_ids.update([image_id])
|
566 |
+
self._evaluatable_labels[image_id] = np.array([])
|
567 |
+
|
568 |
+
detection_classes = detections_dict[DetectionResultFields.detection_classes] - self._label_id_offset
|
569 |
+
allowed_classes = np.where(np.isin(detection_classes, self._evaluatable_labels[image_id]))
|
570 |
+
detection_classes = detection_classes[allowed_classes]
|
571 |
+
detected_boxes = detections_dict[DetectionResultFields.detection_boxes][allowed_classes]
|
572 |
+
detected_scores = detections_dict[DetectionResultFields.detection_scores][allowed_classes]
|
573 |
+
|
574 |
+
if self._evaluate_masks:
|
575 |
+
detection_masks = detections_dict[DetectionResultFields.detection_masks][allowed_classes]
|
576 |
+
else:
|
577 |
+
detection_masks = None
|
578 |
+
self._evaluation.add_single_detected_image_info(
|
579 |
+
image_key=image_id,
|
580 |
+
detected_boxes=detected_boxes,
|
581 |
+
detected_scores=detected_scores,
|
582 |
+
detected_class_labels=detection_classes,
|
583 |
+
detected_masks=detection_masks)
|
584 |
+
|
585 |
+
def clear(self):
|
586 |
+
"""Clears stored data."""
|
587 |
+
|
588 |
+
super(OpenImagesChallengeEvaluator, self).clear()
|
589 |
+
self._evaluatable_labels.clear()
|
590 |
+
|
efficientdet/effdet/evaluation/fields.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class InputDataFields(object):
|
3 |
+
"""Names for the input tensors.
|
4 |
+
Holds the standard data field names to use for identifying input tensors. This
|
5 |
+
should be used by the decoder to identify keys for the returned tensor_dict
|
6 |
+
containing input tensors. And it should be used by the model to identify the
|
7 |
+
tensors it needs.
|
8 |
+
Attributes:
|
9 |
+
image: image.
|
10 |
+
image_additional_channels: additional channels.
|
11 |
+
key: unique key corresponding to image.
|
12 |
+
filename: original filename of the dataset (without common path).
|
13 |
+
gt_image_classes: image-level class labels.
|
14 |
+
gt_image_confidences: image-level class confidences.
|
15 |
+
gt_labeled_classes: image-level annotation that indicates the
|
16 |
+
classes for which an image has been labeled.
|
17 |
+
gt_boxes: coordinates of the ground truth boxes in the image.
|
18 |
+
gt_classes: box-level class labels.
|
19 |
+
gt_confidences: box-level class confidences. The shape should be
|
20 |
+
the same as the shape of gt_classes.
|
21 |
+
gt_label_types: box-level label types (e.g. explicit negative).
|
22 |
+
gt_is_crowd: [DEPRECATED, use gt_group_of instead]
|
23 |
+
is the groundtruth a single object or a crowd.
|
24 |
+
gt_area: area of a groundtruth segment.
|
25 |
+
gt_difficult: is a `difficult` object
|
26 |
+
gt_group_of: is a `group_of` objects, e.g. multiple objects of the
|
27 |
+
same class, forming a connected group, where instances are heavily
|
28 |
+
occluding each other.
|
29 |
+
gt_instance_masks: ground truth instance masks.
|
30 |
+
gt_instance_boundaries: ground truth instance boundaries.
|
31 |
+
gt_instance_classes: instance mask-level class labels.
|
32 |
+
gt_label_weights: groundtruth label weights.
|
33 |
+
gt_weights: groundtruth weight factor for bounding boxes.
|
34 |
+
image_height: height of images, used to decode
|
35 |
+
image_width: width of images, used to decode
|
36 |
+
"""
|
37 |
+
image = 'image'
|
38 |
+
key = 'image_id'
|
39 |
+
filename = 'filename'
|
40 |
+
gt_boxes = 'bbox'
|
41 |
+
gt_classes = 'cls'
|
42 |
+
gt_confidences = 'confidences'
|
43 |
+
gt_label_types = 'label_types'
|
44 |
+
gt_image_classes = 'img_cls'
|
45 |
+
gt_image_confidences = 'img_confidences'
|
46 |
+
gt_labeled_classes = 'labeled_cls'
|
47 |
+
gt_is_crowd = 'is_crowd'
|
48 |
+
gt_area = 'area'
|
49 |
+
gt_difficult = 'difficult'
|
50 |
+
gt_group_of = 'group_of'
|
51 |
+
gt_instance_masks = 'instance_masks'
|
52 |
+
gt_instance_boundaries = 'instance_boundaries'
|
53 |
+
gt_instance_classes = 'instance_classes'
|
54 |
+
image_height = 'img_height'
|
55 |
+
image_width = 'img_width'
|
56 |
+
image_size = 'img_size'
|
57 |
+
|
58 |
+
|
59 |
+
class DetectionResultFields(object):
|
60 |
+
"""Naming conventions for storing the output of the detector.
|
61 |
+
Attributes:
|
62 |
+
source_id: source of the original image.
|
63 |
+
key: unique key corresponding to image.
|
64 |
+
detection_boxes: coordinates of the detection boxes in the image.
|
65 |
+
detection_scores: detection scores for the detection boxes in the image.
|
66 |
+
detection_multiclass_scores: class score distribution (including background)
|
67 |
+
for detection boxes in the image including background class.
|
68 |
+
detection_classes: detection-level class labels.
|
69 |
+
detection_masks: contains a segmentation mask for each detection box.
|
70 |
+
"""
|
71 |
+
|
72 |
+
key = 'image_id'
|
73 |
+
detection_boxes = 'bbox'
|
74 |
+
detection_scores = 'score'
|
75 |
+
detection_classes = 'cls'
|
76 |
+
detection_masks = 'masks'
|
77 |
+
|
78 |
+
|
79 |
+
class BoxListFields(object):
|
80 |
+
"""Naming conventions for BoxLists.
|
81 |
+
Attributes:
|
82 |
+
boxes: bounding box coordinates.
|
83 |
+
classes: classes per bounding box.
|
84 |
+
scores: scores per bounding box.
|
85 |
+
weights: sample weights per bounding box.
|
86 |
+
objectness: objectness score per bounding box.
|
87 |
+
masks: masks per bounding box.
|
88 |
+
boundaries: boundaries per bounding box.
|
89 |
+
keypoints: keypoints per bounding box.
|
90 |
+
keypoint_heatmaps: keypoint heatmaps per bounding box.
|
91 |
+
is_crowd: is_crowd annotation per bounding box.
|
92 |
+
"""
|
93 |
+
boxes = 'boxes'
|
94 |
+
classes = 'classes'
|
95 |
+
scores = 'scores'
|
96 |
+
weights = 'weights'
|
97 |
+
confidences = 'confidences'
|
98 |
+
objectness = 'objectness'
|
99 |
+
masks = 'masks'
|
100 |
+
boundaries = 'boundaries'
|
101 |
+
keypoints = 'keypoints'
|
102 |
+
keypoint_visibilities = 'keypoint_visibilities'
|
103 |
+
keypoint_heatmaps = 'keypoint_heatmaps'
|
104 |
+
is_crowd = 'is_crowd'
|
105 |
+
group_of = 'group_of'
|
efficientdet/effdet/evaluation/metrics.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def compute_precision_recall(scores, labels, num_gt):
|
5 |
+
"""Compute precision and recall.
|
6 |
+
Args:
|
7 |
+
scores: A float numpy array representing detection score
|
8 |
+
labels: A float numpy array representing weighted true/false positive labels
|
9 |
+
num_gt: Number of ground truth instances
|
10 |
+
Raises:
|
11 |
+
ValueError: if the input is not of the correct format
|
12 |
+
Returns:
|
13 |
+
precision: Fraction of positive instances over detected ones. This value is
|
14 |
+
None if no ground truth labels are present.
|
15 |
+
recall: Fraction of detected positive instance over all positive instances.
|
16 |
+
This value is None if no ground truth labels are present.
|
17 |
+
"""
|
18 |
+
if not isinstance(labels, np.ndarray) or len(labels.shape) != 1:
|
19 |
+
raise ValueError("labels must be single dimension numpy array")
|
20 |
+
|
21 |
+
if labels.dtype != np.float and labels.dtype != np.bool:
|
22 |
+
raise ValueError("labels type must be either bool or float")
|
23 |
+
|
24 |
+
if not isinstance(scores, np.ndarray) or len(scores.shape) != 1:
|
25 |
+
raise ValueError("scores must be single dimension numpy array")
|
26 |
+
|
27 |
+
if num_gt < np.sum(labels):
|
28 |
+
raise ValueError("Number of true positives must be smaller than num_gt.")
|
29 |
+
|
30 |
+
if len(scores) != len(labels):
|
31 |
+
raise ValueError("scores and labels must be of the same size.")
|
32 |
+
|
33 |
+
if num_gt == 0:
|
34 |
+
return None, None
|
35 |
+
|
36 |
+
sorted_indices = np.argsort(scores)
|
37 |
+
sorted_indices = sorted_indices[::-1]
|
38 |
+
true_positive_labels = labels[sorted_indices]
|
39 |
+
false_positive_labels = (true_positive_labels <= 0).astype(float)
|
40 |
+
cum_true_positives = np.cumsum(true_positive_labels)
|
41 |
+
cum_false_positives = np.cumsum(false_positive_labels)
|
42 |
+
precision = cum_true_positives.astype(float) / (cum_true_positives + cum_false_positives)
|
43 |
+
recall = cum_true_positives.astype(float) / num_gt
|
44 |
+
return precision, recall
|
45 |
+
|
46 |
+
|
47 |
+
def compute_average_precision(precision, recall):
|
48 |
+
"""Compute Average Precision according to the definition in VOCdevkit.
|
49 |
+
Precision is modified to ensure that it does not decrease as recall
|
50 |
+
decrease.
|
51 |
+
Args:
|
52 |
+
precision: A float [N, 1] numpy array of precisions
|
53 |
+
recall: A float [N, 1] numpy array of recalls
|
54 |
+
Raises:
|
55 |
+
ValueError: if the input is not of the correct format
|
56 |
+
Returns:
|
57 |
+
average_precison: The area under the precision recall curve. NaN if
|
58 |
+
precision and recall are None.
|
59 |
+
"""
|
60 |
+
if precision is None:
|
61 |
+
if recall is not None:
|
62 |
+
raise ValueError("If precision is None, recall must also be None")
|
63 |
+
return np.NAN
|
64 |
+
|
65 |
+
if not isinstance(precision, np.ndarray) or not isinstance(recall, np.ndarray):
|
66 |
+
raise ValueError("precision and recall must be numpy array")
|
67 |
+
if precision.dtype != np.float or recall.dtype != np.float:
|
68 |
+
raise ValueError("input must be float numpy array.")
|
69 |
+
if len(precision) != len(recall):
|
70 |
+
raise ValueError("precision and recall must be of the same size.")
|
71 |
+
if not precision.size:
|
72 |
+
return 0.0
|
73 |
+
if np.amin(precision) < 0 or np.amax(precision) > 1:
|
74 |
+
raise ValueError("Precision must be in the range of [0, 1].")
|
75 |
+
if np.amin(recall) < 0 or np.amax(recall) > 1:
|
76 |
+
raise ValueError("recall must be in the range of [0, 1].")
|
77 |
+
if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)):
|
78 |
+
raise ValueError("recall must be a non-decreasing array")
|
79 |
+
|
80 |
+
recall = np.concatenate([[0], recall, [1]])
|
81 |
+
precision = np.concatenate([[0], precision, [0]])
|
82 |
+
|
83 |
+
# Preprocess precision to be a non-decreasing array
|
84 |
+
for i in range(len(precision) - 2, -1, -1):
|
85 |
+
precision[i] = np.maximum(precision[i], precision[i + 1])
|
86 |
+
|
87 |
+
indices = np.where(recall[1:] != recall[:-1])[0] + 1
|
88 |
+
average_precision = np.sum((recall[indices] - recall[indices - 1]) * precision[indices])
|
89 |
+
return average_precision
|
90 |
+
|
91 |
+
|
92 |
+
def compute_cor_loc(num_gt_imgs_per_class, num_images_correctly_detected_per_class):
|
93 |
+
"""Compute CorLoc according to the definition in the following paper.
|
94 |
+
https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf
|
95 |
+
Returns nans if there are no ground truth images for a class.
|
96 |
+
Args:
|
97 |
+
num_gt_imgs_per_class: 1D array, representing number of images containing
|
98 |
+
at least one object instance of a particular class
|
99 |
+
num_images_correctly_detected_per_class: 1D array, representing number of
|
100 |
+
images that are correctly detected at least one object instance of a particular class
|
101 |
+
Returns:
|
102 |
+
corloc_per_class: A float numpy array represents the corloc score of each class
|
103 |
+
"""
|
104 |
+
return np.where(
|
105 |
+
num_gt_imgs_per_class == 0, np.nan,
|
106 |
+
num_images_correctly_detected_per_class / num_gt_imgs_per_class)
|
107 |
+
|
108 |
+
|
109 |
+
def compute_median_rank_at_k(tp_fp_list, k):
|
110 |
+
"""Computes MedianRank@k, where k is the top-scoring labels.
|
111 |
+
Args:
|
112 |
+
tp_fp_list: a list of numpy arrays; each numpy array corresponds to the all
|
113 |
+
detection on a single image, where the detections are sorted by score in
|
114 |
+
descending order. Further, each numpy array element can have boolean or
|
115 |
+
float values. True positive elements have either value >0.0 or True;
|
116 |
+
any other value is considered false positive.
|
117 |
+
k: number of top-scoring proposals to take.
|
118 |
+
Returns:
|
119 |
+
median_rank: median rank of all true positive proposals among top k by score.
|
120 |
+
"""
|
121 |
+
ranks = []
|
122 |
+
for i in range(len(tp_fp_list)):
|
123 |
+
ranks.append(np.where(tp_fp_list[i][0:min(k, tp_fp_list[i].shape[0])] > 0)[0])
|
124 |
+
concatenated_ranks = np.concatenate(ranks)
|
125 |
+
return np.median(concatenated_ranks)
|
126 |
+
|
127 |
+
|
128 |
+
def compute_recall_at_k(tp_fp_list, num_gt, k):
|
129 |
+
"""Computes Recall@k, MedianRank@k, where k is the top-scoring labels.
|
130 |
+
Args:
|
131 |
+
tp_fp_list: a list of numpy arrays; each numpy array corresponds to the all
|
132 |
+
detection on a single image, where the detections are sorted by score in
|
133 |
+
descending order. Further, each numpy array element can have boolean or
|
134 |
+
float values. True positive elements have either value >0.0 or True;
|
135 |
+
any other value is considered false positive.
|
136 |
+
num_gt: number of groundtruth anotations.
|
137 |
+
k: number of top-scoring proposals to take.
|
138 |
+
Returns:
|
139 |
+
recall: recall evaluated on the top k by score detections.
|
140 |
+
"""
|
141 |
+
|
142 |
+
tp_fp_eval = []
|
143 |
+
for i in range(len(tp_fp_list)):
|
144 |
+
tp_fp_eval.append(tp_fp_list[i][0:min(k, tp_fp_list[i].shape[0])])
|
145 |
+
|
146 |
+
tp_fp_eval = np.concatenate(tp_fp_eval)
|
147 |
+
|
148 |
+
return np.sum(tp_fp_eval) / num_gt
|
efficientdet/effdet/evaluation/np_box_list.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Bounding Box List operations for Numpy BoxLists.
|
17 |
+
|
18 |
+
Example box operations that are supported:
|
19 |
+
* Areas: compute bounding box areas
|
20 |
+
* IOU: pairwise intersection-over-union scores
|
21 |
+
"""
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
|
25 |
+
class BoxList(object):
|
26 |
+
"""Box collection.
|
27 |
+
BoxList represents a list of bounding boxes as numpy array, where each
|
28 |
+
bounding box is represented as a row of 4 numbers,
|
29 |
+
[y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a
|
30 |
+
given list correspond to a single image.
|
31 |
+
Optionally, users can add additional related fields (such as
|
32 |
+
objectness/classification scores).
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, data):
|
36 |
+
"""Constructs box collection.
|
37 |
+
Args:
|
38 |
+
data: a numpy array of shape [N, 4] representing box coordinates
|
39 |
+
Raises:
|
40 |
+
ValueError: if bbox data is not a numpy array
|
41 |
+
ValueError: if invalid dimensions for bbox data
|
42 |
+
"""
|
43 |
+
if not isinstance(data, np.ndarray):
|
44 |
+
raise ValueError('data must be a numpy array.')
|
45 |
+
if len(data.shape) != 2 or data.shape[1] != 4:
|
46 |
+
raise ValueError('Invalid dimensions for box data.')
|
47 |
+
if data.dtype != np.float32 and data.dtype != np.float64:
|
48 |
+
raise ValueError('Invalid data type for box data: float is required.')
|
49 |
+
if not self._is_valid_boxes(data):
|
50 |
+
raise ValueError('Invalid box data. data must be a numpy array of '
|
51 |
+
'N*[y_min, x_min, y_max, x_max]')
|
52 |
+
self.data = {'boxes': data}
|
53 |
+
|
54 |
+
def num_boxes(self):
|
55 |
+
"""Return number of boxes held in collections."""
|
56 |
+
return self.data['boxes'].shape[0]
|
57 |
+
|
58 |
+
def get_extra_fields(self):
|
59 |
+
"""Return all non-box fields."""
|
60 |
+
return [k for k in self.data.keys() if k != 'boxes']
|
61 |
+
|
62 |
+
def has_field(self, field):
|
63 |
+
return field in self.data
|
64 |
+
|
65 |
+
def add_field(self, field, field_data):
|
66 |
+
"""Add data to a specified field.
|
67 |
+
Args:
|
68 |
+
field: a string parameter used to speficy a related field to be accessed.
|
69 |
+
field_data: a numpy array of [N, ...] representing the data associated
|
70 |
+
with the field.
|
71 |
+
Raises:
|
72 |
+
ValueError: if the field is already exist or the dimension of the field
|
73 |
+
data does not matches the number of boxes.
|
74 |
+
"""
|
75 |
+
if self.has_field(field):
|
76 |
+
raise ValueError('Field ' + field + 'already exists')
|
77 |
+
if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes():
|
78 |
+
raise ValueError('Invalid dimensions for field data')
|
79 |
+
self.data[field] = field_data
|
80 |
+
|
81 |
+
def get(self):
|
82 |
+
"""Convenience function for accesssing box coordinates.
|
83 |
+
Returns:
|
84 |
+
a numpy array of shape [N, 4] representing box corners
|
85 |
+
"""
|
86 |
+
return self.get_field('boxes')
|
87 |
+
|
88 |
+
def get_field(self, field):
|
89 |
+
"""Accesses data associated with the specified field in the box collection.
|
90 |
+
Args:
|
91 |
+
field: a string parameter used to speficy a related field to be accessed.
|
92 |
+
Returns:
|
93 |
+
a numpy 1-d array representing data of an associated field
|
94 |
+
Raises:
|
95 |
+
ValueError: if invalid field
|
96 |
+
"""
|
97 |
+
if not self.has_field(field):
|
98 |
+
raise ValueError('field {} does not exist'.format(field))
|
99 |
+
return self.data[field]
|
100 |
+
|
101 |
+
def get_coordinates(self):
|
102 |
+
"""Get corner coordinates of boxes.
|
103 |
+
Returns:
|
104 |
+
a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max]
|
105 |
+
"""
|
106 |
+
box_coordinates = self.get()
|
107 |
+
y_min = box_coordinates[:, 0]
|
108 |
+
x_min = box_coordinates[:, 1]
|
109 |
+
y_max = box_coordinates[:, 2]
|
110 |
+
x_max = box_coordinates[:, 3]
|
111 |
+
return [y_min, x_min, y_max, x_max]
|
112 |
+
|
113 |
+
def _is_valid_boxes(self, data):
|
114 |
+
"""Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin].
|
115 |
+
Args:
|
116 |
+
data: a numpy array of shape [N, 4] representing box coordinates
|
117 |
+
Returns:
|
118 |
+
a boolean indicating whether all ymax of boxes are equal or greater than
|
119 |
+
ymin, and all xmax of boxes are equal or greater than xmin.
|
120 |
+
"""
|
121 |
+
if data.shape[0] > 0:
|
122 |
+
for i in range(data.shape[0]):
|
123 |
+
if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]:
|
124 |
+
return False
|
125 |
+
return True
|
126 |
+
|
127 |
+
|
128 |
+
def area(boxes):
|
129 |
+
"""Computes area of boxes.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
boxes: Numpy array with shape [N, 4] holding N boxes
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
a numpy array with shape [N*1] representing box areas
|
136 |
+
"""
|
137 |
+
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
138 |
+
|
139 |
+
|
140 |
+
def intersection(boxes1, boxes2):
|
141 |
+
"""Compute pairwise intersection areas between boxes.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
boxes1: a numpy array with shape [N, 4] holding N boxes
|
145 |
+
boxes2: a numpy array with shape [M, 4] holding M boxes
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
a numpy array with shape [N*M] representing pairwise intersection area
|
149 |
+
"""
|
150 |
+
[y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1)
|
151 |
+
[y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1)
|
152 |
+
|
153 |
+
all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2))
|
154 |
+
all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2))
|
155 |
+
intersect_heights = np.maximum(np.zeros(all_pairs_max_ymin.shape), all_pairs_min_ymax - all_pairs_max_ymin)
|
156 |
+
all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2))
|
157 |
+
all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2))
|
158 |
+
intersect_widths = np.maximum(np.zeros(all_pairs_max_xmin.shape), all_pairs_min_xmax - all_pairs_max_xmin)
|
159 |
+
return intersect_heights * intersect_widths
|
160 |
+
|
161 |
+
|
162 |
+
def iou(boxes1, boxes2):
|
163 |
+
"""Computes pairwise intersection-over-union between box collections.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
boxes1: a numpy array with shape [N, 4] holding N boxes.
|
167 |
+
boxes2: a numpy array with shape [M, 4] holding N boxes.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
a numpy array with shape [N, M] representing pairwise iou scores.
|
171 |
+
"""
|
172 |
+
intersect = intersection(boxes1, boxes2)
|
173 |
+
area1 = area(boxes1)
|
174 |
+
area2 = area(boxes2)
|
175 |
+
union = np.expand_dims(area1, axis=1) + np.expand_dims(area2, axis=0) - intersect
|
176 |
+
return intersect / union
|
177 |
+
|
178 |
+
|
179 |
+
def ioa(boxes1, boxes2):
|
180 |
+
"""Computes pairwise intersection-over-area between box collections.
|
181 |
+
|
182 |
+
Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
|
183 |
+
their intersection area over box2's area. Note that ioa is not symmetric,
|
184 |
+
that is, IOA(box1, box2) != IOA(box2, box1).
|
185 |
+
|
186 |
+
Args:
|
187 |
+
boxes1: a numpy array with shape [N, 4] holding N boxes.
|
188 |
+
boxes2: a numpy array with shape [M, 4] holding N boxes.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
a numpy array with shape [N, M] representing pairwise ioa scores.
|
192 |
+
"""
|
193 |
+
intersect = intersection(boxes1, boxes2)
|
194 |
+
areas = np.expand_dims(area(boxes2), axis=0)
|
195 |
+
return intersect / areas
|
196 |
+
|
197 |
+
|
198 |
+
class SortOrder(object):
|
199 |
+
"""Enum class for sort order.
|
200 |
+
|
201 |
+
Attributes:
|
202 |
+
ascend: ascend order.
|
203 |
+
descend: descend order.
|
204 |
+
"""
|
205 |
+
ASCEND = 1
|
206 |
+
DESCEND = 2
|
207 |
+
|
208 |
+
|
209 |
+
def area_boxlist(boxlist):
|
210 |
+
"""Computes area of boxes.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
boxlist: BoxList holding N boxes
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
a numpy array with shape [N*1] representing box areas
|
217 |
+
"""
|
218 |
+
y_min, x_min, y_max, x_max = boxlist.get_coordinates()
|
219 |
+
return (y_max - y_min) * (x_max - x_min)
|
220 |
+
|
221 |
+
|
222 |
+
def intersection_boxlist(boxlist1, boxlist2):
|
223 |
+
"""Compute pairwise intersection areas between boxes.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
boxlist1: BoxList holding N boxes
|
227 |
+
boxlist2: BoxList holding M boxes
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
a numpy array with shape [N*M] representing pairwise intersection area
|
231 |
+
"""
|
232 |
+
return intersection(boxlist1.get(), boxlist2.get())
|
233 |
+
|
234 |
+
|
235 |
+
def iou_boxlist(boxlist1, boxlist2):
|
236 |
+
"""Computes pairwise intersection-over-union between box collections.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
boxlist1: BoxList holding N boxes
|
240 |
+
boxlist2: BoxList holding M boxes
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
a numpy array with shape [N, M] representing pairwise iou scores.
|
244 |
+
"""
|
245 |
+
return iou(boxlist1.get(), boxlist2.get())
|
246 |
+
|
247 |
+
|
248 |
+
def ioa_boxlist(boxlist1, boxlist2):
|
249 |
+
"""Computes pairwise intersection-over-area between box collections.
|
250 |
+
|
251 |
+
Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
|
252 |
+
their intersection area over box2's area. Note that ioa is not symmetric,
|
253 |
+
that is, IOA(box1, box2) != IOA(box2, box1).
|
254 |
+
|
255 |
+
Args:
|
256 |
+
boxlist1: BoxList holding N boxes
|
257 |
+
boxlist2: BoxList holding M boxes
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
a numpy array with shape [N, M] representing pairwise ioa scores.
|
261 |
+
"""
|
262 |
+
return ioa(boxlist1.get(), boxlist2.get())
|
263 |
+
|
264 |
+
|
265 |
+
def gather_boxlist(boxlist, indices, fields=None):
|
266 |
+
"""Gather boxes from BoxList according to indices and return new BoxList.
|
267 |
+
|
268 |
+
By default, gather returns boxes corresponding to the input index list, as
|
269 |
+
well as all additional fields stored in the boxlist (indexing into the
|
270 |
+
first dimension). However one can optionally only gather from a
|
271 |
+
subset of fields.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
boxlist: BoxList holding N boxes
|
275 |
+
indices: a 1-d numpy array of type int_
|
276 |
+
fields: (optional) list of fields to also gather from. If None (default),
|
277 |
+
all fields are gathered from. Pass an empty fields list to only gather the box coordinates.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
subboxlist: a BoxList corresponding to the subset of the input BoxList specified by indices
|
281 |
+
|
282 |
+
Raises:
|
283 |
+
ValueError: if specified field is not contained in boxlist or if the indices are not of type int_
|
284 |
+
"""
|
285 |
+
if indices.size:
|
286 |
+
if np.amax(indices) >= boxlist.num_boxes() or np.amin(indices) < 0:
|
287 |
+
raise ValueError('indices are out of valid range.')
|
288 |
+
subboxlist = BoxList(boxlist.get()[indices, :])
|
289 |
+
if fields is None:
|
290 |
+
fields = boxlist.get_extra_fields()
|
291 |
+
for field in fields:
|
292 |
+
extra_field_data = boxlist.get_field(field)
|
293 |
+
subboxlist.add_field(field, extra_field_data[indices, ...])
|
294 |
+
return subboxlist
|
295 |
+
|
296 |
+
|
297 |
+
def sort_by_field_boxlist(boxlist, field, order=SortOrder.DESCEND):
|
298 |
+
"""Sort boxes and associated fields according to a scalar field.
|
299 |
+
|
300 |
+
A common use case is reordering the boxes according to descending scores.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
boxlist: BoxList holding N boxes.
|
304 |
+
field: A BoxList field for sorting and reordering the BoxList.
|
305 |
+
order: (Optional) 'descend' or 'ascend'. Default is descend.
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
sorted_boxlist: A sorted BoxList with the field in the specified order.
|
309 |
+
|
310 |
+
Raises:
|
311 |
+
ValueError: if specified field does not exist or is not of single dimension.
|
312 |
+
ValueError: if the order is not either descend or ascend.
|
313 |
+
"""
|
314 |
+
if not boxlist.has_field(field):
|
315 |
+
raise ValueError('Field ' + field + ' does not exist')
|
316 |
+
if len(boxlist.get_field(field).shape) != 1:
|
317 |
+
raise ValueError('Field ' + field + 'should be single dimension.')
|
318 |
+
if order != SortOrder.DESCEND and order != SortOrder.ASCEND:
|
319 |
+
raise ValueError('Invalid sort order')
|
320 |
+
|
321 |
+
field_to_sort = boxlist.get_field(field)
|
322 |
+
sorted_indices = np.argsort(field_to_sort)
|
323 |
+
if order == SortOrder.DESCEND:
|
324 |
+
sorted_indices = sorted_indices[::-1]
|
325 |
+
return gather_boxlist(boxlist, sorted_indices)
|
326 |
+
|
327 |
+
|
328 |
+
def non_max_suppression(boxlist, max_output_size=10000, iou_threshold=1.0, score_threshold=-10.0):
|
329 |
+
"""Non maximum suppression.
|
330 |
+
|
331 |
+
This op greedily selects a subset of detection bounding boxes, pruning
|
332 |
+
away boxes that have high IOU (intersection over union) overlap (> thresh)
|
333 |
+
with already selected boxes. In each iteration, the detected bounding box with
|
334 |
+
highest score in the available pool is selected.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
boxlist: BoxList holding N boxes. Must contain a 'scores' field
|
338 |
+
representing detection scores. All scores belong to the same class.
|
339 |
+
max_output_size: maximum number of retained boxes
|
340 |
+
iou_threshold: intersection over union threshold.
|
341 |
+
score_threshold: minimum score threshold. Remove the boxes with scores less than
|
342 |
+
this value. Default value is set to -10. A very low threshold to pass pretty
|
343 |
+
much all the boxes, unless the user sets a different score threshold.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
a BoxList holding M boxes where M <= max_output_size
|
347 |
+
Raises:
|
348 |
+
ValueError: if 'scores' field does not exist
|
349 |
+
ValueError: if threshold is not in [0, 1]
|
350 |
+
ValueError: if max_output_size < 0
|
351 |
+
"""
|
352 |
+
if not boxlist.has_field('scores'):
|
353 |
+
raise ValueError('Field scores does not exist')
|
354 |
+
if iou_threshold < 0. or iou_threshold > 1.0:
|
355 |
+
raise ValueError('IOU threshold must be in [0, 1]')
|
356 |
+
if max_output_size < 0:
|
357 |
+
raise ValueError('max_output_size must be bigger than 0.')
|
358 |
+
|
359 |
+
boxlist = filter_scores_greater_than(boxlist, score_threshold)
|
360 |
+
if boxlist.num_boxes() == 0:
|
361 |
+
return boxlist
|
362 |
+
|
363 |
+
boxlist = sort_by_field_boxlist(boxlist, 'scores')
|
364 |
+
|
365 |
+
# Prevent further computation if NMS is disabled.
|
366 |
+
if iou_threshold == 1.0:
|
367 |
+
if boxlist.num_boxes() > max_output_size:
|
368 |
+
selected_indices = np.arange(max_output_size)
|
369 |
+
return gather_boxlist(boxlist, selected_indices)
|
370 |
+
else:
|
371 |
+
return boxlist
|
372 |
+
|
373 |
+
boxes = boxlist.get()
|
374 |
+
num_boxes = boxlist.num_boxes()
|
375 |
+
# is_index_valid is True only for all remaining valid boxes,
|
376 |
+
is_index_valid = np.full(num_boxes, 1, dtype=bool)
|
377 |
+
selected_indices = []
|
378 |
+
num_output = 0
|
379 |
+
for i in range(num_boxes):
|
380 |
+
if num_output < max_output_size:
|
381 |
+
if is_index_valid[i]:
|
382 |
+
num_output += 1
|
383 |
+
selected_indices.append(i)
|
384 |
+
is_index_valid[i] = False
|
385 |
+
valid_indices = np.where(is_index_valid)[0]
|
386 |
+
if valid_indices.size == 0:
|
387 |
+
break
|
388 |
+
|
389 |
+
intersect_over_union = iou(np.expand_dims(boxes[i, :], axis=0), boxes[valid_indices, :])
|
390 |
+
intersect_over_union = np.squeeze(intersect_over_union, axis=0)
|
391 |
+
is_index_valid[valid_indices] = np.logical_and(
|
392 |
+
is_index_valid[valid_indices],
|
393 |
+
intersect_over_union <= iou_threshold)
|
394 |
+
return gather_boxlist(boxlist, np.array(selected_indices))
|
395 |
+
|
396 |
+
|
397 |
+
def multi_class_non_max_suppression(boxlist, score_thresh, iou_thresh, max_output_size):
|
398 |
+
"""Multi-class version of non maximum suppression.
|
399 |
+
|
400 |
+
This op greedily selects a subset of detection bounding boxes, pruning
|
401 |
+
away boxes that have high IOU (intersection over union) overlap (> thresh)
|
402 |
+
with already selected boxes. It operates independently for each class for
|
403 |
+
which scores are provided (via the scores field of the input box_list),
|
404 |
+
pruning boxes with score less than a provided threshold prior to
|
405 |
+
applying NMS.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
boxlist: BoxList holding N boxes. Must contain a 'scores' field
|
409 |
+
representing detection scores. This scores field is a tensor that can
|
410 |
+
be 1 dimensional (in the case of a single class) or 2-dimensional, which
|
411 |
+
which case we assume that it takes the shape [num_boxes, num_classes].
|
412 |
+
We further assume that this rank is known statically and that
|
413 |
+
scores.shape[1] is also known (i.e., the number of classes is fixed
|
414 |
+
and known at graph construction time).
|
415 |
+
score_thresh: scalar threshold for score (low scoring boxes are removed).
|
416 |
+
iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap
|
417 |
+
with previously selected boxes are removed).
|
418 |
+
max_output_size: maximum number of retained boxes per class.
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
a BoxList holding M boxes with a rank-1 scores field representing
|
422 |
+
corresponding scores for each box with scores sorted in decreasing order
|
423 |
+
and a rank-1 classes field representing a class label for each box.
|
424 |
+
Raises:
|
425 |
+
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
|
426 |
+
a valid scores field.
|
427 |
+
"""
|
428 |
+
if not 0 <= iou_thresh <= 1.0:
|
429 |
+
raise ValueError('thresh must be between 0 and 1')
|
430 |
+
if not isinstance(boxlist, BoxList):
|
431 |
+
raise ValueError('boxlist must be a BoxList')
|
432 |
+
if not boxlist.has_field('scores'):
|
433 |
+
raise ValueError('input boxlist must have \'scores\' field')
|
434 |
+
scores = boxlist.get_field('scores')
|
435 |
+
if len(scores.shape) == 1:
|
436 |
+
scores = np.reshape(scores, [-1, 1])
|
437 |
+
elif len(scores.shape) == 2:
|
438 |
+
if scores.shape[1] is None:
|
439 |
+
raise ValueError('scores field must have statically defined second dimension')
|
440 |
+
else:
|
441 |
+
raise ValueError('scores field must be of rank 1 or 2')
|
442 |
+
num_boxes = boxlist.num_boxes()
|
443 |
+
num_scores = scores.shape[0]
|
444 |
+
num_classes = scores.shape[1]
|
445 |
+
|
446 |
+
if num_boxes != num_scores:
|
447 |
+
raise ValueError('Incorrect scores field length: actual vs expected.')
|
448 |
+
|
449 |
+
selected_boxes_list = []
|
450 |
+
for class_idx in range(num_classes):
|
451 |
+
boxlist_and_class_scores = BoxList(boxlist.get())
|
452 |
+
class_scores = np.reshape(scores[0:num_scores, class_idx], [-1])
|
453 |
+
boxlist_and_class_scores.add_field('scores', class_scores)
|
454 |
+
boxlist_filt = filter_scores_greater_than(boxlist_and_class_scores, score_thresh)
|
455 |
+
nms_result = non_max_suppression(
|
456 |
+
boxlist_filt, max_output_size=max_output_size, iou_threshold=iou_thresh, score_threshold=score_thresh)
|
457 |
+
nms_result.add_field('classes', np.zeros_like(nms_result.get_field('scores')) + class_idx)
|
458 |
+
selected_boxes_list.append(nms_result)
|
459 |
+
selected_boxes = concatenate_boxlist(selected_boxes_list)
|
460 |
+
sorted_boxes = sort_by_field_boxlist(selected_boxes, 'scores')
|
461 |
+
return sorted_boxes
|
462 |
+
|
463 |
+
|
464 |
+
def scale(boxlist, y_scale, x_scale):
|
465 |
+
"""Scale box coordinates in x and y dimensions.
|
466 |
+
|
467 |
+
Args:
|
468 |
+
boxlist: BoxList holding N boxes
|
469 |
+
y_scale: float
|
470 |
+
x_scale: float
|
471 |
+
|
472 |
+
Returns:
|
473 |
+
boxlist: BoxList holding N boxes
|
474 |
+
"""
|
475 |
+
y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
|
476 |
+
y_min = y_scale * y_min
|
477 |
+
y_max = y_scale * y_max
|
478 |
+
x_min = x_scale * x_min
|
479 |
+
x_max = x_scale * x_max
|
480 |
+
scaled_boxlist = BoxList(np.hstack([y_min, x_min, y_max, x_max]))
|
481 |
+
|
482 |
+
fields = boxlist.get_extra_fields()
|
483 |
+
for field in fields:
|
484 |
+
extra_field_data = boxlist.get_field(field)
|
485 |
+
scaled_boxlist.add_field(field, extra_field_data)
|
486 |
+
|
487 |
+
return scaled_boxlist
|
488 |
+
|
489 |
+
|
490 |
+
def clip_to_window(boxlist, window, filter_nonoverlapping=True):
|
491 |
+
"""Clip bounding boxes to a window.
|
492 |
+
|
493 |
+
This op clips input bounding boxes (represented by bounding box
|
494 |
+
corners) to a window, optionally filtering out boxes that do not
|
495 |
+
overlap at all with the window.
|
496 |
+
|
497 |
+
Args:
|
498 |
+
boxlist: BoxList holding M_in boxes
|
499 |
+
window: a numpy array of shape [4] representing the [y_min, x_min, y_max, x_max]
|
500 |
+
window to which the op should clip boxes.
|
501 |
+
filter_nonoverlapping: whether to filter out boxes that do not overlap at all with the window.
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
a BoxList holding M_out boxes where M_out <= M_in
|
505 |
+
"""
|
506 |
+
y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
|
507 |
+
win_y_min = window[0]
|
508 |
+
win_x_min = window[1]
|
509 |
+
win_y_max = window[2]
|
510 |
+
win_x_max = window[3]
|
511 |
+
y_min_clipped = np.fmax(np.fmin(y_min, win_y_max), win_y_min)
|
512 |
+
y_max_clipped = np.fmax(np.fmin(y_max, win_y_max), win_y_min)
|
513 |
+
x_min_clipped = np.fmax(np.fmin(x_min, win_x_max), win_x_min)
|
514 |
+
x_max_clipped = np.fmax(np.fmin(x_max, win_x_max), win_x_min)
|
515 |
+
clipped = BoxList(np.hstack([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped]))
|
516 |
+
clipped = _copy_extra_fields(clipped, boxlist)
|
517 |
+
if filter_nonoverlapping:
|
518 |
+
areas = area(clipped)
|
519 |
+
nonzero_area_indices = np.reshape(np.nonzero(np.greater(areas, 0.0)), [-1]).astype(np.int32)
|
520 |
+
clipped = gather_boxlist(clipped, nonzero_area_indices)
|
521 |
+
return clipped
|
522 |
+
|
523 |
+
|
524 |
+
def prune_non_overlapping_boxes(boxlist1, boxlist2, minoverlap=0.0):
|
525 |
+
"""Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2.
|
526 |
+
|
527 |
+
For each box in boxlist1, we want its IOA to be more than minoverlap with
|
528 |
+
at least one of the boxes in boxlist2. If it does not, we remove it.
|
529 |
+
|
530 |
+
Args:
|
531 |
+
boxlist1: BoxList holding N boxes.
|
532 |
+
boxlist2: BoxList holding M boxes.
|
533 |
+
minoverlap: Minimum required overlap between boxes, to count them as overlapping.
|
534 |
+
|
535 |
+
Returns:
|
536 |
+
A pruned boxlist with size [N', 4].
|
537 |
+
"""
|
538 |
+
intersection_over_area = ioa(boxlist2, boxlist1) # [M, N] tensor
|
539 |
+
intersection_over_area = np.amax(intersection_over_area, axis=0) # [N] tensor
|
540 |
+
keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap))
|
541 |
+
keep_inds = np.nonzero(keep_bool)[0]
|
542 |
+
new_boxlist1 = gather_boxlist(boxlist1, keep_inds)
|
543 |
+
return new_boxlist1
|
544 |
+
|
545 |
+
|
546 |
+
def prune_outside_window(boxlist, window):
|
547 |
+
"""Prunes bounding boxes that fall outside a given window.
|
548 |
+
|
549 |
+
This function prunes bounding boxes that even partially fall outside the given
|
550 |
+
window. See also ClipToWindow which only prunes bounding boxes that fall
|
551 |
+
completely outside the window, and clips any bounding boxes that partially
|
552 |
+
overflow.
|
553 |
+
|
554 |
+
Args:
|
555 |
+
boxlist: a BoxList holding M_in boxes.
|
556 |
+
window: a numpy array of size 4, representing [ymin, xmin, ymax, xmax] of the window.
|
557 |
+
|
558 |
+
Returns:
|
559 |
+
pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in.
|
560 |
+
valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes in the input tensor.
|
561 |
+
"""
|
562 |
+
|
563 |
+
y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
|
564 |
+
win_y_min = window[0]
|
565 |
+
win_x_min = window[1]
|
566 |
+
win_y_max = window[2]
|
567 |
+
win_x_max = window[3]
|
568 |
+
coordinate_violations = np.hstack([
|
569 |
+
np.less(y_min, win_y_min), np.less(x_min, win_x_min),
|
570 |
+
np.greater(y_max, win_y_max), np.greater(x_max, win_x_max)])
|
571 |
+
valid_indices = np.reshape(np.where(np.logical_not(np.max(coordinate_violations, axis=1))), [-1])
|
572 |
+
return gather_boxlist(boxlist, valid_indices), valid_indices
|
573 |
+
|
574 |
+
|
575 |
+
def concatenate_boxlist(boxlists, fields=None):
|
576 |
+
"""Concatenate list of BoxLists.
|
577 |
+
|
578 |
+
This op concatenates a list of input BoxLists into a larger BoxList. It also
|
579 |
+
handles concatenation of BoxList fields as long as the field tensor shapes
|
580 |
+
are equal except for the first dimension.
|
581 |
+
|
582 |
+
Args:
|
583 |
+
boxlists: list of BoxList objects
|
584 |
+
fields: optional list of fields to also concatenate. By default, all
|
585 |
+
fields from the first BoxList in the list are included in the concatenation.
|
586 |
+
|
587 |
+
Returns:
|
588 |
+
a BoxList with number of boxes equal to
|
589 |
+
sum([boxlist.num_boxes() for boxlist in BoxList])
|
590 |
+
Raises:
|
591 |
+
ValueError: if boxlists is invalid (i.e., is not a list, is empty, or
|
592 |
+
contains non BoxList objects), or if requested fields are not contained in all boxlists
|
593 |
+
"""
|
594 |
+
if not isinstance(boxlists, list):
|
595 |
+
raise ValueError('boxlists should be a list')
|
596 |
+
if not boxlists:
|
597 |
+
raise ValueError('boxlists should have nonzero length')
|
598 |
+
for boxlist in boxlists:
|
599 |
+
if not isinstance(boxlist, BoxList):
|
600 |
+
raise ValueError('all elements of boxlists should be BoxList objects')
|
601 |
+
concatenated = BoxList(np.vstack([boxlist.get() for boxlist in boxlists]))
|
602 |
+
if fields is None:
|
603 |
+
fields = boxlists[0].get_extra_fields()
|
604 |
+
for field in fields:
|
605 |
+
first_field_shape = boxlists[0].get_field(field).shape
|
606 |
+
first_field_shape = first_field_shape[1:]
|
607 |
+
for boxlist in boxlists:
|
608 |
+
if not boxlist.has_field(field):
|
609 |
+
raise ValueError('boxlist must contain all requested fields')
|
610 |
+
field_shape = boxlist.get_field(field).shape
|
611 |
+
field_shape = field_shape[1:]
|
612 |
+
if field_shape != first_field_shape:
|
613 |
+
raise ValueError('field %s must have same shape for all boxlists '
|
614 |
+
'except for the 0th dimension.' % field)
|
615 |
+
concatenated_field = np.concatenate([boxlist.get_field(field) for boxlist in boxlists], axis=0)
|
616 |
+
concatenated.add_field(field, concatenated_field)
|
617 |
+
return concatenated
|
618 |
+
|
619 |
+
|
620 |
+
def filter_scores_greater_than(boxlist, thresh):
|
621 |
+
"""Filter to keep only boxes with score exceeding a given threshold.
|
622 |
+
|
623 |
+
This op keeps the collection of boxes whose corresponding scores are
|
624 |
+
greater than the input threshold.
|
625 |
+
|
626 |
+
Args:
|
627 |
+
boxlist: BoxList holding N boxes. Must contain a 'scores' field representing detection scores.
|
628 |
+
thresh: scalar threshold
|
629 |
+
|
630 |
+
Returns:
|
631 |
+
a BoxList holding M boxes where M <= N
|
632 |
+
|
633 |
+
Raises:
|
634 |
+
ValueError: if boxlist not a BoxList object or if it does not have a scores field
|
635 |
+
"""
|
636 |
+
if not isinstance(boxlist, BoxList):
|
637 |
+
raise ValueError('boxlist must be a BoxList')
|
638 |
+
if not boxlist.has_field('scores'):
|
639 |
+
raise ValueError('input boxlist must have \'scores\' field')
|
640 |
+
scores = boxlist.get_field('scores')
|
641 |
+
if len(scores.shape) > 2:
|
642 |
+
raise ValueError('Scores should have rank 1 or 2')
|
643 |
+
if len(scores.shape) == 2 and scores.shape[1] != 1:
|
644 |
+
raise ValueError('Scores should have rank 1 or have shape '
|
645 |
+
'consistent with [None, 1]')
|
646 |
+
high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), [-1]).astype(np.int32)
|
647 |
+
return gather_boxlist(boxlist, high_score_indices)
|
648 |
+
|
649 |
+
|
650 |
+
def change_coordinate_frame(boxlist, window):
|
651 |
+
"""Change coordinate frame of the boxlist to be relative to window's frame.
|
652 |
+
|
653 |
+
Given a window of the form [ymin, xmin, ymax, xmax],
|
654 |
+
changes bounding box coordinates from boxlist to be relative to this window
|
655 |
+
(e.g., the min corner maps to (0,0) and the max corner maps to (1,1)).
|
656 |
+
|
657 |
+
An example use case is data augmentation: where we are given groundtruth
|
658 |
+
boxes (boxlist) and would like to randomly crop the image to some
|
659 |
+
window (window). In this case we need to change the coordinate frame of
|
660 |
+
each groundtruth box to be relative to this new window.
|
661 |
+
|
662 |
+
Args:
|
663 |
+
boxlist: A BoxList object holding N boxes.
|
664 |
+
window: a size 4 1-D numpy array.
|
665 |
+
|
666 |
+
Returns:
|
667 |
+
Returns a BoxList object with N boxes.
|
668 |
+
"""
|
669 |
+
win_height = window[2] - window[0]
|
670 |
+
win_width = window[3] - window[1]
|
671 |
+
boxlist_new = scale(
|
672 |
+
BoxList(boxlist.get() - [window[0], window[1], window[0], window[1]]), 1.0 / win_height, 1.0 / win_width)
|
673 |
+
_copy_extra_fields(boxlist_new, boxlist)
|
674 |
+
|
675 |
+
return boxlist_new
|
676 |
+
|
677 |
+
|
678 |
+
def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
|
679 |
+
"""Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
|
680 |
+
|
681 |
+
Args:
|
682 |
+
boxlist_to_copy_to: BoxList to which extra fields are copied.
|
683 |
+
boxlist_to_copy_from: BoxList from which fields are copied.
|
684 |
+
|
685 |
+
Returns:
|
686 |
+
boxlist_to_copy_to with extra fields.
|
687 |
+
"""
|
688 |
+
for field in boxlist_to_copy_from.get_extra_fields():
|
689 |
+
boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field))
|
690 |
+
return boxlist_to_copy_to
|
691 |
+
|
692 |
+
|
693 |
+
def _update_valid_indices_by_removing_high_iou_boxes(
|
694 |
+
selected_indices, is_index_valid, intersect_over_union, threshold):
|
695 |
+
max_iou = np.max(intersect_over_union[:, selected_indices], axis=1)
|
696 |
+
return np.logical_and(is_index_valid, max_iou <= threshold)
|
efficientdet/effdet/evaluation/np_mask_list.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from .np_box_list import *
|
3 |
+
|
4 |
+
EPSILON = 1e-7
|
5 |
+
|
6 |
+
|
7 |
+
class MaskList(BoxList):
|
8 |
+
"""Convenience wrapper for BoxList with masks.
|
9 |
+
|
10 |
+
BoxMaskList extends the np_box_list.BoxList to contain masks as well.
|
11 |
+
In particular, its constructor receives both boxes and masks. Note that the
|
12 |
+
masks correspond to the full image.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, box_data, mask_data):
|
16 |
+
"""Constructs box collection.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
box_data: a numpy array of shape [N, 4] representing box coordinates
|
20 |
+
mask_data: a numpy array of shape [N, height, width] representing masks
|
21 |
+
with values are in {0,1}. The masks correspond to the full
|
22 |
+
image. The height and the width will be equal to image height and width.
|
23 |
+
|
24 |
+
Raises:
|
25 |
+
ValueError: if bbox data is not a numpy array
|
26 |
+
ValueError: if invalid dimensions for bbox data
|
27 |
+
ValueError: if mask data is not a numpy array
|
28 |
+
ValueError: if invalid dimension for mask data
|
29 |
+
"""
|
30 |
+
super(MaskList, self).__init__(box_data)
|
31 |
+
if not isinstance(mask_data, np.ndarray):
|
32 |
+
raise ValueError('Mask data must be a numpy array.')
|
33 |
+
if len(mask_data.shape) != 3:
|
34 |
+
raise ValueError('Invalid dimensions for mask data.')
|
35 |
+
if mask_data.dtype != np.uint8:
|
36 |
+
raise ValueError('Invalid data type for mask data: uint8 is required.')
|
37 |
+
if mask_data.shape[0] != box_data.shape[0]:
|
38 |
+
raise ValueError('There should be the same number of boxes and masks.')
|
39 |
+
self.data['masks'] = mask_data
|
40 |
+
|
41 |
+
def get_masks(self):
|
42 |
+
"""Convenience function for accessing masks.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
a numpy array of shape [N, height, width] representing masks
|
46 |
+
"""
|
47 |
+
return self.get_field('masks')
|
48 |
+
|
49 |
+
|
50 |
+
def boxlist_to_masklist(boxlist):
|
51 |
+
"""Converts a BoxList containing 'masks' into a BoxMaskList.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
boxlist: An np_box_list.BoxList object.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
An BoxMaskList object.
|
58 |
+
|
59 |
+
Raises:
|
60 |
+
ValueError: If boxlist does not contain `masks` as a field.
|
61 |
+
"""
|
62 |
+
if not boxlist.has_field('masks'):
|
63 |
+
raise ValueError('boxlist does not contain mask field.')
|
64 |
+
masklist = MaskList(box_data=boxlist.get(), mask_data=boxlist.get_field('masks'))
|
65 |
+
extra_fields = boxlist.get_extra_fields()
|
66 |
+
for key in extra_fields:
|
67 |
+
if key != 'masks':
|
68 |
+
masklist.data[key] = boxlist.get_field(key)
|
69 |
+
return masklist
|
70 |
+
|
71 |
+
|
72 |
+
def area_mask(masks):
|
73 |
+
"""Computes area of masks.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
masks: Numpy array with shape [N, height, width] holding N masks. Masks
|
77 |
+
values are of type np.uint8 and values are in {0,1}.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
a numpy array with shape [N*1] representing mask areas.
|
81 |
+
|
82 |
+
Raises:
|
83 |
+
ValueError: If masks.dtype is not np.uint8
|
84 |
+
"""
|
85 |
+
if masks.dtype != np.uint8:
|
86 |
+
raise ValueError('Masks type should be np.uint8')
|
87 |
+
return np.sum(masks, axis=(1, 2), dtype=np.float32)
|
88 |
+
|
89 |
+
|
90 |
+
def intersection_mask(masks1, masks2):
|
91 |
+
"""Compute pairwise intersection areas between masks.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
masks1: a numpy array with shape [N, height, width] holding N masks. Masks
|
95 |
+
values are of type np.uint8 and values are in {0,1}.
|
96 |
+
masks2: a numpy array with shape [M, height, width] holding M masks. Masks
|
97 |
+
values are of type np.uint8 and values are in {0,1}.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
a numpy array with shape [N*M] representing pairwise intersection area.
|
101 |
+
|
102 |
+
Raises:
|
103 |
+
ValueError: If masks1 and masks2 are not of type np.uint8.
|
104 |
+
"""
|
105 |
+
if masks1.dtype != np.uint8 or masks2.dtype != np.uint8:
|
106 |
+
raise ValueError('masks1 and masks2 should be of type np.uint8')
|
107 |
+
n = masks1.shape[0]
|
108 |
+
m = masks2.shape[0]
|
109 |
+
answer = np.zeros([n, m], dtype=np.float32)
|
110 |
+
for i in np.arange(n):
|
111 |
+
for j in np.arange(m):
|
112 |
+
answer[i, j] = np.sum(np.minimum(masks1[i], masks2[j]), dtype=np.float32)
|
113 |
+
return answer
|
114 |
+
|
115 |
+
|
116 |
+
def iou_mask(masks1, masks2):
|
117 |
+
"""Computes pairwise intersection-over-union between mask collections.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
masks1: a numpy array with shape [N, height, width] holding N masks. Masks
|
121 |
+
values are of type np.uint8 and values are in {0,1}.
|
122 |
+
masks2: a numpy array with shape [M, height, width] holding N masks. Masks
|
123 |
+
values are of type np.uint8 and values are in {0,1}.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
a numpy array with shape [N, M] representing pairwise iou scores.
|
127 |
+
|
128 |
+
Raises:
|
129 |
+
ValueError: If masks1 and masks2 are not of type np.uint8.
|
130 |
+
"""
|
131 |
+
if masks1.dtype != np.uint8 or masks2.dtype != np.uint8:
|
132 |
+
raise ValueError('masks1 and masks2 should be of type np.uint8')
|
133 |
+
intersect = intersection(masks1, masks2)
|
134 |
+
area1 = area(masks1)
|
135 |
+
area2 = area(masks2)
|
136 |
+
union = np.expand_dims(area1, axis=1) + np.expand_dims(area2, axis=0) - intersect
|
137 |
+
return intersect / np.maximum(union, EPSILON)
|
138 |
+
|
139 |
+
|
140 |
+
def ioa_mask(masks1, masks2):
|
141 |
+
"""Computes pairwise intersection-over-area between box collections.
|
142 |
+
|
143 |
+
Intersection-over-area (ioa) between two masks, mask1 and mask2 is defined as
|
144 |
+
their intersection area over mask2's area. Note that ioa is not symmetric,
|
145 |
+
that is, IOA(mask1, mask2) != IOA(mask2, mask1).
|
146 |
+
|
147 |
+
Args:
|
148 |
+
masks1: a numpy array with shape [N, height, width] holding N masks. Masks
|
149 |
+
values are of type np.uint8 and values are in {0,1}.
|
150 |
+
masks2: a numpy array with shape [M, height, width] holding N masks. Masks
|
151 |
+
values are of type np.uint8 and values are in {0,1}.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
a numpy array with shape [N, M] representing pairwise ioa scores.
|
155 |
+
|
156 |
+
Raises:
|
157 |
+
ValueError: If masks1 and masks2 are not of type np.uint8.
|
158 |
+
"""
|
159 |
+
if masks1.dtype != np.uint8 or masks2.dtype != np.uint8:
|
160 |
+
raise ValueError('masks1 and masks2 should be of type np.uint8')
|
161 |
+
intersect = intersection(masks1, masks2)
|
162 |
+
areas = np.expand_dims(area(masks2), axis=0)
|
163 |
+
return intersect / (areas + EPSILON)
|
164 |
+
|
165 |
+
|
166 |
+
def area_masklist(masklist):
|
167 |
+
"""Computes area of masks.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
masklist: BoxMaskList holding N boxes and masks
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
a numpy array with shape [N*1] representing mask areas
|
174 |
+
"""
|
175 |
+
return area_mask(masklist.get_masks())
|
176 |
+
|
177 |
+
|
178 |
+
def intersection_masklist(masklist1, masklist2):
|
179 |
+
"""Compute pairwise intersection areas between masks.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
masklist1: BoxMaskList holding N boxes and masks
|
183 |
+
masklist2: BoxMaskList holding M boxes and masks
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
a numpy array with shape [N*M] representing pairwise intersection area
|
187 |
+
"""
|
188 |
+
return intersection_mask(masklist1.get_masks(), masklist2.get_masks())
|
189 |
+
|
190 |
+
|
191 |
+
def iou_masklist(masklist1, masklist2):
|
192 |
+
"""Computes pairwise intersection-over-union between box and mask collections.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
masklist1: BoxMaskList holding N boxes and masks
|
196 |
+
masklist2: BoxMaskList holding M boxes and masks
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
a numpy array with shape [N, M] representing pairwise iou scores.
|
200 |
+
"""
|
201 |
+
return iou_mask(masklist1.get_masks(), masklist2.get_masks())
|
202 |
+
|
203 |
+
|
204 |
+
def ioa_masklist(masklist1, masklist2):
|
205 |
+
"""Computes pairwise intersection-over-area between box and mask collections.
|
206 |
+
|
207 |
+
Intersection-over-area (ioa) between two masks mask1 and mask2 is defined as
|
208 |
+
their intersection area over mask2's area. Note that ioa is not symmetric,
|
209 |
+
that is, IOA(mask1, mask2) != IOA(mask2, mask1).
|
210 |
+
|
211 |
+
Args:
|
212 |
+
masklist1: BoxMaskList holding N boxes and masks
|
213 |
+
masklist2: BoxMaskList holding M boxes and masks
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
a numpy array with shape [N, M] representing pairwise ioa scores.
|
217 |
+
"""
|
218 |
+
return ioa_mask(masklist1.get_masks(), masklist2.get_masks())
|
219 |
+
|
220 |
+
|
221 |
+
def gather_masklist(masklist, indices, fields=None):
|
222 |
+
"""Gather boxes from BoxMaskList according to indices.
|
223 |
+
|
224 |
+
By default, gather returns boxes corresponding to the input index list, as
|
225 |
+
well as all additional fields stored in the masklist (indexing into the
|
226 |
+
first dimension). However one can optionally only gather from a
|
227 |
+
subset of fields.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
masklist: BoxMaskList holding N boxes
|
231 |
+
indices: a 1-d numpy array of type int_
|
232 |
+
fields: (optional) list of fields to also gather from. If None (default), all fields
|
233 |
+
are gathered from. Pass an empty fields list to only gather the box coordinates.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
submasklist: a BoxMaskList corresponding to the subset of the input masklist specified by indices
|
237 |
+
|
238 |
+
Raises:
|
239 |
+
ValueError: if specified field is not contained in masklist or if the indices are not of type int_
|
240 |
+
"""
|
241 |
+
if fields is not None:
|
242 |
+
if 'masks' not in fields:
|
243 |
+
fields.append('masks')
|
244 |
+
return boxlist_to_masklist(gather_boxlist(boxlist=masklist, indices=indices, fields=fields))
|
245 |
+
|
246 |
+
|
247 |
+
def sort_by_field_masklist(masklist, field, order=SortOrder.DESCEND):
|
248 |
+
"""Sort boxes and associated fields according to a scalar field.
|
249 |
+
|
250 |
+
A common use case is reordering the boxes according to descending scores.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
masklist: BoxMaskList holding N boxes.
|
254 |
+
field: A BoxMaskList field for sorting and reordering the BoxMaskList.
|
255 |
+
order: (Optional) 'descend' or 'ascend'. Default is descend.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
sorted_masklist: A sorted BoxMaskList with the field in the specified order.
|
259 |
+
"""
|
260 |
+
return boxlist_to_masklist(sort_by_field_boxlist(boxlist=masklist, field=field, order=order))
|
261 |
+
|
262 |
+
|
263 |
+
def non_max_suppression_mask(masklist, max_output_size=10000, iou_threshold=1.0, score_threshold=-10.0):
|
264 |
+
"""Non maximum suppression.
|
265 |
+
|
266 |
+
This op greedily selects a subset of detection bounding boxes, pruning
|
267 |
+
away boxes that have high IOU (intersection over union) overlap (> thresh)
|
268 |
+
with already selected boxes. In each iteration, the detected bounding box with
|
269 |
+
highest score in the available pool is selected.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
masklist: BoxMaskList holding N boxes. Must contain a 'scores' field representing
|
273 |
+
detection scores. All scores belong to the same class.
|
274 |
+
max_output_size: maximum number of retained boxes
|
275 |
+
iou_threshold: intersection over union threshold.
|
276 |
+
score_threshold: minimum score threshold. Remove the boxes with scores
|
277 |
+
less than this value. Default value is set to -10. A very
|
278 |
+
low threshold to pass pretty much all the boxes, unless
|
279 |
+
the user sets a different score threshold.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
an BoxMaskList holding M boxes where M <= max_output_size
|
283 |
+
|
284 |
+
Raises:
|
285 |
+
ValueError: if 'scores' field does not exist
|
286 |
+
ValueError: if threshold is not in [0, 1]
|
287 |
+
ValueError: if max_output_size < 0
|
288 |
+
"""
|
289 |
+
if not masklist.has_field('scores'):
|
290 |
+
raise ValueError('Field scores does not exist')
|
291 |
+
if iou_threshold < 0. or iou_threshold > 1.0:
|
292 |
+
raise ValueError('IOU threshold must be in [0, 1]')
|
293 |
+
if max_output_size < 0:
|
294 |
+
raise ValueError('max_output_size must be bigger than 0.')
|
295 |
+
|
296 |
+
masklist = filter_scores_greater_than(masklist, score_threshold)
|
297 |
+
if masklist.num_boxes() == 0:
|
298 |
+
return masklist
|
299 |
+
|
300 |
+
masklist = sort_by_field_boxlist(masklist, 'scores')
|
301 |
+
|
302 |
+
# Prevent further computation if NMS is disabled.
|
303 |
+
if iou_threshold == 1.0:
|
304 |
+
if masklist.num_boxes() > max_output_size:
|
305 |
+
selected_indices = np.arange(max_output_size)
|
306 |
+
return gather_masklist(masklist, selected_indices)
|
307 |
+
else:
|
308 |
+
return masklist
|
309 |
+
|
310 |
+
masks = masklist.get_masks()
|
311 |
+
num_masks = masklist.num_boxes()
|
312 |
+
|
313 |
+
# is_index_valid is True only for all remaining valid boxes,
|
314 |
+
is_index_valid = np.full(num_masks, 1, dtype=bool)
|
315 |
+
selected_indices = []
|
316 |
+
num_output = 0
|
317 |
+
for i in range(num_masks):
|
318 |
+
if num_output < max_output_size:
|
319 |
+
if is_index_valid[i]:
|
320 |
+
num_output += 1
|
321 |
+
selected_indices.append(i)
|
322 |
+
is_index_valid[i] = False
|
323 |
+
valid_indices = np.where(is_index_valid)[0]
|
324 |
+
if valid_indices.size == 0:
|
325 |
+
break
|
326 |
+
|
327 |
+
intersect_over_union = iou_mask(np.expand_dims(masks[i], axis=0), masks[valid_indices])
|
328 |
+
intersect_over_union = np.squeeze(intersect_over_union, axis=0)
|
329 |
+
is_index_valid[valid_indices] = np.logical_and(
|
330 |
+
is_index_valid[valid_indices],
|
331 |
+
intersect_over_union <= iou_threshold)
|
332 |
+
return gather_masklist(masklist, np.array(selected_indices))
|
333 |
+
|
334 |
+
|
335 |
+
def multi_class_non_max_suppression_mask(masklist, score_thresh, iou_thresh, max_output_size):
|
336 |
+
"""Multi-class version of non maximum suppression.
|
337 |
+
|
338 |
+
This op greedily selects a subset of detection bounding boxes, pruning away boxes that have
|
339 |
+
high IOU (intersection over union) overlap (> thresh) with already selected boxes. It
|
340 |
+
operates independently for each class for which scores are provided (via the scores field
|
341 |
+
of the input box_list), pruning boxes with score less than a provided threshold prior to
|
342 |
+
applying NMS.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
masklist: BoxMaskList holding N boxes. Must contain a 'scores' field representing detection
|
346 |
+
scores. This scores field is a tensor that can be 1 dimensional (in the case of a
|
347 |
+
single class) or 2-dimensional, in which case we assume that it takes the shape
|
348 |
+
[num_boxes, num_classes]. We further assume that this rank is known statically and
|
349 |
+
that scores.shape[1] is also known (i.e., the number of classes is fixed and known
|
350 |
+
at graph construction time).
|
351 |
+
score_thresh: scalar threshold for score (low scoring boxes are removed).
|
352 |
+
iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap with previously
|
353 |
+
selected boxes are removed).
|
354 |
+
max_output_size: maximum number of retained boxes per class.
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
a masklist holding M boxes with a rank-1 scores field representing
|
358 |
+
corresponding scores for each box with scores sorted in decreasing order
|
359 |
+
and a rank-1 classes field representing a class label for each box.
|
360 |
+
Raises:
|
361 |
+
ValueError: if iou_thresh is not in [0, 1] or if input masklist does not have a valid scores field.
|
362 |
+
"""
|
363 |
+
if not 0 <= iou_thresh <= 1.0:
|
364 |
+
raise ValueError('thresh must be between 0 and 1')
|
365 |
+
if not isinstance(masklist, MaskList):
|
366 |
+
raise ValueError('masklist must be a masklist')
|
367 |
+
if not masklist.has_field('scores'):
|
368 |
+
raise ValueError('input masklist must have \'scores\' field')
|
369 |
+
scores = masklist.get_field('scores')
|
370 |
+
if len(scores.shape) == 1:
|
371 |
+
scores = np.reshape(scores, [-1, 1])
|
372 |
+
elif len(scores.shape) == 2:
|
373 |
+
if scores.shape[1] is None:
|
374 |
+
raise ValueError('scores field must have statically defined second dimension')
|
375 |
+
else:
|
376 |
+
raise ValueError('scores field must be of rank 1 or 2')
|
377 |
+
|
378 |
+
num_boxes = masklist.num_boxes()
|
379 |
+
num_scores = scores.shape[0]
|
380 |
+
num_classes = scores.shape[1]
|
381 |
+
|
382 |
+
if num_boxes != num_scores:
|
383 |
+
raise ValueError('Incorrect scores field length: actual vs expected.')
|
384 |
+
|
385 |
+
selected_boxes_list = []
|
386 |
+
for class_idx in range(num_classes):
|
387 |
+
masklist_and_class_scores = MaskList(box_data=masklist.get(), mask_data=masklist.get_masks())
|
388 |
+
class_scores = np.reshape(scores[0:num_scores, class_idx], [-1])
|
389 |
+
masklist_and_class_scores.add_field('scores', class_scores)
|
390 |
+
masklist_filt = filter_scores_greater_than(masklist_and_class_scores, score_thresh)
|
391 |
+
nms_result = non_max_suppression(
|
392 |
+
masklist_filt,
|
393 |
+
max_output_size=max_output_size,
|
394 |
+
iou_threshold=iou_thresh,
|
395 |
+
score_threshold=score_thresh)
|
396 |
+
nms_result.add_field('classes', np.zeros_like(nms_result.get_field('scores')) + class_idx)
|
397 |
+
selected_boxes_list.append(nms_result)
|
398 |
+
selected_boxes = concatenate_boxlist(selected_boxes_list)
|
399 |
+
sorted_boxes = sort_by_field_boxlist(selected_boxes, 'scores')
|
400 |
+
return boxlist_to_masklist(boxlist=sorted_boxes)
|
401 |
+
|
402 |
+
|
403 |
+
def prune_non_overlapping_masklist(masklist1, masklist2, minoverlap=0.0):
|
404 |
+
"""Prunes the boxes in list1 that overlap less than thresh with list2.
|
405 |
+
|
406 |
+
For each mask in masklist1, we want its IOA to be more than minoverlap
|
407 |
+
with at least one of the masks in masklist2. If it does not, we remove
|
408 |
+
it. If the masks are not full size image, we do the pruning based on boxes.
|
409 |
+
|
410 |
+
Args:
|
411 |
+
masklist1: BoxMaskList holding N boxes and masks.
|
412 |
+
masklist2: BoxMaskList holding M boxes and masks.
|
413 |
+
minoverlap: Minimum required overlap between boxes, to count them as overlapping.
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
A pruned masklist with size [N', 4].
|
417 |
+
"""
|
418 |
+
intersection_over_area = ioa_masklist(masklist2, masklist1) # [M, N] tensor
|
419 |
+
intersection_over_area = np.amax(intersection_over_area, axis=0) # [N] tensor
|
420 |
+
keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap))
|
421 |
+
keep_inds = np.nonzero(keep_bool)[0]
|
422 |
+
new_masklist1 = gather_masklist(masklist1, keep_inds)
|
423 |
+
return new_masklist1
|
424 |
+
|
425 |
+
|
426 |
+
def concatenate_masklist(masklists, fields=None):
|
427 |
+
"""Concatenate list of masklists.
|
428 |
+
|
429 |
+
This op concatenates a list of input masklists into a larger
|
430 |
+
masklist. It also
|
431 |
+
handles concatenation of masklist fields as long as the field tensor
|
432 |
+
shapes are equal except for the first dimension.
|
433 |
+
|
434 |
+
Args:
|
435 |
+
masklists: list of BoxMaskList objects
|
436 |
+
fields: optional list of fields to also concatenate. By default, all
|
437 |
+
fields from the first BoxMaskList in the list are included in the concatenation.
|
438 |
+
|
439 |
+
Returns:
|
440 |
+
a masklist with number of boxes equal to sum([masklist.num_boxes() for masklist in masklist])
|
441 |
+
Raises:
|
442 |
+
ValueError: if masklists is invalid (i.e., is not a list, is empty, or contains non
|
443 |
+
masklist objects), or if requested fields are not contained in all masklists
|
444 |
+
"""
|
445 |
+
if fields is not None:
|
446 |
+
if 'masks' not in fields:
|
447 |
+
fields.append('masks')
|
448 |
+
return boxlist_to_masklist(concatenate_boxlist(boxlists=masklists, fields=fields))
|
449 |
+
|
450 |
+
|
451 |
+
def filter_scores_greater_than_masklist(masklist, thresh):
|
452 |
+
"""Filter to keep only boxes and masks with score exceeding a given threshold.
|
453 |
+
|
454 |
+
This op keeps the collection of boxes and masks whose corresponding scores are
|
455 |
+
greater than the input threshold.
|
456 |
+
|
457 |
+
Args:
|
458 |
+
masklist: BoxMaskList holding N boxes and masks. Must contain a
|
459 |
+
'scores' field representing detection scores.
|
460 |
+
thresh: scalar threshold
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
a BoxMaskList holding M boxes and masks where M <= N
|
464 |
+
|
465 |
+
Raises:
|
466 |
+
ValueError: if masklist not a BoxMaskList object or if it does not have a scores field
|
467 |
+
"""
|
468 |
+
if not isinstance(masklist, MaskList):
|
469 |
+
raise ValueError('masklist must be a BoxMaskList')
|
470 |
+
if not masklist.has_field('scores'):
|
471 |
+
raise ValueError('input masklist must have \'scores\' field')
|
472 |
+
scores = masklist.get_field('scores')
|
473 |
+
if len(scores.shape) > 2:
|
474 |
+
raise ValueError('Scores should have rank 1 or 2')
|
475 |
+
if len(scores.shape) == 2 and scores.shape[1] != 1:
|
476 |
+
raise ValueError('Scores should have rank 1 or have shape consistent with [None, 1]')
|
477 |
+
high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), [-1]).astype(np.int32)
|
478 |
+
return gather_masklist(masklist, high_score_indices)
|
efficientdet/effdet/evaluation/object_detection_evaluation.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from effdet.evaluation.metrics import compute_precision_recall, compute_average_precision, compute_cor_loc
|
6 |
+
from effdet.evaluation.per_image_evaluation import PerImageEvaluation
|
7 |
+
|
8 |
+
|
9 |
+
class ObjectDetectionEvaluation:
|
10 |
+
"""Internal implementation of Pascal object detection metrics."""
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
num_gt_classes,
|
14 |
+
matching_iou_threshold=0.5,
|
15 |
+
nms_iou_threshold=1.0,
|
16 |
+
nms_max_output_boxes=10000,
|
17 |
+
recall_lower_bound=0.0,
|
18 |
+
recall_upper_bound=1.0,
|
19 |
+
use_weighted_mean_ap=False,
|
20 |
+
label_id_offset=0,
|
21 |
+
group_of_weight=0.0,
|
22 |
+
per_image_eval_class=PerImageEvaluation):
|
23 |
+
"""Constructor.
|
24 |
+
Args:
|
25 |
+
num_gt_classes: Number of ground-truth classes.
|
26 |
+
matching_iou_threshold: IOU threshold used for matching detected boxes to ground-truth boxes.
|
27 |
+
nms_iou_threshold: IOU threshold used for non-maximum suppression.
|
28 |
+
nms_max_output_boxes: Maximum number of boxes returned by non-maximum suppression.
|
29 |
+
recall_lower_bound: lower bound of recall operating area
|
30 |
+
recall_upper_bound: upper bound of recall operating area
|
31 |
+
use_weighted_mean_ap: (optional) boolean which determines if the mean
|
32 |
+
average precision is computed directly from the scores and tp_fp_labels of all classes.
|
33 |
+
label_id_offset: The label id offset.
|
34 |
+
group_of_weight: Weight of group-of boxes.If set to 0, detections of the
|
35 |
+
correct class within a group-of box are ignored. If weight is > 0, then
|
36 |
+
if at least one detection falls within a group-of box with
|
37 |
+
matching_iou_threshold, weight group_of_weight is added to true
|
38 |
+
positives. Consequently, if no detection falls within a group-of box,
|
39 |
+
weight group_of_weight is added to false negatives.
|
40 |
+
per_image_eval_class: The class that contains functions for computing per image metrics.
|
41 |
+
Raises:
|
42 |
+
ValueError: if num_gt_classes is smaller than 1.
|
43 |
+
"""
|
44 |
+
if num_gt_classes < 1:
|
45 |
+
raise ValueError('Need at least 1 groundtruth class for evaluation.')
|
46 |
+
|
47 |
+
self.per_image_eval = per_image_eval_class(
|
48 |
+
num_gt_classes=num_gt_classes,
|
49 |
+
matching_iou_threshold=matching_iou_threshold,
|
50 |
+
nms_iou_threshold=nms_iou_threshold,
|
51 |
+
nms_max_output_boxes=nms_max_output_boxes,
|
52 |
+
group_of_weight=group_of_weight)
|
53 |
+
self.recall_lower_bound = recall_lower_bound
|
54 |
+
self.recall_upper_bound = recall_upper_bound
|
55 |
+
self.group_of_weight = group_of_weight
|
56 |
+
self.num_class = num_gt_classes
|
57 |
+
self.use_weighted_mean_ap = use_weighted_mean_ap
|
58 |
+
self.label_id_offset = label_id_offset
|
59 |
+
|
60 |
+
self.gt_boxes = {}
|
61 |
+
self.gt_class_labels = {}
|
62 |
+
self.gt_masks = {}
|
63 |
+
self.gt_is_difficult_list = {}
|
64 |
+
self.gt_is_group_of_list = {}
|
65 |
+
self.num_gt_instances_per_class = np.zeros(self.num_class, dtype=float)
|
66 |
+
self.num_gt_imgs_per_class = np.zeros(self.num_class, dtype=int)
|
67 |
+
|
68 |
+
self._initialize_detections()
|
69 |
+
|
70 |
+
def _initialize_detections(self):
|
71 |
+
"""Initializes internal data structures."""
|
72 |
+
self.detection_keys = set()
|
73 |
+
self.scores_per_class = [[] for _ in range(self.num_class)]
|
74 |
+
self.tp_fp_labels_per_class = [[] for _ in range(self.num_class)]
|
75 |
+
self.num_images_correctly_detected_per_class = np.zeros(self.num_class)
|
76 |
+
self.average_precision_per_class = np.empty(self.num_class, dtype=float)
|
77 |
+
self.average_precision_per_class.fill(np.nan)
|
78 |
+
self.precisions_per_class = [np.nan] * self.num_class
|
79 |
+
self.recalls_per_class = [np.nan] * self.num_class
|
80 |
+
self.sum_tp_class = [np.nan] * self.num_class
|
81 |
+
|
82 |
+
self.corloc_per_class = np.ones(self.num_class, dtype=float)
|
83 |
+
|
84 |
+
def clear_detections(self):
|
85 |
+
self._initialize_detections()
|
86 |
+
|
87 |
+
def add_single_ground_truth_image_info(
|
88 |
+
self, image_key, gt_boxes, gt_class_labels,
|
89 |
+
gt_is_difficult_list=None, gt_is_group_of_list=None, gt_masks=None):
|
90 |
+
"""Adds groundtruth for a single image to be used for evaluation.
|
91 |
+
Args:
|
92 |
+
image_key: A unique string/integer identifier for the image.
|
93 |
+
gt_boxes: float32 numpy array of shape [num_boxes, 4] containing
|
94 |
+
`num_boxes` groundtruth boxes of the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
95 |
+
gt_class_labels: integer numpy array of shape [num_boxes]
|
96 |
+
containing 0-indexed groundtruth classes for the boxes.
|
97 |
+
gt_is_difficult_list: A length M numpy boolean array denoting
|
98 |
+
whether a ground truth box is a difficult instance or not. To support
|
99 |
+
the case that no boxes are difficult, it is by default set as None.
|
100 |
+
gt_is_group_of_list: A length M numpy boolean array denoting
|
101 |
+
whether a ground truth box is a group-of box or not. To support the case
|
102 |
+
that no boxes are groups-of, it is by default set as None.
|
103 |
+
gt_masks: uint8 numpy array of shape [num_boxes, height, width]
|
104 |
+
containing `num_boxes` groundtruth masks. The mask values range from 0 to 1.
|
105 |
+
"""
|
106 |
+
if image_key in self.gt_boxes:
|
107 |
+
logging.warning('image %s has already been added to the ground truth database.', image_key)
|
108 |
+
return
|
109 |
+
|
110 |
+
self.gt_boxes[image_key] = gt_boxes
|
111 |
+
self.gt_class_labels[image_key] = gt_class_labels
|
112 |
+
self.gt_masks[image_key] = gt_masks
|
113 |
+
if gt_is_difficult_list is None:
|
114 |
+
num_boxes = gt_boxes.shape[0]
|
115 |
+
gt_is_difficult_list = np.zeros(num_boxes, dtype=bool)
|
116 |
+
gt_is_difficult_list = gt_is_difficult_list.astype(dtype=bool)
|
117 |
+
self.gt_is_difficult_list[image_key] = gt_is_difficult_list
|
118 |
+
if gt_is_group_of_list is None:
|
119 |
+
num_boxes = gt_boxes.shape[0]
|
120 |
+
gt_is_group_of_list = np.zeros(num_boxes, dtype=bool)
|
121 |
+
if gt_masks is None:
|
122 |
+
num_boxes = gt_boxes.shape[0]
|
123 |
+
mask_presence_indicator = np.zeros(num_boxes, dtype=bool)
|
124 |
+
else:
|
125 |
+
mask_presence_indicator = (np.sum(gt_masks, axis=(1, 2)) == 0).astype(dtype=bool)
|
126 |
+
|
127 |
+
gt_is_group_of_list = gt_is_group_of_list.astype(dtype=bool)
|
128 |
+
self.gt_is_group_of_list[image_key] = gt_is_group_of_list
|
129 |
+
|
130 |
+
# ignore boxes without masks
|
131 |
+
masked_gt_is_difficult_list = gt_is_difficult_list | mask_presence_indicator
|
132 |
+
for class_index in range(self.num_class):
|
133 |
+
num_gt_instances = np.sum(
|
134 |
+
gt_class_labels[~masked_gt_is_difficult_list & ~gt_is_group_of_list] == class_index)
|
135 |
+
num_groupof_gt_instances = self.group_of_weight * np.sum(
|
136 |
+
gt_class_labels[gt_is_group_of_list & ~masked_gt_is_difficult_list] == class_index)
|
137 |
+
self.num_gt_instances_per_class[class_index] += num_gt_instances + num_groupof_gt_instances
|
138 |
+
if np.any(gt_class_labels == class_index):
|
139 |
+
self.num_gt_imgs_per_class[class_index] += 1
|
140 |
+
|
141 |
+
def add_single_detected_image_info(
|
142 |
+
self, image_key, detected_boxes, detected_scores, detected_class_labels, detected_masks=None):
|
143 |
+
"""Adds detections for a single image to be used for evaluation.
|
144 |
+
Args:
|
145 |
+
image_key: A unique string/integer identifier for the image.
|
146 |
+
detected_boxes: float32 numpy array of shape [num_boxes, 4] containing
|
147 |
+
`num_boxes` detection boxes of the format [ymin, xmin, ymax, xmax] in
|
148 |
+
absolute image coordinates.
|
149 |
+
detected_scores: float32 numpy array of shape [num_boxes] containing
|
150 |
+
detection scores for the boxes.
|
151 |
+
detected_class_labels: integer numpy array of shape [num_boxes] containing
|
152 |
+
0-indexed detection classes for the boxes.
|
153 |
+
detected_masks: np.uint8 numpy array of shape [num_boxes, height, width]
|
154 |
+
containing `num_boxes` detection masks with values ranging between 0 and 1.
|
155 |
+
Raises:
|
156 |
+
ValueError: if the number of boxes, scores and class labels differ in length.
|
157 |
+
"""
|
158 |
+
if len(detected_boxes) != len(detected_scores) or len(detected_boxes) != len(detected_class_labels):
|
159 |
+
raise ValueError(
|
160 |
+
'detected_boxes, detected_scores and '
|
161 |
+
'detected_class_labels should all have same lengths. Got'
|
162 |
+
'[%d, %d, %d]' % len(detected_boxes), len(detected_scores),
|
163 |
+
len(detected_class_labels))
|
164 |
+
|
165 |
+
if image_key in self.detection_keys:
|
166 |
+
logging.warning('image %s has already been added to the detection result database', image_key)
|
167 |
+
return
|
168 |
+
|
169 |
+
self.detection_keys.add(image_key)
|
170 |
+
if image_key in self.gt_boxes:
|
171 |
+
gt_boxes = self.gt_boxes[image_key]
|
172 |
+
gt_class_labels = self.gt_class_labels[image_key]
|
173 |
+
# Masks are popped instead of look up. The reason is that we do not want
|
174 |
+
# to keep all masks in memory which can cause memory overflow.
|
175 |
+
gt_masks = self.gt_masks.pop(image_key)
|
176 |
+
gt_is_difficult_list = self.gt_is_difficult_list[image_key]
|
177 |
+
gt_is_group_of_list = self.gt_is_group_of_list[image_key]
|
178 |
+
else:
|
179 |
+
gt_boxes = np.empty(shape=[0, 4], dtype=float)
|
180 |
+
gt_class_labels = np.array([], dtype=int)
|
181 |
+
if detected_masks is None:
|
182 |
+
gt_masks = None
|
183 |
+
else:
|
184 |
+
gt_masks = np.empty(shape=[0, 1, 1], dtype=float)
|
185 |
+
gt_is_difficult_list = np.array([], dtype=bool)
|
186 |
+
gt_is_group_of_list = np.array([], dtype=bool)
|
187 |
+
scores, tp_fp_labels, is_class_correctly_detected_in_image = \
|
188 |
+
self.per_image_eval.compute_object_detection_metrics(
|
189 |
+
detected_boxes=detected_boxes,
|
190 |
+
detected_scores=detected_scores,
|
191 |
+
detected_class_labels=detected_class_labels,
|
192 |
+
gt_boxes=gt_boxes,
|
193 |
+
gt_class_labels=gt_class_labels,
|
194 |
+
gt_is_difficult_list=gt_is_difficult_list,
|
195 |
+
gt_is_group_of_list=gt_is_group_of_list,
|
196 |
+
detected_masks=detected_masks,
|
197 |
+
gt_masks=gt_masks)
|
198 |
+
|
199 |
+
for i in range(self.num_class):
|
200 |
+
if scores[i].shape[0] > 0:
|
201 |
+
self.scores_per_class[i].append(scores[i])
|
202 |
+
self.tp_fp_labels_per_class[i].append(tp_fp_labels[i])
|
203 |
+
self.num_images_correctly_detected_per_class += is_class_correctly_detected_in_image
|
204 |
+
|
205 |
+
def evaluate(self):
|
206 |
+
"""Compute evaluation result.
|
207 |
+
Returns:
|
208 |
+
A dict with the following fields -
|
209 |
+
average_precision: float numpy array of average precision for each class.
|
210 |
+
mean_ap: mean average precision of all classes, float scalar
|
211 |
+
precisions: List of precisions, each precision is a float numpy array
|
212 |
+
recalls: List of recalls, each recall is a float numpy array
|
213 |
+
corloc: numpy float array
|
214 |
+
mean_corloc: Mean CorLoc score for each class, float scalar
|
215 |
+
"""
|
216 |
+
if (self.num_gt_instances_per_class == 0).any():
|
217 |
+
logging.warning(
|
218 |
+
'The following classes have no ground truth examples: %s',
|
219 |
+
np.squeeze(np.argwhere(self.num_gt_instances_per_class == 0)) + self.label_id_offset)
|
220 |
+
|
221 |
+
if self.use_weighted_mean_ap:
|
222 |
+
all_scores = np.array([], dtype=float)
|
223 |
+
all_tp_fp_labels = np.array([], dtype=bool)
|
224 |
+
for class_index in range(self.num_class):
|
225 |
+
if self.num_gt_instances_per_class[class_index] == 0:
|
226 |
+
continue
|
227 |
+
if not self.scores_per_class[class_index]:
|
228 |
+
scores = np.array([], dtype=float)
|
229 |
+
tp_fp_labels = np.array([], dtype=float)
|
230 |
+
else:
|
231 |
+
scores = np.concatenate(self.scores_per_class[class_index])
|
232 |
+
tp_fp_labels = np.concatenate(self.tp_fp_labels_per_class[class_index])
|
233 |
+
if self.use_weighted_mean_ap:
|
234 |
+
all_scores = np.append(all_scores, scores)
|
235 |
+
all_tp_fp_labels = np.append(all_tp_fp_labels, tp_fp_labels)
|
236 |
+
precision, recall = compute_precision_recall(
|
237 |
+
scores, tp_fp_labels, self.num_gt_instances_per_class[class_index])
|
238 |
+
recall_within_bound_indices = [
|
239 |
+
index for index, value in enumerate(recall) if
|
240 |
+
value >= self.recall_lower_bound and value <= self.recall_upper_bound
|
241 |
+
]
|
242 |
+
recall_within_bound = recall[recall_within_bound_indices]
|
243 |
+
precision_within_bound = precision[recall_within_bound_indices]
|
244 |
+
|
245 |
+
self.precisions_per_class[class_index] = precision_within_bound
|
246 |
+
self.recalls_per_class[class_index] = recall_within_bound
|
247 |
+
self.sum_tp_class[class_index] = tp_fp_labels.sum()
|
248 |
+
average_precision = compute_average_precision(precision_within_bound, recall_within_bound)
|
249 |
+
self.average_precision_per_class[class_index] = average_precision
|
250 |
+
logging.debug('average_precision: %f', average_precision)
|
251 |
+
|
252 |
+
self.corloc_per_class = compute_cor_loc(
|
253 |
+
self.num_gt_imgs_per_class, self.num_images_correctly_detected_per_class)
|
254 |
+
|
255 |
+
if self.use_weighted_mean_ap:
|
256 |
+
num_gt_instances = np.sum(self.num_gt_instances_per_class)
|
257 |
+
precision, recall = compute_precision_recall(all_scores, all_tp_fp_labels, num_gt_instances)
|
258 |
+
recall_within_bound_indices = [
|
259 |
+
index for index, value in enumerate(recall) if
|
260 |
+
value >= self.recall_lower_bound and value <= self.recall_upper_bound
|
261 |
+
]
|
262 |
+
recall_within_bound = recall[recall_within_bound_indices]
|
263 |
+
precision_within_bound = precision[recall_within_bound_indices]
|
264 |
+
mean_ap = compute_average_precision(precision_within_bound, recall_within_bound)
|
265 |
+
else:
|
266 |
+
mean_ap = np.nanmean(self.average_precision_per_class)
|
267 |
+
mean_corloc = np.nanmean(self.corloc_per_class)
|
268 |
+
|
269 |
+
return dict(
|
270 |
+
per_class_ap=self.average_precision_per_class, mean_ap=mean_ap,
|
271 |
+
per_class_precision=self.precisions_per_class,
|
272 |
+
per_class_recall=self.recalls_per_class,
|
273 |
+
per_class_corlocs=self.corloc_per_class, mean_corloc=mean_corloc)
|
efficientdet/effdet/evaluation/per_image_evaluation.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .np_mask_list import *
|
2 |
+
from .metrics import *
|
3 |
+
|
4 |
+
|
5 |
+
class PerImageEvaluation:
|
6 |
+
"""Evaluate detection result of a single image."""
|
7 |
+
|
8 |
+
def __init__(self,
|
9 |
+
num_gt_classes,
|
10 |
+
matching_iou_threshold=0.5,
|
11 |
+
nms_iou_threshold=0.3,
|
12 |
+
nms_max_output_boxes=50,
|
13 |
+
group_of_weight=0.0):
|
14 |
+
"""Initialized PerImageEvaluation by evaluation parameters.
|
15 |
+
Args:
|
16 |
+
num_gt_classes: Number of ground truth object classes
|
17 |
+
matching_iou_threshold: A ratio of area intersection to union, which is
|
18 |
+
the threshold to consider whether a detection is true positive or not
|
19 |
+
nms_iou_threshold: IOU threshold used in Non Maximum Suppression.
|
20 |
+
nms_max_output_boxes: Number of maximum output boxes in NMS.
|
21 |
+
group_of_weight: Weight of the group-of boxes.
|
22 |
+
"""
|
23 |
+
self.matching_iou_threshold = matching_iou_threshold
|
24 |
+
self.nms_iou_threshold = nms_iou_threshold
|
25 |
+
self.nms_max_output_boxes = nms_max_output_boxes
|
26 |
+
self.num_gt_classes = num_gt_classes
|
27 |
+
self.group_of_weight = group_of_weight
|
28 |
+
|
29 |
+
def compute_object_detection_metrics(
|
30 |
+
self, detected_boxes, detected_scores, detected_class_labels,
|
31 |
+
gt_boxes, gt_class_labels, gt_is_difficult_list, gt_is_group_of_list,
|
32 |
+
detected_masks=None, gt_masks=None):
|
33 |
+
"""Evaluates detections as being tp, fp or weighted from a single image.
|
34 |
+
The evaluation is done in two stages:
|
35 |
+
1. All detections are matched to non group-of boxes; true positives are
|
36 |
+
determined and detections matched to difficult boxes are ignored.
|
37 |
+
2. Detections that are determined as false positives are matched against
|
38 |
+
group-of boxes and weighted if matched.
|
39 |
+
Args:
|
40 |
+
detected_boxes: A float numpy array of shape [N, 4], representing N
|
41 |
+
regions of detected object regions. Each row is of the format [y_min, x_min, y_max, x_max]
|
42 |
+
detected_scores: A float numpy array of shape [N, 1], representing the
|
43 |
+
confidence scores of the detected N object instances.
|
44 |
+
detected_class_labels: A integer numpy array of shape [N, 1], repreneting
|
45 |
+
the class labels of the detected N object instances.
|
46 |
+
gt_boxes: A float numpy array of shape [M, 4], representing M
|
47 |
+
regions of object instances in ground truth
|
48 |
+
gt_class_labels: An integer numpy array of shape [M, 1],
|
49 |
+
representing M class labels of object instances in ground truth
|
50 |
+
gt_is_difficult_list: A boolean numpy array of length M denoting
|
51 |
+
whether a ground truth box is a difficult instance or not
|
52 |
+
gt_is_group_of_list: A boolean numpy array of length M denoting
|
53 |
+
whether a ground truth box has group-of tag
|
54 |
+
detected_masks: (optional) A uint8 numpy array of shape [N, height,
|
55 |
+
width]. If not None, the metrics will be computed based on masks.
|
56 |
+
gt_masks: (optional) A uint8 numpy array of shape [M, height,
|
57 |
+
width]. Can have empty masks, i.e. where all values are 0.
|
58 |
+
Returns:
|
59 |
+
scores: A list of C float numpy arrays. Each numpy array is of
|
60 |
+
shape [K, 1], representing K scores detected with object class label c
|
61 |
+
tp_fp_labels: A list of C boolean numpy arrays. Each numpy array
|
62 |
+
is of shape [K, 1], representing K True/False positive label of
|
63 |
+
object instances detected with class label c
|
64 |
+
is_class_correctly_detected_in_image: a numpy integer array of
|
65 |
+
shape [C, 1], indicating whether the correponding class has a least
|
66 |
+
one instance being correctly detected in the image
|
67 |
+
"""
|
68 |
+
detected_boxes, detected_scores, detected_class_labels, detected_masks = (
|
69 |
+
self._remove_invalid_boxes(detected_boxes, detected_scores, detected_class_labels, detected_masks))
|
70 |
+
|
71 |
+
scores, tp_fp_labels = self._compute_tp_fp(
|
72 |
+
detected_boxes=detected_boxes,
|
73 |
+
detected_scores=detected_scores,
|
74 |
+
detected_class_labels=detected_class_labels,
|
75 |
+
gt_boxes=gt_boxes,
|
76 |
+
gt_class_labels=gt_class_labels,
|
77 |
+
gt_is_difficult_list=gt_is_difficult_list,
|
78 |
+
gt_is_group_of_list=gt_is_group_of_list,
|
79 |
+
detected_masks=detected_masks,
|
80 |
+
gt_masks=gt_masks)
|
81 |
+
|
82 |
+
is_class_correctly_detected_in_image = self._compute_cor_loc(
|
83 |
+
detected_boxes=detected_boxes,
|
84 |
+
detected_scores=detected_scores,
|
85 |
+
detected_class_labels=detected_class_labels,
|
86 |
+
gt_boxes=gt_boxes,
|
87 |
+
gt_class_labels=gt_class_labels,
|
88 |
+
detected_masks=detected_masks,
|
89 |
+
gt_masks=gt_masks)
|
90 |
+
|
91 |
+
return scores, tp_fp_labels, is_class_correctly_detected_in_image
|
92 |
+
|
93 |
+
def _compute_cor_loc(
|
94 |
+
self, detected_boxes, detected_scores, detected_class_labels,
|
95 |
+
gt_boxes, gt_class_labels, detected_masks=None, gt_masks=None):
|
96 |
+
"""Compute CorLoc score for object detection result.
|
97 |
+
Args:
|
98 |
+
detected_boxes: A float numpy array of shape [N, 4], representing N
|
99 |
+
regions of detected object regions. Each row is of the format [y_min, x_min, y_max, x_max]
|
100 |
+
detected_scores: A float numpy array of shape [N, 1], representing the
|
101 |
+
confidence scores of the detected N object instances.
|
102 |
+
detected_class_labels: A integer numpy array of shape [N, 1], repreneting
|
103 |
+
the class labels of the detected N object instances.
|
104 |
+
gt_boxes: A float numpy array of shape [M, 4], representing M
|
105 |
+
regions of object instances in ground truth
|
106 |
+
gt_class_labels: An integer numpy array of shape [M, 1],
|
107 |
+
representing M class labels of object instances in ground truth
|
108 |
+
detected_masks: (optional) A uint8 numpy array of shape [N, height, width].
|
109 |
+
If not None, the scores will be computed based on masks.
|
110 |
+
gt_masks: (optional) A uint8 numpy array of shape [M, height, width].
|
111 |
+
Returns:
|
112 |
+
is_class_correctly_detected_in_image: a numpy integer array of
|
113 |
+
shape [C, 1], indicating whether the correponding class has a least
|
114 |
+
one instance being correctly detected in the image
|
115 |
+
Raises:
|
116 |
+
ValueError: If detected masks is not None but groundtruth masks are None,
|
117 |
+
or the other way around.
|
118 |
+
"""
|
119 |
+
if (detected_masks is not None and gt_masks is None) or (
|
120 |
+
detected_masks is None and gt_masks is not None):
|
121 |
+
raise ValueError(
|
122 |
+
'If `detected_masks` is provided, then `gt_masks` should also be provided.')
|
123 |
+
|
124 |
+
is_class_correctly_detected_in_image = np.zeros(
|
125 |
+
self.num_gt_classes, dtype=int)
|
126 |
+
for i in range(self.num_gt_classes):
|
127 |
+
(gt_boxes_at_ith_class, gt_masks_at_ith_class,
|
128 |
+
detected_boxes_at_ith_class, detected_scores_at_ith_class,
|
129 |
+
detected_masks_at_ith_class) = self._get_ith_class_arrays(
|
130 |
+
detected_boxes, detected_scores, detected_masks,
|
131 |
+
detected_class_labels, gt_boxes, gt_masks,
|
132 |
+
gt_class_labels, i)
|
133 |
+
is_class_correctly_detected_in_image[i] = (
|
134 |
+
self._compute_is_class_correctly_detected_in_image(
|
135 |
+
detected_boxes=detected_boxes_at_ith_class,
|
136 |
+
detected_scores=detected_scores_at_ith_class,
|
137 |
+
gt_boxes=gt_boxes_at_ith_class,
|
138 |
+
detected_masks=detected_masks_at_ith_class,
|
139 |
+
gt_masks=gt_masks_at_ith_class))
|
140 |
+
|
141 |
+
return is_class_correctly_detected_in_image
|
142 |
+
|
143 |
+
def _compute_is_class_correctly_detected_in_image(
|
144 |
+
self, detected_boxes, detected_scores, gt_boxes, detected_masks=None, gt_masks=None):
|
145 |
+
"""Compute CorLoc score for a single class.
|
146 |
+
Args:
|
147 |
+
detected_boxes: A numpy array of shape [N, 4] representing detected box coordinates
|
148 |
+
detected_scores: A 1-d numpy array of length N representing classification score
|
149 |
+
gt_boxes: A numpy array of shape [M, 4] representing ground truth box coordinates
|
150 |
+
detected_masks: (optional) A np.uint8 numpy array of shape [N, height, width].
|
151 |
+
If not None, the scores will be computed based on masks.
|
152 |
+
gt_masks: (optional) A np.uint8 numpy array of shape [M, height, width].
|
153 |
+
Returns:
|
154 |
+
is_class_correctly_detected_in_image: An integer 1 or 0 denoting whether a
|
155 |
+
class is correctly detected in the image or not
|
156 |
+
"""
|
157 |
+
if detected_boxes.size > 0:
|
158 |
+
if gt_boxes.size > 0:
|
159 |
+
max_score_id = np.argmax(detected_scores)
|
160 |
+
mask_mode = False
|
161 |
+
if detected_masks is not None and gt_masks is not None:
|
162 |
+
mask_mode = True
|
163 |
+
if mask_mode:
|
164 |
+
detected_boxlist = MaskList(
|
165 |
+
box_data=np.expand_dims(detected_boxes[max_score_id], axis=0),
|
166 |
+
mask_data=np.expand_dims(detected_masks[max_score_id], axis=0))
|
167 |
+
gt_boxlist = MaskList(box_data=gt_boxes, mask_data=gt_masks)
|
168 |
+
iou = iou_masklist(detected_boxlist, gt_boxlist)
|
169 |
+
else:
|
170 |
+
detected_boxlist = BoxList(np.expand_dims(detected_boxes[max_score_id, :], axis=0))
|
171 |
+
gt_boxlist = BoxList(gt_boxes)
|
172 |
+
iou = iou_boxlist(detected_boxlist, gt_boxlist)
|
173 |
+
if np.max(iou) >= self.matching_iou_threshold:
|
174 |
+
return 1
|
175 |
+
return 0
|
176 |
+
|
177 |
+
def _compute_tp_fp(
|
178 |
+
self, detected_boxes, detected_scores, detected_class_labels,
|
179 |
+
gt_boxes, gt_class_labels, gt_is_difficult_list, gt_is_group_of_list, detected_masks=None, gt_masks=None):
|
180 |
+
"""Labels true/false positives of detections of an image across all classes.
|
181 |
+
Args:
|
182 |
+
detected_boxes: A float numpy array of shape [N, 4], representing N
|
183 |
+
regions of detected object regions. Each row is of the format [y_min, x_min, y_max, x_max]
|
184 |
+
detected_scores: A float numpy array of shape [N, 1], representing the
|
185 |
+
confidence scores of the detected N object instances.
|
186 |
+
detected_class_labels: A integer numpy array of shape [N, 1], representing
|
187 |
+
the class labels of the detected N object instances.
|
188 |
+
gt_boxes: A float numpy array of shape [M, 4], representing M
|
189 |
+
regions of object instances in ground truth
|
190 |
+
gt_class_labels: An integer numpy array of shape [M, 1],
|
191 |
+
representing M class labels of object instances in ground truth
|
192 |
+
gt_is_difficult_list: A boolean numpy array of length M denoting
|
193 |
+
whether a ground truth box is a difficult instance or not
|
194 |
+
gt_is_group_of_list: A boolean numpy array of length M denoting
|
195 |
+
whether a ground truth box has group-of tag
|
196 |
+
detected_masks: (optional) A np.uint8 numpy array of shape [N, height,
|
197 |
+
width]. If not None, the scores will be computed based on masks.
|
198 |
+
gt_masks: (optional) A np.uint8 numpy array of shape [M, height, width].
|
199 |
+
Returns:
|
200 |
+
result_scores: A list of float numpy arrays. Each numpy array is of
|
201 |
+
shape [K, 1], representing K scores detected with object class label c
|
202 |
+
result_tp_fp_labels: A list of boolean numpy array. Each numpy array is of
|
203 |
+
shape [K, 1], representing K True/False positive label of object
|
204 |
+
instances detected with class label c
|
205 |
+
Raises:
|
206 |
+
ValueError: If detected masks is not None but groundtruth masks are None,
|
207 |
+
or the other way around.
|
208 |
+
"""
|
209 |
+
if detected_masks is not None and gt_masks is None:
|
210 |
+
raise ValueError(
|
211 |
+
'Detected masks is available but groundtruth masks is not.')
|
212 |
+
if detected_masks is None and gt_masks is not None:
|
213 |
+
raise ValueError(
|
214 |
+
'Groundtruth masks is available but detected masks is not.')
|
215 |
+
|
216 |
+
result_scores = []
|
217 |
+
result_tp_fp_labels = []
|
218 |
+
for i in range(self.num_gt_classes):
|
219 |
+
gt_is_difficult_list_at_ith_class = (
|
220 |
+
gt_is_difficult_list[gt_class_labels == i])
|
221 |
+
gt_is_group_of_list_at_ith_class = (
|
222 |
+
gt_is_group_of_list[gt_class_labels == i])
|
223 |
+
(gt_boxes_at_ith_class, gt_masks_at_ith_class,
|
224 |
+
detected_boxes_at_ith_class, detected_scores_at_ith_class,
|
225 |
+
detected_masks_at_ith_class) = self._get_ith_class_arrays(
|
226 |
+
detected_boxes, detected_scores, detected_masks,
|
227 |
+
detected_class_labels, gt_boxes, gt_masks,
|
228 |
+
gt_class_labels, i)
|
229 |
+
scores, tp_fp_labels = self._compute_tp_fp_for_single_class(
|
230 |
+
detected_boxes=detected_boxes_at_ith_class,
|
231 |
+
detected_scores=detected_scores_at_ith_class,
|
232 |
+
gt_boxes=gt_boxes_at_ith_class,
|
233 |
+
gt_is_difficult_list=gt_is_difficult_list_at_ith_class,
|
234 |
+
gt_is_group_of_list=gt_is_group_of_list_at_ith_class,
|
235 |
+
detected_masks=detected_masks_at_ith_class,
|
236 |
+
gt_masks=gt_masks_at_ith_class)
|
237 |
+
result_scores.append(scores)
|
238 |
+
result_tp_fp_labels.append(tp_fp_labels)
|
239 |
+
return result_scores, result_tp_fp_labels
|
240 |
+
|
241 |
+
def _get_overlaps_and_scores_mask_mode(
|
242 |
+
self, detected_boxes, detected_scores, detected_masks,
|
243 |
+
gt_boxes, gt_masks, gt_is_group_of_list):
|
244 |
+
"""Computes overlaps and scores between detected and groudntruth masks.
|
245 |
+
Args:
|
246 |
+
detected_boxes: A numpy array of shape [N, 4] representing detected box coordinates
|
247 |
+
detected_scores: A 1-d numpy array of length N representing classification score
|
248 |
+
detected_masks: A uint8 numpy array of shape [N, height, width]. If not
|
249 |
+
None, the scores will be computed based on masks.
|
250 |
+
gt_boxes: A numpy array of shape [M, 4] representing ground truth box coordinates
|
251 |
+
gt_masks: A uint8 numpy array of shape [M, height, width].
|
252 |
+
gt_is_group_of_list: A boolean numpy array of length M denoting
|
253 |
+
whether a ground truth box has group-of tag. If a groundtruth box is
|
254 |
+
group-of box, every detection matching this box is ignored.
|
255 |
+
Returns:
|
256 |
+
iou: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
|
257 |
+
gt_non_group_of_boxlist.num_boxes() == 0 it will be None.
|
258 |
+
ioa: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
|
259 |
+
gt_group_of_boxlist.num_boxes() == 0 it will be None.
|
260 |
+
scores: The score of the detected boxlist.
|
261 |
+
num_boxes: Number of non-maximum suppressed detected boxes.
|
262 |
+
"""
|
263 |
+
detected_boxlist = MaskList(box_data=detected_boxes, mask_data=detected_masks)
|
264 |
+
detected_boxlist.add_field('scores', detected_scores)
|
265 |
+
detected_boxlist = non_max_suppression(detected_boxlist, self.nms_max_output_boxes, self.nms_iou_threshold)
|
266 |
+
gt_non_group_of_boxlist = MaskList(
|
267 |
+
box_data=gt_boxes[~gt_is_group_of_list], mask_data=gt_masks[~gt_is_group_of_list])
|
268 |
+
gt_group_of_boxlist = MaskList(
|
269 |
+
box_data=gt_boxes[gt_is_group_of_list], mask_data=gt_masks[gt_is_group_of_list])
|
270 |
+
iou_b = iou_masklist(detected_boxlist, gt_non_group_of_boxlist)
|
271 |
+
ioa_b = np.transpose(ioa_masklist(gt_group_of_boxlist, detected_boxlist))
|
272 |
+
scores = detected_boxlist.get_field('scores')
|
273 |
+
num_boxes = detected_boxlist.num_boxes()
|
274 |
+
return iou_b, ioa_b, scores, num_boxes
|
275 |
+
|
276 |
+
def _get_overlaps_and_scores_box_mode(
|
277 |
+
self, detected_boxes, detected_scores, gt_boxes, gt_is_group_of_list):
|
278 |
+
"""Computes overlaps and scores between detected and groudntruth boxes.
|
279 |
+
Args:
|
280 |
+
detected_boxes: A numpy array of shape [N, 4] representing detected box coordinates
|
281 |
+
detected_scores: A 1-d numpy array of length N representing classification score
|
282 |
+
gt_boxes: A numpy array of shape [M, 4] representing ground truth box coordinates
|
283 |
+
gt_is_group_of_list: A boolean numpy array of length M denoting
|
284 |
+
whether a ground truth box has group-of tag. If a groundtruth box is
|
285 |
+
group-of box, every detection matching this box is ignored.
|
286 |
+
Returns:
|
287 |
+
iou: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
|
288 |
+
gt_non_group_of_boxlist.num_boxes() == 0 it will be None.
|
289 |
+
ioa: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
|
290 |
+
gt_group_of_boxlist.num_boxes() == 0 it will be None.
|
291 |
+
scores: The score of the detected boxlist.
|
292 |
+
num_boxes: Number of non-maximum suppressed detected boxes.
|
293 |
+
"""
|
294 |
+
detected_boxlist = BoxList(detected_boxes)
|
295 |
+
detected_boxlist.add_field('scores', detected_scores)
|
296 |
+
detected_boxlist = non_max_suppression(detected_boxlist, self.nms_max_output_boxes, self.nms_iou_threshold)
|
297 |
+
gt_non_group_of_boxlist = BoxList(gt_boxes[~gt_is_group_of_list])
|
298 |
+
gt_group_of_boxlist = BoxList(gt_boxes[gt_is_group_of_list])
|
299 |
+
iou_b = iou_boxlist(detected_boxlist, gt_non_group_of_boxlist)
|
300 |
+
ioa_b = np.transpose(ioa_boxlist(gt_group_of_boxlist, detected_boxlist))
|
301 |
+
scores = detected_boxlist.get_field('scores')
|
302 |
+
num_boxes = detected_boxlist.num_boxes()
|
303 |
+
return iou_b, ioa_b, scores, num_boxes
|
304 |
+
|
305 |
+
def _compute_tp_fp_for_single_class(
|
306 |
+
self, detected_boxes, detected_scores, gt_boxes,
|
307 |
+
gt_is_difficult_list, gt_is_group_of_list, detected_masks=None, gt_masks=None):
|
308 |
+
"""Labels boxes detected with the same class from the same image as tp/fp.
|
309 |
+
Args:
|
310 |
+
detected_boxes: A numpy array of shape [N, 4] representing detected box coordinates
|
311 |
+
detected_scores: A 1-d numpy array of length N representing classification score
|
312 |
+
gt_boxes: A numpy array of shape [M, 4] representing ground truth box coordinates
|
313 |
+
gt_is_difficult_list: A boolean numpy array of length M denoting
|
314 |
+
whether a ground truth box is a difficult instance or not. If a
|
315 |
+
groundtruth box is difficult, every detection matching this box is ignored.
|
316 |
+
gt_is_group_of_list: A boolean numpy array of length M denoting
|
317 |
+
whether a ground truth box has group-of tag. If a groundtruth box is
|
318 |
+
group-of box, every detection matching this box is ignored.
|
319 |
+
detected_masks: (optional) A uint8 numpy array of shape [N, height,
|
320 |
+
width]. If not None, the scores will be computed based on masks.
|
321 |
+
gt_masks: (optional) A uint8 numpy array of shape [M, height, width].
|
322 |
+
Returns:
|
323 |
+
Two arrays of the same size, containing all boxes that were evaluated as
|
324 |
+
being true positives or false positives; if a box matched to a difficult
|
325 |
+
box or to a group-of box, it is ignored.
|
326 |
+
scores: A numpy array representing the detection scores.
|
327 |
+
tp_fp_labels: a boolean numpy array indicating whether a detection is a true positive.
|
328 |
+
"""
|
329 |
+
if detected_boxes.size == 0:
|
330 |
+
return np.array([], dtype=float), np.array([], dtype=bool)
|
331 |
+
|
332 |
+
mask_mode = False
|
333 |
+
if detected_masks is not None and gt_masks is not None:
|
334 |
+
mask_mode = True
|
335 |
+
|
336 |
+
iou_b = np.ndarray([0, 0])
|
337 |
+
ioa_b = np.ndarray([0, 0])
|
338 |
+
iou_m = np.ndarray([0, 0])
|
339 |
+
ioa_m = np.ndarray([0, 0])
|
340 |
+
if mask_mode:
|
341 |
+
# For Instance Segmentation Evaluation on Open Images V5, not all boxed
|
342 |
+
# instances have corresponding segmentation annotations. Those boxes that
|
343 |
+
# dont have segmentation annotations are represented as empty masks in
|
344 |
+
# gt_masks nd array.
|
345 |
+
mask_presence_indicator = (np.sum(gt_masks, axis=(1, 2)) > 0)
|
346 |
+
|
347 |
+
iou_m, ioa_m, scores, num_detected_boxes = self._get_overlaps_and_scores_mask_mode(
|
348 |
+
detected_boxes=detected_boxes,
|
349 |
+
detected_scores=detected_scores,
|
350 |
+
detected_masks=detected_masks,
|
351 |
+
gt_boxes=gt_boxes[mask_presence_indicator, :],
|
352 |
+
gt_masks=gt_masks[mask_presence_indicator, :],
|
353 |
+
gt_is_group_of_list=gt_is_group_of_list[mask_presence_indicator])
|
354 |
+
|
355 |
+
if sum(mask_presence_indicator) < len(mask_presence_indicator):
|
356 |
+
# Not all masks are present - some masks are empty
|
357 |
+
iou_b, ioa_b, _, num_detected_boxes = self._get_overlaps_and_scores_box_mode(
|
358 |
+
detected_boxes=detected_boxes,
|
359 |
+
detected_scores=detected_scores,
|
360 |
+
gt_boxes=gt_boxes[~mask_presence_indicator, :],
|
361 |
+
gt_is_group_of_list=gt_is_group_of_list[~mask_presence_indicator])
|
362 |
+
num_detected_boxes = detected_boxes.shape[0]
|
363 |
+
else:
|
364 |
+
mask_presence_indicator = np.zeros(gt_is_group_of_list.shape, dtype=bool)
|
365 |
+
iou_b, ioa_b, scores, num_detected_boxes = self._get_overlaps_and_scores_box_mode(
|
366 |
+
detected_boxes=detected_boxes,
|
367 |
+
detected_scores=detected_scores,
|
368 |
+
gt_boxes=gt_boxes,
|
369 |
+
gt_is_group_of_list=gt_is_group_of_list)
|
370 |
+
|
371 |
+
if gt_boxes.size == 0:
|
372 |
+
return scores, np.zeros(num_detected_boxes, dtype=bool)
|
373 |
+
|
374 |
+
tp_fp_labels = np.zeros(num_detected_boxes, dtype=bool)
|
375 |
+
is_matched_to_box = np.zeros(num_detected_boxes, dtype=bool)
|
376 |
+
is_matched_to_difficult = np.zeros(num_detected_boxes, dtype=bool)
|
377 |
+
is_matched_to_group_of = np.zeros(num_detected_boxes, dtype=bool)
|
378 |
+
|
379 |
+
def compute_match_iou(iou_matrix, gt_nongroup_of_is_difficult_list, is_box):
|
380 |
+
"""Computes TP/FP for non group-of box matching.
|
381 |
+
The function updates the following local variables:
|
382 |
+
tp_fp_labels - if a box is matched to group-of
|
383 |
+
is_matched_to_difficult - the detections that were processed at this are
|
384 |
+
matched to difficult box.
|
385 |
+
is_matched_to_box - the detections that were processed at this stage are marked as is_box.
|
386 |
+
Args:
|
387 |
+
iou_matrix: intersection-over-union matrix [num_gt_boxes]x[num_det_boxes].
|
388 |
+
gt_nongroup_of_is_difficult_list: boolean that specifies if gt box is difficult.
|
389 |
+
is_box: boolean that specifies if currently boxes or masks are processed.
|
390 |
+
"""
|
391 |
+
max_overlap_gt_ids = np.argmax(iou_matrix, axis=1)
|
392 |
+
is_gt_detected = np.zeros(iou_matrix.shape[1], dtype=bool)
|
393 |
+
for i in range(num_detected_boxes):
|
394 |
+
gt_id = max_overlap_gt_ids[i]
|
395 |
+
is_evaluatable = (
|
396 |
+
not tp_fp_labels[i] and
|
397 |
+
not is_matched_to_difficult[i] and
|
398 |
+
iou_matrix[i, gt_id] >= self.matching_iou_threshold and
|
399 |
+
not is_matched_to_group_of[i])
|
400 |
+
if is_evaluatable:
|
401 |
+
if not gt_nongroup_of_is_difficult_list[gt_id]:
|
402 |
+
if not is_gt_detected[gt_id]:
|
403 |
+
tp_fp_labels[i] = True
|
404 |
+
is_gt_detected[gt_id] = True
|
405 |
+
is_matched_to_box[i] = is_box
|
406 |
+
else:
|
407 |
+
is_matched_to_difficult[i] = True
|
408 |
+
|
409 |
+
def compute_match_ioa(ioa_matrix, is_box):
|
410 |
+
"""Computes TP/FP for group-of box matching.
|
411 |
+
The function updates the following local variables:
|
412 |
+
is_matched_to_group_of - if a box is matched to group-of
|
413 |
+
is_matched_to_box - the detections that were processed at this stage are marked as is_box.
|
414 |
+
Args:
|
415 |
+
ioa_matrix: intersection-over-area matrix [num_gt_boxes]x[num_det_boxes].
|
416 |
+
is_box: boolean that specifies if currently boxes or masks are processed.
|
417 |
+
Returns:
|
418 |
+
scores_group_of: of detections matched to group-of boxes[num_groupof_matched].
|
419 |
+
tp_fp_labels_group_of: boolean array of size [num_groupof_matched], all values are True.
|
420 |
+
"""
|
421 |
+
scores_group_of = np.zeros(ioa_matrix.shape[1], dtype=float)
|
422 |
+
tp_fp_labels_group_of = self.group_of_weight * np.ones(ioa_matrix.shape[1], dtype=float)
|
423 |
+
max_overlap_group_of_gt_ids = np.argmax(ioa_matrix, axis=1)
|
424 |
+
for i in range(num_detected_boxes):
|
425 |
+
gt_id = max_overlap_group_of_gt_ids[i]
|
426 |
+
is_evaluatable = (
|
427 |
+
not tp_fp_labels[i] and
|
428 |
+
not is_matched_to_difficult[i] and
|
429 |
+
ioa_matrix[i, gt_id] >= self.matching_iou_threshold and
|
430 |
+
not is_matched_to_group_of[i])
|
431 |
+
if is_evaluatable:
|
432 |
+
is_matched_to_group_of[i] = True
|
433 |
+
is_matched_to_box[i] = is_box
|
434 |
+
scores_group_of[gt_id] = max(scores_group_of[gt_id], scores[i])
|
435 |
+
selector = np.where((scores_group_of > 0) & (tp_fp_labels_group_of > 0))
|
436 |
+
scores_group_of = scores_group_of[selector]
|
437 |
+
tp_fp_labels_group_of = tp_fp_labels_group_of[selector]
|
438 |
+
|
439 |
+
return scores_group_of, tp_fp_labels_group_of
|
440 |
+
|
441 |
+
# The evaluation is done in two stages:
|
442 |
+
# 1. Evaluate all objects that actually have instance level masks.
|
443 |
+
# 2. Evaluate all objects that are not already evaluated as boxes.
|
444 |
+
if iou_m.shape[1] > 0:
|
445 |
+
gt_is_difficult_mask_list = gt_is_difficult_list[mask_presence_indicator]
|
446 |
+
gt_is_group_of_mask_list = gt_is_group_of_list[mask_presence_indicator]
|
447 |
+
compute_match_iou(iou_m, gt_is_difficult_mask_list[~gt_is_group_of_mask_list], is_box=False)
|
448 |
+
|
449 |
+
scores_mask_group_of = np.ndarray([0], dtype=float)
|
450 |
+
tp_fp_labels_mask_group_of = np.ndarray([0], dtype=float)
|
451 |
+
if ioa_m.shape[1] > 0:
|
452 |
+
scores_mask_group_of, tp_fp_labels_mask_group_of = compute_match_ioa(ioa_m, is_box=False)
|
453 |
+
|
454 |
+
# Tp-fp evaluation for non-group of boxes (if any).
|
455 |
+
if iou_b.shape[1] > 0:
|
456 |
+
gt_is_difficult_box_list = gt_is_difficult_list[~mask_presence_indicator]
|
457 |
+
gt_is_group_of_box_list = gt_is_group_of_list[~mask_presence_indicator]
|
458 |
+
compute_match_iou(iou_b, gt_is_difficult_box_list[~gt_is_group_of_box_list], is_box=True)
|
459 |
+
|
460 |
+
scores_box_group_of = np.ndarray([0], dtype=float)
|
461 |
+
tp_fp_labels_box_group_of = np.ndarray([0], dtype=float)
|
462 |
+
if ioa_b.shape[1] > 0:
|
463 |
+
scores_box_group_of, tp_fp_labels_box_group_of = compute_match_ioa(ioa_b, is_box=True)
|
464 |
+
|
465 |
+
if mask_mode:
|
466 |
+
# Note: here crowds are treated as ignore regions.
|
467 |
+
valid_entries = (~is_matched_to_difficult & ~is_matched_to_group_of & ~is_matched_to_box)
|
468 |
+
return np.concatenate((scores[valid_entries], scores_mask_group_of)),\
|
469 |
+
np.concatenate((tp_fp_labels[valid_entries].astype(float), tp_fp_labels_mask_group_of))
|
470 |
+
else:
|
471 |
+
valid_entries = (~is_matched_to_difficult & ~is_matched_to_group_of)
|
472 |
+
return np.concatenate((scores[valid_entries], scores_box_group_of)),\
|
473 |
+
np.concatenate((tp_fp_labels[valid_entries].astype(float), tp_fp_labels_box_group_of))
|
474 |
+
|
475 |
+
def _get_ith_class_arrays(
|
476 |
+
self, detected_boxes, detected_scores, detected_masks, detected_class_labels,
|
477 |
+
gt_boxes, gt_masks, gt_class_labels, class_index):
|
478 |
+
"""Returns numpy arrays belonging to class with index `class_index`.
|
479 |
+
Args:
|
480 |
+
detected_boxes: A numpy array containing detected boxes.
|
481 |
+
detected_scores: A numpy array containing detected scores.
|
482 |
+
detected_masks: A numpy array containing detected masks.
|
483 |
+
detected_class_labels: A numpy array containing detected class labels.
|
484 |
+
gt_boxes: A numpy array containing groundtruth boxes.
|
485 |
+
gt_masks: A numpy array containing groundtruth masks.
|
486 |
+
gt_class_labels: A numpy array containing groundtruth class labels.
|
487 |
+
class_index: An integer index.
|
488 |
+
Returns:
|
489 |
+
gt_boxes_at_ith_class: A numpy array containing groundtruth boxes labeled as ith class.
|
490 |
+
gt_masks_at_ith_class: A numpy array containing groundtruth masks labeled as ith class.
|
491 |
+
detected_boxes_at_ith_class: A numpy array containing detected boxes corresponding to the ith class.
|
492 |
+
detected_scores_at_ith_class: A numpy array containing detected scores corresponding to the ith class.
|
493 |
+
detected_masks_at_ith_class: A numpy array containing detected masks corresponding to the ith class.
|
494 |
+
"""
|
495 |
+
selected_groundtruth = (gt_class_labels == class_index)
|
496 |
+
gt_boxes_at_ith_class = gt_boxes[selected_groundtruth]
|
497 |
+
if gt_masks is not None:
|
498 |
+
gt_masks_at_ith_class = gt_masks[selected_groundtruth]
|
499 |
+
else:
|
500 |
+
gt_masks_at_ith_class = None
|
501 |
+
selected_detections = (detected_class_labels == class_index)
|
502 |
+
detected_boxes_at_ith_class = detected_boxes[selected_detections]
|
503 |
+
detected_scores_at_ith_class = detected_scores[selected_detections]
|
504 |
+
if detected_masks is not None:
|
505 |
+
detected_masks_at_ith_class = detected_masks[selected_detections]
|
506 |
+
else:
|
507 |
+
detected_masks_at_ith_class = None
|
508 |
+
return (gt_boxes_at_ith_class, gt_masks_at_ith_class,
|
509 |
+
detected_boxes_at_ith_class, detected_scores_at_ith_class,
|
510 |
+
detected_masks_at_ith_class)
|
511 |
+
|
512 |
+
def _remove_invalid_boxes(
|
513 |
+
self, detected_boxes, detected_scores, detected_class_labels, detected_masks=None):
|
514 |
+
"""Removes entries with invalid boxes.
|
515 |
+
A box is invalid if either its xmax is smaller than its xmin, or its ymax is smaller than its ymin.
|
516 |
+
Args:
|
517 |
+
detected_boxes: A float numpy array of size [num_boxes, 4] containing box
|
518 |
+
coordinates in [ymin, xmin, ymax, xmax] format.
|
519 |
+
detected_scores: A float numpy array of size [num_boxes].
|
520 |
+
detected_class_labels: A int32 numpy array of size [num_boxes].
|
521 |
+
detected_masks: A uint8 numpy array of size [num_boxes, height, width].
|
522 |
+
Returns:
|
523 |
+
valid_detected_boxes: A float numpy array of size [num_valid_boxes, 4]
|
524 |
+
containing box coordinates in [ymin, xmin, ymax, xmax] format.
|
525 |
+
valid_detected_scores: A float numpy array of size [num_valid_boxes].
|
526 |
+
valid_detected_class_labels: A int32 numpy array of size [num_valid_boxes].
|
527 |
+
valid_detected_masks: A uint8 numpy array of size [num_valid_boxes, height, width].
|
528 |
+
"""
|
529 |
+
valid_indices = np.logical_and(
|
530 |
+
detected_boxes[:, 0] < detected_boxes[:, 2], detected_boxes[:, 1] < detected_boxes[:, 3])
|
531 |
+
detected_boxes = detected_boxes[valid_indices]
|
532 |
+
detected_scores = detected_scores[valid_indices]
|
533 |
+
detected_class_labels = detected_class_labels[valid_indices]
|
534 |
+
if detected_masks is not None:
|
535 |
+
detected_masks = detected_masks[valid_indices]
|
536 |
+
return [detected_boxes, detected_scores, detected_class_labels, detected_masks]
|
537 |
+
|
538 |
+
|
efficientdet/effdet/evaluator.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.distributed as dist
|
3 |
+
import abc
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import time
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from .distributed import synchronize, is_main_process, all_gather_container
|
10 |
+
from pycocotools.cocoeval import COCOeval
|
11 |
+
|
12 |
+
# FIXME experimenting with speedups for OpenImages eval, it's slow
|
13 |
+
#import pyximport; py_importer, pyx_importer = pyximport.install(pyimport=True)
|
14 |
+
import effdet.evaluation.detection_evaluator as tfm_eval
|
15 |
+
#pyximport.uninstall(py_importer, pyx_importer)
|
16 |
+
|
17 |
+
_logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
__all__ = ['CocoEvaluator', 'PascalEvaluator', 'OpenImagesEvaluator', 'create_evaluator']
|
21 |
+
|
22 |
+
|
23 |
+
class Evaluator:
|
24 |
+
|
25 |
+
def __init__(self, distributed=False, pred_yxyx=False):
|
26 |
+
self.distributed = distributed
|
27 |
+
self.distributed_device = None
|
28 |
+
self.pred_yxyx = pred_yxyx
|
29 |
+
self.img_indices = []
|
30 |
+
self.predictions = []
|
31 |
+
|
32 |
+
def add_predictions(self, detections, target):
|
33 |
+
if self.distributed:
|
34 |
+
if self.distributed_device is None:
|
35 |
+
# cache for use later to broadcast end metric
|
36 |
+
self.distributed_device = detections.device
|
37 |
+
synchronize()
|
38 |
+
detections = all_gather_container(detections)
|
39 |
+
img_indices = all_gather_container(target['img_idx'])
|
40 |
+
if not is_main_process():
|
41 |
+
return
|
42 |
+
else:
|
43 |
+
img_indices = target['img_idx']
|
44 |
+
|
45 |
+
detections = detections.cpu().numpy()
|
46 |
+
img_indices = img_indices.cpu().numpy()
|
47 |
+
for img_idx, img_dets in zip(img_indices, detections):
|
48 |
+
self.img_indices.append(img_idx)
|
49 |
+
self.predictions.append(img_dets)
|
50 |
+
|
51 |
+
def _coco_predictions(self):
|
52 |
+
# generate coco-style predictions
|
53 |
+
coco_predictions = []
|
54 |
+
coco_ids = []
|
55 |
+
for img_idx, img_dets in zip(self.img_indices, self.predictions):
|
56 |
+
img_id = self._dataset.img_ids[img_idx]
|
57 |
+
coco_ids.append(img_id)
|
58 |
+
if self.pred_yxyx:
|
59 |
+
# to xyxy
|
60 |
+
img_dets[:, 0:4] = img_dets[:, [1, 0, 3, 2]]
|
61 |
+
# to xywh
|
62 |
+
img_dets[:, 2] -= img_dets[:, 0]
|
63 |
+
img_dets[:, 3] -= img_dets[:, 1]
|
64 |
+
for det in img_dets:
|
65 |
+
score = float(det[4])
|
66 |
+
if score < .001: # stop when below this threshold, scores in descending order
|
67 |
+
break
|
68 |
+
coco_det = dict(
|
69 |
+
image_id=int(img_id),
|
70 |
+
bbox=det[0:4].tolist(),
|
71 |
+
score=score,
|
72 |
+
category_id=int(det[5]))
|
73 |
+
coco_predictions.append(coco_det)
|
74 |
+
return coco_predictions, coco_ids
|
75 |
+
|
76 |
+
@abc.abstractmethod
|
77 |
+
def evaluate(self):
|
78 |
+
pass
|
79 |
+
|
80 |
+
def save(self, result_file):
|
81 |
+
# save results in coco style, override to save in a alternate form
|
82 |
+
if not self.distributed or dist.get_rank() == 0:
|
83 |
+
assert len(self.predictions)
|
84 |
+
coco_predictions, coco_ids = self._coco_predictions()
|
85 |
+
json.dump(coco_predictions, open(result_file, 'w'), indent=4)
|
86 |
+
|
87 |
+
|
88 |
+
class CocoEvaluator(Evaluator):
|
89 |
+
|
90 |
+
def __init__(self, dataset, neptune=None, distributed=False, pred_yxyx=False):
|
91 |
+
super().__init__(distributed=distributed, pred_yxyx=pred_yxyx)
|
92 |
+
self._dataset = dataset.parser
|
93 |
+
self.coco_api = dataset.parser.coco
|
94 |
+
self.neptune = neptune
|
95 |
+
|
96 |
+
def reset(self):
|
97 |
+
self.img_indices = []
|
98 |
+
self.predictions = []
|
99 |
+
|
100 |
+
def evaluate(self):
|
101 |
+
if not self.distributed or dist.get_rank() == 0:
|
102 |
+
assert len(self.predictions)
|
103 |
+
coco_predictions, coco_ids = self._coco_predictions()
|
104 |
+
json.dump(coco_predictions, open('./temp.json', 'w'), indent=4)
|
105 |
+
results = self.coco_api.loadRes('./temp.json')
|
106 |
+
coco_eval = COCOeval(self.coco_api, results, 'bbox')
|
107 |
+
coco_eval.params.imgIds = coco_ids # score only ids we've used
|
108 |
+
coco_eval.evaluate()
|
109 |
+
coco_eval.accumulate()
|
110 |
+
coco_eval.summarize()
|
111 |
+
metric = coco_eval.stats[0] # mAP 0.5-0.95
|
112 |
+
if self.neptune:
|
113 |
+
self.neptune.log_metric('valid/mAP/0.5-0.95IOU', metric)
|
114 |
+
self.neptune.log_metric('valid/mAP/0.5IOU', coco_eval.stats[1])
|
115 |
+
if self.distributed:
|
116 |
+
dist.broadcast(torch.tensor(metric, device=self.distributed_device), 0)
|
117 |
+
else:
|
118 |
+
metric = torch.tensor(0, device=self.distributed_device)
|
119 |
+
dist.broadcast(metric, 0)
|
120 |
+
metric = metric.item()
|
121 |
+
self.reset()
|
122 |
+
return metric
|
123 |
+
|
124 |
+
|
125 |
+
class TfmEvaluator(Evaluator):
|
126 |
+
""" Tensorflow Models Evaluator Wrapper """
|
127 |
+
def __init__(
|
128 |
+
self, dataset, neptune=None, distributed=False, pred_yxyx=False,
|
129 |
+
evaluator_cls=tfm_eval.ObjectDetectionEvaluator):
|
130 |
+
super().__init__(distributed=distributed, pred_yxyx=pred_yxyx)
|
131 |
+
self._evaluator = evaluator_cls(categories=dataset.parser.cat_dicts)
|
132 |
+
self._eval_metric_name = self._evaluator._metric_names[0]
|
133 |
+
self._dataset = dataset.parser
|
134 |
+
self.neptune = neptune
|
135 |
+
|
136 |
+
def reset(self):
|
137 |
+
self._evaluator.clear()
|
138 |
+
self.img_indices = []
|
139 |
+
self.predictions = []
|
140 |
+
|
141 |
+
def evaluate(self):
|
142 |
+
if not self.distributed or dist.get_rank() == 0:
|
143 |
+
for img_idx, img_dets in zip(self.img_indices, self.predictions):
|
144 |
+
gt = self._dataset.get_ann_info(img_idx)
|
145 |
+
self._evaluator.add_single_ground_truth_image_info(img_idx, gt)
|
146 |
+
|
147 |
+
bbox = img_dets[:, 0:4] if self.pred_yxyx else img_dets[:, [1, 0, 3, 2]]
|
148 |
+
det = dict(bbox=bbox, score=img_dets[:, 4], cls=img_dets[:, 5])
|
149 |
+
self._evaluator.add_single_detected_image_info(img_idx, det)
|
150 |
+
|
151 |
+
metrics = self._evaluator.evaluate()
|
152 |
+
_logger.info('Metrics:')
|
153 |
+
for k, v in metrics.items():
|
154 |
+
_logger.info(f'{k}: {v}')
|
155 |
+
if self.neptune:
|
156 |
+
key = 'valid/mAP/' + str(k).split('/')[-1]
|
157 |
+
self.neptune.log_metric(key, v)
|
158 |
+
|
159 |
+
map_metric = metrics[self._eval_metric_name]
|
160 |
+
if self.distributed:
|
161 |
+
dist.broadcast(torch.tensor(map_metric, device=self.distributed_device), 0)
|
162 |
+
else:
|
163 |
+
map_metric = torch.tensor(0, device=self.distributed_device)
|
164 |
+
wait = dist.broadcast(map_metric, 0, async_op=True)
|
165 |
+
while not wait.is_completed():
|
166 |
+
# wait without spinning the cpu @ 100%, no need for low latency here
|
167 |
+
time.sleep(0.5)
|
168 |
+
map_metric = map_metric.item()
|
169 |
+
self.reset()
|
170 |
+
return map_metric
|
171 |
+
|
172 |
+
|
173 |
+
class PascalEvaluator(TfmEvaluator):
|
174 |
+
|
175 |
+
def __init__(self, dataset, neptune=None, distributed=False, pred_yxyx=False):
|
176 |
+
super().__init__(
|
177 |
+
dataset, neptune, distributed=distributed, pred_yxyx=pred_yxyx, evaluator_cls=tfm_eval.PascalDetectionEvaluator)
|
178 |
+
|
179 |
+
|
180 |
+
class OpenImagesEvaluator(TfmEvaluator):
|
181 |
+
|
182 |
+
def __init__(self, dataset, distributed=False, pred_yxyx=False):
|
183 |
+
super().__init__(
|
184 |
+
dataset, distributed=distributed, pred_yxyx=pred_yxyx, evaluator_cls=tfm_eval.OpenImagesDetectionEvaluator)
|
185 |
+
|
186 |
+
|
187 |
+
def create_evaluator(name, dataset, neptune=None, distributed=False, pred_yxyx=False):
|
188 |
+
# FIXME support OpenImages Challenge2019 metric w/ image level label consideration
|
189 |
+
if 'coco' in name:
|
190 |
+
return CocoEvaluator(dataset, neptune, distributed=distributed, pred_yxyx=pred_yxyx)
|
191 |
+
elif 'openimages' in name:
|
192 |
+
return OpenImagesEvaluator(dataset, distributed=distributed, pred_yxyx=pred_yxyx)
|
193 |
+
else:
|
194 |
+
return CocoEvaluator(dataset, neptune, distributed=distributed, pred_yxyx=pred_yxyx)
|
195 |
+
#return PascalEvaluator(dataset, neptune, distributed=distributed, pred_yxyx=pred_yxyx)
|
efficientdet/effdet/factory.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .efficientdet import EfficientDet, HeadNet
|
2 |
+
from .bench import DetBenchTrain, DetBenchPredict
|
3 |
+
from .config import get_efficientdet_config
|
4 |
+
from .helpers import load_pretrained, load_checkpoint
|
5 |
+
|
6 |
+
|
7 |
+
def create_model(
|
8 |
+
model_name, bench_task='', num_classes=None, pretrained=False,
|
9 |
+
checkpoint_path='', checkpoint_ema=False, **kwargs):
|
10 |
+
|
11 |
+
config = get_efficientdet_config(model_name)
|
12 |
+
return create_model_from_config(
|
13 |
+
config, bench_task=bench_task, num_classes=num_classes, pretrained=pretrained,
|
14 |
+
checkpoint_path=checkpoint_path, checkpoint_ema=checkpoint_ema, **kwargs)
|
15 |
+
|
16 |
+
|
17 |
+
def create_model_from_config(
|
18 |
+
config, bench_task='', num_classes=None, pretrained=False,
|
19 |
+
checkpoint_path='', checkpoint_ema=False, **kwargs):
|
20 |
+
|
21 |
+
pretrained_backbone = kwargs.pop('pretrained_backbone', True)
|
22 |
+
if pretrained or checkpoint_path:
|
23 |
+
pretrained_backbone = False # no point in loading backbone weights
|
24 |
+
|
25 |
+
# Config overrides, override some config values via kwargs.
|
26 |
+
overrides = ('redundant_bias', 'label_smoothing', 'new_focal', 'jit_loss')
|
27 |
+
for ov in overrides:
|
28 |
+
value = kwargs.pop(ov, None)
|
29 |
+
if value is not None:
|
30 |
+
setattr(config, ov, value)
|
31 |
+
|
32 |
+
labeler = kwargs.pop('bench_labeler', False)
|
33 |
+
|
34 |
+
# create the base model
|
35 |
+
model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs)
|
36 |
+
|
37 |
+
# pretrained weights are always spec'd for original config, load them before we change the model
|
38 |
+
if pretrained:
|
39 |
+
load_pretrained(model, config.url)
|
40 |
+
|
41 |
+
# reset model head if num_classes doesn't match configs
|
42 |
+
if num_classes is not None and num_classes != config.num_classes:
|
43 |
+
model.reset_head(num_classes=num_classes)
|
44 |
+
|
45 |
+
# load an argument specified training checkpoint
|
46 |
+
if checkpoint_path:
|
47 |
+
load_checkpoint(model, checkpoint_path, use_ema=checkpoint_ema)
|
48 |
+
|
49 |
+
# wrap model in task specific training/prediction bench if set
|
50 |
+
if bench_task == 'train':
|
51 |
+
model = DetBenchTrain(model, create_labeler=labeler)
|
52 |
+
elif bench_task == 'predict':
|
53 |
+
model = DetBenchPredict(model)
|
54 |
+
return model
|
efficientdet/effdet/helpers.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
from timm.models import load_checkpoint
|
7 |
+
|
8 |
+
try:
|
9 |
+
from torch.hub import load_state_dict_from_url
|
10 |
+
except ImportError:
|
11 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
12 |
+
|
13 |
+
|
14 |
+
def load_pretrained(model, url, filter_fn=None, strict=True):
|
15 |
+
if not url:
|
16 |
+
logging.warning("Pretrained model URL is empty, using random initialization. "
|
17 |
+
"Did you intend to use a `tf_` variant of the model?")
|
18 |
+
return
|
19 |
+
state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
|
20 |
+
if filter_fn is not None:
|
21 |
+
state_dict = filter_fn(state_dict)
|
22 |
+
model.load_state_dict(state_dict, strict=strict)
|
efficientdet/effdet/loss.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" EfficientDet Focal, Huber/Smooth L1 loss fns w/ jit support
|
2 |
+
|
3 |
+
Based on loss fn in Google's automl EfficientDet repository (Apache 2.0 license).
|
4 |
+
https://github.com/google/automl/tree/master/efficientdet
|
5 |
+
|
6 |
+
Copyright 2020 Ross Wightman
|
7 |
+
"""
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from typing import Optional, List, Tuple
|
13 |
+
|
14 |
+
|
15 |
+
def focal_loss_legacy(logits, targets, alpha: float, gamma: float, normalizer):
|
16 |
+
"""Compute the focal loss between `logits` and the golden `target` values.
|
17 |
+
|
18 |
+
'Legacy focal loss matches the loss used in the official Tensorflow impl for initial
|
19 |
+
model releases and some time after that. It eventually transitioned to the 'New' loss
|
20 |
+
defined below.
|
21 |
+
|
22 |
+
Focal loss = -(1-pt)^gamma * log(pt)
|
23 |
+
where pt is the probability of being classified to the true class.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
logits: A float32 tensor of size [batch, height_in, width_in, num_predictions].
|
27 |
+
|
28 |
+
targets: A float32 tensor of size [batch, height_in, width_in, num_predictions].
|
29 |
+
|
30 |
+
alpha: A float32 scalar multiplying alpha to the loss from positive examples
|
31 |
+
and (1-alpha) to the loss from negative examples.
|
32 |
+
|
33 |
+
gamma: A float32 scalar modulating loss from hard and easy examples.
|
34 |
+
|
35 |
+
normalizer: A float32 scalar normalizes the total loss from all examples.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
loss: A float32 scalar representing normalized total loss.
|
39 |
+
"""
|
40 |
+
positive_label_mask = targets == 1.0
|
41 |
+
cross_entropy = F.binary_cross_entropy_with_logits(logits, targets.to(logits.dtype), reduction='none')
|
42 |
+
neg_logits = -1.0 * logits
|
43 |
+
modulator = torch.exp(gamma * targets * neg_logits - gamma * torch.log1p(torch.exp(neg_logits)))
|
44 |
+
|
45 |
+
loss = modulator * cross_entropy
|
46 |
+
weighted_loss = torch.where(positive_label_mask, alpha * loss, (1.0 - alpha) * loss)
|
47 |
+
return weighted_loss / normalizer
|
48 |
+
|
49 |
+
|
50 |
+
def new_focal_loss(logits, targets, alpha: float, gamma: float, normalizer, label_smoothing: float = 0.01):
|
51 |
+
"""Compute the focal loss between `logits` and the golden `target` values.
|
52 |
+
|
53 |
+
'New' is not the best descriptor, but this focal loss impl matches recent versions of
|
54 |
+
the official Tensorflow impl of EfficientDet. It has support for label smoothing, however
|
55 |
+
it is a bit slower, doesn't jit optimize well, and uses more memory.
|
56 |
+
|
57 |
+
Focal loss = -(1-pt)^gamma * log(pt)
|
58 |
+
where pt is the probability of being classified to the true class.
|
59 |
+
Args:
|
60 |
+
logits: A float32 tensor of size [batch, height_in, width_in, num_predictions].
|
61 |
+
targets: A float32 tensor of size [batch, height_in, width_in, num_predictions].
|
62 |
+
alpha: A float32 scalar multiplying alpha to the loss from positive examples
|
63 |
+
and (1-alpha) to the loss from negative examples.
|
64 |
+
gamma: A float32 scalar modulating loss from hard and easy examples.
|
65 |
+
normalizer: Divide loss by this value.
|
66 |
+
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
|
67 |
+
Returns:
|
68 |
+
loss: A float32 scalar representing normalized total loss.
|
69 |
+
"""
|
70 |
+
# compute focal loss multipliers before label smoothing, such that it will not blow up the loss.
|
71 |
+
pred_prob = logits.sigmoid()
|
72 |
+
targets = targets.to(logits.dtype)
|
73 |
+
onem_targets = 1. - targets
|
74 |
+
p_t = (targets * pred_prob) + (onem_targets * (1. - pred_prob))
|
75 |
+
alpha_factor = targets * alpha + onem_targets * (1. - alpha)
|
76 |
+
modulating_factor = (1. - p_t) ** gamma
|
77 |
+
|
78 |
+
# apply label smoothing for cross_entropy for each entry.
|
79 |
+
if label_smoothing > 0.:
|
80 |
+
targets = targets * (1. - label_smoothing) + .5 * label_smoothing
|
81 |
+
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
|
82 |
+
|
83 |
+
# compute the final loss and return
|
84 |
+
return (1 / normalizer) * alpha_factor * modulating_factor * ce
|
85 |
+
|
86 |
+
|
87 |
+
def huber_loss(
|
88 |
+
input, target, delta: float = 1., weights: Optional[torch.Tensor] = None, size_average: bool = True):
|
89 |
+
"""
|
90 |
+
"""
|
91 |
+
err = input - target
|
92 |
+
abs_err = err.abs()
|
93 |
+
quadratic = torch.clamp(abs_err, max=delta)
|
94 |
+
linear = abs_err - quadratic
|
95 |
+
loss = 0.5 * quadratic.pow(2) + delta * linear
|
96 |
+
if weights is not None:
|
97 |
+
loss *= weights
|
98 |
+
if size_average:
|
99 |
+
return loss.mean()
|
100 |
+
else:
|
101 |
+
return loss.sum()
|
102 |
+
|
103 |
+
|
104 |
+
def smooth_l1_loss(
|
105 |
+
input, target, beta: float = 1. / 9, weights: Optional[torch.Tensor] = None, size_average: bool = True):
|
106 |
+
"""
|
107 |
+
very similar to the smooth_l1_loss from pytorch, but with the extra beta parameter
|
108 |
+
"""
|
109 |
+
if beta < 1e-5:
|
110 |
+
# if beta == 0, then torch.where will result in nan gradients when
|
111 |
+
# the chain rule is applied due to pytorch implementation details
|
112 |
+
# (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of
|
113 |
+
# zeros, rather than "no gradient"). To avoid this issue, we define
|
114 |
+
# small values of beta to be exactly l1 loss.
|
115 |
+
loss = torch.abs(input - target)
|
116 |
+
else:
|
117 |
+
err = torch.abs(input - target)
|
118 |
+
loss = torch.where(err < beta, 0.5 * err.pow(2) / beta, err - 0.5 * beta)
|
119 |
+
if weights is not None:
|
120 |
+
loss *= weights
|
121 |
+
if size_average:
|
122 |
+
return loss.mean()
|
123 |
+
else:
|
124 |
+
return loss.sum()
|
125 |
+
|
126 |
+
|
127 |
+
def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1):
|
128 |
+
"""Computes box regression loss."""
|
129 |
+
# delta is typically around the mean value of regression target.
|
130 |
+
# for instances, the regression targets of 512x512 input with 6 anchors on
|
131 |
+
# P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2].
|
132 |
+
normalizer = num_positives * 4.0
|
133 |
+
mask = box_targets != 0.0
|
134 |
+
box_loss = huber_loss(box_outputs, box_targets, weights=mask, delta=delta, size_average=False)
|
135 |
+
return box_loss / normalizer
|
136 |
+
|
137 |
+
|
138 |
+
def one_hot(x, num_classes: int):
|
139 |
+
# NOTE: PyTorch one-hot does not handle -ve entries (no hot) like Tensorflow, so mask them out
|
140 |
+
x_non_neg = (x >= 0).unsqueeze(-1)
|
141 |
+
onehot = torch.zeros(x.shape + (num_classes,), device=x.device, dtype=torch.float32)
|
142 |
+
return onehot.scatter(-1, x.unsqueeze(-1) * x_non_neg, 1) * x_non_neg
|
143 |
+
|
144 |
+
|
145 |
+
def loss_fn(
|
146 |
+
cls_outputs: List[torch.Tensor],
|
147 |
+
box_outputs: List[torch.Tensor],
|
148 |
+
cls_targets: List[torch.Tensor],
|
149 |
+
box_targets: List[torch.Tensor],
|
150 |
+
num_positives: torch.Tensor,
|
151 |
+
num_classes: int,
|
152 |
+
alpha: float,
|
153 |
+
gamma: float,
|
154 |
+
delta: float,
|
155 |
+
box_loss_weight: float,
|
156 |
+
label_smoothing: float = 0.,
|
157 |
+
new_focal: bool = False,
|
158 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
159 |
+
"""Computes total detection loss.
|
160 |
+
Computes total detection loss including box and class loss from all levels.
|
161 |
+
Args:
|
162 |
+
cls_outputs: a List with values representing logits in [batch_size, height, width, num_anchors].
|
163 |
+
at each feature level (index)
|
164 |
+
|
165 |
+
box_outputs: a List with values representing box regression targets in
|
166 |
+
[batch_size, height, width, num_anchors * 4] at each feature level (index)
|
167 |
+
|
168 |
+
cls_targets: groundtruth class targets.
|
169 |
+
|
170 |
+
box_targets: groundtrusth box targets.
|
171 |
+
|
172 |
+
num_positives: num positive grountruth anchors
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
total_loss: an integer tensor representing total loss reducing from class and box losses from all levels.
|
176 |
+
|
177 |
+
cls_loss: an integer tensor representing total class loss.
|
178 |
+
|
179 |
+
box_loss: an integer tensor representing total box regression loss.
|
180 |
+
"""
|
181 |
+
# Sum all positives in a batch for normalization and avoid zero
|
182 |
+
# num_positives_sum, which would lead to inf loss during training
|
183 |
+
num_positives_sum = (num_positives.sum() + 1.0).float()
|
184 |
+
levels = len(cls_outputs)
|
185 |
+
|
186 |
+
cls_losses = []
|
187 |
+
box_losses = []
|
188 |
+
for l in range(levels):
|
189 |
+
cls_targets_at_level = cls_targets[l]
|
190 |
+
box_targets_at_level = box_targets[l]
|
191 |
+
|
192 |
+
# Onehot encoding for classification labels.
|
193 |
+
cls_targets_at_level_oh = one_hot(cls_targets_at_level, num_classes)
|
194 |
+
|
195 |
+
bs, height, width, _, _ = cls_targets_at_level_oh.shape
|
196 |
+
cls_targets_at_level_oh = cls_targets_at_level_oh.view(bs, height, width, -1)
|
197 |
+
cls_outputs_at_level = cls_outputs[l].permute(0, 2, 3, 1).float()
|
198 |
+
if new_focal:
|
199 |
+
cls_loss = new_focal_loss(
|
200 |
+
cls_outputs_at_level, cls_targets_at_level_oh,
|
201 |
+
alpha=alpha, gamma=gamma, normalizer=num_positives_sum, label_smoothing=label_smoothing)
|
202 |
+
else:
|
203 |
+
cls_loss = focal_loss_legacy(
|
204 |
+
cls_outputs_at_level, cls_targets_at_level_oh,
|
205 |
+
alpha=alpha, gamma=gamma, normalizer=num_positives_sum)
|
206 |
+
cls_loss = cls_loss.view(bs, height, width, -1, num_classes)
|
207 |
+
cls_loss = cls_loss * (cls_targets_at_level != -2).unsqueeze(-1)
|
208 |
+
cls_losses.append(cls_loss.sum()) # FIXME reference code added a clamp here at some point ...clamp(0, 2))
|
209 |
+
|
210 |
+
box_losses.append(_box_loss(
|
211 |
+
box_outputs[l].permute(0, 2, 3, 1).float(),
|
212 |
+
box_targets_at_level,
|
213 |
+
num_positives_sum,
|
214 |
+
delta=delta))
|
215 |
+
|
216 |
+
# Sum per level losses to total loss.
|
217 |
+
cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1)
|
218 |
+
box_loss = torch.sum(torch.stack(box_losses, dim=-1), dim=-1)
|
219 |
+
total_loss = cls_loss + box_loss_weight * box_loss
|
220 |
+
return total_loss, cls_loss, box_loss
|
221 |
+
|
222 |
+
|
223 |
+
loss_jit = torch.jit.script(loss_fn)
|
224 |
+
|
225 |
+
|
226 |
+
class DetectionLoss(nn.Module):
|
227 |
+
|
228 |
+
__constants__ = ['num_classes']
|
229 |
+
|
230 |
+
def __init__(self, config):
|
231 |
+
super(DetectionLoss, self).__init__()
|
232 |
+
self.config = config
|
233 |
+
self.num_classes = config.num_classes
|
234 |
+
self.alpha = config.alpha
|
235 |
+
self.gamma = config.gamma
|
236 |
+
self.delta = config.delta
|
237 |
+
self.box_loss_weight = config.box_loss_weight
|
238 |
+
self.label_smoothing = config.label_smoothing
|
239 |
+
self.new_focal = config.new_focal
|
240 |
+
self.use_jit = config.jit_loss
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self,
|
244 |
+
cls_outputs: List[torch.Tensor],
|
245 |
+
box_outputs: List[torch.Tensor],
|
246 |
+
cls_targets: List[torch.Tensor],
|
247 |
+
box_targets: List[torch.Tensor],
|
248 |
+
num_positives: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
249 |
+
|
250 |
+
l_fn = loss_fn
|
251 |
+
if not torch.jit.is_scripting() and self.use_jit:
|
252 |
+
# This branch only active if parent / bench itself isn't being scripted
|
253 |
+
# NOTE: I haven't figured out what to do here wrt to tracing, is it an issue?
|
254 |
+
l_fn = loss_jit
|
255 |
+
|
256 |
+
return l_fn(
|
257 |
+
cls_outputs, box_outputs, cls_targets, box_targets, num_positives,
|
258 |
+
num_classes=self.num_classes, alpha=self.alpha, gamma=self.gamma, delta=self.delta,
|
259 |
+
box_loss_weight=self.box_loss_weight, label_smoothing=self.label_smoothing, new_focal=self.new_focal)
|
efficientdet/effdet/object_detection/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Tensorflow Object Detection
|
2 |
+
|
3 |
+
All of this code is adapted/ported/copied from https://github.com/google/automl/tree/552d0facd14f4fe9205a67fb13ecb5690a4d1c94/efficientdet/object_detection
|
efficientdet/effdet/object_detection/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Google Research. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
# Object detection data loaders and libraries are mostly based on RetinaNet:
|
16 |
+
# https://github.com/tensorflow/tpu/tree/master/models/official/retinanet
|
17 |
+
from .argmax_matcher import ArgMaxMatcher
|
18 |
+
from .box_coder import FasterRcnnBoxCoder
|
19 |
+
from .box_list import BoxList
|
20 |
+
from .matcher import Match
|
21 |
+
from .region_similarity_calculator import IouSimilarity
|
22 |
+
from .target_assigner import TargetAssigner
|
efficientdet/effdet/object_detection/argmax_matcher.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Google Research. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Argmax matcher implementation.
|
16 |
+
|
17 |
+
This class takes a similarity matrix and matches columns to rows based on the
|
18 |
+
maximum value per column. One can specify matched_thresholds and
|
19 |
+
to prevent columns from matching to rows (generally resulting in a negative
|
20 |
+
training example) and unmatched_theshold to ignore the match (generally
|
21 |
+
resulting in neither a positive or negative training example).
|
22 |
+
|
23 |
+
This matcher is used in Fast(er)-RCNN.
|
24 |
+
|
25 |
+
Note: matchers are used in TargetAssigners. There is a create_target_assigner
|
26 |
+
factory function for popular implementations.
|
27 |
+
"""
|
28 |
+
import torch
|
29 |
+
from .matcher import Match
|
30 |
+
from typing import Optional
|
31 |
+
|
32 |
+
|
33 |
+
def one_hot_bool(x, num_classes: int):
|
34 |
+
# for improved perf over PyTorch builtin one_hot, scatter to bool
|
35 |
+
onehot = torch.zeros(x.size(0), num_classes, device=x.device, dtype=torch.bool)
|
36 |
+
return onehot.scatter_(1, x.unsqueeze(1), 1)
|
37 |
+
|
38 |
+
|
39 |
+
@torch.jit.script
|
40 |
+
class ArgMaxMatcher(object): # cannot inherit with torchscript
|
41 |
+
"""Matcher based on highest value.
|
42 |
+
|
43 |
+
This class computes matches from a similarity matrix. Each column is matched
|
44 |
+
to a single row.
|
45 |
+
|
46 |
+
To support object detection target assignment this class enables setting both
|
47 |
+
matched_threshold (upper threshold) and unmatched_threshold (lower thresholds)
|
48 |
+
defining three categories of similarity which define whether examples are
|
49 |
+
positive, negative, or ignored:
|
50 |
+
(1) similarity >= matched_threshold: Highest similarity. Matched/Positive!
|
51 |
+
(2) matched_threshold > similarity >= unmatched_threshold: Medium similarity.
|
52 |
+
Depending on negatives_lower_than_unmatched, this is either
|
53 |
+
Unmatched/Negative OR Ignore.
|
54 |
+
(3) unmatched_threshold > similarity: Lowest similarity. Depending on flag
|
55 |
+
negatives_lower_than_unmatched, either Unmatched/Negative OR Ignore.
|
56 |
+
For ignored matches this class sets the values in the Match object to -2.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self,
|
60 |
+
matched_threshold: float,
|
61 |
+
unmatched_threshold: Optional[float] = None,
|
62 |
+
negatives_lower_than_unmatched: bool = True,
|
63 |
+
force_match_for_each_row: bool = False):
|
64 |
+
"""Construct ArgMaxMatcher.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
matched_threshold: Threshold for positive matches. Positive if
|
68 |
+
sim >= matched_threshold, where sim is the maximum value of the
|
69 |
+
similarity matrix for a given column. Set to None for no threshold.
|
70 |
+
unmatched_threshold: Threshold for negative matches. Negative if
|
71 |
+
sim < unmatched_threshold. Defaults to matched_threshold
|
72 |
+
when set to None.
|
73 |
+
negatives_lower_than_unmatched: Boolean which defaults to True. If True
|
74 |
+
then negative matches are the ones below the unmatched_threshold,
|
75 |
+
whereas ignored matches are in between the matched and unmatched
|
76 |
+
threshold. If False, then negative matches are in between the matched
|
77 |
+
and unmatched threshold, and everything lower than unmatched is ignored.
|
78 |
+
force_match_for_each_row: If True, ensures that each row is matched to
|
79 |
+
at least one column (which is not guaranteed otherwise if the
|
80 |
+
matched_threshold is high). Defaults to False. See
|
81 |
+
argmax_matcher_test.testMatcherForceMatch() for an example.
|
82 |
+
|
83 |
+
Raises:
|
84 |
+
ValueError: if unmatched_threshold is set but matched_threshold is not set
|
85 |
+
or if unmatched_threshold > matched_threshold.
|
86 |
+
"""
|
87 |
+
if (matched_threshold is None) and (unmatched_threshold is not None):
|
88 |
+
raise ValueError('Need to also define matched_threshold when unmatched_threshold is defined')
|
89 |
+
self._matched_threshold = matched_threshold
|
90 |
+
self._unmatched_threshold: float = 0.
|
91 |
+
if unmatched_threshold is None:
|
92 |
+
self._unmatched_threshold = matched_threshold
|
93 |
+
else:
|
94 |
+
if unmatched_threshold > matched_threshold:
|
95 |
+
raise ValueError('unmatched_threshold needs to be smaller or equal to matched_threshold')
|
96 |
+
self._unmatched_threshold = unmatched_threshold
|
97 |
+
if not negatives_lower_than_unmatched:
|
98 |
+
if self._unmatched_threshold == self._matched_threshold:
|
99 |
+
raise ValueError('When negatives are in between matched and unmatched thresholds, these '
|
100 |
+
'cannot be of equal value. matched: %s, unmatched: %s',
|
101 |
+
self._matched_threshold, self._unmatched_threshold)
|
102 |
+
self._force_match_for_each_row = force_match_for_each_row
|
103 |
+
self._negatives_lower_than_unmatched = negatives_lower_than_unmatched
|
104 |
+
|
105 |
+
def _match_when_rows_are_empty(self, similarity_matrix):
|
106 |
+
"""Performs matching when the rows of similarity matrix are empty.
|
107 |
+
|
108 |
+
When the rows are empty, all detections are false positives. So we return
|
109 |
+
a tensor of -1's to indicate that the columns do not match to any rows.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
matches: int32 tensor indicating the row each column matches to.
|
113 |
+
"""
|
114 |
+
return -1 * torch.ones(similarity_matrix.shape[1], dtype=torch.long, device=similarity_matrix.device)
|
115 |
+
|
116 |
+
def _match_when_rows_are_non_empty(self, similarity_matrix):
|
117 |
+
"""Performs matching when the rows of similarity matrix are non empty.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
matches: int32 tensor indicating the row each column matches to.
|
121 |
+
"""
|
122 |
+
# Matches for each column
|
123 |
+
matched_vals, matches = torch.max(similarity_matrix, 0)
|
124 |
+
|
125 |
+
# Deal with matched and unmatched threshold
|
126 |
+
if self._matched_threshold is not None:
|
127 |
+
# Get logical indices of ignored and unmatched columns as tf.int64
|
128 |
+
below_unmatched_threshold = self._unmatched_threshold > matched_vals
|
129 |
+
between_thresholds = (matched_vals >= self._unmatched_threshold) & \
|
130 |
+
(self._matched_threshold > matched_vals)
|
131 |
+
|
132 |
+
if self._negatives_lower_than_unmatched:
|
133 |
+
matches = self._set_values_using_indicator(matches, below_unmatched_threshold, -1)
|
134 |
+
matches = self._set_values_using_indicator(matches, between_thresholds, -2)
|
135 |
+
else:
|
136 |
+
matches = self._set_values_using_indicator(matches, below_unmatched_threshold, -2)
|
137 |
+
matches = self._set_values_using_indicator(matches, between_thresholds, -1)
|
138 |
+
|
139 |
+
if self._force_match_for_each_row:
|
140 |
+
force_match_column_ids = torch.argmax(similarity_matrix, 1)
|
141 |
+
force_match_column_indicators = one_hot_bool(force_match_column_ids, similarity_matrix.shape[1])
|
142 |
+
force_match_column_mask, force_match_row_ids = torch.max(force_match_column_indicators, 0)
|
143 |
+
final_matches = torch.where(force_match_column_mask, force_match_row_ids, matches)
|
144 |
+
return final_matches
|
145 |
+
else:
|
146 |
+
return matches
|
147 |
+
|
148 |
+
def match(self, similarity_matrix):
|
149 |
+
"""Tries to match each column of the similarity matrix to a row.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
similarity_matrix: tensor of shape [N, M] representing any similarity metric.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
Match object with corresponding matches for each of M columns.
|
156 |
+
"""
|
157 |
+
if similarity_matrix.shape[0] == 0:
|
158 |
+
return Match(self._match_when_rows_are_empty(similarity_matrix))
|
159 |
+
else:
|
160 |
+
return Match(self._match_when_rows_are_non_empty(similarity_matrix))
|
161 |
+
|
162 |
+
def _set_values_using_indicator(self, x, indicator, val: int):
|
163 |
+
"""Set the indicated fields of x to val.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
x: tensor.
|
167 |
+
indicator: boolean with same shape as x.
|
168 |
+
val: scalar with value to set.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
modified tensor.
|
172 |
+
"""
|
173 |
+
indicator = indicator.to(dtype=x.dtype)
|
174 |
+
return x * (1 - indicator) + val * indicator
|
efficientdet/effdet/object_detection/box_coder.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Google Research. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Base box coder.
|
16 |
+
|
17 |
+
Box coders convert between coordinate frames, namely image-centric
|
18 |
+
(with (0,0) on the top left of image) and anchor-centric (with (0,0) being
|
19 |
+
defined by a specific anchor).
|
20 |
+
|
21 |
+
Users of a BoxCoder can call two methods:
|
22 |
+
encode: which encodes a box with respect to a given anchor
|
23 |
+
(or rather, a tensor of boxes wrt a corresponding tensor of anchors) and
|
24 |
+
decode: which inverts this encoding with a decode operation.
|
25 |
+
In both cases, the arguments are assumed to be in 1-1 correspondence already;
|
26 |
+
it is not the job of a BoxCoder to perform matching.
|
27 |
+
"""
|
28 |
+
import torch
|
29 |
+
from typing import List, Optional
|
30 |
+
from .box_list import BoxList
|
31 |
+
|
32 |
+
# Box coder types.
|
33 |
+
FASTER_RCNN = 'faster_rcnn'
|
34 |
+
KEYPOINT = 'keypoint'
|
35 |
+
MEAN_STDDEV = 'mean_stddev'
|
36 |
+
SQUARE = 'square'
|
37 |
+
|
38 |
+
|
39 |
+
"""Faster RCNN box coder.
|
40 |
+
|
41 |
+
Faster RCNN box coder follows the coding schema described below:
|
42 |
+
ty = (y - ya) / ha
|
43 |
+
tx = (x - xa) / wa
|
44 |
+
th = log(h / ha)
|
45 |
+
tw = log(w / wa)
|
46 |
+
where x, y, w, h denote the box's center coordinates, width and height
|
47 |
+
respectively. Similarly, xa, ya, wa, ha denote the anchor's center
|
48 |
+
coordinates, width and height. tx, ty, tw and th denote the anchor-encoded
|
49 |
+
center, width and height respectively.
|
50 |
+
|
51 |
+
See http://arxiv.org/abs/1506.01497 for details.
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
EPS = 1e-8
|
56 |
+
|
57 |
+
|
58 |
+
#@torch.jit.script
|
59 |
+
class FasterRcnnBoxCoder(object):
|
60 |
+
"""Faster RCNN box coder."""
|
61 |
+
|
62 |
+
def __init__(self, scale_factors: Optional[List[float]] = None, eps: float = EPS):
|
63 |
+
"""Constructor for FasterRcnnBoxCoder.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
scale_factors: List of 4 positive scalars to scale ty, tx, th and tw.
|
67 |
+
If set to None, does not perform scaling. For Faster RCNN,
|
68 |
+
the open-source implementation recommends using [10.0, 10.0, 5.0, 5.0].
|
69 |
+
"""
|
70 |
+
self._scale_factors = scale_factors
|
71 |
+
if scale_factors is not None:
|
72 |
+
assert len(scale_factors) == 4
|
73 |
+
for scalar in scale_factors:
|
74 |
+
assert scalar > 0
|
75 |
+
self.eps = eps
|
76 |
+
|
77 |
+
#@property
|
78 |
+
def code_size(self):
|
79 |
+
return 4
|
80 |
+
|
81 |
+
def encode(self, boxes: BoxList, anchors: BoxList):
|
82 |
+
"""Encode a box collection with respect to anchor collection.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
boxes: BoxList holding N boxes to be encoded.
|
86 |
+
anchors: BoxList of anchors.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
a tensor representing N anchor-encoded boxes of the format [ty, tx, th, tw].
|
90 |
+
"""
|
91 |
+
# Convert anchors to the center coordinate representation.
|
92 |
+
ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
|
93 |
+
ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes()
|
94 |
+
# Avoid NaN in division and log below.
|
95 |
+
ha += self.eps
|
96 |
+
wa += self.eps
|
97 |
+
h += self.eps
|
98 |
+
w += self.eps
|
99 |
+
|
100 |
+
tx = (xcenter - xcenter_a) / wa
|
101 |
+
ty = (ycenter - ycenter_a) / ha
|
102 |
+
tw = torch.log(w / wa)
|
103 |
+
th = torch.log(h / ha)
|
104 |
+
# Scales location targets as used in paper for joint training.
|
105 |
+
if self._scale_factors is not None:
|
106 |
+
ty *= self._scale_factors[0]
|
107 |
+
tx *= self._scale_factors[1]
|
108 |
+
th *= self._scale_factors[2]
|
109 |
+
tw *= self._scale_factors[3]
|
110 |
+
return torch.stack([ty, tx, th, tw]).t()
|
111 |
+
|
112 |
+
def decode(self, rel_codes, anchors: BoxList):
|
113 |
+
"""Decode relative codes to boxes.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
rel_codes: a tensor representing N anchor-encoded boxes.
|
117 |
+
anchors: BoxList of anchors.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
boxes: BoxList holding N bounding boxes.
|
121 |
+
"""
|
122 |
+
ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
|
123 |
+
|
124 |
+
ty, tx, th, tw = rel_codes.t().unbind()
|
125 |
+
if self._scale_factors is not None:
|
126 |
+
ty /= self._scale_factors[0]
|
127 |
+
tx /= self._scale_factors[1]
|
128 |
+
th /= self._scale_factors[2]
|
129 |
+
tw /= self._scale_factors[3]
|
130 |
+
w = torch.exp(tw) * wa
|
131 |
+
h = torch.exp(th) * ha
|
132 |
+
ycenter = ty * ha + ycenter_a
|
133 |
+
xcenter = tx * wa + xcenter_a
|
134 |
+
ymin = ycenter - h / 2.
|
135 |
+
xmin = xcenter - w / 2.
|
136 |
+
ymax = ycenter + h / 2.
|
137 |
+
xmax = xcenter + w / 2.
|
138 |
+
return BoxList(torch.stack([ymin, xmin, ymax, xmax]).t())
|
139 |
+
|
140 |
+
|
141 |
+
def batch_decode(encoded_boxes, box_coder: FasterRcnnBoxCoder, anchors: BoxList):
|
142 |
+
"""Decode a batch of encoded boxes.
|
143 |
+
|
144 |
+
This op takes a batch of encoded bounding boxes and transforms
|
145 |
+
them to a batch of bounding boxes specified by their corners in
|
146 |
+
the order of [y_min, x_min, y_max, x_max].
|
147 |
+
|
148 |
+
Args:
|
149 |
+
encoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
|
150 |
+
code_size] representing the location of the objects.
|
151 |
+
box_coder: a BoxCoder object.
|
152 |
+
anchors: a BoxList of anchors used to encode `encoded_boxes`.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
decoded_boxes: a float32 tensor of shape [batch_size, num_anchors, coder_size]
|
156 |
+
representing the corners of the objects in the order of [y_min, x_min, y_max, x_max].
|
157 |
+
|
158 |
+
Raises:
|
159 |
+
ValueError: if batch sizes of the inputs are inconsistent, or if
|
160 |
+
the number of anchors inferred from encoded_boxes and anchors are inconsistent.
|
161 |
+
"""
|
162 |
+
assert len(encoded_boxes.shape) == 3
|
163 |
+
if encoded_boxes.shape[1] != anchors.num_boxes():
|
164 |
+
raise ValueError('The number of anchors inferred from encoded_boxes'
|
165 |
+
' and anchors are inconsistent: shape[1] of encoded_boxes'
|
166 |
+
' %s should be equal to the number of anchors: %s.' %
|
167 |
+
(encoded_boxes.shape[1], anchors.num_boxes()))
|
168 |
+
|
169 |
+
decoded_boxes = torch.stack([
|
170 |
+
box_coder.decode(boxes, anchors).boxes for boxes in encoded_boxes.unbind()
|
171 |
+
])
|
172 |
+
return decoded_boxes
|