Spaces:
Build error
Build error
File size: 2,742 Bytes
7569f5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import re
import json
import subprocess
from importlib import reload
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--path", type=str, required=True)
parser.add_argument("--max_line", type=int, default=1000000000000)
parser.add_argument("--ci_smoke_test", action="store_true")
args = parser.parse_args()
def check_corr(result: str, correct_solution: str, tol: float = 1e-3):
result = result.replace(",", "")
if result.strip() == correct_solution.strip():
return 1
try:
result = float(result.strip())
correct_solution = float(correct_solution.strip())
return abs(result - correct_solution) < tol
except:
return 0
# final_accs = []
# for i in range(2):
# acc = 0
# total = 0
# with open(args.path) as f:
# for line in f:
# line = json.loads(line)
# label = str(line["label"])
# if i == 0:
# code = line["response"]
# else:
# code = line["logs"][0]["content"]
# total += 1
# code = code.strip().replace("```", "")
# code = code.lstrip("python3")
# code = code.lstrip("python")
# with open("tmp.py", "w") as f:
# f.write(code)
# try:
# import tmp
# reload(tmp)
# result = str(tmp.solution())
# is_corr = check_corr(result, label)
# is_corr = int(is_corr)
# # Step 2
# if is_corr:
# acc += 1
# except:
# print(code)
# final_accs.append(acc / total)
# print(final_accs)
final_accs = []
err_cnts = []
for i in range(2):
acc = 0
total = 0
err_cnt = 0
with open(args.path) as f:
for idx, line in enumerate(f):
if idx == args.max_line:
break
line = json.loads(line)
label = str(line["label"])
if i == 0:
response = line["response"]
else:
if line["logs"][0]["module"] == "Role Assigner":
response = line["logs"][1]["content"]
else:
response = line["logs"][0]["content"]
total += 1
result = re.findall(r"\\boxed\{(.+?)\}", response)
if len(result) == 0:
err_cnt += 1
print(response)
continue
result = result[0]
acc += check_corr(result, label)
final_accs.append(acc / total)
err_cnts.append(err_cnt)
print(final_accs)
print(err_cnts)
if args.ci_smoke_test is True:
assert final_accs[0] == 1.0
|