CYF200127 commited on
Commit
5e9bd47
·
verified ·
1 Parent(s): 23b9d28

Upload 116 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +131 -0
  2. examples/exp.png +0 -0
  3. examples/reaction1.png +0 -0
  4. examples/reaction2.png +0 -0
  5. examples/reaction3.png +0 -0
  6. examples/reaction4.png +0 -0
  7. getReaction.py +78 -0
  8. molscribe/__init__.py +1 -0
  9. molscribe/__pycache__/__init__.cpython-310.pyc +0 -0
  10. molscribe/__pycache__/augment.cpython-310.pyc +0 -0
  11. molscribe/__pycache__/chemistry.cpython-310.pyc +0 -0
  12. molscribe/__pycache__/constants.cpython-310.pyc +0 -0
  13. molscribe/__pycache__/dataset.cpython-310.pyc +0 -0
  14. molscribe/__pycache__/evaluate.cpython-310.pyc +0 -0
  15. molscribe/__pycache__/interface.cpython-310.pyc +0 -0
  16. molscribe/__pycache__/loss.cpython-310.pyc +0 -0
  17. molscribe/__pycache__/model.cpython-310.pyc +0 -0
  18. molscribe/__pycache__/tokenizer.cpython-310.pyc +0 -0
  19. molscribe/__pycache__/utils.cpython-310.pyc +0 -0
  20. molscribe/augment.py +282 -0
  21. molscribe/chemistry.py +649 -0
  22. molscribe/constants.py +130 -0
  23. molscribe/dataset.py +594 -0
  24. molscribe/evaluate.py +79 -0
  25. molscribe/indigo/__init__.py +0 -0
  26. molscribe/indigo/__pycache__/__init__.cpython-310.pyc +0 -0
  27. molscribe/indigo/__pycache__/bingo.cpython-310.pyc +0 -0
  28. molscribe/indigo/__pycache__/inchi.cpython-310.pyc +0 -0
  29. molscribe/indigo/__pycache__/renderer.cpython-310.pyc +0 -0
  30. molscribe/indigo/bingo.py +334 -0
  31. molscribe/indigo/inchi.py +84 -0
  32. molscribe/indigo/renderer.py +113 -0
  33. molscribe/inference/__init__.py +4 -0
  34. molscribe/inference/__pycache__/__init__.cpython-310.pyc +0 -0
  35. molscribe/inference/__pycache__/beam_search.cpython-310.pyc +0 -0
  36. molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc +0 -0
  37. molscribe/inference/__pycache__/greedy_search.cpython-310.pyc +0 -0
  38. molscribe/inference/beam_search.py +190 -0
  39. molscribe/inference/decode_strategy.py +63 -0
  40. molscribe/inference/greedy_search.py +128 -0
  41. molscribe/interface.py +223 -0
  42. molscribe/loss.py +125 -0
  43. molscribe/model.py +397 -0
  44. molscribe/tokenizer.py +524 -0
  45. molscribe/transformer/__init__.py +3 -0
  46. molscribe/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  47. molscribe/transformer/__pycache__/decoder.cpython-310.pyc +0 -0
  48. molscribe/transformer/__pycache__/embedding.cpython-310.pyc +0 -0
  49. molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc +0 -0
  50. molscribe/transformer/decoder.py +487 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ from rxnim import RXNIM
5
+ from getReaction import generate_combined_image
6
+ import torch
7
+ from rxn.reaction import Reaction
8
+
9
+ PROMPT_DIR = "prompts/"
10
+ ckpt_path = "./rxn/model/model.ckpt"
11
+ model = Reaction(ckpt_path, device=torch.device('cpu'))
12
+
13
+ # 定义 prompt 文件名到友好名字的映射
14
+ PROMPT_NAMES = {
15
+ "2_RxnOCR.txt": "Reaction Image Parsing Workflow",
16
+ }
17
+ example_diagram = "examples/exp.png"
18
+
19
+ def list_prompt_files_with_names():
20
+ """
21
+ 列出 prompts 目录下的所有 .txt 文件,为没有名字的生成默认名字。
22
+ 返回 {friendly_name: filename} 映射。
23
+ """
24
+ prompt_files = {}
25
+ for f in os.listdir(PROMPT_DIR):
26
+ if f.endswith(".txt"):
27
+ # 如果文件名有预定义的名字,使用预定义名字
28
+ friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}")
29
+ prompt_files[friendly_name] = f
30
+ return prompt_files
31
+
32
+ def parse_reactions(output_json):
33
+ """
34
+ 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。
35
+ """
36
+ reactions_data = json.loads(output_json) # 转换 JSON 字符串为字典
37
+ reactions_list = reactions_data.get("reactions", [])
38
+ detailed_output = []
39
+
40
+ for reaction in reactions_list:
41
+ reaction_id = reaction.get("reaction_id", "Unknown ID")
42
+ reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
43
+ conditions = [
44
+ f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
45
+ for c in reaction.get("conditions", [])
46
+ ]
47
+ conditions_1 = [
48
+ f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
49
+ for c in reaction.get("conditions", [])
50
+ ]
51
+ products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
52
+ products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
53
+
54
+ # 构造反应的完整字符串,定制字体颜色
55
+ full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}"
56
+ full_reaction = f"<span style='color:black'>{full_reaction}</span>"
57
+
58
+ # 详细反应格式化输出
59
+ reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
60
+ reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
61
+ reaction_output += f" Conditions: {', '.join(conditions)}<br>"
62
+ reaction_output += f" Products: {', '.join(products)}<br>"
63
+ reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>"
64
+ reaction_output += "<br>"
65
+ detailed_output.append(reaction_output)
66
+
67
+ return detailed_output
68
+
69
+ def process_chem_image(image, selected_task):
70
+ chem_mllm = RXNIM()
71
+
72
+ # 将友好名字转换为实际文件名
73
+ prompt_path = os.path.join(PROMPT_DIR, prompts_with_names[selected_task])
74
+ image_path = "temp_image.png"
75
+ image.save(image_path)
76
+
77
+ # 调用 RXNIM 处理
78
+ rxnim_result = chem_mllm.process(image_path, prompt_path)
79
+
80
+ # 将 JSON 结果解析为结构化输出
81
+ detailed_reactions = parse_reactions(rxnim_result)
82
+
83
+ # 调用 RxnScribe 模型处理并生成整合图像
84
+ predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
85
+ combined_image_path = generate_combined_image(predictions, image_path)
86
+
87
+ json_file_path = "output.json"
88
+ with open(json_file_path, "w") as json_file:
89
+ json.dump(json.loads(rxnim_result), json_file, indent=4)
90
+
91
+
92
+ # 返回详细反应和整合图像
93
+ return "\n\n".join(detailed_reactions), combined_image_path, example_diagram, json_file_path
94
+
95
+
96
+ # 获取 prompts 和友好名字
97
+ prompts_with_names = list_prompt_files_with_names()
98
+
99
+ # 示例数据:图像路径 + 任务选项
100
+ examples = [
101
+
102
+ ["examples/reaction1.png", "Reaction Image Parsing Workflow"],
103
+ ["examples/reaction2.png", "Reaction Image Parsing Workflow"],
104
+ ["examples/reaction3.png", "Reaction Image Parsing Workflow"],
105
+ ["examples/reaction4.png", "Reaction Image Parsing Workflow"],
106
+ ]
107
+
108
+ # 定义 Gradio 界面
109
+ demo = gr.Interface(
110
+ fn=process_chem_image,
111
+ inputs=[
112
+ gr.Image(type="pil", label="Upload Reaction Image"),
113
+ gr.Radio(
114
+ choices=list(prompts_with_names.keys()), # 显示任务名字
115
+ label="Select a predefined task",
116
+ ),
117
+ ],
118
+ outputs=[
119
+ gr.HTML(label="Reaction outputs"),
120
+ gr.Image(label="Visualization"), # 显示整合图像
121
+ gr.Image(value=example_diagram, label="Schematic Diagram"),
122
+ gr.File(label="Download JSON File"),
123
+
124
+ ],
125
+ title="Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model",
126
+ description="Upload a reaction image and select a predefined task prompt.",
127
+ examples=examples, # 使用嵌套列表作为示例
128
+ examples_per_page=20,
129
+ )
130
+
131
+ demo.launch()
examples/exp.png ADDED
examples/reaction1.png ADDED
examples/reaction2.png ADDED
examples/reaction3.png ADDED
examples/reaction4.png ADDED
getReaction.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./rxn/')
3
+ import torch
4
+ from rxn.reaction import Reaction
5
+ import json
6
+ from matplotlib import pyplot as plt
7
+ import numpy as np
8
+
9
+ ckpt_path = "./rxn/model/model.ckpt"
10
+ model = Reaction(ckpt_path, device=torch.device('cpu'))
11
+ device = torch.device('cpu')
12
+
13
+ def get_reaction(image_path: str) -> list:
14
+ '''Returns a list of reactions extracted from the image.'''
15
+ image_file = image_path
16
+ return json.dumps(model.predict_image_file(image_file, molscribe=True, ocr=True))
17
+
18
+
19
+
20
+ def generate_combined_image(predictions, image_file):
21
+ """
22
+ 将预测的图像整合到一个对称的布局中输出。
23
+ """
24
+ output = model.draw_predictions(predictions, image_file=image_file)
25
+ n_images = len(output)
26
+ if n_images == 1:
27
+ n_cols = 1
28
+ elif n_images == 2:
29
+ n_cols = 2
30
+ else:
31
+ n_cols = 3
32
+ n_rows = (n_images + n_cols - 1) // n_cols # 计算需要的行数
33
+
34
+ # 确保每张图像符合要求
35
+ processed_images = []
36
+ for img in output:
37
+ if len(img.shape) == 2: # 灰度图像
38
+ img = np.stack([img] * 3, axis=-1) # 转换为 RGB 格式
39
+ elif img.shape[2] > 3: # RGBA 图像
40
+ img = img[:, :, :3] # 只保留 RGB 通道
41
+ if img.dtype == np.float32 or img.dtype == np.float64:
42
+ img = (img * 255).astype(np.uint8) # 转换为 uint8
43
+ processed_images.append(img)
44
+ output = processed_images
45
+
46
+ # 为不足的子图位置添加占位图
47
+ if n_images < n_rows * n_cols:
48
+ blank_image = np.ones_like(output[0]) * 255 # 生成一个白色占位图
49
+ while len(output) < n_rows * n_cols:
50
+ output.append(blank_image)
51
+
52
+ # 创建子图画布
53
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
54
+
55
+ # 确保 axes 是一维数组
56
+ if isinstance(axes, np.ndarray):
57
+ axes = axes.flatten()
58
+ else:
59
+ axes = [axes] # 单个子图的情况
60
+
61
+ # 绘制每张图像
62
+ for idx, img in enumerate(output):
63
+ ax = axes[idx]
64
+ ax.imshow(img)
65
+ ax.axis('off')
66
+ if idx < n_images:
67
+ ax.set_title(f"Reaction {idx + 1}")
68
+
69
+ # 删除多余的子图
70
+ for idx in range(n_images, len(axes)):
71
+ fig.delaxes(axes[idx])
72
+
73
+ # 保存整合图像
74
+ combined_image_path = "combined_output.png"
75
+ plt.tight_layout()
76
+ plt.savefig(combined_image_path)
77
+ plt.close(fig)
78
+ return combined_image_path
molscribe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interface import MolScribe
molscribe/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (190 Bytes). View file
 
molscribe/__pycache__/augment.cpython-310.pyc ADDED
Binary file (8.98 kB). View file
 
molscribe/__pycache__/chemistry.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
molscribe/__pycache__/constants.cpython-310.pyc ADDED
Binary file (6.18 kB). View file
 
molscribe/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (17.6 kB). View file
 
molscribe/__pycache__/evaluate.cpython-310.pyc ADDED
Binary file (3.48 kB). View file
 
molscribe/__pycache__/interface.cpython-310.pyc ADDED
Binary file (8.95 kB). View file
 
molscribe/__pycache__/loss.cpython-310.pyc ADDED
Binary file (4.25 kB). View file
 
molscribe/__pycache__/model.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
molscribe/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (16.8 kB). View file
 
molscribe/__pycache__/utils.cpython-310.pyc ADDED
Binary file (6.33 kB). View file
 
