Samuel Mueller
working locally
f50f696
import math
import random
import torch
from torch.utils import data
from torchvision import transforms
import numpy as np
from datasets import omniglotNshot
import utils
def _compute_maxtranslations(single_image_tensor, dim, background):
assert len(single_image_tensor.shape) == 2
content_rows = ((single_image_tensor == background).all(dim=1 - dim) == False).nonzero()
begin, end = content_rows[0], content_rows[-1]
return torch.cat([-begin, single_image_tensor.shape[dim] - end - 1]).cpu().tolist()
def compute_maxtranslations_x_y(single_image_tensor, background):
return _compute_maxtranslations(single_image_tensor, 1, background), _compute_maxtranslations(single_image_tensor,
0, background)
def translate(img, trans_x, trans_y):
return transforms.functional.affine(img.unsqueeze(0), angle=0.0, translate=[trans_x, trans_y], scale=1.0,
interpolation=transforms.InterpolationMode.NEAREST, shear=[0.0, 0.0],
fill=0.).squeeze(0)
def translate_omniglot(image_tensor, background=0.):
flat_image_tensor = image_tensor.view(-1, *image_tensor.shape[-2:])
for i, image in enumerate(flat_image_tensor):
max_x, max_y = compute_maxtranslations_x_y(image, background)
flat_image_tensor[i] = translate(image, random.randint(*max_x), random.randint(*max_y))
return flat_image_tensor.view(*image_tensor.shape)
class DataLoader(data.DataLoader):
def __init__(self, num_steps, batch_size, seq_len, num_features, num_outputs, num_classes_used=1200, fuse_x_y=False, train=True, translations=True, jonas_style=False):
# TODO position before last is predictable by counting..
utils.set_locals_in_self(locals())
assert not fuse_x_y, 'So far don\' support fusing.'
imgsz = math.isqrt(num_features)
assert imgsz * imgsz == num_features
assert ((seq_len-1) // num_outputs) * num_outputs == seq_len - 1
if jonas_style:
self.d = omniglotNshot.OmniglotNShotJonas('omniglot', batchsz=batch_size, n_way=num_outputs,
k_shot=((seq_len - 1) // num_outputs),
k_query=1, imgsz=imgsz)
else:
self.d = omniglotNshot.OmniglotNShot('omniglot', batchsz=batch_size, n_way=num_outputs,
k_shot=((seq_len - 1) // num_outputs),
k_query=1, imgsz=imgsz, num_train_classes_used=num_classes_used)
def __len__(self):
return self.num_steps
def __iter__(self):
# Eval at pos
def t(x, y, x_q, y_q):
x = np.concatenate([x,x_q[:,:1]], 1)
y = np.concatenate([y,y_q[:,:1]], 1)
y = torch.from_numpy(y).transpose(0, 1)
target_y = y.clone().detach()
target_y[:-1] = -100
x = torch.from_numpy(x)
if self.translations and self.train:
x = translate_omniglot(x)
image_tensor = x.view(*x.shape[:2], -1).transpose(0, 1), y
return image_tensor, target_y
return (t(*self.d.next(mode='train' if self.train else 'test')) for _ in range(self.num_steps))
@torch.no_grad()
def validate(self, finetuned_model, eval_pos=-1):
finetuned_model.eval()
device = next(iter(finetuned_model.parameters())).device
if not hasattr(self, 't_dl'):
self.t_dl = DataLoader(num_steps=self.num_steps, batch_size=self.batch_size, seq_len=self.seq_len,
num_features=self.num_features, num_outputs=self.num_outputs, fuse_x_y=self.fuse_x_y,
train=False)
ps = []
ys = []
for x,y in self.t_dl:
p = finetuned_model(tuple(e.to(device) for e in x), single_eval_pos=eval_pos)
ps.append(p)
ys.append(y)
ps = torch.cat(ps,1)
ys = torch.cat(ys,1)
def acc(ps,ys):
return (ps.argmax(-1)==ys.to(ps.device)).float().mean()
a = acc(ps[eval_pos], ys[eval_pos]).cpu()
return a