mix-bt / ssl-sota /eval /get_data.py
wgcban's picture
Upload 98 files
803ef9e
raw
history blame
558 Bytes
import torch
def get_data(model, loader, output_size, device):
""" encodes whole dataset into embeddings """
xs = torch.empty(
len(loader), loader.batch_size, output_size, dtype=torch.float32, device=device
)
ys = torch.empty(len(loader), loader.batch_size, dtype=torch.long, device=device)
with torch.no_grad():
for i, (x, y) in enumerate(loader):
x = x.cuda()
xs[i] = model(x).to(device)
ys[i] = y.to(device)
xs = xs.view(-1, output_size)
ys = ys.view(-1)
return xs, ys