File size: 7,697 Bytes
246d201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import os
import random
import re
import shutil
from pyke import knowledge_engine
class PykeProgram:
def __init__(
self, logic_program: str, dataset_name='ProntoQA', workspace_mount_path='./'
) -> None:
self.logic_program = logic_program
self.flag = self.parse_logic_program()
self.dataset_name = dataset_name
self.cache_dir = os.path.join(workspace_mount_path, '.cache_program')
# prepare the files for facts and rules
try:
self.create_fact_file(self.Facts)
self.create_rule_file(self.Rules)
self.flag = True
except Exception:
self.flag = False
self.answer_map = {
'ProntoQA': self.answer_map_prontoqa,
'ProofWriter': self.answer_map_proofwriter,
}
def parse_logic_program(self):
keywords = ['Query:', 'Rules:', 'Facts:', 'Predicates:']
program_str = self.logic_program
for keyword in keywords:
try:
program_str, segment_list = self._parse_segment(program_str, keyword)
setattr(self, keyword[:-1], segment_list)
except Exception:
setattr(self, keyword[:-1], None)
return self.validate_program()
def _parse_segment(self, program_str, key_phrase):
remain_program_str, segment = program_str.split(key_phrase)
segment_list = segment.strip().split('\n')
for i in range(len(segment_list)):
segment_list[i] = segment_list[i].split(':::')[0].strip()
return remain_program_str, segment_list
# check if the program is valid; if not, try to fix it
def validate_program(self):
if self.Rules is not None and self.Facts is not None:
if not self.Rules[0] == '' and not self.Facts[0] == '':
return True
# try to fix the program
tmp_rules = []
tmp_facts = []
statements = self.Facts if self.Facts is not None else self.Rules
if statements is None:
return False
for fact in statements:
if fact.find('>>>') >= 0: # this is a rule
tmp_rules.append(fact)
else:
tmp_facts.append(fact)
self.Rules = tmp_rules
self.Facts = tmp_facts
return False
def create_fact_file(self, facts):
with open(os.path.join(self.cache_dir, 'facts.kfb'), 'w') as f:
for fact in facts:
# check for invalid facts
if not fact.find('$x') >= 0:
f.write(fact + '\n')
def create_rule_file(self, rules):
pyke_rules = []
for idx, rule in enumerate(rules):
pyke_rules.append(self.parse_forward_rule(idx + 1, rule))
with open(os.path.join(self.cache_dir, 'rules.krb'), 'w') as f:
f.write('\n\n'.join(pyke_rules))
# example rule: Furry($x, True) && Quite($x, True) >>> White($x, True)
def parse_forward_rule(self, f_index, rule):
premise, conclusion = rule.split('>>>')
premise = premise.strip()
# split the premise into multiple facts if needed
premise = premise.split('&&')
premise_list = [p.strip() for p in premise]
conclusion = conclusion.strip()
# split the conclusion into multiple facts if needed
conclusion = conclusion.split('&&')
conclusion_list = [c.strip() for c in conclusion]
# create the Pyke rule
pyke_rule = f"""fact{f_index}\n\tforeach"""
for p in premise_list:
pyke_rule += f"""\n\t\tfacts.{p}"""
pyke_rule += """\n\tassert"""
for c in conclusion_list:
pyke_rule += f"""\n\t\tfacts.{c}"""
return pyke_rule
"""
for example: Is Marvin from Mars?
Query: FromMars(Marvin, $label)
"""
def check_specific_predicate(self, subject_name, predicate_name, engine):
results = []
with engine.prove_goal(
f'facts.{predicate_name}({subject_name}, $label)'
) as gen:
for vars, plan in gen:
results.append(vars['label'])
with engine.prove_goal(
f'rules.{predicate_name}({subject_name}, $label)'
) as gen:
for vars, plan in gen:
results.append(vars['label'])
if len(results) == 1:
return results[0]
elif len(results) == 2:
return results[0] and results[1]
elif len(results) == 0:
return None
"""
Input Example: Metallic(Wren, False)
"""
def parse_query(self, query):
pattern = r'(\w+)\(([^,]+),\s*([^)]+)\)'
match = re.match(pattern, query)
if match:
function_name = match.group(1)
arg1 = match.group(2)
arg2 = match.group(3)
arg2 = True if arg2 == 'True' else False
return function_name, arg1, arg2
else:
raise ValueError(f'Invalid query: {query}')
def execute_program(self):
# delete the compiled_krb dir
complied_krb_dir = './models/compiled_krb'
if os.path.exists(complied_krb_dir):
print('removing compiled_krb')
# os.system(f'rm -rf {complied_krb_dir}/*')
shutil.rmtree(complied_krb_dir)
# absolute_path = os.path.abspath(complied_krb_dir)
# print(absolute_path)
try:
engine = knowledge_engine.engine(self.cache_dir)
engine.reset()
engine.activate('rules')
engine.get_kb('facts')
# parse the logic query into pyke query
predicate, subject, value_to_check = self.parse_query(self.Query[0])
result = self.check_specific_predicate(subject, predicate, engine)
answer = self.answer_map[self.dataset_name](result, value_to_check)
except Exception as err:
return None, err
return answer, ''
def answer_mapping(self, answer):
return answer
def answer_map_prontoqa(self, result, value_to_check):
if result == value_to_check:
return 'A'
else:
return 'B'
def answer_map_proofwriter(self, result, value_to_check):
if result is None:
return 'C'
elif result == value_to_check:
return 'A'
else:
return 'B'
class LogicInferenceEngine:
def __init__(self):
self.dataset_name = os.environ.get('DATASET_NAME', 'ProofWriter')
self.workspace_mount_path = '/workspace'
def random_backup(self):
if self.dataset_name == 'ProntoQA':
return random.choice(['A', 'B'])
elif self.dataset_name == 'ProofWriter':
return random.choice(['A', 'B', 'C'])
def safe_execute_program(self, logic_program):
program = PykeProgram(
logic_program, self.dataset_name, self.workspace_mount_path
)
# cannot parse the program
if not program.flag:
answer = self.random_backup()
return answer, 'parsing error', ''
# execute the program
answer, error_message = program.execute_program()
# not executable
if answer is None:
answer = self.random_backup()
return answer, 'execution error', error_message
# successfully executed
answer = program.answer_mapping(answer)
return answer, 'success', ''
|