wgcban's picture
Upload 98 files
803ef9e
raw
history blame
1.17 kB
import torch
import torch.nn as nn
import torch.optim as optim
def eval_sgd(x_train, y_train, x_test, y_test, topk=[1, 5], epoch=500):
""" linear classifier accuracy (sgd) """
lr_start, lr_end = 1e-2, 1e-6
gamma = (lr_end / lr_start) ** (1 / epoch)
output_size = x_train.shape[1]
num_class = y_train.max().item() + 1
clf = nn.Linear(output_size, num_class)
clf.cuda()
clf.train()
optimizer = optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
criterion = nn.CrossEntropyLoss()
for ep in range(epoch):
perm = torch.randperm(len(x_train)).view(-1, 1000)
for idx in perm:
optimizer.zero_grad()
criterion(clf(x_train[idx]), y_train[idx]).backward()
optimizer.step()
scheduler.step()
clf.eval()
with torch.no_grad():
y_pred = clf(x_test)
pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices
acc = {
t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item()
for t in topk
}
del clf
return acc