arshy commited on
Commit
da3b15a
·
1 Parent(s): 8aabc99

automation codes

Browse files
Files changed (2) hide show
  1. automate/automate.py +29 -0
  2. automate/run_benchmark.py +288 -0
automate/automate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from apscheduler.schedulers.blocking import BackgroundScheduler
4
+
5
+
6
+ def run_command(command, shell=True):
7
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell)
8
+ stdout, stderr = process.communicate()
9
+
10
+ if process.returncode == 0:
11
+ print("Command executed successfully")
12
+ print(stdout.decode())
13
+ else:
14
+ print("Command failed")
15
+ print(stderr.decode())
16
+
17
+
18
+ def run_benchmark():
19
+ run_command("python run_benchmark.py")
20
+
21
+
22
+ scheduler = BackgroundScheduler()
23
+ scheduler.add_job(
24
+ run_benchmark,
25
+ 'cron',
26
+ day_of_week='sun',
27
+ hour=0,
28
+ timezone='UTC')
29
+ scheduler.start()
automate/run_benchmark.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import csv
5
+ import json
6
+ import time
7
+ import pickle
8
+ import openai
9
+ import pandas as pd
10
+ from pathlib import Path
11
+ from tqdm import tqdm
12
+ from dotenv import load_dotenv
13
+ from mech.packages.valory.customs.prediction_request import prediction_request
14
+ from benchmark.utils import get_logger, TokenCounterCallback
15
+
16
+ load_dotenv()
17
+ logger = get_logger(__name__)
18
+ this_dir = Path(__file__).parent
19
+
20
+
21
+ def tool_map(tool):
22
+ """Map the tool name to the tool class."""
23
+
24
+ tool_dict = {
25
+ "prediction-online": prediction_request,
26
+ "prediction-offline": prediction_request,
27
+ }
28
+
29
+ tool = tool_dict.get(tool, None)
30
+
31
+ if tool is None:
32
+ raise Exception(f"Tool {tool} not found.")
33
+ else:
34
+ return tool
35
+
36
+
37
+ def prepare_questions(kwargs):
38
+ test_questions = json.load(
39
+ open(this_dir / "olas-predict-benchmark/benchmark/data/autocast/autocast_questions_filtered.json")
40
+ )
41
+ with open(
42
+ this_dir / "olas-predict-benchmark/benchmark/data/autocast/autocast_questions_filtered.pkl", "rb"
43
+ ) as f:
44
+ url_to_content = pickle.load(f)
45
+ num_questions = kwargs.pop("num_questions", len(test_questions))
46
+
47
+ questions = []
48
+ for q in test_questions:
49
+ if q["qtype"] == "t/f" and q["answer"] is not None:
50
+ questions.append(q)
51
+ if len(questions) >= num_questions:
52
+ break
53
+
54
+ return questions, url_to_content
55
+
56
+
57
+ def parse_response(response, test_q):
58
+ try:
59
+ result = json.loads(response[0])
60
+ except Exception as e:
61
+ print("The response is not json-format compatible")
62
+ print(f"################### response[0] = {response[0]}")
63
+ test_q["Correct"] = False
64
+ test_q["prediction"] = None
65
+ return test_q
66
+
67
+ if "p_yes" in result.keys():
68
+ test_q["p_yes"] = float(result["p_yes"])
69
+ else:
70
+ test_q["p_yes"] = None
71
+
72
+ if "p_no" in result.keys():
73
+ test_q["p_no"] = float(result["p_no"])
74
+ else:
75
+ test_q["p_no"] = None
76
+
77
+ if "confidence" in result.keys():
78
+ test_q["confidence"] = float(result["confidence"])
79
+ else:
80
+ test_q["confidence"] = None
81
+
82
+ if "info_utility" in result.keys():
83
+ test_q["info_utility"] = float(result["info_utility"])
84
+ else:
85
+ test_q["info_utility"] = None
86
+
87
+ if response[3] is not None:
88
+ test_q["input_tokens"] = response[3].cost_dict["input_tokens"]
89
+ test_q["output_tokens"] = response[3].cost_dict["output_tokens"]
90
+ test_q["total_tokens"] = response[3].cost_dict["total_tokens"]
91
+ test_q["input_cost"] = response[3].cost_dict["input_cost"]
92
+ test_q["output_cost"] = response[3].cost_dict["output_cost"]
93
+ test_q["total_cost"] = response[3].cost_dict["total_cost"]
94
+ test_q["prompt_response"] = response[1].replace(os.linesep, "")
95
+
96
+ if (test_q["p_yes"] is None) or (float(result["p_yes"]) == float(result["p_no"])):
97
+ test_q["prediction"] = None
98
+ else:
99
+ test_q["prediction"] = "yes" if test_q["p_yes"] > test_q["p_no"] else "no"
100
+ test_q["Correct"] = test_q["prediction"] == test_q["answer"]
101
+ return test_q
102
+
103
+
104
+ def write_results(csv_file_path):
105
+
106
+ results_path = Path(csv_file_path.parent)
107
+ time_string = csv_file_path.stem.split("_", 1)[-1]
108
+
109
+ results_df = pd.read_csv(csv_file_path)
110
+ num_errors = results_df["error"].count()
111
+ logger.info(f"Num errors: {str(num_errors)}")
112
+ results_df = results_df.dropna(subset=["prediction"])
113
+ grouped_df = results_df.groupby(["tool", "model"]).agg(
114
+ {
115
+ "Correct": ["mean", "sum", "count"],
116
+ "crowd_correct": ["mean"],
117
+ "input_tokens": ["mean"],
118
+ "output_tokens": ["mean"],
119
+ "total_tokens": ["mean"],
120
+ "input_cost": ["mean"],
121
+ "output_cost": ["mean"],
122
+ "total_cost": ["mean"],
123
+ }
124
+ )
125
+
126
+ grouped_df.columns = ["_".join(col).strip() for col in grouped_df.columns.values]
127
+ summary_df = grouped_df.reset_index().rename(
128
+ columns={
129
+ "Correct_mean": "accuracy",
130
+ "Correct_sum": "correct",
131
+ "Correct_count": "total",
132
+ "crowd_correct_mean": "crowd_accuracy",
133
+ }
134
+ )
135
+
136
+ logger.info(f"Results:\n\n {results_df}")
137
+ summary_df.to_csv(results_path / f"summary_{time_string}.csv", index=False)
138
+
139
+
140
+ def run_benchmark(kwargs):
141
+ """Start the benchmark tests. If a category flag is provided, run the categories with that mark."""
142
+
143
+ logger.info("Running benchmark tests...")
144
+
145
+ tools = kwargs.pop("tools")
146
+ model = kwargs.pop("model")[0]
147
+ MAX_RETRIES = kwargs.pop("max_retries", 3)
148
+ questions, url_to_content = prepare_questions(kwargs)
149
+ logger.info(f"Running {len(questions)} questions for each tool: {tools}")
150
+
151
+ results_path = Path("results")
152
+ if not results_path.exists():
153
+ results_path.mkdir(exist_ok=True)
154
+
155
+ start_time = time.time()
156
+ time_string = time.strftime("%y%m%d%H%M%S", time.localtime(start_time))
157
+ csv_file_path = results_path / f"results_{time_string}.csv"
158
+
159
+ logger.info("Creating csv files...")
160
+ with open(csv_file_path, mode="a", newline="") as file:
161
+ fieldnames = [
162
+ "prompt",
163
+ "answer",
164
+ "tool",
165
+ "model",
166
+ "p_yes",
167
+ "p_no",
168
+ "confidence",
169
+ "info_utility",
170
+ "prediction",
171
+ "Correct",
172
+ "input_tokens",
173
+ "output_tokens",
174
+ "total_tokens",
175
+ "input_cost",
176
+ "output_cost",
177
+ "total_cost",
178
+ "prompt_response",
179
+ "error",
180
+ "crowd_prediction",
181
+ "crowd_correct",
182
+ ]
183
+ writer = csv.DictWriter(file, fieldnames=fieldnames)
184
+
185
+ if file.tell() == 0:
186
+ writer.writeheader()
187
+
188
+ for t in tools:
189
+ logger.info("Loading the tool...")
190
+ try:
191
+ tool = tool_map(t)
192
+ except Exception as e:
193
+ logger.error(f"Error while loading the tool={tool}")
194
+ continue
195
+ correct_answers = 0
196
+ total_answers = 0
197
+ for test_question in tqdm(
198
+ questions, desc=f"Running tool {t}", total=len(questions)
199
+ ):
200
+ test_q = {
201
+ "prompt": test_question["question"],
202
+ "answer": test_question["answer"],
203
+ "crowd_prediction": test_question["crowd"][-1]["forecast"],
204
+ "tool": t,
205
+ "model": model,
206
+ "counter_callback": TokenCounterCallback(),
207
+ "prompt_response": None,
208
+ }
209
+
210
+ if kwargs["provide_source_links"]:
211
+ test_q["source_links"] = test_question["source_links"]
212
+ test_q["source_links"] = {
213
+ source_link: url_to_content[source_link]
214
+ for source_link in test_q["source_links"]
215
+ }
216
+
217
+ crowd_forecast = test_question["crowd"][-1]["forecast"]
218
+ test_q["crowd_prediction"] = (
219
+ "yes"
220
+ if crowd_forecast > 0.5
221
+ else "no" if crowd_forecast < 0.5 else None
222
+ )
223
+ test_q["crowd_correct"] = test_q["crowd_prediction"] == test_q["answer"]
224
+
225
+ CURRENT_RETRIES = 0
226
+ while True:
227
+ try:
228
+ response = tool.run(**{**test_q, **kwargs})
229
+ test_q = parse_response(response, test_q)
230
+ if test_q["Correct"] == True:
231
+ correct_answers += 1
232
+ if test_q["prediction"] is not None:
233
+ total_answers += 1
234
+ print(
235
+ f"===========ACCURACY============== {correct_answers/total_answers*100}%"
236
+ )
237
+ break
238
+ except openai.APIError as e:
239
+ logger.error(f"Error running benchmark for tool {t}: {e}")
240
+ CURRENT_RETRIES += 1
241
+ if CURRENT_RETRIES > MAX_RETRIES:
242
+ logger.error(
243
+ f"Max retries reached for tool {t}. Skipping question."
244
+ )
245
+ test_q["error"] = e
246
+ break
247
+ else:
248
+ logger.info(
249
+ f"Retrying tool {t} for question {test_q['prompt']}"
250
+ )
251
+ continue
252
+
253
+ except Exception as e:
254
+ logger.error(f"Error running benchmark for tool {t}: {e}")
255
+ test_q["error"] = e
256
+ break
257
+
258
+ if kwargs["provide_source_links"]:
259
+ del test_q["source_links"]
260
+ del test_q["counter_callback"]
261
+
262
+ writer.writerow(test_q)
263
+
264
+ write_results(csv_file_path)
265
+
266
+ end_time = time.time()
267
+ total_time = end_time - start_time
268
+ logger.info(f"Total Time: {total_time} seconds")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ kwargs = {}
273
+ kwargs["num_questions"] = 10
274
+ kwargs["tools"] = [
275
+ "prediction-online",
276
+ ]
277
+ kwargs["model"] = [
278
+ "gpt-3.5-turbo-0125",
279
+ ]
280
+ kwargs["api_keys"] = {}
281
+ kwargs["api_keys"]["openai"] = os.getenv("OPENAI_API_KEY")
282
+ kwargs["api_keys"]["anthropic"] = os.getenv("ANTHROPIC_API_KEY")
283
+ kwargs["api_keys"]["openrouter"] = os.getenv("OPENROUTER_API_KEY")
284
+
285
+ kwargs["num_urls"] = 3
286
+ kwargs["num_words"] = 300
287
+ kwargs["provide_source_links"] = True
288
+ run_benchmark(kwargs)