liuganghuggingface commited on
Commit
58dad72
·
verified ·
1 Parent(s): 78b3cac

Upload graph_decoder/visualize_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. graph_decoder/visualize_utils.py +86 -0
graph_decoder/visualize_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from rdkit import Chem
3
+ from rdkit.Chem import Draw, AllChem
4
+ from rdkit.Geometry import Point3D
5
+ from rdkit import RDLogger
6
+ import numpy as np
7
+ import rdkit.Chem
8
+
9
+ class MolecularVisualization:
10
+ def __init__(self, atom_decoder):
11
+ self.atom_decoder = atom_decoder
12
+
13
+ def mol_from_graphs(self, node_list, adjacency_matrix):
14
+ """
15
+ Convert graphs to rdkit molecules
16
+ node_list: the nodes of a batch of nodes (bs x n)
17
+ adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
18
+ """
19
+ # dictionary to map integer value to the char of atom
20
+ atom_decoder = self.atom_decoder
21
+
22
+ # create empty editable mol object
23
+ mol = Chem.RWMol()
24
+
25
+ # add atoms to mol and keep track of index
26
+ node_to_idx = {}
27
+ for i in range(len(node_list)):
28
+ if node_list[i] == -1:
29
+ continue
30
+ a = Chem.Atom(atom_decoder[int(node_list[i])])
31
+ molIdx = mol.AddAtom(a)
32
+ node_to_idx[i] = molIdx
33
+
34
+ for ix, row in enumerate(adjacency_matrix):
35
+ for iy, bond in enumerate(row):
36
+ # only traverse half the symmetric matrix
37
+ if iy <= ix:
38
+ continue
39
+ if bond == 1:
40
+ bond_type = Chem.rdchem.BondType.SINGLE
41
+ elif bond == 2:
42
+ bond_type = Chem.rdchem.BondType.DOUBLE
43
+ elif bond == 3:
44
+ bond_type = Chem.rdchem.BondType.TRIPLE
45
+ elif bond == 4:
46
+ bond_type = Chem.rdchem.BondType.AROMATIC
47
+ else:
48
+ continue
49
+ mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
50
+
51
+ try:
52
+ mol = mol.GetMol()
53
+ except rdkit.Chem.KekulizeException:
54
+ print("Can't kekulize molecule")
55
+ mol = None
56
+ return mol
57
+
58
+ def visualize_chain(self, nodes_list, adjacency_matrix):
59
+ RDLogger.DisableLog('rdApp.*')
60
+ # convert graphs to the rdkit molecules
61
+ mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
62
+
63
+ # find the coordinates of atoms in the final molecule
64
+ final_molecule = mols[-1]
65
+ AllChem.Compute2DCoords(final_molecule)
66
+
67
+ coords = []
68
+ for i, atom in enumerate(final_molecule.GetAtoms()):
69
+ positions = final_molecule.GetConformer().GetAtomPosition(i)
70
+ coords.append((positions.x, positions.y, positions.z))
71
+
72
+ # align all the molecules
73
+ for i, mol in enumerate(mols):
74
+ AllChem.Compute2DCoords(mol)
75
+ conf = mol.GetConformer()
76
+ for j, atom in enumerate(mol.GetAtoms()):
77
+ x, y, z = coords[j]
78
+ conf.SetAtomPosition(j, Point3D(x, y, z))
79
+
80
+ # create list of molecule images
81
+ mol_images = []
82
+ for frame, mol in enumerate(mols):
83
+ img = Draw.MolToImage(mol, size=(300, 300), legend=f"Frame {frame}")
84
+ mol_images.append(img)
85
+
86
+ return mol_images