# Copyright 2023 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implements the combination DD+AR.""" import time from absl import logging import dd import graph as gh import problem as pr from problem import Dependency # pylint: disable=g-importing-member import trace_back def saturate_or_goal( g: gh.Graph, theorems: list[pr.Theorem], level_times: list[float], p: pr.Problem, max_level: int = 100, timeout: int = 600, ) -> tuple[ list[dict[str, list[tuple[gh.Point, ...]]]], list[dict[str, list[tuple[gh.Point, ...]]]], list[int], list[pr.Dependency], ]: """Run DD until saturation or goal found.""" derives = [] eq4s = [] branching = [] all_added = [] while len(level_times) < max_level: level = len(level_times) + 1 t = time.time() added, derv, eq4, n_branching = dd.bfs_one_level( g, theorems, level, p, verbose=False, nm_check=True, timeout=timeout ) all_added += added branching.append(n_branching) derives.append(derv) eq4s.append(eq4) level_time = time.time() - t logging.info(f'Depth {level}/{max_level} time = {level_time}') # pylint: disable=logging-fstring-interpolation level_times.append(level_time) if p.goal is not None: goal_args = list(map(lambda x: g.get(x, lambda: int(x)), p.goal.args)) if g.check(p.goal.name, goal_args): # found goal break if not added: # saturated break if level_time > timeout: break return derives, eq4s, branching, all_added def solve( g: gh.Graph, theorems: list[pr.Problem], controller: pr.Problem, max_level: int = 1000, timeout: int = 600, ) -> tuple[gh.Graph, list[float], str, list[int], list[pr.Dependency]]: """Alternate between DD and AR until goal is found.""" status = 'saturated' level_times = [] dervs, eq4 = g.derive_algebra(level=0, verbose=False) derives = [dervs] eq4s = [eq4] branches = [] all_added = [] while len(level_times) < max_level: dervs, eq4, next_branches, added = saturate_or_goal( g, theorems, level_times, controller, max_level, timeout=timeout ) all_added += added derives += dervs eq4s += eq4 branches += next_branches # Now, it is either goal or saturated if controller.goal is not None: goal_args = g.names2points(controller.goal.args) if g.check(controller.goal.name, goal_args): # found goal status = 'solved' break if not derives: # officially saturated. logging.info("derives empty, breaking") break # Now we resort to algebra derivations. added = [] while derives and not added: added += dd.apply_derivations(g, derives.pop(0)) if added: continue # Final help from AR. while eq4s and not added: added += dd.apply_derivations(g, eq4s.pop(0)) all_added += added if not added: # Nothing left. saturated. logging.info("Nothing added, breaking") break return g, level_times, status, branches, all_added def get_proof_steps( g: gh.Graph, goal: pr.Clause, merge_trivials: bool = False ) -> tuple[ list[pr.Dependency], list[pr.Dependency], list[tuple[list[pr.Dependency], list[pr.Dependency]]], dict[tuple[str, ...], int], ]: """Extract proof steps from the built DAG.""" goal_args = g.names2nodes(goal.args) query = Dependency(goal.name, goal_args, None, None) setup, aux, log, setup_points = trace_back.get_logs( query, g, merge_trivials=merge_trivials ) refs = {} setup = trace_back.point_log(setup, refs, set()) aux = trace_back.point_log(aux, refs, setup_points) setup = [(prems, [tuple(p)]) for p, prems in setup] aux = [(prems, [tuple(p)]) for p, prems in aux] return setup, aux, log, refs