HugoVoxx's picture
Upload 96 files
be3b34d verified
raw
history blame
20.4 kB
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementing Algebraic Reasoning (AR)."""
from collections import defaultdict # pylint: disable=g-importing-member
from fractions import Fraction as frac # pylint: disable=g-importing-member
from typing import Any, Generator
import geometry as gm
import numpy as np
import problem as pr
from scipy import optimize
class InfQuotientError(Exception):
pass
def _gcd(x: int, y: int) -> int:
while y:
x, y = y, x % y
return x
def simplify(n: int, d: int) -> tuple[int, int]:
g = _gcd(n, d)
return (n // g, d // g)
# maximum denominator for a fraction.
MAX_DENOMINATOR = 1000000
# tolerance for fraction approximation
TOL = 1e-15
def get_quotient(v: float) -> tuple[int, int]:
n = v
d = 1
while abs(n - round(n)) > TOL:
d += 1
n += v
if d > MAX_DENOMINATOR:
e = InfQuotientError(v)
raise e
n = int(round(n))
return simplify(n, d)
def fix_v(v: float) -> float:
n, d = get_quotient(v)
return n / d
def fix(e: dict[str, float]) -> dict[str, float]:
return {k: fix_v(v) for k, v in e.items()}
def frac_string(f: frac) -> str:
n, d = get_quotient(f)
return f'{n}/{d}'
def hashed(e: dict[str, float]) -> tuple[tuple[str, float], ...]:
return tuple(sorted(list(e.items())))
def is_zero(e: dict[str, float]) -> bool:
return len(strip(e)) == 0 # pylint: disable=g-explicit-length-test
def strip(e: dict[str, float]) -> dict[str, float]:
return {v: c for v, c in e.items() if c != 0}
def plus(e1: dict[str, float], e2: dict[str, float]) -> dict[str, float]:
e = dict(e1)
for v, c in e2.items():
if v in e:
e[v] += c
else:
e[v] = c
return strip(e)
def plus_all(*es: list[dict[str, float]]) -> dict[str, float]:
result = {}
for e in es:
result = plus(result, e)
return result
def mult(e: dict[str, float], m: float) -> dict[str, float]:
return {v: m * c for v, c in e.items()}
def minus(e1: dict[str, float], e2: dict[str, float]) -> dict[str, float]:
return plus(e1, mult(e2, -1))
def div(e1: dict[str, float], e2: dict[str, float]) -> float:
"""Divide e1 by e2."""
e1 = strip(e1)
e2 = strip(e2)
if set(e1.keys()) != set(e2.keys()):
return None
n, d = None, None
for v, c1 in e1.items():
c2 = e2[v] # we want c1/c2 = n/d => c1*d=c2*n
if n is not None and c1 * d != c2 * n:
return None
n, d = c1, c2
return frac(n) / frac(d)
def recon(e: dict[str, float], const: str) -> tuple[str, dict[str, float]]:
"""Reconcile one variable in the expression e=0, given const."""
e = strip(e)
if len(e) == 0: # pylint: disable=g-explicit-length-test
return None
v0 = None
for v in e:
if v != const:
v0 = v
break
if v0 is None:
return v0
c0 = e.pop(v0)
return v0, {v: -c / c0 for v, c in e.items()}
def replace(
e: dict[str, float], v0: str, e0: dict[str, float]
) -> dict[str, float]:
if v0 not in e:
return e
e = dict(e)
m = e.pop(v0)
return plus(e, mult(e0, m))
def comb2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]:
if len(elems) < 1:
return
for i, e1 in enumerate(elems[:-1]):
for e2 in elems[i + 1 :]:
yield e1, e2
def perm2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]:
for e1, e2 in comb2(elems):
yield e1, e2
yield e2, e1
def chain2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]:
if len(elems) < 2:
return
for i, e1 in enumerate(elems[:-1]):
yield e1, elems[i + 1]
def update_groups(
groups1: list[Any], groups2: list[Any]
) -> tuple[list[Any], list[tuple[Any, Any]], list[list[Any]]]:
"""Update groups of equivalent elements.
Given groups1 = [set1, set2, set3, ..]
where all elems within each set_i is defined to be "equivalent" to each other.
(but not across the sets)
Incoming groups2 = [set1, set2, ...] similar to set1 - it is the
additional equivalent information on elements in groups1.
Return the new updated groups1 and the set of links
that make it that way.
Example:
groups1 = [{1, 2}, {3, 4, 5}, {6, 7}]
groups2 = [{2, 3, 8}, {9, 10, 11}]
=> new groups1 and links:
groups1 = [{1, 2, 3, 4, 5, 8}, {6, 7}, {9, 10, 11}]
links = (2, 3), (3, 8), (9, 10), (10, 11)
Explain: since groups2 says 2 and 3 are equivalent (with {2, 3, 8}),
then {1, 2} and {3, 4, 5} in groups1 will be merged,
because 2 and 3 each belong to those 2 groups.
Additionally 8 also belong to this same group.
{3, 4, 5} is left alone, while {9, 10, 11} is a completely new set.
The links to make this all happens is:
(2, 3): to merge {1, 2} and {3, 4, 5}
(3, 8): to link 8 into the merged({1, 2, 3, 4, 5})
(9, 10) and (10, 11): to make the new group {9, 10, 11}
Args:
groups1: a list of sets.
groups2: a list of sets.
Returns:
groups1, links, history: result of the update.
"""
history = []
links = []
for g2 in groups2:
joins = [None] * len(groups1) # mark which one in groups1 is merged
merged_g1 = set() # merge them into this.
old = None # any elem in g2 that belong to any set in groups1 (old)
new = [] # all elem in g2 that is new
for e in g2:
found = False
for i, g1 in enumerate(groups1):
if e not in g1:
continue
found = True
if joins[i]:
continue
joins[i] = True
merged_g1.update(g1)
if old is not None:
links.append((old, e)) # link to make merging happen.
old = e
if not found: # e is new!
new.append(e)
# now chain elems in new together.
if old is not None and new:
links.append((old, new[0]))
merged_g1.update(new)
links += chain2(new)
new_groups1 = []
if merged_g1: # put the merged_g1 in first
new_groups1.append(merged_g1)
# put the remaining (unjoined) groups in
new_groups1 += [g1 for j, g1 in zip(joins, groups1) if not j]
if old is None and new:
new_groups1 += [set(new)]
groups1 = new_groups1
history.append(groups1)
return groups1, links, history
class Table:
"""The coefficient matrix."""
def __init__(self, const: str = '1'):
self.const = const
self.v2e = {}
self.add_free(const) # the table {var: expression}
# to cache what is already derived/inputted
self.eqs = set()
self.groups = [] # groups of equal pairs.
# for why (linprog)
self.c = []
self.v2i = {} # v -> index of row in A.
self.deps = [] # equal number of columns.
self.A = np.zeros([0, 0]) # pylint: disable=invalid-name
self.do_why = True
def add_free(self, v: str) -> None:
self.v2e[v] = {v: frac(1)}
def replace(self, v0: str, e0: dict[str, float]) -> None:
for v, e in list(self.v2e.items()):
self.v2e[v] = replace(e, v0, e0)
def add_expr(self, vc: list[tuple[str, float]]) -> bool:
"""Add a new equality, represented by the list of tuples vc=[(v, c), ..]."""
result = {}
free = []
for v, c in vc:
c = frac(c)
if v in self.v2e:
result = plus(result, mult(self.v2e[v], c))
else:
free += [(v, c)]
if free == []: # pylint: disable=g-explicit-bool-comparison
if is_zero(self.modulo(result)):
return False
result = recon(result, self.const)
if result is None:
return False
v, e = result
self.replace(v, e)
elif len(free) == 1:
v, m = free[0]
self.v2e[v] = mult(result, frac(-1, m))
else:
dependent_v = None
for v, m in free:
if dependent_v is None and v != self.const:
dependent_v = (v, m)
continue
self.add_free(v)
result = plus(result, {v: m})
v, m = dependent_v
self.v2e[v] = mult(result, frac(-1, m))
return True
def register(self, vc: list[tuple[str, float]], dep: pr.Dependency) -> None:
"""Register a new equality vc=[(v, c), ..] with traceback dependency dep."""
result = plus_all(*[{v: c} for v, c in vc])
if is_zero(result):
return
vs, _ = zip(*vc)
for v in vs:
if v not in self.v2i:
self.v2i[v] = len(self.v2i)
(m, n), l = self.A.shape, len(self.v2i)
if l > m:
self.A = np.concatenate([self.A, np.zeros([l - m, n])], 0)
new_column = np.zeros([len(self.v2i), 2]) # N, 2
for v, c in vc:
new_column[self.v2i[v], 0] += float(c)
new_column[self.v2i[v], 1] -= float(c)
self.A = np.concatenate([self.A, new_column], 1)
self.c += [1.0, -1.0]
self.deps += [dep]
def register2(
self, a: str, b: str, m: float, n: float, dep: pr.Dependency
) -> None:
self.register([(a, m), (b, -n)], dep)
def register3(self, a: str, b: str, f: float, dep: pr.Dependency) -> None:
self.register([(a, 1), (b, -1), (self.const, -f)], dep)
def register4(
self, a: str, b: str, c: str, d: str, dep: pr.Dependency
) -> None:
self.register([(a, 1), (b, -1), (c, -1), (d, 1)], dep)
def why(self, e: dict[str, float]) -> list[Any]:
"""AR traceback == MILP."""
if not self.do_why:
return []
# why expr == 0?
# Solve min(c^Tx) s.t. A_eq * x = b_eq, x >= 0
e = strip(e)
if not e:
return []
b_eq = [0] * len(self.v2i)
for v, c in e.items():
b_eq[self.v2i[v]] += float(c)
try:
x = optimize.linprog(c=self.c, A_eq=self.A, b_eq=b_eq, method='highs')[
'x'
]
except: # pylint: disable=bare-except
x = optimize.linprog(
c=self.c,
A_eq=self.A,
b_eq=b_eq,
)['x']
deps = []
for i, dep in enumerate(self.deps):
if x[2 * i] > 1e-12 or x[2 * i + 1] > 1e-12:
if dep not in deps:
deps.append(dep)
return deps
def record_eq(self, v1: str, v2: str, v3: str, v4: str) -> None:
self.eqs.add((v1, v2, v3, v4))
self.eqs.add((v2, v1, v4, v3))
self.eqs.add((v3, v4, v1, v2))
self.eqs.add((v4, v3, v2, v1))
def check_record_eq(self, v1: str, v2: str, v3: str, v4: str) -> bool:
if (v1, v2, v3, v4) in self.eqs:
return True
if (v2, v1, v4, v3) in self.eqs:
return True
if (v3, v4, v1, v2) in self.eqs:
return True
if (v4, v3, v2, v1) in self.eqs:
return True
return False
def add_eq2(
self, a: str, b: str, m: float, n: float, dep: pr.Dependency
) -> None:
# a/b = m/n
if not self.add_expr([(a, n), (b, -m)]):
return []
self.register2(a, b, m, n, dep)
def add_eq3(self, a: str, b: str, f: float, dep: pr.Dependency) -> None:
# a - b = f * constant
self.eqs.add((a, b, frac(f)))
self.eqs.add((b, a, frac(1 - f)))
if not self.add_expr([(a, 1), (b, -1), (self.const, -f)]):
return []
self.register3(a, b, f, dep)
def add_eq4(self, a: str, b: str, c: str, d: str, dep: pr.Dependency) -> None:
# a - b = c - d
self.record_eq(a, b, c, d)
self.record_eq(a, c, b, d)
expr = list(minus({a: 1, b: -1}, {c: 1, d: -1}).items())
if not self.add_expr(expr):
return []
self.register4(a, b, c, d, dep)
self.groups, _, _ = update_groups(
self.groups, [{(a, b), (c, d)}, {(b, a), (d, c)}]
)
def pairs(self) -> Generator[list[tuple[str, str]], None, None]:
for v1, v2 in perm2(list(self.v2e.keys())): # pylint: disable=g-builtin-op
if v1 == self.const or v2 == self.const:
continue
yield v1, v2
def modulo(self, e: dict[str, float]) -> dict[str, float]:
return strip(e)
def get_all_eqs(
self,
) -> dict[tuple[tuple[str, float], ...], list[tuple[str, str]]]:
h2pairs = defaultdict(list)
for v1, v2 in self.pairs():
e1, e2 = self.v2e[v1], self.v2e[v2]
e12 = minus(e1, e2)
h12 = hashed(self.modulo(e12))
h2pairs[h12].append((v1, v2))
return h2pairs
def get_all_eqs_and_why(
self, return_quads: bool = True
) -> Generator[Any, None, None]:
"""Check all 4/3/2-permutations for new equalities."""
groups = []
for h, vv in self.get_all_eqs().items():
if h == (): # pylint: disable=g-explicit-bool-comparison
for v1, v2 in vv:
if (v1, v2) in self.eqs or (v2, v1) in self.eqs:
continue
self.eqs.add((v1, v2))
# why v1 - v2 = e12 ? (note modulo(e12) == 0)
why_dict = minus({v1: 1, v2: -1}, minus(self.v2e[v1], self.v2e[v2]))
yield v1, v2, self.why(why_dict)
continue
if len(h) == 1 and h[0][0] == self.const:
for v1, v2 in vv:
frac = h[0][1] # pylint: disable=redefined-outer-name
if (v1, v2, frac) in self.eqs:
continue
self.eqs.add((v1, v2, frac))
# why v1 - v2 = e12 ? (note modulo(e12) == 0)
why_dict = minus({v1: 1, v2: -1}, minus(self.v2e[v1], self.v2e[v2]))
value = simplify(frac.numerator, frac.denominator)
yield v1, v2, value, self.why(why_dict)
continue
groups.append(vv)
if not return_quads:
return
self.groups, links, _ = update_groups(self.groups, groups)
for (v1, v2), (v3, v4) in links:
if self.check_record_eq(v1, v2, v3, v4):
continue
e12 = minus(self.v2e[v1], self.v2e[v2])
e34 = minus(self.v2e[v3], self.v2e[v4])
why_dict = minus( # why (v1-v2)-(v3-v4)=e12-e34?
minus({v1: 1, v2: -1}, {v3: 1, v4: -1}), minus(e12, e34)
)
self.record_eq(v1, v2, v3, v4)
yield v1, v2, v3, v4, self.why(why_dict)
class GeometricTable(Table):
"""Abstract class representing the coefficient matrix (table) A."""
def __init__(self, name: str = ''):
super().__init__(name)
self.v2obj = {}
def get_name(self, objs: list[Any]) -> list[str]:
self.v2obj.update({o.name: o for o in objs})
return [o.name for o in objs]
def map2obj(self, names: list[str]) -> list[Any]:
return [self.v2obj[n] for n in names]
def get_all_eqs_and_why(
self, return_quads: bool
) -> Generator[Any, None, None]:
for out in super().get_all_eqs_and_why(return_quads):
if len(out) == 3:
x, y, why = out
x, y = self.map2obj([x, y])
yield x, y, why
if len(out) == 4:
x, y, f, why = out
x, y = self.map2obj([x, y])
yield x, y, f, why
if len(out) == 5:
a, b, x, y, why = out
a, b, x, y = self.map2obj([a, b, x, y])
yield a, b, x, y, why
class RatioTable(GeometricTable):
"""Coefficient matrix A for log(distance)."""
def __init__(self, name: str = ''):
name = name or '1'
super().__init__(name)
self.one = self.const
def add_eq(self, l1: gm.Length, l2: gm.Length, dep: pr.Dependency) -> None:
l1, l2 = self.get_name([l1, l2])
return super().add_eq3(l1, l2, 0.0, dep)
def add_const_ratio(
self, l1: gm.Length, l2: gm.Length, m: float, n: float, dep: pr.Dependency
) -> None:
l1, l2 = self.get_name([l1, l2])
return super().add_eq2(l1, l2, m, n, dep)
def add_eqratio(
self,
l1: gm.Length,
l2: gm.Length,
l3: gm.Length,
l4: gm.Length,
dep: pr.Dependency,
) -> None:
l1, l2, l3, l4 = self.get_name([l1, l2, l3, l4])
return self.add_eq4(l1, l2, l3, l4, dep)
def get_all_eqs_and_why(self) -> Generator[Any, None, None]:
return super().get_all_eqs_and_why(True)
class AngleTable(GeometricTable):
"""Coefficient matrix A for slope(direction)."""
def __init__(self, name: str = ''):
name = name or 'pi'
super().__init__(name)
self.pi = self.const
def modulo(self, e: dict[str, float]) -> dict[str, float]:
e = strip(e)
if self.pi not in e:
return super().modulo(e)
e[self.pi] = e[self.pi] % 1
return strip(e)
def add_para(
self, d1: gm.Direction, d2: gm.Direction, dep: pr.Dependency
) -> None:
return self.add_const_angle(d1, d2, 0, dep)
def add_const_angle(
self, d1: gm.Direction, d2: gm.Direction, ang: float, dep: pr.Dependency
) -> None:
if ang and d2._obj.num > d1._obj.num: # pylint: disable=protected-access
d1, d2 = d2, d1
ang = 180 - ang
d1, d2 = self.get_name([d1, d2])
num, den = simplify(ang, 180)
ang = frac(int(num), int(den))
return super().add_eq3(d1, d2, ang, dep)
def add_eqangle(
self,
d1: gm.Direction,
d2: gm.Direction,
d3: gm.Direction,
d4: gm.Direction,
dep: pr.Dependency,
) -> None:
"""Add the inequality d1-d2=d3-d4."""
# Use string as variables.
l1, l2, l3, l4 = [d._obj.num for d in [d1, d2, d3, d4]] # pylint: disable=protected-access
d1, d2, d3, d4 = self.get_name([d1, d2, d3, d4])
ang1 = {d1: 1, d2: -1}
ang2 = {d3: 1, d4: -1}
if l2 > l1:
ang1 = plus({self.pi: 1}, ang1)
if l4 > l3:
ang2 = plus({self.pi: 1}, ang2)
ang12 = minus(ang1, ang2)
self.record_eq(d1, d2, d3, d4)
self.record_eq(d1, d3, d2, d4)
expr = list(ang12.items())
if not self.add_expr(expr):
return []
self.register(expr, dep)
def get_all_eqs_and_why(self) -> Generator[Any, None, None]:
return super().get_all_eqs_and_why(True)
class DistanceTable(GeometricTable):
"""Coefficient matrix A for position(point, line)."""
def __init__(self, name: str = ''):
name = name or '1:1'
self.merged = {}
self.ratios = set()
super().__init__(name)
def pairs(self) -> Generator[tuple[str, str], None, None]:
l2vs = defaultdict(list)
for v in list(self.v2e.keys()): # pylint: disable=g-builtin-op
if v == self.const:
continue
l, p = v.split(':')
l2vs[l].append(p)
for l, ps in l2vs.items():
for p1, p2 in perm2(ps):
yield l + ':' + p1, l + ':' + p2
def name(self, l: gm.Line, p: gm.Point) -> str:
v = l.name + ':' + p.name
self.v2obj[v] = (l, p)
return v
def map2obj(self, names: list[str]) -> list[gm.Point]:
return [self.v2obj[n][1] for n in names]
def add_cong(
self,
l12: gm.Line,
l34: gm.Line,
p1: gm.Point,
p2: gm.Point,
p3: gm.Point,
p4: gm.Point,
dep: pr.Dependency,
) -> None:
"""Add that distance between p1 and p2 (on l12) == p3 and p4 (on l34)."""
if p2.num > p1.num:
p1, p2 = p2, p1
if p4.num > p3.num:
p3, p4 = p4, p3
p1 = self.name(l12, p1)
p2 = self.name(l12, p2)
p3 = self.name(l34, p3)
p4 = self.name(l34, p4)
return super().add_eq4(p1, p2, p3, p4, dep)
def get_all_eqs_and_why(self) -> Generator[Any, None, None]:
for x in super().get_all_eqs_and_why(True):
yield x
# Now we figure out all the const ratios.
h2pairs = defaultdict(list)
for v1, v2 in self.pairs():
if (v1, v2) in self.merged:
continue
e1, e2 = self.v2e[v1], self.v2e[v2]
e12 = minus(e1, e2)
h12 = hashed(e12)
h2pairs[h12].append((v1, v2, e12))
for (_, vves1), (_, vves2) in perm2(list(h2pairs.items())):
v1, v2, e12 = vves1[0]
for v1_, v2_, _ in vves1[1:]:
self.merged[(v1_, v2_)] = (v1, v2)
v3, v4, e34 = vves2[0]
for v3_, v4_, _ in vves2[1:]:
self.merged[(v3_, v4_)] = (v3, v4)
if (v1, v2, v3, v4) in self.ratios:
continue
d12 = div(e12, e34)
if d12 is None or d12 > 1 or d12 < 0:
continue
self.ratios.add((v1, v2, v3, v4))
self.ratios.add((v2, v1, v4, v3))
n, d = d12.numerator, d12.denominator
# (v1 - v2) * d = (v3 - v4) * n
why_dict = minus(
minus({v1: d, v2: -d}, {v3: n, v4: -n}),
minus(mult(e12, d), mult(e34, n)), # there is no modulo, so this is 0
)
v1, v2, v3, v4 = self.map2obj([v1, v2, v3, v4])
yield v1, v2, v3, v4, abs(n), abs(d), self.why(why_dict)