# Copyright 2023 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implements geometric objects used in the graph representation.""" from __future__ import annotations from collections import defaultdict # pylint: disable=g-importing-member from typing import Any, Type # pylint: disable=protected-access class Node: r"""Node in the proof state graph. Can be Point, Line, Circle, etc. Each node maintains a merge history to other nodes if they are (found out to be) equivalent a -> b - \ c -> d -> e -> f -> g d.merged_to = e d.rep = g d.merged_from = {a, b, c, d} d.equivs = {a, b, c, d, e, f, g} """ def __init__(self, name: str = '', graph: Any = None): self.name = name or str(self) self.graph = graph self.edge_graph = {} # Edge graph: what other nodes is connected to this node. # edge graph = { # other1: {self1: deps, self2: deps}, # other2: {self2: deps, self3: deps} # } self.merge_graph = {} # Merge graph: history of merges with other nodes. # merge_graph = {self1: {self2: deps1, self3: deps2}} self.rep_by = None # represented by. self.members = {self} self._val = None self._obj = None self.deps = [] # numerical representation. self.num = None self.change = set() # what other nodes' num rely on this node? def set_rep(self, node: Node) -> None: if node == self: return self.rep_by = node node.merge_edge_graph(self.edge_graph) node.members.update(self.members) def rep(self) -> Node: x = self while x.rep_by: x = x.rep_by return x def why_rep(self) -> list[Any]: return self.why_equal([self.rep()], None) def rep_and_why(self) -> tuple[Node, list[Any]]: rep = self.rep() return rep, self.why_equal([rep], None) def neighbors( self, oftype: Type[Node], return_set: bool = False, do_rep: bool = True ) -> list[Node]: """Neighbors of this node in the proof state graph.""" if do_rep: rep = self.rep() else: rep = self result = set() for n in rep.edge_graph: if oftype is None or oftype and isinstance(n, oftype): if do_rep: result.add(n.rep()) else: result.add(n) if return_set: return result return list(result) def merge_edge_graph( self, new_edge_graph: dict[Node, dict[Node, list[Node]]] ) -> None: for x, xdict in new_edge_graph.items(): if x in self.edge_graph: self.edge_graph[x].update(dict(xdict)) else: self.edge_graph[x] = dict(xdict) def merge(self, nodes: list[Node], deps: list[Any]) -> None: for node in nodes: self.merge_one(node, deps) def merge_one(self, node: Node, deps: list[Any]) -> None: node.rep().set_rep(self.rep()) if node in self.merge_graph: return self.merge_graph[node] = deps node.merge_graph[self] = deps def is_val(self, node: Node) -> bool: return ( isinstance(self, Line) and isinstance(node, Direction) or isinstance(self, Segment) and isinstance(node, Length) or isinstance(self, Angle) and isinstance(node, Measure) or isinstance(self, Ratio) and isinstance(node, Value) ) def set_val(self, node: Node) -> None: self._val = node def set_obj(self, node: Node) -> None: self._obj = node @property def val(self) -> Node: if self._val is None: return None return self._val.rep() @property def obj(self) -> Node: if self._obj is None: return None return self._obj.rep() def equivs(self) -> set[Node]: return self.rep().members def connect_to(self, node: Node, deps: list[Any] = None) -> None: rep = self.rep() if node in rep.edge_graph: rep.edge_graph[node].update({self: deps}) else: rep.edge_graph[node] = {self: deps} if self.is_val(node): self.set_val(node) node.set_obj(self) def equivs_upto(self, level: int) -> dict[Node, Node]: """What are the equivalent nodes up to a certain level.""" parent = {self: None} visited = set() queue = [self] i = 0 while i < len(queue): current = queue[i] i += 1 visited.add(current) for neighbor in current.merge_graph: if ( level is not None and current.merge_graph[neighbor].level is not None and current.merge_graph[neighbor].level >= level ): continue if neighbor not in visited: queue.append(neighbor) parent[neighbor] = current return parent def why_equal(self, others: list[Node], level: int) -> list[Any]: """BFS why this node is equal to other nodes.""" others = set(others) found = 0 parent = {} queue = [self] i = 0 while i < len(queue): current = queue[i] if current in others: found += 1 if found == len(others): break i += 1 for neighbor in current.merge_graph: if ( level is not None and current.merge_graph[neighbor].level is not None and current.merge_graph[neighbor].level >= level ): continue if neighbor not in parent: queue.append(neighbor) parent[neighbor] = current return bfs_backtrack(self, others, parent) def why_equal_groups( self, groups: list[list[Node]], level: int ) -> tuple[list[Any], list[Node]]: """BFS for why self is equal to at least one member of each group.""" others = [None for _ in groups] found = 0 parent = {} queue = [self] i = 0 while i < len(queue): current = queue[i] for j, grp in enumerate(groups): if others[j] is None and current in grp: others[j] = current found += 1 if found == len(others): break i += 1 for neighbor in current.merge_graph: if ( level is not None and current.merge_graph[neighbor].level is not None and current.merge_graph[neighbor].level >= level ): continue if neighbor not in parent: queue.append(neighbor) parent[neighbor] = current return bfs_backtrack(self, others, parent), others def why_val(self, level: int) -> list[Any]: return self._val.why_equal([self.val], level) def why_connect(self, node: Node, level: int = None) -> list[Any]: rep = self.rep() equivs = list(rep.edge_graph[node].keys()) if not equivs: return None equiv = equivs[0] dep = rep.edge_graph[node][equiv] return [dep] + self.why_equal(equiv, level) def why_connect(*pairs: list[tuple[Node, Node]]) -> list[Any]: result = [] for node1, node2 in pairs: result += node1.why_connect(node2) return result def is_equiv(x: Node, y: Node, level: int = None) -> bool: level = level or float('inf') return x.why_equal([y], level) is not None def is_equal(x: Node, y: Node, level: int = None) -> bool: if x == y: return True if x._val is None or y._val is None: return False if x.val != y.val: return False return is_equiv(x._val, y._val, level) def bfs_backtrack( root: Node, leafs: list[Node], parent: dict[Node, Node] ) -> list[Any]: """Return the path given BFS trace of parent nodes.""" backtracked = {root} # no need to backtrack further when touching this set. deps = [] for node in leafs: if node is None: return None if node in backtracked: continue if node not in parent: return None while node not in backtracked: backtracked.add(node) deps.append(node.merge_graph[parent[node]]) node = parent[node] return deps class Point(Node): pass class Line(Node): """Node of type Line.""" def new_val(self) -> Direction: return Direction() def why_coll(self, points: list[Point], level: int = None) -> list[Any]: """Why points are connected to self.""" level = level or float('inf') groups = [] for p in points: group = [ l for l, d in self.edge_graph[p].items() if d is None or d.level < level ] if not group: return None groups.append(group) min_deps = None for line in groups[0]: deps, others = line.why_equal_groups(groups[1:], level) if deps is None: continue for p, o in zip(points, [line] + others): deps.append(self.edge_graph[p][o]) if min_deps is None or len(deps) < len(min_deps): min_deps = deps if min_deps is None: return None return [d for d in min_deps if d is not None] class Segment(Node): def new_val(self) -> Length: return Length() class Circle(Node): """Node of type Circle.""" def why_cyclic(self, points: list[Point], level: int = None) -> list[Any]: """Why points are connected to self.""" level = level or float('inf') groups = [] for p in points: group = [ c for c, d in self.edge_graph[p].items() if d is None or d.level < level ] if not group: return None groups.append(group) min_deps = None for circle in groups[0]: deps, others = circle.why_equal_groups(groups[1:], level) if deps is None: continue for p, o in zip(points, [circle] + others): deps.append(self.edge_graph[p][o]) if min_deps is None or len(deps) < len(min_deps): min_deps = deps if min_deps is None: return None return [d for d in min_deps if d is not None] def why_equal(x: Node, y: Node, level: int = None) -> list[Any]: if x == y: return [] if not x._val or not y._val: return None if x._val == y._val: return [] return x._val.why_equal([y._val], level) class Direction(Node): pass def get_lines_thru_all(*points: list[Point]) -> list[Line]: line2count = defaultdict(lambda: 0) points = set(points) for p in points: for l in p.neighbors(Line): line2count[l] += 1 return [l for l, count in line2count.items() if count == len(points)] def line_of_and_why( points: list[Point], level: int = None ) -> tuple[Line, list[Any]]: """Why points are collinear.""" for l0 in get_lines_thru_all(*points): for l in l0.equivs(): if all([p in l.edge_graph for p in points]): x, y = l.points colls = list({x, y} | set(points)) # if len(colls) < 3: # return l, [] why = l.why_coll(colls, level) if why is not None: return l, why return None, None def get_circles_thru_all(*points: list[Point]) -> list[Circle]: circle2count = defaultdict(lambda: 0) points = set(points) for p in points: for c in p.neighbors(Circle): circle2count[c] += 1 return [c for c, count in circle2count.items() if count == len(points)] def circle_of_and_why( points: list[Point], level: int = None ) -> tuple[Circle, list[Any]]: """Why points are concyclic.""" for c0 in get_circles_thru_all(*points): for c in c0.equivs(): if all([p in c.edge_graph for p in points]): cycls = list(set(points)) why = c.why_cyclic(cycls, level) if why is not None: return c, why return None, None def name_map(struct: Any) -> Any: if isinstance(struct, list): return [name_map(x) for x in struct] elif isinstance(struct, tuple): return tuple([name_map(x) for x in struct]) elif isinstance(struct, set): return set([name_map(x) for x in struct]) elif isinstance(struct, dict): return {name_map(x): name_map(y) for x, y in struct.items()} else: return getattr(struct, 'name', '') class Angle(Node): """Node of type Angle.""" def new_val(self) -> Measure: return Measure() def set_directions(self, d1: Direction, d2: Direction) -> None: self._d = d1, d2 @property def directions(self) -> tuple[Direction, Direction]: d1, d2 = self._d if d1 is None or d2 is None: return d1, d2 return d1.rep(), d2.rep() class Measure(Node): pass class Length(Node): pass class Ratio(Node): """Node of type Ratio.""" def new_val(self) -> Value: return Value() def set_lengths(self, l1: Length, l2: Length) -> None: self._l = l1, l2 @property def lengths(self) -> tuple[Length, Length]: l1, l2 = self._l if l1 is None or l2 is None: return l1, l2 return l1.rep(), l2.rep() class Value(Node): pass def all_angles( d1: Direction, d2: Direction, level: int = None ) -> tuple[Angle, list[Direction], list[Direction]]: level = level or float('inf') d1s = d1.equivs_upto(level) d2s = d2.equivs_upto(level) for ang in d1.rep().neighbors(Angle): d1_, d2_ = ang._d if d1_ in d1s and d2_ in d2s: yield ang, d1s, d2s def all_ratios( d1, d2, level=None ) -> tuple[Angle, list[Direction], list[Direction]]: level = level or float('inf') d1s = d1.equivs_upto(level) d2s = d2.equivs_upto(level) for ang in d1.rep().neighbors(Ratio): d1_, d2_ = ang._l if d1_ in d1s and d2_ in d2s: yield ang, d1s, d2s RANKING = { Point: 0, Line: 1, Segment: 2, Circle: 3, Direction: 4, Length: 5, Angle: 6, Ratio: 7, Measure: 8, Value: 9, } def val_type(x: Node) -> Type[Node]: if isinstance(x, Line): return Direction if isinstance(x, Segment): return Length if isinstance(x, Angle): return Measure if isinstance(x, Ratio): return Value