""" This script is adapted from Qwen2.5-Math https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/grader.py """ import re import regex import multiprocessing from math import isclose from typing import Union from collections import defaultdict from sympy import simplify, N from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.latex import parse_latex def latex2sympy(sympy: str, variable_values={}): # record frac global frac_type if sympy.find(r'\frac') != -1: frac_type = r'\frac' if sympy.find(r'\dfrac') != -1: frac_type = r'\dfrac' if sympy.find(r'\tfrac') != -1: frac_type = r'\tfrac' sympy = sympy.replace(r'\dfrac', r'\frac') sympy = sympy.replace(r'\tfrac', r'\frac') # Translate Transpose sympy = sympy.replace(r'\mathrm{T}', 'T', -1) # Translate Derivative sympy = sympy.replace(r'\mathrm{d}', 'd', -1).replace(r'{\rm d}', 'd', -1) # Translate Matrix sympy = sympy.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1) # Translate Permutation sympy = re.sub(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}", r"\\frac{(\1)!}{((\1)-(\2))!}", sympy) # Remove \displaystyle sympy = sympy.replace(r'\displaystyle', ' ', -1) # Remove \quad sympy = sympy.replace(r'\quad', ' ', -1).replace(r'\qquad', ' ', -1).replace(r'~', ' ', -1).replace(r'\,', ' ', -1) # Remove $ sympy = sympy.replace(r'$', ' ', -1) # variable values global VARIABLE_VALUES if len(variable_values) > 0: VARIABLE_VALUES = variable_values else: VARIABLE_VALUES = {} # setup listener matherror = MathErrorListener(sympy) # stream input stream = InputStream(sympy) lex = PSLexer(stream) lex.removeErrorListeners() lex.addErrorListener(matherror) tokens = CommonTokenStream(lex) parser = PSParser(tokens) # remove default console error listener parser.removeErrorListeners() parser.addErrorListener(matherror) # process the input return_data = None math = parser.math() # if a list if math.relation_list(): return_data = [] # go over list items relation_list = math.relation_list().relation_list_content() for list_item in relation_list.relation(): expr = convert_relation(list_item) return_data.append(expr) # if not, do default else: relation = math.relation() return_data = convert_relation(relation) return return_data def math_answer_cleaning(answer, dataset_name): """ remove irrelevant strings and unify the answer format before checking whether the answers are equal """ def _is_completely_wrapped_by_text(input_string): pattern = r'^\\text{(.*)}$' match = re.match(pattern, input_string) if match: ## input_string is completely wrapped by \text{} extracted_content = match.group(1) extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "") return extracted_content else: return None ## remove irrelevant \\text and space extracted_content = _is_completely_wrapped_by_text(answer) answer = extracted_content if extracted_content else answer ## e.g., convert 5,\!460 into 5460; convert 14{,}916 into 14916 convert \$4 into 4 answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "") ## e.g., convert \dfrac{3}{2} into frac{3}{2} answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{") ## e.g., convert 121^\circ into 121 answer = answer.replace("^\circ", "") answer = answer.replace("^{\circ}", "") ## remove \quad answer = answer.replace("\quad", "") ## remove space answer = answer.replace(" ", "") ## remove \n answer = answer.replace("\n", "").replace("\\n", "") ## e.g., convert 3.54\times10^{10} into 3.54e10 answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer) ## e.g., convert 3.54\times10^10 into 3.54e10 answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer) ## e.g., convert 558\,\text{nm} into 558 answer = re.sub(r'\\,\\text\{.*?\}', '', answer) ## e.g., convert 558\text{nm} into 558 answer = re.sub(r'\\text\{.*?\}', '', answer) ## e.g., convert 2^{10} into 2^10 answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer) ## lowercase answer = answer.lower() if dataset_name == "collegemath": ## convert 558\mathrm{ft} into 558 answer = re.sub(r'\\mathrm\{.*?\}', '', answer) ## clean noisy answer answer = re.sub(r'\$\([^)]*\)', '', answer) if answer.endswith("-"): answer = answer[:-1] if answer.endswith("."): answer = answer[:-1] if answer.endswith("hours"): answer = answer[:-len("hours")] ## extract final answer after '=' or ':' if "=" in answer: answer = answer.split("=", 1)[1] if ":" in answer: answer = answer.split(":", 1)[1] ## \emptyset and \oslash both reprsent empty set in latex answer = answer.replace("\\emptyset", "\\oslash") if dataset_name == "gsm8k": # Example: 5,600 -> 5600 answer = answer.replace(',', '') if dataset_name == "gaokao2023en": unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes'] for unit in unit_strings: answer = answer.replace(unit, "") return answer def extract_final_answer(output): pattern_re = re.compile(r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}", re.DOTALL) all_matches = pattern_re.findall(output) if len(all_matches) >= 1: extracted_answer = all_matches[-1] else: extracted_answer = None return extracted_answer, all_matches def round_number(answer): def _is_float(string): try: float(string) return True except: return False if _is_float(answer) and float(answer) < 1: ## to consider the case like 5.56e-10 (convert 5.56e-10 into 5.6e-10) ## still return a string type return f"{float(answer):.2g}" return answer def choice_answer_clean(pred: str): pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") # Clean the answer based on the dataset tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) if tmp: pred = tmp else: pred = [pred.strip().strip(".")] pred = pred[-1] # Remove the period at the end, again! pred = pred.rstrip(".").rstrip("/") return pred def parse_digits(num): num = regex.sub(",", "", str(num)) try: return float(num) except: if num.endswith("%"): num = num[:-1] if num.endswith("\\"): num = num[:-1] try: return float(num) / 100 except: pass return None def is_digit(num): # paired with parse_digits return parse_digits(num) is not None def str_to_pmatrix(input_str): input_str = input_str.strip() matrix_str = re.findall(r"\{.*,.*\}", input_str) pmatrix_list = [] for m in matrix_str: m = m.strip("{}") pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" pmatrix_list.append(pmatrix) return ", ".join(pmatrix_list) def math_equal( prediction: Union[bool, float, str], reference: Union[float, str], include_percentage: bool = True, is_close: bool = True, timeout: bool = False, ) -> bool: """ Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal 2. symbolic equal: both can convert to sympy expression and are equal """ if prediction is None or reference is None: return False if str(prediction.strip().lower()) == str(reference.strip().lower()): return True if ( reference in ["A", "B", "C", "D", "E"] and choice_answer_clean(prediction) == reference ): return True # fraction equal if fraction_equal(prediction, reference): return True try: # numerical equal if round_number(prediction) == round_number(reference): return True if is_digit(prediction) and is_digit(reference): prediction = parse_digits(prediction) reference = parse_digits(reference) # number questions if include_percentage: gt_result = [reference / 100, reference, reference * 100] else: gt_result = [reference] for item in gt_result: try: if is_close: if numeric_equal(prediction, item): return True else: if item == prediction: return True except Exception: continue return False except: pass if not prediction and prediction not in [0, False]: return False # symbolic equal reference = str(reference).strip() prediction = str(prediction).strip() ## pmatrix (amps) if "pmatrix" in prediction and not "pmatrix" in reference: reference = str_to_pmatrix(reference) ## deal with [], (), {} pred_str, ref_str = prediction, reference if ( prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(") ) or ( prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") ): pred_str = pred_str.strip("[]()") ref_str = ref_str.strip("[]()") for s in ["{", "}", "(", ")"]: ref_str = ref_str.replace(s, "") pred_str = pred_str.replace(s, "") if pred_str.lower() == ref_str.lower(): return True ## [a, b] vs. [c, d], return a==c and b==d if ( regex.match(r"(\(|\[).+(\)|\])", prediction) is not None and regex.match(r"(\(|\[).+(\)|\])", reference) is not None ): pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") if len(pred_parts) == len(ref_parts): if all( [ math_equal( pred_parts[i], ref_parts[i], include_percentage, is_close ) for i in range(len(pred_parts)) ] ): return True if ( ( prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}") ) and ( prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}") ) and ( reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}") ) and ( reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") ) ): pred_lines = [ line.strip() for line in prediction[ len("\\begin{pmatrix}") : -len("\\end{pmatrix}") ].split("\\\\") if line.strip() ] ref_lines = [ line.strip() for line in reference[ len("\\begin{pmatrix}") : -len("\\end{pmatrix}") ].split("\\\\") if line.strip() ] matched = True if len(pred_lines) == len(ref_lines): for pred_line, ref_line in zip(pred_lines, ref_lines): pred_parts = pred_line.split("&") ref_parts = ref_line.split("&") if len(pred_parts) == len(ref_parts): if not all( [ math_equal( pred_parts[i], ref_parts[i], include_percentage, is_close, ) for i in range(len(pred_parts)) ] ): matched = False break else: matched = False if not matched: break else: matched = False if matched: return True if prediction.count("=") == 1 and reference.count("=") == 1: pred = prediction.split("=") pred = f"{pred[0].strip()} - ({pred[1].strip()})" ref = reference.split("=") ref = f"{ref[0].strip()} - ({ref[1].strip()})" if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): return True elif ( prediction.count("=") == 1 and len(prediction.split("=")[0].strip()) <= 2 and "=" not in reference ): if math_equal( prediction.split("=")[1], reference, include_percentage, is_close ): return True elif ( reference.count("=") == 1 and len(reference.split("=")[0].strip()) <= 2 and "=" not in prediction ): if math_equal( prediction, reference.split("=")[1], include_percentage, is_close ): return True # symbolic equal with sympy if timeout: if call_with_timeout(symbolic_equal_process, prediction, reference): return True else: if symbolic_equal(prediction, reference): return True return False def numeric_equal(prediction: float, reference: float): # Note that relative tolerance has significant impact # on the result of the synthesized GSM-Hard dataset # if reference.is_integer(): # return isclose(reference, round(prediction), abs_tol=1e-4) # else: # prediction = round(prediction, len(str(reference).split(".")[-1])) return isclose(reference, prediction, rel_tol=1e-4) def fraction_equal(prediction, reference): def _calculate_numbers(input_string): try: result = eval(input_string) return result except: return None reference = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', reference) prediction = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', prediction) if reference == prediction: return True reference = _calculate_numbers(reference) prediction = _calculate_numbers(prediction) if reference and reference == prediction: return True return False def symbolic_equal(a, b): def _parse(s): for f in [parse_latex, parse_expr, latex2sympy]: try: return f(s.replace("\\\\", "\\")) except: try: return f(s) except: pass return s a = _parse(a) b = _parse(b) # direct equal try: if str(a) == str(b) or a == b: return True except: pass # simplify equal try: if a.equals(b) or simplify(a - b) == 0: return True except: pass # equation equal try: if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): return True except: pass try: if numeric_equal(float(N(a)), float(N(b))): return True except: pass # matrix try: # if a and b are matrix if a.shape == b.shape: _a = a.applyfunc(lambda x: round(x, 3)) _b = b.applyfunc(lambda x: round(x, 3)) if _a.equals(_b): return True except: pass return False def symbolic_equal_process(a, b, output_queue): result = symbolic_equal(a, b) output_queue.put(result) def math_equal_process(prediction, reference, output_queue): result = math_equal(prediction, reference, timeout=True) output_queue.put(result) def call_with_timeout(func, *args, timeout=1, **kwargs): output_queue = multiprocessing.Queue() process_args = args + (output_queue,) process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) process.start() process.join(timeout) if process.is_alive(): process.terminate() process.join() return False return output_queue.get() def check_correctness_of_multiple_answer_cases(prediction, reference, all_matches): if prediction.replace(",", "").replace("$", "") == reference.replace(",", "").replace("$", ""): return True if not prediction.split("=")[-1] == reference.split("=")[-1].replace("$", ""): return False if "," in reference or "or" in reference or "and" in reference: ## there are multiple answers if len(all_matches) <= 1: return False prediction1 = prediction.split("=")[-1] prediction2 = all_matches[-2].split("=")[-1] reference = reference.replace("$", "") if "or" in reference: gold_list = reference.split("or", 1) elif "and" in reference: gold_list = reference.split("and", 1) else: gold_list = reference.split(",", 1) reference1 = gold_list[-1].split("=")[-1] reference2 = gold_list[-2].split("=")[-1] if math_equal(prediction1, reference1) and math_equal(prediction2, reference2): return True elif math_equal(prediction2, reference1) and math_equal(prediction1, reference2): return True return False else: return True def is_equal(model_output, reference, dataset_name): extracted_model_answer, all_matches = extract_final_answer(model_output) if extracted_model_answer is None or reference is None: return False extracted_model_answer = math_answer_cleaning(extracted_model_answer, dataset_name) reference = math_answer_cleaning(reference, dataset_name) # if math_equal(prediction, reference, timeout=True): if call_with_timeout(math_equal_process, extracted_model_answer, reference): return True if dataset_name == "collegemath": return check_correctness_of_multiple_answer_cases(extracted_model_answer, reference, all_matches) return False