Upload 116 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +131 -0
- examples/exp.png +0 -0
- examples/reaction1.png +0 -0
- examples/reaction2.png +0 -0
- examples/reaction3.png +0 -0
- examples/reaction4.png +0 -0
- getReaction.py +78 -0
- molscribe/__init__.py +1 -0
- molscribe/__pycache__/__init__.cpython-310.pyc +0 -0
- molscribe/__pycache__/augment.cpython-310.pyc +0 -0
- molscribe/__pycache__/chemistry.cpython-310.pyc +0 -0
- molscribe/__pycache__/constants.cpython-310.pyc +0 -0
- molscribe/__pycache__/dataset.cpython-310.pyc +0 -0
- molscribe/__pycache__/evaluate.cpython-310.pyc +0 -0
- molscribe/__pycache__/interface.cpython-310.pyc +0 -0
- molscribe/__pycache__/loss.cpython-310.pyc +0 -0
- molscribe/__pycache__/model.cpython-310.pyc +0 -0
- molscribe/__pycache__/tokenizer.cpython-310.pyc +0 -0
- molscribe/__pycache__/utils.cpython-310.pyc +0 -0
- molscribe/augment.py +282 -0
- molscribe/chemistry.py +649 -0
- molscribe/constants.py +130 -0
- molscribe/dataset.py +594 -0
- molscribe/evaluate.py +79 -0
- molscribe/indigo/__init__.py +0 -0
- molscribe/indigo/__pycache__/__init__.cpython-310.pyc +0 -0
- molscribe/indigo/__pycache__/bingo.cpython-310.pyc +0 -0
- molscribe/indigo/__pycache__/inchi.cpython-310.pyc +0 -0
- molscribe/indigo/__pycache__/renderer.cpython-310.pyc +0 -0
- molscribe/indigo/bingo.py +334 -0
- molscribe/indigo/inchi.py +84 -0
- molscribe/indigo/renderer.py +113 -0
- molscribe/inference/__init__.py +4 -0
- molscribe/inference/__pycache__/__init__.cpython-310.pyc +0 -0
- molscribe/inference/__pycache__/beam_search.cpython-310.pyc +0 -0
- molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc +0 -0
- molscribe/inference/__pycache__/greedy_search.cpython-310.pyc +0 -0
- molscribe/inference/beam_search.py +190 -0
- molscribe/inference/decode_strategy.py +63 -0
- molscribe/inference/greedy_search.py +128 -0
- molscribe/interface.py +223 -0
- molscribe/loss.py +125 -0
- molscribe/model.py +397 -0
- molscribe/tokenizer.py +524 -0
- molscribe/transformer/__init__.py +3 -0
- molscribe/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- molscribe/transformer/__pycache__/decoder.cpython-310.pyc +0 -0
- molscribe/transformer/__pycache__/embedding.cpython-310.pyc +0 -0
- molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc +0 -0
- molscribe/transformer/decoder.py +487 -0
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import json
|
4 |
+
from rxnim import RXNIM
|
5 |
+
from getReaction import generate_combined_image
|
6 |
+
import torch
|
7 |
+
from rxn.reaction import Reaction
|
8 |
+
|
9 |
+
PROMPT_DIR = "prompts/"
|
10 |
+
ckpt_path = "./rxn/model/model.ckpt"
|
11 |
+
model = Reaction(ckpt_path, device=torch.device('cpu'))
|
12 |
+
|
13 |
+
# 定义 prompt 文件名到友好名字的映射
|
14 |
+
PROMPT_NAMES = {
|
15 |
+
"2_RxnOCR.txt": "Reaction Image Parsing Workflow",
|
16 |
+
}
|
17 |
+
example_diagram = "examples/exp.png"
|
18 |
+
|
19 |
+
def list_prompt_files_with_names():
|
20 |
+
"""
|
21 |
+
列出 prompts 目录下的所有 .txt 文件,为没有名字的生成默认名字。
|
22 |
+
返回 {friendly_name: filename} 映射。
|
23 |
+
"""
|
24 |
+
prompt_files = {}
|
25 |
+
for f in os.listdir(PROMPT_DIR):
|
26 |
+
if f.endswith(".txt"):
|
27 |
+
# 如果文件名有预定义的名字,使用预定义名字
|
28 |
+
friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}")
|
29 |
+
prompt_files[friendly_name] = f
|
30 |
+
return prompt_files
|
31 |
+
|
32 |
+
def parse_reactions(output_json):
|
33 |
+
"""
|
34 |
+
解析 JSON 格式的反应数据并格式化输出,包含颜色定制。
|
35 |
+
"""
|
36 |
+
reactions_data = json.loads(output_json) # 转换 JSON 字符串为字典
|
37 |
+
reactions_list = reactions_data.get("reactions", [])
|
38 |
+
detailed_output = []
|
39 |
+
|
40 |
+
for reaction in reactions_list:
|
41 |
+
reaction_id = reaction.get("reaction_id", "Unknown ID")
|
42 |
+
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
|
43 |
+
conditions = [
|
44 |
+
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
|
45 |
+
for c in reaction.get("conditions", [])
|
46 |
+
]
|
47 |
+
conditions_1 = [
|
48 |
+
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
|
49 |
+
for c in reaction.get("conditions", [])
|
50 |
+
]
|
51 |
+
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
|
52 |
+
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
|
53 |
+
|
54 |
+
# 构造反应的完整字符串,定制字体颜色
|
55 |
+
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}"
|
56 |
+
full_reaction = f"<span style='color:black'>{full_reaction}</span>"
|
57 |
+
|
58 |
+
# 详细反应格式化输出
|
59 |
+
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
|
60 |
+
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
|
61 |
+
reaction_output += f" Conditions: {', '.join(conditions)}<br>"
|
62 |
+
reaction_output += f" Products: {', '.join(products)}<br>"
|
63 |
+
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>"
|
64 |
+
reaction_output += "<br>"
|
65 |
+
detailed_output.append(reaction_output)
|
66 |
+
|
67 |
+
return detailed_output
|
68 |
+
|
69 |
+
def process_chem_image(image, selected_task):
|
70 |
+
chem_mllm = RXNIM()
|
71 |
+
|
72 |
+
# 将友好名字转换为实际文件名
|
73 |
+
prompt_path = os.path.join(PROMPT_DIR, prompts_with_names[selected_task])
|
74 |
+
image_path = "temp_image.png"
|
75 |
+
image.save(image_path)
|
76 |
+
|
77 |
+
# 调用 RXNIM 处理
|
78 |
+
rxnim_result = chem_mllm.process(image_path, prompt_path)
|
79 |
+
|
80 |
+
# 将 JSON 结果解析为结构化输出
|
81 |
+
detailed_reactions = parse_reactions(rxnim_result)
|
82 |
+
|
83 |
+
# 调用 RxnScribe 模型处理并生成整合图像
|
84 |
+
predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
|
85 |
+
combined_image_path = generate_combined_image(predictions, image_path)
|
86 |
+
|
87 |
+
json_file_path = "output.json"
|
88 |
+
with open(json_file_path, "w") as json_file:
|
89 |
+
json.dump(json.loads(rxnim_result), json_file, indent=4)
|
90 |
+
|
91 |
+
|
92 |
+
# 返回详细反应和整合图像
|
93 |
+
return "\n\n".join(detailed_reactions), combined_image_path, example_diagram, json_file_path
|
94 |
+
|
95 |
+
|
96 |
+
# 获取 prompts 和友好名字
|
97 |
+
prompts_with_names = list_prompt_files_with_names()
|
98 |
+
|
99 |
+
# 示例数据:图像路径 + 任务选项
|
100 |
+
examples = [
|
101 |
+
|
102 |
+
["examples/reaction1.png", "Reaction Image Parsing Workflow"],
|
103 |
+
["examples/reaction2.png", "Reaction Image Parsing Workflow"],
|
104 |
+
["examples/reaction3.png", "Reaction Image Parsing Workflow"],
|
105 |
+
["examples/reaction4.png", "Reaction Image Parsing Workflow"],
|
106 |
+
]
|
107 |
+
|
108 |
+
# 定义 Gradio 界面
|
109 |
+
demo = gr.Interface(
|
110 |
+
fn=process_chem_image,
|
111 |
+
inputs=[
|
112 |
+
gr.Image(type="pil", label="Upload Reaction Image"),
|
113 |
+
gr.Radio(
|
114 |
+
choices=list(prompts_with_names.keys()), # 显示任务名字
|
115 |
+
label="Select a predefined task",
|
116 |
+
),
|
117 |
+
],
|
118 |
+
outputs=[
|
119 |
+
gr.HTML(label="Reaction outputs"),
|
120 |
+
gr.Image(label="Visualization"), # 显示整合图像
|
121 |
+
gr.Image(value=example_diagram, label="Schematic Diagram"),
|
122 |
+
gr.File(label="Download JSON File"),
|
123 |
+
|
124 |
+
],
|
125 |
+
title="Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model",
|
126 |
+
description="Upload a reaction image and select a predefined task prompt.",
|
127 |
+
examples=examples, # 使用嵌套列表作为示例
|
128 |
+
examples_per_page=20,
|
129 |
+
)
|
130 |
+
|
131 |
+
demo.launch()
|
examples/exp.png
ADDED
examples/reaction1.png
ADDED
examples/reaction2.png
ADDED
examples/reaction3.png
ADDED
examples/reaction4.png
ADDED
getReaction.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./rxn/')
|
3 |
+
import torch
|
4 |
+
from rxn.reaction import Reaction
|
5 |
+
import json
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
ckpt_path = "./rxn/model/model.ckpt"
|
10 |
+
model = Reaction(ckpt_path, device=torch.device('cpu'))
|
11 |
+
device = torch.device('cpu')
|
12 |
+
|
13 |
+
def get_reaction(image_path: str) -> list:
|
14 |
+
'''Returns a list of reactions extracted from the image.'''
|
15 |
+
image_file = image_path
|
16 |
+
return json.dumps(model.predict_image_file(image_file, molscribe=True, ocr=True))
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def generate_combined_image(predictions, image_file):
|
21 |
+
"""
|
22 |
+
将预测的图像整合到一个对称的布局中输出。
|
23 |
+
"""
|
24 |
+
output = model.draw_predictions(predictions, image_file=image_file)
|
25 |
+
n_images = len(output)
|
26 |
+
if n_images == 1:
|
27 |
+
n_cols = 1
|
28 |
+
elif n_images == 2:
|
29 |
+
n_cols = 2
|
30 |
+
else:
|
31 |
+
n_cols = 3
|
32 |
+
n_rows = (n_images + n_cols - 1) // n_cols # 计算需要的行数
|
33 |
+
|
34 |
+
# 确保每张图像符合要求
|
35 |
+
processed_images = []
|
36 |
+
for img in output:
|
37 |
+
if len(img.shape) == 2: # 灰度图像
|
38 |
+
img = np.stack([img] * 3, axis=-1) # 转换为 RGB 格式
|
39 |
+
elif img.shape[2] > 3: # RGBA 图像
|
40 |
+
img = img[:, :, :3] # 只保留 RGB 通道
|
41 |
+
if img.dtype == np.float32 or img.dtype == np.float64:
|
42 |
+
img = (img * 255).astype(np.uint8) # 转换为 uint8
|
43 |
+
processed_images.append(img)
|
44 |
+
output = processed_images
|
45 |
+
|
46 |
+
# 为不足的子图位置添加占位图
|
47 |
+
if n_images < n_rows * n_cols:
|
48 |
+
blank_image = np.ones_like(output[0]) * 255 # 生成一个白色占位图
|
49 |
+
while len(output) < n_rows * n_cols:
|
50 |
+
output.append(blank_image)
|
51 |
+
|
52 |
+
# 创建子图画布
|
53 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
|
54 |
+
|
55 |
+
# 确保 axes 是一维数组
|
56 |
+
if isinstance(axes, np.ndarray):
|
57 |
+
axes = axes.flatten()
|
58 |
+
else:
|
59 |
+
axes = [axes] # 单个子图的情况
|
60 |
+
|
61 |
+
# 绘制每张图像
|
62 |
+
for idx, img in enumerate(output):
|
63 |
+
ax = axes[idx]
|
64 |
+
ax.imshow(img)
|
65 |
+
ax.axis('off')
|
66 |
+
if idx < n_images:
|
67 |
+
ax.set_title(f"Reaction {idx + 1}")
|
68 |
+
|
69 |
+
# 删除多余的子图
|
70 |
+
for idx in range(n_images, len(axes)):
|
71 |
+
fig.delaxes(axes[idx])
|
72 |
+
|
73 |
+
# 保存整合图像
|
74 |
+
combined_image_path = "combined_output.png"
|
75 |
+
plt.tight_layout()
|
76 |
+
plt.savefig(combined_image_path)
|
77 |
+
plt.close(fig)
|
78 |
+
return combined_image_path
|
molscribe/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .interface import MolScribe
|
molscribe/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (190 Bytes). View file
|
|
molscribe/__pycache__/augment.cpython-310.pyc
ADDED
Binary file (8.98 kB). View file
|
|
molscribe/__pycache__/chemistry.cpython-310.pyc
ADDED
Binary file (17.5 kB). View file
|
|
molscribe/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (6.18 kB). View file
|
|
molscribe/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (17.6 kB). View file
|
|
molscribe/__pycache__/evaluate.cpython-310.pyc
ADDED
Binary file (3.48 kB). View file
|
|
molscribe/__pycache__/interface.cpython-310.pyc
ADDED
Binary file (8.95 kB). View file
|
|
molscribe/__pycache__/loss.cpython-310.pyc
ADDED
Binary file (4.25 kB). View file
|
|
molscribe/__pycache__/model.cpython-310.pyc
ADDED
Binary file (13.2 kB). View file
|
|
molscribe/__pycache__/tokenizer.cpython-310.pyc
ADDED
Binary file (16.8 kB). View file
|
|
molscribe/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (6.33 kB). View file
|
|
molscribe/augment.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
from albumentations.augmentations.geometric.functional import safe_rotate_enlarged_img_size, _maybe_process_in_chunks, \
|
3 |
+
keypoint_rotate
|
4 |
+
import cv2
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def safe_rotate(
|
11 |
+
img: np.ndarray,
|
12 |
+
angle: int = 0,
|
13 |
+
interpolation: int = cv2.INTER_LINEAR,
|
14 |
+
value: int = None,
|
15 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
16 |
+
):
|
17 |
+
|
18 |
+
old_rows, old_cols = img.shape[:2]
|
19 |
+
|
20 |
+
# getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
|
21 |
+
image_center = (old_cols / 2, old_rows / 2)
|
22 |
+
|
23 |
+
# Rows and columns of the rotated image (not cropped)
|
24 |
+
new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols)
|
25 |
+
|
26 |
+
# Rotation Matrix
|
27 |
+
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
28 |
+
|
29 |
+
# Shift the image to create padding
|
30 |
+
rotation_mat[0, 2] += new_cols / 2 - image_center[0]
|
31 |
+
rotation_mat[1, 2] += new_rows / 2 - image_center[1]
|
32 |
+
|
33 |
+
# CV2 Transformation function
|
34 |
+
warp_affine_fn = _maybe_process_in_chunks(
|
35 |
+
cv2.warpAffine,
|
36 |
+
M=rotation_mat,
|
37 |
+
dsize=(new_cols, new_rows),
|
38 |
+
flags=interpolation,
|
39 |
+
borderMode=border_mode,
|
40 |
+
borderValue=value,
|
41 |
+
)
|
42 |
+
|
43 |
+
# rotate image with the new bounds
|
44 |
+
rotated_img = warp_affine_fn(img)
|
45 |
+
|
46 |
+
return rotated_img
|
47 |
+
|
48 |
+
|
49 |
+
def keypoint_safe_rotate(keypoint, angle, rows, cols):
|
50 |
+
old_rows = rows
|
51 |
+
old_cols = cols
|
52 |
+
|
53 |
+
# Rows and columns of the rotated image (not cropped)
|
54 |
+
new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols)
|
55 |
+
|
56 |
+
col_diff = (new_cols - old_cols) / 2
|
57 |
+
row_diff = (new_rows - old_rows) / 2
|
58 |
+
|
59 |
+
# Shift keypoint
|
60 |
+
shifted_keypoint = (int(keypoint[0] + col_diff), int(keypoint[1] + row_diff), keypoint[2], keypoint[3])
|
61 |
+
|
62 |
+
# Rotate keypoint
|
63 |
+
rotated_keypoint = keypoint_rotate(shifted_keypoint, angle, rows=new_rows, cols=new_cols)
|
64 |
+
|
65 |
+
return rotated_keypoint
|
66 |
+
|
67 |
+
|
68 |
+
class SafeRotate(A.SafeRotate):
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
limit=90,
|
73 |
+
interpolation=cv2.INTER_LINEAR,
|
74 |
+
border_mode=cv2.BORDER_REFLECT_101,
|
75 |
+
value=None,
|
76 |
+
mask_value=None,
|
77 |
+
always_apply=False,
|
78 |
+
p=0.5,
|
79 |
+
):
|
80 |
+
super(SafeRotate, self).__init__(
|
81 |
+
limit=limit,
|
82 |
+
interpolation=interpolation,
|
83 |
+
border_mode=border_mode,
|
84 |
+
value=value,
|
85 |
+
mask_value=mask_value,
|
86 |
+
always_apply=always_apply,
|
87 |
+
p=p)
|
88 |
+
|
89 |
+
def apply(self, img, angle=0, interpolation=cv2.INTER_LINEAR, **params):
|
90 |
+
return safe_rotate(
|
91 |
+
img=img, value=self.value, angle=angle, interpolation=interpolation, border_mode=self.border_mode)
|
92 |
+
|
93 |
+
def apply_to_keypoint(self, keypoint, angle=0, **params):
|
94 |
+
return keypoint_safe_rotate(keypoint, angle=angle, rows=params["rows"], cols=params["cols"])
|
95 |
+
|
96 |
+
|
97 |
+
class CropWhite(A.DualTransform):
|
98 |
+
|
99 |
+
def __init__(self, value=(255, 255, 255), pad=0, p=1.0):
|
100 |
+
super(CropWhite, self).__init__(p=p)
|
101 |
+
self.value = value
|
102 |
+
self.pad = pad
|
103 |
+
assert pad >= 0
|
104 |
+
|
105 |
+
def update_params(self, params, **kwargs):
|
106 |
+
super().update_params(params, **kwargs)
|
107 |
+
assert "image" in kwargs
|
108 |
+
img = kwargs["image"]
|
109 |
+
height, width, _ = img.shape
|
110 |
+
x = (img != self.value).sum(axis=2)
|
111 |
+
if x.sum() == 0:
|
112 |
+
return params
|
113 |
+
row_sum = x.sum(axis=1)
|
114 |
+
top = 0
|
115 |
+
while row_sum[top] == 0 and top+1 < height:
|
116 |
+
top += 1
|
117 |
+
bottom = height
|
118 |
+
while row_sum[bottom-1] == 0 and bottom-1 > top:
|
119 |
+
bottom -= 1
|
120 |
+
col_sum = x.sum(axis=0)
|
121 |
+
left = 0
|
122 |
+
while col_sum[left] == 0 and left+1 < width:
|
123 |
+
left += 1
|
124 |
+
right = width
|
125 |
+
while col_sum[right-1] == 0 and right-1 > left:
|
126 |
+
right -= 1
|
127 |
+
# crop_top = max(0, top - self.pad)
|
128 |
+
# crop_bottom = max(0, height - bottom - self.pad)
|
129 |
+
# crop_left = max(0, left - self.pad)
|
130 |
+
# crop_right = max(0, width - right - self.pad)
|
131 |
+
# params.update({"crop_top": crop_top, "crop_bottom": crop_bottom,
|
132 |
+
# "crop_left": crop_left, "crop_right": crop_right})
|
133 |
+
params.update({"crop_top": top, "crop_bottom": height - bottom,
|
134 |
+
"crop_left": left, "crop_right": width - right})
|
135 |
+
return params
|
136 |
+
|
137 |
+
def apply(self, img, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params):
|
138 |
+
height, width, _ = img.shape
|
139 |
+
img = img[crop_top:height - crop_bottom, crop_left:width - crop_right]
|
140 |
+
img = A.augmentations.pad_with_params(
|
141 |
+
img, self.pad, self.pad, self.pad, self.pad, border_mode=cv2.BORDER_CONSTANT, value=self.value)
|
142 |
+
return img
|
143 |
+
|
144 |
+
def apply_to_keypoint(self, keypoint, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params):
|
145 |
+
x, y, angle, scale = keypoint[:4]
|
146 |
+
return x - crop_left + self.pad, y - crop_top + self.pad, angle, scale
|
147 |
+
|
148 |
+
def get_transform_init_args_names(self):
|
149 |
+
return ('value', 'pad')
|
150 |
+
|
151 |
+
|
152 |
+
class PadWhite(A.DualTransform):
|
153 |
+
|
154 |
+
def __init__(self, pad_ratio=0.2, p=0.5, value=(255, 255, 255)):
|
155 |
+
super(PadWhite, self).__init__(p=p)
|
156 |
+
self.pad_ratio = pad_ratio
|
157 |
+
self.value = value
|
158 |
+
|
159 |
+
def update_params(self, params, **kwargs):
|
160 |
+
super().update_params(params, **kwargs)
|
161 |
+
assert "image" in kwargs
|
162 |
+
img = kwargs["image"]
|
163 |
+
height, width, _ = img.shape
|
164 |
+
side = random.randrange(4)
|
165 |
+
if side == 0:
|
166 |
+
params['pad_top'] = int(height * self.pad_ratio * random.random())
|
167 |
+
elif side == 1:
|
168 |
+
params['pad_bottom'] = int(height * self.pad_ratio * random.random())
|
169 |
+
elif side == 2:
|
170 |
+
params['pad_left'] = int(width * self.pad_ratio * random.random())
|
171 |
+
elif side == 3:
|
172 |
+
params['pad_right'] = int(width * self.pad_ratio * random.random())
|
173 |
+
return params
|
174 |
+
|
175 |
+
def apply(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
|
176 |
+
height, width, _ = img.shape
|
177 |
+
img = A.augmentations.pad_with_params(
|
178 |
+
img, pad_top, pad_bottom, pad_left, pad_right, border_mode=cv2.BORDER_CONSTANT, value=self.value)
|
179 |
+
return img
|
180 |
+
|
181 |
+
def apply_to_keypoint(self, keypoint, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
|
182 |
+
x, y, angle, scale = keypoint[:4]
|
183 |
+
return x + pad_left, y + pad_top, angle, scale
|
184 |
+
|
185 |
+
def get_transform_init_args_names(self):
|
186 |
+
return ('value', 'pad_ratio')
|
187 |
+
|
188 |
+
|
189 |
+
class SaltAndPepperNoise(A.DualTransform):
|
190 |
+
|
191 |
+
def __init__(self, num_dots, value=(0, 0, 0), p=0.5):
|
192 |
+
super().__init__(p)
|
193 |
+
self.num_dots = num_dots
|
194 |
+
self.value = value
|
195 |
+
|
196 |
+
def apply(self, img, **params):
|
197 |
+
height, width, _ = img.shape
|
198 |
+
num_dots = random.randrange(self.num_dots + 1)
|
199 |
+
for i in range(num_dots):
|
200 |
+
x = random.randrange(height)
|
201 |
+
y = random.randrange(width)
|
202 |
+
img[x, y] = self.value
|
203 |
+
return img
|
204 |
+
|
205 |
+
def apply_to_keypoint(self, keypoint, **params):
|
206 |
+
return keypoint
|
207 |
+
|
208 |
+
def get_transform_init_args_names(self):
|
209 |
+
return ('value', 'num_dots')
|
210 |
+
|
211 |
+
class ResizePad(A.DualTransform):
|
212 |
+
|
213 |
+
def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, value=(255, 255, 255)):
|
214 |
+
super(ResizePad, self).__init__(always_apply=True)
|
215 |
+
self.height = height
|
216 |
+
self.width = width
|
217 |
+
self.interpolation = interpolation
|
218 |
+
self.value = value
|
219 |
+
|
220 |
+
def apply(self, img, interpolation=cv2.INTER_LINEAR, **params):
|
221 |
+
h, w, _ = img.shape
|
222 |
+
img = A.augmentations.geometric.functional.resize(
|
223 |
+
img,
|
224 |
+
height=min(h, self.height),
|
225 |
+
width=min(w, self.width),
|
226 |
+
interpolation=interpolation
|
227 |
+
)
|
228 |
+
h, w, _ = img.shape
|
229 |
+
pad_top = (self.height - h) // 2
|
230 |
+
pad_bottom = (self.height - h) - pad_top
|
231 |
+
pad_left = (self.width - w) // 2
|
232 |
+
pad_right = (self.width - w) - pad_left
|
233 |
+
img = A.augmentations.pad_with_params(
|
234 |
+
img,
|
235 |
+
pad_top,
|
236 |
+
pad_bottom,
|
237 |
+
pad_left,
|
238 |
+
pad_right,
|
239 |
+
border_mode=cv2.BORDER_CONSTANT,
|
240 |
+
value=self.value,
|
241 |
+
)
|
242 |
+
return img
|
243 |
+
|
244 |
+
|
245 |
+
def normalized_grid_distortion(
|
246 |
+
img,
|
247 |
+
num_steps=10,
|
248 |
+
xsteps=(),
|
249 |
+
ysteps=(),
|
250 |
+
*args,
|
251 |
+
**kwargs
|
252 |
+
):
|
253 |
+
height, width = img.shape[:2]
|
254 |
+
|
255 |
+
# compensate for smaller last steps in source image.
|
256 |
+
x_step = width // num_steps
|
257 |
+
last_x_step = min(width, ((num_steps + 1) * x_step)) - (num_steps * x_step)
|
258 |
+
xsteps[-1] *= last_x_step / x_step
|
259 |
+
|
260 |
+
y_step = height // num_steps
|
261 |
+
last_y_step = min(height, ((num_steps + 1) * y_step)) - (num_steps * y_step)
|
262 |
+
ysteps[-1] *= last_y_step / y_step
|
263 |
+
|
264 |
+
# now normalize such that distortion never leaves image bounds.
|
265 |
+
tx = width / math.floor(width / num_steps)
|
266 |
+
ty = height / math.floor(height / num_steps)
|
267 |
+
xsteps = np.array(xsteps) * (tx / np.sum(xsteps))
|
268 |
+
ysteps = np.array(ysteps) * (ty / np.sum(ysteps))
|
269 |
+
|
270 |
+
# do actual distortion.
|
271 |
+
return A.augmentations.functional.grid_distortion(img, num_steps, xsteps, ysteps, *args, **kwargs)
|
272 |
+
|
273 |
+
|
274 |
+
class NormalizedGridDistortion(A.augmentations.transforms.GridDistortion):
|
275 |
+
def apply(self, img, stepsx=(), stepsy=(), interpolation=cv2.INTER_LINEAR, **params):
|
276 |
+
return normalized_grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode,
|
277 |
+
self.value)
|
278 |
+
|
279 |
+
def apply_to_mask(self, img, stepsx=(), stepsy=(), **params):
|
280 |
+
return normalized_grid_distortion(
|
281 |
+
img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
|
282 |
+
|
molscribe/chemistry.py
ADDED
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import traceback
|
3 |
+
import numpy as np
|
4 |
+
import multiprocessing
|
5 |
+
|
6 |
+
import rdkit
|
7 |
+
import rdkit.Chem as Chem
|
8 |
+
|
9 |
+
rdkit.RDLogger.DisableLog('rdApp.*')
|
10 |
+
|
11 |
+
from SmilesPE.pretokenizer import atomwise_tokenizer
|
12 |
+
|
13 |
+
from .constants import RGROUP_SYMBOLS, ABBREVIATIONS, VALENCES, FORMULA_REGEX
|
14 |
+
|
15 |
+
|
16 |
+
def is_valid_mol(s, format_='atomtok'):
|
17 |
+
if format_ == 'atomtok':
|
18 |
+
mol = Chem.MolFromSmiles(s)
|
19 |
+
elif format_ == 'inchi':
|
20 |
+
if not s.startswith('InChI=1S'):
|
21 |
+
s = f"InChI=1S/{s}"
|
22 |
+
mol = Chem.MolFromInchi(s)
|
23 |
+
else:
|
24 |
+
raise NotImplemented
|
25 |
+
return mol is not None
|
26 |
+
|
27 |
+
|
28 |
+
def _convert_smiles_to_inchi(smiles):
|
29 |
+
try:
|
30 |
+
mol = Chem.MolFromSmiles(smiles)
|
31 |
+
inchi = Chem.MolToInchi(mol)
|
32 |
+
except:
|
33 |
+
inchi = None
|
34 |
+
return inchi
|
35 |
+
|
36 |
+
|
37 |
+
def convert_smiles_to_inchi(smiles_list, num_workers=16):
|
38 |
+
with multiprocessing.Pool(num_workers) as p:
|
39 |
+
inchi_list = p.map(_convert_smiles_to_inchi, smiles_list, chunksize=128)
|
40 |
+
n_success = sum([x is not None for x in inchi_list])
|
41 |
+
r_success = n_success / len(inchi_list)
|
42 |
+
inchi_list = [x if x else 'InChI=1S/H2O/h1H2' for x in inchi_list]
|
43 |
+
return inchi_list, r_success
|
44 |
+
|
45 |
+
|
46 |
+
def merge_inchi(inchi1, inchi2):
|
47 |
+
replaced = 0
|
48 |
+
inchi1 = copy.deepcopy(inchi1)
|
49 |
+
for i in range(len(inchi1)):
|
50 |
+
if inchi1[i] == 'InChI=1S/H2O/h1H2':
|
51 |
+
inchi1[i] = inchi2[i]
|
52 |
+
replaced += 1
|
53 |
+
return inchi1, replaced
|
54 |
+
|
55 |
+
|
56 |
+
def _get_num_atoms(smiles):
|
57 |
+
try:
|
58 |
+
return Chem.MolFromSmiles(smiles).GetNumAtoms()
|
59 |
+
except:
|
60 |
+
return 0
|
61 |
+
|
62 |
+
|
63 |
+
def get_num_atoms(smiles, num_workers=16):
|
64 |
+
if type(smiles) is str:
|
65 |
+
return _get_num_atoms(smiles)
|
66 |
+
with multiprocessing.Pool(num_workers) as p:
|
67 |
+
num_atoms = p.map(_get_num_atoms, smiles)
|
68 |
+
return num_atoms
|
69 |
+
|
70 |
+
|
71 |
+
def normalize_nodes(nodes, flip_y=True):
|
72 |
+
x, y = nodes[:, 0], nodes[:, 1]
|
73 |
+
minx, maxx = min(x), max(x)
|
74 |
+
miny, maxy = min(y), max(y)
|
75 |
+
x = (x - minx) / max(maxx - minx, 1e-6)
|
76 |
+
if flip_y:
|
77 |
+
y = (maxy - y) / max(maxy - miny, 1e-6)
|
78 |
+
else:
|
79 |
+
y = (y - miny) / max(maxy - miny, 1e-6)
|
80 |
+
return np.stack([x, y], axis=1)
|
81 |
+
|
82 |
+
|
83 |
+
def _verify_chirality(mol, coords, symbols, edges, debug=False):
|
84 |
+
try:
|
85 |
+
n = mol.GetNumAtoms()
|
86 |
+
# Make a temp mol to find chiral centers
|
87 |
+
mol_tmp = mol.GetMol()
|
88 |
+
Chem.SanitizeMol(mol_tmp)
|
89 |
+
|
90 |
+
chiral_centers = Chem.FindMolChiralCenters(
|
91 |
+
mol_tmp, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False)
|
92 |
+
chiral_center_ids = [idx for idx, _ in chiral_centers] # List[Tuple[int, any]] -> List[int]
|
93 |
+
|
94 |
+
# correction to clear pre-condition violation (for some corner cases)
|
95 |
+
for bond in mol.GetBonds():
|
96 |
+
if bond.GetBondType() == Chem.BondType.SINGLE:
|
97 |
+
bond.SetBondDir(Chem.BondDir.NONE)
|
98 |
+
|
99 |
+
# Create conformer from 2D coordinate
|
100 |
+
conf = Chem.Conformer(n)
|
101 |
+
conf.Set3D(True)
|
102 |
+
for i, (x, y) in enumerate(coords):
|
103 |
+
conf.SetAtomPosition(i, (x, 1 - y, 0))
|
104 |
+
mol.AddConformer(conf)
|
105 |
+
Chem.SanitizeMol(mol)
|
106 |
+
Chem.AssignStereochemistryFrom3D(mol)
|
107 |
+
# NOTE: seems that only AssignStereochemistryFrom3D can handle double bond E/Z
|
108 |
+
# So we do this first, remove the conformer and add back the 2D conformer for chiral correction
|
109 |
+
|
110 |
+
mol.RemoveAllConformers()
|
111 |
+
conf = Chem.Conformer(n)
|
112 |
+
conf.Set3D(False)
|
113 |
+
for i, (x, y) in enumerate(coords):
|
114 |
+
conf.SetAtomPosition(i, (x, 1 - y, 0))
|
115 |
+
mol.AddConformer(conf)
|
116 |
+
|
117 |
+
# Magic, inferring chirality from coordinates and BondDir. DO NOT CHANGE.
|
118 |
+
Chem.SanitizeMol(mol)
|
119 |
+
Chem.AssignChiralTypesFromBondDirs(mol)
|
120 |
+
Chem.AssignStereochemistry(mol, force=True)
|
121 |
+
|
122 |
+
# Second loop to reset any wedge/dash bond to be starting from the chiral center)
|
123 |
+
for i in chiral_center_ids:
|
124 |
+
for j in range(n):
|
125 |
+
if edges[i][j] == 5:
|
126 |
+
# assert edges[j][i] == 6
|
127 |
+
mol.RemoveBond(i, j)
|
128 |
+
mol.AddBond(i, j, Chem.BondType.SINGLE)
|
129 |
+
mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINWEDGE)
|
130 |
+
elif edges[i][j] == 6:
|
131 |
+
# assert edges[j][i] == 5
|
132 |
+
mol.RemoveBond(i, j)
|
133 |
+
mol.AddBond(i, j, Chem.BondType.SINGLE)
|
134 |
+
mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINDASH)
|
135 |
+
Chem.AssignChiralTypesFromBondDirs(mol)
|
136 |
+
Chem.AssignStereochemistry(mol, force=True)
|
137 |
+
|
138 |
+
# reset chiral tags for non-carbon atom
|
139 |
+
for atom in mol.GetAtoms():
|
140 |
+
if atom.GetSymbol() != "C":
|
141 |
+
atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
|
142 |
+
mol = mol.GetMol()
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
if debug:
|
146 |
+
raise e
|
147 |
+
pass
|
148 |
+
return mol
|
149 |
+
|
150 |
+
|
151 |
+
def _parse_tokens(tokens: list):
|
152 |
+
"""
|
153 |
+
Parse tokens of condensed formula into list of pairs `(elt, num)`
|
154 |
+
where `num` is the multiplicity of the atom (or nested condensed formula) `elt`
|
155 |
+
Used by `_parse_formula`, which does the same thing but takes a formula in string form as input
|
156 |
+
"""
|
157 |
+
elements = []
|
158 |
+
i = 0
|
159 |
+
j = 0
|
160 |
+
while i < len(tokens):
|
161 |
+
if tokens[i] == '(':
|
162 |
+
while j < len(tokens) and tokens[j] != ')':
|
163 |
+
j += 1
|
164 |
+
elt = _parse_tokens(tokens[i + 1:j])
|
165 |
+
else:
|
166 |
+
elt = tokens[i]
|
167 |
+
j += 1
|
168 |
+
if j < len(tokens) and tokens[j].isnumeric():
|
169 |
+
num = int(tokens[j])
|
170 |
+
j += 1
|
171 |
+
else:
|
172 |
+
num = 1
|
173 |
+
elements.append((elt, num))
|
174 |
+
i = j
|
175 |
+
return elements
|
176 |
+
|
177 |
+
|
178 |
+
def _parse_formula(formula: str):
|
179 |
+
"""
|
180 |
+
Parse condensed formula into list of pairs `(elt, num)`
|
181 |
+
where `num` is the subscript to the atom (or nested condensed formula) `elt`
|
182 |
+
Example: "C2H4O" -> [('C', 2), ('H', 4), ('O', 1)]
|
183 |
+
"""
|
184 |
+
tokens = FORMULA_REGEX.findall(formula)
|
185 |
+
# if ''.join(tokens) != formula:
|
186 |
+
# tokens = FORMULA_REGEX_BACKUP.findall(formula)
|
187 |
+
return _parse_tokens(tokens)
|
188 |
+
|
189 |
+
|
190 |
+
def _expand_carbon(elements: list):
|
191 |
+
"""
|
192 |
+
Given list of pairs `(elt, num)`, output single list of all atoms in order,
|
193 |
+
expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary
|
194 |
+
Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O'])
|
195 |
+
"""
|
196 |
+
expanded = []
|
197 |
+
i = 0
|
198 |
+
while i < len(elements):
|
199 |
+
elt, num = elements[i]
|
200 |
+
# expand carbon sequence
|
201 |
+
if elt == 'C' and num > 1 and i + 1 < len(elements):
|
202 |
+
next_elt, next_num = elements[i + 1]
|
203 |
+
quotient, remainder = next_num // num, next_num % num
|
204 |
+
for _ in range(num):
|
205 |
+
expanded.append('C')
|
206 |
+
for _ in range(quotient):
|
207 |
+
expanded.append(next_elt)
|
208 |
+
for _ in range(remainder):
|
209 |
+
expanded.append(next_elt)
|
210 |
+
i += 2
|
211 |
+
# recurse if `elt` itself is a list (nested formula)
|
212 |
+
elif isinstance(elt, list):
|
213 |
+
new_elt = _expand_carbon(elt)
|
214 |
+
for _ in range(num):
|
215 |
+
expanded.append(new_elt)
|
216 |
+
i += 1
|
217 |
+
# simplest case: simply append `elt` `num` times
|
218 |
+
else:
|
219 |
+
for _ in range(num):
|
220 |
+
expanded.append(elt)
|
221 |
+
i += 1
|
222 |
+
return expanded
|
223 |
+
|
224 |
+
|
225 |
+
def _expand_abbreviation(abbrev):
|
226 |
+
"""
|
227 |
+
Expand abbreviation into its SMILES; also converts [Rn] to [n*]
|
228 |
+
Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula
|
229 |
+
"""
|
230 |
+
if abbrev in ABBREVIATIONS:
|
231 |
+
return ABBREVIATIONS[abbrev].smiles
|
232 |
+
if abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()):
|
233 |
+
if abbrev[1:].isdigit():
|
234 |
+
return f'[{abbrev[1:]}*]'
|
235 |
+
return '*'
|
236 |
+
return f'[{abbrev}]'
|
237 |
+
|
238 |
+
|
239 |
+
def _get_bond_symb(bond_num):
|
240 |
+
"""
|
241 |
+
Get SMILES symbol for a bond given bond order
|
242 |
+
Used in `_condensed_formula_list_to_smiles` while writing the SMILES string
|
243 |
+
"""
|
244 |
+
if bond_num == 0:
|
245 |
+
return '.'
|
246 |
+
if bond_num == 1:
|
247 |
+
return ''
|
248 |
+
if bond_num == 2:
|
249 |
+
return '='
|
250 |
+
if bond_num == 3:
|
251 |
+
return '#'
|
252 |
+
return ''
|
253 |
+
|
254 |
+
|
255 |
+
def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None):
|
256 |
+
"""
|
257 |
+
Converts condensed formula (in the form of a list of symbols) to smiles
|
258 |
+
Input:
|
259 |
+
`formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2
|
260 |
+
`start_bond`: # bonds attached to beginning of formula
|
261 |
+
`end_bond`: # bonds attached to end of formula (deduce automatically if None)
|
262 |
+
`direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically)
|
263 |
+
Returns:
|
264 |
+
`smiles`: smiles corresponding to input condensed formula
|
265 |
+
`bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified
|
266 |
+
`num_trials`: number of trials
|
267 |
+
`success` (bool): whether conversion was successful
|
268 |
+
"""
|
269 |
+
# `direction` not specified: try left to right; if fails, try right to left
|
270 |
+
if direction is None:
|
271 |
+
num_trials = 1
|
272 |
+
for dir_choice in [1, -1]:
|
273 |
+
smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice)
|
274 |
+
num_trials += trials
|
275 |
+
if success:
|
276 |
+
return smiles, bonds_left, num_trials, success
|
277 |
+
return None, None, num_trials, False
|
278 |
+
assert direction == 1 or direction == -1
|
279 |
+
|
280 |
+
def dfs(smiles, bonds_left, cur_idx, add_idx):
|
281 |
+
"""
|
282 |
+
`smiles`: SMILES string so far
|
283 |
+
`cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached)
|
284 |
+
`cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far
|
285 |
+
`bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to
|
286 |
+
`add_idx`: index (in list `formula`) of atom to be attached to current atom
|
287 |
+
`add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far
|
288 |
+
Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2)
|
289 |
+
"""
|
290 |
+
num_trials = 1
|
291 |
+
# end of formula: return result
|
292 |
+
if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1):
|
293 |
+
if end_bond is not None and end_bond != bonds_left:
|
294 |
+
return smiles, bonds_left, num_trials, False
|
295 |
+
return smiles, bonds_left, num_trials, True
|
296 |
+
|
297 |
+
# no more bonds but there are atoms remaining: conversion failed
|
298 |
+
if bonds_left <= 0:
|
299 |
+
return smiles, bonds_left, num_trials, False
|
300 |
+
to_add = formula_list[add_idx] # atom to be added to current atom
|
301 |
+
|
302 |
+
if isinstance(to_add, list): # "atom" added is a list (i.e. nested condensed formula): assume valence of 1
|
303 |
+
if bonds_left > 1:
|
304 |
+
# "atom" added does not use up remaining bonds of current atom
|
305 |
+
# get smiles of "atom" (which is itself a condensed formula)
|
306 |
+
add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction)
|
307 |
+
if val > 0:
|
308 |
+
add_str = _get_bond_symb(val + 1) + add_str
|
309 |
+
num_trials += trials
|
310 |
+
if not success:
|
311 |
+
return smiles, bonds_left, num_trials, False
|
312 |
+
# put smiles of "atom" in parentheses and append to smiles; go to next atom to add to current atom
|
313 |
+
result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction)
|
314 |
+
else:
|
315 |
+
# "atom" added uses up remaining bonds of current atom
|
316 |
+
# get smiles of "atom" and bonds left on it
|
317 |
+
add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction)
|
318 |
+
num_trials += trials
|
319 |
+
if not success:
|
320 |
+
return smiles, bonds_left, num_trials, False
|
321 |
+
# append smiles of "atom" (without parentheses) to smiles; it becomes new current atom
|
322 |
+
result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction)
|
323 |
+
smiles, bonds_left, trials, success = result
|
324 |
+
num_trials += trials
|
325 |
+
return smiles, bonds_left, num_trials, success
|
326 |
+
|
327 |
+
# atom added is a single symbol (as opposed to nested condensed formula)
|
328 |
+
for val in VALENCES.get(to_add, [1]): # try all possible valences of atom added
|
329 |
+
add_str = _expand_abbreviation(to_add) # expand to smiles if symbol is abbreviation
|
330 |
+
if bonds_left > val: # atom added does not use up remaining bonds of current atom; go to next atom to add to current atom
|
331 |
+
if cur_idx >= 0:
|
332 |
+
add_str = _get_bond_symb(val) + add_str
|
333 |
+
result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction)
|
334 |
+
else: # atom added uses up remaining bonds of current atom; it becomes new current atom
|
335 |
+
if cur_idx >= 0:
|
336 |
+
add_str = _get_bond_symb(bonds_left) + add_str
|
337 |
+
result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction)
|
338 |
+
trials, success = result[2:]
|
339 |
+
num_trials += trials
|
340 |
+
if success:
|
341 |
+
return result[0], result[1], num_trials, success
|
342 |
+
if num_trials > 10000:
|
343 |
+
break
|
344 |
+
return smiles, bonds_left, num_trials, False
|
345 |
+
|
346 |
+
cur_idx = -1 if direction == 1 else len(formula_list)
|
347 |
+
add_idx = 0 if direction == 1 else len(formula_list) - 1
|
348 |
+
return dfs('', start_bond, cur_idx, add_idx)
|
349 |
+
|
350 |
+
|
351 |
+
def get_smiles_from_symbol(symbol, mol, atom, bonds):
|
352 |
+
"""
|
353 |
+
Convert symbol (abbrev. or condensed formula) to smiles
|
354 |
+
If condensed formula, determine parsing direction and num. bonds on each side using coordinates
|
355 |
+
"""
|
356 |
+
print(symbol)
|
357 |
+
if symbol in ABBREVIATIONS:
|
358 |
+
return ABBREVIATIONS[symbol].smiles
|
359 |
+
if len(symbol) > 20:
|
360 |
+
return None
|
361 |
+
|
362 |
+
#mol_check = Chem.MolFromSmiles(symbol)
|
363 |
+
#if mol_check:
|
364 |
+
# print(symbol) # Print the symbol to debug
|
365 |
+
# return symbol
|
366 |
+
|
367 |
+
total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds]))
|
368 |
+
formula_list = _expand_carbon(_parse_formula(symbol))
|
369 |
+
smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None)
|
370 |
+
if success:
|
371 |
+
mol_check = Chem.MolFromSmiles(smiles) # Check if the SMILES is valid
|
372 |
+
if mol_check:
|
373 |
+
print(f"smiles:{smiles}") # Print the symbol to debug
|
374 |
+
return smiles
|
375 |
+
|
376 |
+
|
377 |
+
mol_check = Chem.MolFromSmiles(symbol)
|
378 |
+
if mol_check:
|
379 |
+
print(f"symbol:{symbol}") # Print the symbol to debug
|
380 |
+
return symbol
|
381 |
+
|
382 |
+
return None
|
383 |
+
|
384 |
+
|
385 |
+
def _replace_functional_group(smiles):
|
386 |
+
smiles = smiles.replace('<unk>', 'C')
|
387 |
+
for i, r in enumerate(RGROUP_SYMBOLS):
|
388 |
+
symbol = f'[{r}]'
|
389 |
+
if symbol in smiles:
|
390 |
+
if r[0] == 'R' and r[1:].isdigit():
|
391 |
+
smiles = smiles.replace(symbol, f'[{int(r[1:])}*]')
|
392 |
+
else:
|
393 |
+
smiles = smiles.replace(symbol, '*')
|
394 |
+
# For unknown tokens (i.e. rdkit cannot parse), replace them with [{isotope}*], where isotope is an identifier.
|
395 |
+
tokens = atomwise_tokenizer(smiles)
|
396 |
+
new_tokens = []
|
397 |
+
mappings = {} # isotope : symbol
|
398 |
+
isotope = 50
|
399 |
+
for token in tokens:
|
400 |
+
if token[0] == '[':
|
401 |
+
if token[1:-1] in ABBREVIATIONS or Chem.AtomFromSmiles(token) is None:
|
402 |
+
while f'[{isotope}*]' in smiles or f'[{isotope}*]' in new_tokens:
|
403 |
+
isotope += 1
|
404 |
+
placeholder = f'[{isotope}*]'
|
405 |
+
mappings[isotope] = token[1:-1]
|
406 |
+
new_tokens.append(placeholder)
|
407 |
+
continue
|
408 |
+
new_tokens.append(token)
|
409 |
+
smiles = ''.join(new_tokens)
|
410 |
+
return smiles, mappings
|
411 |
+
|
412 |
+
|
413 |
+
def convert_smiles_to_mol(smiles):
|
414 |
+
if smiles is None or smiles == '':
|
415 |
+
return None
|
416 |
+
try:
|
417 |
+
mol = Chem.MolFromSmiles(smiles)
|
418 |
+
except:
|
419 |
+
return None
|
420 |
+
return mol
|
421 |
+
|
422 |
+
|
423 |
+
BOND_TYPES = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE}
|
424 |
+
|
425 |
+
|
426 |
+
def _expand_functional_group(mol, mappings, debug=False):
|
427 |
+
def _need_expand(mol, mappings):
|
428 |
+
return any([len(Chem.GetAtomAlias(atom)) > 0 for atom in mol.GetAtoms()]) or len(mappings) > 0
|
429 |
+
|
430 |
+
if _need_expand(mol, mappings):
|
431 |
+
mol_w = Chem.RWMol(mol)
|
432 |
+
num_atoms = mol_w.GetNumAtoms()
|
433 |
+
for i, atom in enumerate(mol_w.GetAtoms()): # reset radical electrons
|
434 |
+
atom.SetNumRadicalElectrons(0)
|
435 |
+
|
436 |
+
atoms_to_remove = []
|
437 |
+
for i in range(num_atoms):
|
438 |
+
atom = mol_w.GetAtomWithIdx(i)
|
439 |
+
if atom.GetSymbol() == '*':
|
440 |
+
symbol = Chem.GetAtomAlias(atom)
|
441 |
+
isotope = atom.GetIsotope()
|
442 |
+
if isotope > 0 and isotope in mappings:
|
443 |
+
symbol = mappings[isotope]
|
444 |
+
if not (isinstance(symbol, str) and len(symbol) > 0):
|
445 |
+
continue
|
446 |
+
# rgroups do not need to be expanded
|
447 |
+
if symbol in RGROUP_SYMBOLS:
|
448 |
+
continue
|
449 |
+
|
450 |
+
bonds = atom.GetBonds()
|
451 |
+
sub_smiles = get_smiles_from_symbol(symbol, mol_w, atom, bonds)
|
452 |
+
|
453 |
+
# create mol object for abbreviation/condensed formula from its SMILES
|
454 |
+
mol_r = convert_smiles_to_mol(sub_smiles)
|
455 |
+
|
456 |
+
if mol_r is None:
|
457 |
+
# atom.SetAtomicNum(6)
|
458 |
+
atom.SetIsotope(0)
|
459 |
+
continue
|
460 |
+
|
461 |
+
# remove bonds connected to abbreviation/condensed formula
|
462 |
+
adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds]
|
463 |
+
for adjacent_idx in adjacent_indices:
|
464 |
+
mol_w.RemoveBond(i, adjacent_idx)
|
465 |
+
|
466 |
+
adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices]
|
467 |
+
for adjacent_atom, bond in zip(adjacent_atoms, bonds):
|
468 |
+
adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble()))
|
469 |
+
|
470 |
+
# get indices of atoms of main body that connect to substituent
|
471 |
+
bonding_atoms_w = adjacent_indices
|
472 |
+
# assume indices are concated after combine mol_w and mol_r
|
473 |
+
bonding_atoms_r = [mol_w.GetNumAtoms()]
|
474 |
+
for atm in mol_r.GetAtoms():
|
475 |
+
if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0:
|
476 |
+
bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx())
|
477 |
+
|
478 |
+
# combine main body and substituent into a single molecule object
|
479 |
+
combo = Chem.CombineMols(mol_w, mol_r)
|
480 |
+
|
481 |
+
# connect substituent to main body with bonds
|
482 |
+
mol_w = Chem.RWMol(combo)
|
483 |
+
# if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body
|
484 |
+
for atm in bonding_atoms_w:
|
485 |
+
bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons()
|
486 |
+
mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order])
|
487 |
+
|
488 |
+
# reset radical electrons
|
489 |
+
for atm in bonding_atoms_w:
|
490 |
+
mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
|
491 |
+
for atm in bonding_atoms_r:
|
492 |
+
mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
|
493 |
+
atoms_to_remove.append(i)
|
494 |
+
|
495 |
+
# Remove atom in the end, otherwise the id will change
|
496 |
+
# Reverse the order and remove atoms with larger id first
|
497 |
+
atoms_to_remove.sort(reverse=True)
|
498 |
+
for i in atoms_to_remove:
|
499 |
+
mol_w.RemoveAtom(i)
|
500 |
+
smiles = Chem.MolToSmiles(mol_w)
|
501 |
+
mol = mol_w.GetMol()
|
502 |
+
else:
|
503 |
+
smiles = Chem.MolToSmiles(mol)
|
504 |
+
return smiles, mol
|
505 |
+
|
506 |
+
|
507 |
+
def _convert_graph_to_smiles(coords, symbols, edges, image=None, debug=False):
|
508 |
+
mol = Chem.RWMol()
|
509 |
+
n = len(symbols)
|
510 |
+
ids = []
|
511 |
+
for i in range(n):
|
512 |
+
symbol = symbols[i]
|
513 |
+
if symbol[0] == '[':
|
514 |
+
symbol = symbol[1:-1]
|
515 |
+
if symbol in RGROUP_SYMBOLS:
|
516 |
+
atom = Chem.Atom("*")
|
517 |
+
if symbol[0] == 'R' and symbol[1:].isdigit():
|
518 |
+
atom.SetIsotope(int(symbol[1:]))
|
519 |
+
Chem.SetAtomAlias(atom, symbol)
|
520 |
+
elif symbol in ABBREVIATIONS:
|
521 |
+
atom = Chem.Atom("*")
|
522 |
+
Chem.SetAtomAlias(atom, symbol)
|
523 |
+
else:
|
524 |
+
try: # try to get SMILES of atom
|
525 |
+
atom = Chem.AtomFromSmiles(symbols[i])
|
526 |
+
atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
|
527 |
+
except: # otherwise, abbreviation or condensed formula
|
528 |
+
atom = Chem.Atom("*")
|
529 |
+
Chem.SetAtomAlias(atom, symbol)
|
530 |
+
|
531 |
+
if atom.GetSymbol() == '*':
|
532 |
+
atom.SetProp('molFileAlias', symbol)
|
533 |
+
|
534 |
+
idx = mol.AddAtom(atom)
|
535 |
+
assert idx == i
|
536 |
+
ids.append(idx)
|
537 |
+
|
538 |
+
for i in range(n):
|
539 |
+
for j in range(i + 1, n):
|
540 |
+
if edges[i][j] == 1:
|
541 |
+
mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
|
542 |
+
elif edges[i][j] == 2:
|
543 |
+
mol.AddBond(ids[i], ids[j], Chem.BondType.DOUBLE)
|
544 |
+
elif edges[i][j] == 3:
|
545 |
+
mol.AddBond(ids[i], ids[j], Chem.BondType.TRIPLE)
|
546 |
+
elif edges[i][j] == 4:
|
547 |
+
mol.AddBond(ids[i], ids[j], Chem.BondType.AROMATIC)
|
548 |
+
elif edges[i][j] == 5:
|
549 |
+
mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
|
550 |
+
mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINWEDGE)
|
551 |
+
elif edges[i][j] == 6:
|
552 |
+
mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
|
553 |
+
mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINDASH)
|
554 |
+
|
555 |
+
pred_smiles = '<invalid>'
|
556 |
+
|
557 |
+
try:
|
558 |
+
# TODO: move to an util function
|
559 |
+
if image is not None:
|
560 |
+
height, width, _ = image.shape
|
561 |
+
ratio = width / height
|
562 |
+
coords = [[x * ratio * 10, y * 10] for x, y in coords]
|
563 |
+
mol = _verify_chirality(mol, coords, symbols, edges, debug)
|
564 |
+
# molblock is obtained before expanding func groups, otherwise the expanded group won't have coordinates.
|
565 |
+
# TODO: make sure molblock has the abbreviation information
|
566 |
+
pred_molblock = Chem.MolToMolBlock(mol)
|
567 |
+
pred_smiles, mol = _expand_functional_group(mol, {}, debug)
|
568 |
+
success = True
|
569 |
+
except Exception as e:
|
570 |
+
if debug:
|
571 |
+
print(traceback.format_exc())
|
572 |
+
pred_molblock = ''
|
573 |
+
success = False
|
574 |
+
|
575 |
+
if debug:
|
576 |
+
return pred_smiles, pred_molblock, mol, success
|
577 |
+
return pred_smiles, pred_molblock, success
|
578 |
+
|
579 |
+
|
580 |
+
def convert_graph_to_smiles(coords, symbols, edges, images=None, num_workers=16):
|
581 |
+
with multiprocessing.Pool(num_workers) as p:
|
582 |
+
if images is None:
|
583 |
+
results = p.starmap(_convert_graph_to_smiles, zip(coords, symbols, edges), chunksize=128)
|
584 |
+
else:
|
585 |
+
results = p.starmap(_convert_graph_to_smiles, zip(coords, symbols, edges, images), chunksize=128)
|
586 |
+
smiles_list, molblock_list, success = zip(*results)
|
587 |
+
r_success = np.mean(success)
|
588 |
+
return smiles_list, molblock_list, r_success
|
589 |
+
|
590 |
+
|
591 |
+
def _postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, debug=False):
|
592 |
+
if type(smiles) is not str or smiles == '':
|
593 |
+
return '', False
|
594 |
+
mol = None
|
595 |
+
pred_molblock = ''
|
596 |
+
try:
|
597 |
+
pred_smiles = smiles
|
598 |
+
pred_smiles, mappings = _replace_functional_group(pred_smiles)
|
599 |
+
if coords is not None and symbols is not None and edges is not None:
|
600 |
+
pred_smiles = pred_smiles.replace('@', '').replace('/', '').replace('\\', '')
|
601 |
+
mol = Chem.RWMol(Chem.MolFromSmiles(pred_smiles, sanitize=False))
|
602 |
+
mol = _verify_chirality(mol, coords, symbols, edges, debug)
|
603 |
+
else:
|
604 |
+
mol = Chem.MolFromSmiles(pred_smiles, sanitize=False)
|
605 |
+
# pred_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
|
606 |
+
if molblock:
|
607 |
+
pred_molblock = Chem.MolToMolBlock(mol)
|
608 |
+
pred_smiles, mol = _expand_functional_group(mol, mappings)
|
609 |
+
success = True
|
610 |
+
except Exception as e:
|
611 |
+
if debug:
|
612 |
+
print(traceback.format_exc())
|
613 |
+
pred_smiles = smiles
|
614 |
+
pred_molblock = ''
|
615 |
+
success = False
|
616 |
+
if debug:
|
617 |
+
return pred_smiles, pred_molblock, mol, success
|
618 |
+
return pred_smiles, pred_molblock, success
|
619 |
+
|
620 |
+
|
621 |
+
def postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, num_workers=16):
|
622 |
+
with multiprocessing.Pool(num_workers) as p:
|
623 |
+
if coords is not None and symbols is not None and edges is not None:
|
624 |
+
results = p.starmap(_postprocess_smiles, zip(smiles, coords, symbols, edges), chunksize=128)
|
625 |
+
else:
|
626 |
+
results = p.map(_postprocess_smiles, smiles, chunksize=128)
|
627 |
+
smiles_list, molblock_list, success = zip(*results)
|
628 |
+
r_success = np.mean(success)
|
629 |
+
return smiles_list, molblock_list, r_success
|
630 |
+
|
631 |
+
|
632 |
+
def _keep_main_molecule(smiles, debug=False):
|
633 |
+
try:
|
634 |
+
mol = Chem.MolFromSmiles(smiles)
|
635 |
+
frags = Chem.GetMolFrags(mol, asMols=True)
|
636 |
+
if len(frags) > 1:
|
637 |
+
num_atoms = [m.GetNumAtoms() for m in frags]
|
638 |
+
main_mol = frags[np.argmax(num_atoms)]
|
639 |
+
smiles = Chem.MolToSmiles(main_mol)
|
640 |
+
except Exception as e:
|
641 |
+
if debug:
|
642 |
+
print(traceback.format_exc())
|
643 |
+
return smiles
|
644 |
+
|
645 |
+
|
646 |
+
def keep_main_molecule(smiles, num_workers=16):
|
647 |
+
with multiprocessing.Pool(num_workers) as p:
|
648 |
+
results = p.map(_keep_main_molecule, smiles, chunksize=128)
|
649 |
+
return results
|
molscribe/constants.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import re
|
3 |
+
|
4 |
+
ORGANIC_SET = {'B', 'C', 'N', 'O', 'P', 'S', 'F', 'Cl', 'Br', 'I'}
|
5 |
+
|
6 |
+
RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', "R'",
|
7 |
+
'Ra', 'Rb', 'Rc', 'Rd', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar']
|
8 |
+
|
9 |
+
PLACEHOLDER_ATOMS = ["Lv", "Lu", "Nd", "Yb", "At", "Fm", "Er"]
|
10 |
+
|
11 |
+
|
12 |
+
class Substitution(object):
|
13 |
+
'''Define common substitutions for chemical shorthand'''
|
14 |
+
def __init__(self, abbrvs, smarts, smiles, probability):
|
15 |
+
assert type(abbrvs) is list
|
16 |
+
self.abbrvs = abbrvs
|
17 |
+
self.smarts = smarts
|
18 |
+
self.smiles = smiles
|
19 |
+
self.probability = probability
|
20 |
+
|
21 |
+
|
22 |
+
SUBSTITUTIONS: List[Substitution] = [
|
23 |
+
Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5),
|
24 |
+
Substitution(['OCOCH3'], '[#8]-[#6](=[#8])-[#6]', "[O]C(=O)C]", 0.5),
|
25 |
+
Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5),
|
26 |
+
Substitution(['CO2Et', 'COOEt', 'EtO2C'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5),
|
27 |
+
|
28 |
+
Substitution(['OAc'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7),
|
29 |
+
Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7),
|
30 |
+
Substitution(['Ac'], 'C(=O)[CH3]', "[C](=O)C", 0.1),
|
31 |
+
|
32 |
+
Substitution(['OBz'], '[OH0;D2]C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[O]C(=O)c1ccccc1", 0.7), # Benzoyl
|
33 |
+
Substitution(['Bz'], 'C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)c1ccccc1", 0.2), # Benzoyl
|
34 |
+
|
35 |
+
Substitution(['OBn'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[O]Cc1ccccc1", 0.7), # Benzyl
|
36 |
+
Substitution(['Bn'], '[CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[CH2]c1ccccc1", 0.2), # Benzyl
|
37 |
+
|
38 |
+
Substitution(['NHBoc'], '[NH1;D2]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6),
|
39 |
+
Substitution(['NBoc'], '[NH0;D3]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6),
|
40 |
+
Substitution(['Boc'], 'C(=O)OC([CH3])([CH3])[CH3]', "[C](=O)OC(C)(C)C", 0.2),
|
41 |
+
|
42 |
+
Substitution(['Cbm'], 'C(=O)[NH2;D1]', "[C](=O)N", 0.2),
|
43 |
+
Substitution(['Cbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[C](=O)OCc1ccccc1", 0.4),
|
44 |
+
Substitution(['Cy'], '[CH1;X3]1[CH2][CH2][CH2][CH2][CH2]1', "[CH1]1CCCCC1", 0.3),
|
45 |
+
Substitution(['Fmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3',
|
46 |
+
"[C](=O)OCC1c(cccc2)c2c3c1cccc3", 0.6),
|
47 |
+
Substitution(['Mes'], '[cH0]1c([CH3])cc([CH3])cc([CH3])1', "[c]1c(C)cc(C)cc(C)1", 0.5),
|
48 |
+
Substitution(['OMs'], '[OH0;D2]S(=O)(=O)[CH3]', "[O]S(=O)(=O)C", 0.7),
|
49 |
+
Substitution(['Ms'], 'S(=O)(=O)[CH3]', "[S](=O)(=O)C", 0.2),
|
50 |
+
Substitution(['Ph'], '[cH0]1[cH][cH][cH1][cH][cH]1', "[c]1ccccc1", 0.5),
|
51 |
+
Substitution(['PMB'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[CH2]c1ccc(OC)cc1", 0.2),
|
52 |
+
Substitution(['Py'], '[cH0]1[n;+0][cH1][cH1][cH1][cH1]1', "[c]1ncccc1", 0.1),
|
53 |
+
Substitution(['SEM'], '[CH2;D2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]CSi(C)(C)C", 0.2),
|
54 |
+
Substitution(['Suc'], 'C(=O)[CH2][CH2]C(=O)[OH]', "[C](=O)CCC(=O)O", 0.2),
|
55 |
+
Substitution(['TBS'], '[Si]([CH3])([CH3])C([CH3])([CH3])[CH3]', "[Si](C)(C)C(C)(C)C", 0.5),
|
56 |
+
Substitution(['TBZ'], 'C(=S)[cH]1[cH][cH][cH1][cH][cH]1', "[C](=S)c1ccccc1", 0.2),
|
57 |
+
Substitution(['OTf'], '[OH0;D2]S(=O)(=O)C(F)(F)F', "[O]S(=O)(=O)C(F)(F)F", 0.7),
|
58 |
+
Substitution(['Tf'], 'S(=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.2),
|
59 |
+
Substitution(['TFA'], 'C(=O)C(F)(F)F', "[C](=O)C(F)(F)F", 0.3),
|
60 |
+
Substitution(['TMS'], '[Si]([CH3])([CH3])[CH3]', "[Si](C)(C)C", 0.5),
|
61 |
+
Substitution(['Ts'], 'S(=O)(=O)c1[cH1][cH1][cH0]([CH3])[cH1][cH1]1', "[S](=O)(=O)c1ccc(C)cc1", 0.6), # Tos
|
62 |
+
|
63 |
+
# Alkyl chains
|
64 |
+
Substitution(['OMe', 'MeO'], '[OH0;D2][CH3;D1]', "[O]C", 0.3),
|
65 |
+
Substitution(['SMe', 'MeS'], '[SH0;D2][CH3;D1]', "[S]C", 0.3),
|
66 |
+
Substitution(['NMe', 'MeN'], '[N;X3][CH3;D1]', "[NH]C", 0.3),
|
67 |
+
Substitution(['Me'], '[CH3;D1]', "[CH3]", 0.1),
|
68 |
+
Substitution(['OEt', 'EtO'], '[OH0;D2][CH2;D2][CH3]', "[O]CC", 0.5),
|
69 |
+
Substitution(['Et', 'C2H5'], '[CH2;D2][CH3]', "[CH2]C", 0.3),
|
70 |
+
Substitution(['Pr', 'nPr', 'n-Pr'], '[CH2;D2][CH2;D2][CH3]', "[CH2]CC", 0.3),
|
71 |
+
Substitution(['Bu', 'nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3),
|
72 |
+
|
73 |
+
# Branched
|
74 |
+
Substitution(['iPr', 'i-Pr'], '[CH1;D3]([CH3])[CH3]', "[CH1](C)C", 0.2),
|
75 |
+
Substitution(['iBu', 'i-Bu'], '[CH2;D2][CH1;D3]([CH3])[CH3]', "[CH2]C(C)C", 0.2),
|
76 |
+
Substitution(['OiBu'], '[OH0;D2][CH2;D2][CH1;D3]([CH3])[CH3]', "[O]CC(C)C", 0.2),
|
77 |
+
Substitution(['OtBu'], '[OH0;D2][CH0]([CH3])([CH3])[CH3]', "[O]C(C)(C)C", 0.6),
|
78 |
+
Substitution(['tBu', 't-Bu'], '[CH0]([CH3])([CH3])[CH3]', "[C](C)(C)C", 0.3),
|
79 |
+
|
80 |
+
# Other shorthands (MIGHT NOT WANT ALL OF THESE)
|
81 |
+
Substitution(['CF3', 'F3C'], '[CH0;D4](F)(F)F', "[C](F)(F)F", 0.5),
|
82 |
+
Substitution(['NCF3', 'F3CN'], '[N;X3][CH0;D4](F)(F)F', "[NH]C(F)(F)F", 0.5),
|
83 |
+
Substitution(['OCF3', 'F3CO'], '[OH0;X2][CH0;D4](F)(F)F', "[O]C(F)(F)F", 0.5),
|
84 |
+
Substitution(['CCl3'], '[CH0;D4](Cl)(Cl)Cl', "[C](Cl)(Cl)Cl", 0.5),
|
85 |
+
Substitution(['CO2H', 'HO2C', 'COOH'], 'C(=O)[OH]', "[C](=O)O", 0.5), # COOH
|
86 |
+
Substitution(['CN', 'NC'], 'C#[ND1]', "[C]#N", 0.5),
|
87 |
+
Substitution(['OCH3', 'H3CO'], '[OH0;D2][CH3]', "[O]C", 0.4),
|
88 |
+
Substitution(['SO3H'], 'S(=O)(=O)[OH]', "[S](=O)(=O)O", 0.4),
|
89 |
+
Substitution(['CH3O'], '[OH0;D2][CH3]', "[O]C", 0),
|
90 |
+
Substitution(['PhCH2CH2'], '[OH0;D2][CH3]', "C1=CC=CC=C1CC", 0),
|
91 |
+
Substitution(['SO2ToI','SO2Tol'], '[OH0;D2][CH3]', "CS(=O)(=O)C1=CC=CC=C1", 0),
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
]
|
99 |
+
|
100 |
+
ABBREVIATIONS = {abbrv: sub for sub in SUBSTITUTIONS for abbrv in sub.abbrvs}
|
101 |
+
|
102 |
+
VALENCES = {
|
103 |
+
"H": [1], "Li": [1], "Be": [2], "B": [3], "C": [4], "N": [3, 5], "O": [2], "F": [1],
|
104 |
+
"Na": [1], "Mg": [2], "Al": [3], "Si": [4], "P": [5, 3], "S": [6, 2, 4], "Cl": [1], "K": [1], "Ca": [2],
|
105 |
+
"Br": [1], "I": [1]
|
106 |
+
}
|
107 |
+
|
108 |
+
ELEMENTS = [
|
109 |
+
"H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
|
110 |
+
"Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca",
|
111 |
+
"Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
|
112 |
+
"Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr",
|
113 |
+
"Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn",
|
114 |
+
"Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
|
115 |
+
"Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb",
|
116 |
+
"Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg",
|
117 |
+
"Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
|
118 |
+
"Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm",
|
119 |
+
"Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds",
|
120 |
+
"Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
|
121 |
+
]
|
122 |
+
|
123 |
+
COLORS = {
|
124 |
+
u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0',
|
125 |
+
u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75'
|
126 |
+
}
|
127 |
+
|
128 |
+
# tokens of condensed formula
|
129 |
+
FORMULA_REGEX = re.compile(
|
130 |
+
'(' + '|'.join(list(ABBREVIATIONS.keys())) + '|R[0-9]*|[A-Z][a-z]+|[A-Z]|[0-9]+|\(|\))')
|
molscribe/dataset.py
ADDED
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import random
|
5 |
+
import re
|
6 |
+
import string
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils.data import DataLoader, Dataset
|
12 |
+
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
13 |
+
import albumentations as A
|
14 |
+
from albumentations.pytorch import ToTensorV2
|
15 |
+
|
16 |
+
from .indigo import Indigo
|
17 |
+
from .indigo.renderer import IndigoRenderer
|
18 |
+
|
19 |
+
from .augment import SafeRotate, CropWhite, PadWhite, SaltAndPepperNoise
|
20 |
+
from .utils import FORMAT_INFO
|
21 |
+
from .tokenizer import PAD_ID
|
22 |
+
from .chemistry import get_num_atoms, normalize_nodes
|
23 |
+
from .constants import RGROUP_SYMBOLS, SUBSTITUTIONS, ELEMENTS, COLORS
|
24 |
+
|
25 |
+
cv2.setNumThreads(1)
|
26 |
+
|
27 |
+
INDIGO_HYGROGEN_PROB = 0.2
|
28 |
+
INDIGO_FUNCTIONAL_GROUP_PROB = 0.8
|
29 |
+
INDIGO_CONDENSED_PROB = 0.5
|
30 |
+
INDIGO_RGROUP_PROB = 0.5
|
31 |
+
INDIGO_COMMENT_PROB = 0.3
|
32 |
+
INDIGO_DEARMOTIZE_PROB = 0.8
|
33 |
+
INDIGO_COLOR_PROB = 0.2
|
34 |
+
|
35 |
+
|
36 |
+
def get_transforms(input_size, augment=True, rotate=True, debug=False):
|
37 |
+
trans_list = []
|
38 |
+
if augment and rotate:
|
39 |
+
trans_list.append(SafeRotate(limit=90, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255)))
|
40 |
+
trans_list.append(CropWhite(pad=5))
|
41 |
+
if augment:
|
42 |
+
trans_list += [
|
43 |
+
# NormalizedGridDistortion(num_steps=10, distort_limit=0.3),
|
44 |
+
A.CropAndPad(percent=[-0.01, 0.00], keep_size=False, p=0.5),
|
45 |
+
PadWhite(pad_ratio=0.4, p=0.2),
|
46 |
+
A.Downscale(scale_min=0.2, scale_max=0.5, interpolation=3),
|
47 |
+
A.Blur(),
|
48 |
+
A.GaussNoise(),
|
49 |
+
SaltAndPepperNoise(num_dots=20, p=0.5)
|
50 |
+
]
|
51 |
+
trans_list.append(A.Resize(input_size, input_size))
|
52 |
+
if not debug:
|
53 |
+
mean = [0.485, 0.456, 0.406]
|
54 |
+
std = [0.229, 0.224, 0.225]
|
55 |
+
trans_list += [
|
56 |
+
A.ToGray(p=1),
|
57 |
+
A.Normalize(mean=mean, std=std),
|
58 |
+
ToTensorV2(),
|
59 |
+
]
|
60 |
+
return A.Compose(trans_list, keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))
|
61 |
+
|
62 |
+
|
63 |
+
def add_functional_group(indigo, mol, debug=False):
|
64 |
+
if random.random() > INDIGO_FUNCTIONAL_GROUP_PROB:
|
65 |
+
return mol
|
66 |
+
# Delete functional group and add a pseudo atom with its abbrv
|
67 |
+
substitutions = [sub for sub in SUBSTITUTIONS]
|
68 |
+
random.shuffle(substitutions)
|
69 |
+
for sub in substitutions:
|
70 |
+
query = indigo.loadSmarts(sub.smarts)
|
71 |
+
matcher = indigo.substructureMatcher(mol)
|
72 |
+
matched_atoms_ids = set()
|
73 |
+
for match in matcher.iterateMatches(query):
|
74 |
+
if random.random() < sub.probability or debug:
|
75 |
+
atoms = []
|
76 |
+
atoms_ids = set()
|
77 |
+
for item in query.iterateAtoms():
|
78 |
+
atom = match.mapAtom(item)
|
79 |
+
atoms.append(atom)
|
80 |
+
atoms_ids.add(atom.index())
|
81 |
+
if len(matched_atoms_ids.intersection(atoms_ids)) > 0:
|
82 |
+
continue
|
83 |
+
abbrv = random.choice(sub.abbrvs)
|
84 |
+
superatom = mol.addAtom(abbrv)
|
85 |
+
for atom in atoms:
|
86 |
+
for nei in atom.iterateNeighbors():
|
87 |
+
if nei.index() not in atoms_ids:
|
88 |
+
if nei.symbol() == 'H':
|
89 |
+
# indigo won't match explicit hydrogen, so remove them explicitly
|
90 |
+
atoms_ids.add(nei.index())
|
91 |
+
else:
|
92 |
+
superatom.addBond(nei, nei.bond().bondOrder())
|
93 |
+
for id in atoms_ids:
|
94 |
+
mol.getAtom(id).remove()
|
95 |
+
matched_atoms_ids = matched_atoms_ids.union(atoms_ids)
|
96 |
+
return mol
|
97 |
+
|
98 |
+
|
99 |
+
def add_explicit_hydrogen(indigo, mol):
|
100 |
+
atoms = []
|
101 |
+
for atom in mol.iterateAtoms():
|
102 |
+
try:
|
103 |
+
hs = atom.countImplicitHydrogens()
|
104 |
+
if hs > 0:
|
105 |
+
atoms.append((atom, hs))
|
106 |
+
except:
|
107 |
+
continue
|
108 |
+
if len(atoms) > 0 and random.random() < INDIGO_HYGROGEN_PROB:
|
109 |
+
atom, hs = random.choice(atoms)
|
110 |
+
for i in range(hs):
|
111 |
+
h = mol.addAtom('H')
|
112 |
+
h.addBond(atom, 1)
|
113 |
+
return mol
|
114 |
+
|
115 |
+
|
116 |
+
def add_rgroup(indigo, mol, smiles):
|
117 |
+
atoms = []
|
118 |
+
for atom in mol.iterateAtoms():
|
119 |
+
try:
|
120 |
+
hs = atom.countImplicitHydrogens()
|
121 |
+
if hs > 0:
|
122 |
+
atoms.append(atom)
|
123 |
+
except:
|
124 |
+
continue
|
125 |
+
if len(atoms) > 0 and '*' not in smiles:
|
126 |
+
if random.random() < INDIGO_RGROUP_PROB:
|
127 |
+
atom_idx = random.choice(range(len(atoms)))
|
128 |
+
atom = atoms[atom_idx]
|
129 |
+
atoms.pop(atom_idx)
|
130 |
+
symbol = random.choice(RGROUP_SYMBOLS)
|
131 |
+
r = mol.addAtom(symbol)
|
132 |
+
r.addBond(atom, 1)
|
133 |
+
return mol
|
134 |
+
|
135 |
+
|
136 |
+
def get_rand_symb():
|
137 |
+
symb = random.choice(ELEMENTS)
|
138 |
+
if random.random() < 0.1:
|
139 |
+
symb += random.choice(string.ascii_lowercase)
|
140 |
+
if random.random() < 0.1:
|
141 |
+
symb += random.choice(string.ascii_uppercase)
|
142 |
+
if random.random() < 0.1:
|
143 |
+
symb = f'({gen_rand_condensed()})'
|
144 |
+
return symb
|
145 |
+
|
146 |
+
|
147 |
+
def get_rand_num():
|
148 |
+
if random.random() < 0.9:
|
149 |
+
if random.random() < 0.8:
|
150 |
+
return ''
|
151 |
+
else:
|
152 |
+
return str(random.randint(2, 9))
|
153 |
+
else:
|
154 |
+
return '1' + str(random.randint(2, 9))
|
155 |
+
|
156 |
+
|
157 |
+
def gen_rand_condensed():
|
158 |
+
tokens = []
|
159 |
+
for i in range(5):
|
160 |
+
if i >= 1 and random.random() < 0.8:
|
161 |
+
break
|
162 |
+
tokens.append(get_rand_symb())
|
163 |
+
tokens.append(get_rand_num())
|
164 |
+
return ''.join(tokens)
|
165 |
+
|
166 |
+
|
167 |
+
def add_rand_condensed(indigo, mol):
|
168 |
+
atoms = []
|
169 |
+
for atom in mol.iterateAtoms():
|
170 |
+
try:
|
171 |
+
hs = atom.countImplicitHydrogens()
|
172 |
+
if hs > 0:
|
173 |
+
atoms.append(atom)
|
174 |
+
except:
|
175 |
+
continue
|
176 |
+
if len(atoms) > 0 and random.random() < INDIGO_CONDENSED_PROB:
|
177 |
+
atom = random.choice(atoms)
|
178 |
+
symbol = gen_rand_condensed()
|
179 |
+
r = mol.addAtom(symbol)
|
180 |
+
r.addBond(atom, 1)
|
181 |
+
return mol
|
182 |
+
|
183 |
+
|
184 |
+
def generate_output_smiles(indigo, mol):
|
185 |
+
# TODO: if using mol.canonicalSmiles(), explicit H will be removed
|
186 |
+
smiles = mol.smiles()
|
187 |
+
mol = indigo.loadMolecule(smiles)
|
188 |
+
if '*' in smiles:
|
189 |
+
part_a, part_b = smiles.split(' ', maxsplit=1)
|
190 |
+
part_b = re.search(r'\$.*\$', part_b).group(0)[1:-1]
|
191 |
+
symbols = [t for t in part_b.split(';') if len(t) > 0]
|
192 |
+
output = ''
|
193 |
+
cnt = 0
|
194 |
+
for i, c in enumerate(part_a):
|
195 |
+
if c != '*':
|
196 |
+
output += c
|
197 |
+
else:
|
198 |
+
output += f'[{symbols[cnt]}]'
|
199 |
+
cnt += 1
|
200 |
+
return mol, output
|
201 |
+
else:
|
202 |
+
if ' ' in smiles:
|
203 |
+
# special cases with extension
|
204 |
+
smiles = smiles.split(' ')[0]
|
205 |
+
return mol, smiles
|
206 |
+
|
207 |
+
|
208 |
+
def add_comment(indigo):
|
209 |
+
if random.random() < INDIGO_COMMENT_PROB:
|
210 |
+
indigo.setOption('render-comment', str(random.randint(1, 20)) + random.choice(string.ascii_letters))
|
211 |
+
indigo.setOption('render-comment-font-size', random.randint(40, 60))
|
212 |
+
indigo.setOption('render-comment-alignment', random.choice([0, 0.5, 1]))
|
213 |
+
indigo.setOption('render-comment-position', random.choice(['top', 'bottom']))
|
214 |
+
indigo.setOption('render-comment-offset', random.randint(2, 30))
|
215 |
+
|
216 |
+
|
217 |
+
def add_color(indigo, mol):
|
218 |
+
if random.random() < INDIGO_COLOR_PROB:
|
219 |
+
indigo.setOption('render-coloring', True)
|
220 |
+
if random.random() < INDIGO_COLOR_PROB:
|
221 |
+
indigo.setOption('render-base-color', random.choice(list(COLORS.values())))
|
222 |
+
if random.random() < INDIGO_COLOR_PROB:
|
223 |
+
if random.random() < 0.5:
|
224 |
+
indigo.setOption('render-highlight-color-enabled', True)
|
225 |
+
indigo.setOption('render-highlight-color', random.choice(list(COLORS.values())))
|
226 |
+
if random.random() < 0.5:
|
227 |
+
indigo.setOption('render-highlight-thickness-enabled', True)
|
228 |
+
for atom in mol.iterateAtoms():
|
229 |
+
if random.random() < 0.1:
|
230 |
+
atom.highlight()
|
231 |
+
return mol
|
232 |
+
|
233 |
+
|
234 |
+
def get_graph(mol, image, shuffle_nodes=False, pseudo_coords=False):
|
235 |
+
mol.layout()
|
236 |
+
coords, symbols = [], []
|
237 |
+
index_map = {}
|
238 |
+
atoms = [atom for atom in mol.iterateAtoms()]
|
239 |
+
if shuffle_nodes:
|
240 |
+
random.shuffle(atoms)
|
241 |
+
for i, atom in enumerate(atoms):
|
242 |
+
if pseudo_coords:
|
243 |
+
x, y, z = atom.xyz()
|
244 |
+
else:
|
245 |
+
x, y = atom.coords()
|
246 |
+
coords.append([x, y])
|
247 |
+
symbols.append(atom.symbol())
|
248 |
+
index_map[atom.index()] = i
|
249 |
+
if pseudo_coords:
|
250 |
+
coords = normalize_nodes(np.array(coords))
|
251 |
+
h, w, _ = image.shape
|
252 |
+
coords[:, 0] = coords[:, 0] * w
|
253 |
+
coords[:, 1] = coords[:, 1] * h
|
254 |
+
n = len(symbols)
|
255 |
+
edges = np.zeros((n, n), dtype=int)
|
256 |
+
for bond in mol.iterateBonds():
|
257 |
+
s = index_map[bond.source().index()]
|
258 |
+
t = index_map[bond.destination().index()]
|
259 |
+
# 1/2/3/4 : single/double/triple/aromatic
|
260 |
+
edges[s, t] = bond.bondOrder()
|
261 |
+
edges[t, s] = bond.bondOrder()
|
262 |
+
if bond.bondStereo() in [5, 6]:
|
263 |
+
edges[s, t] = bond.bondStereo()
|
264 |
+
edges[t, s] = 11 - bond.bondStereo()
|
265 |
+
graph = {
|
266 |
+
'coords': coords,
|
267 |
+
'symbols': symbols,
|
268 |
+
'edges': edges,
|
269 |
+
'num_atoms': len(symbols)
|
270 |
+
}
|
271 |
+
return graph
|
272 |
+
|
273 |
+
|
274 |
+
def generate_indigo_image(smiles, mol_augment=True, default_option=False, shuffle_nodes=False, pseudo_coords=False,
|
275 |
+
include_condensed=True, debug=False):
|
276 |
+
indigo = Indigo()
|
277 |
+
renderer = IndigoRenderer(indigo)
|
278 |
+
indigo.setOption('render-output-format', 'png')
|
279 |
+
indigo.setOption('render-background-color', '1,1,1')
|
280 |
+
indigo.setOption('render-stereo-style', 'none')
|
281 |
+
indigo.setOption('render-label-mode', 'hetero')
|
282 |
+
indigo.setOption('render-font-family', 'Arial')
|
283 |
+
if not default_option:
|
284 |
+
thickness = random.uniform(0.5, 2) # limit the sum of the following two parameters to be smaller than 4
|
285 |
+
indigo.setOption('render-relative-thickness', thickness)
|
286 |
+
indigo.setOption('render-bond-line-width', random.uniform(1, 4 - thickness))
|
287 |
+
if random.random() < 0.5:
|
288 |
+
indigo.setOption('render-font-family', random.choice(['Arial', 'Times', 'Courier', 'Helvetica']))
|
289 |
+
indigo.setOption('render-label-mode', random.choice(['hetero', 'terminal-hetero']))
|
290 |
+
indigo.setOption('render-implicit-hydrogens-visible', random.choice([True, False]))
|
291 |
+
if random.random() < 0.1:
|
292 |
+
indigo.setOption('render-stereo-style', 'old')
|
293 |
+
if random.random() < 0.2:
|
294 |
+
indigo.setOption('render-atom-ids-visible', True)
|
295 |
+
|
296 |
+
try:
|
297 |
+
mol = indigo.loadMolecule(smiles)
|
298 |
+
if mol_augment:
|
299 |
+
if random.random() < INDIGO_DEARMOTIZE_PROB:
|
300 |
+
mol.dearomatize()
|
301 |
+
else:
|
302 |
+
mol.aromatize()
|
303 |
+
smiles = mol.canonicalSmiles()
|
304 |
+
add_comment(indigo)
|
305 |
+
mol = add_explicit_hydrogen(indigo, mol)
|
306 |
+
mol = add_rgroup(indigo, mol, smiles)
|
307 |
+
if include_condensed:
|
308 |
+
mol = add_rand_condensed(indigo, mol)
|
309 |
+
mol = add_functional_group(indigo, mol, debug)
|
310 |
+
mol = add_color(indigo, mol)
|
311 |
+
mol, smiles = generate_output_smiles(indigo, mol)
|
312 |
+
|
313 |
+
buf = renderer.renderToBuffer(mol)
|
314 |
+
img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1) # decode buffer to image
|
315 |
+
# img = np.repeat(np.expand_dims(img, 2), 3, axis=2) # expand to RGB
|
316 |
+
graph = get_graph(mol, img, shuffle_nodes, pseudo_coords)
|
317 |
+
success = True
|
318 |
+
except Exception:
|
319 |
+
if debug:
|
320 |
+
raise Exception
|
321 |
+
img = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32)
|
322 |
+
graph = {}
|
323 |
+
success = False
|
324 |
+
return img, smiles, graph, success
|
325 |
+
|
326 |
+
|
327 |
+
class TrainDataset(Dataset):
|
328 |
+
def __init__(self, args, df, tokenizer, split='train', dynamic_indigo=False):
|
329 |
+
super().__init__()
|
330 |
+
self.df = df
|
331 |
+
self.args = args
|
332 |
+
self.tokenizer = tokenizer
|
333 |
+
if 'file_path' in df.columns:
|
334 |
+
self.file_paths = df['file_path'].values
|
335 |
+
if not self.file_paths[0].startswith(args.data_path):
|
336 |
+
self.file_paths = [os.path.join(args.data_path, path) for path in df['file_path']]
|
337 |
+
self.smiles = df['SMILES'].values if 'SMILES' in df.columns else None
|
338 |
+
self.formats = args.formats
|
339 |
+
self.labelled = (split == 'train')
|
340 |
+
if self.labelled:
|
341 |
+
self.labels = {}
|
342 |
+
for format_ in self.formats:
|
343 |
+
if format_ in ['atomtok', 'inchi']:
|
344 |
+
field = FORMAT_INFO[format_]['name']
|
345 |
+
if field in df.columns:
|
346 |
+
self.labels[format_] = df[field].values
|
347 |
+
self.transform = get_transforms(args.input_size,
|
348 |
+
augment=(self.labelled and args.augment))
|
349 |
+
# self.fix_transform = A.Compose([A.Transpose(p=1), A.VerticalFlip(p=1)])
|
350 |
+
self.dynamic_indigo = (dynamic_indigo and split == 'train')
|
351 |
+
if self.labelled and not dynamic_indigo and args.coords_file is not None:
|
352 |
+
if args.coords_file == 'aux_file':
|
353 |
+
self.coords_df = df
|
354 |
+
self.pseudo_coords = True
|
355 |
+
else:
|
356 |
+
self.coords_df = pd.read_csv(args.coords_file)
|
357 |
+
self.pseudo_coords = False
|
358 |
+
else:
|
359 |
+
self.coords_df = None
|
360 |
+
self.pseudo_coords = args.pseudo_coords
|
361 |
+
|
362 |
+
def __len__(self):
|
363 |
+
return len(self.df)
|
364 |
+
|
365 |
+
def image_transform(self, image, coords=[], renormalize=False):
|
366 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # .astype(np.float32)
|
367 |
+
augmented = self.transform(image=image, keypoints=coords)
|
368 |
+
image = augmented['image']
|
369 |
+
if len(coords) > 0:
|
370 |
+
coords = np.array(augmented['keypoints'])
|
371 |
+
if renormalize:
|
372 |
+
coords = normalize_nodes(coords, flip_y=False)
|
373 |
+
else:
|
374 |
+
_, height, width = image.shape
|
375 |
+
coords[:, 0] = coords[:, 0] / width
|
376 |
+
coords[:, 1] = coords[:, 1] / height
|
377 |
+
coords = np.array(coords).clip(0, 1)
|
378 |
+
return image, coords
|
379 |
+
return image
|
380 |
+
|
381 |
+
def __getitem__(self, idx):
|
382 |
+
try:
|
383 |
+
return self.getitem(idx)
|
384 |
+
except Exception as e:
|
385 |
+
with open(os.path.join(self.args.save_path, f'error_dataset_{int(time.time())}.log'), 'w') as f:
|
386 |
+
f.write(str(e))
|
387 |
+
raise e
|
388 |
+
|
389 |
+
def getitem(self, idx):
|
390 |
+
ref = {}
|
391 |
+
if self.dynamic_indigo:
|
392 |
+
begin = time.time()
|
393 |
+
image, smiles, graph, success = generate_indigo_image(
|
394 |
+
self.smiles[idx], mol_augment=self.args.mol_augment, default_option=self.args.default_option,
|
395 |
+
shuffle_nodes=self.args.shuffle_nodes, pseudo_coords=self.pseudo_coords,
|
396 |
+
include_condensed=self.args.include_condensed)
|
397 |
+
# raw_image = image
|
398 |
+
end = time.time()
|
399 |
+
if idx < 30 and self.args.save_image:
|
400 |
+
path = os.path.join(self.args.save_path, 'images')
|
401 |
+
os.makedirs(path, exist_ok=True)
|
402 |
+
cv2.imwrite(os.path.join(path, f'{idx}.png'), image)
|
403 |
+
if not success:
|
404 |
+
return idx, None, {}
|
405 |
+
image, coords = self.image_transform(image, graph['coords'], renormalize=self.pseudo_coords)
|
406 |
+
graph['coords'] = coords
|
407 |
+
ref['time'] = end - begin
|
408 |
+
if 'atomtok' in self.formats:
|
409 |
+
max_len = FORMAT_INFO['atomtok']['max_len']
|
410 |
+
label = self.tokenizer['atomtok'].text_to_sequence(smiles, tokenized=False)
|
411 |
+
ref['atomtok'] = torch.LongTensor(label[:max_len])
|
412 |
+
if 'edges' in self.formats and 'atomtok_coords' not in self.formats and 'chartok_coords' not in self.formats:
|
413 |
+
ref['edges'] = torch.tensor(graph['edges'])
|
414 |
+
if 'atomtok_coords' in self.formats:
|
415 |
+
self._process_atomtok_coords(idx, ref, smiles, graph['coords'], graph['edges'],
|
416 |
+
mask_ratio=self.args.mask_ratio)
|
417 |
+
if 'chartok_coords' in self.formats:
|
418 |
+
self._process_chartok_coords(idx, ref, smiles, graph['coords'], graph['edges'],
|
419 |
+
mask_ratio=self.args.mask_ratio)
|
420 |
+
return idx, image, ref
|
421 |
+
else:
|
422 |
+
file_path = self.file_paths[idx]
|
423 |
+
image = cv2.imread(file_path)
|
424 |
+
if image is None:
|
425 |
+
image = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32)
|
426 |
+
print(file_path, 'not found!')
|
427 |
+
if self.coords_df is not None:
|
428 |
+
h, w, _ = image.shape
|
429 |
+
coords = np.array(eval(self.coords_df.loc[idx, 'node_coords']))
|
430 |
+
if self.pseudo_coords:
|
431 |
+
coords = normalize_nodes(coords)
|
432 |
+
coords[:, 0] = coords[:, 0] * w
|
433 |
+
coords[:, 1] = coords[:, 1] * h
|
434 |
+
image, coords = self.image_transform(image, coords, renormalize=self.pseudo_coords)
|
435 |
+
else:
|
436 |
+
image = self.image_transform(image)
|
437 |
+
coords = None
|
438 |
+
if self.labelled:
|
439 |
+
smiles = self.smiles[idx]
|
440 |
+
if 'atomtok' in self.formats:
|
441 |
+
max_len = FORMAT_INFO['atomtok']['max_len']
|
442 |
+
label = self.tokenizer['atomtok'].text_to_sequence(smiles, False)
|
443 |
+
ref['atomtok'] = torch.LongTensor(label[:max_len])
|
444 |
+
if 'atomtok_coords' in self.formats:
|
445 |
+
if coords is not None:
|
446 |
+
self._process_atomtok_coords(idx, ref, smiles, coords, mask_ratio=0)
|
447 |
+
else:
|
448 |
+
self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1)
|
449 |
+
if 'chartok_coords' in self.formats:
|
450 |
+
if coords is not None:
|
451 |
+
self._process_chartok_coords(idx, ref, smiles, coords, mask_ratio=0)
|
452 |
+
else:
|
453 |
+
self._process_chartok_coords(idx, ref, smiles, mask_ratio=1)
|
454 |
+
if self.args.predict_coords and ('atomtok_coords' in self.formats or 'chartok_coords' in self.formats):
|
455 |
+
smiles = self.smiles[idx]
|
456 |
+
if 'atomtok_coords' in self.formats:
|
457 |
+
self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1)
|
458 |
+
if 'chartok_coords' in self.formats:
|
459 |
+
self._process_chartok_coords(idx, ref, smiles, mask_ratio=1)
|
460 |
+
return idx, image, ref
|
461 |
+
|
462 |
+
def _process_atomtok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0):
|
463 |
+
max_len = FORMAT_INFO['atomtok_coords']['max_len']
|
464 |
+
tokenizer = self.tokenizer['atomtok_coords']
|
465 |
+
if smiles is None or type(smiles) is not str:
|
466 |
+
smiles = ""
|
467 |
+
label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio)
|
468 |
+
ref['atomtok_coords'] = torch.LongTensor(label[:max_len])
|
469 |
+
indices = [i for i in indices if i < max_len]
|
470 |
+
ref['atom_indices'] = torch.LongTensor(indices)
|
471 |
+
if tokenizer.continuous_coords:
|
472 |
+
if coords is not None:
|
473 |
+
ref['coords'] = torch.tensor(coords)
|
474 |
+
else:
|
475 |
+
ref['coords'] = torch.ones(len(indices), 2) * -1.
|
476 |
+
if edges is not None:
|
477 |
+
ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)]
|
478 |
+
else:
|
479 |
+
if 'edges' in self.df.columns:
|
480 |
+
edge_list = eval(self.df.loc[idx, 'edges'])
|
481 |
+
n = len(indices)
|
482 |
+
edges = torch.zeros((n, n), dtype=torch.long)
|
483 |
+
for u, v, t in edge_list:
|
484 |
+
if u < n and v < n:
|
485 |
+
if t <= 4:
|
486 |
+
edges[u, v] = t
|
487 |
+
edges[v, u] = t
|
488 |
+
else:
|
489 |
+
edges[u, v] = t
|
490 |
+
edges[v, u] = 11 - t
|
491 |
+
ref['edges'] = edges
|
492 |
+
else:
|
493 |
+
ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100)
|
494 |
+
|
495 |
+
def _process_chartok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0):
|
496 |
+
max_len = FORMAT_INFO['chartok_coords']['max_len']
|
497 |
+
tokenizer = self.tokenizer['chartok_coords']
|
498 |
+
if smiles is None or type(smiles) is not str:
|
499 |
+
smiles = ""
|
500 |
+
label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio)
|
501 |
+
ref['chartok_coords'] = torch.LongTensor(label[:max_len])
|
502 |
+
indices = [i for i in indices if i < max_len]
|
503 |
+
ref['atom_indices'] = torch.LongTensor(indices)
|
504 |
+
if tokenizer.continuous_coords:
|
505 |
+
if coords is not None:
|
506 |
+
ref['coords'] = torch.tensor(coords)
|
507 |
+
else:
|
508 |
+
ref['coords'] = torch.ones(len(indices), 2) * -1.
|
509 |
+
if edges is not None:
|
510 |
+
ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)]
|
511 |
+
else:
|
512 |
+
if 'edges' in self.df.columns:
|
513 |
+
edge_list = eval(self.df.loc[idx, 'edges'])
|
514 |
+
n = len(indices)
|
515 |
+
edges = torch.zeros((n, n), dtype=torch.long)
|
516 |
+
for u, v, t in edge_list:
|
517 |
+
if u < n and v < n:
|
518 |
+
if t <= 4:
|
519 |
+
edges[u, v] = t
|
520 |
+
edges[v, u] = t
|
521 |
+
else:
|
522 |
+
edges[u, v] = t
|
523 |
+
edges[v, u] = 11 - t
|
524 |
+
ref['edges'] = edges
|
525 |
+
else:
|
526 |
+
ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100)
|
527 |
+
|
528 |
+
|
529 |
+
class AuxTrainDataset(Dataset):
|
530 |
+
|
531 |
+
def __init__(self, args, train_df, aux_df, tokenizer):
|
532 |
+
super().__init__()
|
533 |
+
self.train_dataset = TrainDataset(args, train_df, tokenizer, dynamic_indigo=args.dynamic_indigo)
|
534 |
+
self.aux_dataset = TrainDataset(args, aux_df, tokenizer, dynamic_indigo=False)
|
535 |
+
|
536 |
+
def __len__(self):
|
537 |
+
return len(self.train_dataset) + len(self.aux_dataset)
|
538 |
+
|
539 |
+
def __getitem__(self, idx):
|
540 |
+
if idx < len(self.train_dataset):
|
541 |
+
return self.train_dataset[idx]
|
542 |
+
else:
|
543 |
+
return self.aux_dataset[idx - len(self.train_dataset)]
|
544 |
+
|
545 |
+
|
546 |
+
def pad_images(imgs):
|
547 |
+
# B, C, H, W
|
548 |
+
max_shape = [0, 0]
|
549 |
+
for img in imgs:
|
550 |
+
for i in range(len(max_shape)):
|
551 |
+
max_shape[i] = max(max_shape[i], img.shape[-1 - i])
|
552 |
+
stack = []
|
553 |
+
for img in imgs:
|
554 |
+
pad = []
|
555 |
+
for i in range(len(max_shape)):
|
556 |
+
pad = pad + [0, max_shape[i] - img.shape[-1 - i]]
|
557 |
+
stack.append(F.pad(img, pad, value=0))
|
558 |
+
return torch.stack(stack)
|
559 |
+
|
560 |
+
|
561 |
+
def bms_collate(batch):
|
562 |
+
ids = []
|
563 |
+
imgs = []
|
564 |
+
batch = [ex for ex in batch if ex[1] is not None]
|
565 |
+
formats = list(batch[0][2].keys())
|
566 |
+
seq_formats = [k for k in formats if
|
567 |
+
k in ['atomtok', 'inchi', 'nodes', 'atomtok_coords', 'chartok_coords', 'atom_indices']]
|
568 |
+
refs = {key: [[], []] for key in seq_formats}
|
569 |
+
for ex in batch:
|
570 |
+
ids.append(ex[0])
|
571 |
+
imgs.append(ex[1])
|
572 |
+
ref = ex[2]
|
573 |
+
for key in seq_formats:
|
574 |
+
refs[key][0].append(ref[key])
|
575 |
+
refs[key][1].append(torch.LongTensor([len(ref[key])]))
|
576 |
+
# Sequence
|
577 |
+
for key in seq_formats:
|
578 |
+
# this padding should work for atomtok_with_coords too, each of which has shape (length, 4)
|
579 |
+
refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=PAD_ID)
|
580 |
+
refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1)
|
581 |
+
# Time
|
582 |
+
# if 'time' in formats:
|
583 |
+
# refs['time'] = [ex[2]['time'] for ex in batch]
|
584 |
+
# Coords
|
585 |
+
if 'coords' in formats:
|
586 |
+
refs['coords'] = pad_sequence([ex[2]['coords'] for ex in batch], batch_first=True, padding_value=-1.)
|
587 |
+
# Edges
|
588 |
+
if 'edges' in formats:
|
589 |
+
edges_list = [ex[2]['edges'] for ex in batch]
|
590 |
+
max_len = max([len(edges) for edges in edges_list])
|
591 |
+
refs['edges'] = torch.stack(
|
592 |
+
[F.pad(edges, (0, max_len - len(edges), 0, max_len - len(edges)), value=-100) for edges in edges_list],
|
593 |
+
dim=0)
|
594 |
+
return ids, pad_images(imgs), refs
|
molscribe/evaluate.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import multiprocessing
|
3 |
+
|
4 |
+
import rdkit
|
5 |
+
import rdkit.Chem as Chem
|
6 |
+
rdkit.RDLogger.DisableLog('rdApp.*')
|
7 |
+
from SmilesPE.pretokenizer import atomwise_tokenizer
|
8 |
+
|
9 |
+
|
10 |
+
def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True):
|
11 |
+
if type(smiles) is not str or smiles == '':
|
12 |
+
return '', False
|
13 |
+
if ignore_cistrans:
|
14 |
+
smiles = smiles.replace('/', '').replace('\\', '')
|
15 |
+
if replace_rgroup:
|
16 |
+
tokens = atomwise_tokenizer(smiles)
|
17 |
+
for j, token in enumerate(tokens):
|
18 |
+
if token[0] == '[' and token[-1] == ']':
|
19 |
+
symbol = token[1:-1]
|
20 |
+
if symbol[0] == 'R' and symbol[1:].isdigit():
|
21 |
+
tokens[j] = f'[{symbol[1:]}*]'
|
22 |
+
elif Chem.AtomFromSmiles(token) is None:
|
23 |
+
tokens[j] = '*'
|
24 |
+
smiles = ''.join(tokens)
|
25 |
+
try:
|
26 |
+
canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral))
|
27 |
+
success = True
|
28 |
+
except:
|
29 |
+
canon_smiles = smiles
|
30 |
+
success = False
|
31 |
+
return canon_smiles, success
|
32 |
+
|
33 |
+
|
34 |
+
def convert_smiles_to_canonsmiles(
|
35 |
+
smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16):
|
36 |
+
with multiprocessing.Pool(num_workers) as p:
|
37 |
+
results = p.starmap(canonicalize_smiles,
|
38 |
+
[(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list],
|
39 |
+
chunksize=128)
|
40 |
+
canon_smiles, success = zip(*results)
|
41 |
+
return list(canon_smiles), np.mean(success)
|
42 |
+
|
43 |
+
|
44 |
+
class SmilesEvaluator(object):
|
45 |
+
|
46 |
+
def __init__(self, gold_smiles, num_workers=16):
|
47 |
+
self.gold_smiles = gold_smiles
|
48 |
+
self.gold_canon_smiles, self.gold_valid = convert_smiles_to_canonsmiles(gold_smiles, num_workers=num_workers)
|
49 |
+
self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles,
|
50 |
+
ignore_chiral=True, num_workers=num_workers)
|
51 |
+
self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles,
|
52 |
+
ignore_cistrans=True, num_workers=num_workers)
|
53 |
+
self.gold_canon_smiles = self._replace_empty(self.gold_canon_smiles)
|
54 |
+
self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral)
|
55 |
+
self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans)
|
56 |
+
|
57 |
+
def _replace_empty(self, smiles_list):
|
58 |
+
"""Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty."""
|
59 |
+
return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "<empty>"
|
60 |
+
for smiles in smiles_list]
|
61 |
+
|
62 |
+
def evaluate(self, pred_smiles):
|
63 |
+
results = {}
|
64 |
+
results['gold_valid'] = self.gold_valid
|
65 |
+
# Canon SMILES
|
66 |
+
pred_canon_smiles, pred_valid = convert_smiles_to_canonsmiles(pred_smiles)
|
67 |
+
results['canon_smiles_em'] = (np.array(self.gold_canon_smiles) == np.array(pred_canon_smiles)).mean()
|
68 |
+
results['pred_valid'] = pred_valid
|
69 |
+
# Ignore chirality (Graph exact match)
|
70 |
+
pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_chiral=True)
|
71 |
+
results['graph'] = (np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral)).mean()
|
72 |
+
# Ignore double bond cis/trans
|
73 |
+
pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_cistrans=True)
|
74 |
+
results['canon_smiles'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)).mean()
|
75 |
+
# Evaluate on molecules with chiral centers
|
76 |
+
chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g])
|
77 |
+
results['chiral_ratio'] = len(chiral) / len(self.gold_smiles)
|
78 |
+
results['chiral'] = (chiral[:, 0] == chiral[:, 1]).mean() if len(chiral) > 0 else -1
|
79 |
+
return results
|
molscribe/indigo/__init__.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
molscribe/indigo/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (96.8 kB). View file
|
|
molscribe/indigo/__pycache__/bingo.cpython-310.pyc
ADDED
Binary file (11.4 kB). View file
|
|
molscribe/indigo/__pycache__/inchi.cpython-310.pyc
ADDED
Binary file (2.73 kB). View file
|
|
molscribe/indigo/__pycache__/renderer.cpython-310.pyc
ADDED
Binary file (2.8 kB). View file
|
|
molscribe/indigo/bingo.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) from 2009 to Present EPAM Systems.
|
3 |
+
#
|
4 |
+
# This file is part of Indigo toolkit.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import os
|
19 |
+
from . import *
|
20 |
+
|
21 |
+
|
22 |
+
class BingoException(Exception):
|
23 |
+
|
24 |
+
def __init__(self, value):
|
25 |
+
self.value = value
|
26 |
+
|
27 |
+
def __str__(self):
|
28 |
+
if sys.version_info > (3, 0):
|
29 |
+
return repr(self.value.decode('ascii'))
|
30 |
+
else:
|
31 |
+
return repr(self.value)
|
32 |
+
|
33 |
+
|
34 |
+
class Bingo(object):
|
35 |
+
def __init__(self, bingoId, indigo, lib):
|
36 |
+
self._id = bingoId
|
37 |
+
self._indigo = indigo
|
38 |
+
self._lib = lib
|
39 |
+
self._lib.bingoVersion.restype = c_char_p
|
40 |
+
self._lib.bingoVersion.argtypes = None
|
41 |
+
self._lib.bingoCreateDatabaseFile.restype = c_int
|
42 |
+
self._lib.bingoCreateDatabaseFile.argtypes = [c_char_p, c_char_p, c_char_p]
|
43 |
+
self._lib.bingoLoadDatabaseFile.restype = c_int
|
44 |
+
self._lib.bingoLoadDatabaseFile.argtypes = [c_char_p, c_char_p]
|
45 |
+
self._lib.bingoCloseDatabase.restype = c_int
|
46 |
+
self._lib.bingoCloseDatabase.argtypes = [c_int]
|
47 |
+
self._lib.bingoInsertRecordObj.restype = c_int
|
48 |
+
self._lib.bingoInsertRecordObj.argtypes = [c_int, c_int]
|
49 |
+
self._lib.bingoInsertRecordObjWithExtFP.restype = c_int
|
50 |
+
self._lib.bingoInsertRecordObjWithExtFP.argtypes = [c_int, c_int, c_int]
|
51 |
+
self._lib.bingoGetRecordObj.restype = c_int
|
52 |
+
self._lib.bingoGetRecordObj.argtypes = [c_int, c_int]
|
53 |
+
self._lib.bingoInsertRecordObjWithId.restype = c_int
|
54 |
+
self._lib.bingoInsertRecordObjWithId.argtypes = [c_int, c_int, c_int]
|
55 |
+
self._lib.bingoInsertRecordObjWithIdAndExtFP.restype = c_int
|
56 |
+
self._lib.bingoInsertRecordObjWithIdAndExtFP.argtypes = [c_int, c_int, c_int, c_int]
|
57 |
+
self._lib.bingoDeleteRecord.restype = c_int
|
58 |
+
self._lib.bingoDeleteRecord.argtypes = [c_int, c_int]
|
59 |
+
self._lib.bingoSearchSub.restype = c_int
|
60 |
+
self._lib.bingoSearchSub.argtypes = [c_int, c_int, c_char_p]
|
61 |
+
self._lib.bingoSearchExact.restype = c_int
|
62 |
+
self._lib.bingoSearchExact.argtypes = [c_int, c_int, c_char_p]
|
63 |
+
self._lib.bingoSearchMolFormula.restype = c_int
|
64 |
+
self._lib.bingoSearchMolFormula.argtypes = [c_int, c_char_p, c_char_p]
|
65 |
+
self._lib.bingoSearchSim.restype = c_int
|
66 |
+
self._lib.bingoSearchSim.argtypes = [c_int, c_int, c_float, c_float, c_char_p]
|
67 |
+
self._lib.bingoSearchSimWithExtFP.restype = c_int
|
68 |
+
self._lib.bingoSearchSimWithExtFP.argtypes = [c_int, c_int, c_float, c_float, c_int, c_char_p]
|
69 |
+
self._lib.bingoSearchSimTopN.restype = c_int
|
70 |
+
self._lib.bingoSearchSimTopN.argtypes = [c_int, c_int, c_int, c_float, c_char_p]
|
71 |
+
self._lib.bingoSearchSimTopNWithExtFP.restype = c_int
|
72 |
+
self._lib.bingoSearchSimTopNWithExtFP.argtypes = [c_int, c_int, c_int, c_float, c_int, c_char_p]
|
73 |
+
self._lib.bingoEnumerateId.restype = c_int
|
74 |
+
self._lib.bingoEnumerateId.argtypes = [c_int]
|
75 |
+
self._lib.bingoNext.restype = c_int
|
76 |
+
self._lib.bingoNext.argtypes = [c_int]
|
77 |
+
self._lib.bingoGetCurrentId.restype = c_int
|
78 |
+
self._lib.bingoGetCurrentId.argtypes = [c_int]
|
79 |
+
self._lib.bingoGetObject.restype = c_int
|
80 |
+
self._lib.bingoGetObject.argtypes = [c_int]
|
81 |
+
self._lib.bingoEndSearch.restype = c_int
|
82 |
+
self._lib.bingoEndSearch.argtypes = [c_int]
|
83 |
+
self._lib.bingoGetCurrentSimilarityValue.restype = c_float
|
84 |
+
self._lib.bingoGetCurrentSimilarityValue.argtypes = [c_int]
|
85 |
+
self._lib.bingoOptimize.restype = c_int
|
86 |
+
self._lib.bingoOptimize.argtypes = [c_int]
|
87 |
+
self._lib.bingoEstimateRemainingResultsCount.restype = c_int
|
88 |
+
self._lib.bingoEstimateRemainingResultsCount.argtypes = [c_int]
|
89 |
+
self._lib.bingoEstimateRemainingResultsCountError.restype = c_int
|
90 |
+
self._lib.bingoEstimateRemainingResultsCountError.argtypes = [c_int]
|
91 |
+
self._lib.bingoEstimateRemainingTime.restype = c_int
|
92 |
+
self._lib.bingoEstimateRemainingTime.argtypes = [c_int, POINTER(c_float)]
|
93 |
+
self._lib.bingoContainersCount.restype = c_int
|
94 |
+
self._lib.bingoContainersCount.argtypes = [c_int]
|
95 |
+
self._lib.bingoCellsCount.restype = c_int
|
96 |
+
self._lib.bingoCellsCount.argtypes = [c_int]
|
97 |
+
self._lib.bingoCurrentCell.restype = c_int
|
98 |
+
self._lib.bingoCurrentCell.argtypes = [c_int]
|
99 |
+
self._lib.bingoMinCell.restype = c_int
|
100 |
+
self._lib.bingoMinCell.argtypes = [c_int]
|
101 |
+
self._lib.bingoMaxCell.restype = c_int
|
102 |
+
self._lib.bingoMaxCell.argtypes = [c_int]
|
103 |
+
|
104 |
+
def __del__(self):
|
105 |
+
self.close()
|
106 |
+
|
107 |
+
def close(self):
|
108 |
+
self._indigo._setSessionId()
|
109 |
+
if self._id >= 0:
|
110 |
+
Bingo._checkResult(self._indigo, self._lib.bingoCloseDatabase(self._id))
|
111 |
+
self._id = -1
|
112 |
+
|
113 |
+
@staticmethod
|
114 |
+
def _checkResult(indigo, result):
|
115 |
+
if result < 0:
|
116 |
+
raise BingoException(indigo._lib.indigoGetLastError())
|
117 |
+
return result
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def _checkResultPtr (indigo, result):
|
121 |
+
if result is None:
|
122 |
+
raise BingoException(indigo._lib.indigoGetLastError())
|
123 |
+
return result
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def _checkResultString (indigo, result):
|
127 |
+
res = Bingo._checkResultPtr(indigo, result)
|
128 |
+
if sys.version_info >= (3, 0):
|
129 |
+
return res.decode('ascii')
|
130 |
+
else:
|
131 |
+
return res.encode('ascii')
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def _getLib(indigo):
|
135 |
+
if os.name == 'posix' and not platform.mac_ver()[0] and not platform.system().startswith("CYGWIN"):
|
136 |
+
_lib = CDLL(indigo.dllpath + "/libbingo.so")
|
137 |
+
elif os.name == 'nt' or platform.system().startswith("CYGWIN"):
|
138 |
+
_lib = CDLL(indigo.dllpath + "/bingo.dll")
|
139 |
+
elif platform.mac_ver()[0]:
|
140 |
+
_lib = CDLL(indigo.dllpath + "/libbingo.dylib")
|
141 |
+
else:
|
142 |
+
raise BingoException("unsupported OS: " + os.name)
|
143 |
+
return _lib
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def createDatabaseFile(indigo, path, databaseType, options=''):
|
147 |
+
indigo._setSessionId()
|
148 |
+
if not options:
|
149 |
+
options = ''
|
150 |
+
lib = Bingo._getLib(indigo)
|
151 |
+
lib.bingoCreateDatabaseFile.restype = c_int
|
152 |
+
lib.bingoCreateDatabaseFile.argtypes = [c_char_p, c_char_p, c_char_p]
|
153 |
+
return Bingo(Bingo._checkResult(indigo, lib.bingoCreateDatabaseFile(path.encode('ascii'), databaseType.encode('ascii'), options.encode('ascii'))), indigo, lib)
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def loadDatabaseFile(indigo, path, options=''):
|
157 |
+
indigo._setSessionId()
|
158 |
+
if not options:
|
159 |
+
options = ''
|
160 |
+
lib = Bingo._getLib(indigo)
|
161 |
+
lib.bingoLoadDatabaseFile.restype = c_int
|
162 |
+
lib.bingoLoadDatabaseFile.argtypes = [c_char_p, c_char_p]
|
163 |
+
return Bingo(Bingo._checkResult(indigo, lib.bingoLoadDatabaseFile(path.encode('ascii'), options.encode('ascii'))), indigo, lib)
|
164 |
+
|
165 |
+
def version(self):
|
166 |
+
self._indigo._setSessionId()
|
167 |
+
return Bingo._checkResultString(self._indigo, self._lib.bingoVersion())
|
168 |
+
|
169 |
+
def insert(self, indigoObject, index=None):
|
170 |
+
self._indigo._setSessionId()
|
171 |
+
if not index:
|
172 |
+
return Bingo._checkResult(self._indigo, self._lib.bingoInsertRecordObj(self._id, indigoObject.id))
|
173 |
+
else:
|
174 |
+
return Bingo._checkResult(self._indigo,
|
175 |
+
self._lib.bingoInsertRecordObjWithId(self._id, indigoObject.id, index))
|
176 |
+
|
177 |
+
def insertWithExtFP(self, indigoObject, ext_fp, index=None):
|
178 |
+
self._indigo._setSessionId()
|
179 |
+
if not index:
|
180 |
+
return Bingo._checkResult(self._indigo, self._lib.bingoInsertRecordObjWithExtFP(self._id, indigoObject.id, ext_fp.id))
|
181 |
+
else:
|
182 |
+
return Bingo._checkResult(self._indigo,
|
183 |
+
self._lib.bingoInsertRecordObjWithIdAndExtFP(self._id, indigoObject.id, index, ext_fp.id))
|
184 |
+
|
185 |
+
def delete(self, index):
|
186 |
+
self._indigo._setSessionId()
|
187 |
+
Bingo._checkResult(self._indigo, self._lib.bingoDeleteRecord(self._id, index))
|
188 |
+
|
189 |
+
def searchSub(self, query, options=''):
|
190 |
+
self._indigo._setSessionId()
|
191 |
+
if not options:
|
192 |
+
options = ''
|
193 |
+
return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchSub(self._id, query.id, options.encode('ascii'))),
|
194 |
+
self._indigo, self)
|
195 |
+
|
196 |
+
def searchExact(self, query, options=''):
|
197 |
+
self._indigo._setSessionId()
|
198 |
+
if not options:
|
199 |
+
options = ''
|
200 |
+
return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchExact(self._id, query.id, options.encode('ascii'))),
|
201 |
+
self._indigo, self)
|
202 |
+
|
203 |
+
def searchSim(self, query, minSim, maxSim, metric='tanimoto'):
|
204 |
+
self._indigo._setSessionId()
|
205 |
+
if not metric:
|
206 |
+
metric = 'tanimoto'
|
207 |
+
return BingoObject(
|
208 |
+
Bingo._checkResult(self._indigo, self._lib.bingoSearchSim(self._id, query.id, minSim, maxSim, metric.encode('ascii'))),
|
209 |
+
self._indigo, self)
|
210 |
+
|
211 |
+
def searchSimWithExtFP(self, query, minSim, maxSim, ext_fp, metric='tanimoto'):
|
212 |
+
self._indigo._setSessionId()
|
213 |
+
if not metric:
|
214 |
+
metric = 'tanimoto'
|
215 |
+
return BingoObject(
|
216 |
+
Bingo._checkResult(self._indigo, self._lib.bingoSearchSimWithExtFP(self._id, query.id, minSim, maxSim, ext_fp.id, metric.encode('ascii'))),
|
217 |
+
self._indigo, self)
|
218 |
+
|
219 |
+
def searchSimTopN(self, query, limit, minSim, metric='tanimoto'):
|
220 |
+
self._indigo._setSessionId()
|
221 |
+
if not metric:
|
222 |
+
metric = 'tanimoto'
|
223 |
+
return BingoObject(
|
224 |
+
Bingo._checkResult(self._indigo, self._lib.bingoSearchSimTopN(self._id, query.id, limit, minSim, metric.encode('ascii'))),
|
225 |
+
self._indigo, self)
|
226 |
+
|
227 |
+
def searchSimTopNWithExtFP(self, query, limit, minSim, ext_fp, metric='tanimoto'):
|
228 |
+
self._indigo._setSessionId()
|
229 |
+
if not metric:
|
230 |
+
metric = 'tanimoto'
|
231 |
+
return BingoObject(
|
232 |
+
Bingo._checkResult(self._indigo, self._lib.bingoSearchSimTopNWithExtFP(self._id, query.id, limit, minSim, ext_fp.id, metric.encode('ascii'))),
|
233 |
+
self._indigo, self)
|
234 |
+
|
235 |
+
def enumerateId(self):
|
236 |
+
self._indigo._setSessionId()
|
237 |
+
e = self._lib.bingoEnumerateId(self._id)
|
238 |
+
result = Bingo._checkResult(self._indigo, e)
|
239 |
+
return BingoObject(result, self._indigo, self)
|
240 |
+
|
241 |
+
def searchMolFormula(self, query, options=''):
|
242 |
+
self._indigo._setSessionId()
|
243 |
+
if not options:
|
244 |
+
options = ''
|
245 |
+
return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchMolFormula(self._id, query.encode('ascii'), options.encode('ascii'))),
|
246 |
+
self._indigo, self)
|
247 |
+
|
248 |
+
def optimize(self):
|
249 |
+
self._indigo._setSessionId()
|
250 |
+
Bingo._checkResult(self._indigo, self._lib.bingoOptimize(self._id))
|
251 |
+
|
252 |
+
def getRecordById (self, id):
|
253 |
+
self._indigo._setSessionId()
|
254 |
+
return IndigoObject(self._indigo, Bingo._checkResult(self._indigo, self._lib.bingoGetRecordObj(self._id, id)))
|
255 |
+
|
256 |
+
class BingoObject(object):
|
257 |
+
def __init__(self, objId, indigo, bingo):
|
258 |
+
self._id = objId
|
259 |
+
self._indigo = indigo
|
260 |
+
self._bingo = bingo
|
261 |
+
|
262 |
+
def __del__(self):
|
263 |
+
self.close()
|
264 |
+
|
265 |
+
def close(self):
|
266 |
+
self._indigo._setSessionId()
|
267 |
+
if self._id >= 0:
|
268 |
+
Bingo._checkResult(self._indigo, self._bingo._lib.bingoEndSearch(self._id))
|
269 |
+
self._id = -1
|
270 |
+
|
271 |
+
def next(self):
|
272 |
+
self._indigo._setSessionId()
|
273 |
+
return (Bingo._checkResult(self._indigo, self._bingo._lib.bingoNext(self._id)) == 1)
|
274 |
+
|
275 |
+
def getCurrentId(self):
|
276 |
+
self._indigo._setSessionId()
|
277 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetCurrentId(self._id))
|
278 |
+
|
279 |
+
def getIndigoObject(self):
|
280 |
+
self._indigo._setSessionId()
|
281 |
+
return IndigoObject(self._indigo, Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetObject(self._id)))
|
282 |
+
|
283 |
+
def getCurrentSimilarityValue(self):
|
284 |
+
self._indigo._setSessionId()
|
285 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetCurrentSimilarityValue(self._id))
|
286 |
+
|
287 |
+
def estimateRemainingResultsCount(self):
|
288 |
+
self._indigo._setSessionId()
|
289 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingResultsCount(self._id))
|
290 |
+
|
291 |
+
def estimateRemainingResultsCountError(self):
|
292 |
+
self._indigo._setSessionId()
|
293 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingResultsCountError(self._id))
|
294 |
+
|
295 |
+
def estimateRemainingTime(self):
|
296 |
+
self._indigo._setSessionId()
|
297 |
+
value = c_float()
|
298 |
+
Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingTime(self._id, pointer(value)))
|
299 |
+
return value.value
|
300 |
+
|
301 |
+
def containersCount(self):
|
302 |
+
self._indigo._setSessionId()
|
303 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoContainersCount(self._id))
|
304 |
+
|
305 |
+
def cellsCount(self):
|
306 |
+
self._indigo._setSessionId()
|
307 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoCellsCount(self._id))
|
308 |
+
|
309 |
+
def currentCell(self):
|
310 |
+
self._indigo._setSessionId()
|
311 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoCurrentCell(self._id))
|
312 |
+
|
313 |
+
def minCell(self):
|
314 |
+
self._indigo._setSessionId()
|
315 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoMinCell(self._id))
|
316 |
+
|
317 |
+
def maxCell(self):
|
318 |
+
self._indigo._setSessionId()
|
319 |
+
return Bingo._checkResult(self._indigo, self._bingo._lib.bingoMaxCell(self._id))
|
320 |
+
|
321 |
+
def __enter__(self):
|
322 |
+
return self
|
323 |
+
|
324 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
325 |
+
self.close()
|
326 |
+
|
327 |
+
def __iter__(self):
|
328 |
+
return self
|
329 |
+
|
330 |
+
def __next__(self):
|
331 |
+
next_item = self.next()
|
332 |
+
if next_item:
|
333 |
+
return self
|
334 |
+
raise StopIteration
|
molscribe/indigo/inchi.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) from 2009 to Present EPAM Systems.
|
3 |
+
#
|
4 |
+
# This file is part of Indigo toolkit.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
from . import *
|
19 |
+
|
20 |
+
|
21 |
+
class IndigoInchi(object):
|
22 |
+
def __init__(self, indigo):
|
23 |
+
self.indigo = indigo
|
24 |
+
|
25 |
+
if os.name == 'posix' and not platform.mac_ver()[0] and not platform.system().startswith("CYGWIN"):
|
26 |
+
self._lib = CDLL(indigo.dllpath + "/libindigo-inchi.so")
|
27 |
+
elif os.name == 'nt' or platform.system().startswith("CYGWIN"):
|
28 |
+
self._lib = CDLL(indigo.dllpath + "\indigo-inchi.dll")
|
29 |
+
elif platform.mac_ver()[0]:
|
30 |
+
self._lib = CDLL(indigo.dllpath + "/libindigo-inchi.dylib")
|
31 |
+
else:
|
32 |
+
raise IndigoException("unsupported OS: " + os.name)
|
33 |
+
|
34 |
+
self._lib.indigoInchiVersion.restype = c_char_p
|
35 |
+
self._lib.indigoInchiVersion.argtypes = []
|
36 |
+
self._lib.indigoInchiResetOptions.restype = c_int
|
37 |
+
self._lib.indigoInchiResetOptions.argtypes = []
|
38 |
+
self._lib.indigoInchiLoadMolecule.restype = c_int
|
39 |
+
self._lib.indigoInchiLoadMolecule.argtypes = [c_char_p]
|
40 |
+
self._lib.indigoInchiGetInchi.restype = c_char_p
|
41 |
+
self._lib.indigoInchiGetInchi.argtypes = [c_int]
|
42 |
+
self._lib.indigoInchiGetInchiKey.restype = c_char_p
|
43 |
+
self._lib.indigoInchiGetInchiKey.argtypes = [c_char_p]
|
44 |
+
self._lib.indigoInchiGetWarning.restype = c_char_p
|
45 |
+
self._lib.indigoInchiGetWarning.argtypes = []
|
46 |
+
self._lib.indigoInchiGetLog.restype = c_char_p
|
47 |
+
self._lib.indigoInchiGetLog.argtypes = []
|
48 |
+
self._lib.indigoInchiGetAuxInfo.restype = c_char_p
|
49 |
+
self._lib.indigoInchiGetAuxInfo.argtypes = []
|
50 |
+
|
51 |
+
def resetOptions(self):
|
52 |
+
self.indigo._setSessionId()
|
53 |
+
self.indigo._checkResult(self._lib.indigoInchiResetOptions())
|
54 |
+
|
55 |
+
def loadMolecule(self, inchi):
|
56 |
+
self.indigo._setSessionId()
|
57 |
+
res = self.indigo._checkResult(self._lib.indigoInchiLoadMolecule(inchi.encode('ascii')))
|
58 |
+
if res == 0:
|
59 |
+
return None
|
60 |
+
return self.indigo.IndigoObject(self.indigo, res)
|
61 |
+
|
62 |
+
def version(self):
|
63 |
+
self.indigo._setSessionId()
|
64 |
+
return self.indigo._checkResultString(self._lib.indigoInchiVersion())
|
65 |
+
|
66 |
+
def getInchi(self, molecule):
|
67 |
+
self.indigo._setSessionId()
|
68 |
+
return self.indigo._checkResultString(self._lib.indigoInchiGetInchi(molecule.id))
|
69 |
+
|
70 |
+
def getInchiKey(self, inchi):
|
71 |
+
self.indigo._setSessionId()
|
72 |
+
return self.indigo._checkResultString(self._lib.indigoInchiGetInchiKey(inchi.encode('ascii')))
|
73 |
+
|
74 |
+
def getWarning(self):
|
75 |
+
self.indigo._setSessionId()
|
76 |
+
return self.indigo._checkResultString(self._lib.indigoInchiGetWarning())
|
77 |
+
|
78 |
+
def getLog(self):
|
79 |
+
self.indigo._setSessionId()
|
80 |
+
return self.indigo._checkResultString(self._lib.indigoInchiGetLog())
|
81 |
+
|
82 |
+
def getAuxInfo(self):
|
83 |
+
self.indigo._setSessionId()
|
84 |
+
return self.indigo._checkResultString(self._lib.indigoInchiGetAuxInfo())
|
molscribe/indigo/renderer.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) from 2009 to Present EPAM Systems.
|
3 |
+
#
|
4 |
+
# This file is part of Indigo toolkit.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import os
|
19 |
+
import platform
|
20 |
+
from ctypes import CDLL, POINTER, c_char_p, c_int
|
21 |
+
|
22 |
+
from . import IndigoException
|
23 |
+
|
24 |
+
|
25 |
+
class IndigoRenderer(object):
|
26 |
+
def __init__(self, indigo):
|
27 |
+
self.indigo = indigo
|
28 |
+
|
29 |
+
if (
|
30 |
+
os.name == "posix"
|
31 |
+
and not platform.mac_ver()[0]
|
32 |
+
and not platform.system().startswith("CYGWIN")
|
33 |
+
):
|
34 |
+
self._lib = CDLL(indigo.dllpath + "/libindigo-renderer.so")
|
35 |
+
elif os.name == "nt" or platform.system().startswith("CYGWIN"):
|
36 |
+
self._lib = CDLL(indigo.dllpath + "\indigo-renderer.dll")
|
37 |
+
elif platform.mac_ver()[0]:
|
38 |
+
self._lib = CDLL(indigo.dllpath + "/libindigo-renderer.dylib")
|
39 |
+
else:
|
40 |
+
raise IndigoException("unsupported OS: " + os.name)
|
41 |
+
|
42 |
+
self._lib.indigoRender.restype = c_int
|
43 |
+
self._lib.indigoRender.argtypes = [c_int, c_int]
|
44 |
+
self._lib.indigoRenderToFile.restype = c_int
|
45 |
+
self._lib.indigoRenderToFile.argtypes = [c_int, c_char_p]
|
46 |
+
self._lib.indigoRenderGrid.restype = c_int
|
47 |
+
self._lib.indigoRenderGrid.argtypes = [
|
48 |
+
c_int,
|
49 |
+
POINTER(c_int),
|
50 |
+
c_int,
|
51 |
+
c_int,
|
52 |
+
]
|
53 |
+
self._lib.indigoRenderGridToFile.restype = c_int
|
54 |
+
self._lib.indigoRenderGridToFile.argtypes = [
|
55 |
+
c_int,
|
56 |
+
POINTER(c_int),
|
57 |
+
c_int,
|
58 |
+
c_char_p,
|
59 |
+
]
|
60 |
+
self._lib.indigoRenderReset.restype = c_int
|
61 |
+
self._lib.indigoRenderReset.argtypes = [c_int]
|
62 |
+
|
63 |
+
def renderToBuffer(self, obj):
|
64 |
+
self.indigo._setSessionId()
|
65 |
+
wb = self.indigo.writeBuffer()
|
66 |
+
try:
|
67 |
+
self.indigo._checkResult(self._lib.indigoRender(obj.id, wb.id))
|
68 |
+
return wb.toBuffer()
|
69 |
+
finally:
|
70 |
+
wb.dispose()
|
71 |
+
|
72 |
+
def renderToFile(self, obj, filename):
|
73 |
+
self.indigo._setSessionId()
|
74 |
+
self.indigo._checkResult(
|
75 |
+
self._lib.indigoRenderToFile(obj.id, filename.encode("ascii"))
|
76 |
+
)
|
77 |
+
|
78 |
+
def renderGridToFile(self, objects, refatoms, ncolumns, filename):
|
79 |
+
self.indigo._setSessionId()
|
80 |
+
arr = None
|
81 |
+
if refatoms:
|
82 |
+
if len(refatoms) != objects.count():
|
83 |
+
raise IndigoException(
|
84 |
+
"renderGridToFile(): refatoms[] size must be equal to the number of objects"
|
85 |
+
)
|
86 |
+
arr = (c_int * len(refatoms))()
|
87 |
+
for i in range(len(refatoms)):
|
88 |
+
arr[i] = refatoms[i]
|
89 |
+
self.indigo._checkResult(
|
90 |
+
self._lib.indigoRenderGridToFile(
|
91 |
+
objects.id, arr, ncolumns, filename.encode("ascii")
|
92 |
+
)
|
93 |
+
)
|
94 |
+
|
95 |
+
def renderGridToBuffer(self, objects, refatoms, ncolumns):
|
96 |
+
self.indigo._setSessionId()
|
97 |
+
arr = None
|
98 |
+
if refatoms:
|
99 |
+
if len(refatoms) != objects.count():
|
100 |
+
raise IndigoException(
|
101 |
+
"renderGridToBuffer(): refatoms[] size must be equal to the number of objects"
|
102 |
+
)
|
103 |
+
arr = (c_int * len(refatoms))()
|
104 |
+
for i in range(len(refatoms)):
|
105 |
+
arr[i] = refatoms[i]
|
106 |
+
wb = self.indigo.writeBuffer()
|
107 |
+
try:
|
108 |
+
self.indigo._checkResult(
|
109 |
+
self._lib.indigoRenderGrid(objects.id, arr, ncolumns, wb.id)
|
110 |
+
)
|
111 |
+
return wb.toBuffer()
|
112 |
+
finally:
|
113 |
+
wb.dispose()
|
molscribe/inference/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .greedy_search import GreedySearch
|
2 |
+
from .beam_search import BeamSearch
|
3 |
+
|
4 |
+
__all__ = ["GreedySearch", "BeamSearch"]
|
molscribe/inference/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (282 Bytes). View file
|
|
molscribe/inference/__pycache__/beam_search.cpython-310.pyc
ADDED
Binary file (5.44 kB). View file
|
|
molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc
ADDED
Binary file (2.68 kB). View file
|
|
molscribe/inference/__pycache__/greedy_search.cpython-310.pyc
ADDED
Binary file (4.11 kB). View file
|
|
molscribe/inference/beam_search.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .decode_strategy import DecodeStrategy
|
3 |
+
|
4 |
+
|
5 |
+
class BeamSearch(DecodeStrategy):
|
6 |
+
"""Generation with beam search.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length,
|
10 |
+
return_attention, max_length):
|
11 |
+
super(BeamSearch, self).__init__(
|
12 |
+
pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length)
|
13 |
+
self.beam_size = beam_size
|
14 |
+
self.n_best = n_best
|
15 |
+
|
16 |
+
# result caching
|
17 |
+
self.hypotheses = [[] for _ in range(batch_size)]
|
18 |
+
|
19 |
+
# beam state
|
20 |
+
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool)
|
21 |
+
|
22 |
+
self._batch_offset = torch.arange(batch_size, dtype=torch.long)
|
23 |
+
|
24 |
+
self.select_indices = None
|
25 |
+
self.done = False
|
26 |
+
|
27 |
+
def initialize(self, memory_bank, device=None):
|
28 |
+
"""Repeat src objects `beam_size` times.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def fn_map_state(state, dim):
|
32 |
+
return torch.repeat_interleave(state, self.beam_size, dim=dim)
|
33 |
+
|
34 |
+
memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0)
|
35 |
+
if device is None:
|
36 |
+
device = memory_bank.device
|
37 |
+
|
38 |
+
self.memory_length = memory_bank.size(1)
|
39 |
+
super().initialize(memory_bank, device)
|
40 |
+
|
41 |
+
self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device)
|
42 |
+
self._beam_offset = torch.arange(
|
43 |
+
0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device)
|
44 |
+
self.topk_log_probs = torch.tensor(
|
45 |
+
[0.0] + [float("-inf")] * (self.beam_size - 1), device=device
|
46 |
+
).repeat(self.batch_size)
|
47 |
+
# buffers for the topk scores and 'backpointer'
|
48 |
+
self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device)
|
49 |
+
self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device)
|
50 |
+
self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device)
|
51 |
+
|
52 |
+
return fn_map_state, memory_bank
|
53 |
+
|
54 |
+
@property
|
55 |
+
def current_predictions(self):
|
56 |
+
return self.alive_seq[:, -1]
|
57 |
+
|
58 |
+
@property
|
59 |
+
def current_backptr(self):
|
60 |
+
# for testing
|
61 |
+
return self.select_indices.view(self.batch_size, self.beam_size)
|
62 |
+
|
63 |
+
@property
|
64 |
+
def batch_offset(self):
|
65 |
+
return self._batch_offset
|
66 |
+
|
67 |
+
def _pick(self, log_probs):
|
68 |
+
"""Return token decision for a step.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
log_probs (FloatTensor): (B, vocab_size)
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
topk_scores (FloatTensor): (B, beam_size)
|
75 |
+
topk_ids (LongTensor): (B, beam_size)
|
76 |
+
"""
|
77 |
+
vocab_size = log_probs.size(-1)
|
78 |
+
|
79 |
+
# Flatten probs into a list of probabilities.
|
80 |
+
curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
|
81 |
+
topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1)
|
82 |
+
return topk_scores, topk_ids
|
83 |
+
|
84 |
+
def advance(self, log_probs, attn):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
log_probs: (B * beam_size, vocab_size)
|
88 |
+
"""
|
89 |
+
vocab_size = log_probs.size(-1)
|
90 |
+
|
91 |
+
# (non-finished) batch_size
|
92 |
+
_B = log_probs.shape[0] // self.beam_size
|
93 |
+
|
94 |
+
step = len(self) # alive_seq
|
95 |
+
self.ensure_min_length(log_probs)
|
96 |
+
|
97 |
+
# Multiply probs by the beam probability
|
98 |
+
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
|
99 |
+
|
100 |
+
curr_length = step + 1
|
101 |
+
curr_scores = log_probs / curr_length # avg log_prob
|
102 |
+
self.topk_scores, self.topk_ids = self._pick(curr_scores)
|
103 |
+
# topk_scores/topk_ids: (batch_size, beam_size)
|
104 |
+
|
105 |
+
# Recover log probs
|
106 |
+
torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs)
|
107 |
+
|
108 |
+
# Resolve beam origin and map to batch index flat representation.
|
109 |
+
self._batch_index = self.topk_ids // vocab_size
|
110 |
+
self._batch_index += self._beam_offset[:_B].unsqueeze(1)
|
111 |
+
self.select_indices = self._batch_index.view(_B * self.beam_size)
|
112 |
+
self.topk_ids.fmod_(vocab_size) # resolve true word ids
|
113 |
+
|
114 |
+
# Append last prediction.
|
115 |
+
self.alive_seq = torch.cat(
|
116 |
+
[self.alive_seq.index_select(0, self.select_indices),
|
117 |
+
self.topk_ids.view(_B * self.beam_size, 1)], -1)
|
118 |
+
|
119 |
+
if self.return_attention:
|
120 |
+
current_attn = attn.index_select(1, self.select_indices)
|
121 |
+
if step == 1:
|
122 |
+
self.alive_attn = current_attn
|
123 |
+
else:
|
124 |
+
self.alive_attn = self.alive_attn.index_select(
|
125 |
+
1, self.select_indices)
|
126 |
+
self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
|
127 |
+
|
128 |
+
self.is_finished = self.topk_ids.eq(self.eos)
|
129 |
+
self.ensure_max_length()
|
130 |
+
|
131 |
+
def update_finished(self):
|
132 |
+
_B_old = self.topk_log_probs.shape[0]
|
133 |
+
step = self.alive_seq.shape[-1] # len(self)
|
134 |
+
self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
|
135 |
+
|
136 |
+
self.is_finished = self.is_finished.to('cpu')
|
137 |
+
self.top_beam_finished |= self.is_finished[:, 0].eq(1)
|
138 |
+
predictions = self.alive_seq.view(_B_old, self.beam_size, step)
|
139 |
+
attention = (
|
140 |
+
self.alive_attn.view(
|
141 |
+
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
|
142 |
+
if self.alive_attn is not None else None)
|
143 |
+
non_finished_batch = []
|
144 |
+
for i in range(self.is_finished.size(0)):
|
145 |
+
b = self._batch_offset[i]
|
146 |
+
finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1)
|
147 |
+
# Store finished hypothesis for this batch.
|
148 |
+
for j in finished_hyp: # Beam level: finished beam j in batch i
|
149 |
+
self.hypotheses[b].append((
|
150 |
+
self.topk_scores[i, j],
|
151 |
+
predictions[i, j, 1:], # Ignore start token
|
152 |
+
attention[:, i, j, :self.memory_length]
|
153 |
+
if attention is not None else None))
|
154 |
+
# End condition is the top beam finished and we can return
|
155 |
+
# n_best hypotheses.
|
156 |
+
finish_flag = self.top_beam_finished[i] != 0
|
157 |
+
if finish_flag and len(self.hypotheses[b]) >= self.n_best:
|
158 |
+
best_hyp = sorted(
|
159 |
+
self.hypotheses[b], key=lambda x: x[0], reverse=True)
|
160 |
+
for n, (score, pred, attn) in enumerate(best_hyp):
|
161 |
+
if n >= self.n_best:
|
162 |
+
break
|
163 |
+
self.scores[b].append(score.item())
|
164 |
+
self.predictions[b].append(pred)
|
165 |
+
self.attention[b].append(
|
166 |
+
attn if attn is not None else [])
|
167 |
+
else:
|
168 |
+
non_finished_batch.append(i)
|
169 |
+
non_finished = torch.tensor(non_finished_batch)
|
170 |
+
|
171 |
+
if len(non_finished) == 0:
|
172 |
+
self.done = True
|
173 |
+
return
|
174 |
+
|
175 |
+
_B_new = non_finished.shape[0]
|
176 |
+
# Remove finished batches for the next step
|
177 |
+
self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished)
|
178 |
+
self._batch_offset = self._batch_offset.index_select(0, non_finished)
|
179 |
+
non_finished = non_finished.to(self.topk_ids.device)
|
180 |
+
self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished)
|
181 |
+
self._batch_index = self._batch_index.index_select(0, non_finished)
|
182 |
+
self.select_indices = self._batch_index.view(_B_new * self.beam_size)
|
183 |
+
self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1))
|
184 |
+
self.topk_scores = self.topk_scores.index_select(0, non_finished)
|
185 |
+
self.topk_ids = self.topk_ids.index_select(0, non_finished)
|
186 |
+
|
187 |
+
if self.alive_attn is not None:
|
188 |
+
inp_seq_len = self.alive_attn.size(-1)
|
189 |
+
self.alive_attn = attention.index_select(1, non_finished) \
|
190 |
+
.view(step - 1, _B_new * self.beam_size, inp_seq_len)
|
molscribe/inference/decode_strategy.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class DecodeStrategy(object):
|
5 |
+
def __init__(self, pad, bos, eos, batch_size, parallel_paths, min_length, max_length,
|
6 |
+
return_attention=False, return_hidden=False):
|
7 |
+
self.pad = pad
|
8 |
+
self.bos = bos
|
9 |
+
self.eos = eos
|
10 |
+
|
11 |
+
self.batch_size = batch_size
|
12 |
+
self.parallel_paths = parallel_paths
|
13 |
+
# result catching
|
14 |
+
self.predictions = [[] for _ in range(batch_size)]
|
15 |
+
self.scores = [[] for _ in range(batch_size)]
|
16 |
+
self.token_scores = [[] for _ in range(batch_size)]
|
17 |
+
self.attention = [[] for _ in range(batch_size)]
|
18 |
+
self.hidden = [[] for _ in range(batch_size)]
|
19 |
+
|
20 |
+
self.alive_attn = None
|
21 |
+
self.alive_hidden = None
|
22 |
+
|
23 |
+
self.min_length = min_length
|
24 |
+
self.max_length = max_length
|
25 |
+
|
26 |
+
n_paths = batch_size * parallel_paths
|
27 |
+
self.return_attention = return_attention
|
28 |
+
self.return_hidden = return_hidden
|
29 |
+
|
30 |
+
self.done = False
|
31 |
+
|
32 |
+
def initialize(self, memory_bank, device=None):
|
33 |
+
if device is None:
|
34 |
+
device = torch.device('cpu')
|
35 |
+
self.alive_seq = torch.full(
|
36 |
+
[self.batch_size * self.parallel_paths, 1], self.bos,
|
37 |
+
dtype=torch.long, device=device)
|
38 |
+
self.is_finished = torch.zeros(
|
39 |
+
[self.batch_size, self.parallel_paths],
|
40 |
+
dtype=torch.uint8, device=device)
|
41 |
+
self.alive_log_token_scores = torch.zeros(
|
42 |
+
[self.batch_size * self.parallel_paths, 0],
|
43 |
+
dtype=torch.float, device=device)
|
44 |
+
|
45 |
+
return None, memory_bank
|
46 |
+
|
47 |
+
def __len__(self):
|
48 |
+
return self.alive_seq.shape[1]
|
49 |
+
|
50 |
+
def ensure_min_length(self, log_probs):
|
51 |
+
if len(self) <= self.min_length:
|
52 |
+
log_probs[:, self.eos] = -1e20 # forced non-end
|
53 |
+
|
54 |
+
def ensure_max_length(self):
|
55 |
+
if len(self) == self.max_length + 1:
|
56 |
+
self.is_finished.fill_(1)
|
57 |
+
|
58 |
+
def advance(self, log_probs, attn):
|
59 |
+
raise NotImplementedError()
|
60 |
+
|
61 |
+
def update_finished(self):
|
62 |
+
raise NotImplementedError
|
63 |
+
|
molscribe/inference/greedy_search.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .decode_strategy import DecodeStrategy
|
3 |
+
|
4 |
+
|
5 |
+
def sample_with_temperature(logits, sampling_temp, keep_topk):
|
6 |
+
"""Select next tokens randomly from the top k possible next tokens.
|
7 |
+
|
8 |
+
Samples from a categorical distribution over the ``keep_topk`` words using
|
9 |
+
the category probabilities ``logits / sampling_temp``.
|
10 |
+
"""
|
11 |
+
|
12 |
+
if sampling_temp == 0.0 or keep_topk == 1:
|
13 |
+
# argmax
|
14 |
+
topk_scores, topk_ids = logits.topk(1, dim=-1)
|
15 |
+
if sampling_temp > 0:
|
16 |
+
topk_scores /= sampling_temp
|
17 |
+
else:
|
18 |
+
logits = torch.div(logits, sampling_temp)
|
19 |
+
if keep_topk > 0:
|
20 |
+
top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
|
21 |
+
kth_best = top_values[:, -1].view([-1, 1])
|
22 |
+
kth_best = kth_best.repeat([1, logits.shape[1]]).float()
|
23 |
+
ignore = torch.lt(logits, kth_best)
|
24 |
+
logits = logits.masked_fill(ignore, -10000)
|
25 |
+
|
26 |
+
dist = torch.distributions.Multinomial(logits=logits, total_count=1)
|
27 |
+
topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
|
28 |
+
topk_scores = logits.gather(dim=1, index=topk_ids)
|
29 |
+
|
30 |
+
return topk_ids, topk_scores
|
31 |
+
|
32 |
+
|
33 |
+
class GreedySearch(DecodeStrategy):
|
34 |
+
"""Select next tokens randomly from the top k possible next tokens.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, pad, bos, eos, batch_size, min_length, max_length,
|
38 |
+
return_attention=False, return_hidden=False, sampling_temp=1, keep_topk=1):
|
39 |
+
super().__init__(
|
40 |
+
pad, bos, eos, batch_size, 1, min_length, max_length, return_attention, return_hidden)
|
41 |
+
self.sampling_temp = sampling_temp
|
42 |
+
self.keep_topk = keep_topk
|
43 |
+
self.topk_scores = None
|
44 |
+
|
45 |
+
def initialize(self, memory_bank, device=None):
|
46 |
+
fn_map_state = None
|
47 |
+
|
48 |
+
if device is None:
|
49 |
+
device = memory_bank.device
|
50 |
+
|
51 |
+
self.memory_length = memory_bank.size(1)
|
52 |
+
super().initialize(memory_bank, device)
|
53 |
+
|
54 |
+
self.select_indices = torch.arange(
|
55 |
+
self.batch_size, dtype=torch.long, device=device)
|
56 |
+
self.original_batch_idx = torch.arange(
|
57 |
+
self.batch_size, dtype=torch.long, device=device)
|
58 |
+
|
59 |
+
return fn_map_state, memory_bank
|
60 |
+
|
61 |
+
@property
|
62 |
+
def current_predictions(self):
|
63 |
+
return self.alive_seq[:, -1]
|
64 |
+
|
65 |
+
@property
|
66 |
+
def batch_offset(self):
|
67 |
+
return self.select_indices
|
68 |
+
|
69 |
+
def _pick(self, log_probs):
|
70 |
+
"""Function used to pick next tokens.
|
71 |
+
"""
|
72 |
+
topk_ids, topk_scores = sample_with_temperature(
|
73 |
+
log_probs, self.sampling_temp, self.keep_topk)
|
74 |
+
return topk_ids, topk_scores
|
75 |
+
|
76 |
+
def advance(self, log_probs, attn=None, hidden=None, label=None):
|
77 |
+
"""Select next tokens randomly from the top k possible next tokens.
|
78 |
+
"""
|
79 |
+
self.ensure_min_length(log_probs)
|
80 |
+
topk_ids, self.topk_scores = self._pick(log_probs) # log_probs: b x v; topk_ids & self.topk_scores: b x (t=1)
|
81 |
+
self.is_finished = topk_ids.eq(self.eos)
|
82 |
+
if label is not None:
|
83 |
+
label = label.view_as(self.is_finished)
|
84 |
+
self.is_finished = label.eq(self.eos)
|
85 |
+
self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) # b x (l+1) (first element is <bos>; note l = len(self)-1)
|
86 |
+
self.alive_log_token_scores = torch.cat([self.alive_log_token_scores, self.topk_scores], -1)
|
87 |
+
|
88 |
+
if self.return_attention:
|
89 |
+
if self.alive_attn is None:
|
90 |
+
self.alive_attn = attn
|
91 |
+
else:
|
92 |
+
self.alive_attn = torch.cat([self.alive_attn, attn], 1)
|
93 |
+
if self.return_hidden:
|
94 |
+
if self.alive_hidden is None:
|
95 |
+
self.alive_hidden = hidden
|
96 |
+
else:
|
97 |
+
self.alive_hidden = torch.cat([self.alive_hidden, hidden], 1) # b x l x h
|
98 |
+
self.ensure_max_length()
|
99 |
+
|
100 |
+
def update_finished(self):
|
101 |
+
"""Finalize scores and predictions."""
|
102 |
+
# is_finished indicates the decoder finished generating the sequence. Remove it from the batch and update
|
103 |
+
# the results.
|
104 |
+
finished_batches = self.is_finished.view(-1).nonzero()
|
105 |
+
for b in finished_batches.view(-1):
|
106 |
+
b_orig = self.original_batch_idx[b]
|
107 |
+
# scores/predictions/attention are lists,
|
108 |
+
# (to be compatible with beam-search)
|
109 |
+
self.scores[b_orig].append(torch.exp(torch.mean(self.alive_log_token_scores[b])).item())
|
110 |
+
self.token_scores[b_orig].append(torch.exp(self.alive_log_token_scores[b]).tolist())
|
111 |
+
self.predictions[b_orig].append(self.alive_seq[b, 1:]) # skip <bos>
|
112 |
+
self.attention[b_orig].append(
|
113 |
+
self.alive_attn[b, :, :self.memory_length] if self.alive_attn is not None else [])
|
114 |
+
self.hidden[b_orig].append(
|
115 |
+
self.alive_hidden[b, :] if self.alive_hidden is not None else [])
|
116 |
+
self.done = self.is_finished.all()
|
117 |
+
if self.done:
|
118 |
+
return
|
119 |
+
is_alive = ~self.is_finished.view(-1)
|
120 |
+
self.alive_seq = self.alive_seq[is_alive]
|
121 |
+
self.alive_log_token_scores = self.alive_log_token_scores[is_alive]
|
122 |
+
if self.alive_attn is not None:
|
123 |
+
self.alive_attn = self.alive_attn[is_alive]
|
124 |
+
if self.alive_hidden is not None:
|
125 |
+
self.alive_hidden = self.alive_hidden[is_alive]
|
126 |
+
self.select_indices = is_alive.nonzero().view(-1)
|
127 |
+
self.original_batch_idx = self.original_batch_idx[is_alive]
|
128 |
+
# select_indices is equal to original_batch_idx for greedy search?
|
molscribe/interface.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
9 |
+
|
10 |
+
from .dataset import get_transforms
|
11 |
+
from .model import Encoder, Decoder
|
12 |
+
from .chemistry import convert_graph_to_smiles
|
13 |
+
from .tokenizer import get_tokenizer
|
14 |
+
|
15 |
+
|
16 |
+
BOND_TYPES = ["", "single", "double", "triple", "aromatic", "solid wedge", "dashed wedge"]
|
17 |
+
|
18 |
+
|
19 |
+
def safe_load(module, module_states):
|
20 |
+
def remove_prefix(state_dict):
|
21 |
+
return {k.replace('module.', ''): v for k, v in state_dict.items()}
|
22 |
+
missing_keys, unexpected_keys = module.load_state_dict(remove_prefix(module_states), strict=False)
|
23 |
+
return
|
24 |
+
|
25 |
+
|
26 |
+
class MolScribe:
|
27 |
+
|
28 |
+
def __init__(self, model_path, device=None):
|
29 |
+
"""
|
30 |
+
MolScribe Interface
|
31 |
+
:param model_path: path of the model checkpoint.
|
32 |
+
:param device: torch device, defaults to be CPU.
|
33 |
+
"""
|
34 |
+
model_states = torch.load(model_path, map_location=torch.device('cpu'))
|
35 |
+
args = self._get_args(model_states['args'])
|
36 |
+
if device is None:
|
37 |
+
device = torch.device('cpu')
|
38 |
+
self.device = device
|
39 |
+
self.tokenizer = get_tokenizer(args)
|
40 |
+
self.encoder, self.decoder = self._get_model(args, self.tokenizer, self.device, model_states)
|
41 |
+
self.transform = get_transforms(args.input_size, augment=False)
|
42 |
+
|
43 |
+
def _get_args(self, args_states=None):
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
# Model
|
46 |
+
parser.add_argument('--encoder', type=str, default='swin_base')
|
47 |
+
parser.add_argument('--decoder', type=str, default='transformer')
|
48 |
+
parser.add_argument('--trunc_encoder', action='store_true') # use the hidden states before downsample
|
49 |
+
parser.add_argument('--no_pretrained', action='store_true')
|
50 |
+
parser.add_argument('--use_checkpoint', action='store_true', default=True)
|
51 |
+
parser.add_argument('--dropout', type=float, default=0.5)
|
52 |
+
parser.add_argument('--embed_dim', type=int, default=256)
|
53 |
+
parser.add_argument('--enc_pos_emb', action='store_true')
|
54 |
+
group = parser.add_argument_group("transformer_options")
|
55 |
+
group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6)
|
56 |
+
group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256)
|
57 |
+
group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8)
|
58 |
+
group.add_argument("--dec_num_queries", type=int, default=128)
|
59 |
+
group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1)
|
60 |
+
group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1)
|
61 |
+
group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0)
|
62 |
+
parser.add_argument('--continuous_coords', action='store_true')
|
63 |
+
parser.add_argument('--compute_confidence', action='store_true')
|
64 |
+
# Data
|
65 |
+
parser.add_argument('--input_size', type=int, default=384)
|
66 |
+
parser.add_argument('--vocab_file', type=str, default=None)
|
67 |
+
parser.add_argument('--coord_bins', type=int, default=64)
|
68 |
+
parser.add_argument('--sep_xy', action='store_true', default=True)
|
69 |
+
|
70 |
+
args = parser.parse_args([])
|
71 |
+
if args_states:
|
72 |
+
for key, value in args_states.items():
|
73 |
+
args.__dict__[key] = value
|
74 |
+
return args
|
75 |
+
|
76 |
+
def _get_model(self, args, tokenizer, device, states):
|
77 |
+
encoder = Encoder(args, pretrained=False)
|
78 |
+
args.encoder_dim = encoder.n_features
|
79 |
+
decoder = Decoder(args, tokenizer)
|
80 |
+
|
81 |
+
safe_load(encoder, states['encoder'])
|
82 |
+
safe_load(decoder, states['decoder'])
|
83 |
+
# print(f"Model loaded from {load_path}")
|
84 |
+
|
85 |
+
encoder.to(device)
|
86 |
+
decoder.to(device)
|
87 |
+
encoder.eval()
|
88 |
+
decoder.eval()
|
89 |
+
return encoder, decoder
|
90 |
+
|
91 |
+
def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=16):
|
92 |
+
device = self.device
|
93 |
+
predictions = []
|
94 |
+
self.decoder.compute_confidence = return_confidence
|
95 |
+
|
96 |
+
for idx in range(0, len(input_images), batch_size):
|
97 |
+
batch_images = input_images[idx:idx+batch_size]
|
98 |
+
images = [self.transform(image=image, keypoints=[])['image'] for image in batch_images]
|
99 |
+
images = torch.stack(images, dim=0).to(device)
|
100 |
+
with torch.no_grad():
|
101 |
+
features, hiddens = self.encoder(images)
|
102 |
+
batch_predictions = self.decoder.decode(features, hiddens)
|
103 |
+
predictions += batch_predictions
|
104 |
+
|
105 |
+
return self.convert_graph_to_output(predictions, input_images, return_confidence, return_atoms_bonds)
|
106 |
+
|
107 |
+
|
108 |
+
def convert_graph_to_output(self, predictions, input_images, return_confidence=True, return_atoms_bonds=True):
|
109 |
+
node_coords = [pred['chartok_coords']['coords'] for pred in predictions]
|
110 |
+
node_symbols = [pred['chartok_coords']['symbols'] for pred in predictions]
|
111 |
+
edges = [pred['edges'] for pred in predictions]
|
112 |
+
# node_symbols = [r_groups[symbol] if symbol in r_groups else symbol for symbol in node_symbols]
|
113 |
+
smiles_list, molblock_list, r_success = convert_graph_to_smiles(
|
114 |
+
node_coords, node_symbols, edges, images=input_images)
|
115 |
+
|
116 |
+
outputs = []
|
117 |
+
for smiles, molblock, pred in zip(smiles_list, molblock_list, predictions):
|
118 |
+
pred_dict = {"smiles": smiles, "molfile": molblock, "oringinal_coords": pred['chartok_coords']['coords'], "original_symbols": pred['chartok_coords']['symbols'], "orignal_edges": pred['edges']}
|
119 |
+
if return_confidence:
|
120 |
+
pred_dict["confidence"] = pred["overall_score"]
|
121 |
+
if return_atoms_bonds:
|
122 |
+
coords = pred['chartok_coords']['coords']
|
123 |
+
symbols = pred['chartok_coords']['symbols']
|
124 |
+
|
125 |
+
|
126 |
+
# get atoms info
|
127 |
+
atom_list = []
|
128 |
+
for i, (symbol, coord) in enumerate(zip(symbols, coords)):
|
129 |
+
atom_dict = {"atom_symbol": symbol, "x": round(coord[0],3), "y": round(coord[1],3)}
|
130 |
+
if return_confidence:
|
131 |
+
atom_dict["confidence"] = pred['chartok_coords']['atom_scores'][i]
|
132 |
+
atom_list.append(atom_dict)
|
133 |
+
pred_dict["atoms"] = atom_list
|
134 |
+
# get bonds info
|
135 |
+
bond_list = []
|
136 |
+
num_atoms = len(symbols)
|
137 |
+
for i in range(num_atoms-1):
|
138 |
+
for j in range(i+1, num_atoms):
|
139 |
+
bond_type_int = pred['edges'][i][j]
|
140 |
+
if bond_type_int != 0:
|
141 |
+
bond_type_str = BOND_TYPES[bond_type_int]
|
142 |
+
bond_dict = {"bond_type": bond_type_str, "endpoint_atoms": (i, j)}
|
143 |
+
if return_confidence:
|
144 |
+
bond_dict["confidence"] = pred["edge_scores"][i][j]
|
145 |
+
bond_list.append(bond_dict)
|
146 |
+
pred_dict["bonds"] = bond_list
|
147 |
+
outputs.append(pred_dict)
|
148 |
+
return outputs
|
149 |
+
|
150 |
+
def predict_image(self, image, return_atoms_bonds=False, return_confidence=False):
|
151 |
+
return self.predict_images([
|
152 |
+
image], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0]
|
153 |
+
|
154 |
+
def predict_image_files(self, image_files: List, return_atoms_bonds=False, return_confidence=False):
|
155 |
+
input_images = []
|
156 |
+
for path in image_files:
|
157 |
+
image = cv2.imread(path)
|
158 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
159 |
+
input_images.append(image)
|
160 |
+
return self.predict_images(
|
161 |
+
input_images, return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)
|
162 |
+
|
163 |
+
def predict_image_file(self, image_file: str, return_atoms_bonds=False, return_confidence=False):
|
164 |
+
return self.predict_image_files(
|
165 |
+
[image_file], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0]
|
166 |
+
|
167 |
+
def draw_prediction(self, prediction, image, notebook=False):
|
168 |
+
if "atoms" not in prediction or "bonds" not in prediction:
|
169 |
+
raise ValueError("atoms and bonds information are not provided.")
|
170 |
+
h, w, _ = image.shape
|
171 |
+
h, w = np.array([h, w]) * 400 / max(h, w)
|
172 |
+
image = cv2.resize(image, (int(w), int(h)))
|
173 |
+
fig, ax = plt.subplots(1, 1)
|
174 |
+
ax.axis('off')
|
175 |
+
ax.set_xlim(-0.05 * w, w * 1.05)
|
176 |
+
ax.set_ylim(1.05 * h, -0.05 * h)
|
177 |
+
plt.imshow(image, alpha=0.)
|
178 |
+
x = [a['x'] * w for a in prediction['atoms']]
|
179 |
+
y = [a['y'] * h for a in prediction['atoms']]
|
180 |
+
markersize = min(w, h) / 3
|
181 |
+
plt.scatter(x, y, marker='o', s=markersize, color='lightskyblue', zorder=10)
|
182 |
+
for i, atom in enumerate(prediction['atoms']):
|
183 |
+
symbol = atom['atom_symbol'].lstrip('[').rstrip(']')
|
184 |
+
plt.annotate(symbol, xy=(x[i], y[i]), ha='center', va='center', color='black', zorder=100)
|
185 |
+
for bond in prediction['bonds']:
|
186 |
+
u, v = bond['endpoint_atoms']
|
187 |
+
x1, y1, x2, y2 = x[u], y[u], x[v], y[v]
|
188 |
+
bond_type = bond['bond_type']
|
189 |
+
if bond_type == 'single':
|
190 |
+
color = 'tab:green'
|
191 |
+
ax.plot([x1, x2], [y1, y2], color, linewidth=4)
|
192 |
+
elif bond_type == 'aromatic':
|
193 |
+
color = 'tab:purple'
|
194 |
+
ax.plot([x1, x2], [y1, y2], color, linewidth=4)
|
195 |
+
elif bond_type == 'double':
|
196 |
+
color = 'tab:green'
|
197 |
+
ax.plot([x1, x2], [y1, y2], color=color, linewidth=7)
|
198 |
+
ax.plot([x1, x2], [y1, y2], color='w', linewidth=1.5, zorder=2.1)
|
199 |
+
elif bond_type == 'triple':
|
200 |
+
color = 'tab:green'
|
201 |
+
x1s, x2s = 0.8 * x1 + 0.2 * x2, 0.2 * x1 + 0.8 * x2
|
202 |
+
y1s, y2s = 0.8 * y1 + 0.2 * y2, 0.2 * y1 + 0.8 * y2
|
203 |
+
ax.plot([x1s, x2s], [y1s, y2s], color=color, linewidth=9)
|
204 |
+
ax.plot([x1, x2], [y1, y2], color='w', linewidth=5, zorder=2.05)
|
205 |
+
ax.plot([x1, x2], [y1, y2], color=color, linewidth=2, zorder=2.1)
|
206 |
+
else:
|
207 |
+
length = 10
|
208 |
+
width = 10
|
209 |
+
color = 'tab:green'
|
210 |
+
if bond_type == 'solid wedge':
|
211 |
+
ax.annotate('', xy=(x1, y1), xytext=(x2, y2),
|
212 |
+
arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2)
|
213 |
+
else:
|
214 |
+
ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
|
215 |
+
arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2)
|
216 |
+
fig.tight_layout()
|
217 |
+
if not notebook:
|
218 |
+
canvas = FigureCanvasAgg(fig)
|
219 |
+
canvas.draw()
|
220 |
+
buf = canvas.buffer_rgba()
|
221 |
+
result_image = np.asarray(buf)
|
222 |
+
plt.close(fig)
|
223 |
+
return result_image
|
molscribe/loss.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from scipy.optimize import linear_sum_assignment
|
5 |
+
from .tokenizer import PAD_ID, MASK, MASK_ID
|
6 |
+
|
7 |
+
|
8 |
+
class LabelSmoothingLoss(nn.Module):
|
9 |
+
"""
|
10 |
+
With label smoothing,
|
11 |
+
KL-divergence between q_{smoothed ground truth prob.}(w)
|
12 |
+
and p_{prob. computed by model}(w) is minimized.
|
13 |
+
"""
|
14 |
+
def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
|
15 |
+
assert 0.0 < label_smoothing <= 1.0
|
16 |
+
self.ignore_index = ignore_index
|
17 |
+
super(LabelSmoothingLoss, self).__init__()
|
18 |
+
|
19 |
+
smoothing_value = label_smoothing / (tgt_vocab_size - 2)
|
20 |
+
one_hot = torch.full((tgt_vocab_size,), smoothing_value)
|
21 |
+
one_hot[self.ignore_index] = 0
|
22 |
+
self.register_buffer('one_hot', one_hot.unsqueeze(0))
|
23 |
+
|
24 |
+
self.confidence = 1.0 - label_smoothing
|
25 |
+
|
26 |
+
def forward(self, output, target):
|
27 |
+
"""
|
28 |
+
output (FloatTensor): batch_size x n_classes
|
29 |
+
target (LongTensor): batch_size
|
30 |
+
"""
|
31 |
+
# assuming output is raw logits
|
32 |
+
# convert to log_probs
|
33 |
+
log_probs = F.log_softmax(output, dim=-1)
|
34 |
+
|
35 |
+
model_prob = self.one_hot.repeat(target.size(0), 1)
|
36 |
+
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
|
37 |
+
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
|
38 |
+
|
39 |
+
# reduction mean or sum?
|
40 |
+
return F.kl_div(log_probs, model_prob, reduction='batchmean')
|
41 |
+
|
42 |
+
|
43 |
+
class SequenceLoss(nn.Module):
|
44 |
+
|
45 |
+
def __init__(self, label_smoothing, vocab_size, ignore_index=-100, ignore_indices=[]):
|
46 |
+
super(SequenceLoss, self).__init__()
|
47 |
+
if ignore_indices:
|
48 |
+
ignore_index = ignore_indices[0]
|
49 |
+
self.ignore_index = ignore_index
|
50 |
+
self.ignore_indices = ignore_indices
|
51 |
+
if label_smoothing == 0:
|
52 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean')
|
53 |
+
else:
|
54 |
+
self.criterion = LabelSmoothingLoss(label_smoothing, vocab_size, ignore_index)
|
55 |
+
|
56 |
+
def forward(self, output, target):
|
57 |
+
"""
|
58 |
+
:param output: [batch, len, vocab]
|
59 |
+
:param target: [batch, len]
|
60 |
+
:return:
|
61 |
+
"""
|
62 |
+
batch_size, max_len, vocab_size = output.size()
|
63 |
+
output = output.reshape(-1, vocab_size)
|
64 |
+
target = target.reshape(-1)
|
65 |
+
for idx in self.ignore_indices:
|
66 |
+
if idx != self.ignore_index:
|
67 |
+
target.masked_fill_((target == idx), self.ignore_index)
|
68 |
+
loss = self.criterion(output, target)
|
69 |
+
return loss
|
70 |
+
|
71 |
+
|
72 |
+
class GraphLoss(nn.Module):
|
73 |
+
|
74 |
+
def __init__(self):
|
75 |
+
super(GraphLoss, self).__init__()
|
76 |
+
weight = torch.ones(7) * 10
|
77 |
+
weight[0] = 1
|
78 |
+
self.criterion = nn.CrossEntropyLoss(weight, ignore_index=-100)
|
79 |
+
|
80 |
+
def forward(self, outputs, targets):
|
81 |
+
results = {}
|
82 |
+
if 'coords' in outputs:
|
83 |
+
pred = outputs['coords']
|
84 |
+
max_len = pred.size(1)
|
85 |
+
target = targets['coords'][:, :max_len]
|
86 |
+
mask = target.ge(0)
|
87 |
+
loss = F.l1_loss(pred, target, reduction='none')
|
88 |
+
results['coords'] = (loss * mask).sum() / mask.sum()
|
89 |
+
if 'edges' in outputs:
|
90 |
+
pred = outputs['edges']
|
91 |
+
max_len = pred.size(-1)
|
92 |
+
target = targets['edges'][:, :max_len, :max_len]
|
93 |
+
results['edges'] = self.criterion(pred, target)
|
94 |
+
return results
|
95 |
+
|
96 |
+
|
97 |
+
class Criterion(nn.Module):
|
98 |
+
|
99 |
+
def __init__(self, args, tokenizer):
|
100 |
+
super(Criterion, self).__init__()
|
101 |
+
criterion = {}
|
102 |
+
for format_ in args.formats:
|
103 |
+
if format_ == 'edges':
|
104 |
+
criterion['edges'] = GraphLoss()
|
105 |
+
else:
|
106 |
+
if MASK in tokenizer[format_].stoi:
|
107 |
+
ignore_indices = [PAD_ID, MASK_ID]
|
108 |
+
else:
|
109 |
+
ignore_indices = []
|
110 |
+
criterion[format_] = SequenceLoss(args.label_smoothing, len(tokenizer[format_]),
|
111 |
+
ignore_index=PAD_ID, ignore_indices=ignore_indices)
|
112 |
+
self.criterion = nn.ModuleDict(criterion)
|
113 |
+
|
114 |
+
def forward(self, results, refs):
|
115 |
+
losses = {}
|
116 |
+
for format_ in results:
|
117 |
+
predictions, targets, *_ = results[format_]
|
118 |
+
loss_ = self.criterion[format_](predictions, targets)
|
119 |
+
if type(loss_) is dict:
|
120 |
+
losses.update(loss_)
|
121 |
+
else:
|
122 |
+
if loss_.numel() > 1:
|
123 |
+
loss_ = loss_.mean()
|
124 |
+
losses[format_] = loss_
|
125 |
+
return losses
|
molscribe/model.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import timm
|
8 |
+
|
9 |
+
from .utils import FORMAT_INFO, to_device
|
10 |
+
from .tokenizer import SOS_ID, EOS_ID, PAD_ID, MASK_ID
|
11 |
+
from .inference import GreedySearch, BeamSearch
|
12 |
+
from .transformer import TransformerDecoder, Embeddings
|
13 |
+
|
14 |
+
|
15 |
+
class Encoder(nn.Module):
|
16 |
+
def __init__(self, args, pretrained=False):
|
17 |
+
super().__init__()
|
18 |
+
model_name = args.encoder
|
19 |
+
self.model_name = model_name
|
20 |
+
if model_name.startswith('resnet'):
|
21 |
+
self.model_type = 'resnet'
|
22 |
+
self.cnn = timm.create_model(model_name, pretrained=pretrained)
|
23 |
+
self.n_features = self.cnn.num_features # encoder_dim
|
24 |
+
self.cnn.global_pool = nn.Identity()
|
25 |
+
self.cnn.fc = nn.Identity()
|
26 |
+
elif model_name.startswith('swin'):
|
27 |
+
self.model_type = 'swin'
|
28 |
+
self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False,
|
29 |
+
use_checkpoint=args.use_checkpoint)
|
30 |
+
self.n_features = self.transformer.num_features
|
31 |
+
self.transformer.head = nn.Identity()
|
32 |
+
elif 'efficientnet' in model_name:
|
33 |
+
self.model_type = 'efficientnet'
|
34 |
+
self.cnn = timm.create_model(model_name, pretrained=pretrained)
|
35 |
+
self.n_features = self.cnn.num_features
|
36 |
+
self.cnn.global_pool = nn.Identity()
|
37 |
+
self.cnn.classifier = nn.Identity()
|
38 |
+
else:
|
39 |
+
raise NotImplemented
|
40 |
+
|
41 |
+
def swin_forward(self, transformer, x):
|
42 |
+
x = transformer.patch_embed(x)
|
43 |
+
if transformer.absolute_pos_embed is not None:
|
44 |
+
x = x + transformer.absolute_pos_embed
|
45 |
+
x = transformer.pos_drop(x)
|
46 |
+
|
47 |
+
def layer_forward(layer, x, hiddens):
|
48 |
+
for blk in layer.blocks:
|
49 |
+
if not torch.jit.is_scripting() and layer.use_checkpoint:
|
50 |
+
x = torch.utils.checkpoint.checkpoint(blk, x)
|
51 |
+
else:
|
52 |
+
x = blk(x)
|
53 |
+
H, W = layer.input_resolution
|
54 |
+
B, L, C = x.shape
|
55 |
+
hiddens.append(x.view(B, H, W, C))
|
56 |
+
if layer.downsample is not None:
|
57 |
+
x = layer.downsample(x)
|
58 |
+
return x, hiddens
|
59 |
+
|
60 |
+
hiddens = []
|
61 |
+
for layer in transformer.layers:
|
62 |
+
x, hiddens = layer_forward(layer, x, hiddens)
|
63 |
+
x = transformer.norm(x) # B L C
|
64 |
+
hiddens[-1] = x.view_as(hiddens[-1])
|
65 |
+
return x, hiddens
|
66 |
+
|
67 |
+
def forward(self, x, refs=None):
|
68 |
+
if self.model_type in ['resnet', 'efficientnet']:
|
69 |
+
features = self.cnn(x)
|
70 |
+
features = features.permute(0, 2, 3, 1)
|
71 |
+
hiddens = []
|
72 |
+
elif self.model_type == 'swin':
|
73 |
+
if 'patch' in self.model_name:
|
74 |
+
features, hiddens = self.swin_forward(self.transformer, x)
|
75 |
+
else:
|
76 |
+
features, hiddens = self.transformer(x)
|
77 |
+
else:
|
78 |
+
raise NotImplemented
|
79 |
+
return features, hiddens
|
80 |
+
|
81 |
+
|
82 |
+
class TransformerDecoderBase(nn.Module):
|
83 |
+
|
84 |
+
def __init__(self, args):
|
85 |
+
super().__init__()
|
86 |
+
self.args = args
|
87 |
+
|
88 |
+
self.enc_trans_layer = nn.Sequential(
|
89 |
+
nn.Linear(args.encoder_dim, args.dec_hidden_size)
|
90 |
+
# nn.LayerNorm(args.dec_hidden_size, eps=1e-6)
|
91 |
+
)
|
92 |
+
self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None
|
93 |
+
|
94 |
+
self.decoder = TransformerDecoder(
|
95 |
+
num_layers=args.dec_num_layers,
|
96 |
+
d_model=args.dec_hidden_size,
|
97 |
+
heads=args.dec_attn_heads,
|
98 |
+
d_ff=args.dec_hidden_size * 4,
|
99 |
+
copy_attn=False,
|
100 |
+
self_attn_type="scaled-dot",
|
101 |
+
dropout=args.hidden_dropout,
|
102 |
+
attention_dropout=args.attn_dropout,
|
103 |
+
max_relative_positions=args.max_relative_positions,
|
104 |
+
aan_useffn=False,
|
105 |
+
full_context_alignment=False,
|
106 |
+
alignment_layer=0,
|
107 |
+
alignment_heads=0,
|
108 |
+
pos_ffn_activation_fn='gelu'
|
109 |
+
)
|
110 |
+
|
111 |
+
def enc_transform(self, encoder_out):
|
112 |
+
batch_size = encoder_out.size(0)
|
113 |
+
encoder_dim = encoder_out.size(-1)
|
114 |
+
encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
|
115 |
+
max_len = encoder_out.size(1)
|
116 |
+
device = encoder_out.device
|
117 |
+
if self.enc_pos_emb:
|
118 |
+
pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0)
|
119 |
+
encoder_out = encoder_out + pos_emb
|
120 |
+
encoder_out = self.enc_trans_layer(encoder_out)
|
121 |
+
return encoder_out
|
122 |
+
|
123 |
+
|
124 |
+
class TransformerDecoderAR(TransformerDecoderBase):
|
125 |
+
"""Autoregressive Transformer Decoder"""
|
126 |
+
|
127 |
+
def __init__(self, args, tokenizer):
|
128 |
+
super().__init__(args)
|
129 |
+
self.tokenizer = tokenizer
|
130 |
+
self.vocab_size = len(self.tokenizer)
|
131 |
+
self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True)
|
132 |
+
self.embeddings = Embeddings(
|
133 |
+
word_vec_size=args.dec_hidden_size,
|
134 |
+
word_vocab_size=self.vocab_size,
|
135 |
+
word_padding_idx=PAD_ID,
|
136 |
+
position_encoding=True,
|
137 |
+
dropout=args.hidden_dropout)
|
138 |
+
|
139 |
+
def dec_embedding(self, tgt, step=None):
|
140 |
+
pad_idx = self.embeddings.word_padding_idx
|
141 |
+
tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt]
|
142 |
+
emb = self.embeddings(tgt, step=step)
|
143 |
+
assert emb.dim() == 3 # batch x len x embedding_dim
|
144 |
+
return emb, tgt_pad_mask
|
145 |
+
|
146 |
+
def forward(self, encoder_out, labels, label_lengths):
|
147 |
+
"""Training mode"""
|
148 |
+
batch_size, max_len, _ = encoder_out.size()
|
149 |
+
memory_bank = self.enc_transform(encoder_out)
|
150 |
+
|
151 |
+
tgt = labels.unsqueeze(-1) # (b, t, 1)
|
152 |
+
tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
|
153 |
+
dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask)
|
154 |
+
|
155 |
+
logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v)
|
156 |
+
return logits[:, :-1], labels[:, 1:], dec_out
|
157 |
+
|
158 |
+
def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256,
|
159 |
+
labels=None):
|
160 |
+
"""Inference mode. Autoregressively decode the sequence. Only greedy search is supported now. Beam search is
|
161 |
+
out-dated. The labels is used for partial prediction, i.e. part of the sequence is given. In standard decoding,
|
162 |
+
labels=None."""
|
163 |
+
batch_size, max_len, _ = encoder_out.size()
|
164 |
+
memory_bank = self.enc_transform(encoder_out)
|
165 |
+
orig_labels = labels
|
166 |
+
|
167 |
+
if beam_size == 1:
|
168 |
+
decode_strategy = GreedySearch(
|
169 |
+
sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length,
|
170 |
+
pad=PAD_ID, bos=SOS_ID, eos=EOS_ID,
|
171 |
+
return_attention=False, return_hidden=True)
|
172 |
+
else:
|
173 |
+
decode_strategy = BeamSearch(
|
174 |
+
beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length,
|
175 |
+
pad=PAD_ID, bos=SOS_ID, eos=EOS_ID,
|
176 |
+
return_attention=False)
|
177 |
+
|
178 |
+
# adapted from onmt.translate.translator
|
179 |
+
results = {
|
180 |
+
"predictions": None,
|
181 |
+
"scores": None,
|
182 |
+
"attention": None
|
183 |
+
}
|
184 |
+
|
185 |
+
# (2) prep decode_strategy. Possibly repeat src objects.
|
186 |
+
_, memory_bank = decode_strategy.initialize(memory_bank=memory_bank)
|
187 |
+
|
188 |
+
# (3) Begin decoding step by step:
|
189 |
+
for step in range(decode_strategy.max_length):
|
190 |
+
tgt = decode_strategy.current_predictions.view(-1, 1, 1)
|
191 |
+
if labels is not None:
|
192 |
+
label = labels[:, step].view(-1, 1, 1)
|
193 |
+
mask = label.eq(MASK_ID).long()
|
194 |
+
tgt = tgt * mask + label * (1 - mask)
|
195 |
+
tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
|
196 |
+
dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank,
|
197 |
+
tgt_pad_mask=tgt_pad_mask, step=step)
|
198 |
+
|
199 |
+
attn = dec_attn.get("std", None)
|
200 |
+
|
201 |
+
dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v]
|
202 |
+
dec_logits = dec_logits.squeeze(1)
|
203 |
+
log_probs = F.log_softmax(dec_logits, dim=-1)
|
204 |
+
|
205 |
+
if self.tokenizer.output_constraint:
|
206 |
+
output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()]
|
207 |
+
output_mask = torch.tensor(output_mask, device=log_probs.device)
|
208 |
+
log_probs.masked_fill_(output_mask, -10000)
|
209 |
+
|
210 |
+
label = labels[:, step + 1] if labels is not None and step + 1 < labels.size(1) else None
|
211 |
+
decode_strategy.advance(log_probs, attn, dec_out, label)
|
212 |
+
any_finished = decode_strategy.is_finished.any()
|
213 |
+
if any_finished:
|
214 |
+
decode_strategy.update_finished()
|
215 |
+
if decode_strategy.done:
|
216 |
+
break
|
217 |
+
|
218 |
+
select_indices = decode_strategy.select_indices
|
219 |
+
if any_finished:
|
220 |
+
# Reorder states.
|
221 |
+
memory_bank = memory_bank.index_select(0, select_indices)
|
222 |
+
if labels is not None:
|
223 |
+
labels = labels.index_select(0, select_indices)
|
224 |
+
self.map_state(lambda state, dim: state.index_select(dim, select_indices))
|
225 |
+
|
226 |
+
results["scores"] = decode_strategy.scores # fixed to be average of token scores
|
227 |
+
results["token_scores"] = decode_strategy.token_scores
|
228 |
+
results["predictions"] = decode_strategy.predictions
|
229 |
+
results["attention"] = decode_strategy.attention
|
230 |
+
results["hidden"] = decode_strategy.hidden
|
231 |
+
if orig_labels is not None:
|
232 |
+
for i in range(batch_size):
|
233 |
+
pred = results["predictions"][i][0]
|
234 |
+
label = orig_labels[i][1:len(pred) + 1]
|
235 |
+
mask = label.eq(MASK_ID).long()
|
236 |
+
pred = pred[:len(label)]
|
237 |
+
results["predictions"][i][0] = pred * mask + label * (1 - mask)
|
238 |
+
|
239 |
+
return results["predictions"], results['scores'], results["token_scores"], results["hidden"]
|
240 |
+
|
241 |
+
# adapted from onmt.decoders.transformer
|
242 |
+
def map_state(self, fn):
|
243 |
+
def _recursive_map(struct, batch_dim=0):
|
244 |
+
for k, v in struct.items():
|
245 |
+
if v is not None:
|
246 |
+
if isinstance(v, dict):
|
247 |
+
_recursive_map(v)
|
248 |
+
else:
|
249 |
+
struct[k] = fn(v, batch_dim)
|
250 |
+
|
251 |
+
if self.decoder.state["cache"] is not None:
|
252 |
+
_recursive_map(self.decoder.state["cache"])
|
253 |
+
|
254 |
+
|
255 |
+
class GraphPredictor(nn.Module):
|
256 |
+
|
257 |
+
def __init__(self, decoder_dim, coords=False):
|
258 |
+
super(GraphPredictor, self).__init__()
|
259 |
+
self.coords = coords
|
260 |
+
self.mlp = nn.Sequential(
|
261 |
+
nn.Linear(decoder_dim * 2, decoder_dim), nn.GELU(),
|
262 |
+
nn.Linear(decoder_dim, 7)
|
263 |
+
)
|
264 |
+
if coords:
|
265 |
+
self.coords_mlp = nn.Sequential(
|
266 |
+
nn.Linear(decoder_dim, decoder_dim), nn.GELU(),
|
267 |
+
nn.Linear(decoder_dim, 2)
|
268 |
+
)
|
269 |
+
|
270 |
+
def forward(self, hidden, indices=None):
|
271 |
+
b, l, dim = hidden.size()
|
272 |
+
if indices is None:
|
273 |
+
index = [i for i in range(3, l, 3)]
|
274 |
+
hidden = hidden[:, index]
|
275 |
+
else:
|
276 |
+
batch_id = torch.arange(b).unsqueeze(1).expand_as(indices).reshape(-1)
|
277 |
+
indices = indices.view(-1)
|
278 |
+
hidden = hidden[batch_id, indices].view(b, -1, dim)
|
279 |
+
b, l, dim = hidden.size()
|
280 |
+
results = {}
|
281 |
+
hh = torch.cat([hidden.unsqueeze(2).expand(b, l, l, dim), hidden.unsqueeze(1).expand(b, l, l, dim)], dim=3)
|
282 |
+
results['edges'] = self.mlp(hh).permute(0, 3, 1, 2)
|
283 |
+
if self.coords:
|
284 |
+
results['coords'] = self.coords_mlp(hidden)
|
285 |
+
return results
|
286 |
+
|
287 |
+
|
288 |
+
def get_edge_prediction(edge_prob):
|
289 |
+
if not edge_prob:
|
290 |
+
return [], []
|
291 |
+
n = len(edge_prob)
|
292 |
+
if n == 0:
|
293 |
+
return [], []
|
294 |
+
for i in range(n):
|
295 |
+
for j in range(i + 1, n):
|
296 |
+
for k in range(5):
|
297 |
+
edge_prob[i][j][k] = (edge_prob[i][j][k] + edge_prob[j][i][k]) / 2
|
298 |
+
edge_prob[j][i][k] = edge_prob[i][j][k]
|
299 |
+
edge_prob[i][j][5] = (edge_prob[i][j][5] + edge_prob[j][i][6]) / 2
|
300 |
+
edge_prob[i][j][6] = (edge_prob[i][j][6] + edge_prob[j][i][5]) / 2
|
301 |
+
edge_prob[j][i][5] = edge_prob[i][j][6]
|
302 |
+
edge_prob[j][i][6] = edge_prob[i][j][5]
|
303 |
+
prediction = np.argmax(edge_prob, axis=2).tolist()
|
304 |
+
score = np.max(edge_prob, axis=2).tolist()
|
305 |
+
return prediction, score
|
306 |
+
|
307 |
+
|
308 |
+
class Decoder(nn.Module):
|
309 |
+
"""This class is a wrapper for different decoder architectures, and support multiple decoders."""
|
310 |
+
|
311 |
+
def __init__(self, args, tokenizer):
|
312 |
+
super(Decoder, self).__init__()
|
313 |
+
self.args = args
|
314 |
+
self.formats = args.formats
|
315 |
+
self.tokenizer = tokenizer
|
316 |
+
decoder = {}
|
317 |
+
for format_ in args.formats:
|
318 |
+
if format_ == 'edges':
|
319 |
+
decoder['edges'] = GraphPredictor(args.dec_hidden_size, coords=args.continuous_coords)
|
320 |
+
else:
|
321 |
+
decoder[format_] = TransformerDecoderAR(args, tokenizer[format_])
|
322 |
+
self.decoder = nn.ModuleDict(decoder)
|
323 |
+
self.compute_confidence = args.compute_confidence
|
324 |
+
|
325 |
+
def forward(self, encoder_out, hiddens, refs):
|
326 |
+
"""Training mode. Compute the logits with teacher forcing."""
|
327 |
+
results = {}
|
328 |
+
refs = to_device(refs, encoder_out.device)
|
329 |
+
for format_ in self.formats:
|
330 |
+
if format_ == 'edges':
|
331 |
+
if 'atomtok_coords' in results:
|
332 |
+
dec_out = results['atomtok_coords'][2]
|
333 |
+
predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0])
|
334 |
+
elif 'chartok_coords' in results:
|
335 |
+
dec_out = results['chartok_coords'][2]
|
336 |
+
predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0])
|
337 |
+
else:
|
338 |
+
raise NotImplemented
|
339 |
+
targets = {'edges': refs['edges']}
|
340 |
+
if 'coords' in predictions:
|
341 |
+
targets['coords'] = refs['coords']
|
342 |
+
results['edges'] = (predictions, targets)
|
343 |
+
else:
|
344 |
+
labels, label_lengths = refs[format_]
|
345 |
+
results[format_] = self.decoder[format_](encoder_out, labels, label_lengths)
|
346 |
+
return results
|
347 |
+
|
348 |
+
def decode(self, encoder_out, hiddens=None, refs=None, beam_size=1, n_best=1):
|
349 |
+
"""Inference mode. Call each decoder's decode method (if required), convert the output format (e.g. token to
|
350 |
+
sequence). Beam search is not supported yet."""
|
351 |
+
results = {}
|
352 |
+
predictions = []
|
353 |
+
for format_ in self.formats:
|
354 |
+
if format_ in ['atomtok', 'atomtok_coords', 'chartok_coords']:
|
355 |
+
max_len = FORMAT_INFO[format_]['max_len']
|
356 |
+
results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len)
|
357 |
+
outputs, scores, token_scores, *_ = results[format_]
|
358 |
+
beam_preds = [[self.tokenizer[format_].sequence_to_smiles(x.tolist()) for x in pred]
|
359 |
+
for pred in outputs]
|
360 |
+
predictions = [{format_: pred[0]} for pred in beam_preds]
|
361 |
+
if self.compute_confidence:
|
362 |
+
for i in range(len(predictions)):
|
363 |
+
# -1: y score, -2: x score, -3: symbol score
|
364 |
+
indices = np.array(predictions[i][format_]['indices']) - 3
|
365 |
+
if format_ == 'chartok_coords':
|
366 |
+
atom_scores = []
|
367 |
+
for symbol, index in zip(predictions[i][format_]['symbols'], indices):
|
368 |
+
atom_score = (np.prod(token_scores[i][0][index - len(symbol) + 1:index + 1])
|
369 |
+
** (1 / len(symbol))).item()
|
370 |
+
atom_scores.append(atom_score)
|
371 |
+
else:
|
372 |
+
atom_scores = np.array(token_scores[i][0])[indices].tolist()
|
373 |
+
predictions[i][format_]['atom_scores'] = atom_scores
|
374 |
+
predictions[i][format_]['average_token_score'] = scores[i][0]
|
375 |
+
if format_ == 'edges':
|
376 |
+
if 'atomtok_coords' in results:
|
377 |
+
atom_format = 'atomtok_coords'
|
378 |
+
elif 'chartok_coords' in results:
|
379 |
+
atom_format = 'chartok_coords'
|
380 |
+
else:
|
381 |
+
raise NotImplemented
|
382 |
+
dec_out = results[atom_format][3] # batch x n_best x len x dim
|
383 |
+
for i in range(len(dec_out)):
|
384 |
+
hidden = dec_out[i][0].unsqueeze(0) # 1 * len * dim
|
385 |
+
indices = torch.LongTensor(predictions[i][atom_format]['indices']).unsqueeze(0) # 1 * k
|
386 |
+
pred = self.decoder['edges'](hidden, indices) # k * k
|
387 |
+
prob = F.softmax(pred['edges'].squeeze(0).permute(1, 2, 0), dim=2).tolist() # k * k * 7
|
388 |
+
edge_pred, edge_score = get_edge_prediction(prob)
|
389 |
+
predictions[i]['edges'] = edge_pred
|
390 |
+
if self.compute_confidence:
|
391 |
+
predictions[i]['edge_scores'] = edge_score
|
392 |
+
predictions[i]['edge_score_product'] = np.sqrt(np.prod(edge_score)).item()
|
393 |
+
predictions[i]['overall_score'] = predictions[i][atom_format]['average_token_score'] * \
|
394 |
+
predictions[i]['edge_score_product']
|
395 |
+
predictions[i][atom_format].pop('average_token_score')
|
396 |
+
predictions[i].pop('edge_score_product')
|
397 |
+
return predictions
|
molscribe/tokenizer.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from SmilesPE.pretokenizer import atomwise_tokenizer
|
6 |
+
|
7 |
+
PAD = '<pad>'
|
8 |
+
SOS = '<sos>'
|
9 |
+
EOS = '<eos>'
|
10 |
+
UNK = '<unk>'
|
11 |
+
MASK = '<mask>'
|
12 |
+
PAD_ID = 0
|
13 |
+
SOS_ID = 1
|
14 |
+
EOS_ID = 2
|
15 |
+
UNK_ID = 3
|
16 |
+
MASK_ID = 4
|
17 |
+
|
18 |
+
|
19 |
+
class Tokenizer(object):
|
20 |
+
|
21 |
+
def __init__(self, path=None):
|
22 |
+
self.stoi = {}
|
23 |
+
self.itos = {}
|
24 |
+
if path:
|
25 |
+
self.load(path)
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.stoi)
|
29 |
+
|
30 |
+
@property
|
31 |
+
def output_constraint(self):
|
32 |
+
return False
|
33 |
+
|
34 |
+
def save(self, path):
|
35 |
+
with open(path, 'w') as f:
|
36 |
+
json.dump(self.stoi, f)
|
37 |
+
|
38 |
+
def load(self, path):
|
39 |
+
with open(path) as f:
|
40 |
+
self.stoi = json.load(f)
|
41 |
+
self.itos = {item[1]: item[0] for item in self.stoi.items()}
|
42 |
+
|
43 |
+
def fit_on_texts(self, texts):
|
44 |
+
vocab = set()
|
45 |
+
for text in texts:
|
46 |
+
vocab.update(text.split(' '))
|
47 |
+
vocab = [PAD, SOS, EOS, UNK] + list(vocab)
|
48 |
+
for i, s in enumerate(vocab):
|
49 |
+
self.stoi[s] = i
|
50 |
+
self.itos = {item[1]: item[0] for item in self.stoi.items()}
|
51 |
+
assert self.stoi[PAD] == PAD_ID
|
52 |
+
assert self.stoi[SOS] == SOS_ID
|
53 |
+
assert self.stoi[EOS] == EOS_ID
|
54 |
+
assert self.stoi[UNK] == UNK_ID
|
55 |
+
|
56 |
+
def text_to_sequence(self, text, tokenized=True):
|
57 |
+
sequence = []
|
58 |
+
sequence.append(self.stoi['<sos>'])
|
59 |
+
if tokenized:
|
60 |
+
tokens = text.split(' ')
|
61 |
+
else:
|
62 |
+
tokens = atomwise_tokenizer(text)
|
63 |
+
for s in tokens:
|
64 |
+
if s not in self.stoi:
|
65 |
+
s = '<unk>'
|
66 |
+
sequence.append(self.stoi[s])
|
67 |
+
sequence.append(self.stoi['<eos>'])
|
68 |
+
return sequence
|
69 |
+
|
70 |
+
def texts_to_sequences(self, texts):
|
71 |
+
sequences = []
|
72 |
+
for text in texts:
|
73 |
+
sequence = self.text_to_sequence(text)
|
74 |
+
sequences.append(sequence)
|
75 |
+
return sequences
|
76 |
+
|
77 |
+
def sequence_to_text(self, sequence):
|
78 |
+
return ''.join(list(map(lambda i: self.itos[i], sequence)))
|
79 |
+
|
80 |
+
def sequences_to_texts(self, sequences):
|
81 |
+
texts = []
|
82 |
+
for sequence in sequences:
|
83 |
+
text = self.sequence_to_text(sequence)
|
84 |
+
texts.append(text)
|
85 |
+
return texts
|
86 |
+
|
87 |
+
def predict_caption(self, sequence):
|
88 |
+
caption = ''
|
89 |
+
for i in sequence:
|
90 |
+
if i == self.stoi['<eos>'] or i == self.stoi['<pad>']:
|
91 |
+
break
|
92 |
+
caption += self.itos[i]
|
93 |
+
return caption
|
94 |
+
|
95 |
+
def predict_captions(self, sequences):
|
96 |
+
captions = []
|
97 |
+
for sequence in sequences:
|
98 |
+
caption = self.predict_caption(sequence)
|
99 |
+
captions.append(caption)
|
100 |
+
return captions
|
101 |
+
|
102 |
+
def sequence_to_smiles(self, sequence):
|
103 |
+
return {'smiles': self.predict_caption(sequence)}
|
104 |
+
|
105 |
+
|
106 |
+
class NodeTokenizer(Tokenizer):
|
107 |
+
|
108 |
+
def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False):
|
109 |
+
super().__init__(path)
|
110 |
+
self.maxx = input_size # height
|
111 |
+
self.maxy = input_size # width
|
112 |
+
self.sep_xy = sep_xy
|
113 |
+
self.special_tokens = [PAD, SOS, EOS, UNK, MASK]
|
114 |
+
self.continuous_coords = continuous_coords
|
115 |
+
self.debug = debug
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
if self.sep_xy:
|
119 |
+
return self.offset + self.maxx + self.maxy
|
120 |
+
else:
|
121 |
+
return self.offset + max(self.maxx, self.maxy)
|
122 |
+
|
123 |
+
@property
|
124 |
+
def offset(self):
|
125 |
+
return len(self.stoi)
|
126 |
+
|
127 |
+
@property
|
128 |
+
def output_constraint(self):
|
129 |
+
return not self.continuous_coords
|
130 |
+
|
131 |
+
def len_symbols(self):
|
132 |
+
return len(self.stoi)
|
133 |
+
|
134 |
+
def fit_atom_symbols(self, atoms):
|
135 |
+
vocab = self.special_tokens + list(set(atoms))
|
136 |
+
for i, s in enumerate(vocab):
|
137 |
+
self.stoi[s] = i
|
138 |
+
assert self.stoi[PAD] == PAD_ID
|
139 |
+
assert self.stoi[SOS] == SOS_ID
|
140 |
+
assert self.stoi[EOS] == EOS_ID
|
141 |
+
assert self.stoi[UNK] == UNK_ID
|
142 |
+
assert self.stoi[MASK] == MASK_ID
|
143 |
+
self.itos = {item[1]: item[0] for item in self.stoi.items()}
|
144 |
+
|
145 |
+
def is_x(self, x):
|
146 |
+
return self.offset <= x < self.offset + self.maxx
|
147 |
+
|
148 |
+
def is_y(self, y):
|
149 |
+
if self.sep_xy:
|
150 |
+
return self.offset + self.maxx <= y
|
151 |
+
return self.offset <= y
|
152 |
+
|
153 |
+
def is_symbol(self, s):
|
154 |
+
return len(self.special_tokens) <= s < self.offset or s == UNK_ID
|
155 |
+
|
156 |
+
def is_atom(self, id):
|
157 |
+
if self.is_symbol(id):
|
158 |
+
return self.is_atom_token(self.itos[id])
|
159 |
+
return False
|
160 |
+
|
161 |
+
def is_atom_token(self, token):
|
162 |
+
return token.isalpha() or token.startswith("[") or token == '*' or token == UNK
|
163 |
+
|
164 |
+
def x_to_id(self, x):
|
165 |
+
return self.offset + round(x * (self.maxx - 1))
|
166 |
+
|
167 |
+
def y_to_id(self, y):
|
168 |
+
if self.sep_xy:
|
169 |
+
return self.offset + self.maxx + round(y * (self.maxy - 1))
|
170 |
+
return self.offset + round(y * (self.maxy - 1))
|
171 |
+
|
172 |
+
def id_to_x(self, id):
|
173 |
+
return (id - self.offset) / (self.maxx - 1)
|
174 |
+
|
175 |
+
def id_to_y(self, id):
|
176 |
+
if self.sep_xy:
|
177 |
+
return (id - self.offset - self.maxx) / (self.maxy - 1)
|
178 |
+
return (id - self.offset) / (self.maxy - 1)
|
179 |
+
|
180 |
+
def get_output_mask(self, id):
|
181 |
+
mask = [False] * len(self)
|
182 |
+
if self.continuous_coords:
|
183 |
+
return mask
|
184 |
+
if self.is_atom(id):
|
185 |
+
return [True] * self.offset + [False] * self.maxx + [True] * self.maxy
|
186 |
+
if self.is_x(id):
|
187 |
+
return [True] * (self.offset + self.maxx) + [False] * self.maxy
|
188 |
+
if self.is_y(id):
|
189 |
+
return [False] * self.offset + [True] * (self.maxx + self.maxy)
|
190 |
+
return mask
|
191 |
+
|
192 |
+
def symbol_to_id(self, symbol):
|
193 |
+
if symbol not in self.stoi:
|
194 |
+
return UNK_ID
|
195 |
+
return self.stoi[symbol]
|
196 |
+
|
197 |
+
def symbols_to_labels(self, symbols):
|
198 |
+
labels = []
|
199 |
+
for symbol in symbols:
|
200 |
+
labels.append(self.symbol_to_id(symbol))
|
201 |
+
return labels
|
202 |
+
|
203 |
+
def labels_to_symbols(self, labels):
|
204 |
+
symbols = []
|
205 |
+
for label in labels:
|
206 |
+
symbols.append(self.itos[label])
|
207 |
+
return symbols
|
208 |
+
|
209 |
+
def nodes_to_grid(self, nodes):
|
210 |
+
coords, symbols = nodes['coords'], nodes['symbols']
|
211 |
+
grid = np.zeros((self.maxx, self.maxy), dtype=int)
|
212 |
+
for [x, y], symbol in zip(coords, symbols):
|
213 |
+
x = round(x * (self.maxx - 1))
|
214 |
+
y = round(y * (self.maxy - 1))
|
215 |
+
grid[x][y] = self.symbol_to_id(symbol)
|
216 |
+
return grid
|
217 |
+
|
218 |
+
def grid_to_nodes(self, grid):
|
219 |
+
coords, symbols, indices = [], [], []
|
220 |
+
for i in range(self.maxx):
|
221 |
+
for j in range(self.maxy):
|
222 |
+
if grid[i][j] != 0:
|
223 |
+
x = i / (self.maxx - 1)
|
224 |
+
y = j / (self.maxy - 1)
|
225 |
+
coords.append([x, y])
|
226 |
+
symbols.append(self.itos[grid[i][j]])
|
227 |
+
indices.append([i, j])
|
228 |
+
return {'coords': coords, 'symbols': symbols, 'indices': indices}
|
229 |
+
|
230 |
+
def nodes_to_sequence(self, nodes):
|
231 |
+
coords, symbols = nodes['coords'], nodes['symbols']
|
232 |
+
labels = [SOS_ID]
|
233 |
+
for (x, y), symbol in zip(coords, symbols):
|
234 |
+
assert 0 <= x <= 1
|
235 |
+
assert 0 <= y <= 1
|
236 |
+
labels.append(self.x_to_id(x))
|
237 |
+
labels.append(self.y_to_id(y))
|
238 |
+
labels.append(self.symbol_to_id(symbol))
|
239 |
+
labels.append(EOS_ID)
|
240 |
+
return labels
|
241 |
+
|
242 |
+
def sequence_to_nodes(self, sequence):
|
243 |
+
coords, symbols = [], []
|
244 |
+
i = 0
|
245 |
+
if sequence[0] == SOS_ID:
|
246 |
+
i += 1
|
247 |
+
while i + 2 < len(sequence):
|
248 |
+
if sequence[i] == EOS_ID:
|
249 |
+
break
|
250 |
+
if self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]):
|
251 |
+
x = self.id_to_x(sequence[i])
|
252 |
+
y = self.id_to_y(sequence[i+1])
|
253 |
+
symbol = self.itos[sequence[i+2]]
|
254 |
+
coords.append([x, y])
|
255 |
+
symbols.append(symbol)
|
256 |
+
i += 3
|
257 |
+
return {'coords': coords, 'symbols': symbols}
|
258 |
+
|
259 |
+
def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False):
|
260 |
+
tokens = atomwise_tokenizer(smiles)
|
261 |
+
labels = [SOS_ID]
|
262 |
+
indices = []
|
263 |
+
atom_idx = -1
|
264 |
+
for token in tokens:
|
265 |
+
if atom_only and not self.is_atom_token(token):
|
266 |
+
continue
|
267 |
+
if token in self.stoi:
|
268 |
+
labels.append(self.stoi[token])
|
269 |
+
else:
|
270 |
+
if self.debug:
|
271 |
+
print(f'{token} not in vocab')
|
272 |
+
labels.append(UNK_ID)
|
273 |
+
if self.is_atom_token(token):
|
274 |
+
atom_idx += 1
|
275 |
+
if not self.continuous_coords:
|
276 |
+
if mask_ratio > 0 and random.random() < mask_ratio:
|
277 |
+
labels.append(MASK_ID)
|
278 |
+
labels.append(MASK_ID)
|
279 |
+
elif coords is not None:
|
280 |
+
if atom_idx < len(coords):
|
281 |
+
x, y = coords[atom_idx]
|
282 |
+
assert 0 <= x <= 1
|
283 |
+
assert 0 <= y <= 1
|
284 |
+
else:
|
285 |
+
x = random.random()
|
286 |
+
y = random.random()
|
287 |
+
labels.append(self.x_to_id(x))
|
288 |
+
labels.append(self.y_to_id(y))
|
289 |
+
indices.append(len(labels) - 1)
|
290 |
+
labels.append(EOS_ID)
|
291 |
+
return labels, indices
|
292 |
+
|
293 |
+
def sequence_to_smiles(self, sequence):
|
294 |
+
has_coords = not self.continuous_coords
|
295 |
+
smiles = ''
|
296 |
+
coords, symbols, indices = [], [], []
|
297 |
+
for i, label in enumerate(sequence):
|
298 |
+
if label == EOS_ID or label == PAD_ID:
|
299 |
+
break
|
300 |
+
if self.is_x(label) or self.is_y(label):
|
301 |
+
continue
|
302 |
+
token = self.itos[label]
|
303 |
+
smiles += token
|
304 |
+
if self.is_atom_token(token):
|
305 |
+
if has_coords:
|
306 |
+
if i+3 < len(sequence) and self.is_x(sequence[i+1]) and self.is_y(sequence[i+2]):
|
307 |
+
x = self.id_to_x(sequence[i+1])
|
308 |
+
y = self.id_to_y(sequence[i+2])
|
309 |
+
coords.append([x, y])
|
310 |
+
symbols.append(token)
|
311 |
+
indices.append(i+3)
|
312 |
+
else:
|
313 |
+
if i+1 < len(sequence):
|
314 |
+
symbols.append(token)
|
315 |
+
indices.append(i+1)
|
316 |
+
results = {'smiles': smiles, 'symbols': symbols, 'indices': indices}
|
317 |
+
if has_coords:
|
318 |
+
results['coords'] = coords
|
319 |
+
return results
|
320 |
+
|
321 |
+
|
322 |
+
class CharTokenizer(NodeTokenizer):
|
323 |
+
|
324 |
+
def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False):
|
325 |
+
super().__init__(input_size, path, sep_xy, continuous_coords, debug)
|
326 |
+
|
327 |
+
def fit_on_texts(self, texts):
|
328 |
+
vocab = set()
|
329 |
+
for text in texts:
|
330 |
+
vocab.update(list(text))
|
331 |
+
if ' ' in vocab:
|
332 |
+
vocab.remove(' ')
|
333 |
+
vocab = [PAD, SOS, EOS, UNK] + list(vocab)
|
334 |
+
for i, s in enumerate(vocab):
|
335 |
+
self.stoi[s] = i
|
336 |
+
self.itos = {item[1]: item[0] for item in self.stoi.items()}
|
337 |
+
assert self.stoi[PAD] == PAD_ID
|
338 |
+
assert self.stoi[SOS] == SOS_ID
|
339 |
+
assert self.stoi[EOS] == EOS_ID
|
340 |
+
assert self.stoi[UNK] == UNK_ID
|
341 |
+
|
342 |
+
def text_to_sequence(self, text, tokenized=True):
|
343 |
+
sequence = []
|
344 |
+
sequence.append(self.stoi['<sos>'])
|
345 |
+
if tokenized:
|
346 |
+
tokens = text.split(' ')
|
347 |
+
assert all(len(s) == 1 for s in tokens)
|
348 |
+
else:
|
349 |
+
tokens = list(text)
|
350 |
+
for s in tokens:
|
351 |
+
if s not in self.stoi:
|
352 |
+
s = '<unk>'
|
353 |
+
sequence.append(self.stoi[s])
|
354 |
+
sequence.append(self.stoi['<eos>'])
|
355 |
+
return sequence
|
356 |
+
|
357 |
+
def fit_atom_symbols(self, atoms):
|
358 |
+
atoms = list(set(atoms))
|
359 |
+
chars = []
|
360 |
+
for atom in atoms:
|
361 |
+
chars.extend(list(atom))
|
362 |
+
vocab = self.special_tokens + chars
|
363 |
+
for i, s in enumerate(vocab):
|
364 |
+
self.stoi[s] = i
|
365 |
+
assert self.stoi[PAD] == PAD_ID
|
366 |
+
assert self.stoi[SOS] == SOS_ID
|
367 |
+
assert self.stoi[EOS] == EOS_ID
|
368 |
+
assert self.stoi[UNK] == UNK_ID
|
369 |
+
assert self.stoi[MASK] == MASK_ID
|
370 |
+
self.itos = {item[1]: item[0] for item in self.stoi.items()}
|
371 |
+
|
372 |
+
def get_output_mask(self, id):
|
373 |
+
''' TO FIX '''
|
374 |
+
mask = [False] * len(self)
|
375 |
+
if self.continuous_coords:
|
376 |
+
return mask
|
377 |
+
if self.is_x(id):
|
378 |
+
return [True] * (self.offset + self.maxx) + [False] * self.maxy
|
379 |
+
if self.is_y(id):
|
380 |
+
return [False] * self.offset + [True] * (self.maxx + self.maxy)
|
381 |
+
return mask
|
382 |
+
|
383 |
+
def nodes_to_sequence(self, nodes):
|
384 |
+
coords, symbols = nodes['coords'], nodes['symbols']
|
385 |
+
labels = [SOS_ID]
|
386 |
+
for (x, y), symbol in zip(coords, symbols):
|
387 |
+
assert 0 <= x <= 1
|
388 |
+
assert 0 <= y <= 1
|
389 |
+
labels.append(self.x_to_id(x))
|
390 |
+
labels.append(self.y_to_id(y))
|
391 |
+
for char in symbol:
|
392 |
+
labels.append(self.symbol_to_id(char))
|
393 |
+
labels.append(EOS_ID)
|
394 |
+
return labels
|
395 |
+
|
396 |
+
def sequence_to_nodes(self, sequence):
|
397 |
+
coords, symbols = [], []
|
398 |
+
i = 0
|
399 |
+
if sequence[0] == SOS_ID:
|
400 |
+
i += 1
|
401 |
+
while i < len(sequence):
|
402 |
+
if sequence[i] == EOS_ID:
|
403 |
+
break
|
404 |
+
if i+2 < len(sequence) and self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]):
|
405 |
+
x = self.id_to_x(sequence[i])
|
406 |
+
y = self.id_to_y(sequence[i+1])
|
407 |
+
for j in range(i+2, len(sequence)):
|
408 |
+
if not self.is_symbol(sequence[j]):
|
409 |
+
break
|
410 |
+
symbol = ''.join(self.itos(sequence[k]) for k in range(i+2, j))
|
411 |
+
coords.append([x, y])
|
412 |
+
symbols.append(symbol)
|
413 |
+
i = j
|
414 |
+
else:
|
415 |
+
i += 1
|
416 |
+
return {'coords': coords, 'symbols': symbols}
|
417 |
+
|
418 |
+
def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False):
|
419 |
+
tokens = atomwise_tokenizer(smiles)
|
420 |
+
labels = [SOS_ID]
|
421 |
+
indices = []
|
422 |
+
atom_idx = -1
|
423 |
+
for token in tokens:
|
424 |
+
if atom_only and not self.is_atom_token(token):
|
425 |
+
continue
|
426 |
+
for c in token:
|
427 |
+
if c in self.stoi:
|
428 |
+
labels.append(self.stoi[c])
|
429 |
+
else:
|
430 |
+
if self.debug:
|
431 |
+
print(f'{c} not in vocab')
|
432 |
+
labels.append(UNK_ID)
|
433 |
+
if self.is_atom_token(token):
|
434 |
+
atom_idx += 1
|
435 |
+
if not self.continuous_coords:
|
436 |
+
if mask_ratio > 0 and random.random() < mask_ratio:
|
437 |
+
labels.append(MASK_ID)
|
438 |
+
labels.append(MASK_ID)
|
439 |
+
elif coords is not None:
|
440 |
+
if atom_idx < len(coords):
|
441 |
+
x, y = coords[atom_idx]
|
442 |
+
assert 0 <= x <= 1
|
443 |
+
assert 0 <= y <= 1
|
444 |
+
else:
|
445 |
+
x = random.random()
|
446 |
+
y = random.random()
|
447 |
+
labels.append(self.x_to_id(x))
|
448 |
+
labels.append(self.y_to_id(y))
|
449 |
+
indices.append(len(labels) - 1)
|
450 |
+
labels.append(EOS_ID)
|
451 |
+
return labels, indices
|
452 |
+
|
453 |
+
def sequence_to_smiles(self, sequence):
|
454 |
+
has_coords = not self.continuous_coords
|
455 |
+
smiles = ''
|
456 |
+
coords, symbols, indices = [], [], []
|
457 |
+
i = 0
|
458 |
+
while i < len(sequence):
|
459 |
+
label = sequence[i]
|
460 |
+
if label == EOS_ID or label == PAD_ID:
|
461 |
+
break
|
462 |
+
if self.is_x(label) or self.is_y(label):
|
463 |
+
i += 1
|
464 |
+
continue
|
465 |
+
if not self.is_atom(label):
|
466 |
+
smiles += self.itos[label]
|
467 |
+
i += 1
|
468 |
+
continue
|
469 |
+
if self.itos[label] == '[':
|
470 |
+
j = i + 1
|
471 |
+
while j < len(sequence):
|
472 |
+
if not self.is_symbol(sequence[j]):
|
473 |
+
break
|
474 |
+
if self.itos[sequence[j]] == ']':
|
475 |
+
j += 1
|
476 |
+
break
|
477 |
+
j += 1
|
478 |
+
else:
|
479 |
+
if i+1 < len(sequence) and (self.itos[label] == 'C' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'l' \
|
480 |
+
or self.itos[label] == 'B' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'r'):
|
481 |
+
j = i+2
|
482 |
+
else:
|
483 |
+
j = i+1
|
484 |
+
token = ''.join(self.itos[sequence[k]] for k in range(i, j))
|
485 |
+
smiles += token
|
486 |
+
if has_coords:
|
487 |
+
if j+2 < len(sequence) and self.is_x(sequence[j]) and self.is_y(sequence[j+1]):
|
488 |
+
x = self.id_to_x(sequence[j])
|
489 |
+
y = self.id_to_y(sequence[j+1])
|
490 |
+
coords.append([x, y])
|
491 |
+
symbols.append(token)
|
492 |
+
indices.append(j+2)
|
493 |
+
i = j+2
|
494 |
+
else:
|
495 |
+
i = j
|
496 |
+
else:
|
497 |
+
if j < len(sequence):
|
498 |
+
symbols.append(token)
|
499 |
+
indices.append(j)
|
500 |
+
i = j
|
501 |
+
results = {'smiles': smiles, 'symbols': symbols, 'indices': indices}
|
502 |
+
if has_coords:
|
503 |
+
results['coords'] = coords
|
504 |
+
return results
|
505 |
+
|
506 |
+
|
507 |
+
def get_tokenizer(args):
|
508 |
+
tokenizer = {}
|
509 |
+
for format_ in args.formats:
|
510 |
+
if format_ == 'atomtok':
|
511 |
+
if args.vocab_file is None:
|
512 |
+
args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json')
|
513 |
+
tokenizer['atomtok'] = Tokenizer(args.vocab_file)
|
514 |
+
elif format_ == "atomtok_coords":
|
515 |
+
if args.vocab_file is None:
|
516 |
+
args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json')
|
517 |
+
tokenizer["atomtok_coords"] = NodeTokenizer(args.coord_bins, args.vocab_file, args.sep_xy,
|
518 |
+
continuous_coords=args.continuous_coords)
|
519 |
+
elif format_ == "chartok_coords":
|
520 |
+
if args.vocab_file is None:
|
521 |
+
args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_chars.json')
|
522 |
+
tokenizer["chartok_coords"] = CharTokenizer(args.coord_bins, args.vocab_file, args.sep_xy,
|
523 |
+
continuous_coords=args.continuous_coords)
|
524 |
+
return tokenizer
|
molscribe/transformer/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .decoder import TransformerDecoder
|
2 |
+
from .embedding import Embeddings
|
3 |
+
from .swin_transformer import swin_base, swin_large
|
molscribe/transformer/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (324 Bytes). View file
|
|
molscribe/transformer/__pycache__/decoder.cpython-310.pyc
ADDED
Binary file (13.5 kB). View file
|
|
molscribe/transformer/__pycache__/embedding.cpython-310.pyc
ADDED
Binary file (7.91 kB). View file
|
|
molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc
ADDED
Binary file (21.2 kB). View file
|
|
molscribe/transformer/decoder.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of "Attention is All You Need" and of
|
3 |
+
subsequent transformer based architectures
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from onmt.decoders.decoder import DecoderBase
|
10 |
+
from onmt.modules import MultiHeadedAttention, AverageAttention
|
11 |
+
from onmt.modules.position_ffn import PositionwiseFeedForward
|
12 |
+
from onmt.modules.position_ffn import ActivationFunction
|
13 |
+
from onmt.utils.misc import sequence_mask
|
14 |
+
|
15 |
+
|
16 |
+
class TransformerDecoderLayerBase(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
d_model,
|
20 |
+
heads,
|
21 |
+
d_ff,
|
22 |
+
dropout,
|
23 |
+
attention_dropout,
|
24 |
+
self_attn_type="scaled-dot",
|
25 |
+
max_relative_positions=0,
|
26 |
+
aan_useffn=False,
|
27 |
+
full_context_alignment=False,
|
28 |
+
alignment_heads=0,
|
29 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
d_model (int): the dimension of keys/values/queries in
|
34 |
+
:class:`MultiHeadedAttention`, also the input size of
|
35 |
+
the first-layer of the :class:`PositionwiseFeedForward`.
|
36 |
+
heads (int): the number of heads for MultiHeadedAttention.
|
37 |
+
d_ff (int): the second-layer of the
|
38 |
+
:class:`PositionwiseFeedForward`.
|
39 |
+
dropout (float): dropout in residual, self-attn(dot) and
|
40 |
+
feed-forward
|
41 |
+
attention_dropout (float): dropout in context_attn (and
|
42 |
+
self-attn(avg))
|
43 |
+
self_attn_type (string): type of self-attention scaled-dot,
|
44 |
+
average
|
45 |
+
max_relative_positions (int):
|
46 |
+
Max distance between inputs in relative positions
|
47 |
+
representations
|
48 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
49 |
+
full_context_alignment (bool):
|
50 |
+
whether enable an extra full context decoder forward for
|
51 |
+
alignment
|
52 |
+
alignment_heads (int):
|
53 |
+
N. of cross attention heads to use for alignment guiding
|
54 |
+
pos_ffn_activation_fn (ActivationFunction):
|
55 |
+
activation function choice for PositionwiseFeedForward layer
|
56 |
+
|
57 |
+
"""
|
58 |
+
super(TransformerDecoderLayerBase, self).__init__()
|
59 |
+
|
60 |
+
if self_attn_type == "scaled-dot":
|
61 |
+
self.self_attn = MultiHeadedAttention(
|
62 |
+
heads,
|
63 |
+
d_model,
|
64 |
+
dropout=attention_dropout,
|
65 |
+
max_relative_positions=max_relative_positions,
|
66 |
+
)
|
67 |
+
elif self_attn_type == "average":
|
68 |
+
self.self_attn = AverageAttention(
|
69 |
+
d_model, dropout=attention_dropout, aan_useffn=aan_useffn
|
70 |
+
)
|
71 |
+
|
72 |
+
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
|
73 |
+
pos_ffn_activation_fn
|
74 |
+
)
|
75 |
+
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
|
76 |
+
self.drop = nn.Dropout(dropout)
|
77 |
+
self.full_context_alignment = full_context_alignment
|
78 |
+
self.alignment_heads = alignment_heads
|
79 |
+
|
80 |
+
def forward(self, *args, **kwargs):
|
81 |
+
"""Extend `_forward` for (possibly) multiple decoder pass:
|
82 |
+
Always a default (future masked) decoder forward pass,
|
83 |
+
Possibly a second future aware decoder pass for joint learn
|
84 |
+
full context alignement, :cite:`garg2019jointly`.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
* All arguments of _forward.
|
88 |
+
with_align (bool): whether return alignment attention.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
(FloatTensor, FloatTensor, FloatTensor or None):
|
92 |
+
|
93 |
+
* output ``(batch_size, T, model_dim)``
|
94 |
+
* top_attn ``(batch_size, T, src_len)``
|
95 |
+
* attn_align ``(batch_size, T, src_len)`` or None
|
96 |
+
"""
|
97 |
+
with_align = kwargs.pop("with_align", False)
|
98 |
+
output, attns = self._forward(*args, **kwargs)
|
99 |
+
top_attn = attns[:, 0, :, :].contiguous()
|
100 |
+
attn_align = None
|
101 |
+
if with_align:
|
102 |
+
if self.full_context_alignment:
|
103 |
+
# return _, (B, Q_len, K_len)
|
104 |
+
_, attns = self._forward(*args, **kwargs, future=True)
|
105 |
+
|
106 |
+
if self.alignment_heads > 0:
|
107 |
+
attns = attns[:, : self.alignment_heads, :, :].contiguous()
|
108 |
+
# layer average attention across heads, get ``(B, Q, K)``
|
109 |
+
# Case 1: no full_context, no align heads -> layer avg baseline
|
110 |
+
# Case 2: no full_context, 1 align heads -> guided align
|
111 |
+
# Case 3: full_context, 1 align heads -> full cte guided align
|
112 |
+
attn_align = attns.mean(dim=1)
|
113 |
+
return output, top_attn, attn_align
|
114 |
+
|
115 |
+
def update_dropout(self, dropout, attention_dropout):
|
116 |
+
self.self_attn.update_dropout(attention_dropout)
|
117 |
+
self.feed_forward.update_dropout(dropout)
|
118 |
+
self.drop.p = dropout
|
119 |
+
|
120 |
+
def _forward(self, *args, **kwargs):
|
121 |
+
raise NotImplementedError
|
122 |
+
|
123 |
+
def _compute_dec_mask(self, tgt_pad_mask, future):
|
124 |
+
tgt_len = tgt_pad_mask.size(-1)
|
125 |
+
if not future: # apply future_mask, result mask in (B, T, T)
|
126 |
+
future_mask = torch.ones(
|
127 |
+
[tgt_len, tgt_len],
|
128 |
+
device=tgt_pad_mask.device,
|
129 |
+
dtype=torch.uint8,
|
130 |
+
)
|
131 |
+
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
|
132 |
+
# BoolTensor was introduced in pytorch 1.2
|
133 |
+
try:
|
134 |
+
future_mask = future_mask.bool()
|
135 |
+
except AttributeError:
|
136 |
+
pass
|
137 |
+
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
|
138 |
+
else: # only mask padding, result mask in (B, 1, T)
|
139 |
+
dec_mask = tgt_pad_mask
|
140 |
+
return dec_mask
|
141 |
+
|
142 |
+
def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
|
143 |
+
if isinstance(self.self_attn, MultiHeadedAttention):
|
144 |
+
return self.self_attn(
|
145 |
+
inputs_norm,
|
146 |
+
inputs_norm,
|
147 |
+
inputs_norm,
|
148 |
+
mask=dec_mask,
|
149 |
+
layer_cache=layer_cache,
|
150 |
+
attn_type="self",
|
151 |
+
)
|
152 |
+
elif isinstance(self.self_attn, AverageAttention):
|
153 |
+
return self.self_attn(
|
154 |
+
inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
raise ValueError(
|
158 |
+
f"self attention {type(self.self_attn)} not supported"
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
class TransformerDecoderLayer(TransformerDecoderLayerBase):
|
163 |
+
"""Transformer Decoder layer block in Pre-Norm style.
|
164 |
+
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
|
165 |
+
providing better converge speed and performance. This is also the actual
|
166 |
+
implementation in tensor2tensor and also avalable in fairseq.
|
167 |
+
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
|
168 |
+
|
169 |
+
.. mermaid::
|
170 |
+
|
171 |
+
graph LR
|
172 |
+
%% "*SubLayer" can be self-attn, src-attn or feed forward block
|
173 |
+
A(input) --> B[Norm]
|
174 |
+
B --> C["*SubLayer"]
|
175 |
+
C --> D[Drop]
|
176 |
+
D --> E((+))
|
177 |
+
A --> E
|
178 |
+
E --> F(out)
|
179 |
+
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
d_model,
|
185 |
+
heads,
|
186 |
+
d_ff,
|
187 |
+
dropout,
|
188 |
+
attention_dropout,
|
189 |
+
self_attn_type="scaled-dot",
|
190 |
+
max_relative_positions=0,
|
191 |
+
aan_useffn=False,
|
192 |
+
full_context_alignment=False,
|
193 |
+
alignment_heads=0,
|
194 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
195 |
+
):
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
See TransformerDecoderLayerBase
|
199 |
+
"""
|
200 |
+
super(TransformerDecoderLayer, self).__init__(
|
201 |
+
d_model,
|
202 |
+
heads,
|
203 |
+
d_ff,
|
204 |
+
dropout,
|
205 |
+
attention_dropout,
|
206 |
+
self_attn_type,
|
207 |
+
max_relative_positions,
|
208 |
+
aan_useffn,
|
209 |
+
full_context_alignment,
|
210 |
+
alignment_heads,
|
211 |
+
pos_ffn_activation_fn=pos_ffn_activation_fn,
|
212 |
+
)
|
213 |
+
self.context_attn = MultiHeadedAttention(
|
214 |
+
heads, d_model, dropout=attention_dropout
|
215 |
+
)
|
216 |
+
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
|
217 |
+
|
218 |
+
def update_dropout(self, dropout, attention_dropout):
|
219 |
+
super(TransformerDecoderLayer, self).update_dropout(
|
220 |
+
dropout, attention_dropout
|
221 |
+
)
|
222 |
+
self.context_attn.update_dropout(attention_dropout)
|
223 |
+
|
224 |
+
def _forward(
|
225 |
+
self,
|
226 |
+
inputs,
|
227 |
+
memory_bank,
|
228 |
+
src_pad_mask,
|
229 |
+
tgt_pad_mask,
|
230 |
+
layer_cache=None,
|
231 |
+
step=None,
|
232 |
+
future=False,
|
233 |
+
):
|
234 |
+
"""A naive forward pass for transformer decoder.
|
235 |
+
|
236 |
+
# T: could be 1 in the case of stepwise decoding or tgt_len
|
237 |
+
|
238 |
+
Args:
|
239 |
+
inputs (FloatTensor): ``(batch_size, T, model_dim)``
|
240 |
+
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
|
241 |
+
src_pad_mask (bool): ``(batch_size, 1, src_len)``
|
242 |
+
tgt_pad_mask (bool): ``(batch_size, 1, T)``
|
243 |
+
layer_cache (dict or None): cached layer info when stepwise decode
|
244 |
+
step (int or None): stepwise decoding counter
|
245 |
+
future (bool): If set True, do not apply future_mask.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
(FloatTensor, FloatTensor):
|
249 |
+
|
250 |
+
* output ``(batch_size, T, model_dim)``
|
251 |
+
* attns ``(batch_size, head, T, src_len)``
|
252 |
+
|
253 |
+
"""
|
254 |
+
dec_mask = None
|
255 |
+
|
256 |
+
if inputs.size(1) > 1:
|
257 |
+
# masking is necessary when sequence length is greater than one
|
258 |
+
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
|
259 |
+
|
260 |
+
inputs_norm = self.layer_norm_1(inputs)
|
261 |
+
|
262 |
+
query, _ = self._forward_self_attn(
|
263 |
+
inputs_norm, dec_mask, layer_cache, step
|
264 |
+
)
|
265 |
+
|
266 |
+
query = self.drop(query) + inputs
|
267 |
+
|
268 |
+
query_norm = self.layer_norm_2(query)
|
269 |
+
mid, attns = self.context_attn(
|
270 |
+
memory_bank,
|
271 |
+
memory_bank,
|
272 |
+
query_norm,
|
273 |
+
mask=src_pad_mask,
|
274 |
+
layer_cache=layer_cache,
|
275 |
+
attn_type="context",
|
276 |
+
)
|
277 |
+
output = self.feed_forward(self.drop(mid) + query)
|
278 |
+
|
279 |
+
return output, attns
|
280 |
+
|
281 |
+
|
282 |
+
class TransformerDecoderBase(DecoderBase):
|
283 |
+
def __init__(self, d_model, copy_attn, alignment_layer):
|
284 |
+
super(TransformerDecoderBase, self).__init__()
|
285 |
+
|
286 |
+
# Decoder State
|
287 |
+
self.state = {}
|
288 |
+
|
289 |
+
# previously, there was a GlobalAttention module here for copy
|
290 |
+
# attention. But it was never actually used -- the "copy" attention
|
291 |
+
# just reuses the context attention.
|
292 |
+
self._copy = copy_attn
|
293 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
294 |
+
|
295 |
+
self.alignment_layer = alignment_layer
|
296 |
+
|
297 |
+
@classmethod
|
298 |
+
def from_opt(cls, opt, embeddings):
|
299 |
+
"""Alternate constructor."""
|
300 |
+
return cls(
|
301 |
+
opt.dec_layers,
|
302 |
+
opt.dec_rnn_size,
|
303 |
+
opt.heads,
|
304 |
+
opt.transformer_ff,
|
305 |
+
opt.copy_attn,
|
306 |
+
opt.self_attn_type,
|
307 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
308 |
+
opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout,
|
309 |
+
embeddings,
|
310 |
+
opt.max_relative_positions,
|
311 |
+
opt.aan_useffn,
|
312 |
+
opt.full_context_alignment,
|
313 |
+
opt.alignment_layer,
|
314 |
+
alignment_heads=opt.alignment_heads,
|
315 |
+
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
|
316 |
+
)
|
317 |
+
|
318 |
+
def init_state(self, src, memory_bank, enc_hidden):
|
319 |
+
"""Initialize decoder state."""
|
320 |
+
self.state["src"] = src
|
321 |
+
self.state["cache"] = None
|
322 |
+
|
323 |
+
def map_state(self, fn):
|
324 |
+
def _recursive_map(struct, batch_dim=0):
|
325 |
+
for k, v in struct.items():
|
326 |
+
if v is not None:
|
327 |
+
if isinstance(v, dict):
|
328 |
+
_recursive_map(v)
|
329 |
+
else:
|
330 |
+
struct[k] = fn(v, batch_dim)
|
331 |
+
|
332 |
+
if self.state["src"] is not None:
|
333 |
+
self.state["src"] = fn(self.state["src"], 1)
|
334 |
+
if self.state["cache"] is not None:
|
335 |
+
_recursive_map(self.state["cache"])
|
336 |
+
|
337 |
+
def detach_state(self):
|
338 |
+
raise NotImplementedError
|
339 |
+
|
340 |
+
def forward(self, *args, **kwargs):
|
341 |
+
raise NotImplementedError
|
342 |
+
|
343 |
+
def update_dropout(self, dropout, attention_dropout):
|
344 |
+
self.embeddings.update_dropout(dropout)
|
345 |
+
for layer in self.transformer_layers:
|
346 |
+
layer.update_dropout(dropout, attention_dropout)
|
347 |
+
|
348 |
+
|
349 |
+
class TransformerDecoder(TransformerDecoderBase):
|
350 |
+
"""The Transformer decoder from "Attention is All You Need".
|
351 |
+
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
|
352 |
+
|
353 |
+
.. mermaid::
|
354 |
+
|
355 |
+
graph BT
|
356 |
+
A[input]
|
357 |
+
B[multi-head self-attn]
|
358 |
+
BB[multi-head src-attn]
|
359 |
+
C[feed forward]
|
360 |
+
O[output]
|
361 |
+
A --> B
|
362 |
+
B --> BB
|
363 |
+
BB --> C
|
364 |
+
C --> O
|
365 |
+
|
366 |
+
|
367 |
+
Args:
|
368 |
+
num_layers (int): number of decoder layers.
|
369 |
+
d_model (int): size of the model
|
370 |
+
heads (int): number of heads
|
371 |
+
d_ff (int): size of the inner FF layer
|
372 |
+
copy_attn (bool): if using a separate copy attention
|
373 |
+
self_attn_type (str): type of self-attention scaled-dot, average
|
374 |
+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
375 |
+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
376 |
+
embeddings (onmt.modules.Embeddings):
|
377 |
+
embeddings to use, should have positional encodings
|
378 |
+
max_relative_positions (int):
|
379 |
+
Max distance between inputs in relative positions representations
|
380 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
381 |
+
full_context_alignment (bool):
|
382 |
+
whether enable an extra full context decoder forward for alignment
|
383 |
+
alignment_layer (int): N° Layer to supervise with for alignment guiding
|
384 |
+
alignment_heads (int):
|
385 |
+
N. of cross attention heads to use for alignment guiding
|
386 |
+
"""
|
387 |
+
|
388 |
+
def __init__(
|
389 |
+
self,
|
390 |
+
num_layers,
|
391 |
+
d_model,
|
392 |
+
heads,
|
393 |
+
d_ff,
|
394 |
+
copy_attn,
|
395 |
+
self_attn_type,
|
396 |
+
dropout,
|
397 |
+
attention_dropout,
|
398 |
+
max_relative_positions,
|
399 |
+
aan_useffn,
|
400 |
+
full_context_alignment,
|
401 |
+
alignment_layer,
|
402 |
+
alignment_heads,
|
403 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
404 |
+
):
|
405 |
+
super(TransformerDecoder, self).__init__(
|
406 |
+
d_model, copy_attn, alignment_layer
|
407 |
+
)
|
408 |
+
|
409 |
+
self.transformer_layers = nn.ModuleList(
|
410 |
+
[
|
411 |
+
TransformerDecoderLayer(
|
412 |
+
d_model,
|
413 |
+
heads,
|
414 |
+
d_ff,
|
415 |
+
dropout,
|
416 |
+
attention_dropout,
|
417 |
+
self_attn_type=self_attn_type,
|
418 |
+
max_relative_positions=max_relative_positions,
|
419 |
+
aan_useffn=aan_useffn,
|
420 |
+
full_context_alignment=full_context_alignment,
|
421 |
+
alignment_heads=alignment_heads,
|
422 |
+
pos_ffn_activation_fn=pos_ffn_activation_fn,
|
423 |
+
)
|
424 |
+
for i in range(num_layers)
|
425 |
+
]
|
426 |
+
)
|
427 |
+
|
428 |
+
def detach_state(self):
|
429 |
+
self.state["src"] = self.state["src"].detach()
|
430 |
+
|
431 |
+
def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs):
|
432 |
+
"""Decode, possibly stepwise."""
|
433 |
+
if step == 0:
|
434 |
+
self._init_cache(memory_bank)
|
435 |
+
|
436 |
+
batch_size, src_len, src_dim = memory_bank.size()
|
437 |
+
device = memory_bank.device
|
438 |
+
if src_pad_mask is None:
|
439 |
+
src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device)
|
440 |
+
output = tgt_emb
|
441 |
+
batch_size, tgt_len, tgt_dim = tgt_emb.size()
|
442 |
+
if tgt_pad_mask is None:
|
443 |
+
tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device)
|
444 |
+
|
445 |
+
future = kwargs.pop("future", False)
|
446 |
+
with_align = kwargs.pop("with_align", False)
|
447 |
+
attn_aligns = []
|
448 |
+
hiddens = []
|
449 |
+
|
450 |
+
for i, layer in enumerate(self.transformer_layers):
|
451 |
+
layer_cache = (
|
452 |
+
self.state["cache"]["layer_{}".format(i)]
|
453 |
+
if step is not None
|
454 |
+
else None
|
455 |
+
)
|
456 |
+
output, attn, attn_align = layer(
|
457 |
+
output,
|
458 |
+
memory_bank,
|
459 |
+
src_pad_mask,
|
460 |
+
tgt_pad_mask,
|
461 |
+
layer_cache=layer_cache,
|
462 |
+
step=step,
|
463 |
+
with_align=with_align,
|
464 |
+
future=future
|
465 |
+
)
|
466 |
+
hiddens.append(output)
|
467 |
+
if attn_align is not None:
|
468 |
+
attn_aligns.append(attn_align)
|
469 |
+
|
470 |
+
output = self.layer_norm(output) # (B, L, D)
|
471 |
+
|
472 |
+
attns = {"std": attn}
|
473 |
+
if self._copy:
|
474 |
+
attns["copy"] = attn
|
475 |
+
if with_align:
|
476 |
+
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
|
477 |
+
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
|
478 |
+
|
479 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
480 |
+
return output, attns, hiddens
|
481 |
+
|
482 |
+
def _init_cache(self, memory_bank):
|
483 |
+
self.state["cache"] = {}
|
484 |
+
for i, layer in enumerate(self.transformer_layers):
|
485 |
+
layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None}
|
486 |
+
self.state["cache"]["layer_{}".format(i)] = layer_cache
|
487 |
+
|