molscribe/augment.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from albumentations.augmentations.geometric.functional import safe_rotate_enlarged_img_size, _maybe_process_in_chunks, \
3
+ keypoint_rotate
4
+ import cv2
5
+ import math
6
+ import random
7
+ import numpy as np
8
+
9
+
10
+ def safe_rotate(
11
+ img: np.ndarray,
12
+ angle: int = 0,
13
+ interpolation: int = cv2.INTER_LINEAR,
14
+ value: int = None,
15
+ border_mode: int = cv2.BORDER_REFLECT_101,
16
+ ):
17
+
18
+ old_rows, old_cols = img.shape[:2]
19
+
20
+ # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
21
+ image_center = (old_cols / 2, old_rows / 2)
22
+
23
+ # Rows and columns of the rotated image (not cropped)
24
+ new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols)
25
+
26
+ # Rotation Matrix
27
+ rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
28
+
29
+ # Shift the image to create padding
30
+ rotation_mat[0, 2] += new_cols / 2 - image_center[0]
31
+ rotation_mat[1, 2] += new_rows / 2 - image_center[1]
32
+
33
+ # CV2 Transformation function
34
+ warp_affine_fn = _maybe_process_in_chunks(
35
+ cv2.warpAffine,
36
+ M=rotation_mat,
37
+ dsize=(new_cols, new_rows),
38
+ flags=interpolation,
39
+ borderMode=border_mode,
40
+ borderValue=value,
41
+ )
42
+
43
+ # rotate image with the new bounds
44
+ rotated_img = warp_affine_fn(img)
45
+
46
+ return rotated_img
47
+
48
+
49
+ def keypoint_safe_rotate(keypoint, angle, rows, cols):
50
+ old_rows = rows
51
+ old_cols = cols
52
+
53
+ # Rows and columns of the rotated image (not cropped)
54
+ new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols)
55
+
56
+ col_diff = (new_cols - old_cols) / 2
57
+ row_diff = (new_rows - old_rows) / 2
58
+
59
+ # Shift keypoint
60
+ shifted_keypoint = (int(keypoint[0] + col_diff), int(keypoint[1] + row_diff), keypoint[2], keypoint[3])
61
+
62
+ # Rotate keypoint
63
+ rotated_keypoint = keypoint_rotate(shifted_keypoint, angle, rows=new_rows, cols=new_cols)
64
+
65
+ return rotated_keypoint
66
+
67
+
68
+ class SafeRotate(A.SafeRotate):
69
+
70
+ def __init__(
71
+ self,
72
+ limit=90,
73
+ interpolation=cv2.INTER_LINEAR,
74
+ border_mode=cv2.BORDER_REFLECT_101,
75
+ value=None,
76
+ mask_value=None,
77
+ always_apply=False,
78
+ p=0.5,
79
+ ):
80
+ super(SafeRotate, self).__init__(
81
+ limit=limit,
82
+ interpolation=interpolation,
83
+ border_mode=border_mode,
84
+ value=value,
85
+ mask_value=mask_value,
86
+ always_apply=always_apply,
87
+ p=p)
88
+
89
+ def apply(self, img, angle=0, interpolation=cv2.INTER_LINEAR, **params):
90
+ return safe_rotate(
91
+ img=img, value=self.value, angle=angle, interpolation=interpolation, border_mode=self.border_mode)
92
+
93
+ def apply_to_keypoint(self, keypoint, angle=0, **params):
94
+ return keypoint_safe_rotate(keypoint, angle=angle, rows=params["rows"], cols=params["cols"])
95
+
96
+
97
+ class CropWhite(A.DualTransform):
98
+
99
+ def __init__(self, value=(255, 255, 255), pad=0, p=1.0):
100
+ super(CropWhite, self).__init__(p=p)
101
+ self.value = value
102
+ self.pad = pad
103
+ assert pad >= 0
104
+
105
+ def update_params(self, params, **kwargs):
106
+ super().update_params(params, **kwargs)
107
+ assert "image" in kwargs
108
+ img = kwargs["image"]
109
+ height, width, _ = img.shape
110
+ x = (img != self.value).sum(axis=2)
111
+ if x.sum() == 0:
112
+ return params
113
+ row_sum = x.sum(axis=1)
114
+ top = 0
115
+ while row_sum[top] == 0 and top+1 < height:
116
+ top += 1
117
+ bottom = height
118
+ while row_sum[bottom-1] == 0 and bottom-1 > top:
119
+ bottom -= 1
120
+ col_sum = x.sum(axis=0)
121
+ left = 0
122
+ while col_sum[left] == 0 and left+1 < width:
123
+ left += 1
124
+ right = width
125
+ while col_sum[right-1] == 0 and right-1 > left:
126
+ right -= 1
127
+ # crop_top = max(0, top - self.pad)
128
+ # crop_bottom = max(0, height - bottom - self.pad)
129
+ # crop_left = max(0, left - self.pad)
130
+ # crop_right = max(0, width - right - self.pad)
131
+ # params.update({"crop_top": crop_top, "crop_bottom": crop_bottom,
132
+ # "crop_left": crop_left, "crop_right": crop_right})
133
+ params.update({"crop_top": top, "crop_bottom": height - bottom,
134
+ "crop_left": left, "crop_right": width - right})
135
+ return params
136
+
137
+ def apply(self, img, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params):
138
+ height, width, _ = img.shape
139
+ img = img[crop_top:height - crop_bottom, crop_left:width - crop_right]
140
+ img = A.augmentations.pad_with_params(
141
+ img, self.pad, self.pad, self.pad, self.pad, border_mode=cv2.BORDER_CONSTANT, value=self.value)
142
+ return img
143
+
144
+ def apply_to_keypoint(self, keypoint, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params):
145
+ x, y, angle, scale = keypoint[:4]
146
+ return x - crop_left + self.pad, y - crop_top + self.pad, angle, scale
147
+
148
+ def get_transform_init_args_names(self):
149
+ return ('value', 'pad')
150
+
151
+
152
+ class PadWhite(A.DualTransform):
153
+
154
+ def __init__(self, pad_ratio=0.2, p=0.5, value=(255, 255, 255)):
155
+ super(PadWhite, self).__init__(p=p)
156
+ self.pad_ratio = pad_ratio
157
+ self.value = value
158
+
159
+ def update_params(self, params, **kwargs):
160
+ super().update_params(params, **kwargs)
161
+ assert "image" in kwargs
162
+ img = kwargs["image"]
163
+ height, width, _ = img.shape
164
+ side = random.randrange(4)
165
+ if side == 0:
166
+ params['pad_top'] = int(height * self.pad_ratio * random.random())
167
+ elif side == 1:
168
+ params['pad_bottom'] = int(height * self.pad_ratio * random.random())
169
+ elif side == 2:
170
+ params['pad_left'] = int(width * self.pad_ratio * random.random())
171
+ elif side == 3:
172
+ params['pad_right'] = int(width * self.pad_ratio * random.random())
173
+ return params
174
+
175
+ def apply(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
176
+ height, width, _ = img.shape
177
+ img = A.augmentations.pad_with_params(
178
+ img, pad_top, pad_bottom, pad_left, pad_right, border_mode=cv2.BORDER_CONSTANT, value=self.value)
179
+ return img
180
+
181
+ def apply_to_keypoint(self, keypoint, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
182
+ x, y, angle, scale = keypoint[:4]
183
+ return x + pad_left, y + pad_top, angle, scale
184
+
185
+ def get_transform_init_args_names(self):
186
+ return ('value', 'pad_ratio')
187
+
188
+
189
+ class SaltAndPepperNoise(A.DualTransform):
190
+
191
+ def __init__(self, num_dots, value=(0, 0, 0), p=0.5):
192
+ super().__init__(p)
193
+ self.num_dots = num_dots
194
+ self.value = value
195
+
196
+ def apply(self, img, **params):
197
+ height, width, _ = img.shape
198
+ num_dots = random.randrange(self.num_dots + 1)
199
+ for i in range(num_dots):
200
+ x = random.randrange(height)
201
+ y = random.randrange(width)
202
+ img[x, y] = self.value
203
+ return img
204
+
205
+ def apply_to_keypoint(self, keypoint, **params):
206
+ return keypoint
207
+
208
+ def get_transform_init_args_names(self):
209
+ return ('value', 'num_dots')
210
+
211
+ class ResizePad(A.DualTransform):
212
+
213
+ def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, value=(255, 255, 255)):
214
+ super(ResizePad, self).__init__(always_apply=True)
215
+ self.height = height
216
+ self.width = width
217
+ self.interpolation = interpolation
218
+ self.value = value
219
+
220
+ def apply(self, img, interpolation=cv2.INTER_LINEAR, **params):
221
+ h, w, _ = img.shape
222
+ img = A.augmentations.geometric.functional.resize(
223
+ img,
224
+ height=min(h, self.height),
225
+ width=min(w, self.width),
226
+ interpolation=interpolation
227
+ )
228
+ h, w, _ = img.shape
229
+ pad_top = (self.height - h) // 2
230
+ pad_bottom = (self.height - h) - pad_top
231
+ pad_left = (self.width - w) // 2
232
+ pad_right = (self.width - w) - pad_left
233
+ img = A.augmentations.pad_with_params(
234
+ img,
235
+ pad_top,
236
+ pad_bottom,
237
+ pad_left,
238
+ pad_right,
239
+ border_mode=cv2.BORDER_CONSTANT,
240
+ value=self.value,
241
+ )
242
+ return img
243
+
244
+
245
+ def normalized_grid_distortion(
246
+ img,
247
+ num_steps=10,
248
+ xsteps=(),
249
+ ysteps=(),
250
+ *args,
251
+ **kwargs
252
+ ):
253
+ height, width = img.shape[:2]
254
+
255
+ # compensate for smaller last steps in source image.
256
+ x_step = width // num_steps
257
+ last_x_step = min(width, ((num_steps + 1) * x_step)) - (num_steps * x_step)
258
+ xsteps[-1] *= last_x_step / x_step
259
+
260
+ y_step = height // num_steps
261
+ last_y_step = min(height, ((num_steps + 1) * y_step)) - (num_steps * y_step)
262
+ ysteps[-1] *= last_y_step / y_step
263
+
264
+ # now normalize such that distortion never leaves image bounds.
265
+ tx = width / math.floor(width / num_steps)
266
+ ty = height / math.floor(height / num_steps)
267
+ xsteps = np.array(xsteps) * (tx / np.sum(xsteps))
268
+ ysteps = np.array(ysteps) * (ty / np.sum(ysteps))
269
+
270
+ # do actual distortion.
271
+ return A.augmentations.functional.grid_distortion(img, num_steps, xsteps, ysteps, *args, **kwargs)
272
+
273
+
274
+ class NormalizedGridDistortion(A.augmentations.transforms.GridDistortion):
275
+ def apply(self, img, stepsx=(), stepsy=(), interpolation=cv2.INTER_LINEAR, **params):
276
+ return normalized_grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode,
277
+ self.value)
278
+
279
+ def apply_to_mask(self, img, stepsx=(), stepsy=(), **params):
280
+ return normalized_grid_distortion(
281
+ img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
282
+
molscribe/chemistry.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import traceback
3
+ import numpy as np
4
+ import multiprocessing
5
+
6
+ import rdkit
7
+ import rdkit.Chem as Chem
8
+
9
+ rdkit.RDLogger.DisableLog('rdApp.*')
10
+
11
+ from SmilesPE.pretokenizer import atomwise_tokenizer
12
+
13
+ from .constants import RGROUP_SYMBOLS, ABBREVIATIONS, VALENCES, FORMULA_REGEX
14
+
15
+
16
+ def is_valid_mol(s, format_='atomtok'):
17
+ if format_ == 'atomtok':
18
+ mol = Chem.MolFromSmiles(s)
19
+ elif format_ == 'inchi':
20
+ if not s.startswith('InChI=1S'):
21
+ s = f"InChI=1S/{s}"
22
+ mol = Chem.MolFromInchi(s)
23
+ else:
24
+ raise NotImplemented
25
+ return mol is not None
26
+
27
+
28
+ def _convert_smiles_to_inchi(smiles):
29
+ try:
30
+ mol = Chem.MolFromSmiles(smiles)
31
+ inchi = Chem.MolToInchi(mol)
32
+ except:
33
+ inchi = None
34
+ return inchi
35
+
36
+
37
+ def convert_smiles_to_inchi(smiles_list, num_workers=16):
38
+ with multiprocessing.Pool(num_workers) as p:
39
+ inchi_list = p.map(_convert_smiles_to_inchi, smiles_list, chunksize=128)
40
+ n_success = sum([x is not None for x in inchi_list])
41
+ r_success = n_success / len(inchi_list)
42
+ inchi_list = [x if x else 'InChI=1S/H2O/h1H2' for x in inchi_list]
43
+ return inchi_list, r_success
44
+
45
+
46
+ def merge_inchi(inchi1, inchi2):
47
+ replaced = 0
48
+ inchi1 = copy.deepcopy(inchi1)
49
+ for i in range(len(inchi1)):
50
+ if inchi1[i] == 'InChI=1S/H2O/h1H2':
51
+ inchi1[i] = inchi2[i]
52
+ replaced += 1
53
+ return inchi1, replaced
54
+
55
+
56
+ def _get_num_atoms(smiles):
57
+ try:
58
+ return Chem.MolFromSmiles(smiles).GetNumAtoms()
59
+ except:
60
+ return 0
61
+
62
+
63
+ def get_num_atoms(smiles, num_workers=16):
64
+ if type(smiles) is str:
65
+ return _get_num_atoms(smiles)
66
+ with multiprocessing.Pool(num_workers) as p:
67
+ num_atoms = p.map(_get_num_atoms, smiles)
68
+ return num_atoms
69
+
70
+
71
+ def normalize_nodes(nodes, flip_y=True):
72
+ x, y = nodes[:, 0], nodes[:, 1]
73
+ minx, maxx = min(x), max(x)
74
+ miny, maxy = min(y), max(y)
75
+ x = (x - minx) / max(maxx - minx, 1e-6)
76
+ if flip_y:
77
+ y = (maxy - y) / max(maxy - miny, 1e-6)
78
+ else:
79
+ y = (y - miny) / max(maxy - miny, 1e-6)
80
+ return np.stack([x, y], axis=1)
81
+
82
+
83
+ def _verify_chirality(mol, coords, symbols, edges, debug=False):
84
+ try:
85
+ n = mol.GetNumAtoms()
86
+ # Make a temp mol to find chiral centers
87
+ mol_tmp = mol.GetMol()
88
+ Chem.SanitizeMol(mol_tmp)
89
+
90
+ chiral_centers = Chem.FindMolChiralCenters(
91
+ mol_tmp, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False)
92
+ chiral_center_ids = [idx for idx, _ in chiral_centers] # List[Tuple[int, any]] -> List[int]
93
+
94
+ # correction to clear pre-condition violation (for some corner cases)
95
+ for bond in mol.GetBonds():
96
+ if bond.GetBondType() == Chem.BondType.SINGLE:
97
+ bond.SetBondDir(Chem.BondDir.NONE)
98
+
99
+ # Create conformer from 2D coordinate
100
+ conf = Chem.Conformer(n)
101
+ conf.Set3D(True)
102
+ for i, (x, y) in enumerate(coords):
103
+ conf.SetAtomPosition(i, (x, 1 - y, 0))
104
+ mol.AddConformer(conf)
105
+ Chem.SanitizeMol(mol)
106
+ Chem.AssignStereochemistryFrom3D(mol)
107
+ # NOTE: seems that only AssignStereochemistryFrom3D can handle double bond E/Z
108
+ # So we do this first, remove the conformer and add back the 2D conformer for chiral correction
109
+
110
+ mol.RemoveAllConformers()
111
+ conf = Chem.Conformer(n)
112
+ conf.Set3D(False)
113
+ for i, (x, y) in enumerate(coords):
114
+ conf.SetAtomPosition(i, (x, 1 - y, 0))
115
+ mol.AddConformer(conf)
116
+
117
+ # Magic, inferring chirality from coordinates and BondDir. DO NOT CHANGE.
118
+ Chem.SanitizeMol(mol)
119
+ Chem.AssignChiralTypesFromBondDirs(mol)
120
+ Chem.AssignStereochemistry(mol, force=True)
121
+
122
+ # Second loop to reset any wedge/dash bond to be starting from the chiral center)
123
+ for i in chiral_center_ids:
124
+ for j in range(n):
125
+ if edges[i][j] == 5:
126
+ # assert edges[j][i] == 6
127
+ mol.RemoveBond(i, j)
128
+ mol.AddBond(i, j, Chem.BondType.SINGLE)
129
+ mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINWEDGE)
130
+ elif edges[i][j] == 6:
131
+ # assert edges[j][i] == 5
132
+ mol.RemoveBond(i, j)
133
+ mol.AddBond(i, j, Chem.BondType.SINGLE)
134
+ mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINDASH)
135
+ Chem.AssignChiralTypesFromBondDirs(mol)
136
+ Chem.AssignStereochemistry(mol, force=True)
137
+
138
+ # reset chiral tags for non-carbon atom
139
+ for atom in mol.GetAtoms():
140
+ if atom.GetSymbol() != "C":
141
+ atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
142
+ mol = mol.GetMol()
143
+
144
+ except Exception as e:
145
+ if debug:
146
+ raise e
147
+ pass
148
+ return mol
149
+
150
+
151
+ def _parse_tokens(tokens: list):
152
+ """
153
+ Parse tokens of condensed formula into list of pairs `(elt, num)`
154
+ where `num` is the multiplicity of the atom (or nested condensed formula) `elt`
155
+ Used by `_parse_formula`, which does the same thing but takes a formula in string form as input
156
+ """
157
+ elements = []
158
+ i = 0
159
+ j = 0
160
+ while i < len(tokens):
161
+ if tokens[i] == '(':
162
+ while j < len(tokens) and tokens[j] != ')':
163
+ j += 1
164
+ elt = _parse_tokens(tokens[i + 1:j])
165
+ else:
166
+ elt = tokens[i]
167
+ j += 1
168
+ if j < len(tokens) and tokens[j].isnumeric():
169
+ num = int(tokens[j])
170
+ j += 1
171
+ else:
172
+ num = 1
173
+ elements.append((elt, num))
174
+ i = j
175
+ return elements
176
+
177
+
178
+ def _parse_formula(formula: str):
179
+ """
180
+ Parse condensed formula into list of pairs `(elt, num)`
181
+ where `num` is the subscript to the atom (or nested condensed formula) `elt`
182
+ Example: "C2H4O" -> [('C', 2), ('H', 4), ('O', 1)]
183
+ """
184
+ tokens = FORMULA_REGEX.findall(formula)
185
+ # if ''.join(tokens) != formula:
186
+ # tokens = FORMULA_REGEX_BACKUP.findall(formula)
187
+ return _parse_tokens(tokens)
188
+
189
+
190
+ def _expand_carbon(elements: list):
191
+ """
192
+ Given list of pairs `(elt, num)`, output single list of all atoms in order,
193
+ expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary
194
+ Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O'])
195
+ """
196
+ expanded = []
197
+ i = 0
198
+ while i < len(elements):
199
+ elt, num = elements[i]
200
+ # expand carbon sequence
201
+ if elt == 'C' and num > 1 and i + 1 < len(elements):
202
+ next_elt, next_num = elements[i + 1]
203
+ quotient, remainder = next_num // num, next_num % num
204
+ for _ in range(num):
205
+ expanded.append('C')
206
+ for _ in range(quotient):
207
+ expanded.append(next_elt)
208
+ for _ in range(remainder):
209
+ expanded.append(next_elt)
210
+ i += 2
211
+ # recurse if `elt` itself is a list (nested formula)
212
+ elif isinstance(elt, list):
213
+ new_elt = _expand_carbon(elt)
214
+ for _ in range(num):
215
+ expanded.append(new_elt)
216
+ i += 1
217
+ # simplest case: simply append `elt` `num` times
218
+ else:
219
+ for _ in range(num):
220
+ expanded.append(elt)
221
+ i += 1
222
+ return expanded
223
+
224
+
225
+ def _expand_abbreviation(abbrev):
226
+ """
227
+ Expand abbreviation into its SMILES; also converts [Rn] to [n*]
228
+ Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula
229
+ """
230
+ if abbrev in ABBREVIATIONS:
231
+ return ABBREVIATIONS[abbrev].smiles
232
+ if abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()):
233
+ if abbrev[1:].isdigit():
234
+ return f'[{abbrev[1:]}*]'
235
+ return '*'
236
+ return f'[{abbrev}]'
237
+
238
+
239
+ def _get_bond_symb(bond_num):
240
+ """
241
+ Get SMILES symbol for a bond given bond order
242
+ Used in `_condensed_formula_list_to_smiles` while writing the SMILES string
243
+ """
244
+ if bond_num == 0:
245
+ return '.'
246
+ if bond_num == 1:
247
+ return ''
248
+ if bond_num == 2:
249
+ return '='
250
+ if bond_num == 3:
251
+ return '#'
252
+ return ''
253
+
254
+
255
+ def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None):
256
+ """
257
+ Converts condensed formula (in the form of a list of symbols) to smiles
258
+ Input:
259
+ `formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2
260
+ `start_bond`: # bonds attached to beginning of formula
261
+ `end_bond`: # bonds attached to end of formula (deduce automatically if None)
262
+ `direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically)
263
+ Returns:
264
+ `smiles`: smiles corresponding to input condensed formula
265
+ `bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified
266
+ `num_trials`: number of trials
267
+ `success` (bool): whether conversion was successful
268
+ """
269
+ # `direction` not specified: try left to right; if fails, try right to left
270
+ if direction is None:
271
+ num_trials = 1
272
+ for dir_choice in [1, -1]:
273
+ smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice)
274
+ num_trials += trials
275
+ if success:
276
+ return smiles, bonds_left, num_trials, success
277
+ return None, None, num_trials, False
278
+ assert direction == 1 or direction == -1
279
+
280
+ def dfs(smiles, bonds_left, cur_idx, add_idx):
281
+ """
282
+ `smiles`: SMILES string so far
283
+ `cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached)
284
+ `cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far
285
+ `bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to
286
+ `add_idx`: index (in list `formula`) of atom to be attached to current atom
287
+ `add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far
288
+ Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2)
289
+ """
290
+ num_trials = 1
291
+ # end of formula: return result
292
+ if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1):
293
+ if end_bond is not None and end_bond != bonds_left:
294
+ return smiles, bonds_left, num_trials, False
295
+ return smiles, bonds_left, num_trials, True
296
+
297
+ # no more bonds but there are atoms remaining: conversion failed
298
+ if bonds_left <= 0:
299
+ return smiles, bonds_left, num_trials, False
300
+ to_add = formula_list[add_idx] # atom to be added to current atom
301
+
302
+ if isinstance(to_add, list): # "atom" added is a list (i.e. nested condensed formula): assume valence of 1
303
+ if bonds_left > 1:
304
+ # "atom" added does not use up remaining bonds of current atom
305
+ # get smiles of "atom" (which is itself a condensed formula)
306
+ add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction)
307
+ if val > 0:
308
+ add_str = _get_bond_symb(val + 1) + add_str
309
+ num_trials += trials
310
+ if not success:
311
+ return smiles, bonds_left, num_trials, False
312
+ # put smiles of "atom" in parentheses and append to smiles; go to next atom to add to current atom
313
+ result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction)
314
+ else:
315
+ # "atom" added uses up remaining bonds of current atom
316
+ # get smiles of "atom" and bonds left on it
317
+ add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction)
318
+ num_trials += trials
319
+ if not success:
320
+ return smiles, bonds_left, num_trials, False
321
+ # append smiles of "atom" (without parentheses) to smiles; it becomes new current atom
322
+ result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction)
323
+ smiles, bonds_left, trials, success = result
324
+ num_trials += trials
325
+ return smiles, bonds_left, num_trials, success
326
+
327
+ # atom added is a single symbol (as opposed to nested condensed formula)
328
+ for val in VALENCES.get(to_add, [1]): # try all possible valences of atom added
329
+ add_str = _expand_abbreviation(to_add) # expand to smiles if symbol is abbreviation
330
+ if bonds_left > val: # atom added does not use up remaining bonds of current atom; go to next atom to add to current atom
331
+ if cur_idx >= 0:
332
+ add_str = _get_bond_symb(val) + add_str
333
+ result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction)
334
+ else: # atom added uses up remaining bonds of current atom; it becomes new current atom
335
+ if cur_idx >= 0:
336
+ add_str = _get_bond_symb(bonds_left) + add_str
337
+ result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction)
338
+ trials, success = result[2:]
339
+ num_trials += trials
340
+ if success:
341
+ return result[0], result[1], num_trials, success
342
+ if num_trials > 10000:
343
+ break
344
+ return smiles, bonds_left, num_trials, False
345
+
346
+ cur_idx = -1 if direction == 1 else len(formula_list)
347
+ add_idx = 0 if direction == 1 else len(formula_list) - 1
348
+ return dfs('', start_bond, cur_idx, add_idx)
349
+
350
+
351
+ def get_smiles_from_symbol(symbol, mol, atom, bonds):
352
+ """
353
+ Convert symbol (abbrev. or condensed formula) to smiles
354
+ If condensed formula, determine parsing direction and num. bonds on each side using coordinates
355
+ """
356
+ print(symbol)
357
+ if symbol in ABBREVIATIONS:
358
+ return ABBREVIATIONS[symbol].smiles
359
+ if len(symbol) > 20:
360
+ return None
361
+
362
+ #mol_check = Chem.MolFromSmiles(symbol)
363
+ #if mol_check:
364
+ # print(symbol) # Print the symbol to debug
365
+ # return symbol
366
+
367
+ total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds]))
368
+ formula_list = _expand_carbon(_parse_formula(symbol))
369
+ smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None)
370
+ if success:
371
+ mol_check = Chem.MolFromSmiles(smiles) # Check if the SMILES is valid
372
+ if mol_check:
373
+ print(f"smiles:{smiles}") # Print the symbol to debug
374
+ return smiles
375
+
376
+
377
+ mol_check = Chem.MolFromSmiles(symbol)
378
+ if mol_check:
379
+ print(f"symbol:{symbol}") # Print the symbol to debug
380
+ return symbol
381
+
382
+ return None
383
+
384
+
385
+ def _replace_functional_group(smiles):
386
+ smiles = smiles.replace('<unk>', 'C')
387
+ for i, r in enumerate(RGROUP_SYMBOLS):
388
+ symbol = f'[{r}]'
389
+ if symbol in smiles:
390
+ if r[0] == 'R' and r[1:].isdigit():
391
+ smiles = smiles.replace(symbol, f'[{int(r[1:])}*]')
392
+ else:
393
+ smiles = smiles.replace(symbol, '*')
394
+ # For unknown tokens (i.e. rdkit cannot parse), replace them with [{isotope}*], where isotope is an identifier.
395
+ tokens = atomwise_tokenizer(smiles)
396
+ new_tokens = []
397
+ mappings = {} # isotope : symbol
398
+ isotope = 50
399
+ for token in tokens:
400
+ if token[0] == '[':
401
+ if token[1:-1] in ABBREVIATIONS or Chem.AtomFromSmiles(token) is None:
402
+ while f'[{isotope}*]' in smiles or f'[{isotope}*]' in new_tokens:
403
+ isotope += 1
404
+ placeholder = f'[{isotope}*]'
405
+ mappings[isotope] = token[1:-1]
406
+ new_tokens.append(placeholder)
407
+ continue
408
+ new_tokens.append(token)
409
+ smiles = ''.join(new_tokens)
410
+ return smiles, mappings
411
+
412
+
413
+ def convert_smiles_to_mol(smiles):
414
+ if smiles is None or smiles == '':
415
+ return None
416
+ try:
417
+ mol = Chem.MolFromSmiles(smiles)
418
+ except:
419
+ return None
420
+ return mol
421
+
422
+
423
+ BOND_TYPES = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE}
424
+
425
+
426
+ def _expand_functional_group(mol, mappings, debug=False):
427
+ def _need_expand(mol, mappings):
428
+ return any([len(Chem.GetAtomAlias(atom)) > 0 for atom in mol.GetAtoms()]) or len(mappings) > 0
429
+
430
+ if _need_expand(mol, mappings):
431
+ mol_w = Chem.RWMol(mol)
432
+ num_atoms = mol_w.GetNumAtoms()
433
+ for i, atom in enumerate(mol_w.GetAtoms()): # reset radical electrons
434
+ atom.SetNumRadicalElectrons(0)
435
+
436
+ atoms_to_remove = []
437
+ for i in range(num_atoms):
438
+ atom = mol_w.GetAtomWithIdx(i)
439
+ if atom.GetSymbol() == '*':
440
+ symbol = Chem.GetAtomAlias(atom)
441
+ isotope = atom.GetIsotope()
442
+ if isotope > 0 and isotope in mappings:
443
+ symbol = mappings[isotope]
444
+ if not (isinstance(symbol, str) and len(symbol) > 0):
445
+ continue
446
+ # rgroups do not need to be expanded
447
+ if symbol in RGROUP_SYMBOLS:
448
+ continue
449
+
450
+ bonds = atom.GetBonds()
451
+ sub_smiles = get_smiles_from_symbol(symbol, mol_w, atom, bonds)
452
+
453
+ # create mol object for abbreviation/condensed formula from its SMILES
454
+ mol_r = convert_smiles_to_mol(sub_smiles)
455
+
456
+ if mol_r is None:
457
+ # atom.SetAtomicNum(6)
458
+ atom.SetIsotope(0)
459
+ continue
460
+
461
+ # remove bonds connected to abbreviation/condensed formula
462
+ adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds]
463
+ for adjacent_idx in adjacent_indices:
464
+ mol_w.RemoveBond(i, adjacent_idx)
465
+
466
+ adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices]
467
+ for adjacent_atom, bond in zip(adjacent_atoms, bonds):
468
+ adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble()))
469
+
470
+ # get indices of atoms of main body that connect to substituent
471
+ bonding_atoms_w = adjacent_indices
472
+ # assume indices are concated after combine mol_w and mol_r
473
+ bonding_atoms_r = [mol_w.GetNumAtoms()]
474
+ for atm in mol_r.GetAtoms():
475
+ if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0:
476
+ bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx())
477
+
478
+ # combine main body and substituent into a single molecule object
479
+ combo = Chem.CombineMols(mol_w, mol_r)
480
+
481
+ # connect substituent to main body with bonds
482
+ mol_w = Chem.RWMol(combo)
483
+ # if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body
484
+ for atm in bonding_atoms_w:
485
+ bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons()
486
+ mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order])
487
+
488
+ # reset radical electrons
489
+ for atm in bonding_atoms_w:
490
+ mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
491
+ for atm in bonding_atoms_r:
492
+ mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
493
+ atoms_to_remove.append(i)
494
+
495
+ # Remove atom in the end, otherwise the id will change
496
+ # Reverse the order and remove atoms with larger id first
497
+ atoms_to_remove.sort(reverse=True)
498
+ for i in atoms_to_remove:
499
+ mol_w.RemoveAtom(i)
500
+ smiles = Chem.MolToSmiles(mol_w)
501
+ mol = mol_w.GetMol()
502
+ else:
503
+ smiles = Chem.MolToSmiles(mol)
504
+ return smiles, mol
505
+
506
+
507
+ def _convert_graph_to_smiles(coords, symbols, edges, image=None, debug=False):
508
+ mol = Chem.RWMol()
509
+ n = len(symbols)
510
+ ids = []
511
+ for i in range(n):
512
+ symbol = symbols[i]
513
+ if symbol[0] == '[':
514
+ symbol = symbol[1:-1]
515
+ if symbol in RGROUP_SYMBOLS:
516
+ atom = Chem.Atom("*")
517
+ if symbol[0] == 'R' and symbol[1:].isdigit():
518
+ atom.SetIsotope(int(symbol[1:]))
519
+ Chem.SetAtomAlias(atom, symbol)
520
+ elif symbol in ABBREVIATIONS:
521
+ atom = Chem.Atom("*")
522
+ Chem.SetAtomAlias(atom, symbol)
523
+ else:
524
+ try: # try to get SMILES of atom
525
+ atom = Chem.AtomFromSmiles(symbols[i])
526
+ atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
527
+ except: # otherwise, abbreviation or condensed formula
528
+ atom = Chem.Atom("*")
529
+ Chem.SetAtomAlias(atom, symbol)
530
+
531
+ if atom.GetSymbol() == '*':
532
+ atom.SetProp('molFileAlias', symbol)
533
+
534
+ idx = mol.AddAtom(atom)
535
+ assert idx == i
536
+ ids.append(idx)
537
+
538
+ for i in range(n):
539
+ for j in range(i + 1, n):
540
+ if edges[i][j] == 1:
541
+ mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
542
+ elif edges[i][j] == 2:
543
+ mol.AddBond(ids[i], ids[j], Chem.BondType.DOUBLE)
544
+ elif edges[i][j] == 3:
545
+ mol.AddBond(ids[i], ids[j], Chem.BondType.TRIPLE)
546
+ elif edges[i][j] == 4:
547
+ mol.AddBond(ids[i], ids[j], Chem.BondType.AROMATIC)
548
+ elif edges[i][j] == 5:
549
+ mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
550
+ mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINWEDGE)
551
+ elif edges[i][j] == 6:
552
+ mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
553
+ mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINDASH)
554
+
555
+ pred_smiles = '<invalid>'
556
+
557
+ try:
558
+ # TODO: move to an util function
559
+ if image is not None:
560
+ height, width, _ = image.shape
561
+ ratio = width / height
562
+ coords = [[x * ratio * 10, y * 10] for x, y in coords]
563
+ mol = _verify_chirality(mol, coords, symbols, edges, debug)
564
+ # molblock is obtained before expanding func groups, otherwise the expanded group won't have coordinates.
565
+ # TODO: make sure molblock has the abbreviation information
566
+ pred_molblock = Chem.MolToMolBlock(mol)
567
+ pred_smiles, mol = _expand_functional_group(mol, {}, debug)
568
+ success = True
569
+ except Exception as e:
570
+ if debug:
571
+ print(traceback.format_exc())
572
+ pred_molblock = ''
573
+ success = False
574
+
575
+ if debug:
576
+ return pred_smiles, pred_molblock, mol, success
577
+ return pred_smiles, pred_molblock, success
578
+
579
+
580
+ def convert_graph_to_smiles(coords, symbols, edges, images=None, num_workers=16):
581
+ with multiprocessing.Pool(num_workers) as p:
582
+ if images is None:
583
+ results = p.starmap(_convert_graph_to_smiles, zip(coords, symbols, edges), chunksize=128)
584
+ else:
585
+ results = p.starmap(_convert_graph_to_smiles, zip(coords, symbols, edges, images), chunksize=128)
586
+ smiles_list, molblock_list, success = zip(*results)
587
+ r_success = np.mean(success)
588
+ return smiles_list, molblock_list, r_success
589
+
590
+
591
+ def _postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, debug=False):
592
+ if type(smiles) is not str or smiles == '':
593
+ return '', False
594
+ mol = None
595
+ pred_molblock = ''
596
+ try:
597
+ pred_smiles = smiles
598
+ pred_smiles, mappings = _replace_functional_group(pred_smiles)
599
+ if coords is not None and symbols is not None and edges is not None:
600
+ pred_smiles = pred_smiles.replace('@', '').replace('/', '').replace('\\', '')
601
+ mol = Chem.RWMol(Chem.MolFromSmiles(pred_smiles, sanitize=False))
602
+ mol = _verify_chirality(mol, coords, symbols, edges, debug)
603
+ else:
604
+ mol = Chem.MolFromSmiles(pred_smiles, sanitize=False)
605
+ # pred_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
606
+ if molblock:
607
+ pred_molblock = Chem.MolToMolBlock(mol)
608
+ pred_smiles, mol = _expand_functional_group(mol, mappings)
609
+ success = True
610
+ except Exception as e:
611
+ if debug:
612
+ print(traceback.format_exc())
613
+ pred_smiles = smiles
614
+ pred_molblock = ''
615
+ success = False
616
+ if debug:
617
+ return pred_smiles, pred_molblock, mol, success
618
+ return pred_smiles, pred_molblock, success
619
+
620
+
621
+ def postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, num_workers=16):
622
+ with multiprocessing.Pool(num_workers) as p:
623
+ if coords is not None and symbols is not None and edges is not None:
624
+ results = p.starmap(_postprocess_smiles, zip(smiles, coords, symbols, edges), chunksize=128)
625
+ else:
626
+ results = p.map(_postprocess_smiles, smiles, chunksize=128)
627
+ smiles_list, molblock_list, success = zip(*results)
628
+ r_success = np.mean(success)
629
+ return smiles_list, molblock_list, r_success
630
+
631
+
632
+ def _keep_main_molecule(smiles, debug=False):
633
+ try:
634
+ mol = Chem.MolFromSmiles(smiles)
635
+ frags = Chem.GetMolFrags(mol, asMols=True)
636
+ if len(frags) > 1:
637
+ num_atoms = [m.GetNumAtoms() for m in frags]
638
+ main_mol = frags[np.argmax(num_atoms)]
639
+ smiles = Chem.MolToSmiles(main_mol)
640
+ except Exception as e:
641
+ if debug:
642
+ print(traceback.format_exc())
643
+ return smiles
644
+
645
+
646
+ def keep_main_molecule(smiles, num_workers=16):
647
+ with multiprocessing.Pool(num_workers) as p:
648
+ results = p.map(_keep_main_molecule, smiles, chunksize=128)
649
+ return results
molscribe/constants.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import re
3
+
4
+ ORGANIC_SET = {'B', 'C', 'N', 'O', 'P', 'S', 'F', 'Cl', 'Br', 'I'}
5
+
6
+ RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', "R'",
7
+ 'Ra', 'Rb', 'Rc', 'Rd', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar']
8
+
9
+ PLACEHOLDER_ATOMS = ["Lv", "Lu", "Nd", "Yb", "At", "Fm", "Er"]
10
+
11
+
12
+ class Substitution(object):
13
+ '''Define common substitutions for chemical shorthand'''
14
+ def __init__(self, abbrvs, smarts, smiles, probability):
15
+ assert type(abbrvs) is list
16
+ self.abbrvs = abbrvs
17
+ self.smarts = smarts
18
+ self.smiles = smiles
19
+ self.probability = probability
20
+
21
+
22
+ SUBSTITUTIONS: List[Substitution] = [
23
+ Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5),
24
+ Substitution(['OCOCH3'], '[#8]-[#6](=[#8])-[#6]', "[O]C(=O)C]", 0.5),
25
+ Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5),
26
+ Substitution(['CO2Et', 'COOEt', 'EtO2C'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5),
27
+
28
+ Substitution(['OAc'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7),
29
+ Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7),
30
+ Substitution(['Ac'], 'C(=O)[CH3]', "[C](=O)C", 0.1),
31
+
32
+ Substitution(['OBz'], '[OH0;D2]C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[O]C(=O)c1ccccc1", 0.7), # Benzoyl
33
+ Substitution(['Bz'], 'C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)c1ccccc1", 0.2), # Benzoyl
34
+
35
+ Substitution(['OBn'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[O]Cc1ccccc1", 0.7), # Benzyl
36
+ Substitution(['Bn'], '[CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[CH2]c1ccccc1", 0.2), # Benzyl
37
+
38
+ Substitution(['NHBoc'], '[NH1;D2]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6),
39
+ Substitution(['NBoc'], '[NH0;D3]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6),
40
+ Substitution(['Boc'], 'C(=O)OC([CH3])([CH3])[CH3]', "[C](=O)OC(C)(C)C", 0.2),
41
+
42
+ Substitution(['Cbm'], 'C(=O)[NH2;D1]', "[C](=O)N", 0.2),
43
+ Substitution(['Cbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[C](=O)OCc1ccccc1", 0.4),
44
+ Substitution(['Cy'], '[CH1;X3]1[CH2][CH2][CH2][CH2][CH2]1', "[CH1]1CCCCC1", 0.3),
45
+ Substitution(['Fmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3',
46
+ "[C](=O)OCC1c(cccc2)c2c3c1cccc3", 0.6),
47
+ Substitution(['Mes'], '[cH0]1c([CH3])cc([CH3])cc([CH3])1', "[c]1c(C)cc(C)cc(C)1", 0.5),
48
+ Substitution(['OMs'], '[OH0;D2]S(=O)(=O)[CH3]', "[O]S(=O)(=O)C", 0.7),
49
+ Substitution(['Ms'], 'S(=O)(=O)[CH3]', "[S](=O)(=O)C", 0.2),
50
+ Substitution(['Ph'], '[cH0]1[cH][cH][cH1][cH][cH]1', "[c]1ccccc1", 0.5),
51
+ Substitution(['PMB'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[CH2]c1ccc(OC)cc1", 0.2),
52
+ Substitution(['Py'], '[cH0]1[n;+0][cH1][cH1][cH1][cH1]1', "[c]1ncccc1", 0.1),
53
+ Substitution(['SEM'], '[CH2;D2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]CSi(C)(C)C", 0.2),
54
+ Substitution(['Suc'], 'C(=O)[CH2][CH2]C(=O)[OH]', "[C](=O)CCC(=O)O", 0.2),
55
+ Substitution(['TBS'], '[Si]([CH3])([CH3])C([CH3])([CH3])[CH3]', "[Si](C)(C)C(C)(C)C", 0.5),
56
+ Substitution(['TBZ'], 'C(=S)[cH]1[cH][cH][cH1][cH][cH]1', "[C](=S)c1ccccc1", 0.2),
57
+ Substitution(['OTf'], '[OH0;D2]S(=O)(=O)C(F)(F)F', "[O]S(=O)(=O)C(F)(F)F", 0.7),
58
+ Substitution(['Tf'], 'S(=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.2),
59
+ Substitution(['TFA'], 'C(=O)C(F)(F)F', "[C](=O)C(F)(F)F", 0.3),
60
+ Substitution(['TMS'], '[Si]([CH3])([CH3])[CH3]', "[Si](C)(C)C", 0.5),
61
+ Substitution(['Ts'], 'S(=O)(=O)c1[cH1][cH1][cH0]([CH3])[cH1][cH1]1', "[S](=O)(=O)c1ccc(C)cc1", 0.6), # Tos
62
+
63
+ # Alkyl chains
64
+ Substitution(['OMe', 'MeO'], '[OH0;D2][CH3;D1]', "[O]C", 0.3),
65
+ Substitution(['SMe', 'MeS'], '[SH0;D2][CH3;D1]', "[S]C", 0.3),
66
+ Substitution(['NMe', 'MeN'], '[N;X3][CH3;D1]', "[NH]C", 0.3),
67
+ Substitution(['Me'], '[CH3;D1]', "[CH3]", 0.1),
68
+ Substitution(['OEt', 'EtO'], '[OH0;D2][CH2;D2][CH3]', "[O]CC", 0.5),
69
+ Substitution(['Et', 'C2H5'], '[CH2;D2][CH3]', "[CH2]C", 0.3),
70
+ Substitution(['Pr', 'nPr', 'n-Pr'], '[CH2;D2][CH2;D2][CH3]', "[CH2]CC", 0.3),
71
+ Substitution(['Bu', 'nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3),
72
+
73
+ # Branched
74
+ Substitution(['iPr', 'i-Pr'], '[CH1;D3]([CH3])[CH3]', "[CH1](C)C", 0.2),
75
+ Substitution(['iBu', 'i-Bu'], '[CH2;D2][CH1;D3]([CH3])[CH3]', "[CH2]C(C)C", 0.2),
76
+ Substitution(['OiBu'], '[OH0;D2][CH2;D2][CH1;D3]([CH3])[CH3]', "[O]CC(C)C", 0.2),
77
+ Substitution(['OtBu'], '[OH0;D2][CH0]([CH3])([CH3])[CH3]', "[O]C(C)(C)C", 0.6),
78
+ Substitution(['tBu', 't-Bu'], '[CH0]([CH3])([CH3])[CH3]', "[C](C)(C)C", 0.3),
79
+
80
+ # Other shorthands (MIGHT NOT WANT ALL OF THESE)
81
+ Substitution(['CF3', 'F3C'], '[CH0;D4](F)(F)F', "[C](F)(F)F", 0.5),
82
+ Substitution(['NCF3', 'F3CN'], '[N;X3][CH0;D4](F)(F)F', "[NH]C(F)(F)F", 0.5),
83
+ Substitution(['OCF3', 'F3CO'], '[OH0;X2][CH0;D4](F)(F)F', "[O]C(F)(F)F", 0.5),
84
+ Substitution(['CCl3'], '[CH0;D4](Cl)(Cl)Cl', "[C](Cl)(Cl)Cl", 0.5),
85
+ Substitution(['CO2H', 'HO2C', 'COOH'], 'C(=O)[OH]', "[C](=O)O", 0.5), # COOH
86
+ Substitution(['CN', 'NC'], 'C#[ND1]', "[C]#N", 0.5),
87
+ Substitution(['OCH3', 'H3CO'], '[OH0;D2][CH3]', "[O]C", 0.4),
88
+ Substitution(['SO3H'], 'S(=O)(=O)[OH]', "[S](=O)(=O)O", 0.4),
89
+ Substitution(['CH3O'], '[OH0;D2][CH3]', "[O]C", 0),
90
+ Substitution(['PhCH2CH2'], '[OH0;D2][CH3]', "C1=CC=CC=C1CC", 0),
91
+ Substitution(['SO2ToI','SO2Tol'], '[OH0;D2][CH3]', "CS(=O)(=O)C1=CC=CC=C1", 0),
92
+
93
+
94
+
95
+
96
+
97
+
98
+ ]
99
+
100
+ ABBREVIATIONS = {abbrv: sub for sub in SUBSTITUTIONS for abbrv in sub.abbrvs}
101
+
102
+ VALENCES = {
103
+ "H": [1], "Li": [1], "Be": [2], "B": [3], "C": [4], "N": [3, 5], "O": [2], "F": [1],
104
+ "Na": [1], "Mg": [2], "Al": [3], "Si": [4], "P": [5, 3], "S": [6, 2, 4], "Cl": [1], "K": [1], "Ca": [2],
105
+ "Br": [1], "I": [1]
106
+ }
107
+
108
+ ELEMENTS = [
109
+ "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
110
+ "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca",
111
+ "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
112
+ "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr",
113
+ "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn",
114
+ "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
115
+ "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb",
116
+ "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg",
117
+ "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
118
+ "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm",
119
+ "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds",
120
+ "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
121
+ ]
122
+
123
+ COLORS = {
124
+ u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0',
125
+ u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75'
126
+ }
127
+
128
+ # tokens of condensed formula
129
+ FORMULA_REGEX = re.compile(
130
+ '(' + '|'.join(list(ABBREVIATIONS.keys())) + '|R[0-9]*|[A-Z][a-z]+|[A-Z]|[0-9]+|\(|\))')
molscribe/dataset.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import random
5
+ import re
6
+ import string
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
13
+ import albumentations as A
14
+ from albumentations.pytorch import ToTensorV2
15
+
16
+ from .indigo import Indigo
17
+ from .indigo.renderer import IndigoRenderer
18
+
19
+ from .augment import SafeRotate, CropWhite, PadWhite, SaltAndPepperNoise
20
+ from .utils import FORMAT_INFO
21
+ from .tokenizer import PAD_ID
22
+ from .chemistry import get_num_atoms, normalize_nodes
23
+ from .constants import RGROUP_SYMBOLS, SUBSTITUTIONS, ELEMENTS, COLORS
24
+
25
+ cv2.setNumThreads(1)
26
+
27
+ INDIGO_HYGROGEN_PROB = 0.2
28
+ INDIGO_FUNCTIONAL_GROUP_PROB = 0.8
29
+ INDIGO_CONDENSED_PROB = 0.5
30
+ INDIGO_RGROUP_PROB = 0.5
31
+ INDIGO_COMMENT_PROB = 0.3
32
+ INDIGO_DEARMOTIZE_PROB = 0.8
33
+ INDIGO_COLOR_PROB = 0.2
34
+
35
+
36
+ def get_transforms(input_size, augment=True, rotate=True, debug=False):
37
+ trans_list = []
38
+ if augment and rotate:
39
+ trans_list.append(SafeRotate(limit=90, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255)))
40
+ trans_list.append(CropWhite(pad=5))
41
+ if augment:
42
+ trans_list += [
43
+ # NormalizedGridDistortion(num_steps=10, distort_limit=0.3),
44
+ A.CropAndPad(percent=[-0.01, 0.00], keep_size=False, p=0.5),
45
+ PadWhite(pad_ratio=0.4, p=0.2),
46
+ A.Downscale(scale_min=0.2, scale_max=0.5, interpolation=3),
47
+ A.Blur(),
48
+ A.GaussNoise(),
49
+ SaltAndPepperNoise(num_dots=20, p=0.5)
50
+ ]
51
+ trans_list.append(A.Resize(input_size, input_size))
52
+ if not debug:
53
+ mean = [0.485, 0.456, 0.406]
54
+ std = [0.229, 0.224, 0.225]
55
+ trans_list += [
56
+ A.ToGray(p=1),
57
+ A.Normalize(mean=mean, std=std),
58
+ ToTensorV2(),
59
+ ]
60
+ return A.Compose(trans_list, keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))
61
+
62
+
63
+ def add_functional_group(indigo, mol, debug=False):
64
+ if random.random() > INDIGO_FUNCTIONAL_GROUP_PROB:
65
+ return mol
66
+ # Delete functional group and add a pseudo atom with its abbrv
67
+ substitutions = [sub for sub in SUBSTITUTIONS]
68
+ random.shuffle(substitutions)
69
+ for sub in substitutions:
70
+ query = indigo.loadSmarts(sub.smarts)
71
+ matcher = indigo.substructureMatcher(mol)
72
+ matched_atoms_ids = set()
73
+ for match in matcher.iterateMatches(query):
74
+ if random.random() < sub.probability or debug:
75
+ atoms = []
76
+ atoms_ids = set()
77
+ for item in query.iterateAtoms():
78
+ atom = match.mapAtom(item)
79
+ atoms.append(atom)
80
+ atoms_ids.add(atom.index())
81
+ if len(matched_atoms_ids.intersection(atoms_ids)) > 0:
82
+ continue
83
+ abbrv = random.choice(sub.abbrvs)
84
+ superatom = mol.addAtom(abbrv)
85
+ for atom in atoms:
86
+ for nei in atom.iterateNeighbors():
87
+ if nei.index() not in atoms_ids:
88
+ if nei.symbol() == 'H':
89
+ # indigo won't match explicit hydrogen, so remove them explicitly
90
+ atoms_ids.add(nei.index())
91
+ else:
92
+ superatom.addBond(nei, nei.bond().bondOrder())
93
+ for id in atoms_ids:
94
+ mol.getAtom(id).remove()
95
+ matched_atoms_ids = matched_atoms_ids.union(atoms_ids)
96
+ return mol
97
+
98
+
99
+ def add_explicit_hydrogen(indigo, mol):
100
+ atoms = []
101
+ for atom in mol.iterateAtoms():
102
+ try:
103
+ hs = atom.countImplicitHydrogens()
104
+ if hs > 0:
105
+ atoms.append((atom, hs))
106
+ except:
107
+ continue
108
+ if len(atoms) > 0 and random.random() < INDIGO_HYGROGEN_PROB:
109
+ atom, hs = random.choice(atoms)
110
+ for i in range(hs):
111
+ h = mol.addAtom('H')
112
+ h.addBond(atom, 1)
113
+ return mol
114
+
115
+
116
+ def add_rgroup(indigo, mol, smiles):
117
+ atoms = []
118
+ for atom in mol.iterateAtoms():
119
+ try:
120
+ hs = atom.countImplicitHydrogens()
121
+ if hs > 0:
122
+ atoms.append(atom)
123
+ except:
124
+ continue
125
+ if len(atoms) > 0 and '*' not in smiles:
126
+ if random.random() < INDIGO_RGROUP_PROB:
127
+ atom_idx = random.choice(range(len(atoms)))
128
+ atom = atoms[atom_idx]
129
+ atoms.pop(atom_idx)
130
+ symbol = random.choice(RGROUP_SYMBOLS)
131
+ r = mol.addAtom(symbol)
132
+ r.addBond(atom, 1)
133
+ return mol
134
+
135
+
136
+ def get_rand_symb():
137
+ symb = random.choice(ELEMENTS)
138
+ if random.random() < 0.1:
139
+ symb += random.choice(string.ascii_lowercase)
140
+ if random.random() < 0.1:
141
+ symb += random.choice(string.ascii_uppercase)
142
+ if random.random() < 0.1:
143
+ symb = f'({gen_rand_condensed()})'
144
+ return symb
145
+
146
+
147
+ def get_rand_num():
148
+ if random.random() < 0.9:
149
+ if random.random() < 0.8:
150
+ return ''
151
+ else:
152
+ return str(random.randint(2, 9))
153
+ else:
154
+ return '1' + str(random.randint(2, 9))
155
+
156
+
157
+ def gen_rand_condensed():
158
+ tokens = []
159
+ for i in range(5):
160
+ if i >= 1 and random.random() < 0.8:
161
+ break
162
+ tokens.append(get_rand_symb())
163
+ tokens.append(get_rand_num())
164
+ return ''.join(tokens)
165
+
166
+
167
+ def add_rand_condensed(indigo, mol):
168
+ atoms = []
169
+ for atom in mol.iterateAtoms():
170
+ try:
171
+ hs = atom.countImplicitHydrogens()
172
+ if hs > 0:
173
+ atoms.append(atom)
174
+ except:
175
+ continue
176
+ if len(atoms) > 0 and random.random() < INDIGO_CONDENSED_PROB:
177
+ atom = random.choice(atoms)
178
+ symbol = gen_rand_condensed()
179
+ r = mol.addAtom(symbol)
180
+ r.addBond(atom, 1)
181
+ return mol
182
+
183
+
184
+ def generate_output_smiles(indigo, mol):
185
+ # TODO: if using mol.canonicalSmiles(), explicit H will be removed
186
+ smiles = mol.smiles()
187
+ mol = indigo.loadMolecule(smiles)
188
+ if '*' in smiles:
189
+ part_a, part_b = smiles.split(' ', maxsplit=1)
190
+ part_b = re.search(r'\$.*\$', part_b).group(0)[1:-1]
191
+ symbols = [t for t in part_b.split(';') if len(t) > 0]
192
+ output = ''
193
+ cnt = 0
194
+ for i, c in enumerate(part_a):
195
+ if c != '*':
196
+ output += c
197
+ else:
198
+ output += f'[{symbols[cnt]}]'
199
+ cnt += 1
200
+ return mol, output
201
+ else:
202
+ if ' ' in smiles:
203
+ # special cases with extension
204
+ smiles = smiles.split(' ')[0]
205
+ return mol, smiles
206
+
207
+
208
+ def add_comment(indigo):
209
+ if random.random() < INDIGO_COMMENT_PROB:
210
+ indigo.setOption('render-comment', str(random.randint(1, 20)) + random.choice(string.ascii_letters))
211
+ indigo.setOption('render-comment-font-size', random.randint(40, 60))
212
+ indigo.setOption('render-comment-alignment', random.choice([0, 0.5, 1]))
213
+ indigo.setOption('render-comment-position', random.choice(['top', 'bottom']))
214
+ indigo.setOption('render-comment-offset', random.randint(2, 30))
215
+
216
+
217
+ def add_color(indigo, mol):
218
+ if random.random() < INDIGO_COLOR_PROB:
219
+ indigo.setOption('render-coloring', True)
220
+ if random.random() < INDIGO_COLOR_PROB:
221
+ indigo.setOption('render-base-color', random.choice(list(COLORS.values())))
222
+ if random.random() < INDIGO_COLOR_PROB:
223
+ if random.random() < 0.5:
224
+ indigo.setOption('render-highlight-color-enabled', True)
225
+ indigo.setOption('render-highlight-color', random.choice(list(COLORS.values())))
226
+ if random.random() < 0.5:
227
+ indigo.setOption('render-highlight-thickness-enabled', True)
228
+ for atom in mol.iterateAtoms():
229
+ if random.random() < 0.1:
230
+ atom.highlight()
231
+ return mol
232
+
233
+
234
+ def get_graph(mol, image, shuffle_nodes=False, pseudo_coords=False):
235
+ mol.layout()
236
+ coords, symbols = [], []
237
+ index_map = {}
238
+ atoms = [atom for atom in mol.iterateAtoms()]
239
+ if shuffle_nodes:
240
+ random.shuffle(atoms)
241
+ for i, atom in enumerate(atoms):
242
+ if pseudo_coords:
243
+ x, y, z = atom.xyz()
244
+ else:
245
+ x, y = atom.coords()
246
+ coords.append([x, y])
247
+ symbols.append(atom.symbol())
248
+ index_map[atom.index()] = i
249
+ if pseudo_coords:
250
+ coords = normalize_nodes(np.array(coords))
251
+ h, w, _ = image.shape
252
+ coords[:, 0] = coords[:, 0] * w
253
+ coords[:, 1] = coords[:, 1] * h
254
+ n = len(symbols)
255
+ edges = np.zeros((n, n), dtype=int)
256
+ for bond in mol.iterateBonds():
257
+ s = index_map[bond.source().index()]
258
+ t = index_map[bond.destination().index()]
259
+ # 1/2/3/4 : single/double/triple/aromatic
260
+ edges[s, t] = bond.bondOrder()
261
+ edges[t, s] = bond.bondOrder()
262
+ if bond.bondStereo() in [5, 6]:
263
+ edges[s, t] = bond.bondStereo()
264
+ edges[t, s] = 11 - bond.bondStereo()
265
+ graph = {
266
+ 'coords': coords,
267
+ 'symbols': symbols,
268
+ 'edges': edges,
269
+ 'num_atoms': len(symbols)
270
+ }
271
+ return graph
272
+
273
+
274
+ def generate_indigo_image(smiles, mol_augment=True, default_option=False, shuffle_nodes=False, pseudo_coords=False,
275
+ include_condensed=True, debug=False):
276
+ indigo = Indigo()
277
+ renderer = IndigoRenderer(indigo)
278
+ indigo.setOption('render-output-format', 'png')
279
+ indigo.setOption('render-background-color', '1,1,1')
280
+ indigo.setOption('render-stereo-style', 'none')
281
+ indigo.setOption('render-label-mode', 'hetero')
282
+ indigo.setOption('render-font-family', 'Arial')
283
+ if not default_option:
284
+ thickness = random.uniform(0.5, 2) # limit the sum of the following two parameters to be smaller than 4
285
+ indigo.setOption('render-relative-thickness', thickness)
286
+ indigo.setOption('render-bond-line-width', random.uniform(1, 4 - thickness))
287
+ if random.random() < 0.5:
288
+ indigo.setOption('render-font-family', random.choice(['Arial', 'Times', 'Courier', 'Helvetica']))
289
+ indigo.setOption('render-label-mode', random.choice(['hetero', 'terminal-hetero']))
290
+ indigo.setOption('render-implicit-hydrogens-visible', random.choice([True, False]))
291
+ if random.random() < 0.1:
292
+ indigo.setOption('render-stereo-style', 'old')
293
+ if random.random() < 0.2:
294
+ indigo.setOption('render-atom-ids-visible', True)
295
+
296
+ try:
297
+ mol = indigo.loadMolecule(smiles)
298
+ if mol_augment:
299
+ if random.random() < INDIGO_DEARMOTIZE_PROB:
300
+ mol.dearomatize()
301
+ else:
302
+ mol.aromatize()
303
+ smiles = mol.canonicalSmiles()
304
+ add_comment(indigo)
305
+ mol = add_explicit_hydrogen(indigo, mol)
306
+ mol = add_rgroup(indigo, mol, smiles)
307
+ if include_condensed:
308
+ mol = add_rand_condensed(indigo, mol)
309
+ mol = add_functional_group(indigo, mol, debug)
310
+ mol = add_color(indigo, mol)
311
+ mol, smiles = generate_output_smiles(indigo, mol)
312
+
313
+ buf = renderer.renderToBuffer(mol)
314
+ img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1) # decode buffer to image
315
+ # img = np.repeat(np.expand_dims(img, 2), 3, axis=2) # expand to RGB
316
+ graph = get_graph(mol, img, shuffle_nodes, pseudo_coords)
317
+ success = True
318
+ except Exception:
319
+ if debug:
320
+ raise Exception
321
+ img = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32)
322
+ graph = {}
323
+ success = False
324
+ return img, smiles, graph, success
325
+
326
+
327
+ class TrainDataset(Dataset):
328
+ def __init__(self, args, df, tokenizer, split='train', dynamic_indigo=False):
329
+ super().__init__()
330
+ self.df = df
331
+ self.args = args
332
+ self.tokenizer = tokenizer
333
+ if 'file_path' in df.columns:
334
+ self.file_paths = df['file_path'].values
335
+ if not self.file_paths[0].startswith(args.data_path):
336
+ self.file_paths = [os.path.join(args.data_path, path) for path in df['file_path']]
337
+ self.smiles = df['SMILES'].values if 'SMILES' in df.columns else None
338
+ self.formats = args.formats
339
+ self.labelled = (split == 'train')
340
+ if self.labelled:
341
+ self.labels = {}
342
+ for format_ in self.formats:
343
+ if format_ in ['atomtok', 'inchi']:
344
+ field = FORMAT_INFO[format_]['name']
345
+ if field in df.columns:
346
+ self.labels[format_] = df[field].values
347
+ self.transform = get_transforms(args.input_size,
348
+ augment=(self.labelled and args.augment))
349
+ # self.fix_transform = A.Compose([A.Transpose(p=1), A.VerticalFlip(p=1)])
350
+ self.dynamic_indigo = (dynamic_indigo and split == 'train')
351
+ if self.labelled and not dynamic_indigo and args.coords_file is not None:
352
+ if args.coords_file == 'aux_file':
353
+ self.coords_df = df
354
+ self.pseudo_coords = True
355
+ else:
356
+ self.coords_df = pd.read_csv(args.coords_file)
357
+ self.pseudo_coords = False
358
+ else:
359
+ self.coords_df = None
360
+ self.pseudo_coords = args.pseudo_coords
361
+
362
+ def __len__(self):
363
+ return len(self.df)
364
+
365
+ def image_transform(self, image, coords=[], renormalize=False):
366
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # .astype(np.float32)
367
+ augmented = self.transform(image=image, keypoints=coords)
368
+ image = augmented['image']
369
+ if len(coords) > 0:
370
+ coords = np.array(augmented['keypoints'])
371
+ if renormalize:
372
+ coords = normalize_nodes(coords, flip_y=False)
373
+ else:
374
+ _, height, width = image.shape
375
+ coords[:, 0] = coords[:, 0] / width
376
+ coords[:, 1] = coords[:, 1] / height
377
+ coords = np.array(coords).clip(0, 1)
378
+ return image, coords
379
+ return image
380
+
381
+ def __getitem__(self, idx):
382
+ try:
383
+ return self.getitem(idx)
384
+ except Exception as e:
385
+ with open(os.path.join(self.args.save_path, f'error_dataset_{int(time.time())}.log'), 'w') as f:
386
+ f.write(str(e))
387
+ raise e
388
+
389
+ def getitem(self, idx):
390
+ ref = {}
391
+ if self.dynamic_indigo:
392
+ begin = time.time()
393
+ image, smiles, graph, success = generate_indigo_image(
394
+ self.smiles[idx], mol_augment=self.args.mol_augment, default_option=self.args.default_option,
395
+ shuffle_nodes=self.args.shuffle_nodes, pseudo_coords=self.pseudo_coords,
396
+ include_condensed=self.args.include_condensed)
397
+ # raw_image = image
398
+ end = time.time()
399
+ if idx < 30 and self.args.save_image:
400
+ path = os.path.join(self.args.save_path, 'images')
401
+ os.makedirs(path, exist_ok=True)
402
+ cv2.imwrite(os.path.join(path, f'{idx}.png'), image)
403
+ if not success:
404
+ return idx, None, {}
405
+ image, coords = self.image_transform(image, graph['coords'], renormalize=self.pseudo_coords)
406
+ graph['coords'] = coords
407
+ ref['time'] = end - begin
408
+ if 'atomtok' in self.formats:
409
+ max_len = FORMAT_INFO['atomtok']['max_len']
410
+ label = self.tokenizer['atomtok'].text_to_sequence(smiles, tokenized=False)
411
+ ref['atomtok'] = torch.LongTensor(label[:max_len])
412
+ if 'edges' in self.formats and 'atomtok_coords' not in self.formats and 'chartok_coords' not in self.formats:
413
+ ref['edges'] = torch.tensor(graph['edges'])
414
+ if 'atomtok_coords' in self.formats:
415
+ self._process_atomtok_coords(idx, ref, smiles, graph['coords'], graph['edges'],
416
+ mask_ratio=self.args.mask_ratio)
417
+ if 'chartok_coords' in self.formats:
418
+ self._process_chartok_coords(idx, ref, smiles, graph['coords'], graph['edges'],
419
+ mask_ratio=self.args.mask_ratio)
420
+ return idx, image, ref
421
+ else:
422
+ file_path = self.file_paths[idx]
423
+ image = cv2.imread(file_path)
424
+ if image is None:
425
+ image = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32)
426
+ print(file_path, 'not found!')
427
+ if self.coords_df is not None:
428
+ h, w, _ = image.shape
429
+ coords = np.array(eval(self.coords_df.loc[idx, 'node_coords']))
430
+ if self.pseudo_coords:
431
+ coords = normalize_nodes(coords)
432
+ coords[:, 0] = coords[:, 0] * w
433
+ coords[:, 1] = coords[:, 1] * h
434
+ image, coords = self.image_transform(image, coords, renormalize=self.pseudo_coords)
435
+ else:
436
+ image = self.image_transform(image)
437
+ coords = None
438
+ if self.labelled:
439
+ smiles = self.smiles[idx]
440
+ if 'atomtok' in self.formats:
441
+ max_len = FORMAT_INFO['atomtok']['max_len']
442
+ label = self.tokenizer['atomtok'].text_to_sequence(smiles, False)
443
+ ref['atomtok'] = torch.LongTensor(label[:max_len])
444
+ if 'atomtok_coords' in self.formats:
445
+ if coords is not None:
446
+ self._process_atomtok_coords(idx, ref, smiles, coords, mask_ratio=0)
447
+ else:
448
+ self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1)
449
+ if 'chartok_coords' in self.formats:
450
+ if coords is not None:
451
+ self._process_chartok_coords(idx, ref, smiles, coords, mask_ratio=0)
452
+ else:
453
+ self._process_chartok_coords(idx, ref, smiles, mask_ratio=1)
454
+ if self.args.predict_coords and ('atomtok_coords' in self.formats or 'chartok_coords' in self.formats):
455
+ smiles = self.smiles[idx]
456
+ if 'atomtok_coords' in self.formats:
457
+ self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1)
458
+ if 'chartok_coords' in self.formats:
459
+ self._process_chartok_coords(idx, ref, smiles, mask_ratio=1)
460
+ return idx, image, ref
461
+
462
+ def _process_atomtok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0):
463
+ max_len = FORMAT_INFO['atomtok_coords']['max_len']
464
+ tokenizer = self.tokenizer['atomtok_coords']
465
+ if smiles is None or type(smiles) is not str:
466
+ smiles = ""
467
+ label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio)
468
+ ref['atomtok_coords'] = torch.LongTensor(label[:max_len])
469
+ indices = [i for i in indices if i < max_len]
470
+ ref['atom_indices'] = torch.LongTensor(indices)
471
+ if tokenizer.continuous_coords:
472
+ if coords is not None:
473
+ ref['coords'] = torch.tensor(coords)
474
+ else:
475
+ ref['coords'] = torch.ones(len(indices), 2) * -1.
476
+ if edges is not None:
477
+ ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)]
478
+ else:
479
+ if 'edges' in self.df.columns:
480
+ edge_list = eval(self.df.loc[idx, 'edges'])
481
+ n = len(indices)
482
+ edges = torch.zeros((n, n), dtype=torch.long)
483
+ for u, v, t in edge_list:
484
+ if u < n and v < n:
485
+ if t <= 4:
486
+ edges[u, v] = t
487
+ edges[v, u] = t
488
+ else:
489
+ edges[u, v] = t
490
+ edges[v, u] = 11 - t
491
+ ref['edges'] = edges
492
+ else:
493
+ ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100)
494
+
495
+ def _process_chartok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0):
496
+ max_len = FORMAT_INFO['chartok_coords']['max_len']
497
+ tokenizer = self.tokenizer['chartok_coords']
498
+ if smiles is None or type(smiles) is not str:
499
+ smiles = ""
500
+ label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio)
501
+ ref['chartok_coords'] = torch.LongTensor(label[:max_len])
502
+ indices = [i for i in indices if i < max_len]
503
+ ref['atom_indices'] = torch.LongTensor(indices)
504
+ if tokenizer.continuous_coords:
505
+ if coords is not None:
506
+ ref['coords'] = torch.tensor(coords)
507
+ else:
508
+ ref['coords'] = torch.ones(len(indices), 2) * -1.
509
+ if edges is not None:
510
+ ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)]
511
+ else:
512
+ if 'edges' in self.df.columns:
513
+ edge_list = eval(self.df.loc[idx, 'edges'])
514
+ n = len(indices)
515
+ edges = torch.zeros((n, n), dtype=torch.long)
516
+ for u, v, t in edge_list:
517
+ if u < n and v < n:
518
+ if t <= 4:
519
+ edges[u, v] = t
520
+ edges[v, u] = t
521
+ else:
522
+ edges[u, v] = t
523
+ edges[v, u] = 11 - t
524
+ ref['edges'] = edges
525
+ else:
526
+ ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100)
527
+
528
+
529
+ class AuxTrainDataset(Dataset):
530
+
531
+ def __init__(self, args, train_df, aux_df, tokenizer):
532
+ super().__init__()
533
+ self.train_dataset = TrainDataset(args, train_df, tokenizer, dynamic_indigo=args.dynamic_indigo)
534
+ self.aux_dataset = TrainDataset(args, aux_df, tokenizer, dynamic_indigo=False)
535
+
536
+ def __len__(self):
537
+ return len(self.train_dataset) + len(self.aux_dataset)
538
+
539
+ def __getitem__(self, idx):
540
+ if idx < len(self.train_dataset):
541
+ return self.train_dataset[idx]
542
+ else:
543
+ return self.aux_dataset[idx - len(self.train_dataset)]
544
+
545
+
546
+ def pad_images(imgs):
547
+ # B, C, H, W
548
+ max_shape = [0, 0]
549
+ for img in imgs:
550
+ for i in range(len(max_shape)):
551
+ max_shape[i] = max(max_shape[i], img.shape[-1 - i])
552
+ stack = []
553
+ for img in imgs:
554
+ pad = []
555
+ for i in range(len(max_shape)):
556
+ pad = pad + [0, max_shape[i] - img.shape[-1 - i]]
557
+ stack.append(F.pad(img, pad, value=0))
558
+ return torch.stack(stack)
559
+
560
+
561
+ def bms_collate(batch):
562
+ ids = []
563
+ imgs = []
564
+ batch = [ex for ex in batch if ex[1] is not None]
565
+ formats = list(batch[0][2].keys())
566
+ seq_formats = [k for k in formats if
567
+ k in ['atomtok', 'inchi', 'nodes', 'atomtok_coords', 'chartok_coords', 'atom_indices']]
568
+ refs = {key: [[], []] for key in seq_formats}
569
+ for ex in batch:
570
+ ids.append(ex[0])
571
+ imgs.append(ex[1])
572
+ ref = ex[2]
573
+ for key in seq_formats:
574
+ refs[key][0].append(ref[key])
575
+ refs[key][1].append(torch.LongTensor([len(ref[key])]))
576
+ # Sequence
577
+ for key in seq_formats:
578
+ # this padding should work for atomtok_with_coords too, each of which has shape (length, 4)
579
+ refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=PAD_ID)
580
+ refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1)
581
+ # Time
582
+ # if 'time' in formats:
583
+ # refs['time'] = [ex[2]['time'] for ex in batch]
584
+ # Coords
585
+ if 'coords' in formats:
586
+ refs['coords'] = pad_sequence([ex[2]['coords'] for ex in batch], batch_first=True, padding_value=-1.)
587
+ # Edges
588
+ if 'edges' in formats:
589
+ edges_list = [ex[2]['edges'] for ex in batch]
590
+ max_len = max([len(edges) for edges in edges_list])
591
+ refs['edges'] = torch.stack(
592
+ [F.pad(edges, (0, max_len - len(edges), 0, max_len - len(edges)), value=-100) for edges in edges_list],
593
+ dim=0)
594
+ return ids, pad_images(imgs), refs
molscribe/evaluate.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import multiprocessing
3
+
4
+ import rdkit
5
+ import rdkit.Chem as Chem
6
+ rdkit.RDLogger.DisableLog('rdApp.*')
7
+ from SmilesPE.pretokenizer import atomwise_tokenizer
8
+
9
+
10
+ def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True):
11
+ if type(smiles) is not str or smiles == '':
12
+ return '', False
13
+ if ignore_cistrans:
14
+ smiles = smiles.replace('/', '').replace('\\', '')
15
+ if replace_rgroup:
16
+ tokens = atomwise_tokenizer(smiles)
17
+ for j, token in enumerate(tokens):
18
+ if token[0] == '[' and token[-1] == ']':
19
+ symbol = token[1:-1]
20
+ if symbol[0] == 'R' and symbol[1:].isdigit():
21
+ tokens[j] = f'[{symbol[1:]}*]'
22
+ elif Chem.AtomFromSmiles(token) is None:
23
+ tokens[j] = '*'
24
+ smiles = ''.join(tokens)
25
+ try:
26
+ canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral))
27
+ success = True
28
+ except:
29
+ canon_smiles = smiles
30
+ success = False
31
+ return canon_smiles, success
32
+
33
+
34
+ def convert_smiles_to_canonsmiles(
35
+ smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16):
36
+ with multiprocessing.Pool(num_workers) as p:
37
+ results = p.starmap(canonicalize_smiles,
38
+ [(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list],
39
+ chunksize=128)
40
+ canon_smiles, success = zip(*results)
41
+ return list(canon_smiles), np.mean(success)
42
+
43
+
44
+ class SmilesEvaluator(object):
45
+
46
+ def __init__(self, gold_smiles, num_workers=16):
47
+ self.gold_smiles = gold_smiles
48
+ self.gold_canon_smiles, self.gold_valid = convert_smiles_to_canonsmiles(gold_smiles, num_workers=num_workers)
49
+ self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles,
50
+ ignore_chiral=True, num_workers=num_workers)
51
+ self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles,
52
+ ignore_cistrans=True, num_workers=num_workers)
53
+ self.gold_canon_smiles = self._replace_empty(self.gold_canon_smiles)
54
+ self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral)
55
+ self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans)
56
+
57
+ def _replace_empty(self, smiles_list):
58
+ """Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty."""
59
+ return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "<empty>"
60
+ for smiles in smiles_list]
61
+
62
+ def evaluate(self, pred_smiles):
63
+ results = {}
64
+ results['gold_valid'] = self.gold_valid
65
+ # Canon SMILES
66
+ pred_canon_smiles, pred_valid = convert_smiles_to_canonsmiles(pred_smiles)
67
+ results['canon_smiles_em'] = (np.array(self.gold_canon_smiles) == np.array(pred_canon_smiles)).mean()
68
+ results['pred_valid'] = pred_valid
69
+ # Ignore chirality (Graph exact match)
70
+ pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_chiral=True)
71
+ results['graph'] = (np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral)).mean()
72
+ # Ignore double bond cis/trans
73
+ pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_cistrans=True)
74
+ results['canon_smiles'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)).mean()
75
+ # Evaluate on molecules with chiral centers
76
+ chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g])
77
+ results['chiral_ratio'] = len(chiral) / len(self.gold_smiles)
78
+ results['chiral'] = (chiral[:, 0] == chiral[:, 1]).mean() if len(chiral) > 0 else -1
79
+ return results
molscribe/indigo/__init__.py ADDED
The diff for this file is too large to render. See raw diff
 
