liuganghuggingface
commited on
Upload graph_decoder/visualize_utils.py with huggingface_hub
Browse files
graph_decoder/visualize_utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from rdkit import Chem
|
3 |
+
from rdkit.Chem import Draw, AllChem
|
4 |
+
from rdkit.Geometry import Point3D
|
5 |
+
from rdkit import RDLogger
|
6 |
+
import numpy as np
|
7 |
+
import rdkit.Chem
|
8 |
+
|
9 |
+
class MolecularVisualization:
|
10 |
+
def __init__(self, atom_decoder):
|
11 |
+
self.atom_decoder = atom_decoder
|
12 |
+
|
13 |
+
def mol_from_graphs(self, node_list, adjacency_matrix):
|
14 |
+
"""
|
15 |
+
Convert graphs to rdkit molecules
|
16 |
+
node_list: the nodes of a batch of nodes (bs x n)
|
17 |
+
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
18 |
+
"""
|
19 |
+
# dictionary to map integer value to the char of atom
|
20 |
+
atom_decoder = self.atom_decoder
|
21 |
+
|
22 |
+
# create empty editable mol object
|
23 |
+
mol = Chem.RWMol()
|
24 |
+
|
25 |
+
# add atoms to mol and keep track of index
|
26 |
+
node_to_idx = {}
|
27 |
+
for i in range(len(node_list)):
|
28 |
+
if node_list[i] == -1:
|
29 |
+
continue
|
30 |
+
a = Chem.Atom(atom_decoder[int(node_list[i])])
|
31 |
+
molIdx = mol.AddAtom(a)
|
32 |
+
node_to_idx[i] = molIdx
|
33 |
+
|
34 |
+
for ix, row in enumerate(adjacency_matrix):
|
35 |
+
for iy, bond in enumerate(row):
|
36 |
+
# only traverse half the symmetric matrix
|
37 |
+
if iy <= ix:
|
38 |
+
continue
|
39 |
+
if bond == 1:
|
40 |
+
bond_type = Chem.rdchem.BondType.SINGLE
|
41 |
+
elif bond == 2:
|
42 |
+
bond_type = Chem.rdchem.BondType.DOUBLE
|
43 |
+
elif bond == 3:
|
44 |
+
bond_type = Chem.rdchem.BondType.TRIPLE
|
45 |
+
elif bond == 4:
|
46 |
+
bond_type = Chem.rdchem.BondType.AROMATIC
|
47 |
+
else:
|
48 |
+
continue
|
49 |
+
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
|
50 |
+
|
51 |
+
try:
|
52 |
+
mol = mol.GetMol()
|
53 |
+
except rdkit.Chem.KekulizeException:
|
54 |
+
print("Can't kekulize molecule")
|
55 |
+
mol = None
|
56 |
+
return mol
|
57 |
+
|
58 |
+
def visualize_chain(self, nodes_list, adjacency_matrix):
|
59 |
+
RDLogger.DisableLog('rdApp.*')
|
60 |
+
# convert graphs to the rdkit molecules
|
61 |
+
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
|
62 |
+
|
63 |
+
# find the coordinates of atoms in the final molecule
|
64 |
+
final_molecule = mols[-1]
|
65 |
+
AllChem.Compute2DCoords(final_molecule)
|
66 |
+
|
67 |
+
coords = []
|
68 |
+
for i, atom in enumerate(final_molecule.GetAtoms()):
|
69 |
+
positions = final_molecule.GetConformer().GetAtomPosition(i)
|
70 |
+
coords.append((positions.x, positions.y, positions.z))
|
71 |
+
|
72 |
+
# align all the molecules
|
73 |
+
for i, mol in enumerate(mols):
|
74 |
+
AllChem.Compute2DCoords(mol)
|
75 |
+
conf = mol.GetConformer()
|
76 |
+
for j, atom in enumerate(mol.GetAtoms()):
|
77 |
+
x, y, z = coords[j]
|
78 |
+
conf.SetAtomPosition(j, Point3D(x, y, z))
|
79 |
+
|
80 |
+
# create list of molecule images
|
81 |
+
mol_images = []
|
82 |
+
for frame, mol in enumerate(mols):
|
83 |
+
img = Draw.MolToImage(mol, size=(300, 300), legend=f"Frame {frame}")
|
84 |
+
mol_images.append(img)
|
85 |
+
|
86 |
+
return mol_images
|