File size: 503 Bytes
803ef9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
def eval_knn(x_train, y_train, x_test, y_test, k=200):
""" k-nearest neighbors classifier accuracy """
d = torch.cdist(x_test, x_train)
topk = torch.topk(d, k=k, dim=1, largest=False)
labels = y_train[topk.indices]
pred = torch.empty_like(y_test)
for i in range(len(labels)):
x = labels[i].unique(return_counts=True)
pred[i] = x[0][x[1].argmax()]
acc = (pred == y_test).float().mean().cpu().item()
del d, topk, labels, pred
return acc
|