molscribe/indigo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (96.8 kB). View file
 
molscribe/indigo/__pycache__/bingo.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
molscribe/indigo/__pycache__/inchi.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
molscribe/indigo/__pycache__/renderer.cpython-310.pyc ADDED
Binary file (2.8 kB). View file
 
molscribe/indigo/bingo.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) from 2009 to Present EPAM Systems.
3
+ #
4
+ # This file is part of Indigo toolkit.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+ from . import *
20
+
21
+
22
+ class BingoException(Exception):
23
+
24
+ def __init__(self, value):
25
+ self.value = value
26
+
27
+ def __str__(self):
28
+ if sys.version_info > (3, 0):
29
+ return repr(self.value.decode('ascii'))
30
+ else:
31
+ return repr(self.value)
32
+
33
+
34
+ class Bingo(object):
35
+ def __init__(self, bingoId, indigo, lib):
36
+ self._id = bingoId
37
+ self._indigo = indigo
38
+ self._lib = lib
39
+ self._lib.bingoVersion.restype = c_char_p
40
+ self._lib.bingoVersion.argtypes = None
41
+ self._lib.bingoCreateDatabaseFile.restype = c_int
42
+ self._lib.bingoCreateDatabaseFile.argtypes = [c_char_p, c_char_p, c_char_p]
43
+ self._lib.bingoLoadDatabaseFile.restype = c_int
44
+ self._lib.bingoLoadDatabaseFile.argtypes = [c_char_p, c_char_p]
45
+ self._lib.bingoCloseDatabase.restype = c_int
46
+ self._lib.bingoCloseDatabase.argtypes = [c_int]
47
+ self._lib.bingoInsertRecordObj.restype = c_int
48
+ self._lib.bingoInsertRecordObj.argtypes = [c_int, c_int]
49
+ self._lib.bingoInsertRecordObjWithExtFP.restype = c_int
50
+ self._lib.bingoInsertRecordObjWithExtFP.argtypes = [c_int, c_int, c_int]
51
+ self._lib.bingoGetRecordObj.restype = c_int
52
+ self._lib.bingoGetRecordObj.argtypes = [c_int, c_int]
53
+ self._lib.bingoInsertRecordObjWithId.restype = c_int
54
+ self._lib.bingoInsertRecordObjWithId.argtypes = [c_int, c_int, c_int]
55
+ self._lib.bingoInsertRecordObjWithIdAndExtFP.restype = c_int
56
+ self._lib.bingoInsertRecordObjWithIdAndExtFP.argtypes = [c_int, c_int, c_int, c_int]
57
+ self._lib.bingoDeleteRecord.restype = c_int
58
+ self._lib.bingoDeleteRecord.argtypes = [c_int, c_int]
59
+ self._lib.bingoSearchSub.restype = c_int
60
+ self._lib.bingoSearchSub.argtypes = [c_int, c_int, c_char_p]
61
+ self._lib.bingoSearchExact.restype = c_int
62
+ self._lib.bingoSearchExact.argtypes = [c_int, c_int, c_char_p]
63
+ self._lib.bingoSearchMolFormula.restype = c_int
64
+ self._lib.bingoSearchMolFormula.argtypes = [c_int, c_char_p, c_char_p]
65
+ self._lib.bingoSearchSim.restype = c_int
66
+ self._lib.bingoSearchSim.argtypes = [c_int, c_int, c_float, c_float, c_char_p]
67
+ self._lib.bingoSearchSimWithExtFP.restype = c_int
68
+ self._lib.bingoSearchSimWithExtFP.argtypes = [c_int, c_int, c_float, c_float, c_int, c_char_p]
69
+ self._lib.bingoSearchSimTopN.restype = c_int
70
+ self._lib.bingoSearchSimTopN.argtypes = [c_int, c_int, c_int, c_float, c_char_p]
71
+ self._lib.bingoSearchSimTopNWithExtFP.restype = c_int
72
+ self._lib.bingoSearchSimTopNWithExtFP.argtypes = [c_int, c_int, c_int, c_float, c_int, c_char_p]
73
+ self._lib.bingoEnumerateId.restype = c_int
74
+ self._lib.bingoEnumerateId.argtypes = [c_int]
75
+ self._lib.bingoNext.restype = c_int
76
+ self._lib.bingoNext.argtypes = [c_int]
77
+ self._lib.bingoGetCurrentId.restype = c_int
78
+ self._lib.bingoGetCurrentId.argtypes = [c_int]
79
+ self._lib.bingoGetObject.restype = c_int
80
+ self._lib.bingoGetObject.argtypes = [c_int]
81
+ self._lib.bingoEndSearch.restype = c_int
82
+ self._lib.bingoEndSearch.argtypes = [c_int]
83
+ self._lib.bingoGetCurrentSimilarityValue.restype = c_float
84
+ self._lib.bingoGetCurrentSimilarityValue.argtypes = [c_int]
85
+ self._lib.bingoOptimize.restype = c_int
86
+ self._lib.bingoOptimize.argtypes = [c_int]
87
+ self._lib.bingoEstimateRemainingResultsCount.restype = c_int
88
+ self._lib.bingoEstimateRemainingResultsCount.argtypes = [c_int]
89
+ self._lib.bingoEstimateRemainingResultsCountError.restype = c_int
90
+ self._lib.bingoEstimateRemainingResultsCountError.argtypes = [c_int]
91
+ self._lib.bingoEstimateRemainingTime.restype = c_int
92
+ self._lib.bingoEstimateRemainingTime.argtypes = [c_int, POINTER(c_float)]
93
+ self._lib.bingoContainersCount.restype = c_int
94
+ self._lib.bingoContainersCount.argtypes = [c_int]
95
+ self._lib.bingoCellsCount.restype = c_int
96
+ self._lib.bingoCellsCount.argtypes = [c_int]
97
+ self._lib.bingoCurrentCell.restype = c_int
98
+ self._lib.bingoCurrentCell.argtypes = [c_int]
99
+ self._lib.bingoMinCell.restype = c_int
100
+ self._lib.bingoMinCell.argtypes = [c_int]
101
+ self._lib.bingoMaxCell.restype = c_int
102
+ self._lib.bingoMaxCell.argtypes = [c_int]
103
+
104
+ def __del__(self):
105
+ self.close()
106
+
107
+ def close(self):
108
+ self._indigo._setSessionId()
109
+ if self._id >= 0:
110
+ Bingo._checkResult(self._indigo, self._lib.bingoCloseDatabase(self._id))
111
+ self._id = -1
112
+
113
+ @staticmethod
114
+ def _checkResult(indigo, result):
115
+ if result < 0:
116
+ raise BingoException(indigo._lib.indigoGetLastError())
117
+ return result
118
+
119
+ @staticmethod
120
+ def _checkResultPtr (indigo, result):
121
+ if result is None:
122
+ raise BingoException(indigo._lib.indigoGetLastError())
123
+ return result
124
+
125
+ @staticmethod
126
+ def _checkResultString (indigo, result):
127
+ res = Bingo._checkResultPtr(indigo, result)
128
+ if sys.version_info >= (3, 0):
129
+ return res.decode('ascii')
130
+ else:
131
+ return res.encode('ascii')
132
+
133
+ @staticmethod
134
+ def _getLib(indigo):
135
+ if os.name == 'posix' and not platform.mac_ver()[0] and not platform.system().startswith("CYGWIN"):
136
+ _lib = CDLL(indigo.dllpath + "/libbingo.so")
137
+ elif os.name == 'nt' or platform.system().startswith("CYGWIN"):
138
+ _lib = CDLL(indigo.dllpath + "/bingo.dll")
139
+ elif platform.mac_ver()[0]:
140
+ _lib = CDLL(indigo.dllpath + "/libbingo.dylib")
141
+ else:
142
+ raise BingoException("unsupported OS: " + os.name)
143
+ return _lib
144
+
145
+ @staticmethod
146
+ def createDatabaseFile(indigo, path, databaseType, options=''):
147
+ indigo._setSessionId()
148
+ if not options:
149
+ options = ''
150
+ lib = Bingo._getLib(indigo)
151
+ lib.bingoCreateDatabaseFile.restype = c_int
152
+ lib.bingoCreateDatabaseFile.argtypes = [c_char_p, c_char_p, c_char_p]
153
+ return Bingo(Bingo._checkResult(indigo, lib.bingoCreateDatabaseFile(path.encode('ascii'), databaseType.encode('ascii'), options.encode('ascii'))), indigo, lib)
154
+
155
+ @staticmethod
156
+ def loadDatabaseFile(indigo, path, options=''):
157
+ indigo._setSessionId()
158
+ if not options:
159
+ options = ''
160
+ lib = Bingo._getLib(indigo)
161
+ lib.bingoLoadDatabaseFile.restype = c_int
162
+ lib.bingoLoadDatabaseFile.argtypes = [c_char_p, c_char_p]
163
+ return Bingo(Bingo._checkResult(indigo, lib.bingoLoadDatabaseFile(path.encode('ascii'), options.encode('ascii'))), indigo, lib)
164
+
165
+ def version(self):
166
+ self._indigo._setSessionId()
167
+ return Bingo._checkResultString(self._indigo, self._lib.bingoVersion())
168
+
169
+ def insert(self, indigoObject, index=None):
170
+ self._indigo._setSessionId()
171
+ if not index:
172
+ return Bingo._checkResult(self._indigo, self._lib.bingoInsertRecordObj(self._id, indigoObject.id))
173
+ else:
174
+ return Bingo._checkResult(self._indigo,
175
+ self._lib.bingoInsertRecordObjWithId(self._id, indigoObject.id, index))
176
+
177
+ def insertWithExtFP(self, indigoObject, ext_fp, index=None):
178
+ self._indigo._setSessionId()
179
+ if not index:
180
+ return Bingo._checkResult(self._indigo, self._lib.bingoInsertRecordObjWithExtFP(self._id, indigoObject.id, ext_fp.id))
181
+ else:
182
+ return Bingo._checkResult(self._indigo,
183
+ self._lib.bingoInsertRecordObjWithIdAndExtFP(self._id, indigoObject.id, index, ext_fp.id))
184
+
185
+ def delete(self, index):
186
+ self._indigo._setSessionId()
187
+ Bingo._checkResult(self._indigo, self._lib.bingoDeleteRecord(self._id, index))
188
+
189
+ def searchSub(self, query, options=''):
190
+ self._indigo._setSessionId()
191
+ if not options:
192
+ options = ''
193
+ return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchSub(self._id, query.id, options.encode('ascii'))),
194
+ self._indigo, self)
195
+
196
+ def searchExact(self, query, options=''):
197
+ self._indigo._setSessionId()
198
+ if not options:
199
+ options = ''
200
+ return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchExact(self._id, query.id, options.encode('ascii'))),
201
+ self._indigo, self)
202
+
203
+ def searchSim(self, query, minSim, maxSim, metric='tanimoto'):
204
+ self._indigo._setSessionId()
205
+ if not metric:
206
+ metric = 'tanimoto'
207
+ return BingoObject(
208
+ Bingo._checkResult(self._indigo, self._lib.bingoSearchSim(self._id, query.id, minSim, maxSim, metric.encode('ascii'))),
209
+ self._indigo, self)
210
+
211
+ def searchSimWithExtFP(self, query, minSim, maxSim, ext_fp, metric='tanimoto'):
212
+ self._indigo._setSessionId()
213
+ if not metric:
214
+ metric = 'tanimoto'
215
+ return BingoObject(
216
+ Bingo._checkResult(self._indigo, self._lib.bingoSearchSimWithExtFP(self._id, query.id, minSim, maxSim, ext_fp.id, metric.encode('ascii'))),
217
+ self._indigo, self)
218
+
219
+ def searchSimTopN(self, query, limit, minSim, metric='tanimoto'):
220
+ self._indigo._setSessionId()
221
+ if not metric:
222
+ metric = 'tanimoto'
223
+ return BingoObject(
224
+ Bingo._checkResult(self._indigo, self._lib.bingoSearchSimTopN(self._id, query.id, limit, minSim, metric.encode('ascii'))),
225
+ self._indigo, self)
226
+
227
+ def searchSimTopNWithExtFP(self, query, limit, minSim, ext_fp, metric='tanimoto'):
228
+ self._indigo._setSessionId()
229
+ if not metric:
230
+ metric = 'tanimoto'
231
+ return BingoObject(
232
+ Bingo._checkResult(self._indigo, self._lib.bingoSearchSimTopNWithExtFP(self._id, query.id, limit, minSim, ext_fp.id, metric.encode('ascii'))),
233
+ self._indigo, self)
234
+
235
+ def enumerateId(self):
236
+ self._indigo._setSessionId()
237
+ e = self._lib.bingoEnumerateId(self._id)
238
+ result = Bingo._checkResult(self._indigo, e)
239
+ return BingoObject(result, self._indigo, self)
240
+
241
+ def searchMolFormula(self, query, options=''):
242
+ self._indigo._setSessionId()
243
+ if not options:
244
+ options = ''
245
+ return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchMolFormula(self._id, query.encode('ascii'), options.encode('ascii'))),
246
+ self._indigo, self)
247
+
248
+ def optimize(self):
249
+ self._indigo._setSessionId()
250
+ Bingo._checkResult(self._indigo, self._lib.bingoOptimize(self._id))
251
+
252
+ def getRecordById (self, id):
253
+ self._indigo._setSessionId()
254
+ return IndigoObject(self._indigo, Bingo._checkResult(self._indigo, self._lib.bingoGetRecordObj(self._id, id)))
255
+
256
+ class BingoObject(object):
257
+ def __init__(self, objId, indigo, bingo):
258
+ self._id = objId
259
+ self._indigo = indigo
260
+ self._bingo = bingo
261
+
262
+ def __del__(self):
263
+ self.close()
264
+
265
+ def close(self):
266
+ self._indigo._setSessionId()
267
+ if self._id >= 0:
268
+ Bingo._checkResult(self._indigo, self._bingo._lib.bingoEndSearch(self._id))
269
+ self._id = -1
270
+
271
+ def next(self):
272
+ self._indigo._setSessionId()
273
+ return (Bingo._checkResult(self._indigo, self._bingo._lib.bingoNext(self._id)) == 1)
274
+
275
+ def getCurrentId(self):
276
+ self._indigo._setSessionId()
277
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetCurrentId(self._id))
278
+
279
+ def getIndigoObject(self):
280
+ self._indigo._setSessionId()
281
+ return IndigoObject(self._indigo, Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetObject(self._id)))
282
+
283
+ def getCurrentSimilarityValue(self):
284
+ self._indigo._setSessionId()
285
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetCurrentSimilarityValue(self._id))
286
+
287
+ def estimateRemainingResultsCount(self):
288
+ self._indigo._setSessionId()
289
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingResultsCount(self._id))
290
+
291
+ def estimateRemainingResultsCountError(self):
292
+ self._indigo._setSessionId()
293
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingResultsCountError(self._id))
294
+
295
+ def estimateRemainingTime(self):
296
+ self._indigo._setSessionId()
297
+ value = c_float()
298
+ Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingTime(self._id, pointer(value)))
299
+ return value.value
300
+
301
+ def containersCount(self):
302
+ self._indigo._setSessionId()
303
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoContainersCount(self._id))
304
+
305
+ def cellsCount(self):
306
+ self._indigo._setSessionId()
307
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoCellsCount(self._id))
308
+
309
+ def currentCell(self):
310
+ self._indigo._setSessionId()
311
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoCurrentCell(self._id))
312
+
313
+ def minCell(self):
314
+ self._indigo._setSessionId()
315
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoMinCell(self._id))
316
+
317
+ def maxCell(self):
318
+ self._indigo._setSessionId()
319
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoMaxCell(self._id))
320
+
321
+ def __enter__(self):
322
+ return self
323
+
324
+ def __exit__(self, exc_type, exc_val, exc_tb):
325
+ self.close()
326
+
327
+ def __iter__(self):
328
+ return self
329
+
330
+ def __next__(self):
331
+ next_item = self.next()
332
+ if next_item:
333
+ return self
334
+ raise StopIteration
molscribe/indigo/inchi.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) from 2009 to Present EPAM Systems.
3
+ #
4
+ # This file is part of Indigo toolkit.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from . import *
19
+
20
+
21
+ class IndigoInchi(object):
22
+ def __init__(self, indigo):
23
+ self.indigo = indigo
24
+
25
+ if os.name == 'posix' and not platform.mac_ver()[0] and not platform.system().startswith("CYGWIN"):
26
+ self._lib = CDLL(indigo.dllpath + "/libindigo-inchi.so")
27
+ elif os.name == 'nt' or platform.system().startswith("CYGWIN"):
28
+ self._lib = CDLL(indigo.dllpath + "\indigo-inchi.dll")
29
+ elif platform.mac_ver()[0]:
30
+ self._lib = CDLL(indigo.dllpath + "/libindigo-inchi.dylib")
31
+ else:
32
+ raise IndigoException("unsupported OS: " + os.name)
33
+
34
+ self._lib.indigoInchiVersion.restype = c_char_p
35
+ self._lib.indigoInchiVersion.argtypes = []
36
+ self._lib.indigoInchiResetOptions.restype = c_int
37
+ self._lib.indigoInchiResetOptions.argtypes = []
38
+ self._lib.indigoInchiLoadMolecule.restype = c_int
39
+ self._lib.indigoInchiLoadMolecule.argtypes = [c_char_p]
40
+ self._lib.indigoInchiGetInchi.restype = c_char_p
41
+ self._lib.indigoInchiGetInchi.argtypes = [c_int]
42
+ self._lib.indigoInchiGetInchiKey.restype = c_char_p
43
+ self._lib.indigoInchiGetInchiKey.argtypes = [c_char_p]
44
+ self._lib.indigoInchiGetWarning.restype = c_char_p
45
+ self._lib.indigoInchiGetWarning.argtypes = []
46
+ self._lib.indigoInchiGetLog.restype = c_char_p
47
+ self._lib.indigoInchiGetLog.argtypes = []
48
+ self._lib.indigoInchiGetAuxInfo.restype = c_char_p
49
+ self._lib.indigoInchiGetAuxInfo.argtypes = []
50
+
51
+ def resetOptions(self):
52
+ self.indigo._setSessionId()
53
+ self.indigo._checkResult(self._lib.indigoInchiResetOptions())
54
+
55
+ def loadMolecule(self, inchi):
56
+ self.indigo._setSessionId()
57
+ res = self.indigo._checkResult(self._lib.indigoInchiLoadMolecule(inchi.encode('ascii')))
58
+ if res == 0:
59
+ return None
60
+ return self.indigo.IndigoObject(self.indigo, res)
61
+
62
+ def version(self):
63
+ self.indigo._setSessionId()
64
+ return self.indigo._checkResultString(self._lib.indigoInchiVersion())
65
+
66
+ def getInchi(self, molecule):
67
+ self.indigo._setSessionId()
68
+ return self.indigo._checkResultString(self._lib.indigoInchiGetInchi(molecule.id))
69
+
70
+ def getInchiKey(self, inchi):
71
+ self.indigo._setSessionId()
72
+ return self.indigo._checkResultString(self._lib.indigoInchiGetInchiKey(inchi.encode('ascii')))
73
+
74
+ def getWarning(self):
75
+ self.indigo._setSessionId()
76
+ return self.indigo._checkResultString(self._lib.indigoInchiGetWarning())
77
+
78
+ def getLog(self):
79
+ self.indigo._setSessionId()
80
+ return self.indigo._checkResultString(self._lib.indigoInchiGetLog())
81
+
82
+ def getAuxInfo(self):
83
+ self.indigo._setSessionId()
84
+ return self.indigo._checkResultString(self._lib.indigoInchiGetAuxInfo())
molscribe/indigo/renderer.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) from 2009 to Present EPAM Systems.
3
+ #
4
+ # This file is part of Indigo toolkit.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+ import platform
20
+ from ctypes import CDLL, POINTER, c_char_p, c_int
21
+
22
+ from . import IndigoException
23
+
24
+
25
+ class IndigoRenderer(object):
26
+ def __init__(self, indigo):
27
+ self.indigo = indigo
28
+
29
+ if (
30
+ os.name == "posix"
31
+ and not platform.mac_ver()[0]
32
+ and not platform.system().startswith("CYGWIN")
33
+ ):
34
+ self._lib = CDLL(indigo.dllpath + "/libindigo-renderer.so")
35
+ elif os.name == "nt" or platform.system().startswith("CYGWIN"):
36
+ self._lib = CDLL(indigo.dllpath + "\indigo-renderer.dll")
37
+ elif platform.mac_ver()[0]:
38
+ self._lib = CDLL(indigo.dllpath + "/libindigo-renderer.dylib")
39
+ else:
40
+ raise IndigoException("unsupported OS: " + os.name)
41
+
42
+ self._lib.indigoRender.restype = c_int
43
+ self._lib.indigoRender.argtypes = [c_int, c_int]
44
+ self._lib.indigoRenderToFile.restype = c_int
45
+ self._lib.indigoRenderToFile.argtypes = [c_int, c_char_p]
46
+ self._lib.indigoRenderGrid.restype = c_int
47
+ self._lib.indigoRenderGrid.argtypes = [
48
+ c_int,
49
+ POINTER(c_int),
50
+ c_int,
51
+ c_int,
52
+ ]
53
+ self._lib.indigoRenderGridToFile.restype = c_int
54
+ self._lib.indigoRenderGridToFile.argtypes = [
55
+ c_int,
56
+ POINTER(c_int),
57
+ c_int,
58
+ c_char_p,
59
+ ]
60
+ self._lib.indigoRenderReset.restype = c_int
61
+ self._lib.indigoRenderReset.argtypes = [c_int]
62
+
63
+ def renderToBuffer(self, obj):
64
+ self.indigo._setSessionId()
65
+ wb = self.indigo.writeBuffer()
66
+ try:
67
+ self.indigo._checkResult(self._lib.indigoRender(obj.id, wb.id))
68
+ return wb.toBuffer()
69
+ finally:
70
+ wb.dispose()
71
+
72
+ def renderToFile(self, obj, filename):
73
+ self.indigo._setSessionId()
74
+ self.indigo._checkResult(
75
+ self._lib.indigoRenderToFile(obj.id, filename.encode("ascii"))
76
+ )
77
+
78
+ def renderGridToFile(self, objects, refatoms, ncolumns, filename):
79
+ self.indigo._setSessionId()
80
+ arr = None
81
+ if refatoms:
82
+ if len(refatoms) != objects.count():
83
+ raise IndigoException(
84
+ "renderGridToFile(): refatoms[] size must be equal to the number of objects"
85
+ )
86
+ arr = (c_int * len(refatoms))()
87
+ for i in range(len(refatoms)):
88
+ arr[i] = refatoms[i]
89
+ self.indigo._checkResult(
90
+ self._lib.indigoRenderGridToFile(
91
+ objects.id, arr, ncolumns, filename.encode("ascii")
92
+ )
93
+ )
94
+
95
+ def renderGridToBuffer(self, objects, refatoms, ncolumns):
96
+ self.indigo._setSessionId()
97
+ arr = None
98
+ if refatoms:
99
+ if len(refatoms) != objects.count():
100
+ raise IndigoException(
101
+ "renderGridToBuffer(): refatoms[] size must be equal to the number of objects"
102
+ )
103
+ arr = (c_int * len(refatoms))()
104
+ for i in range(len(refatoms)):
105
+ arr[i] = refatoms[i]
106
+ wb = self.indigo.writeBuffer()
107
+ try:
108
+ self.indigo._checkResult(
109
+ self._lib.indigoRenderGrid(objects.id, arr, ncolumns, wb.id)
110
+ )
111
+ return wb.toBuffer()
112
+ finally:
113
+ wb.dispose()
molscribe/inference/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .greedy_search import GreedySearch
2
+ from .beam_search import BeamSearch
3
+
4
+ __all__ = ["GreedySearch", "BeamSearch"]
molscribe/inference/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (282 Bytes). View file
 
