File size: 1,883 Bytes
b24d496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np


def calculate_metrics_at_k(pred, true, k, compensate_div_0=False, dynamic_topk=True, 
                           skip_empty_trues=False, skip_empty_preds=False):
    precisions_at_k = []
    recalls_at_k = []
    f1_scores_at_k = []
    
    for query_id in pred:
        if dynamic_topk:
            k = len(set(pred[query_id]))
        
        retrieved_documents = set(pred[query_id][:k])
        relevant_documents = set(true[query_id])
        true_positives = len(retrieved_documents.intersection(relevant_documents))

        if compensate_div_0 and not len(retrieved_documents) and not len(relevant_documents):
            precisions_at_k.append(1)
            recalls_at_k.append(1)
            f1_scores_at_k.append(1)
            continue

        if skip_empty_trues and not len(relevant_documents):
            continue

        if skip_empty_preds and not len(retrieved_documents):
            continue
        
        # precision
        precision_at_k = true_positives / k if k else 0
        precisions_at_k.append(precision_at_k)
        
        # recall
        recall_at_k = true_positives / len(relevant_documents) if relevant_documents else 0
        recalls_at_k.append(recall_at_k)
        
        # f1
        if precision_at_k + recall_at_k > 0:
            f1_at_k = 2 * (precision_at_k * recall_at_k) / (precision_at_k + recall_at_k)
        else:
            f1_at_k = 0
        f1_scores_at_k.append(f1_at_k)
        
    # Average Precision@k, Recall@k, and F1@k
    avg_precision_at_k = np.mean(precisions_at_k) if precisions_at_k else 0
    avg_recall_at_k = np.mean(recalls_at_k) if recalls_at_k else 0
    avg_f1_at_k = np.mean(f1_scores_at_k) if f1_scores_at_k else 0
    
    return {
        f"avg_precision@{k}": avg_precision_at_k, 
        f"avg_recall@{k}": avg_recall_at_k,
        f"avg_f1@{k}": avg_f1_at_k
    }