wgcban's picture
Upload 98 files
803ef9e
raw
history blame
503 Bytes
import torch
def eval_knn(x_train, y_train, x_test, y_test, k=200):
""" k-nearest neighbors classifier accuracy """
d = torch.cdist(x_test, x_train)
topk = torch.topk(d, k=k, dim=1, largest=False)
labels = y_train[topk.indices]
pred = torch.empty_like(y_test)
for i in range(len(labels)):
x = labels[i].unique(return_counts=True)
pred[i] = x[0][x[1].argmax()]
acc = (pred == y_test).float().mean().cpu().item()
del d, topk, labels, pred
return acc