|
import math |
|
import random |
|
import torch |
|
from torch.utils import data |
|
from torchvision import transforms |
|
import numpy as np |
|
|
|
from datasets import omniglotNshot |
|
import utils |
|
|
|
|
|
def _compute_maxtranslations(single_image_tensor, dim, background): |
|
assert len(single_image_tensor.shape) == 2 |
|
content_rows = ((single_image_tensor == background).all(dim=1 - dim) == False).nonzero() |
|
begin, end = content_rows[0], content_rows[-1] |
|
return torch.cat([-begin, single_image_tensor.shape[dim] - end - 1]).cpu().tolist() |
|
|
|
|
|
def compute_maxtranslations_x_y(single_image_tensor, background): |
|
return _compute_maxtranslations(single_image_tensor, 1, background), _compute_maxtranslations(single_image_tensor, |
|
0, background) |
|
|
|
|
|
def translate(img, trans_x, trans_y): |
|
return transforms.functional.affine(img.unsqueeze(0), angle=0.0, translate=[trans_x, trans_y], scale=1.0, |
|
interpolation=transforms.InterpolationMode.NEAREST, shear=[0.0, 0.0], |
|
fill=0.).squeeze(0) |
|
|
|
def translate_omniglot(image_tensor, background=0.): |
|
flat_image_tensor = image_tensor.view(-1, *image_tensor.shape[-2:]) |
|
for i, image in enumerate(flat_image_tensor): |
|
max_x, max_y = compute_maxtranslations_x_y(image, background) |
|
flat_image_tensor[i] = translate(image, random.randint(*max_x), random.randint(*max_y)) |
|
return flat_image_tensor.view(*image_tensor.shape) |
|
|
|
|
|
class DataLoader(data.DataLoader): |
|
def __init__(self, num_steps, batch_size, seq_len, num_features, num_outputs, num_classes_used=1200, fuse_x_y=False, train=True, translations=True, jonas_style=False): |
|
|
|
utils.set_locals_in_self(locals()) |
|
assert not fuse_x_y, 'So far don\' support fusing.' |
|
imgsz = math.isqrt(num_features) |
|
assert imgsz * imgsz == num_features |
|
assert ((seq_len-1) // num_outputs) * num_outputs == seq_len - 1 |
|
if jonas_style: |
|
self.d = omniglotNshot.OmniglotNShotJonas('omniglot', batchsz=batch_size, n_way=num_outputs, |
|
k_shot=((seq_len - 1) // num_outputs), |
|
k_query=1, imgsz=imgsz) |
|
else: |
|
self.d = omniglotNshot.OmniglotNShot('omniglot', batchsz=batch_size, n_way=num_outputs, |
|
k_shot=((seq_len - 1) // num_outputs), |
|
k_query=1, imgsz=imgsz, num_train_classes_used=num_classes_used) |
|
|
|
|
|
def __len__(self): |
|
return self.num_steps |
|
|
|
def __iter__(self): |
|
|
|
def t(x, y, x_q, y_q): |
|
x = np.concatenate([x,x_q[:,:1]], 1) |
|
y = np.concatenate([y,y_q[:,:1]], 1) |
|
y = torch.from_numpy(y).transpose(0, 1) |
|
target_y = y.clone().detach() |
|
target_y[:-1] = -100 |
|
x = torch.from_numpy(x) |
|
if self.translations and self.train: |
|
x = translate_omniglot(x) |
|
image_tensor = x.view(*x.shape[:2], -1).transpose(0, 1), y |
|
return image_tensor, target_y |
|
|
|
return (t(*self.d.next(mode='train' if self.train else 'test')) for _ in range(self.num_steps)) |
|
|
|
@torch.no_grad() |
|
def validate(self, finetuned_model, eval_pos=-1): |
|
finetuned_model.eval() |
|
device = next(iter(finetuned_model.parameters())).device |
|
|
|
if not hasattr(self, 't_dl'): |
|
self.t_dl = DataLoader(num_steps=self.num_steps, batch_size=self.batch_size, seq_len=self.seq_len, |
|
num_features=self.num_features, num_outputs=self.num_outputs, fuse_x_y=self.fuse_x_y, |
|
train=False) |
|
|
|
ps = [] |
|
ys = [] |
|
for x,y in self.t_dl: |
|
p = finetuned_model(tuple(e.to(device) for e in x), single_eval_pos=eval_pos) |
|
ps.append(p) |
|
ys.append(y) |
|
|
|
ps = torch.cat(ps,1) |
|
ys = torch.cat(ys,1) |
|
|
|
def acc(ps,ys): |
|
return (ps.argmax(-1)==ys.to(ps.device)).float().mean() |
|
|
|
a = acc(ps[eval_pos], ys[eval_pos]).cpu() |
|
return a |
|
|