|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
import os |
|
import numpy as np |
|
import time |
|
import datetime |
|
import json |
|
import math |
|
import torch |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader |
|
|
|
import shutil |
|
import torch.distributed as dist |
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], "..")) |
|
from data_utils.resnet import resnet50 |
|
import data_utils.datasets_common as dset |
|
from data_utils.cocostuff_dataset import CocoStuff |
|
|
|
|
|
class CenterCropLongEdge(object): |
|
"""Crops the given PIL Image on the long edge. |
|
Parameters |
|
---------- |
|
size: sequence or int |
|
Desired output size of the crop. If size is an int instead of sequence like (h, w), |
|
a square crop (size, size) is made. |
|
""" |
|
|
|
def __call__(self, img): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be cropped. |
|
Returns: |
|
PIL Image: Cropped image. |
|
""" |
|
return transforms.functional.center_crop(img, min(img.size)) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ |
|
|
|
|
|
|
|
class DistributedSampler(torch.utils.data.sampler.Sampler): |
|
"""Sampler that restricts data loading to a subset of the dataset. |
|
|
|
It is especially useful in conjunction with |
|
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each |
|
process can pass a DistributedSampler instance as a DataLoader sampler, |
|
and load a subset of the original dataset that is exclusive to it. |
|
|
|
.. note:: |
|
Dataset is assumed to be of constant size. |
|
|
|
Arguments: |
|
dataset: Dataset used for sampling. |
|
num_replicas (optional): Number of processes participating in |
|
distributed training. |
|
rank (optional): Rank of the current process within num_replicas. |
|
shuffle (optional): If true (default), sampler will shuffle the indices |
|
""" |
|
|
|
def __init__( |
|
self, dataset, num_replicas=None, rank=None, shuffle=True, weights=None |
|
): |
|
if num_replicas is None: |
|
if not torch.dist.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
num_replicas = torch.dist.get_world_size() |
|
if rank is None: |
|
if not torch.dist.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
rank = torch.dist.get_rank() |
|
self.dataset = dataset |
|
self.num_replicas = num_replicas |
|
self.rank = rank |
|
self.epoch = 0 |
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
|
self.total_size = self.num_samples * self.num_replicas |
|
self.shuffle = shuffle |
|
self.weights = weights |
|
|
|
def __iter__(self): |
|
|
|
g = torch.Generator() |
|
g.manual_seed(self.epoch) |
|
if self.shuffle: |
|
if self.weights is not None: |
|
print("using class balanced!") |
|
indices = torch.multinomial( |
|
self.weights, len(self.dataset), replacement=True, generator=g |
|
).tolist() |
|
else: |
|
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
|
else: |
|
indices = list(range(len(self.dataset))) |
|
|
|
|
|
indices += indices[: (self.total_size - len(indices))] |
|
assert len(indices) == self.total_size |
|
|
|
|
|
indices = indices[self.rank : self.total_size : self.num_replicas] |
|
assert len(indices) == self.num_samples |
|
|
|
return iter(indices) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
|
|
class CheckpointedSampler(torch.utils.data.Sampler): |
|
"""Resumable sample with a random generated initialized with a given seed. |
|
|
|
Arguments |
|
--------- |
|
data_source: Dataset |
|
Dataset to sample from. |
|
start_itr: int, optional |
|
Number of iteration to start (or restart) the sampling. |
|
start_epoch: int, optional |
|
Number of epoch to start (or restart) the sampling. |
|
batch_size: int, optional |
|
Batch size. |
|
class_balanced: bool, optional |
|
Sample the data with a class balancing approach. |
|
custom_distrib_gen: bool, optional |
|
Use a temperature controlled class balancing. |
|
samples_per_class: list, optional |
|
A list of int values that indicate the number of samples per class. |
|
class_probabilities: list, optional |
|
A list of float values indicating the probability of a class in the dataset. |
|
longtail_temperature: float, optional |
|
Temperature value to smooth the longtail distribution with a softmax function. |
|
seed: int, optional |
|
Random seed used. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
data_source, |
|
start_itr=0, |
|
start_epoch=0, |
|
batch_size=128, |
|
class_balanced=False, |
|
custom_distrib_gen=False, |
|
samples_per_class=None, |
|
class_probabilities=None, |
|
longtail_temperature=1, |
|
seed=0, |
|
): |
|
self.data_source = data_source |
|
self.num_samples = len(self.data_source) |
|
self.start_itr = start_itr % (len(self.data_source) // batch_size) |
|
self.start_epoch = start_epoch |
|
self.batch_size = batch_size |
|
self.class_balanced = class_balanced |
|
self.custom_distrib_gen = custom_distrib_gen |
|
self.generator = torch.Generator() |
|
self.generator.manual_seed(seed) |
|
|
|
if self.class_balanced: |
|
print("Class balanced sampling.") |
|
self.weights = make_weights_for_balanced_classes( |
|
samples_per_class, |
|
self.data_source.labels, |
|
1000, |
|
self.custom_distrib_gen, |
|
longtail_temperature, |
|
class_probabilities=class_probabilities, |
|
) |
|
self.weights = torch.DoubleTensor(self.weights) |
|
|
|
|
|
print( |
|
"Using the generator ", |
|
self.start_epoch, |
|
" times to resume where we left off.", |
|
) |
|
|
|
for epoch in range(self.start_epoch): |
|
self._sample_epoch_perm() |
|
|
|
if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
|
raise ValueError( |
|
"num_samples should be a positive integeral " |
|
"value, but got num_samples={}".format(self.num_samples) |
|
) |
|
|
|
def _sample_epoch_perm(self): |
|
if self.class_balanced: |
|
out = [ |
|
torch.multinomial( |
|
self.weights, |
|
len(self.data_source), |
|
replacement=True, |
|
generator=self.generator, |
|
) |
|
] |
|
else: |
|
out = [torch.randperm(len(self.data_source), generator=self.generator)] |
|
return out |
|
|
|
def __iter__(self): |
|
out = self._sample_epoch_perm() |
|
output = torch.cat(out).tolist() |
|
return iter(output) |
|
|
|
def __len__(self): |
|
return len(self.data_source) |
|
|
|
|
|
def make_weights_for_balanced_classes( |
|
samples_per_class, |
|
labels=None, |
|
nclasses=None, |
|
custom_distrib_gen=False, |
|
longtail_temperature=1, |
|
class_probabilities=None, |
|
): |
|
"""It prepares the sampling weights for the DataLoader. |
|
|
|
Arguments |
|
--------- |
|
samples_per_class: list |
|
A list of int values (size C) that indicate the number of samples per class, |
|
for all C classes. |
|
labels: list/ NumPy array/ torch Tensor, optional |
|
A list of size N that contains a class label for each sample. |
|
nclasses: int, optional |
|
Number of classes in the dataset. |
|
custom_distrib_gen: bool, optional |
|
Use a temperature controlled class balancing. |
|
longtail_temperature: float, optional |
|
Temperature value to smooth the longtail distribution with a softmax function. |
|
class_probabilities: list |
|
A list of float values (size C) indicating the probability of a class in the dataset. |
|
seed: int |
|
Random seed used. |
|
Returns |
|
------- |
|
If custom_distrib_gen is True, a torch Tensor with size C, where C is the number of classes, |
|
that contains the sampling weights for each class. |
|
If custom_distrib_gen is False, a list with size N (dataset size) that contains the sampling |
|
weights for each individual data sample. |
|
|
|
""" |
|
if custom_distrib_gen: |
|
|
|
print( |
|
"Temperature controlled distribution for balanced classes! " "Temperature:", |
|
longtail_temperature, |
|
) |
|
class_prob = torch.log(torch.DoubleTensor(class_probabilities)) |
|
weight_per_class = torch.exp(class_prob / longtail_temperature) / torch.sum( |
|
torch.exp(class_prob / longtail_temperature) |
|
) |
|
else: |
|
count = [0] * nclasses |
|
for item in labels: |
|
count[item] += 1 |
|
weight_per_class = [0.0] * nclasses |
|
N = float(sum(count)) |
|
for i in range(nclasses): |
|
|
|
weight_per_class[i] = N / float(count[i]) |
|
|
|
weight = [0] * len(labels) |
|
for idx, val in enumerate(labels): |
|
|
|
|
|
weight[idx] = (1 / samples_per_class[val]) * weight_per_class[val] |
|
return weight |
|
|
|
|
|
def load_pretrained_feature_extractor( |
|
pretrained_path="", |
|
feature_extractor="classification", |
|
backbone_feature_extractor="resnet50", |
|
): |
|
"""It loads a pre-trained feature extractor. |
|
|
|
Arguments |
|
--------- |
|
pretrained_path: str, optional |
|
Path to the feature extractor's weights. |
|
feature_extractor: str, optional |
|
If "classification" a network trained on ImageNet for classification will be used. If |
|
"selfsupervised", a network trained on ImageNet with self-supervision will be used. |
|
backbone_feature_extractor: str, optional |
|
Name of the backbone for the feature extractor. Currently, only ResNet50 is supported. |
|
Returns |
|
------- |
|
A Pytorch network initialized with pre-trained weights. |
|
|
|
""" |
|
if backbone_feature_extractor == "resnet50": |
|
print("using resnet50 to extract features") |
|
net = resnet50( |
|
pretrained=False if pretrained_path != "" else True, classifier_run=False |
|
).cuda() |
|
else: |
|
raise ValueError("Not implemented for backbones other than ResNet50.") |
|
if pretrained_path != "": |
|
print("Loading pretrained weights from: ", pretrained_path) |
|
|
|
|
|
state_dict = torch.load(pretrained_path) |
|
if not feature_extractor == "selfsupervised": |
|
state_dict = state_dict["state_dict_best"]["feat_model"] |
|
|
|
|
|
from collections import OrderedDict |
|
|
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
if "module." in k: |
|
name = k[7:] |
|
elif "_feature_blocks." in k: |
|
name = k.replace("_feature_blocks.", "") |
|
else: |
|
name = k |
|
if name in net.state_dict().keys(): |
|
new_state_dict[name] = v |
|
else: |
|
print("key ", name, " not in dict") |
|
|
|
for key in net.state_dict().keys(): |
|
if key not in new_state_dict.keys(): |
|
print("Network key ", key, " not in dict to load") |
|
if not feature_extractor == "selfsupervised": |
|
state_dict = torch.load(pretrained_path)["state_dict_best"]["classifier"] |
|
|
|
for k, v in state_dict.items(): |
|
name = k[7:] |
|
new_state_dict[name] = v |
|
|
|
net.load_state_dict( |
|
new_state_dict, |
|
strict=False if feature_extractor == "selfsupervised" else True, |
|
) |
|
else: |
|
print("Using pretrained weights on full ImageNet.") |
|
return net |
|
|
|
|
|
def get_dataset_images( |
|
resolution, |
|
data_path, |
|
load_in_mem=False, |
|
augment=False, |
|
longtail=False, |
|
split="train", |
|
test_part=False, |
|
which_dataset="imagenet", |
|
instance_json="", |
|
stuff_json="", |
|
**kwargs |
|
): |
|
"""It prepares a dataset that reads the files from a folder. |
|
|
|
Arguments |
|
--------- |
|
resolution: int |
|
Dataset resolution. |
|
data_path: str |
|
Path where to find the data. |
|
load_in_mem: bool, optional |
|
If True, load all data in memory. |
|
augment: bool, optional |
|
If True, use horizontal flips as data augmentation. |
|
longtail: bool, optional |
|
If True, use the longtailed version of ImageNet (ImageNet-LT). |
|
split: str, optional |
|
Split name to use. |
|
test_part: bool, optional |
|
Only used for COCO-Stuff. If True, use the evaluation set instead of the validation set. |
|
which_dataset: str, optional |
|
Dataset name. |
|
instance_json: str, optional |
|
Path where to find the JSON data for COCO-Stuff instances. |
|
stuff_json: str, optional |
|
Path where to find the JSON data for COCO-Stuff stuff. |
|
Returns |
|
------- |
|
A Dataset class. |
|
|
|
""" |
|
|
|
norm_mean = [0.5, 0.5, 0.5] |
|
norm_std = [0.5, 0.5, 0.5] |
|
if which_dataset not in ["coco"]: |
|
transform_list = [CenterCropLongEdge(), transforms.Resize(resolution)] |
|
else: |
|
transform_list = [transforms.Resize(resolution)] |
|
transform_list = transforms.Compose( |
|
transform_list |
|
+ [transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)] |
|
) |
|
if augment: |
|
transform_list = transforms.Compose( |
|
transform_list + [transforms.RandomHorizontalFlip()] |
|
) |
|
|
|
if which_dataset not in ["coco"]: |
|
which_dataset_file = dset.ImageFolder |
|
dataset_kwargs = {} |
|
else: |
|
print("Using coco-stuff dataset class") |
|
which_dataset_file = CocoStuff |
|
dataset_kwargs = { |
|
"image_dir": data_path, |
|
"instances_json": instance_json, |
|
"stuff_json": stuff_json, |
|
"image_size": resolution, |
|
"iscrowd": True if split == "train" else False, |
|
"test_part": test_part, |
|
} |
|
dataset = which_dataset_file( |
|
root=data_path, |
|
transform=transform_list, |
|
load_in_mem=load_in_mem, |
|
split=split, |
|
longtail=longtail, |
|
**dataset_kwargs |
|
) |
|
return dataset |
|
|
|
|
|
def get_dataset_hdf5( |
|
resolution, |
|
data_path, |
|
augment=False, |
|
longtail=False, |
|
local_rank=0, |
|
copy_locally=False, |
|
ddp=True, |
|
tmp_dir="", |
|
class_cond=True, |
|
instance_cond=False, |
|
feature_extractor="classification", |
|
backbone_feature_extractor="resnext50", |
|
which_nn_balance="instance_balance", |
|
which_dataset="imagenet", |
|
split="train", |
|
test_part=False, |
|
kmeans_subsampled=-1, |
|
n_subsampled_data=-1, |
|
feature_augmentation=False, |
|
filter_hd=-1, |
|
k_nn=50, |
|
load_in_mem_feats=False, |
|
compute_nns=False, |
|
**kwargs |
|
): |
|
"""It prepares a dataset that reads the data from HDF5 files. |
|
|
|
Arguments |
|
--------- |
|
resolution: int |
|
Dataset resolution. |
|
data_path: str |
|
Path where to find the data. |
|
load_in_mem: bool, optional |
|
If True, load all data in memory. |
|
augment: bool, optional |
|
If True, use horizontal flips as data augmentation. |
|
longtail: bool, optional |
|
If True, use the longtailed version of ImageNet (ImageNet-LT). |
|
local_rank: int, optional |
|
Index indicating the rank of the DistributedDataParallel (DDP) process in the local |
|
machine. It is set to 0 by default or if DDP is not used. |
|
copy_locally: bool, optional |
|
If true, the HDF5 files will be copied locally to the machine. |
|
Useful if the data is in a server. |
|
ddp: bool, optional |
|
If True, use DistributedDataParallel (DDP). |
|
tmp_dir: str, optional |
|
Path where to copy the dataset HDF5 files locally. |
|
class_cond: bool, optional |
|
If True, the dataset will load the labels of the neighbor real samples. |
|
instance_cond: bool, optional |
|
If True, the dataset will load the instance features. |
|
feature_extractor: str, optional |
|
If "classification" a network trained on ImageNet for classification will be used. If |
|
"selfsupervised", a network trained on ImageNet with self-supervision will be used. |
|
backbone_feature_extractor: str, optional |
|
Name of the backbone for the feature extractor. Currently, only ResNet50 is supported. |
|
which_nn_balance: str, optional |
|
Whether to sample an instance or a neighbor class first. By default, |
|
``instance_balance`` is used. Using ``nnclass_balance`` allows class balancing |
|
to be applied. |
|
split: str, optional |
|
Split name to use. |
|
test_part: bool, optional |
|
Only used for COCO-Stuff. If True, use the evaluation set instead of the validation set. |
|
kmeans_subsampled: int, optional |
|
If other than -1, that number of data points are selected with k-means from the dataset. |
|
It reduces the amount of available data to train or test the model. |
|
n_subsampled_data: int, optional |
|
If other than -1, that number of data points are randomly selected from the dataset. |
|
It reduces the amount of available data to train or test the model. |
|
feature_augmentation: bool, optional |
|
Use the instance features of the flipped ground-truth image instances as |
|
conditioning, with a 50% probability. |
|
filter_hd: int, optional |
|
Only used for COCO-Stuff dataset. If -1, all COCO-Stuff evaluation set is used. |
|
If 0, only images with seen class combinations are used. |
|
If 1, only images with unseen class combinations are used. |
|
k_nn: int, optional |
|
Size of the neighborhood obtained with the k-NN algorithm. |
|
load_in_mem_feats: bool, optional |
|
Load all instance features in memory. |
|
compute_nns: bool, optional |
|
If True, compute the nearest neighbors. If False, load them from a file with |
|
pre-computed neighbors. |
|
Returns |
|
------- |
|
A Dataset class. |
|
|
|
""" |
|
|
|
if which_dataset in ["imagenet", "imagenet_lt"]: |
|
dataset_name_prefix = "ILSVRC" |
|
elif which_dataset == "coco": |
|
dataset_name_prefix = "COCO" |
|
else: |
|
dataset_name_prefix = which_dataset |
|
|
|
hdf5_filename = "%s%i%s%s%s" % ( |
|
dataset_name_prefix, |
|
resolution, |
|
"" if not longtail else "longtail", |
|
"_val" if split == "val" else "", |
|
"_test" if test_part else "", |
|
) |
|
|
|
|
|
data_path_xy = os.path.join(data_path, hdf5_filename + "_xy.hdf5") |
|
data_path_feats, data_path_nns, kmeans_file = None, None, None |
|
if instance_cond: |
|
data_path_feats = os.path.join( |
|
data_path, |
|
hdf5_filename |
|
+ "_feats_%s_%s.hdf5" % (feature_extractor, backbone_feature_extractor), |
|
) |
|
if not compute_nns: |
|
data_path_nns = os.path.join( |
|
data_path, |
|
hdf5_filename |
|
+ "_feats_%s_%s_nn_k%i.hdf5" |
|
% (feature_extractor, backbone_feature_extractor, k_nn), |
|
) |
|
|
|
if kmeans_subsampled > -1: |
|
if which_dataset == "imagenet": |
|
d_name = "IN" |
|
elif which_dataset == "coco": |
|
d_name = "COCO" |
|
else: |
|
d_name = which_dataset |
|
kmeans_file = ( |
|
d_name |
|
+ "_res" |
|
+ str(resolution) |
|
+ "_rn50_" |
|
+ feature_extractor |
|
+ "_kmeans_k" |
|
+ str(kmeans_subsampled) |
|
+ ".npy" |
|
) |
|
kmeans_file = os.path.join(data_path, kmeans_file) |
|
|
|
|
|
if copy_locally: |
|
tmp_file = os.path.join(tmp_dir, hdf5_filename + "_xy.hdf5") |
|
print(tmp_file) |
|
if instance_cond: |
|
tmp_file_feats = os.path.join( |
|
tmp_dir, |
|
hdf5_filename |
|
+ "_feats_%s_%s.hdf5" % (feature_extractor, backbone_feature_extractor), |
|
) |
|
print(tmp_file_feats) |
|
|
|
|
|
if local_rank == 0: |
|
shutil.copy2(data_path_xy, tmp_file) |
|
if instance_cond: |
|
shutil.copy2(data_path_feats, tmp_file_feats) |
|
data_path_xy = tmp_file |
|
if instance_cond: |
|
data_path_feats = tmp_file_feats |
|
|
|
|
|
if ddp: |
|
dist.barrier() |
|
|
|
|
|
if augment: |
|
transform_list = transforms.RandomHorizontalFlip() |
|
else: |
|
transform_list = None |
|
|
|
dataset = dset.ILSVRC_HDF5_feats( |
|
root=data_path_xy, |
|
root_feats=data_path_feats, |
|
root_nns=data_path_nns, |
|
transform=transform_list, |
|
load_labels=class_cond, |
|
load_features=instance_cond, |
|
load_in_mem_images=False, |
|
load_in_mem_labels=True, |
|
load_in_mem_feats=load_in_mem_feats, |
|
k_nn=k_nn, |
|
which_nn_balance=which_nn_balance, |
|
kmeans_file=kmeans_file, |
|
n_subsampled_data=n_subsampled_data, |
|
feature_augmentation=feature_augmentation, |
|
filter_hd=filter_hd, |
|
) |
|
return dataset |
|
|
|
|
|
def filter_by_hd(ood_distance): |
|
"""Pre-select image indexes in COCO-Stuff evaluation set according to its class composition. |
|
|
|
Parameters |
|
---------- |
|
ood_distance: int |
|
Minimum hamming distance (HD) between the set of classes present in the evaluation image |
|
and all training images. |
|
If 0, pre-selected images will be the ones that only contain class sets already seen |
|
during training. |
|
If other than 0, all other images with unseen class sets will be selected, |
|
regardless of the hamming distance (HD>0). |
|
Returns |
|
------- |
|
List of pre-selected images. |
|
""" |
|
|
|
image_ids_original = np.load( |
|
"../coco_stuff_val_indexes/cocostuff_val2_all_idxs.npy", allow_pickle=True |
|
) |
|
print("Filtering new ids!") |
|
odd_image_ids = np.load( |
|
os.path.join( |
|
"../coco_stuff_val_indexes", "val2" + "_image_ids_by_hd_75ktraining_im.npy" |
|
), |
|
allow_pickle=True, |
|
) |
|
if ood_distance == 0: |
|
image_ids = odd_image_ids[ood_distance] |
|
else: |
|
total_img_ids = [] |
|
for ood_dist in range(1, len(odd_image_ids)): |
|
total_img_ids += odd_image_ids[ood_dist] |
|
image_ids = total_img_ids |
|
|
|
allowed_idxs = [] |
|
for i_idx, id in enumerate(image_ids_original): |
|
if id in image_ids: |
|
allowed_idxs.append(i_idx) |
|
allowed_idxs = np.array(allowed_idxs) |
|
print("Num images after filtering ", len(allowed_idxs)) |
|
return allowed_idxs |
|
|
|
|
|
def get_dataloader( |
|
dataset, |
|
batch_size=64, |
|
num_workers=8, |
|
shuffle=True, |
|
pin_memory=True, |
|
drop_last=True, |
|
start_itr=0, |
|
start_epoch=0, |
|
use_checkpointable_sampler=False, |
|
use_balanced_sampler=False, |
|
custom_distrib_gen=False, |
|
samples_per_class=None, |
|
class_probabilities=None, |
|
seed=0, |
|
longtail_temperature=1, |
|
rank=0, |
|
world_size=-1, |
|
**kwargs |
|
): |
|
"""Get DataLoader to iterate over the dataset. |
|
|
|
Parameters |
|
---------- |
|
dataset: Dataset |
|
Class with the specified dataset characteristics. |
|
batch_size: int, optional |
|
Batch size. |
|
num_workers: int, optional |
|
Number of workers for the dataloader. |
|
shuffle: bool, optional |
|
If True, the data is shuffled. If a sampler is used (use_checkpointable_sampler=True, |
|
use_balanced_sampler=True or world_size>-1), this parameter is not used. |
|
pin_memory: bool, optional |
|
Pin memory in the dataloader. |
|
drop_last: bool, optional |
|
Drop last incomplete batch in the dataloader. |
|
start_itr: int, optional |
|
Iteration number to resume the sample from. Only used with |
|
use_checkpointable_sampler=True. |
|
start_epoch: int, optional |
|
Epoch number to resume the sample from. Only used with |
|
use_checkpointable_sampler=True. |
|
use_checkpointable_sampler: bool, optional |
|
If True, use the CheckpointedSampler class to resume jobs from the last seen batch |
|
(deterministic). |
|
use_balanced_sampler: bool, optional |
|
If True, balance the data according to a specific class distribution. Use in conjunction |
|
with ``custom_distrib_gen``, ``samples_per_class``, ``class_probabilities`` and |
|
``longtail_temperature``. |
|
custom_distrib_gen: bool, optional |
|
Use a temperature controlled class balancing. |
|
samples_per_class: list, optional |
|
A list of int values that indicate the number of samples per class. |
|
class_probabilities: list, optional |
|
A list of float values indicating the probability of a class in the dataset. |
|
longtail_temperature: float, optional |
|
Temperature value to smooth the longtail distribution with a softmax function. |
|
seed: int, optional |
|
Random seed used. |
|
rank: int, optional |
|
Rank of the current process (if using DistributedDataParallel training). |
|
world_size: int, optional |
|
World size (if using DistributedDataParallel training). |
|
Returns |
|
------- |
|
An instance of DataLoader. |
|
""" |
|
|
|
|
|
|
|
|
|
loader_kwargs = { |
|
"num_workers": num_workers, |
|
"pin_memory": pin_memory, |
|
"drop_last": drop_last, |
|
} |
|
print("Dropping last batch? ", drop_last) |
|
|
|
|
|
|
|
if use_checkpointable_sampler: |
|
print( |
|
"Using checkpointable sampler from start_itr %d..., using seed %d" |
|
% (start_itr, seed) |
|
) |
|
|
|
sampler = CheckpointedSampler( |
|
dataset, |
|
start_itr, |
|
start_epoch, |
|
batch_size, |
|
class_balanced=use_balanced_sampler, |
|
custom_distrib_gen=custom_distrib_gen, |
|
longtail_temperature=longtail_temperature, |
|
samples_per_class=samples_per_class, |
|
class_probabilities=class_probabilities, |
|
seed=seed, |
|
) |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
sampler=sampler, |
|
shuffle=False, |
|
worker_init_fn=seed_worker, |
|
**loader_kwargs |
|
) |
|
else: |
|
if use_balanced_sampler: |
|
print("Balancing real data! Custom? ", custom_distrib_gen) |
|
weights = make_weights_for_balanced_classes( |
|
samples_per_class, |
|
dataset.labels, |
|
1000, |
|
custom_distrib_gen, |
|
longtail_temperature, |
|
class_probabilities=class_probabilities, |
|
) |
|
weights = torch.DoubleTensor(weights) |
|
else: |
|
weights = None |
|
if world_size == -1: |
|
if use_balanced_sampler: |
|
sampler = torch.utils.data.sampler.WeightedRandomSampler( |
|
weights, len(weights) |
|
) |
|
shuffle = False |
|
else: |
|
sampler = None |
|
else: |
|
sampler = DistributedSampler( |
|
dataset, num_replicas=world_size, rank=rank, weights=weights |
|
) |
|
shuffle = False |
|
print("Loader workers?", loader_kwargs, " with shuffle?", shuffle) |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=shuffle, |
|
sampler=sampler, |
|
worker_init_fn=seed_worker if use_checkpointable_sampler else None, |
|
**loader_kwargs |
|
) |
|
|
|
return loader |
|
|
|
|
|
def sample_conditioning_values( |
|
z_, |
|
y_, |
|
ddp=False, |
|
batch_size=1, |
|
weights_sampling=None, |
|
dataset=None, |
|
constant_conditioning=False, |
|
class_cond=True, |
|
instance_cond=False, |
|
nn_sampling_strategy="instance_balance", |
|
): |
|
"""It samples conditionings from the noise distribution and dataset statistics. |
|
|
|
Arguments |
|
--------- |
|
z_: Distribution |
|
Noise distribution. |
|
y_: Distribution |
|
Labels distribution ( |
|
ddp: bool, optional |
|
If True, use DistributedDataParallel (DDP). |
|
batch_size: int, optional |
|
Batch size. |
|
weights_sampling: NumPy array, optional |
|
Weights to balance the sampling of the conditionings. |
|
dataset: Dataset |
|
Instance of a dataset. |
|
constant_conditioning: bool, optional |
|
If True, set all labels to zero. |
|
class_cond: bool, optional |
|
If True, the dataset will load the labels of the neighbor real samples. |
|
instance_cond: bool, optional |
|
If True, the dataset will load the instance features. |
|
nn_sampling_strategy: str, optional |
|
Whether to sample an instance or a neighbor class first. By default, |
|
``instance_balance`` is used. Using ``nnclass_balance`` allows class balancing |
|
to be applied. |
|
Returns |
|
------- |
|
If not using labels (class_cond=False) nor instance features (instance_cond=False), |
|
return the sampled noise vectors. |
|
If not using labels (class_cond=False), return the sampled noise vectors and instance |
|
feature vectors, sampled according to the ``nn_sampling_strategy`` and ``weights_sampling``. |
|
If using labels (class_cond=True), return the sampled noise vectors, instance feature |
|
vectors and the neighbor class labels. |
|
|
|
""" |
|
with torch.no_grad(): |
|
z_.sample_() |
|
if not class_cond and not instance_cond: |
|
return z_ |
|
elif class_cond and not instance_cond: |
|
y_.sample_() |
|
if constant_conditioning: |
|
return z_, torch.zeros_like(y_) |
|
else: |
|
if ddp: |
|
return z_, y_ |
|
else: |
|
return z_, y_.data.clone() |
|
else: |
|
if nn_sampling_strategy == "instance_balance": |
|
sampling_funct_name = dataset.sample_conditioning_instance_balance |
|
elif nn_sampling_strategy == "nnclass_balance": |
|
sampling_funct_name = dataset.sample_conditioning_nnclass_balance |
|
|
|
labels_g, f_g = sampling_funct_name(batch_size, weights_sampling) |
|
if instance_cond and not class_cond: |
|
return z_, f_g |
|
elif instance_cond and class_cond: |
|
return z_, labels_g, f_g |
|
|
|
|
|
|
|
def prepare_z_y( |
|
G_batch_size, |
|
dim_z, |
|
nclasses, |
|
device="cuda", |
|
fp16=False, |
|
z_var=1.0, |
|
longtail_gen=False, |
|
custom_distrib=False, |
|
longtail_temperature=1, |
|
class_probabilities=None, |
|
): |
|
"""Prepare the noise and label distributions. |
|
|
|
Arguments |
|
--------- |
|
G_batch_size: int |
|
Batch size for the generator. |
|
dim_z: int |
|
Noise vector dimensionality. |
|
nclasses: int |
|
Number of classes in the dataset |
|
fp16: bool, optional |
|
Float16. |
|
z_var: float, optional |
|
Variance for the noise normal distribution. |
|
longtail_gen: bool, optional |
|
If true, use the longtail distribution for the classes (ImageNet-LT) |
|
custom_distrib: bool, optional |
|
If true, use a temperature annealed class distribution. |
|
longtail_temperature: float, optional |
|
Temperature value to smooth the longtail distribution with a softmax function. |
|
class_probabilities: list, optional |
|
A list of float values indicating the probability of a class in the dataset. |
|
|
|
Returns |
|
------- |
|
The noise and class distributions. |
|
""" |
|
z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) |
|
z_.init_distribution("normal", mean=0, var=z_var) |
|
|
|
|
|
if fp16: |
|
z_ = z_.half() |
|
|
|
y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) |
|
if longtail_gen: |
|
y_.init_distribution( |
|
"categorical_longtail", |
|
num_categories=nclasses, |
|
class_prob=class_probabilities, |
|
) |
|
elif custom_distrib: |
|
y_.init_distribution( |
|
"categorical_longtail_temperature", |
|
num_categories=nclasses, |
|
temperature=longtail_temperature, |
|
class_prob=class_probabilities, |
|
) |
|
else: |
|
y_.init_distribution("categorical", num_categories=nclasses) |
|
|
|
return z_, y_ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Distribution(torch.Tensor): |
|
|
|
def init_distribution(self, dist_type, class_prob=None, **kwargs): |
|
self.dist_type = dist_type |
|
self.dist_kwargs = kwargs |
|
if self.dist_type == "normal": |
|
self.mean, self.var = kwargs["mean"], kwargs["var"] |
|
elif self.dist_type == "categorical": |
|
self.num_categories = kwargs["num_categories"] |
|
elif self.dist_type == "categorical_longtail": |
|
print("(class conditioning sampler) using longtail distribution") |
|
self.num_categories = kwargs["num_categories"] |
|
self.class_prob = torch.DoubleTensor(class_prob) |
|
elif self.dist_type == "categorical_longtail_temperature": |
|
print( |
|
"(class conditioning sampler) Softening the long-tail distribution with temperature ", |
|
kwargs["temperature"], |
|
) |
|
self.num_categories = kwargs["num_categories"] |
|
self.class_prob = torch.log(torch.DoubleTensor(class_prob)) |
|
self.class_prob = torch.exp( |
|
self.class_prob / kwargs["temperature"] |
|
) / torch.sum(torch.exp(self.class_prob / kwargs["temperature"])) |
|
|
|
def seed_generator(self, seed): |
|
self.generator.manual_seed(seed) |
|
|
|
def sample_(self): |
|
if self.dist_type == "normal": |
|
self.normal_(self.mean, self.var) |
|
elif self.dist_type == "categorical": |
|
self.random_(0, self.num_categories) |
|
elif ( |
|
"categorical_longtail" in self.dist_type |
|
or "categorical_longtail_temperature" in self.dist_type |
|
): |
|
self.data = torch.multinomial( |
|
self.class_prob, len(self), replacement=True |
|
).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seed_worker(worker_id): |
|
worker_seed = torch.initial_seed() + worker_id |
|
|