import torch | |
def get_data(model, loader, output_size, device): | |
""" encodes whole dataset into embeddings """ | |
xs = torch.empty( | |
len(loader), loader.batch_size, output_size, dtype=torch.float32, device=device | |
) | |
ys = torch.empty(len(loader), loader.batch_size, dtype=torch.long, device=device) | |
with torch.no_grad(): | |
for i, (x, y) in enumerate(loader): | |
x = x.cuda() | |
xs[i] = model(x).to(device) | |
ys[i] = y.to(device) | |
xs = xs.view(-1, output_size) | |
ys = ys.view(-1) | |
return xs, ys | |