molscribe/inference/__pycache__/beam_search.cpython-310.pyc ADDED
Binary file (5.44 kB). View file
 
molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc ADDED
Binary file (2.68 kB). View file
 
molscribe/inference/__pycache__/greedy_search.cpython-310.pyc ADDED
Binary file (4.11 kB). View file
 
molscribe/inference/beam_search.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .decode_strategy import DecodeStrategy
3
+
4
+
5
+ class BeamSearch(DecodeStrategy):
6
+ """Generation with beam search.
7
+ """
8
+
9
+ def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length,
10
+ return_attention, max_length):
11
+ super(BeamSearch, self).__init__(
12
+ pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length)
13
+ self.beam_size = beam_size
14
+ self.n_best = n_best
15
+
16
+ # result caching
17
+ self.hypotheses = [[] for _ in range(batch_size)]
18
+
19
+ # beam state
20
+ self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool)
21
+
22
+ self._batch_offset = torch.arange(batch_size, dtype=torch.long)
23
+
24
+ self.select_indices = None
25
+ self.done = False
26
+
27
+ def initialize(self, memory_bank, device=None):
28
+ """Repeat src objects `beam_size` times.
29
+ """
30
+
31
+ def fn_map_state(state, dim):
32
+ return torch.repeat_interleave(state, self.beam_size, dim=dim)
33
+
34
+ memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0)
35
+ if device is None:
36
+ device = memory_bank.device
37
+
38
+ self.memory_length = memory_bank.size(1)
39
+ super().initialize(memory_bank, device)
40
+
41
+ self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device)
42
+ self._beam_offset = torch.arange(
43
+ 0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device)
44
+ self.topk_log_probs = torch.tensor(
45
+ [0.0] + [float("-inf")] * (self.beam_size - 1), device=device
46
+ ).repeat(self.batch_size)
47
+ # buffers for the topk scores and 'backpointer'
48
+ self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device)
49
+ self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device)
50
+ self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device)
51
+
52
+ return fn_map_state, memory_bank
53
+
54
+ @property
55
+ def current_predictions(self):
56
+ return self.alive_seq[:, -1]
57
+
58
+ @property
59
+ def current_backptr(self):
60
+ # for testing
61
+ return self.select_indices.view(self.batch_size, self.beam_size)
62
+
63
+ @property
64
+ def batch_offset(self):
65
+ return self._batch_offset
66
+
67
+ def _pick(self, log_probs):
68
+ """Return token decision for a step.
69
+
70
+ Args:
71
+ log_probs (FloatTensor): (B, vocab_size)
72
+
73
+ Returns:
74
+ topk_scores (FloatTensor): (B, beam_size)
75
+ topk_ids (LongTensor): (B, beam_size)
76
+ """
77
+ vocab_size = log_probs.size(-1)
78
+
79
+ # Flatten probs into a list of probabilities.
80
+ curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
81
+ topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1)
82
+ return topk_scores, topk_ids
83
+
84
+ def advance(self, log_probs, attn):
85
+ """
86
+ Args:
87
+ log_probs: (B * beam_size, vocab_size)
88
+ """
89
+ vocab_size = log_probs.size(-1)
90
+
91
+ # (non-finished) batch_size
92
+ _B = log_probs.shape[0] // self.beam_size
93
+
94
+ step = len(self) # alive_seq
95
+ self.ensure_min_length(log_probs)
96
+
97
+ # Multiply probs by the beam probability
98
+ log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
99
+
100
+ curr_length = step + 1
101
+ curr_scores = log_probs / curr_length # avg log_prob
102
+ self.topk_scores, self.topk_ids = self._pick(curr_scores)
103
+ # topk_scores/topk_ids: (batch_size, beam_size)
104
+
105
+ # Recover log probs
106
+ torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs)
107
+
108
+ # Resolve beam origin and map to batch index flat representation.
109
+ self._batch_index = self.topk_ids // vocab_size
110
+ self._batch_index += self._beam_offset[:_B].unsqueeze(1)
111
+ self.select_indices = self._batch_index.view(_B * self.beam_size)
112
+ self.topk_ids.fmod_(vocab_size) # resolve true word ids
113
+
114
+ # Append last prediction.
115
+ self.alive_seq = torch.cat(
116
+ [self.alive_seq.index_select(0, self.select_indices),
117
+ self.topk_ids.view(_B * self.beam_size, 1)], -1)
118
+
119
+ if self.return_attention:
120
+ current_attn = attn.index_select(1, self.select_indices)
121
+ if step == 1:
122
+ self.alive_attn = current_attn
123
+ else:
124
+ self.alive_attn = self.alive_attn.index_select(
125
+ 1, self.select_indices)
126
+ self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
127
+
128
+ self.is_finished = self.topk_ids.eq(self.eos)
129
+ self.ensure_max_length()
130
+
131
+ def update_finished(self):
132
+ _B_old = self.topk_log_probs.shape[0]
133
+ step = self.alive_seq.shape[-1] # len(self)
134
+ self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
135
+
136
+ self.is_finished = self.is_finished.to('cpu')
137
+ self.top_beam_finished |= self.is_finished[:, 0].eq(1)
138
+ predictions = self.alive_seq.view(_B_old, self.beam_size, step)
139
+ attention = (
140
+ self.alive_attn.view(
141
+ step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
142
+ if self.alive_attn is not None else None)
143
+ non_finished_batch = []
144
+ for i in range(self.is_finished.size(0)):
145
+ b = self._batch_offset[i]
146
+ finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1)
147
+ # Store finished hypothesis for this batch.
148
+ for j in finished_hyp: # Beam level: finished beam j in batch i
149
+ self.hypotheses[b].append((
150
+ self.topk_scores[i, j],
151
+ predictions[i, j, 1:], # Ignore start token
152
+ attention[:, i, j, :self.memory_length]
153
+ if attention is not None else None))
154
+ # End condition is the top beam finished and we can return
155
+ # n_best hypotheses.
156
+ finish_flag = self.top_beam_finished[i] != 0
157
+ if finish_flag and len(self.hypotheses[b]) >= self.n_best:
158
+ best_hyp = sorted(
159
+ self.hypotheses[b], key=lambda x: x[0], reverse=True)
160
+ for n, (score, pred, attn) in enumerate(best_hyp):
161
+ if n >= self.n_best:
162
+ break
163
+ self.scores[b].append(score.item())
164
+ self.predictions[b].append(pred)
165
+ self.attention[b].append(
166
+ attn if attn is not None else [])
167
+ else:
168
+ non_finished_batch.append(i)
169
+ non_finished = torch.tensor(non_finished_batch)
170
+
171
+ if len(non_finished) == 0:
172
+ self.done = True
173
+ return
174
+
175
+ _B_new = non_finished.shape[0]
176
+ # Remove finished batches for the next step
177
+ self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished)
178
+ self._batch_offset = self._batch_offset.index_select(0, non_finished)
179
+ non_finished = non_finished.to(self.topk_ids.device)
180
+ self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished)
181
+ self._batch_index = self._batch_index.index_select(0, non_finished)
182
+ self.select_indices = self._batch_index.view(_B_new * self.beam_size)
183
+ self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1))
184
+ self.topk_scores = self.topk_scores.index_select(0, non_finished)
185
+ self.topk_ids = self.topk_ids.index_select(0, non_finished)
186
+
187
+ if self.alive_attn is not None:
188
+ inp_seq_len = self.alive_attn.size(-1)
189
+ self.alive_attn = attention.index_select(1, non_finished) \
190
+ .view(step - 1, _B_new * self.beam_size, inp_seq_len)
molscribe/inference/decode_strategy.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class DecodeStrategy(object):
5
+ def __init__(self, pad, bos, eos, batch_size, parallel_paths, min_length, max_length,
6
+ return_attention=False, return_hidden=False):
7
+ self.pad = pad
8
+ self.bos = bos
9
+ self.eos = eos
10
+
11
+ self.batch_size = batch_size
12
+ self.parallel_paths = parallel_paths
13
+ # result catching
14
+ self.predictions = [[] for _ in range(batch_size)]
15
+ self.scores = [[] for _ in range(batch_size)]
16
+ self.token_scores = [[] for _ in range(batch_size)]
17
+ self.attention = [[] for _ in range(batch_size)]
18
+ self.hidden = [[] for _ in range(batch_size)]
19
+
20
+ self.alive_attn = None
21
+ self.alive_hidden = None
22
+
23
+ self.min_length = min_length
24
+ self.max_length = max_length
25
+
26
+ n_paths = batch_size * parallel_paths
27
+ self.return_attention = return_attention
28
+ self.return_hidden = return_hidden
29
+
30
+ self.done = False
31
+
32
+ def initialize(self, memory_bank, device=None):
33
+ if device is None:
34
+ device = torch.device('cpu')
35
+ self.alive_seq = torch.full(
36
+ [self.batch_size * self.parallel_paths, 1], self.bos,
37
+ dtype=torch.long, device=device)
38
+ self.is_finished = torch.zeros(
39
+ [self.batch_size, self.parallel_paths],
40
+ dtype=torch.uint8, device=device)
41
+ self.alive_log_token_scores = torch.zeros(
42
+ [self.batch_size * self.parallel_paths, 0],
43
+ dtype=torch.float, device=device)
44
+
45
+ return None, memory_bank
46
+
47
+ def __len__(self):
48
+ return self.alive_seq.shape[1]
49
+
50
+ def ensure_min_length(self, log_probs):
51
+ if len(self) <= self.min_length:
52
+ log_probs[:, self.eos] = -1e20 # forced non-end
53
+
54
+ def ensure_max_length(self):
55
+ if len(self) == self.max_length + 1:
56
+ self.is_finished.fill_(1)
57
+
58
+ def advance(self, log_probs, attn):
59
+ raise NotImplementedError()
60
+
61
+ def update_finished(self):
62
+ raise NotImplementedError
63
+
molscribe/inference/greedy_search.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .decode_strategy import DecodeStrategy
3
+
4
+
5
+ def sample_with_temperature(logits, sampling_temp, keep_topk):
6
+ """Select next tokens randomly from the top k possible next tokens.
7
+
8
+ Samples from a categorical distribution over the ``keep_topk`` words using
9
+ the category probabilities ``logits / sampling_temp``.
10
+ """
11
+
12
+ if sampling_temp == 0.0 or keep_topk == 1:
13
+ # argmax
14
+ topk_scores, topk_ids = logits.topk(1, dim=-1)
15
+ if sampling_temp > 0:
16
+ topk_scores /= sampling_temp
17
+ else:
18
+ logits = torch.div(logits, sampling_temp)
19
+ if keep_topk > 0:
20
+ top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
21
+ kth_best = top_values[:, -1].view([-1, 1])
22
+ kth_best = kth_best.repeat([1, logits.shape[1]]).float()
23
+ ignore = torch.lt(logits, kth_best)
24
+ logits = logits.masked_fill(ignore, -10000)
25
+
26
+ dist = torch.distributions.Multinomial(logits=logits, total_count=1)
27
+ topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
28
+ topk_scores = logits.gather(dim=1, index=topk_ids)
29
+
30
+ return topk_ids, topk_scores
31
+
32
+
33
+ class GreedySearch(DecodeStrategy):
34
+ """Select next tokens randomly from the top k possible next tokens.
35
+ """
36
+
37
+ def __init__(self, pad, bos, eos, batch_size, min_length, max_length,
38
+ return_attention=False, return_hidden=False, sampling_temp=1, keep_topk=1):
39
+ super().__init__(
40
+ pad, bos, eos, batch_size, 1, min_length, max_length, return_attention, return_hidden)
41
+ self.sampling_temp = sampling_temp
42
+ self.keep_topk = keep_topk
43
+ self.topk_scores = None
44
+
45
+ def initialize(self, memory_bank, device=None):
46
+ fn_map_state = None
47
+
48
+ if device is None:
49
+ device = memory_bank.device
50
+
51
+ self.memory_length = memory_bank.size(1)
52
+ super().initialize(memory_bank, device)
53
+
54
+ self.select_indices = torch.arange(
55
+ self.batch_size, dtype=torch.long, device=device)
56
+ self.original_batch_idx = torch.arange(
57
+ self.batch_size, dtype=torch.long, device=device)
58
+
59
+ return fn_map_state, memory_bank
60
+
61
+ @property
62
+ def current_predictions(self):
63
+ return self.alive_seq[:, -1]
64
+
65
+ @property
66
+ def batch_offset(self):
67
+ return self.select_indices
68
+
69
+ def _pick(self, log_probs):
70
+ """Function used to pick next tokens.
71
+ """
72
+ topk_ids, topk_scores = sample_with_temperature(
73
+ log_probs, self.sampling_temp, self.keep_topk)
74
+ return topk_ids, topk_scores
75
+
76
+ def advance(self, log_probs, attn=None, hidden=None, label=None):
77
+ """Select next tokens randomly from the top k possible next tokens.
78
+ """
79
+ self.ensure_min_length(log_probs)
80
+ topk_ids, self.topk_scores = self._pick(log_probs) # log_probs: b x v; topk_ids & self.topk_scores: b x (t=1)
81
+ self.is_finished = topk_ids.eq(self.eos)
82
+ if label is not None:
83
+ label = label.view_as(self.is_finished)
84
+ self.is_finished = label.eq(self.eos)
85
+ self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) # b x (l+1) (first element is <bos>; note l = len(self)-1)
86
+ self.alive_log_token_scores = torch.cat([self.alive_log_token_scores, self.topk_scores], -1)
87
+
88
+ if self.return_attention:
89
+ if self.alive_attn is None:
90
+ self.alive_attn = attn
91
+ else:
92
+ self.alive_attn = torch.cat([self.alive_attn, attn], 1)
93
+ if self.return_hidden:
94
+ if self.alive_hidden is None:
95
+ self.alive_hidden = hidden
96
+ else:
97
+ self.alive_hidden = torch.cat([self.alive_hidden, hidden], 1) # b x l x h
98
+ self.ensure_max_length()
99
+
100
+ def update_finished(self):
101
+ """Finalize scores and predictions."""
102
+ # is_finished indicates the decoder finished generating the sequence. Remove it from the batch and update
103
+ # the results.
104
+ finished_batches = self.is_finished.view(-1).nonzero()
105
+ for b in finished_batches.view(-1):
106
+ b_orig = self.original_batch_idx[b]
107
+ # scores/predictions/attention are lists,
108
+ # (to be compatible with beam-search)
109
+ self.scores[b_orig].append(torch.exp(torch.mean(self.alive_log_token_scores[b])).item())
110
+ self.token_scores[b_orig].append(torch.exp(self.alive_log_token_scores[b]).tolist())
111
+ self.predictions[b_orig].append(self.alive_seq[b, 1:]) # skip <bos>
112
+ self.attention[b_orig].append(
113
+ self.alive_attn[b, :, :self.memory_length] if self.alive_attn is not None else [])
114
+ self.hidden[b_orig].append(
115
+ self.alive_hidden[b, :] if self.alive_hidden is not None else [])
116
+ self.done = self.is_finished.all()
117
+ if self.done:
118
+ return
119
+ is_alive = ~self.is_finished.view(-1)
120
+ self.alive_seq = self.alive_seq[is_alive]
121
+ self.alive_log_token_scores = self.alive_log_token_scores[is_alive]
122
+ if self.alive_attn is not None:
123
+ self.alive_attn = self.alive_attn[is_alive]
124
+ if self.alive_hidden is not None:
125
+ self.alive_hidden = self.alive_hidden[is_alive]
126
+ self.select_indices = is_alive.nonzero().view(-1)
127
+ self.original_batch_idx = self.original_batch_idx[is_alive]
128
+ # select_indices is equal to original_batch_idx for greedy search?
molscribe/interface.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import List
3
+
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
9
+
10
+ from .dataset import get_transforms
11
+ from .model import Encoder, Decoder
12
+ from .chemistry import convert_graph_to_smiles
13
+ from .tokenizer import get_tokenizer
14
+
15
+
16
+ BOND_TYPES = ["", "single", "double", "triple", "aromatic", "solid wedge", "dashed wedge"]
17
+
18
+
19
+ def safe_load(module, module_states):
20
+ def remove_prefix(state_dict):
21
+ return {k.replace('module.', ''): v for k, v in state_dict.items()}
22
+ missing_keys, unexpected_keys = module.load_state_dict(remove_prefix(module_states), strict=False)
23
+ return
24
+
25
+
26
+ class MolScribe:
27
+
28
+ def __init__(self, model_path, device=None):
29
+ """
30
+ MolScribe Interface
31
+ :param model_path: path of the model checkpoint.
32
+ :param device: torch device, defaults to be CPU.
33
+ """
34
+ model_states = torch.load(model_path, map_location=torch.device('cpu'))
35
+ args = self._get_args(model_states['args'])
36
+ if device is None:
37
+ device = torch.device('cpu')
38
+ self.device = device
39
+ self.tokenizer = get_tokenizer(args)
40
+ self.encoder, self.decoder = self._get_model(args, self.tokenizer, self.device, model_states)
41
+ self.transform = get_transforms(args.input_size, augment=False)
42
+
43
+ def _get_args(self, args_states=None):
44
+ parser = argparse.ArgumentParser()
45
+ # Model
46
+ parser.add_argument('--encoder', type=str, default='swin_base')
47
+ parser.add_argument('--decoder', type=str, default='transformer')
48
+ parser.add_argument('--trunc_encoder', action='store_true') # use the hidden states before downsample
49
+ parser.add_argument('--no_pretrained', action='store_true')
50
+ parser.add_argument('--use_checkpoint', action='store_true', default=True)
51
+ parser.add_argument('--dropout', type=float, default=0.5)
52
+ parser.add_argument('--embed_dim', type=int, default=256)
53
+ parser.add_argument('--enc_pos_emb', action='store_true')
54
+ group = parser.add_argument_group("transformer_options")
55
+ group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6)
56
+ group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256)
57
+ group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8)
58
+ group.add_argument("--dec_num_queries", type=int, default=128)
59
+ group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1)
60
+ group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1)
61
+ group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0)
62
+ parser.add_argument('--continuous_coords', action='store_true')
63
+ parser.add_argument('--compute_confidence', action='store_true')
64
+ # Data
65
+ parser.add_argument('--input_size', type=int, default=384)
66
+ parser.add_argument('--vocab_file', type=str, default=None)
67
+ parser.add_argument('--coord_bins', type=int, default=64)
68
+ parser.add_argument('--sep_xy', action='store_true', default=True)
69
+
70
+ args = parser.parse_args([])
71
+ if args_states:
72
+ for key, value in args_states.items():
73
+ args.__dict__[key] = value
74
+ return args
75
+
76
+ def _get_model(self, args, tokenizer, device, states):
77
+ encoder = Encoder(args, pretrained=False)
78
+ args.encoder_dim = encoder.n_features
79
+ decoder = Decoder(args, tokenizer)
80
+
81
+ safe_load(encoder, states['encoder'])
82
+ safe_load(decoder, states['decoder'])
83
+ # print(f"Model loaded from {load_path}")
84
+
85
+ encoder.to(device)
86
+ decoder.to(device)
87
+ encoder.eval()
88
+ decoder.eval()
89
+ return encoder, decoder
90
+
91
+ def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=16):
92
+ device = self.device
93
+ predictions = []
94
+ self.decoder.compute_confidence = return_confidence
95
+
96
+ for idx in range(0, len(input_images), batch_size):
97
+ batch_images = input_images[idx:idx+batch_size]
98
+ images = [self.transform(image=image, keypoints=[])['image'] for image in batch_images]
99
+ images = torch.stack(images, dim=0).to(device)
100
+ with torch.no_grad():
101
+ features, hiddens = self.encoder(images)
102
+ batch_predictions = self.decoder.decode(features, hiddens)
103
+ predictions += batch_predictions
104
+
105
+ return self.convert_graph_to_output(predictions, input_images, return_confidence, return_atoms_bonds)
106
+
107
+
108
+ def convert_graph_to_output(self, predictions, input_images, return_confidence=True, return_atoms_bonds=True):
109
+ node_coords = [pred['chartok_coords']['coords'] for pred in predictions]
110
+ node_symbols = [pred['chartok_coords']['symbols'] for pred in predictions]
111
+ edges = [pred['edges'] for pred in predictions]
112
+ # node_symbols = [r_groups[symbol] if symbol in r_groups else symbol for symbol in node_symbols]
113
+ smiles_list, molblock_list, r_success = convert_graph_to_smiles(
114
+ node_coords, node_symbols, edges, images=input_images)
115
+
116
+ outputs = []
117
+ for smiles, molblock, pred in zip(smiles_list, molblock_list, predictions):
118
+ pred_dict = {"smiles": smiles, "molfile": molblock, "oringinal_coords": pred['chartok_coords']['coords'], "original_symbols": pred['chartok_coords']['symbols'], "orignal_edges": pred['edges']}
119
+ if return_confidence:
120
+ pred_dict["confidence"] = pred["overall_score"]
121
+ if return_atoms_bonds:
122
+ coords = pred['chartok_coords']['coords']
123
+ symbols = pred['chartok_coords']['symbols']
124
+
125
+
126
+ # get atoms info
127
+ atom_list = []
128
+ for i, (symbol, coord) in enumerate(zip(symbols, coords)):
129
+ atom_dict = {"atom_symbol": symbol, "x": round(coord[0],3), "y": round(coord[1],3)}
130
+ if return_confidence:
131
+ atom_dict["confidence"] = pred['chartok_coords']['atom_scores'][i]
132
+ atom_list.append(atom_dict)
133
+ pred_dict["atoms"] = atom_list
134
+ # get bonds info
135
+ bond_list = []
136
+ num_atoms = len(symbols)
137
+ for i in range(num_atoms-1):
138
+ for j in range(i+1, num_atoms):
139
+ bond_type_int = pred['edges'][i][j]
140
+ if bond_type_int != 0:
141
+ bond_type_str = BOND_TYPES[bond_type_int]
142
+ bond_dict = {"bond_type": bond_type_str, "endpoint_atoms": (i, j)}
143
+ if return_confidence:
144
+ bond_dict["confidence"] = pred["edge_scores"][i][j]
145
+ bond_list.append(bond_dict)
146
+ pred_dict["bonds"] = bond_list
147
+ outputs.append(pred_dict)
148
+ return outputs
149
+
150
+ def predict_image(self, image, return_atoms_bonds=False, return_confidence=False):
151
+ return self.predict_images([
152
+ image], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0]
153
+
154
+ def predict_image_files(self, image_files: List, return_atoms_bonds=False, return_confidence=False):
155
+ input_images = []
156
+ for path in image_files:
157
+ image = cv2.imread(path)
158
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
159
+ input_images.append(image)
160
+ return self.predict_images(
161
+ input_images, return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)
162
+
163
+ def predict_image_file(self, image_file: str, return_atoms_bonds=False, return_confidence=False):
164
+ return self.predict_image_files(
165
+ [image_file], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0]
166
+
167
+ def draw_prediction(self, prediction, image, notebook=False):
168
+ if "atoms" not in prediction or "bonds" not in prediction:
169
+ raise ValueError("atoms and bonds information are not provided.")
170
+ h, w, _ = image.shape
171
+ h, w = np.array([h, w]) * 400 / max(h, w)
172
+ image = cv2.resize(image, (int(w), int(h)))
173
+ fig, ax = plt.subplots(1, 1)
174
+ ax.axis('off')
175
+ ax.set_xlim(-0.05 * w, w * 1.05)
176
+ ax.set_ylim(1.05 * h, -0.05 * h)
177
+ plt.imshow(image, alpha=0.)
178
+ x = [a['x'] * w for a in prediction['atoms']]
179
+ y = [a['y'] * h for a in prediction['atoms']]
180
+ markersize = min(w, h) / 3
181
+ plt.scatter(x, y, marker='o', s=markersize, color='lightskyblue', zorder=10)
182
+ for i, atom in enumerate(prediction['atoms']):
183
+ symbol = atom['atom_symbol'].lstrip('[').rstrip(']')
184
+ plt.annotate(symbol, xy=(x[i], y[i]), ha='center', va='center', color='black', zorder=100)
185
+ for bond in prediction['bonds']:
186
+ u, v = bond['endpoint_atoms']
187
+ x1, y1, x2, y2 = x[u], y[u], x[v], y[v]
188
+ bond_type = bond['bond_type']
189
+ if bond_type == 'single':
190
+ color = 'tab:green'
191
+ ax.plot([x1, x2], [y1, y2], color, linewidth=4)
192
+ elif bond_type == 'aromatic':
193
+ color = 'tab:purple'
194
+ ax.plot([x1, x2], [y1, y2], color, linewidth=4)
195
+ elif bond_type == 'double':
196
+ color = 'tab:green'
197
+ ax.plot([x1, x2], [y1, y2], color=color, linewidth=7)
198
+ ax.plot([x1, x2], [y1, y2], color='w', linewidth=1.5, zorder=2.1)
199
+ elif bond_type == 'triple':
200
+ color = 'tab:green'
201
+ x1s, x2s = 0.8 * x1 + 0.2 * x2, 0.2 * x1 + 0.8 * x2
202
+ y1s, y2s = 0.8 * y1 + 0.2 * y2, 0.2 * y1 + 0.8 * y2
203
+ ax.plot([x1s, x2s], [y1s, y2s], color=color, linewidth=9)
204
+ ax.plot([x1, x2], [y1, y2], color='w', linewidth=5, zorder=2.05)
205
+ ax.plot([x1, x2], [y1, y2], color=color, linewidth=2, zorder=2.1)
206
+ else:
207
+ length = 10
208
+ width = 10
209
+ color = 'tab:green'
210
+ if bond_type == 'solid wedge':
211
+ ax.annotate('', xy=(x1, y1), xytext=(x2, y2),
212
+ arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2)
213
+ else:
214
+ ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
215
+ arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2)
216
+ fig.tight_layout()
217
+ if not notebook:
218
+ canvas = FigureCanvasAgg(fig)
219
+ canvas.draw()
220
+ buf = canvas.buffer_rgba()
221
+ result_image = np.asarray(buf)
222
+ plt.close(fig)
223
+ return result_image
molscribe/loss.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from scipy.optimize import linear_sum_assignment
5
+ from .tokenizer import PAD_ID, MASK, MASK_ID
6
+
7
+
8
+ class LabelSmoothingLoss(nn.Module):
9
+ """
10
+ With label smoothing,
11
+ KL-divergence between q_{smoothed ground truth prob.}(w)
12
+ and p_{prob. computed by model}(w) is minimized.
13
+ """
14
+ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
15
+ assert 0.0 < label_smoothing <= 1.0
16
+ self.ignore_index = ignore_index
17
+ super(LabelSmoothingLoss, self).__init__()
18
+
19
+ smoothing_value = label_smoothing / (tgt_vocab_size - 2)
20
+ one_hot = torch.full((tgt_vocab_size,), smoothing_value)
21
+ one_hot[self.ignore_index] = 0
22
+ self.register_buffer('one_hot', one_hot.unsqueeze(0))
23
+
24
+ self.confidence = 1.0 - label_smoothing
25
+
26
+ def forward(self, output, target):
27
+ """
28
+ output (FloatTensor): batch_size x n_classes
29
+ target (LongTensor): batch_size
30
+ """
31
+ # assuming output is raw logits
32
+ # convert to log_probs
33
+ log_probs = F.log_softmax(output, dim=-1)
34
+
35
+ model_prob = self.one_hot.repeat(target.size(0), 1)
36
+ model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
37
+ model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
38
+
39
+ # reduction mean or sum?
40
+ return F.kl_div(log_probs, model_prob, reduction='batchmean')
41
+
42
+
43
+ class SequenceLoss(nn.Module):
44
+
45
+ def __init__(self, label_smoothing, vocab_size, ignore_index=-100, ignore_indices=[]):
46
+ super(SequenceLoss, self).__init__()
47
+ if ignore_indices:
48
+ ignore_index = ignore_indices[0]
49
+ self.ignore_index = ignore_index
50
+ self.ignore_indices = ignore_indices
51
+ if label_smoothing == 0:
52
+ self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean')
53
+ else:
54
+ self.criterion = LabelSmoothingLoss(label_smoothing, vocab_size, ignore_index)
55
+
56
+ def forward(self, output, target):
57
+ """
58
+ :param output: [batch, len, vocab]
59
+ :param target: [batch, len]
60
+ :return:
61
+ """
62
+ batch_size, max_len, vocab_size = output.size()
63
+ output = output.reshape(-1, vocab_size)
64
+ target = target.reshape(-1)
65
+ for idx in self.ignore_indices:
66
+ if idx != self.ignore_index:
67
+ target.masked_fill_((target == idx), self.ignore_index)
68
+ loss = self.criterion(output, target)
69
+ return loss
70
+
71
+
72
+ class GraphLoss(nn.Module):
73
+
74
+ def __init__(self):
75
+ super(GraphLoss, self).__init__()
76
+ weight = torch.ones(7) * 10
77
+ weight[0] = 1
78
+ self.criterion = nn.CrossEntropyLoss(weight, ignore_index=-100)
79
+
80
+ def forward(self, outputs, targets):
81
+ results = {}
82
+ if 'coords' in outputs:
83
+ pred = outputs['coords']
84
+ max_len = pred.size(1)
85
+ target = targets['coords'][:, :max_len]
86
+ mask = target.ge(0)
87
+ loss = F.l1_loss(pred, target, reduction='none')
88
+ results['coords'] = (loss * mask).sum() / mask.sum()
89
+ if 'edges' in outputs:
90
+ pred = outputs['edges']
91
+ max_len = pred.size(-1)
92
+ target = targets['edges'][:, :max_len, :max_len]
93
+ results['edges'] = self.criterion(pred, target)
94
+ return results
95
+
96
+
97
+ class Criterion(nn.Module):
98
+
99
+ def __init__(self, args, tokenizer):
100
+ super(Criterion, self).__init__()
101
+ criterion = {}
102
+ for format_ in args.formats:
103
+ if format_ == 'edges':
104
+ criterion['edges'] = GraphLoss()
105
+ else:
106
+ if MASK in tokenizer[format_].stoi:
107
+ ignore_indices = [PAD_ID, MASK_ID]
108
+ else:
109
+ ignore_indices = []
110
+ criterion[format_] = SequenceLoss(args.label_smoothing, len(tokenizer[format_]),
111
+ ignore_index=PAD_ID, ignore_indices=ignore_indices)
112
+ self.criterion = nn.ModuleDict(criterion)
113
+
114
+ def forward(self, results, refs):
115
+ losses = {}
116
+ for format_ in results:
117
+ predictions, targets, *_ = results[format_]
118
+ loss_ = self.criterion[format_](predictions, targets)
119
+ if type(loss_) is dict:
120
+ losses.update(loss_)
121
+ else:
122
+ if loss_.numel() > 1:
123
+ loss_ = loss_.mean()
124
+ losses[format_] = loss_
125
+ return losses
molscribe/model.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import timm
8
+
9
+ from .utils import FORMAT_INFO, to_device
10
+ from .tokenizer import SOS_ID, EOS_ID, PAD_ID, MASK_ID
11
+ from .inference import GreedySearch, BeamSearch
12
+ from .transformer import TransformerDecoder, Embeddings
13
+
14
+
15
+ class Encoder(nn.Module):
16
+ def __init__(self, args, pretrained=False):
17
+ super().__init__()
18
+ model_name = args.encoder
19
+ self.model_name = model_name
20
+ if model_name.startswith('resnet'):
21
+ self.model_type = 'resnet'
22
+ self.cnn = timm.create_model(model_name, pretrained=pretrained)
23
+ self.n_features = self.cnn.num_features # encoder_dim
24
+ self.cnn.global_pool = nn.Identity()
25
+ self.cnn.fc = nn.Identity()
26
+ elif model_name.startswith('swin'):
27
+ self.model_type = 'swin'
28
+ self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False,
29
+ use_checkpoint=args.use_checkpoint)
30
+ self.n_features = self.transformer.num_features
31
+ self.transformer.head = nn.Identity()
32
+ elif 'efficientnet' in model_name:
33
+ self.model_type = 'efficientnet'
34
+ self.cnn = timm.create_model(model_name, pretrained=pretrained)
35
+ self.n_features = self.cnn.num_features
36
+ self.cnn.global_pool = nn.Identity()
37
+ self.cnn.classifier = nn.Identity()
38
+ else:
39
+ raise NotImplemented
40
+
41
+ def swin_forward(self, transformer, x):
42
+ x = transformer.patch_embed(x)
43
+ if transformer.absolute_pos_embed is not None:
44
+ x = x + transformer.absolute_pos_embed
45
+ x = transformer.pos_drop(x)
46
+
47
+ def layer_forward(layer, x, hiddens):
48
+ for blk in layer.blocks:
49
+ if not torch.jit.is_scripting() and layer.use_checkpoint:
50
+ x = torch.utils.checkpoint.checkpoint(blk, x)
51
+ else:
52
+ x = blk(x)
53
+ H, W = layer.input_resolution
54
+ B, L, C = x.shape
55
+ hiddens.append(x.view(B, H, W, C))
56
+ if layer.downsample is not None:
57
+ x = layer.downsample(x)
58
+ return x, hiddens
59
+
60
+ hiddens = []
61
+ for layer in transformer.layers:
62
+ x, hiddens = layer_forward(layer, x, hiddens)
63
+ x = transformer.norm(x) # B L C
64
+ hiddens[-1] = x.view_as(hiddens[-1])
65
+ return x, hiddens
66
+
67
+ def forward(self, x, refs=None):
68
+ if self.model_type in ['resnet', 'efficientnet']:
69
+ features = self.cnn(x)
70
+ features = features.permute(0, 2, 3, 1)
71
+ hiddens = []
72
+ elif self.model_type == 'swin':
73
+ if 'patch' in self.model_name:
74
+ features, hiddens = self.swin_forward(self.transformer, x)
75
+ else:
76
+ features, hiddens = self.transformer(x)
77
+ else:
78
+ raise NotImplemented
79
+ return features, hiddens
80
+
81
+
82
+ class TransformerDecoderBase(nn.Module):
83
+
84
+ def __init__(self, args):
85
+ super().__init__()
86
+ self.args = args
87
+
88
+ self.enc_trans_layer = nn.Sequential(
89
+ nn.Linear(args.encoder_dim, args.dec_hidden_size)
90
+ # nn.LayerNorm(args.dec_hidden_size, eps=1e-6)
91
+ )
92
+ self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None
93
+
94
+ self.decoder = TransformerDecoder(
95
+ num_layers=args.dec_num_layers,
96
+ d_model=args.dec_hidden_size,
97
+ heads=args.dec_attn_heads,
98
+ d_ff=args.dec_hidden_size * 4,
99
+ copy_attn=False,
100
+ self_attn_type="scaled-dot",
101
+ dropout=args.hidden_dropout,
102
+ attention_dropout=args.attn_dropout,
103
+ max_relative_positions=args.max_relative_positions,
104
+ aan_useffn=False,
105
+ full_context_alignment=False,
106
+ alignment_layer=0,
107
+ alignment_heads=0,
108
+ pos_ffn_activation_fn='gelu'
109
+ )
110
+
111
+ def enc_transform(self, encoder_out):
112
+ batch_size = encoder_out.size(0)
113
+ encoder_dim = encoder_out.size(-1)
114
+ encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
115
+ max_len = encoder_out.size(1)
116
+ device = encoder_out.device
117
+ if self.enc_pos_emb:
118
+ pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0)
119
+ encoder_out = encoder_out + pos_emb
120
+ encoder_out = self.enc_trans_layer(encoder_out)
121
+ return encoder_out
122
+
123
+
124
+ class TransformerDecoderAR(TransformerDecoderBase):
125
+ """Autoregressive Transformer Decoder"""
126
+
127
+ def __init__(self, args, tokenizer):
128
+ super().__init__(args)
129
+ self.tokenizer = tokenizer
130
+ self.vocab_size = len(self.tokenizer)
131
+ self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True)
132
+ self.embeddings = Embeddings(
133
+ word_vec_size=args.dec_hidden_size,
134
+ word_vocab_size=self.vocab_size,
135
+ word_padding_idx=PAD_ID,
136
+ position_encoding=True,
137
+ dropout=args.hidden_dropout)
138
+
139
+ def dec_embedding(self, tgt, step=None):
140
+ pad_idx = self.embeddings.word_padding_idx
141
+ tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt]
142
+ emb = self.embeddings(tgt, step=step)
143
+ assert emb.dim() == 3 # batch x len x embedding_dim
144
+ return emb, tgt_pad_mask
145
+
146
+ def forward(self, encoder_out, labels, label_lengths):
147
+ """Training mode"""
148
+ batch_size, max_len, _ = encoder_out.size()
149
+ memory_bank = self.enc_transform(encoder_out)
150
+
151
+ tgt = labels.unsqueeze(-1) # (b, t, 1)
152
+ tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
153
+ dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask)
154
+
155
+ logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v)
156
+ return logits[:, :-1], labels[:, 1:], dec_out
157
+
158
+ def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256,
159
+ labels=None):
160
+ """Inference mode. Autoregressively decode the sequence. Only greedy search is supported now. Beam search is
161
+ out-dated. The labels is used for partial prediction, i.e. part of the sequence is given. In standard decoding,
162
+ labels=None."""
163
+ batch_size, max_len, _ = encoder_out.size()
164
+ memory_bank = self.enc_transform(encoder_out)
165
+ orig_labels = labels
166
+
167
+ if beam_size == 1:
168
+ decode_strategy = GreedySearch(
169
+ sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length,
170
+ pad=PAD_ID, bos=SOS_ID, eos=EOS_ID,
171
+ return_attention=False, return_hidden=True)
172
+ else:
173
+ decode_strategy = BeamSearch(
174
+ beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length,
175
+ pad=PAD_ID, bos=SOS_ID, eos=EOS_ID,
176
+ return_attention=False)
177
+
178
+ # adapted from onmt.translate.translator
179
+ results = {
180
+ "predictions": None,
181
+ "scores": None,
182
+ "attention": None
183
+ }
184
+
185
+ # (2) prep decode_strategy. Possibly repeat src objects.
186
+ _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank)
187
+
188
+ # (3) Begin decoding step by step:
189
+ for step in range(decode_strategy.max_length):
190
+ tgt = decode_strategy.current_predictions.view(-1, 1, 1)
191
+ if labels is not None:
192
+ label = labels[:, step].view(-1, 1, 1)
193
+ mask = label.eq(MASK_ID).long()
194
+ tgt = tgt * mask + label * (1 - mask)
195
+ tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
196
+ dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank,
197
+ tgt_pad_mask=tgt_pad_mask, step=step)
198
+
199
+ attn = dec_attn.get("std", None)
200
+
201
+ dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v]
202
+ dec_logits = dec_logits.squeeze(1)
203
+ log_probs = F.log_softmax(dec_logits, dim=-1)
204
+
205
+ if self.tokenizer.output_constraint:
206
+ output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()]
207
+ output_mask = torch.tensor(output_mask, device=log_probs.device)
208
+ log_probs.masked_fill_(output_mask, -10000)
209
+
210
+ label = labels[:, step + 1] if labels is not None and step + 1 < labels.size(1) else None
211
+ decode_strategy.advance(log_probs, attn, dec_out, label)
212
+ any_finished = decode_strategy.is_finished.any()
213
+ if any_finished:
214
+ decode_strategy.update_finished()
215
+ if decode_strategy.done:
216
+ break
217
+
218
+ select_indices = decode_strategy.select_indices
219
+ if any_finished:
220
+ # Reorder states.
221
+ memory_bank = memory_bank.index_select(0, select_indices)
222
+ if labels is not None:
223
+ labels = labels.index_select(0, select_indices)
224
+ self.map_state(lambda state, dim: state.index_select(dim, select_indices))
225
+
226
+ results["scores"] = decode_strategy.scores # fixed to be average of token scores
227
+ results["token_scores"] = decode_strategy.token_scores
228
+ results["predictions"] = decode_strategy.predictions
229
+ results["attention"] = decode_strategy.attention
230
+ results["hidden"] = decode_strategy.hidden
231
+ if orig_labels is not None:
232
+ for i in range(batch_size):
233
+ pred = results["predictions"][i][0]
234
+ label = orig_labels[i][1:len(pred) + 1]
235
+ mask = label.eq(MASK_ID).long()
236
+ pred = pred[:len(label)]
237
+ results["predictions"][i][0] = pred * mask + label * (1 - mask)
238
+
239
+ return results["predictions"], results['scores'], results["token_scores"], results["hidden"]
240
+
241
+ # adapted from onmt.decoders.transformer
242
+ def map_state(self, fn):
243
+ def _recursive_map(struct, batch_dim=0):
244
+ for k, v in struct.items():
245
+ if v is not None:
246
+ if isinstance(v, dict):
247
+ _recursive_map(v)
248
+ else:
249
+ struct[k] = fn(v, batch_dim)
250
+
251
+ if self.decoder.state["cache"] is not None:
252
+ _recursive_map(self.decoder.state["cache"])
253
+
254
+
255
+ class GraphPredictor(nn.Module):
256
+
257
+ def __init__(self, decoder_dim, coords=False):
258
+ super(GraphPredictor, self).__init__()
259
+ self.coords = coords
260
+ self.mlp = nn.Sequential(
261
+ nn.Linear(decoder_dim * 2, decoder_dim), nn.GELU(),
262
+ nn.Linear(decoder_dim, 7)
263
+ )
264
+ if coords:
265
+ self.coords_mlp = nn.Sequential(
266
+ nn.Linear(decoder_dim, decoder_dim), nn.GELU(),
267
+ nn.Linear(decoder_dim, 2)
268
+ )
269
+
270
+ def forward(self, hidden, indices=None):
271
+ b, l, dim = hidden.size()
272
+ if indices is None:
273
+ index = [i for i in range(3, l, 3)]
274
+ hidden = hidden[:, index]
275
+ else:
276
+ batch_id = torch.arange(b).unsqueeze(1).expand_as(indices).reshape(-1)
277
+ indices = indices.view(-1)
278
+ hidden = hidden[batch_id, indices].view(b, -1, dim)
279
+ b, l, dim = hidden.size()
280
+ results = {}
281
+ hh = torch.cat([hidden.unsqueeze(2).expand(b, l, l, dim), hidden.unsqueeze(1).expand(b, l, l, dim)], dim=3)
282
+ results['edges'] = self.mlp(hh).permute(0, 3, 1, 2)
283
+ if self.coords:
284
+ results['coords'] = self.coords_mlp(hidden)
285
+ return results
286
+
287
+
288
+ def get_edge_prediction(edge_prob):
289
+ if not edge_prob:
290
+ return [], []
291
+ n = len(edge_prob)
292
+ if n == 0:
293
+ return [], []
294
+ for i in range(n):
295
+ for j in range(i + 1, n):
296
+ for k in range(5):
297
+ edge_prob[i][j][k] = (edge_prob[i][j][k] + edge_prob[j][i][k]) / 2
298
+ edge_prob[j][i][k] = edge_prob[i][j][k]
299
+ edge_prob[i][j][5] = (edge_prob[i][j][5] + edge_prob[j][i][6]) / 2
300
+ edge_prob[i][j][6] = (edge_prob[i][j][6] + edge_prob[j][i][5]) / 2
301
+ edge_prob[j][i][5] = edge_prob[i][j][6]
302
+ edge_prob[j][i][6] = edge_prob[i][j][5]
303
+ prediction = np.argmax(edge_prob, axis=2).tolist()
304
+ score = np.max(edge_prob, axis=2).tolist()
305
+ return prediction, score
306
+
307
+
308
+ class Decoder(nn.Module):
309
+ """This class is a wrapper for different decoder architectures, and support multiple decoders."""
310
+
311
+ def __init__(self, args, tokenizer):
312
+ super(Decoder, self).__init__()
313
+ self.args = args
314
+ self.formats = args.formats
315
+ self.tokenizer = tokenizer
316
+ decoder = {}
317
+ for format_ in args.formats:
318
+ if format_ == 'edges':
319
+ decoder['edges'] = GraphPredictor(args.dec_hidden_size, coords=args.continuous_coords)
320
+ else:
321
+ decoder[format_] = TransformerDecoderAR(args, tokenizer[format_])
322
+ self.decoder = nn.ModuleDict(decoder)
323
+ self.compute_confidence = args.compute_confidence
324
+
325
+ def forward(self, encoder_out, hiddens, refs):
326
+ """Training mode. Compute the logits with teacher forcing."""
327
+ results = {}
328
+ refs = to_device(refs, encoder_out.device)
329
+ for format_ in self.formats:
330
+ if format_ == 'edges':
331
+ if 'atomtok_coords' in results:
332
+ dec_out = results['atomtok_coords'][2]
333
+ predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0])
334
+ elif 'chartok_coords' in results:
335
+ dec_out = results['chartok_coords'][2]
336
+ predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0])
337
+ else:
338
+ raise NotImplemented
339
+ targets = {'edges': refs['edges']}
340
+ if 'coords' in predictions:
341
+ targets['coords'] = refs['coords']
342
+ results['edges'] = (predictions, targets)
343
+ else:
344
+ labels, label_lengths = refs[format_]
345
+ results[format_] = self.decoder[format_](encoder_out, labels, label_lengths)
346
+ return results
347
+
348
+ def decode(self, encoder_out, hiddens=None, refs=None, beam_size=1, n_best=1):
349
+ """Inference mode. Call each decoder's decode method (if required), convert the output format (e.g. token to
350
+ sequence). Beam search is not supported yet."""
351
+ results = {}
352
+ predictions = []
353
+ for format_ in self.formats:
354
+ if format_ in ['atomtok', 'atomtok_coords', 'chartok_coords']:
355
+ max_len = FORMAT_INFO[format_]['max_len']
356
+ results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len)
357
+ outputs, scores, token_scores, *_ = results[format_]
358
+ beam_preds = [[self.tokenizer[format_].sequence_to_smiles(x.tolist()) for x in pred]
359
+ for pred in outputs]
360
+ predictions = [{format_: pred[0]} for pred in beam_preds]
361
+ if self.compute_confidence:
362
+ for i in range(len(predictions)):
363
+ # -1: y score, -2: x score, -3: symbol score
364
+ indices = np.array(predictions[i][format_]['indices']) - 3
365
+ if format_ == 'chartok_coords':
366
+ atom_scores = []
367
+ for symbol, index in zip(predictions[i][format_]['symbols'], indices):
368
+ atom_score = (np.prod(token_scores[i][0][index - len(symbol) + 1:index + 1])
369
+ ** (1 / len(symbol))).item()
370
+ atom_scores.append(atom_score)
371
+ else:
372
+ atom_scores = np.array(token_scores[i][0])[indices].tolist()
373
+ predictions[i][format_]['atom_scores'] = atom_scores
374
+ predictions[i][format_]['average_token_score'] = scores[i][0]
375
+ if format_ == 'edges':
376
+ if 'atomtok_coords' in results:
377
+ atom_format = 'atomtok_coords'
378
+ elif 'chartok_coords' in results:
379
+ atom_format = 'chartok_coords'
380
+ else:
381
+ raise NotImplemented
382
+ dec_out = results[atom_format][3] # batch x n_best x len x dim
383
+ for i in range(len(dec_out)):
384
+ hidden = dec_out[i][0].unsqueeze(0) # 1 * len * dim
385
+ indices = torch.LongTensor(predictions[i][atom_format]['indices']).unsqueeze(0) # 1 * k
386
+ pred = self.decoder['edges'](hidden, indices) # k * k
387
+ prob = F.softmax(pred['edges'].squeeze(0).permute(1, 2, 0), dim=2).tolist() # k * k * 7
388
+ edge_pred, edge_score = get_edge_prediction(prob)
389
+ predictions[i]['edges'] = edge_pred
390
+ if self.compute_confidence:
391
+ predictions[i]['edge_scores'] = edge_score
392
+ predictions[i]['edge_score_product'] = np.sqrt(np.prod(edge_score)).item()
393
+ predictions[i]['overall_score'] = predictions[i][atom_format]['average_token_score'] * \
394
+ predictions[i]['edge_score_product']
395
+ predictions[i][atom_format].pop('average_token_score')
396
+ predictions[i].pop('edge_score_product')
397
+ return predictions
molscribe/tokenizer.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import numpy as np
5
+ from SmilesPE.pretokenizer import atomwise_tokenizer
6
+
7
+ PAD = '<pad>'
8
+ SOS = '<sos>'
9
+ EOS = '<eos>'
10
+ UNK = '<unk>'
11
+ MASK = '<mask>'
12
+ PAD_ID = 0
13
+ SOS_ID = 1
14
+ EOS_ID = 2
15
+ UNK_ID = 3
16
+ MASK_ID = 4
17
+
18
+
19
+ class Tokenizer(object):
20
+
21
+ def __init__(self, path=None):
22
+ self.stoi = {}
23
+ self.itos = {}
24
+ if path:
25
+ self.load(path)
26
+
27
+ def __len__(self):
28
+ return len(self.stoi)
29
+
30
+ @property
31
+ def output_constraint(self):
32
+ return False
33
+
34
+ def save(self, path):
35
+ with open(path, 'w') as f:
36
+ json.dump(self.stoi, f)
37
+
38
+ def load(self, path):
39
+ with open(path) as f:
40
+ self.stoi = json.load(f)
41
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
42
+
43
+ def fit_on_texts(self, texts):
44
+ vocab = set()
45
+ for text in texts:
46
+ vocab.update(text.split(' '))
47
+ vocab = [PAD, SOS, EOS, UNK] + list(vocab)
48
+ for i, s in enumerate(vocab):
49
+ self.stoi[s] = i
50
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
51
+ assert self.stoi[PAD] == PAD_ID
52
+ assert self.stoi[SOS] == SOS_ID
53
+ assert self.stoi[EOS] == EOS_ID
54
+ assert self.stoi[UNK] == UNK_ID
55
+
56
+ def text_to_sequence(self, text, tokenized=True):
57
+ sequence = []
58
+ sequence.append(self.stoi['<sos>'])
59
+ if tokenized:
60
+ tokens = text.split(' ')
61
+ else:
62
+ tokens = atomwise_tokenizer(text)
63
+ for s in tokens:
64
+ if s not in self.stoi:
65
+ s = '<unk>'
66
+ sequence.append(self.stoi[s])
67
+ sequence.append(self.stoi['<eos>'])
68
+ return sequence
69
+
70
+ def texts_to_sequences(self, texts):
71
+ sequences = []
72
+ for text in texts:
73
+ sequence = self.text_to_sequence(text)
74
+ sequences.append(sequence)
75
+ return sequences
76
+
77
+ def sequence_to_text(self, sequence):
78
+ return ''.join(list(map(lambda i: self.itos[i], sequence)))
79
+
80
+ def sequences_to_texts(self, sequences):
81
+ texts = []
82
+ for sequence in sequences:
83
+ text = self.sequence_to_text(sequence)
84
+ texts.append(text)
85
+ return texts
86
+
87
+ def predict_caption(self, sequence):
88
+ caption = ''
89
+ for i in sequence:
90
+ if i == self.stoi['<eos>'] or i == self.stoi['<pad>']:
91
+ break
92
+ caption += self.itos[i]
93
+ return caption
94
+
95
+ def predict_captions(self, sequences):
96
+ captions = []
97
+ for sequence in sequences:
98
+ caption = self.predict_caption(sequence)
99
+ captions.append(caption)
100
+ return captions
101
+
102
+ def sequence_to_smiles(self, sequence):
103
+ return {'smiles': self.predict_caption(sequence)}
104
+
105
+
106
+ class NodeTokenizer(Tokenizer):
107
+
108
+ def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False):
109
+ super().__init__(path)
110
+ self.maxx = input_size # height
111
+ self.maxy = input_size # width
112
+ self.sep_xy = sep_xy
113
+ self.special_tokens = [PAD, SOS, EOS, UNK, MASK]
114
+ self.continuous_coords = continuous_coords
115
+ self.debug = debug
116
+
117
+ def __len__(self):
118
+ if self.sep_xy:
119
+ return self.offset + self.maxx + self.maxy
120
+ else:
121
+ return self.offset + max(self.maxx, self.maxy)
122
+
123
+ @property
124
+ def offset(self):
125
+ return len(self.stoi)
126
+
127
+ @property
128
+ def output_constraint(self):
129
+ return not self.continuous_coords
130
+
131
+ def len_symbols(self):
132
+ return len(self.stoi)
133
+
134
+ def fit_atom_symbols(self, atoms):
135
+ vocab = self.special_tokens + list(set(atoms))
136
+ for i, s in enumerate(vocab):
137
+ self.stoi[s] = i
138
+ assert self.stoi[PAD] == PAD_ID
139
+ assert self.stoi[SOS] == SOS_ID
140
+ assert self.stoi[EOS] == EOS_ID
141
+ assert self.stoi[UNK] == UNK_ID
142
+ assert self.stoi[MASK] == MASK_ID
143
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
144
+
145
+ def is_x(self, x):
146
+ return self.offset <= x < self.offset + self.maxx
147
+
148
+ def is_y(self, y):
149
+ if self.sep_xy:
150
+ return self.offset + self.maxx <= y
151
+ return self.offset <= y
152
+
153
+ def is_symbol(self, s):
154
+ return len(self.special_tokens) <= s < self.offset or s == UNK_ID
155
+
156
+ def is_atom(self, id):
157
+ if self.is_symbol(id):
158
+ return self.is_atom_token(self.itos[id])
159
+ return False
160
+
161
+ def is_atom_token(self, token):
162
+ return token.isalpha() or token.startswith("[") or token == '*' or token == UNK
163
+
164
+ def x_to_id(self, x):
165
+ return self.offset + round(x * (self.maxx - 1))
166
+
167
+ def y_to_id(self, y):
168
+ if self.sep_xy:
169
+ return self.offset + self.maxx + round(y * (self.maxy - 1))
170
+ return self.offset + round(y * (self.maxy - 1))
171
+
172
+ def id_to_x(self, id):
173
+ return (id - self.offset) / (self.maxx - 1)
174
+
175
+ def id_to_y(self, id):
176
+ if self.sep_xy:
177
+ return (id - self.offset - self.maxx) / (self.maxy - 1)
178
+ return (id - self.offset) / (self.maxy - 1)
179
+
180
+ def get_output_mask(self, id):
181
+ mask = [False] * len(self)
182
+ if self.continuous_coords:
183
+ return mask
184
+ if self.is_atom(id):
185
+ return [True] * self.offset + [False] * self.maxx + [True] * self.maxy
186
+ if self.is_x(id):
187
+ return [True] * (self.offset + self.maxx) + [False] * self.maxy
188
+ if self.is_y(id):
189
+ return [False] * self.offset + [True] * (self.maxx + self.maxy)
190
+ return mask
191
+
192
+ def symbol_to_id(self, symbol):
193
+ if symbol not in self.stoi:
194
+ return UNK_ID
195
+ return self.stoi[symbol]
196
+
197
+ def symbols_to_labels(self, symbols):
198
+ labels = []
199
+ for symbol in symbols:
200
+ labels.append(self.symbol_to_id(symbol))
201
+ return labels
202
+
203
+ def labels_to_symbols(self, labels):
204
+ symbols = []
205
+ for label in labels:
206
+ symbols.append(self.itos[label])
207
+ return symbols
208
+
209
+ def nodes_to_grid(self, nodes):
210
+ coords, symbols = nodes['coords'], nodes['symbols']
211
+ grid = np.zeros((self.maxx, self.maxy), dtype=int)
212
+ for [x, y], symbol in zip(coords, symbols):
213
+ x = round(x * (self.maxx - 1))
214
+ y = round(y * (self.maxy - 1))
215
+ grid[x][y] = self.symbol_to_id(symbol)
216
+ return grid
217
+
218
+ def grid_to_nodes(self, grid):
219
+ coords, symbols, indices = [], [], []
220
+ for i in range(self.maxx):
221
+ for j in range(self.maxy):
222
+ if grid[i][j] != 0:
223
+ x = i / (self.maxx - 1)
224
+ y = j / (self.maxy - 1)
225
+ coords.append([x, y])
226
+ symbols.append(self.itos[grid[i][j]])
227
+ indices.append([i, j])
228
+ return {'coords': coords, 'symbols': symbols, 'indices': indices}
229
+
230
+ def nodes_to_sequence(self, nodes):
231
+ coords, symbols = nodes['coords'], nodes['symbols']
232
+ labels = [SOS_ID]
233
+ for (x, y), symbol in zip(coords, symbols):
234
+ assert 0 <= x <= 1
235
+ assert 0 <= y <= 1
236
+ labels.append(self.x_to_id(x))
237
+ labels.append(self.y_to_id(y))
238
+ labels.append(self.symbol_to_id(symbol))
239
+ labels.append(EOS_ID)
240
+ return labels
241
+
242
+ def sequence_to_nodes(self, sequence):
243
+ coords, symbols = [], []
244
+ i = 0
245
+ if sequence[0] == SOS_ID:
246
+ i += 1
247
+ while i + 2 < len(sequence):
248
+ if sequence[i] == EOS_ID:
249
+ break
250
+ if self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]):
251
+ x = self.id_to_x(sequence[i])
252
+ y = self.id_to_y(sequence[i+1])
253
+ symbol = self.itos[sequence[i+2]]
254
+ coords.append([x, y])
255
+ symbols.append(symbol)
256
+ i += 3
257
+ return {'coords': coords, 'symbols': symbols}
258
+
259
+ def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False):
260
+ tokens = atomwise_tokenizer(smiles)
261
+ labels = [SOS_ID]
262
+ indices = []
263
+ atom_idx = -1
264
+ for token in tokens:
265
+ if atom_only and not self.is_atom_token(token):
266
+ continue
267
+ if token in self.stoi:
268
+ labels.append(self.stoi[token])
269
+ else:
270
+ if self.debug:
271
+ print(f'{token} not in vocab')
272
+ labels.append(UNK_ID)
273
+ if self.is_atom_token(token):
274
+ atom_idx += 1
275
+ if not self.continuous_coords:
276
+ if mask_ratio > 0 and random.random() < mask_ratio:
277
+ labels.append(MASK_ID)
278
+ labels.append(MASK_ID)
279
+ elif coords is not None:
280
+ if atom_idx < len(coords):
281
+ x, y = coords[atom_idx]
282
+ assert 0 <= x <= 1
283
+ assert 0 <= y <= 1
284
+ else:
285
+ x = random.random()
286
+ y = random.random()
287
+ labels.append(self.x_to_id(x))
288
+ labels.append(self.y_to_id(y))
289
+ indices.append(len(labels) - 1)
290
+ labels.append(EOS_ID)
291
+ return labels, indices
292
+
293
+ def sequence_to_smiles(self, sequence):
294
+ has_coords = not self.continuous_coords
295
+ smiles = ''
296
+ coords, symbols, indices = [], [], []
297
+ for i, label in enumerate(sequence):
298
+ if label == EOS_ID or label == PAD_ID:
299
+ break
300
+ if self.is_x(label) or self.is_y(label):
301
+ continue
302
+ token = self.itos[label]
303
+ smiles += token
304
+ if self.is_atom_token(token):
305
+ if has_coords:
306
+ if i+3 < len(sequence) and self.is_x(sequence[i+1]) and self.is_y(sequence[i+2]):
307
+ x = self.id_to_x(sequence[i+1])
308
+ y = self.id_to_y(sequence[i+2])
309
+ coords.append([x, y])
310
+ symbols.append(token)
311
+ indices.append(i+3)
312
+ else:
313
+ if i+1 < len(sequence):
314
+ symbols.append(token)
315
+ indices.append(i+1)
316
+ results = {'smiles': smiles, 'symbols': symbols, 'indices': indices}
317
+ if has_coords:
318
+ results['coords'] = coords
319
+ return results
320
+
321
+
322
+ class CharTokenizer(NodeTokenizer):
323
+
324
+ def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False):
325
+ super().__init__(input_size, path, sep_xy, continuous_coords, debug)
326
+
327
+ def fit_on_texts(self, texts):
328
+ vocab = set()
329
+ for text in texts:
330
+ vocab.update(list(text))
331
+ if ' ' in vocab:
332
+ vocab.remove(' ')
333
+ vocab = [PAD, SOS, EOS, UNK] + list(vocab)
334
+ for i, s in enumerate(vocab):
335
+ self.stoi[s] = i
336
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
337
+ assert self.stoi[PAD] == PAD_ID
338
+ assert self.stoi[SOS] == SOS_ID
339
+ assert self.stoi[EOS] == EOS_ID
340
+ assert self.stoi[UNK] == UNK_ID
341
+
342
+ def text_to_sequence(self, text, tokenized=True):
343
+ sequence = []
344
+ sequence.append(self.stoi['<sos>'])
345
+ if tokenized:
346
+ tokens = text.split(' ')
347
+ assert all(len(s) == 1 for s in tokens)
348
+ else:
349
+ tokens = list(text)
350
+ for s in tokens:
351
+ if s not in self.stoi:
352
+ s = '<unk>'
353
+ sequence.append(self.stoi[s])
354
+ sequence.append(self.stoi['<eos>'])
355
+ return sequence
356
+
357
+ def fit_atom_symbols(self, atoms):
358
+ atoms = list(set(atoms))
359
+ chars = []
360
+ for atom in atoms:
361
+ chars.extend(list(atom))
362
+ vocab = self.special_tokens + chars
363
+ for i, s in enumerate(vocab):
364
+ self.stoi[s] = i
365
+ assert self.stoi[PAD] == PAD_ID
366
+ assert self.stoi[SOS] == SOS_ID
367
+ assert self.stoi[EOS] == EOS_ID
368
+ assert self.stoi[UNK] == UNK_ID
369
+ assert self.stoi[MASK] == MASK_ID
370
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
371
+
372
+ def get_output_mask(self, id):
373
+ ''' TO FIX '''
374
+ mask = [False] * len(self)
375
+ if self.continuous_coords:
376
+ return mask
377
+ if self.is_x(id):
378
+ return [True] * (self.offset + self.maxx) + [False] * self.maxy
379
+ if self.is_y(id):
380
+ return [False] * self.offset + [True] * (self.maxx + self.maxy)
381
+ return mask
382
+
383
+ def nodes_to_sequence(self, nodes):
384
+ coords, symbols = nodes['coords'], nodes['symbols']
385
+ labels = [SOS_ID]
386
+ for (x, y), symbol in zip(coords, symbols):
387
+ assert 0 <= x <= 1
388
+ assert 0 <= y <= 1
389
+ labels.append(self.x_to_id(x))
390
+ labels.append(self.y_to_id(y))
391
+ for char in symbol:
392
+ labels.append(self.symbol_to_id(char))
393
+ labels.append(EOS_ID)
394
+ return labels
395
+
396
+ def sequence_to_nodes(self, sequence):
397
+ coords, symbols = [], []
398
+ i = 0
399
+ if sequence[0] == SOS_ID:
400
+ i += 1
401
+ while i < len(sequence):
402
+ if sequence[i] == EOS_ID:
403
+ break
404
+ if i+2 < len(sequence) and self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]):
405
+ x = self.id_to_x(sequence[i])
406
+ y = self.id_to_y(sequence[i+1])
407
+ for j in range(i+2, len(sequence)):
408
+ if not self.is_symbol(sequence[j]):
409
+ break
410
+ symbol = ''.join(self.itos(sequence[k]) for k in range(i+2, j))
411
+ coords.append([x, y])
412
+ symbols.append(symbol)
413
+ i = j
414
+ else:
415
+ i += 1
416
+ return {'coords': coords, 'symbols': symbols}
417
+
418
+ def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False):
419
+ tokens = atomwise_tokenizer(smiles)
420
+ labels = [SOS_ID]
421
+ indices = []
422
+ atom_idx = -1
423
+ for token in tokens:
424
+ if atom_only and not self.is_atom_token(token):
425
+ continue
426
+ for c in token:
427
+ if c in self.stoi:
428
+ labels.append(self.stoi[c])
429
+ else:
430
+ if self.debug:
431
+ print(f'{c} not in vocab')
432
+ labels.append(UNK_ID)
433
+ if self.is_atom_token(token):
434
+ atom_idx += 1
435
+ if not self.continuous_coords:
436
+ if mask_ratio > 0 and random.random() < mask_ratio:
437
+ labels.append(MASK_ID)
438
+ labels.append(MASK_ID)
439
+ elif coords is not None:
440
+ if atom_idx < len(coords):
441
+ x, y = coords[atom_idx]
442
+ assert 0 <= x <= 1
443
+ assert 0 <= y <= 1
444
+ else:
445
+ x = random.random()
446
+ y = random.random()
447
+ labels.append(self.x_to_id(x))
448
+ labels.append(self.y_to_id(y))
449
+ indices.append(len(labels) - 1)
450
+ labels.append(EOS_ID)
451
+ return labels, indices
452
+
453
+ def sequence_to_smiles(self, sequence):
454
+ has_coords = not self.continuous_coords
455
+ smiles = ''
456
+ coords, symbols, indices = [], [], []
457
+ i = 0
458
+ while i < len(sequence):
459
+ label = sequence[i]
460
+ if label == EOS_ID or label == PAD_ID:
461
+ break
462
+ if self.is_x(label) or self.is_y(label):
463
+ i += 1
464
+ continue
465
+ if not self.is_atom(label):
466
+ smiles += self.itos[label]
467
+ i += 1
468
+ continue
469
+ if self.itos[label] == '[':
470
+ j = i + 1
471
+ while j < len(sequence):
472
+ if not self.is_symbol(sequence[j]):
473
+ break
474
+ if self.itos[sequence[j]] == ']':
475
+ j += 1
476
+ break
477
+ j += 1
478
+ else:
479
+ if i+1 < len(sequence) and (self.itos[label] == 'C' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'l' \
480
+ or self.itos[label] == 'B' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'r'):
481
+ j = i+2
482
+ else:
483
+ j = i+1
484
+ token = ''.join(self.itos[sequence[k]] for k in range(i, j))
485
+ smiles += token
486
+ if has_coords:
487
+ if j+2 < len(sequence) and self.is_x(sequence[j]) and self.is_y(sequence[j+1]):
488
+ x = self.id_to_x(sequence[j])
489
+ y = self.id_to_y(sequence[j+1])
490
+ coords.append([x, y])
491
+ symbols.append(token)
492
+ indices.append(j+2)
493
+ i = j+2
494
+ else:
495
+ i = j
496
+ else:
497
+ if j < len(sequence):
498
+ symbols.append(token)
499
+ indices.append(j)
500
+ i = j
501
+ results = {'smiles': smiles, 'symbols': symbols, 'indices': indices}
502
+ if has_coords:
503
+ results['coords'] = coords
504
+ return results
505
+
506
+
507
+ def get_tokenizer(args):
508
+ tokenizer = {}
509
+ for format_ in args.formats:
510
+ if format_ == 'atomtok':
511
+ if args.vocab_file is None:
512
+ args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json')
513
+ tokenizer['atomtok'] = Tokenizer(args.vocab_file)
514
+ elif format_ == "atomtok_coords":
515
+ if args.vocab_file is None:
516
+ args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json')
517
+ tokenizer["atomtok_coords"] = NodeTokenizer(args.coord_bins, args.vocab_file, args.sep_xy,
518
+ continuous_coords=args.continuous_coords)
519
+ elif format_ == "chartok_coords":
520
+ if args.vocab_file is None:
521
+ args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_chars.json')
522
+ tokenizer["chartok_coords"] = CharTokenizer(args.coord_bins, args.vocab_file, args.sep_xy,
523
+ continuous_coords=args.continuous_coords)
524
+ return tokenizer
molscribe/transformer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .decoder import TransformerDecoder
2
+ from .embedding import Embeddings
3
+ from .swin_transformer import swin_base, swin_large
molscribe/transformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (324 Bytes). View file
 
