agentVerse / scripts /evaluate_logic.py
AgentVerse's picture
first commit
7569f5d
import re
import json
import subprocess
from importlib import reload
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--path", type=str, required=True)
parser.add_argument("--max_line", type=int, default=1000000000000)
args = parser.parse_args()
def check_corr(result: str, correct_solution: str, tol: float = 1e-3):
result = result.replace(",", "")
if result.strip() == correct_solution.strip():
return 1
try:
result = float(result.strip())
correct_solution = float(correct_solution.strip())
return abs(result - correct_solution) < tol
except:
return 0
final_accs = []
err_cnts = []
for i in range(2):
acc = 0
total = 0
err_cnt = 0
with open(args.path) as f:
for idx, line in enumerate(f):
if idx == args.max_line:
break
line = json.loads(line)
label = str(line["label"])
if i == 0:
response = line["response"]
else:
if line["logs"][0]["module"] == "Role Assigner":
response = line["logs"][1]["content"]
else:
response = line["logs"][0]["content"]
total += 1
result = re.findall(r"\\boxed\{(.+?)\}", response)
if len(result) == 0:
err_cnt += 1
# print(response)
continue
result = result[0]
result = re.sub(r"\\text\{.+\}?", "", result)
result = (
result.replace("rd", "")
.replace("nd", "")
.replace("st", "")
.replace("th", "")
.replace("House", "")
.replace("house", "")
.replace("\\", "")
)
# acc += check_corr(result, label)
try:
acc += int(result) == int(label)
except:
print(result)
final_accs.append(acc / total)
err_cnts.append(err_cnt)
print(final_accs)
print(err_cnts)