AceMath-7B-Instruct / evaluation /calculate_scores.py
zihanliu's picture
Upload 3 files
5689fb0 verified
raw
history blame
2.58 kB
from grader import is_equal
import json
import re
def get_gold_list(datapath, dataset_name):
assert dataset_name in ["gsm8k", "math", "minerva_math", "gaokao2023en", "olympiadbench", "collegemath"]
gold_list = []
with open(datapath, "r") as f:
for line in f:
item = json.loads(line)
if dataset_name == "gsm8k":
gold = item['answer'].split("#### ")[-1]
elif dataset_name == "math":
gold = item['answer']
elif dataset_name == "minerva_math":
pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern_re = re.compile(pattern, re.DOTALL)
solution = item['solution']
matches = pattern_re.findall(solution)
if len(matches) == 0:
gold = None
else:
gold = matches[-1]
elif dataset_name == "gaokao2023en":
gold = re.sub(r'^\$(.*)\$$', r'\1', item['answer'])
elif dataset_name == "olympiadbench":
gold = re.sub(r'^\$(.*)\$$', r'\1', item['final_answer'][0])
elif dataset_name == "collegemath":
gold = re.sub(r'^\$(.*)\$$', r'\1', item['answer'])
gold_list.append(gold)
return gold_list
def get_scores_on_math_benchmarks(model_output_path, test_gold_path, dataset_name):
gold_list = get_gold_list(test_gold_path, dataset_name)
"""TODO
Get the output_list from model_output_path
output_list is a list of string (List[str])
Each string represents the model's response for a corresponding question in the benchmark
Therefore, the length of output_list must match the length of gold_list.
output_list = ...
"""
correct = 0
for output, gold in zip(output_list, gold_list):
if is_equal(output, gold, dataset_name):
correct += 1
print("accuracy on %s is %.4f" % (dataset_name, correct / len(gold_list)))
if __name__ == "__main__":
"""TODO
Download test benchmarks from Qwen2.5-Math
https://github.com/QwenLM/Qwen2.5-Math/tree/main/evaluation/data
Prepare model_output_path and test_gold_path for each dataset
"""
test_gold_path = "PATH_OF_THE_BENCHMARK"
model_output_path = "PATH_OF_YOUR_MODEL_OUTPUTS"
dataset_name = "DATASET_NAME" # e.g., gsm8k, math, "minerva_math", "gaokao2023en", "olympiadbench", "collegemath"
get_scores_on_math_benchmarks(model_output_path, test_gold_path, dataset_name)