Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files
ag4masses/alphageometry/alphageometry.py
CHANGED
@@ -33,6 +33,8 @@ import problem as pr
|
|
33 |
#=============
|
34 |
import sys, os, math, re
|
35 |
import multiprocessing
|
|
|
|
|
36 |
model = None # global variable used in multi-processing workers
|
37 |
|
38 |
_GIN_SEARCH_PATHS = flags.DEFINE_list(
|
@@ -199,12 +201,11 @@ def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
|
|
199 |
rule_name = r2name.get(con.rule_name, '')
|
200 |
nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
|
201 |
solution += '{:03}. '.format(i + 1) + nl + '\n'
|
202 |
-
logging.info(solution)
|
203 |
if out_file:
|
204 |
with open(out_file, 'w') as f:
|
205 |
f.write(solution)
|
206 |
-
logging.info('Solution written to %s.', out_file)
|
207 |
-
|
208 |
|
209 |
def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
|
210 |
lm.parse_gin_configuration(
|
@@ -233,7 +234,6 @@ def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
|
|
233 |
return False
|
234 |
|
235 |
write_solution(g, p, out_file)
|
236 |
-
|
237 |
gh.nm.draw(
|
238 |
g.type2nodes[gh.Point],
|
239 |
g.type2nodes[gh.Line],
|
@@ -598,7 +598,7 @@ def bqsearch(i_nd, srch_inputs, out_file) -> tuple[int, bool, list]: # ( iNode,
|
|
598 |
return (i_nd, False, ret)
|
599 |
|
600 |
def run_alphageometry(
|
601 |
-
#
|
602 |
p: pr.Problem,
|
603 |
search_depth: int,
|
604 |
beam_size: int,
|
@@ -739,9 +739,9 @@ def main(_):
|
|
739 |
run_ddar(g, this_problem, _OUT_FILE.value)
|
740 |
|
741 |
elif _MODE.value == 'alphageometry':
|
742 |
-
|
743 |
run_alphageometry(
|
744 |
-
|
745 |
this_problem,
|
746 |
_SEARCH_DEPTH.value,
|
747 |
_BEAM_SIZE.value,
|
|
|
33 |
#=============
|
34 |
import sys, os, math, re
|
35 |
import multiprocessing
|
36 |
+
import warnings
|
37 |
+
warnings.filterwarnings("ignore")
|
38 |
model = None # global variable used in multi-processing workers
|
39 |
|
40 |
_GIN_SEARCH_PATHS = flags.DEFINE_list(
|
|
|
201 |
rule_name = r2name.get(con.rule_name, '')
|
202 |
nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
|
203 |
solution += '{:03}. '.format(i + 1) + nl + '\n'
|
204 |
+
# logging.info(solution)
|
205 |
if out_file:
|
206 |
with open(out_file, 'w') as f:
|
207 |
f.write(solution)
|
208 |
+
# logging.info('Solution written to %s.', out_file)
|
|
|
209 |
|
210 |
def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
|
211 |
lm.parse_gin_configuration(
|
|
|
234 |
return False
|
235 |
|
236 |
write_solution(g, p, out_file)
|
|
|
237 |
gh.nm.draw(
|
238 |
g.type2nodes[gh.Point],
|
239 |
g.type2nodes[gh.Line],
|
|
|
598 |
return (i_nd, False, ret)
|
599 |
|
600 |
def run_alphageometry(
|
601 |
+
# model: lm.LanguageModelInference,
|
602 |
p: pr.Problem,
|
603 |
search_depth: int,
|
604 |
beam_size: int,
|
|
|
739 |
run_ddar(g, this_problem, _OUT_FILE.value)
|
740 |
|
741 |
elif _MODE.value == 'alphageometry':
|
742 |
+
model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
|
743 |
run_alphageometry(
|
744 |
+
model,
|
745 |
this_problem,
|
746 |
_SEARCH_DEPTH.value,
|
747 |
_BEAM_SIZE.value,
|
ag4masses/alphageometry/numericals.py
CHANGED
@@ -29,6 +29,10 @@ from numpy.random import uniform as unif # pylint: disable=g-importing-member
|
|
29 |
import graph as gh
|
30 |
from collections import defaultdict
|
31 |
from itertools import combinations
|
|
|
|
|
|
|
|
|
32 |
|
33 |
matplotlib.use('TkAgg')
|
34 |
|
@@ -1060,7 +1064,7 @@ def _draw_line(
|
|
1060 |
|
1061 |
|
1062 |
def draw_line(
|
1063 |
-
ax: matplotlib.axes.Axes, line: Line, color: Any = 'white'
|
1064 |
) -> tuple[Point, Point]:
|
1065 |
"""Draw a line."""
|
1066 |
points = line.neighbors(gm.Point)
|
@@ -1080,8 +1084,11 @@ def draw_line(
|
|
1080 |
pmax = p, v
|
1081 |
|
1082 |
p1, p2 = pmin[0], pmax[0]
|
1083 |
-
|
1084 |
-
|
|
|
|
|
|
|
1085 |
|
1086 |
|
1087 |
def _draw_circle(
|
@@ -1315,10 +1322,6 @@ def highlight(
|
|
1315 |
) -> None:
|
1316 |
"""Draw highlights."""
|
1317 |
args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args))
|
1318 |
-
|
1319 |
-
if name == 'cyclic':
|
1320 |
-
a, b, c, d = args
|
1321 |
-
_draw_circle(ax, Circle(p1=a, p2=b, p3=c), color=color1, lw=2.0)
|
1322 |
if name == 'coll':
|
1323 |
a, b, c = args
|
1324 |
a, b = max(a, b, c), min(a, b, c)
|
@@ -1407,6 +1410,90 @@ def find_pairs_with_same_distance(line_lengths):
|
|
1407 |
|
1408 |
return result
|
1409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1410 |
def _draw(
|
1411 |
ax: matplotlib.axes.Axes,
|
1412 |
points: list[gm.Point],
|
@@ -1471,6 +1558,7 @@ def _draw(
|
|
1471 |
|
1472 |
# Call the highlight function with the determined color
|
1473 |
highlight(ax, 'cong', [p1, p2, p3, p4], lcolor, color, color)
|
|
|
1474 |
if equals:
|
1475 |
for i, segs in enumerate(equals['segments']):
|
1476 |
color = colors[i % len(colors)]
|
|
|
29 |
import graph as gh
|
30 |
from collections import defaultdict
|
31 |
from itertools import combinations
|
32 |
+
import numpy as np
|
33 |
+
import matplotlib.patches
|
34 |
+
import matplotlib.pyplot as plt
|
35 |
+
from itertools import combinations
|
36 |
|
37 |
matplotlib.use('TkAgg')
|
38 |
|
|
|
1064 |
|
1065 |
|
1066 |
def draw_line(
|
1067 |
+
ax: matplotlib.axes.Axes, line: Line, color: Any = 'white', draw: bool = True
|
1068 |
) -> tuple[Point, Point]:
|
1069 |
"""Draw a line."""
|
1070 |
points = line.neighbors(gm.Point)
|
|
|
1084 |
pmax = p, v
|
1085 |
|
1086 |
p1, p2 = pmin[0], pmax[0]
|
1087 |
+
if draw:
|
1088 |
+
_draw_line(ax, p1, p2, color=color)
|
1089 |
+
return p1, p2
|
1090 |
+
else:
|
1091 |
+
return p1, p2
|
1092 |
|
1093 |
|
1094 |
def _draw_circle(
|
|
|
1322 |
) -> None:
|
1323 |
"""Draw highlights."""
|
1324 |
args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args))
|
|
|
|
|
|
|
|
|
1325 |
if name == 'coll':
|
1326 |
a, b, c = args
|
1327 |
a, b = max(a, b, c), min(a, b, c)
|
|
|
1410 |
|
1411 |
return result
|
1412 |
|
1413 |
+
def calculate_angle(p1, p2, p3, p4):
|
1414 |
+
"""Calculates the angle between two lines formed by points (p1, p2) and (p3, p4) in degrees."""
|
1415 |
+
# Determine the common point
|
1416 |
+
if p2 == p3:
|
1417 |
+
common = p2
|
1418 |
+
other_points = (p1, p4)
|
1419 |
+
v1 = np.array([p1.x - p2.x, p1.y - p2.y])
|
1420 |
+
v2 = np.array([p4.x - p2.x, p4.y - p2.y])
|
1421 |
+
elif p1 == p4:
|
1422 |
+
common = p1
|
1423 |
+
other_points = (p2, p3)
|
1424 |
+
v1 = np.array([p2.x - p1.x, p2.y - p1.y])
|
1425 |
+
v2 = np.array([p3.x - p1.x, p3.y - p1.y])
|
1426 |
+
elif p1 == p3:
|
1427 |
+
common = p1
|
1428 |
+
other_points = (p2, p4)
|
1429 |
+
v1 = np.array([p2.x - p1.x, p2.y - p1.y])
|
1430 |
+
v2 = np.array([p4.x - p1.x, p4.y - p1.y])
|
1431 |
+
elif p2 == p4:
|
1432 |
+
common = p2
|
1433 |
+
other_points = (p1, p3)
|
1434 |
+
v1 = np.array([p1.x - p2.x, p1.y - p2.y])
|
1435 |
+
v2 = np.array([p3.x - p2.x, p3.y - p2.y])
|
1436 |
+
else:
|
1437 |
+
return None, None, None # No shared point, angle cannot be calculated
|
1438 |
+
|
1439 |
+
# Calculate the angle
|
1440 |
+
cos_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
1441 |
+
cos_angle = np.clip(cos_angle, -1, 1) # Ensure valid range for acos
|
1442 |
+
angle_rad = np.arccos(cos_angle)
|
1443 |
+
return common, other_points, round(np.degrees(angle_rad), 5)
|
1444 |
+
|
1445 |
+
def highlight_angle2(ax, origin, p1, p2, radius, color):
|
1446 |
+
"""Highlights the angle formed by two vectors meeting at 'origin'."""
|
1447 |
+
# Calculate angles of vectors
|
1448 |
+
angle1 = np.arctan2(p1.y - origin.y, p1.x - origin.x)
|
1449 |
+
angle2 = np.arctan2(p2.y - origin.y, p2.x - origin.x)
|
1450 |
+
|
1451 |
+
# Convert to degrees and ensure the smaller angle is first
|
1452 |
+
angle1_deg, angle2_deg = sorted(np.degrees([angle1, angle2]))
|
1453 |
+
if angle2_deg - angle1_deg > 180:
|
1454 |
+
angle1_deg, angle2_deg = angle2_deg, angle1_deg
|
1455 |
+
|
1456 |
+
# Draw the wedge
|
1457 |
+
wedge = matplotlib.patches.Wedge(
|
1458 |
+
center=(origin.x, origin.y),
|
1459 |
+
r=radius,
|
1460 |
+
theta1=angle1_deg,
|
1461 |
+
theta2=angle2_deg,
|
1462 |
+
color=color,
|
1463 |
+
alpha=0.5
|
1464 |
+
)
|
1465 |
+
ax.add_patch(wedge)
|
1466 |
+
# print("Angle highlighted with color:", color)
|
1467 |
+
|
1468 |
+
def search_in_dict(num, my_dict):
|
1469 |
+
for key in my_dict.keys():
|
1470 |
+
if round(key, 3) == round(num, 3):
|
1471 |
+
return True
|
1472 |
+
return False
|
1473 |
+
|
1474 |
+
def highlight_same_angle(ax, lines, color_list):
|
1475 |
+
"""Highlights angles formed at shared points by pairs of lines."""
|
1476 |
+
lines_list = [(draw_line(ax, l, draw=False)) for l in lines] # Extract points for all lines
|
1477 |
+
angle_color_radius = {}
|
1478 |
+
|
1479 |
+
for line1, line2 in combinations(lines_list, 2):
|
1480 |
+
# Calculate the angle
|
1481 |
+
common_point, other_points, angle = calculate_angle(*line1, *line2)
|
1482 |
+
if angle is None or angle > 90:
|
1483 |
+
continue # Skip invalid or small angles
|
1484 |
+
|
1485 |
+
if search_in_dict(angle, angle_color_radius) == False:
|
1486 |
+
# Assign color and radius for this unique angle
|
1487 |
+
color = color_list[len(angle_color_radius) % len(color_list)]
|
1488 |
+
radius = 0.1 + len(angle_color_radius) * 0.05
|
1489 |
+
angle_color_radius[round(angle,3)] = (color, radius)
|
1490 |
+
# print(type(angle))
|
1491 |
+
else:
|
1492 |
+
color, radius = angle_color_radius[round(angle, 3)]
|
1493 |
+
# print(type(angle))
|
1494 |
+
|
1495 |
+
# Highlight the angle
|
1496 |
+
highlight_angle2(ax, common_point, *other_points, radius, color)
|
1497 |
def _draw(
|
1498 |
ax: matplotlib.axes.Axes,
|
1499 |
points: list[gm.Point],
|
|
|
1558 |
|
1559 |
# Call the highlight function with the determined color
|
1560 |
highlight(ax, 'cong', [p1, p2, p3, p4], lcolor, color, color)
|
1561 |
+
highlight_same_angle(ax, lines, color_list=colors_highlight)
|
1562 |
if equals:
|
1563 |
for i, segs in enumerate(equals['segments']):
|
1564 |
color = colors[i % len(colors)]
|