File size: 558 Bytes
803ef9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
def get_data(model, loader, output_size, device):
""" encodes whole dataset into embeddings """
xs = torch.empty(
len(loader), loader.batch_size, output_size, dtype=torch.float32, device=device
)
ys = torch.empty(len(loader), loader.batch_size, dtype=torch.long, device=device)
with torch.no_grad():
for i, (x, y) in enumerate(loader):
x = x.cuda()
xs[i] = model(x).to(device)
ys[i] = y.to(device)
xs = xs.view(-1, output_size)
ys = ys.view(-1)
return xs, ys
|