zihanliu commited on
Commit
5689fb0
·
verified ·
1 Parent(s): 8a35aa3

Upload 3 files

Browse files
evaluation/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ ## Introduction
5
+ This is the evaluation script used to reproduce math benchmarks scores for AceMath-1.5B/7B/72B-Instruct models based on their outputs. The benchmark can be downloaded from [Qwen2.5-Math](https://github.com/QwenLM/Qwen2.5-Math/tree/main/evaluation/data).
6
+
7
+ ## Calculate Scores
8
+ ```console
9
+ python calculate_scores.py
10
+ ```
evaluation/calculate_scores.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from grader import is_equal
3
+ import json
4
+ import re
5
+
6
+
7
+ def get_gold_list(datapath, dataset_name):
8
+
9
+ assert dataset_name in ["gsm8k", "math", "minerva_math", "gaokao2023en", "olympiadbench", "collegemath"]
10
+
11
+ gold_list = []
12
+ with open(datapath, "r") as f:
13
+ for line in f:
14
+ item = json.loads(line)
15
+
16
+ if dataset_name == "gsm8k":
17
+ gold = item['answer'].split("#### ")[-1]
18
+
19
+ elif dataset_name == "math":
20
+ gold = item['answer']
21
+
22
+ elif dataset_name == "minerva_math":
23
+ pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
24
+ pattern_re = re.compile(pattern, re.DOTALL)
25
+ solution = item['solution']
26
+ matches = pattern_re.findall(solution)
27
+ if len(matches) == 0:
28
+ gold = None
29
+ else:
30
+ gold = matches[-1]
31
+
32
+ elif dataset_name == "gaokao2023en":
33
+ gold = re.sub(r'^\$(.*)\$$', r'\1', item['answer'])
34
+
35
+ elif dataset_name == "olympiadbench":
36
+ gold = re.sub(r'^\$(.*)\$$', r'\1', item['final_answer'][0])
37
+
38
+ elif dataset_name == "collegemath":
39
+ gold = re.sub(r'^\$(.*)\$$', r'\1', item['answer'])
40
+
41
+ gold_list.append(gold)
42
+
43
+ return gold_list
44
+
45
+
46
+ def get_scores_on_math_benchmarks(model_output_path, test_gold_path, dataset_name):
47
+
48
+ gold_list = get_gold_list(test_gold_path, dataset_name)
49
+
50
+ """TODO
51
+ Get the output_list from model_output_path
52
+ output_list is a list of string (List[str])
53
+ Each string represents the model's response for a corresponding question in the benchmark
54
+ Therefore, the length of output_list must match the length of gold_list.
55
+
56
+ output_list = ...
57
+ """
58
+
59
+ correct = 0
60
+ for output, gold in zip(output_list, gold_list):
61
+ if is_equal(output, gold, dataset_name):
62
+ correct += 1
63
+
64
+ print("accuracy on %s is %.4f" % (dataset_name, correct / len(gold_list)))
65
+
66
+
67
+ if __name__ == "__main__":
68
+ """TODO
69
+ Download test benchmarks from Qwen2.5-Math
70
+ https://github.com/QwenLM/Qwen2.5-Math/tree/main/evaluation/data
71
+
72
+ Prepare model_output_path and test_gold_path for each dataset
73
+ """
74
+
75
+ test_gold_path = "PATH_OF_THE_BENCHMARK"
76
+ model_output_path = "PATH_OF_YOUR_MODEL_OUTPUTS"
77
+ dataset_name = "DATASET_NAME" # e.g., gsm8k, math, "minerva_math", "gaokao2023en", "olympiadbench", "collegemath"
78
+
79
+ get_scores_on_math_benchmarks(model_output_path, test_gold_path, dataset_name)
evaluation/grader.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ This script is adapted from Qwen2.5-Math
4
+ https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/grader.py
5
+ """
6
+
7
+ import re
8
+ import regex
9
+ import multiprocessing
10
+ from math import isclose
11
+ from typing import Union
12
+ from collections import defaultdict
13
+
14
+ from sympy import simplify, N
15
+ from sympy.parsing.sympy_parser import parse_expr
16
+ from sympy.parsing.latex import parse_latex
17
+
18
+
19
+ def latex2sympy(sympy: str, variable_values={}):
20
+ # record frac
21
+ global frac_type
22
+ if sympy.find(r'\frac') != -1:
23
+ frac_type = r'\frac'
24
+ if sympy.find(r'\dfrac') != -1:
25
+ frac_type = r'\dfrac'
26
+ if sympy.find(r'\tfrac') != -1:
27
+ frac_type = r'\tfrac'
28
+ sympy = sympy.replace(r'\dfrac', r'\frac')
29
+ sympy = sympy.replace(r'\tfrac', r'\frac')
30
+ # Translate Transpose
31
+ sympy = sympy.replace(r'\mathrm{T}', 'T', -1)
32
+ # Translate Derivative
33
+ sympy = sympy.replace(r'\mathrm{d}', 'd', -1).replace(r'{\rm d}', 'd', -1)
34
+ # Translate Matrix
35
+ sympy = sympy.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1)
36
+ # Translate Permutation
37
+ sympy = re.sub(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}", r"\\frac{(\1)!}{((\1)-(\2))!}", sympy)
38
+ # Remove \displaystyle
39
+ sympy = sympy.replace(r'\displaystyle', ' ', -1)
40
+ # Remove \quad
41
+ sympy = sympy.replace(r'\quad', ' ', -1).replace(r'\qquad', ' ', -1).replace(r'~', ' ', -1).replace(r'\,', ' ', -1)
42
+ # Remove $
43
+ sympy = sympy.replace(r'$', ' ', -1)
44
+
45
+ # variable values
46
+ global VARIABLE_VALUES
47
+ if len(variable_values) > 0:
48
+ VARIABLE_VALUES = variable_values
49
+ else:
50
+ VARIABLE_VALUES = {}
51
+
52
+ # setup listener
53
+ matherror = MathErrorListener(sympy)
54
+
55
+ # stream input
56
+ stream = InputStream(sympy)
57
+ lex = PSLexer(stream)
58
+ lex.removeErrorListeners()
59
+ lex.addErrorListener(matherror)
60
+
61
+ tokens = CommonTokenStream(lex)
62
+ parser = PSParser(tokens)
63
+
64
+ # remove default console error listener
65
+ parser.removeErrorListeners()
66
+ parser.addErrorListener(matherror)
67
+
68
+ # process the input
69
+ return_data = None
70
+ math = parser.math()
71
+
72
+ # if a list
73
+ if math.relation_list():
74
+ return_data = []
75
+
76
+ # go over list items
77
+ relation_list = math.relation_list().relation_list_content()
78
+ for list_item in relation_list.relation():
79
+ expr = convert_relation(list_item)
80
+ return_data.append(expr)
81
+
82
+ # if not, do default
83
+ else:
84
+ relation = math.relation()
85
+ return_data = convert_relation(relation)
86
+
87
+ return return_data
88
+
89
+
90
+ def math_answer_cleaning(answer, dataset_name):
91
+ """
92
+ remove irrelevant strings and unify the answer format before checking whether the answers are equal
93
+ """
94
+ def _is_completely_wrapped_by_text(input_string):
95
+ pattern = r'^\\text{(.*)}$'
96
+ match = re.match(pattern, input_string)
97
+ if match:
98
+ ## input_string is completely wrapped by \text{}
99
+ extracted_content = match.group(1)
100
+ extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "")
101
+ return extracted_content
102
+ else:
103
+ return None
104
+
105
+ ## remove irrelevant \\text and space
106
+ extracted_content = _is_completely_wrapped_by_text(answer)
107
+ answer = extracted_content if extracted_content else answer
108
+
109
+ ## e.g., convert 5,\!460 into 5460; convert 14{,}916 into 14916 convert \$4 into 4
110
+ answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "")
111
+ ## e.g., convert \dfrac{3}{2} into frac{3}{2}
112
+ answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{")
113
+ ## e.g., convert 121^\circ into 121
114
+ answer = answer.replace("^\circ", "")
115
+ answer = answer.replace("^{\circ}", "")
116
+ ## remove \quad
117
+ answer = answer.replace("\quad", "")
118
+ ## remove space
119
+ answer = answer.replace(" ", "")
120
+ ## remove \n
121
+ answer = answer.replace("\n", "").replace("\\n", "")
122
+ ## e.g., convert 3.54\times10^{10} into 3.54e10
123
+ answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer)
124
+ ## e.g., convert 3.54\times10^10 into 3.54e10
125
+ answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer)
126
+ ## e.g., convert 558\,\text{nm} into 558
127
+ answer = re.sub(r'\\,\\text\{.*?\}', '', answer)
128
+ ## e.g., convert 558\text{nm} into 558
129
+ answer = re.sub(r'\\text\{.*?\}', '', answer)
130
+ ## e.g., convert 2^{10} into 2^10
131
+ answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer)
132
+ ## lowercase
133
+ answer = answer.lower()
134
+
135
+ if dataset_name == "collegemath":
136
+ ## convert 558\mathrm{ft} into 558
137
+ answer = re.sub(r'\\mathrm\{.*?\}', '', answer)
138
+ ## clean noisy answer
139
+ answer = re.sub(r'\$\([^)]*\)', '', answer)
140
+ if answer.endswith("-"):
141
+ answer = answer[:-1]
142
+ if answer.endswith("."):
143
+ answer = answer[:-1]
144
+ if answer.endswith("hours"):
145
+ answer = answer[:-len("hours")]
146
+ ## extract final answer after '=' or ':'
147
+ if "=" in answer:
148
+ answer = answer.split("=", 1)[1]
149
+ if ":" in answer:
150
+ answer = answer.split(":", 1)[1]
151
+ ## \emptyset and \oslash both reprsent empty set in latex
152
+ answer = answer.replace("\\emptyset", "\\oslash")
153
+ if dataset_name == "gsm8k":
154
+ # Example: 5,600 -> 5600
155
+ answer = answer.replace(',', '')
156
+ if dataset_name == "gaokao2023en":
157
+ unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes']
158
+ for unit in unit_strings:
159
+ answer = answer.replace(unit, "")
160
+
161
+ return answer
162
+
163
+
164
+ def extract_final_answer(output):
165
+ pattern_re = re.compile(r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}", re.DOTALL)
166
+ all_matches = pattern_re.findall(output)
167
+
168
+ if len(all_matches) >= 1:
169
+ extracted_answer = all_matches[-1]
170
+ else:
171
+ extracted_answer = None
172
+
173
+ return extracted_answer, all_matches
174
+
175
+
176
+ def round_number(answer):
177
+ def _is_float(string):
178
+ try:
179
+ float(string)
180
+ return True
181
+ except:
182
+ return False
183
+
184
+ if _is_float(answer) and float(answer) < 1:
185
+ ## to consider the case like 5.56e-10 (convert 5.56e-10 into 5.6e-10)
186
+ ## still return a string type
187
+ return f"{float(answer):.2g}"
188
+
189
+ return answer
190
+
191
+
192
+ def choice_answer_clean(pred: str):
193
+ pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
194
+ # Clean the answer based on the dataset
195
+ tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
196
+ if tmp:
197
+ pred = tmp
198
+ else:
199
+ pred = [pred.strip().strip(".")]
200
+ pred = pred[-1]
201
+ # Remove the period at the end, again!
202
+ pred = pred.rstrip(".").rstrip("/")
203
+ return pred
204
+
205
+
206
+ def parse_digits(num):
207
+ num = regex.sub(",", "", str(num))
208
+ try:
209
+ return float(num)
210
+ except:
211
+ if num.endswith("%"):
212
+ num = num[:-1]
213
+ if num.endswith("\\"):
214
+ num = num[:-1]
215
+ try:
216
+ return float(num) / 100
217
+ except:
218
+ pass
219
+ return None
220
+
221
+
222
+ def is_digit(num):
223
+ # paired with parse_digits
224
+ return parse_digits(num) is not None
225
+
226
+
227
+ def str_to_pmatrix(input_str):
228
+ input_str = input_str.strip()
229
+ matrix_str = re.findall(r"\{.*,.*\}", input_str)
230
+ pmatrix_list = []
231
+
232
+ for m in matrix_str:
233
+ m = m.strip("{}")
234
+ pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
235
+ pmatrix_list.append(pmatrix)
236
+
237
+ return ", ".join(pmatrix_list)
238
+
239
+
240
+ def math_equal(
241
+ prediction: Union[bool, float, str],
242
+ reference: Union[float, str],
243
+ include_percentage: bool = True,
244
+ is_close: bool = True,
245
+ timeout: bool = False,
246
+ ) -> bool:
247
+ """
248
+ Exact match of math if and only if:
249
+ 1. numerical equal: both can convert to float and are equal
250
+ 2. symbolic equal: both can convert to sympy expression and are equal
251
+ """
252
+ if prediction is None or reference is None:
253
+ return False
254
+ if str(prediction.strip().lower()) == str(reference.strip().lower()):
255
+ return True
256
+ if (
257
+ reference in ["A", "B", "C", "D", "E"]
258
+ and choice_answer_clean(prediction) == reference
259
+ ):
260
+ return True
261
+
262
+ # fraction equal
263
+ if fraction_equal(prediction, reference):
264
+ return True
265
+
266
+ try: # numerical equal
267
+ if round_number(prediction) == round_number(reference):
268
+ return True
269
+ if is_digit(prediction) and is_digit(reference):
270
+ prediction = parse_digits(prediction)
271
+ reference = parse_digits(reference)
272
+ # number questions
273
+ if include_percentage:
274
+ gt_result = [reference / 100, reference, reference * 100]
275
+ else:
276
+ gt_result = [reference]
277
+ for item in gt_result:
278
+ try:
279
+ if is_close:
280
+ if numeric_equal(prediction, item):
281
+ return True
282
+ else:
283
+ if item == prediction:
284
+ return True
285
+ except Exception:
286
+ continue
287
+ return False
288
+ except:
289
+ pass
290
+
291
+ if not prediction and prediction not in [0, False]:
292
+ return False
293
+
294
+ # symbolic equal
295
+ reference = str(reference).strip()
296
+ prediction = str(prediction).strip()
297
+
298
+ ## pmatrix (amps)
299
+ if "pmatrix" in prediction and not "pmatrix" in reference:
300
+ reference = str_to_pmatrix(reference)
301
+
302
+ ## deal with [], (), {}
303
+ pred_str, ref_str = prediction, reference
304
+ if (
305
+ prediction.startswith("[")
306
+ and prediction.endswith("]")
307
+ and not reference.startswith("(")
308
+ ) or (
309
+ prediction.startswith("(")
310
+ and prediction.endswith(")")
311
+ and not reference.startswith("[")
312
+ ):
313
+ pred_str = pred_str.strip("[]()")
314
+ ref_str = ref_str.strip("[]()")
315
+ for s in ["{", "}", "(", ")"]:
316
+ ref_str = ref_str.replace(s, "")
317
+ pred_str = pred_str.replace(s, "")
318
+ if pred_str.lower() == ref_str.lower():
319
+ return True
320
+
321
+ ## [a, b] vs. [c, d], return a==c and b==d
322
+ if (
323
+ regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
324
+ and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
325
+ ):
326
+ pred_parts = prediction[1:-1].split(",")
327
+ ref_parts = reference[1:-1].split(",")
328
+ if len(pred_parts) == len(ref_parts):
329
+ if all(
330
+ [
331
+ math_equal(
332
+ pred_parts[i], ref_parts[i], include_percentage, is_close
333
+ )
334
+ for i in range(len(pred_parts))
335
+ ]
336
+ ):
337
+ return True
338
+ if (
339
+ (
340
+ prediction.startswith("\\begin{pmatrix}")
341
+ or prediction.startswith("\\begin{bmatrix}")
342
+ )
343
+ and (
344
+ prediction.endswith("\\end{pmatrix}")
345
+ or prediction.endswith("\\end{bmatrix}")
346
+ )
347
+ and (
348
+ reference.startswith("\\begin{pmatrix}")
349
+ or reference.startswith("\\begin{bmatrix}")
350
+ )
351
+ and (
352
+ reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
353
+ )
354
+ ):
355
+ pred_lines = [
356
+ line.strip()
357
+ for line in prediction[
358
+ len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
359
+ ].split("\\\\")
360
+ if line.strip()
361
+ ]
362
+ ref_lines = [
363
+ line.strip()
364
+ for line in reference[
365
+ len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
366
+ ].split("\\\\")
367
+ if line.strip()
368
+ ]
369
+ matched = True
370
+ if len(pred_lines) == len(ref_lines):
371
+ for pred_line, ref_line in zip(pred_lines, ref_lines):
372
+ pred_parts = pred_line.split("&")
373
+ ref_parts = ref_line.split("&")
374
+ if len(pred_parts) == len(ref_parts):
375
+ if not all(
376
+ [
377
+ math_equal(
378
+ pred_parts[i],
379
+ ref_parts[i],
380
+ include_percentage,
381
+ is_close,
382
+ )
383
+ for i in range(len(pred_parts))
384
+ ]
385
+ ):
386
+ matched = False
387
+ break
388
+ else:
389
+ matched = False
390
+ if not matched:
391
+ break
392
+ else:
393
+ matched = False
394
+ if matched:
395
+ return True
396
+
397
+ if prediction.count("=") == 1 and reference.count("=") == 1:
398
+ pred = prediction.split("=")
399
+ pred = f"{pred[0].strip()} - ({pred[1].strip()})"
400
+ ref = reference.split("=")
401
+ ref = f"{ref[0].strip()} - ({ref[1].strip()})"
402
+ if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
403
+ return True
404
+ elif (
405
+ prediction.count("=") == 1
406
+ and len(prediction.split("=")[0].strip()) <= 2
407
+ and "=" not in reference
408
+ ):
409
+ if math_equal(
410
+ prediction.split("=")[1], reference, include_percentage, is_close
411
+ ):
412
+ return True
413
+ elif (
414
+ reference.count("=") == 1
415
+ and len(reference.split("=")[0].strip()) <= 2
416
+ and "=" not in prediction
417
+ ):
418
+ if math_equal(
419
+ prediction, reference.split("=")[1], include_percentage, is_close
420
+ ):
421
+ return True
422
+
423
+ # symbolic equal with sympy
424
+ if timeout:
425
+ if call_with_timeout(symbolic_equal_process, prediction, reference):
426
+ return True
427
+ else:
428
+ if symbolic_equal(prediction, reference):
429
+ return True
430
+
431
+ return False
432
+
433
+
434
+ def numeric_equal(prediction: float, reference: float):
435
+ # Note that relative tolerance has significant impact
436
+ # on the result of the synthesized GSM-Hard dataset
437
+ # if reference.is_integer():
438
+ # return isclose(reference, round(prediction), abs_tol=1e-4)
439
+ # else:
440
+ # prediction = round(prediction, len(str(reference).split(".")[-1]))
441
+ return isclose(reference, prediction, rel_tol=1e-4)
442
+
443
+
444
+ def fraction_equal(prediction, reference):
445
+ def _calculate_numbers(input_string):
446
+ try:
447
+ result = eval(input_string)
448
+ return result
449
+ except:
450
+ return None
451
+
452
+ reference = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', reference)
453
+ prediction = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', prediction)
454
+
455
+ if reference == prediction:
456
+ return True
457
+
458
+ reference = _calculate_numbers(reference)
459
+ prediction = _calculate_numbers(prediction)
460
+
461
+ if reference and reference == prediction:
462
+ return True
463
+
464
+ return False
465
+
466
+ def symbolic_equal(a, b):
467
+ def _parse(s):
468
+ for f in [parse_latex, parse_expr, latex2sympy]:
469
+ try:
470
+ return f(s.replace("\\\\", "\\"))
471
+ except:
472
+ try:
473
+ return f(s)
474
+ except:
475
+ pass
476
+ return s
477
+
478
+ a = _parse(a)
479
+ b = _parse(b)
480
+
481
+ # direct equal
482
+ try:
483
+ if str(a) == str(b) or a == b:
484
+ return True
485
+ except:
486
+ pass
487
+
488
+ # simplify equal
489
+ try:
490
+ if a.equals(b) or simplify(a - b) == 0:
491
+ return True
492
+ except:
493
+ pass
494
+
495
+ # equation equal
496
+ try:
497
+ if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
498
+ return True
499
+ except:
500
+ pass
501
+
502
+ try:
503
+ if numeric_equal(float(N(a)), float(N(b))):
504
+ return True
505
+ except:
506
+ pass
507
+
508
+ # matrix
509
+ try:
510
+ # if a and b are matrix
511
+ if a.shape == b.shape:
512
+ _a = a.applyfunc(lambda x: round(x, 3))
513
+ _b = b.applyfunc(lambda x: round(x, 3))
514
+ if _a.equals(_b):
515
+ return True
516
+ except:
517
+ pass
518
+
519
+ return False
520
+
521
+
522
+ def symbolic_equal_process(a, b, output_queue):
523
+ result = symbolic_equal(a, b)
524
+ output_queue.put(result)
525
+
526
+
527
+ def math_equal_process(prediction, reference, output_queue):
528
+ result = math_equal(prediction, reference, timeout=True)
529
+ output_queue.put(result)
530
+
531
+
532
+ def call_with_timeout(func, *args, timeout=1, **kwargs):
533
+ output_queue = multiprocessing.Queue()
534
+ process_args = args + (output_queue,)
535
+ process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
536
+ process.start()
537
+ process.join(timeout)
538
+
539
+ if process.is_alive():
540
+ process.terminate()
541
+ process.join()
542
+ return False
543
+
544
+ return output_queue.get()
545
+
546
+
547
+ def check_correctness_of_multiple_answer_cases(prediction, reference, all_matches):
548
+
549
+ if prediction.replace(",", "").replace("$", "") == reference.replace(",", "").replace("$", ""):
550
+ return True
551
+
552
+ if not prediction.split("=")[-1] == reference.split("=")[-1].replace("$", ""):
553
+ return False
554
+
555
+ if "," in reference or "or" in reference or "and" in reference:
556
+ ## there are multiple answers
557
+ if len(all_matches) <= 1:
558
+ return False
559
+
560
+ prediction1 = prediction.split("=")[-1]
561
+ prediction2 = all_matches[-2].split("=")[-1]
562
+ reference = reference.replace("$", "")
563
+ if "or" in reference:
564
+ gold_list = reference.split("or", 1)
565
+ elif "and" in reference:
566
+ gold_list = reference.split("and", 1)
567
+ else:
568
+ gold_list = reference.split(",", 1)
569
+
570
+ reference1 = gold_list[-1].split("=")[-1]
571
+ reference2 = gold_list[-2].split("=")[-1]
572
+
573
+ if math_equal(prediction1, reference1) and math_equal(prediction2, reference2):
574
+ return True
575
+ elif math_equal(prediction2, reference1) and math_equal(prediction1, reference2):
576
+ return True
577
+
578
+ return False
579
+
580
+ else:
581
+ return True
582
+
583
+
584
+ def is_equal(model_output, reference, dataset_name):
585
+
586
+ extracted_model_answer, all_matches = extract_final_answer(model_output)
587
+ if extracted_model_answer is None or reference is None:
588
+ return False
589
+
590
+ extracted_model_answer = math_answer_cleaning(extracted_model_answer, dataset_name)
591
+ reference = math_answer_cleaning(reference, dataset_name)
592
+
593
+ # if math_equal(prediction, reference, timeout=True):
594
+ if call_with_timeout(math_equal_process, extracted_model_answer, reference):
595
+ return True
596
+
597
+ if dataset_name == "collegemath":
598
+ return check_correctness_of_multiple_answer_cases(extracted_model_answer, reference, all_matches)
599
+
600
+ return False