File size: 3,416 Bytes
d2d52b7
513aed0
d2d52b7
 
 
 
 
 
 
513aed0
d2d52b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import sys

import torch
import torchvision
from fvcore.nn import FlopCountAnalysis
from torch import nn

sys.path.append("vision/references/segmentation")
from transforms import Compose
from coco_utils import ConvertCocoPolysToMask
from coco_utils import FilterAndRemapCocoCategories
from coco_utils import _coco_remove_images_without_annotations
from utils import ConfusionMatrix


class NanSafeConfusionMatrix(ConfusionMatrix):
    """Confusion matrix with replacement nans to zeros."""

    def __init__(self, num_classes):
        super().__init__(num_classes=num_classes)

    def compute(self):
        """Compute metrics based on confusion matrix."""
        confusion_matrix = self.mat.float()
        acc_global = torch.nan_to_num(torch.diag(confusion_matrix).sum() / confusion_matrix.sum())
        acc = torch.nan_to_num(torch.diag(confusion_matrix) / confusion_matrix.sum(1))
        intersection_over_unions = torch.nan_to_num(
            torch.diag(confusion_matrix)
            / (confusion_matrix.sum(1) + confusion_matrix.sum(0) - torch.diag(confusion_matrix))
        )
        return acc_global, acc, intersection_over_unions


def flops_calculation_function(model: nn.Module, input_sample: torch.Tensor) -> float:
    """Calculate number of flops in millions."""
    counter = FlopCountAnalysis(
        model=model.eval(),
        inputs=input_sample,
    )
    counter.unsupported_ops_warnings(False)
    counter.uncalled_modules_warnings(False)

    flops = counter.total() / input_sample.shape[0]

    return flops / 1e6


def get_coco(root, image_set, transforms, use_v2=False, use_orig=False):
    """Get COCO dataset with VOC or COCO classes."""
    paths = {
        "train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
        "val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
        # "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
    }
    if use_orig:
        classes_list = list(range(81))
    else:
        classes_list = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]

    img_folder, ann_file = paths[image_set]
    img_folder = os.path.join(root, img_folder)
    ann_file = os.path.join(root, ann_file)

    # The 2 "Compose" below achieve the same thing: converting coco detection
    # samples into segmentation-compatible samples. They just do it with
    # slightly different implementations. We could refactor and unify, but
    # keeping them separate helps keeping the v2 version clean
    if use_v2:
        import v2_extras  # pylint: disable=import-outside-toplevel
        from torchvision.datasets import wrap_dataset_for_transforms_v2  # pylint: disable=import-outside-toplevel

        transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
        dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
        dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
    else:
        transforms = Compose(
            [FilterAndRemapCocoCategories(classes_list, remap=True), ConvertCocoPolysToMask(), transforms]
        )
        dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)

    if image_set == "train":
        dataset = _coco_remove_images_without_annotations(dataset, classes_list)

    return dataset