HugoVoxx's picture
Upload 96 files
be3b34d verified
raw
history blame
4.41 kB
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements the combination DD+AR."""
import time
from absl import logging
import dd
import graph as gh
import problem as pr
from problem import Dependency # pylint: disable=g-importing-member
import trace_back
def saturate_or_goal(
g: gh.Graph,
theorems: list[pr.Theorem],
level_times: list[float],
p: pr.Problem,
max_level: int = 100,
timeout: int = 600,
) -> tuple[
list[dict[str, list[tuple[gh.Point, ...]]]],
list[dict[str, list[tuple[gh.Point, ...]]]],
list[int],
list[pr.Dependency],
]:
"""Run DD until saturation or goal found."""
derives = []
eq4s = []
branching = []
all_added = []
while len(level_times) < max_level:
level = len(level_times) + 1
t = time.time()
added, derv, eq4, n_branching = dd.bfs_one_level(
g, theorems, level, p, verbose=False, nm_check=True, timeout=timeout
)
all_added += added
branching.append(n_branching)
derives.append(derv)
eq4s.append(eq4)
level_time = time.time() - t
logging.info(f'Depth {level}/{max_level} time = {level_time}') # pylint: disable=logging-fstring-interpolation
level_times.append(level_time)
if p.goal is not None:
goal_args = list(map(lambda x: g.get(x, lambda: int(x)), p.goal.args))
if g.check(p.goal.name, goal_args): # found goal
break
if not added: # saturated
break
if level_time > timeout:
break
return derives, eq4s, branching, all_added
def solve(
g: gh.Graph,
theorems: list[pr.Problem],
controller: pr.Problem,
max_level: int = 1000,
timeout: int = 600,
) -> tuple[gh.Graph, list[float], str, list[int], list[pr.Dependency]]:
"""Alternate between DD and AR until goal is found."""
status = 'saturated'
level_times = []
dervs, eq4 = g.derive_algebra(level=0, verbose=False)
derives = [dervs]
eq4s = [eq4]
branches = []
all_added = []
while len(level_times) < max_level:
dervs, eq4, next_branches, added = saturate_or_goal(
g, theorems, level_times, controller, max_level, timeout=timeout
)
all_added += added
derives += dervs
eq4s += eq4
branches += next_branches
# Now, it is either goal or saturated
if controller.goal is not None:
goal_args = g.names2points(controller.goal.args)
if g.check(controller.goal.name, goal_args): # found goal
status = 'solved'
break
if not derives: # officially saturated.
logging.info("derives empty, breaking")
break
# Now we resort to algebra derivations.
added = []
while derives and not added:
added += dd.apply_derivations(g, derives.pop(0))
if added:
continue
# Final help from AR.
while eq4s and not added:
added += dd.apply_derivations(g, eq4s.pop(0))
all_added += added
if not added: # Nothing left. saturated.
logging.info("Nothing added, breaking")
break
return g, level_times, status, branches, all_added
def get_proof_steps(
g: gh.Graph, goal: pr.Clause, merge_trivials: bool = False
) -> tuple[
list[pr.Dependency],
list[pr.Dependency],
list[tuple[list[pr.Dependency], list[pr.Dependency]]],
dict[tuple[str, ...], int],
]:
"""Extract proof steps from the built DAG."""
goal_args = g.names2nodes(goal.args)
query = Dependency(goal.name, goal_args, None, None)
setup, aux, log, setup_points = trace_back.get_logs(
query, g, merge_trivials=merge_trivials
)
refs = {}
setup = trace_back.point_log(setup, refs, set())
aux = trace_back.point_log(aux, refs, setup_points)
setup = [(prems, [tuple(p)]) for p, prems in setup]
aux = [(prems, [tuple(p)]) for p, prems in aux]
return setup, aux, log, refs