HugoVoxx's picture
Upload 96 files
be3b34d verified
raw
history blame
9.8 kB
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements DAG-level traceback."""
from typing import Any
import geometry as gm
import pretty as pt
import problem
pretty = pt.pretty
def point_levels(
setup: list[problem.Dependency], existing_points: list[gm.Point]
) -> list[tuple[set[gm.Point], list[problem.Dependency]]]:
"""Reformat setup into levels of point constructions."""
levels = []
for con in setup:
plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)])
while len(levels) - 1 < plevel:
levels.append((set(), []))
for p in con.args:
if not isinstance(p, gm.Point):
continue
if existing_points and p in existing_points:
continue
levels[p.plevel][0].add(p)
cons = levels[plevel][1]
cons.append(con)
return [(p, c) for p, c in levels if p or c]
def point_log(
setup: list[problem.Dependency],
ref_id: dict[tuple[str, ...], int],
existing_points=list[gm.Point],
) -> list[tuple[list[gm.Point], list[problem.Dependency]]]:
"""Reformat setup into groups of point constructions."""
log = []
levels = point_levels(setup, existing_points)
for points, cons in levels:
for con in cons:
if con.hashed() not in ref_id:
ref_id[con.hashed()] = len(ref_id)
log.append((points, cons))
return log
def setup_to_levels(
setup: list[problem.Dependency],
) -> list[list[problem.Dependency]]:
"""Reformat setup into levels of point constructions."""
levels = []
for d in setup:
plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)])
while len(levels) - 1 < plevel:
levels.append([])
levels[plevel].append(d)
levels = [lvl for lvl in levels if lvl]
return levels
def separate_dependency_difference(
query: problem.Dependency,
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
) -> tuple[
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
list[problem.Dependency],
list[problem.Dependency],
set[gm.Point],
set[gm.Point],
]:
"""Identify and separate the dependency difference."""
setup = []
log_, log = log, []
for prems, cons in log_:
if not prems:
setup.extend(cons)
continue
cons_ = []
for con in cons:
if con.rule_name == 'c0':
setup.append(con)
else:
cons_.append(con)
if not cons_:
continue
prems = [p for p in prems if p.name != 'ind']
log.append((prems, cons_))
points = set(query.args)
queue = list(query.args)
i = 0
while i < len(queue):
q = queue[i]
i += 1
if not isinstance(q, gm.Point):
continue
for p in q.rely_on:
if p not in points:
points.add(p)
queue.append(p)
setup_, setup, aux_setup, aux_points = setup, [], [], set()
for con in setup_:
if con.name == 'ind':
continue
elif any([p not in points for p in con.args if isinstance(p, gm.Point)]):
aux_setup.append(con)
aux_points.update(
[p for p in con.args if isinstance(p, gm.Point) and p not in points]
)
else:
setup.append(con)
return log, setup, aux_setup, points, aux_points
def recursive_traceback(
query: problem.Dependency,
) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]:
"""Recursively traceback from the query, i.e. the conclusion."""
visited = set()
log = []
stack = []
def read(q: problem.Dependency) -> None:
q = q.remove_loop()
hashed = q.hashed()
if hashed in visited:
return
if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']:
return
nonlocal stack
stack.append(hashed)
prems = []
if q.rule_name != problem.CONSTRUCTION_RULE:
all_deps = []
dep_names = set()
for d in q.why:
if d.hashed() in dep_names:
continue
dep_names.add(d.hashed())
all_deps.append(d)
for d in all_deps:
h = d.hashed()
if h not in visited:
read(d)
if h in visited:
prems.append(d)
visited.add(hashed)
hashs = sorted([d.hashed() for d in prems])
found = False
for ps, qs in log:
if sorted([d.hashed() for d in ps]) == hashs:
qs += [q]
found = True
break
if not found:
log.append((prems, [q]))
stack.pop(-1)
read(query)
# post process log: separate multi-conclusion lines
log_, log = log, []
for ps, qs in log_:
for q in qs:
log.append((ps, [q]))
return log
def collx_to_coll_setup(
setup: list[problem.Dependency],
) -> list[problem.Dependency]:
"""Convert collx to coll in setups."""
result = []
for level in setup_to_levels(setup):
hashs = set()
for dep in level:
if dep.name == 'collx':
dep.name = 'coll'
dep.args = list(set(dep.args))
if dep.hashed() in hashs:
continue
hashs.add(dep.hashed())
result.append(dep)
return result
def collx_to_coll(
setup: list[problem.Dependency],
aux_setup: list[problem.Dependency],
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
) -> tuple[
list[problem.Dependency],
list[problem.Dependency],
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
]:
"""Convert collx to coll and dedup."""
setup = collx_to_coll_setup(setup)
aux_setup = collx_to_coll_setup(aux_setup)
con_set = set([p.hashed() for p in setup + aux_setup])
log_, log = log, []
for prems, cons in log_:
prem_set = set()
prems_, prems = prems, []
for p in prems_:
if p.name == 'collx':
p.name = 'coll'
p.args = list(set(p.args))
if p.hashed() in prem_set:
continue
prem_set.add(p.hashed())
prems.append(p)
cons_, cons = cons, []
for c in cons_:
if c.name == 'collx':
c.name = 'coll'
c.args = list(set(c.args))
if c.hashed() in con_set:
continue
con_set.add(c.hashed())
cons.append(c)
if not cons or not prems:
continue
log.append((prems, cons))
return setup, aux_setup, log
def get_logs(
query: problem.Dependency, g: Any, merge_trivials: bool = False
) -> tuple[
list[problem.Dependency],
list[problem.Dependency],
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
set[gm.Point],
]:
"""Given a DAG and conclusion N, return the premise, aux, proof."""
query = query.why_me_or_cache(g, query.level)
log = recursive_traceback(query)
log, setup, aux_setup, setup_points, _ = separate_dependency_difference(
query, log
)
setup, aux_setup, log = collx_to_coll(setup, aux_setup, log)
setup, aux_setup, log = shorten_and_shave(
setup, aux_setup, log, merge_trivials
)
return setup, aux_setup, log, setup_points
def shorten_and_shave(
setup: list[problem.Dependency],
aux_setup: list[problem.Dependency],
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
merge_trivials: bool = False,
) -> tuple[
list[problem.Dependency],
list[problem.Dependency],
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
]:
"""Shorten the proof by removing unused predicates."""
log, _ = shorten_proof(log, merge_trivials=merge_trivials)
all_prems = sum([list(prems) for prems, _ in log], [])
all_prems = set([p.hashed() for p in all_prems])
setup = [d for d in setup if d.hashed() in all_prems]
aux_setup = [d for d in aux_setup if d.hashed() in all_prems]
return setup, aux_setup, log
def join_prems(
con: problem.Dependency,
con2prems: dict[tuple[str, ...], list[problem.Dependency]],
expanded: set[tuple[str, ...]],
) -> list[problem.Dependency]:
"""Join proof steps with the same premises."""
h = con.hashed()
if h in expanded or h not in con2prems:
return [con]
result = []
for p in con2prems[h]:
result += join_prems(p, con2prems, expanded)
return result
def shorten_proof(
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
merge_trivials: bool = False,
) -> tuple[
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
dict[tuple[str, ...], list[problem.Dependency]],
]:
"""Join multiple trivials proof steps into one."""
pops = set()
con2prem = {}
for prems, cons in log:
assert len(cons) == 1
con = cons[0]
if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison
con2prem[con.hashed()] = prems
elif not merge_trivials:
# except for the ones that are premises to non-trivial steps.
pops.update({p.hashed() for p in prems})
for p in pops:
if p in con2prem:
con2prem.pop(p)
expanded = set()
log2 = []
for i, (prems, cons) in enumerate(log):
con = cons[0]
if i < len(log) - 1 and con.hashed() in con2prem:
continue
hashs = set()
new_prems = []
for p in sum([join_prems(p, con2prem, expanded) for p in prems], []):
if p.hashed() not in hashs:
new_prems.append(p)
hashs.add(p.hashed())
log2 += [(new_prems, [con])]
expanded.add(con.hashed())
return log2, con2prem