|
import os |
|
import sys |
|
|
|
import torch |
|
import torchvision |
|
from fvcore.nn import FlopCountAnalysis |
|
from torch import nn |
|
|
|
sys.path.append("vision/references/segmentation") |
|
from transforms import Compose |
|
from coco_utils import ConvertCocoPolysToMask |
|
from coco_utils import FilterAndRemapCocoCategories |
|
from coco_utils import _coco_remove_images_without_annotations |
|
from utils import ConfusionMatrix |
|
|
|
|
|
class NanSafeConfusionMatrix(ConfusionMatrix): |
|
"""Confusion matrix with replacement nans to zeros.""" |
|
|
|
def __init__(self, num_classes): |
|
super().__init__(num_classes=num_classes) |
|
|
|
def compute(self): |
|
"""Compute metrics based on confusion matrix.""" |
|
confusion_matrix = self.mat.float() |
|
acc_global = torch.nan_to_num(torch.diag(confusion_matrix).sum() / confusion_matrix.sum()) |
|
acc = torch.nan_to_num(torch.diag(confusion_matrix) / confusion_matrix.sum(1)) |
|
intersection_over_unions = torch.nan_to_num( |
|
torch.diag(confusion_matrix) |
|
/ (confusion_matrix.sum(1) + confusion_matrix.sum(0) - torch.diag(confusion_matrix)) |
|
) |
|
return acc_global, acc, intersection_over_unions |
|
|
|
|
|
def flops_calculation_function(model: nn.Module, input_sample: torch.Tensor) -> float: |
|
"""Calculate number of flops in millions.""" |
|
counter = FlopCountAnalysis( |
|
model=model.eval(), |
|
inputs=input_sample, |
|
) |
|
counter.unsupported_ops_warnings(False) |
|
counter.uncalled_modules_warnings(False) |
|
|
|
flops = counter.total() / input_sample.shape[0] |
|
|
|
return flops / 1e6 |
|
|
|
|
|
def get_coco(root, image_set, transforms, use_v2=False, use_orig=False): |
|
"""Get COCO dataset with VOC or COCO classes.""" |
|
paths = { |
|
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")), |
|
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")), |
|
|
|
} |
|
if use_orig: |
|
classes_list = list(range(81)) |
|
else: |
|
classes_list = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] |
|
|
|
img_folder, ann_file = paths[image_set] |
|
img_folder = os.path.join(root, img_folder) |
|
ann_file = os.path.join(root, ann_file) |
|
|
|
|
|
|
|
|
|
|
|
if use_v2: |
|
import v2_extras |
|
from torchvision.datasets import wrap_dataset_for_transforms_v2 |
|
|
|
transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms]) |
|
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
|
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"}) |
|
else: |
|
transforms = Compose( |
|
[FilterAndRemapCocoCategories(classes_list, remap=True), ConvertCocoPolysToMask(), transforms] |
|
) |
|
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
|
|
|
if image_set == "train": |
|
dataset = _coco_remove_images_without_annotations(dataset, classes_list) |
|
|
|
return dataset |
|
|