HugoVoxx commited on
Commit
6c69a8f
·
verified ·
1 Parent(s): 21e3880

Upload 2 files

Browse files
ag4masses/alphageometry/alphageometry.py CHANGED
@@ -33,6 +33,8 @@ import problem as pr
33
  #=============
34
  import sys, os, math, re
35
  import multiprocessing
 
 
36
  model = None # global variable used in multi-processing workers
37
 
38
  _GIN_SEARCH_PATHS = flags.DEFINE_list(
@@ -199,12 +201,11 @@ def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
199
  rule_name = r2name.get(con.rule_name, '')
200
  nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
201
  solution += '{:03}. '.format(i + 1) + nl + '\n'
202
- logging.info(solution)
203
  if out_file:
204
  with open(out_file, 'w') as f:
205
  f.write(solution)
206
- logging.info('Solution written to %s.', out_file)
207
-
208
 
209
  def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
210
  lm.parse_gin_configuration(
@@ -233,7 +234,6 @@ def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
233
  return False
234
 
235
  write_solution(g, p, out_file)
236
-
237
  gh.nm.draw(
238
  g.type2nodes[gh.Point],
239
  g.type2nodes[gh.Line],
@@ -598,7 +598,7 @@ def bqsearch(i_nd, srch_inputs, out_file) -> tuple[int, bool, list]: # ( iNode,
598
  return (i_nd, False, ret)
599
 
600
  def run_alphageometry(
601
- #XX model: lm.LanguageModelInference,
602
  p: pr.Problem,
603
  search_depth: int,
604
  beam_size: int,
@@ -739,9 +739,9 @@ def main(_):
739
  run_ddar(g, this_problem, _OUT_FILE.value)
740
 
741
  elif _MODE.value == 'alphageometry':
742
- #XX model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
743
  run_alphageometry(
744
- #XX model,
745
  this_problem,
746
  _SEARCH_DEPTH.value,
747
  _BEAM_SIZE.value,
 
33
  #=============
34
  import sys, os, math, re
35
  import multiprocessing
36
+ import warnings
37
+ warnings.filterwarnings("ignore")
38
  model = None # global variable used in multi-processing workers
39
 
40
  _GIN_SEARCH_PATHS = flags.DEFINE_list(
 
201
  rule_name = r2name.get(con.rule_name, '')
202
  nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
203
  solution += '{:03}. '.format(i + 1) + nl + '\n'
204
+ # logging.info(solution)
205
  if out_file:
206
  with open(out_file, 'w') as f:
207
  f.write(solution)
208
+ # logging.info('Solution written to %s.', out_file)
 
209
 
210
  def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
211
  lm.parse_gin_configuration(
 
234
  return False
235
 
236
  write_solution(g, p, out_file)
 
237
  gh.nm.draw(
238
  g.type2nodes[gh.Point],
239
  g.type2nodes[gh.Line],
 
598
  return (i_nd, False, ret)
599
 
600
  def run_alphageometry(
601
+ # model: lm.LanguageModelInference,
602
  p: pr.Problem,
603
  search_depth: int,
604
  beam_size: int,
 
739
  run_ddar(g, this_problem, _OUT_FILE.value)
740
 
741
  elif _MODE.value == 'alphageometry':
742
+ model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
743
  run_alphageometry(
744
+ model,
745
  this_problem,
746
  _SEARCH_DEPTH.value,
747
  _BEAM_SIZE.value,
ag4masses/alphageometry/numericals.py CHANGED
@@ -29,6 +29,10 @@ from numpy.random import uniform as unif # pylint: disable=g-importing-member
29
  import graph as gh
30
  from collections import defaultdict
31
  from itertools import combinations
 
 
 
 
32
 
33
  matplotlib.use('TkAgg')
34
 
@@ -1060,7 +1064,7 @@ def _draw_line(
1060
 
1061
 
1062
  def draw_line(
1063
- ax: matplotlib.axes.Axes, line: Line, color: Any = 'white'
1064
  ) -> tuple[Point, Point]:
1065
  """Draw a line."""
1066
  points = line.neighbors(gm.Point)
@@ -1080,8 +1084,11 @@ def draw_line(
1080
  pmax = p, v
1081
 
1082
  p1, p2 = pmin[0], pmax[0]
1083
- _draw_line(ax, p1, p2, color=color)
1084
- return p1, p2
 
 
 
1085
 
1086
 
1087
  def _draw_circle(
@@ -1315,10 +1322,6 @@ def highlight(
1315
  ) -> None:
1316
  """Draw highlights."""
1317
  args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args))
1318
-
1319
- if name == 'cyclic':
1320
- a, b, c, d = args
1321
- _draw_circle(ax, Circle(p1=a, p2=b, p3=c), color=color1, lw=2.0)
1322
  if name == 'coll':
1323
  a, b, c = args
1324
  a, b = max(a, b, c), min(a, b, c)
@@ -1407,6 +1410,90 @@ def find_pairs_with_same_distance(line_lengths):
1407
 
1408
  return result
1409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1410
  def _draw(
1411
  ax: matplotlib.axes.Axes,
1412
  points: list[gm.Point],
@@ -1471,6 +1558,7 @@ def _draw(
1471
 
1472
  # Call the highlight function with the determined color
1473
  highlight(ax, 'cong', [p1, p2, p3, p4], lcolor, color, color)
 
1474
  if equals:
1475
  for i, segs in enumerate(equals['segments']):
1476
  color = colors[i % len(colors)]
 
29
  import graph as gh
30
  from collections import defaultdict
31
  from itertools import combinations
32
+ import numpy as np
33
+ import matplotlib.patches
34
+ import matplotlib.pyplot as plt
35
+ from itertools import combinations
36
 
37
  matplotlib.use('TkAgg')
38
 
 
1064
 
1065
 
1066
  def draw_line(
1067
+ ax: matplotlib.axes.Axes, line: Line, color: Any = 'white', draw: bool = True
1068
  ) -> tuple[Point, Point]:
1069
  """Draw a line."""
1070
  points = line.neighbors(gm.Point)
 
1084
  pmax = p, v
1085
 
1086
  p1, p2 = pmin[0], pmax[0]
1087
+ if draw:
1088
+ _draw_line(ax, p1, p2, color=color)
1089
+ return p1, p2
1090
+ else:
1091
+ return p1, p2
1092
 
1093
 
1094
  def _draw_circle(
 
1322
  ) -> None:
1323
  """Draw highlights."""
1324
  args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args))
 
 
 
 
1325
  if name == 'coll':
1326
  a, b, c = args
1327
  a, b = max(a, b, c), min(a, b, c)
 
1410
 
1411
  return result
1412
 
1413
+ def calculate_angle(p1, p2, p3, p4):
1414
+ """Calculates the angle between two lines formed by points (p1, p2) and (p3, p4) in degrees."""
1415
+ # Determine the common point
1416
+ if p2 == p3:
1417
+ common = p2
1418
+ other_points = (p1, p4)
1419
+ v1 = np.array([p1.x - p2.x, p1.y - p2.y])
1420
+ v2 = np.array([p4.x - p2.x, p4.y - p2.y])
1421
+ elif p1 == p4:
1422
+ common = p1
1423
+ other_points = (p2, p3)
1424
+ v1 = np.array([p2.x - p1.x, p2.y - p1.y])
1425
+ v2 = np.array([p3.x - p1.x, p3.y - p1.y])
1426
+ elif p1 == p3:
1427
+ common = p1
1428
+ other_points = (p2, p4)
1429
+ v1 = np.array([p2.x - p1.x, p2.y - p1.y])
1430
+ v2 = np.array([p4.x - p1.x, p4.y - p1.y])
1431
+ elif p2 == p4:
1432
+ common = p2
1433
+ other_points = (p1, p3)
1434
+ v1 = np.array([p1.x - p2.x, p1.y - p2.y])
1435
+ v2 = np.array([p3.x - p2.x, p3.y - p2.y])
1436
+ else:
1437
+ return None, None, None # No shared point, angle cannot be calculated
1438
+
1439
+ # Calculate the angle
1440
+ cos_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
1441
+ cos_angle = np.clip(cos_angle, -1, 1) # Ensure valid range for acos
1442
+ angle_rad = np.arccos(cos_angle)
1443
+ return common, other_points, round(np.degrees(angle_rad), 5)
1444
+
1445
+ def highlight_angle2(ax, origin, p1, p2, radius, color):
1446
+ """Highlights the angle formed by two vectors meeting at 'origin'."""
1447
+ # Calculate angles of vectors
1448
+ angle1 = np.arctan2(p1.y - origin.y, p1.x - origin.x)
1449
+ angle2 = np.arctan2(p2.y - origin.y, p2.x - origin.x)
1450
+
1451
+ # Convert to degrees and ensure the smaller angle is first
1452
+ angle1_deg, angle2_deg = sorted(np.degrees([angle1, angle2]))
1453
+ if angle2_deg - angle1_deg > 180:
1454
+ angle1_deg, angle2_deg = angle2_deg, angle1_deg
1455
+
1456
+ # Draw the wedge
1457
+ wedge = matplotlib.patches.Wedge(
1458
+ center=(origin.x, origin.y),
1459
+ r=radius,
1460
+ theta1=angle1_deg,
1461
+ theta2=angle2_deg,
1462
+ color=color,
1463
+ alpha=0.5
1464
+ )
1465
+ ax.add_patch(wedge)
1466
+ # print("Angle highlighted with color:", color)
1467
+
1468
+ def search_in_dict(num, my_dict):
1469
+ for key in my_dict.keys():
1470
+ if round(key, 3) == round(num, 3):
1471
+ return True
1472
+ return False
1473
+
1474
+ def highlight_same_angle(ax, lines, color_list):
1475
+ """Highlights angles formed at shared points by pairs of lines."""
1476
+ lines_list = [(draw_line(ax, l, draw=False)) for l in lines] # Extract points for all lines
1477
+ angle_color_radius = {}
1478
+
1479
+ for line1, line2 in combinations(lines_list, 2):
1480
+ # Calculate the angle
1481
+ common_point, other_points, angle = calculate_angle(*line1, *line2)
1482
+ if angle is None or angle > 90:
1483
+ continue # Skip invalid or small angles
1484
+
1485
+ if search_in_dict(angle, angle_color_radius) == False:
1486
+ # Assign color and radius for this unique angle
1487
+ color = color_list[len(angle_color_radius) % len(color_list)]
1488
+ radius = 0.1 + len(angle_color_radius) * 0.05
1489
+ angle_color_radius[round(angle,3)] = (color, radius)
1490
+ # print(type(angle))
1491
+ else:
1492
+ color, radius = angle_color_radius[round(angle, 3)]
1493
+ # print(type(angle))
1494
+
1495
+ # Highlight the angle
1496
+ highlight_angle2(ax, common_point, *other_points, radius, color)
1497
  def _draw(
1498
  ax: matplotlib.axes.Axes,
1499
  points: list[gm.Point],
 
1558
 
1559
  # Call the highlight function with the determined color
1560
  highlight(ax, 'cong', [p1, p2, p3, p4], lcolor, color, color)
1561
+ highlight_same_angle(ax, lines, color_list=colors_highlight)
1562
  if equals:
1563
  for i, segs in enumerate(equals['segments']):
1564
  color = colors[i % len(colors)]