pminervini commited on
Commit
b06387f
·
1 Parent(s): 6524ea0
Files changed (1) hide show
  1. cli/analysis-cli.py +136 -0
cli/analysis-cli.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import sys
5
+ import json
6
+
7
+ import numpy as np
8
+
9
+ import pandas as pd
10
+ import seaborn as sns
11
+ import matplotlib.pyplot as plt
12
+
13
+ from scipy.cluster.hierarchy import linkage
14
+
15
+ from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task
16
+
17
+ from src.envs import QUEUE_REPO, RESULTS_REPO, API
18
+ from src.utils import my_snapshot_download
19
+
20
+
21
+ def find_json_files(json_path):
22
+ res = []
23
+ for root, dirs, files in os.walk(json_path):
24
+ for file in files:
25
+ if file.endswith(".json"):
26
+ res.append(os.path.join(root, file))
27
+ return res
28
+
29
+
30
+ my_snapshot_download(repo_id=RESULTS_REPO, revision="main", local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
31
+ my_snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
32
+
33
+ result_path_lst = find_json_files(EVAL_RESULTS_PATH_BACKEND)
34
+ request_path_lst = find_json_files(EVAL_REQUESTS_PATH_BACKEND)
35
+
36
+ model_name_to_model_map = {}
37
+
38
+ for path in request_path_lst:
39
+ with open(path, 'r') as f:
40
+ data = json.load(f)
41
+ model_name_to_model_map[data["model"]] = data
42
+
43
+ model_dataset_metric_to_result_map = {}
44
+ data_map = {}
45
+
46
+ for path in result_path_lst:
47
+ with open(path, 'r') as f:
48
+ data = json.load(f)
49
+ model_name = data["config"]["model_name"]
50
+ for dataset_name, results_dict in data["results"].items():
51
+ for metric_name, value in results_dict.items():
52
+
53
+ # print(model_name, dataset_name, metric_name, value)
54
+
55
+ if ',' in metric_name and '_stderr' not in metric_name \
56
+ and 'f1' not in metric_name \
57
+ and 'selfcheckgpt' not in dataset_name \
58
+ and model_name_to_model_map[model_name]["likes"] > 256:
59
+
60
+ to_add = True
61
+
62
+ if 'nq_open' in dataset_name or 'triviaqa' in dataset_name:
63
+ to_add = False
64
+ # pass
65
+
66
+ # breakpoint()
67
+
68
+ if 'bertscore' in metric_name:
69
+ if 'precision' not in metric_name:
70
+ to_add = False
71
+
72
+ if 'correctness,' in metric_name or 'em,' in metric_name:
73
+ to_add = False
74
+
75
+ if 'rouge' in metric_name:
76
+ if 'rougeL' not in metric_name:
77
+ to_add = False
78
+
79
+ if 'ifeval' in dataset_name:
80
+ if 'prompt_level_strict_acc' not in metric_name:
81
+ to_add = False
82
+
83
+ if 'squad' in dataset_name:
84
+ to_add = False
85
+
86
+ if 'fever' in dataset_name:
87
+ to_add = False
88
+
89
+ if 'rouge' in metric_name:
90
+ value /= 100.0
91
+
92
+ if to_add:
93
+ sanitised_metric_name = metric_name.split(',')[0]
94
+ model_dataset_metric_to_result_map[(model_name, dataset_name, sanitised_metric_name)] = value
95
+
96
+ # if (model_name, dataset_name) not in data_map:
97
+ # data_map[(model_name, dataset_name)] = {}
98
+ # data_map[(model_name, dataset_name)][metric_name] = value
99
+
100
+ if model_name not in data_map:
101
+ data_map[model_name] = {}
102
+ data_map[model_name][(dataset_name, sanitised_metric_name)] = value
103
+
104
+ print('model_name', model_name, 'dataset_name', dataset_name, 'metric_name', metric_name, 'value', value)
105
+
106
+ model_name_lst = [m for m in data_map.keys()]
107
+ for m in model_name_lst:
108
+ if len(data_map[m]) < 8:
109
+ del data_map[m]
110
+
111
+ df = pd.DataFrame.from_dict(data_map, orient='index')
112
+ o_df = df.copy(deep=True)
113
+
114
+ print(df)
115
+
116
+ # Check for NaN or infinite values and replace them
117
+ df.replace([np.inf, -np.inf], np.nan, inplace=True) # Replace infinities with NaN
118
+ df.fillna(0, inplace=True) # Replace NaN with 0 (or use another imputation strategy)
119
+
120
+ from sklearn.preprocessing import MinMaxScaler
121
+
122
+ # scaler = MinMaxScaler()
123
+ # df = pd.DataFrame(scaler.fit_transform(df), index=df.index, columns=df.columns)
124
+
125
+ sns.set_context("notebook", font_scale=1.0)
126
+
127
+ # fig = sns.clustermap(df, method='average', metric='cosine', cmap='coolwarm', figsize=(16, 12), annot=True)
128
+ fig = sns.clustermap(df, method='ward', metric='euclidean', cmap='coolwarm', figsize=(16, 12), annot=True, mask=o_df.isnull())
129
+
130
+ # Adjust the size of the cells (less wide)
131
+ plt.setp(fig.ax_heatmap.get_yticklabels(), rotation=0)
132
+ plt.setp(fig.ax_heatmap.get_xticklabels(), rotation=90)
133
+
134
+ # Save the clustermap to file
135
+ fig.savefig('plots/clustermap.pdf')
136
+ fig.savefig('plots/clustermap.png')