Samuel Mueller
working locally
f50f696
raw
history blame
4.3 kB
import math
import random
import torch
from torch.utils import data
from torchvision import transforms
import numpy as np
from datasets import omniglotNshot
import utils
def _compute_maxtranslations(single_image_tensor, dim, background):
assert len(single_image_tensor.shape) == 2
content_rows = ((single_image_tensor == background).all(dim=1 - dim) == False).nonzero()
begin, end = content_rows[0], content_rows[-1]
return torch.cat([-begin, single_image_tensor.shape[dim] - end - 1]).cpu().tolist()
def compute_maxtranslations_x_y(single_image_tensor, background):
return _compute_maxtranslations(single_image_tensor, 1, background), _compute_maxtranslations(single_image_tensor,
0, background)
def translate(img, trans_x, trans_y):
return transforms.functional.affine(img.unsqueeze(0), angle=0.0, translate=[trans_x, trans_y], scale=1.0,
interpolation=transforms.InterpolationMode.NEAREST, shear=[0.0, 0.0],
fill=0.).squeeze(0)
def translate_omniglot(image_tensor, background=0.):
flat_image_tensor = image_tensor.view(-1, *image_tensor.shape[-2:])
for i, image in enumerate(flat_image_tensor):
max_x, max_y = compute_maxtranslations_x_y(image, background)
flat_image_tensor[i] = translate(image, random.randint(*max_x), random.randint(*max_y))
return flat_image_tensor.view(*image_tensor.shape)
class DataLoader(data.DataLoader):
def __init__(self, num_steps, batch_size, seq_len, num_features, num_outputs, num_classes_used=1200, fuse_x_y=False, train=True, translations=True, jonas_style=False):
# TODO position before last is predictable by counting..
utils.set_locals_in_self(locals())
assert not fuse_x_y, 'So far don\' support fusing.'
imgsz = math.isqrt(num_features)
assert imgsz * imgsz == num_features
assert ((seq_len-1) // num_outputs) * num_outputs == seq_len - 1
if jonas_style:
self.d = omniglotNshot.OmniglotNShotJonas('omniglot', batchsz=batch_size, n_way=num_outputs,
k_shot=((seq_len - 1) // num_outputs),
k_query=1, imgsz=imgsz)
else:
self.d = omniglotNshot.OmniglotNShot('omniglot', batchsz=batch_size, n_way=num_outputs,
k_shot=((seq_len - 1) // num_outputs),
k_query=1, imgsz=imgsz, num_train_classes_used=num_classes_used)
def __len__(self):
return self.num_steps
def __iter__(self):
# Eval at pos
def t(x, y, x_q, y_q):
x = np.concatenate([x,x_q[:,:1]], 1)
y = np.concatenate([y,y_q[:,:1]], 1)
y = torch.from_numpy(y).transpose(0, 1)
target_y = y.clone().detach()
target_y[:-1] = -100
x = torch.from_numpy(x)
if self.translations and self.train:
x = translate_omniglot(x)
image_tensor = x.view(*x.shape[:2], -1).transpose(0, 1), y
return image_tensor, target_y
return (t(*self.d.next(mode='train' if self.train else 'test')) for _ in range(self.num_steps))
@torch.no_grad()
def validate(self, finetuned_model, eval_pos=-1):
finetuned_model.eval()
device = next(iter(finetuned_model.parameters())).device
if not hasattr(self, 't_dl'):
self.t_dl = DataLoader(num_steps=self.num_steps, batch_size=self.batch_size, seq_len=self.seq_len,
num_features=self.num_features, num_outputs=self.num_outputs, fuse_x_y=self.fuse_x_y,
train=False)
ps = []
ys = []
for x,y in self.t_dl:
p = finetuned_model(tuple(e.to(device) for e in x), single_eval_pos=eval_pos)
ps.append(p)
ys.append(y)
ps = torch.cat(ps,1)
ys = torch.cat(ys,1)
def acc(ps,ys):
return (ps.argmax(-1)==ys.to(ps.device)).float().mean()
a = acc(ps[eval_pos], ys[eval_pos]).cpu()
return a