molscribe/transformer/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
molscribe/transformer/__pycache__/embedding.cpython-310.pyc ADDED
Binary file (7.91 kB). View file
 
molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc ADDED
Binary file (21.2 kB). View file
 
molscribe/transformer/decoder.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of "Attention is All You Need" and of
3
+ subsequent transformer based architectures
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from onmt.decoders.decoder import DecoderBase
10
+ from onmt.modules import MultiHeadedAttention, AverageAttention
11
+ from onmt.modules.position_ffn import PositionwiseFeedForward
12
+ from onmt.modules.position_ffn import ActivationFunction
13
+ from onmt.utils.misc import sequence_mask
14
+
15
+
16
+ class TransformerDecoderLayerBase(nn.Module):
17
+ def __init__(
18
+ self,
19
+ d_model,
20
+ heads,
21
+ d_ff,
22
+ dropout,
23
+ attention_dropout,
24
+ self_attn_type="scaled-dot",
25
+ max_relative_positions=0,
26
+ aan_useffn=False,
27
+ full_context_alignment=False,
28
+ alignment_heads=0,
29
+ pos_ffn_activation_fn=ActivationFunction.relu,
30
+ ):
31
+ """
32
+ Args:
33
+ d_model (int): the dimension of keys/values/queries in
34
+ :class:`MultiHeadedAttention`, also the input size of
35
+ the first-layer of the :class:`PositionwiseFeedForward`.
36
+ heads (int): the number of heads for MultiHeadedAttention.
37
+ d_ff (int): the second-layer of the
38
+ :class:`PositionwiseFeedForward`.
39
+ dropout (float): dropout in residual, self-attn(dot) and
40
+ feed-forward
41
+ attention_dropout (float): dropout in context_attn (and
42
+ self-attn(avg))
43
+ self_attn_type (string): type of self-attention scaled-dot,
44
+ average
45
+ max_relative_positions (int):
46
+ Max distance between inputs in relative positions
47
+ representations
48
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
49
+ full_context_alignment (bool):
50
+ whether enable an extra full context decoder forward for
51
+ alignment
52
+ alignment_heads (int):
53
+ N. of cross attention heads to use for alignment guiding
54
+ pos_ffn_activation_fn (ActivationFunction):
55
+ activation function choice for PositionwiseFeedForward layer
56
+
57
+ """
58
+ super(TransformerDecoderLayerBase, self).__init__()
59
+
60
+ if self_attn_type == "scaled-dot":
61
+ self.self_attn = MultiHeadedAttention(
62
+ heads,
63
+ d_model,
64
+ dropout=attention_dropout,
65
+ max_relative_positions=max_relative_positions,
66
+ )
67
+ elif self_attn_type == "average":
68
+ self.self_attn = AverageAttention(
69
+ d_model, dropout=attention_dropout, aan_useffn=aan_useffn
70
+ )
71
+
72
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
73
+ pos_ffn_activation_fn
74
+ )
75
+ self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
76
+ self.drop = nn.Dropout(dropout)
77
+ self.full_context_alignment = full_context_alignment
78
+ self.alignment_heads = alignment_heads
79
+
80
+ def forward(self, *args, **kwargs):
81
+ """Extend `_forward` for (possibly) multiple decoder pass:
82
+ Always a default (future masked) decoder forward pass,
83
+ Possibly a second future aware decoder pass for joint learn
84
+ full context alignement, :cite:`garg2019jointly`.
85
+
86
+ Args:
87
+ * All arguments of _forward.
88
+ with_align (bool): whether return alignment attention.
89
+
90
+ Returns:
91
+ (FloatTensor, FloatTensor, FloatTensor or None):
92
+
93
+ * output ``(batch_size, T, model_dim)``
94
+ * top_attn ``(batch_size, T, src_len)``
95
+ * attn_align ``(batch_size, T, src_len)`` or None
96
+ """
97
+ with_align = kwargs.pop("with_align", False)
98
+ output, attns = self._forward(*args, **kwargs)
99
+ top_attn = attns[:, 0, :, :].contiguous()
100
+ attn_align = None
101
+ if with_align:
102
+ if self.full_context_alignment:
103
+ # return _, (B, Q_len, K_len)
104
+ _, attns = self._forward(*args, **kwargs, future=True)
105
+
106
+ if self.alignment_heads > 0:
107
+ attns = attns[:, : self.alignment_heads, :, :].contiguous()
108
+ # layer average attention across heads, get ``(B, Q, K)``
109
+ # Case 1: no full_context, no align heads -> layer avg baseline
110
+ # Case 2: no full_context, 1 align heads -> guided align
111
+ # Case 3: full_context, 1 align heads -> full cte guided align
112
+ attn_align = attns.mean(dim=1)
113
+ return output, top_attn, attn_align
114
+
115
+ def update_dropout(self, dropout, attention_dropout):
116
+ self.self_attn.update_dropout(attention_dropout)
117
+ self.feed_forward.update_dropout(dropout)
118
+ self.drop.p = dropout
119
+
120
+ def _forward(self, *args, **kwargs):
121
+ raise NotImplementedError
122
+
123
+ def _compute_dec_mask(self, tgt_pad_mask, future):
124
+ tgt_len = tgt_pad_mask.size(-1)
125
+ if not future: # apply future_mask, result mask in (B, T, T)
126
+ future_mask = torch.ones(
127
+ [tgt_len, tgt_len],
128
+ device=tgt_pad_mask.device,
129
+ dtype=torch.uint8,
130
+ )
131
+ future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
132
+ # BoolTensor was introduced in pytorch 1.2
133
+ try:
134
+ future_mask = future_mask.bool()
135
+ except AttributeError:
136
+ pass
137
+ dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
138
+ else: # only mask padding, result mask in (B, 1, T)
139
+ dec_mask = tgt_pad_mask
140
+ return dec_mask
141
+
142
+ def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
143
+ if isinstance(self.self_attn, MultiHeadedAttention):
144
+ return self.self_attn(
145
+ inputs_norm,
146
+ inputs_norm,
147
+ inputs_norm,
148
+ mask=dec_mask,
149
+ layer_cache=layer_cache,
150
+ attn_type="self",
151
+ )
152
+ elif isinstance(self.self_attn, AverageAttention):
153
+ return self.self_attn(
154
+ inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step
155
+ )
156
+ else:
157
+ raise ValueError(
158
+ f"self attention {type(self.self_attn)} not supported"
159
+ )
160
+
161
+
162
+ class TransformerDecoderLayer(TransformerDecoderLayerBase):
163
+ """Transformer Decoder layer block in Pre-Norm style.
164
+ Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
165
+ providing better converge speed and performance. This is also the actual
166
+ implementation in tensor2tensor and also avalable in fairseq.
167
+ See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
168
+
169
+ .. mermaid::
170
+
171
+ graph LR
172
+ %% "*SubLayer" can be self-attn, src-attn or feed forward block
173
+ A(input) --> B[Norm]
174
+ B --> C["*SubLayer"]
175
+ C --> D[Drop]
176
+ D --> E((+))
177
+ A --> E
178
+ E --> F(out)
179
+
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ d_model,
185
+ heads,
186
+ d_ff,
187
+ dropout,
188
+ attention_dropout,
189
+ self_attn_type="scaled-dot",
190
+ max_relative_positions=0,
191
+ aan_useffn=False,
192
+ full_context_alignment=False,
193
+ alignment_heads=0,
194
+ pos_ffn_activation_fn=ActivationFunction.relu,
195
+ ):
196
+ """
197
+ Args:
198
+ See TransformerDecoderLayerBase
199
+ """
200
+ super(TransformerDecoderLayer, self).__init__(
201
+ d_model,
202
+ heads,
203
+ d_ff,
204
+ dropout,
205
+ attention_dropout,
206
+ self_attn_type,
207
+ max_relative_positions,
208
+ aan_useffn,
209
+ full_context_alignment,
210
+ alignment_heads,
211
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
212
+ )
213
+ self.context_attn = MultiHeadedAttention(
214
+ heads, d_model, dropout=attention_dropout
215
+ )
216
+ self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
217
+
218
+ def update_dropout(self, dropout, attention_dropout):
219
+ super(TransformerDecoderLayer, self).update_dropout(
220
+ dropout, attention_dropout
221
+ )
222
+ self.context_attn.update_dropout(attention_dropout)
223
+
224
+ def _forward(
225
+ self,
226
+ inputs,
227
+ memory_bank,
228
+ src_pad_mask,
229
+ tgt_pad_mask,
230
+ layer_cache=None,
231
+ step=None,
232
+ future=False,
233
+ ):
234
+ """A naive forward pass for transformer decoder.
235
+
236
+ # T: could be 1 in the case of stepwise decoding or tgt_len
237
+
238
+ Args:
239
+ inputs (FloatTensor): ``(batch_size, T, model_dim)``
240
+ memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
241
+ src_pad_mask (bool): ``(batch_size, 1, src_len)``
242
+ tgt_pad_mask (bool): ``(batch_size, 1, T)``
243
+ layer_cache (dict or None): cached layer info when stepwise decode
244
+ step (int or None): stepwise decoding counter
245
+ future (bool): If set True, do not apply future_mask.
246
+
247
+ Returns:
248
+ (FloatTensor, FloatTensor):
249
+
250
+ * output ``(batch_size, T, model_dim)``
251
+ * attns ``(batch_size, head, T, src_len)``
252
+
253
+ """
254
+ dec_mask = None
255
+
256
+ if inputs.size(1) > 1:
257
+ # masking is necessary when sequence length is greater than one
258
+ dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
259
+
260
+ inputs_norm = self.layer_norm_1(inputs)
261
+
262
+ query, _ = self._forward_self_attn(
263
+ inputs_norm, dec_mask, layer_cache, step
264
+ )
265
+
266
+ query = self.drop(query) + inputs
267
+
268
+ query_norm = self.layer_norm_2(query)
269
+ mid, attns = self.context_attn(
270
+ memory_bank,
271
+ memory_bank,
272
+ query_norm,
273
+ mask=src_pad_mask,
274
+ layer_cache=layer_cache,
275
+ attn_type="context",
276
+ )
277
+ output = self.feed_forward(self.drop(mid) + query)
278
+
279
+ return output, attns
280
+
281
+
282
+ class TransformerDecoderBase(DecoderBase):
283
+ def __init__(self, d_model, copy_attn, alignment_layer):
284
+ super(TransformerDecoderBase, self).__init__()
285
+
286
+ # Decoder State
287
+ self.state = {}
288
+
289
+ # previously, there was a GlobalAttention module here for copy
290
+ # attention. But it was never actually used -- the "copy" attention
291
+ # just reuses the context attention.
292
+ self._copy = copy_attn
293
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
294
+
295
+ self.alignment_layer = alignment_layer
296
+
297
+ @classmethod
298
+ def from_opt(cls, opt, embeddings):
299
+ """Alternate constructor."""
300
+ return cls(
301
+ opt.dec_layers,
302
+ opt.dec_rnn_size,
303
+ opt.heads,
304
+ opt.transformer_ff,
305
+ opt.copy_attn,
306
+ opt.self_attn_type,
307
+ opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
308
+ opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout,
309
+ embeddings,
310
+ opt.max_relative_positions,
311
+ opt.aan_useffn,
312
+ opt.full_context_alignment,
313
+ opt.alignment_layer,
314
+ alignment_heads=opt.alignment_heads,
315
+ pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
316
+ )
317
+
318
+ def init_state(self, src, memory_bank, enc_hidden):
319
+ """Initialize decoder state."""
320
+ self.state["src"] = src
321
+ self.state["cache"] = None
322
+
323
+ def map_state(self, fn):
324
+ def _recursive_map(struct, batch_dim=0):
325
+ for k, v in struct.items():
326
+ if v is not None:
327
+ if isinstance(v, dict):
328
+ _recursive_map(v)
329
+ else:
330
+ struct[k] = fn(v, batch_dim)
331
+
332
+ if self.state["src"] is not None:
333
+ self.state["src"] = fn(self.state["src"], 1)
334
+ if self.state["cache"] is not None:
335
+ _recursive_map(self.state["cache"])
336
+
337
+ def detach_state(self):
338
+ raise NotImplementedError
339
+
340
+ def forward(self, *args, **kwargs):
341
+ raise NotImplementedError
342
+
343
+ def update_dropout(self, dropout, attention_dropout):
344
+ self.embeddings.update_dropout(dropout)
345
+ for layer in self.transformer_layers:
346
+ layer.update_dropout(dropout, attention_dropout)
347
+
348
+
349
+ class TransformerDecoder(TransformerDecoderBase):
350
+ """The Transformer decoder from "Attention is All You Need".
351
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
352
+
353
+ .. mermaid::
354
+
355
+ graph BT
356
+ A[input]
357
+ B[multi-head self-attn]
358
+ BB[multi-head src-attn]
359
+ C[feed forward]
360
+ O[output]
361
+ A --> B
362
+ B --> BB
363
+ BB --> C
364
+ C --> O
365
+
366
+
367
+ Args:
368
+ num_layers (int): number of decoder layers.
369
+ d_model (int): size of the model
370
+ heads (int): number of heads
371
+ d_ff (int): size of the inner FF layer
372
+ copy_attn (bool): if using a separate copy attention
373
+ self_attn_type (str): type of self-attention scaled-dot, average
374
+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
375
+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
376
+ embeddings (onmt.modules.Embeddings):
377
+ embeddings to use, should have positional encodings
378
+ max_relative_positions (int):
379
+ Max distance between inputs in relative positions representations
380
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
381
+ full_context_alignment (bool):
382
+ whether enable an extra full context decoder forward for alignment
383
+ alignment_layer (int): N° Layer to supervise with for alignment guiding
384
+ alignment_heads (int):
385
+ N. of cross attention heads to use for alignment guiding
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ num_layers,
391
+ d_model,
392
+ heads,
393
+ d_ff,
394
+ copy_attn,
395
+ self_attn_type,
396
+ dropout,
397
+ attention_dropout,
398
+ max_relative_positions,
399
+ aan_useffn,
400
+ full_context_alignment,
401
+ alignment_layer,
402
+ alignment_heads,
403
+ pos_ffn_activation_fn=ActivationFunction.relu,
404
+ ):
405
+ super(TransformerDecoder, self).__init__(
406
+ d_model, copy_attn, alignment_layer
407
+ )
408
+
409
+ self.transformer_layers = nn.ModuleList(
410
+ [
411
+ TransformerDecoderLayer(
412
+ d_model,
413
+ heads,
414
+ d_ff,
415
+ dropout,
416
+ attention_dropout,
417
+ self_attn_type=self_attn_type,
418
+ max_relative_positions=max_relative_positions,
419
+ aan_useffn=aan_useffn,
420
+ full_context_alignment=full_context_alignment,
421
+ alignment_heads=alignment_heads,
422
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
423
+ )
424
+ for i in range(num_layers)
425
+ ]
426
+ )
427
+
428
+ def detach_state(self):
429
+ self.state["src"] = self.state["src"].detach()
430
+
431
+ def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs):
432
+ """Decode, possibly stepwise."""
433
+ if step == 0:
434
+ self._init_cache(memory_bank)
435
+
436
+ batch_size, src_len, src_dim = memory_bank.size()
437
+ device = memory_bank.device
438
+ if src_pad_mask is None:
439
+ src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device)
440
+ output = tgt_emb
441
+ batch_size, tgt_len, tgt_dim = tgt_emb.size()
442
+ if tgt_pad_mask is None:
443
+ tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device)
444
+
445
+ future = kwargs.pop("future", False)
446
+ with_align = kwargs.pop("with_align", False)
447
+ attn_aligns = []
448
+ hiddens = []
449
+
450
+ for i, layer in enumerate(self.transformer_layers):
451
+ layer_cache = (
452
+ self.state["cache"]["layer_{}".format(i)]
453
+ if step is not None
454
+ else None
455
+ )
456
+ output, attn, attn_align = layer(
457
+ output,
458
+ memory_bank,
459
+ src_pad_mask,
460
+ tgt_pad_mask,
461
+ layer_cache=layer_cache,
462
+ step=step,
463
+ with_align=with_align,
464
+ future=future
465
+ )
466
+ hiddens.append(output)
467
+ if attn_align is not None:
468
+ attn_aligns.append(attn_align)
469
+
470
+ output = self.layer_norm(output) # (B, L, D)
471
+
472
+ attns = {"std": attn}
473
+ if self._copy:
474
+ attns["copy"] = attn
475
+ if with_align:
476
+ attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
477
+ # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
478
+
479
+ # TODO change the way attns is returned dict => list or tuple (onnx)
480
+ return output, attns, hiddens
481
+
482
+ def _init_cache(self, memory_bank):
483
+ self.state["cache"] = {}
484
+ for i, layer in enumerate(self.transformer_layers):
485
+ layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None}
486
+ self.state["cache"]["layer_{}".format(i)] = layer_cache
487
+