diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f53b0417ac62aeb10987c0da74b3e2f141fe01eb
--- /dev/null
+++ b/app.py
@@ -0,0 +1,131 @@
+import os
+import gradio as gr
+import json
+from rxnim import RXNIM
+from getReaction import generate_combined_image
+import torch
+from rxn.reaction import Reaction
+
+PROMPT_DIR = "prompts/"
+ckpt_path = "./rxn/model/model.ckpt"
+model = Reaction(ckpt_path, device=torch.device('cpu'))
+
+# 定义 prompt 文件名到友好名字的映射
+PROMPT_NAMES = {
+ "2_RxnOCR.txt": "Reaction Image Parsing Workflow",
+}
+example_diagram = "examples/exp.png"
+
+def list_prompt_files_with_names():
+ """
+ 列出 prompts 目录下的所有 .txt 文件,为没有名字的生成默认名字。
+ 返回 {friendly_name: filename} 映射。
+ """
+ prompt_files = {}
+ for f in os.listdir(PROMPT_DIR):
+ if f.endswith(".txt"):
+ # 如果文件名有预定义的名字,使用预定义名字
+ friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}")
+ prompt_files[friendly_name] = f
+ return prompt_files
+
+def parse_reactions(output_json):
+ """
+ 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。
+ """
+ reactions_data = json.loads(output_json) # 转换 JSON 字符串为字典
+ reactions_list = reactions_data.get("reactions", [])
+ detailed_output = []
+
+ for reaction in reactions_list:
+ reaction_id = reaction.get("reaction_id", "Unknown ID")
+ reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
+ conditions = [
+ f"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]"
+ for c in reaction.get("conditions", [])
+ ]
+ conditions_1 = [
+ f"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]"
+ for c in reaction.get("conditions", [])
+ ]
+ products = [f"{p.get('smiles', 'Unknown')}" for p in reaction.get("products", [])]
+ products_1 = [f"{p.get('smiles', 'Unknown')}" for p in reaction.get("products", [])]
+
+ # 构造反应的完整字符串,定制字体颜色
+ full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}"
+ full_reaction = f"{full_reaction}"
+
+ # 详细反应格式化输出
+ reaction_output = f"Reaction: {reaction_id}
"
+ reaction_output += f" Reactants: {', '.join(reactants)}
"
+ reaction_output += f" Conditions: {', '.join(conditions)}
"
+ reaction_output += f" Products: {', '.join(products)}
"
+ reaction_output += f" Full Reaction: {full_reaction}
"
+ reaction_output += "
"
+ detailed_output.append(reaction_output)
+
+ return detailed_output
+
+def process_chem_image(image, selected_task):
+ chem_mllm = RXNIM()
+
+ # 将友好名字转换为实际文件名
+ prompt_path = os.path.join(PROMPT_DIR, prompts_with_names[selected_task])
+ image_path = "temp_image.png"
+ image.save(image_path)
+
+ # 调用 RXNIM 处理
+ rxnim_result = chem_mllm.process(image_path, prompt_path)
+
+ # 将 JSON 结果解析为结构化输出
+ detailed_reactions = parse_reactions(rxnim_result)
+
+ # 调用 RxnScribe 模型处理并生成整合图像
+ predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
+ combined_image_path = generate_combined_image(predictions, image_path)
+
+ json_file_path = "output.json"
+ with open(json_file_path, "w") as json_file:
+ json.dump(json.loads(rxnim_result), json_file, indent=4)
+
+
+ # 返回详细反应和整合图像
+ return "\n\n".join(detailed_reactions), combined_image_path, example_diagram, json_file_path
+
+
+# 获取 prompts 和友好名字
+prompts_with_names = list_prompt_files_with_names()
+
+# 示例数据:图像路径 + 任务选项
+examples = [
+
+ ["examples/reaction1.png", "Reaction Image Parsing Workflow"],
+ ["examples/reaction2.png", "Reaction Image Parsing Workflow"],
+ ["examples/reaction3.png", "Reaction Image Parsing Workflow"],
+ ["examples/reaction4.png", "Reaction Image Parsing Workflow"],
+]
+
+# 定义 Gradio 界面
+demo = gr.Interface(
+ fn=process_chem_image,
+ inputs=[
+ gr.Image(type="pil", label="Upload Reaction Image"),
+ gr.Radio(
+ choices=list(prompts_with_names.keys()), # 显示任务名字
+ label="Select a predefined task",
+ ),
+ ],
+ outputs=[
+ gr.HTML(label="Reaction outputs"),
+ gr.Image(label="Visualization"), # 显示整合图像
+ gr.Image(value=example_diagram, label="Schematic Diagram"),
+ gr.File(label="Download JSON File"),
+
+ ],
+ title="Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model",
+ description="Upload a reaction image and select a predefined task prompt.",
+ examples=examples, # 使用嵌套列表作为示例
+ examples_per_page=20,
+)
+
+demo.launch()
diff --git a/examples/exp.png b/examples/exp.png
new file mode 100644
index 0000000000000000000000000000000000000000..0a91f5ef711bbf739bf2df1e5d6b983eef8bcd06
Binary files /dev/null and b/examples/exp.png differ
diff --git a/examples/reaction1.png b/examples/reaction1.png
new file mode 100644
index 0000000000000000000000000000000000000000..65f6291e5d0dabee694b667433bca915a830990e
Binary files /dev/null and b/examples/reaction1.png differ
diff --git a/examples/reaction2.png b/examples/reaction2.png
new file mode 100644
index 0000000000000000000000000000000000000000..111258a3502714b5205bf68e9a98f10605bb7fde
Binary files /dev/null and b/examples/reaction2.png differ
diff --git a/examples/reaction3.png b/examples/reaction3.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0efc39605b2e99b940f592582e8c5f5344f6062
Binary files /dev/null and b/examples/reaction3.png differ
diff --git a/examples/reaction4.png b/examples/reaction4.png
new file mode 100644
index 0000000000000000000000000000000000000000..627f120b9a70ddfb6bfd05aa8554d4087191b772
Binary files /dev/null and b/examples/reaction4.png differ
diff --git a/getReaction.py b/getReaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eddfc7d65270297622fcecda3775529b0113c01
--- /dev/null
+++ b/getReaction.py
@@ -0,0 +1,78 @@
+import sys
+sys.path.append('./rxn/')
+import torch
+from rxn.reaction import Reaction
+import json
+from matplotlib import pyplot as plt
+import numpy as np
+
+ckpt_path = "./rxn/model/model.ckpt"
+model = Reaction(ckpt_path, device=torch.device('cpu'))
+device = torch.device('cpu')
+
+def get_reaction(image_path: str) -> list:
+ '''Returns a list of reactions extracted from the image.'''
+ image_file = image_path
+ return json.dumps(model.predict_image_file(image_file, molscribe=True, ocr=True))
+
+
+
+def generate_combined_image(predictions, image_file):
+ """
+ 将预测的图像整合到一个对称的布局中输出。
+ """
+ output = model.draw_predictions(predictions, image_file=image_file)
+ n_images = len(output)
+ if n_images == 1:
+ n_cols = 1
+ elif n_images == 2:
+ n_cols = 2
+ else:
+ n_cols = 3
+ n_rows = (n_images + n_cols - 1) // n_cols # 计算需要的行数
+
+ # 确保每张图像符合要求
+ processed_images = []
+ for img in output:
+ if len(img.shape) == 2: # 灰度图像
+ img = np.stack([img] * 3, axis=-1) # 转换为 RGB 格式
+ elif img.shape[2] > 3: # RGBA 图像
+ img = img[:, :, :3] # 只保留 RGB 通道
+ if img.dtype == np.float32 or img.dtype == np.float64:
+ img = (img * 255).astype(np.uint8) # 转换为 uint8
+ processed_images.append(img)
+ output = processed_images
+
+ # 为不足的子图位置添加占位图
+ if n_images < n_rows * n_cols:
+ blank_image = np.ones_like(output[0]) * 255 # 生成一个白色占位图
+ while len(output) < n_rows * n_cols:
+ output.append(blank_image)
+
+ # 创建子图画布
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
+
+ # 确保 axes 是一维数组
+ if isinstance(axes, np.ndarray):
+ axes = axes.flatten()
+ else:
+ axes = [axes] # 单个子图的情况
+
+ # 绘制每张图像
+ for idx, img in enumerate(output):
+ ax = axes[idx]
+ ax.imshow(img)
+ ax.axis('off')
+ if idx < n_images:
+ ax.set_title(f"Reaction {idx + 1}")
+
+ # 删除多余的子图
+ for idx in range(n_images, len(axes)):
+ fig.delaxes(axes[idx])
+
+ # 保存整合图像
+ combined_image_path = "combined_output.png"
+ plt.tight_layout()
+ plt.savefig(combined_image_path)
+ plt.close(fig)
+ return combined_image_path
diff --git a/molscribe/__init__.py b/molscribe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4473879606f65474f5fe4d5ba1be97c695649852
--- /dev/null
+++ b/molscribe/__init__.py
@@ -0,0 +1 @@
+from .interface import MolScribe
diff --git a/molscribe/__pycache__/__init__.cpython-310.pyc b/molscribe/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdef7b6dc638aa80c4487d71d8ef73c65fe80d2e
Binary files /dev/null and b/molscribe/__pycache__/__init__.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/augment.cpython-310.pyc b/molscribe/__pycache__/augment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf8ae46b7badb56a6fe34e219e9d4722b7dea3a6
Binary files /dev/null and b/molscribe/__pycache__/augment.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/chemistry.cpython-310.pyc b/molscribe/__pycache__/chemistry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b699bd0fd816d35417576ff2eb4daa7442dad17
Binary files /dev/null and b/molscribe/__pycache__/chemistry.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/constants.cpython-310.pyc b/molscribe/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb69a972acfeede2648dbf25e83ca301032784c1
Binary files /dev/null and b/molscribe/__pycache__/constants.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/dataset.cpython-310.pyc b/molscribe/__pycache__/dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..212331517e79bda247482c919f939e69acdbfeed
Binary files /dev/null and b/molscribe/__pycache__/dataset.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/evaluate.cpython-310.pyc b/molscribe/__pycache__/evaluate.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42d8e3a260ca800e72ffef8527d3ce4b232dafd6
Binary files /dev/null and b/molscribe/__pycache__/evaluate.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/interface.cpython-310.pyc b/molscribe/__pycache__/interface.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e65b90a47d4bc13f99c5ab024ef7b00198395eb0
Binary files /dev/null and b/molscribe/__pycache__/interface.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/loss.cpython-310.pyc b/molscribe/__pycache__/loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e90bdd4bd9059040639920684546380d7700c2ea
Binary files /dev/null and b/molscribe/__pycache__/loss.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/model.cpython-310.pyc b/molscribe/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a0581d46fe9f1ef3f6bcb6611343bbfa39a15f1
Binary files /dev/null and b/molscribe/__pycache__/model.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/tokenizer.cpython-310.pyc b/molscribe/__pycache__/tokenizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70a934bca40695fe3cc8f0d145bd84454ffba861
Binary files /dev/null and b/molscribe/__pycache__/tokenizer.cpython-310.pyc differ
diff --git a/molscribe/__pycache__/utils.cpython-310.pyc b/molscribe/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4561b81c5b88e04e2dd7e3859b0ca9fbaee12fce
Binary files /dev/null and b/molscribe/__pycache__/utils.cpython-310.pyc differ
diff --git a/molscribe/augment.py b/molscribe/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..a80ebc505c4698238386902affd3377b7a7ea885
--- /dev/null
+++ b/molscribe/augment.py
@@ -0,0 +1,282 @@
+import albumentations as A
+from albumentations.augmentations.geometric.functional import safe_rotate_enlarged_img_size, _maybe_process_in_chunks, \
+ keypoint_rotate
+import cv2
+import math
+import random
+import numpy as np
+
+
+def safe_rotate(
+ img: np.ndarray,
+ angle: int = 0,
+ interpolation: int = cv2.INTER_LINEAR,
+ value: int = None,
+ border_mode: int = cv2.BORDER_REFLECT_101,
+):
+
+ old_rows, old_cols = img.shape[:2]
+
+ # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
+ image_center = (old_cols / 2, old_rows / 2)
+
+ # Rows and columns of the rotated image (not cropped)
+ new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols)
+
+ # Rotation Matrix
+ rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
+
+ # Shift the image to create padding
+ rotation_mat[0, 2] += new_cols / 2 - image_center[0]
+ rotation_mat[1, 2] += new_rows / 2 - image_center[1]
+
+ # CV2 Transformation function
+ warp_affine_fn = _maybe_process_in_chunks(
+ cv2.warpAffine,
+ M=rotation_mat,
+ dsize=(new_cols, new_rows),
+ flags=interpolation,
+ borderMode=border_mode,
+ borderValue=value,
+ )
+
+ # rotate image with the new bounds
+ rotated_img = warp_affine_fn(img)
+
+ return rotated_img
+
+
+def keypoint_safe_rotate(keypoint, angle, rows, cols):
+ old_rows = rows
+ old_cols = cols
+
+ # Rows and columns of the rotated image (not cropped)
+ new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols)
+
+ col_diff = (new_cols - old_cols) / 2
+ row_diff = (new_rows - old_rows) / 2
+
+ # Shift keypoint
+ shifted_keypoint = (int(keypoint[0] + col_diff), int(keypoint[1] + row_diff), keypoint[2], keypoint[3])
+
+ # Rotate keypoint
+ rotated_keypoint = keypoint_rotate(shifted_keypoint, angle, rows=new_rows, cols=new_cols)
+
+ return rotated_keypoint
+
+
+class SafeRotate(A.SafeRotate):
+
+ def __init__(
+ self,
+ limit=90,
+ interpolation=cv2.INTER_LINEAR,
+ border_mode=cv2.BORDER_REFLECT_101,
+ value=None,
+ mask_value=None,
+ always_apply=False,
+ p=0.5,
+ ):
+ super(SafeRotate, self).__init__(
+ limit=limit,
+ interpolation=interpolation,
+ border_mode=border_mode,
+ value=value,
+ mask_value=mask_value,
+ always_apply=always_apply,
+ p=p)
+
+ def apply(self, img, angle=0, interpolation=cv2.INTER_LINEAR, **params):
+ return safe_rotate(
+ img=img, value=self.value, angle=angle, interpolation=interpolation, border_mode=self.border_mode)
+
+ def apply_to_keypoint(self, keypoint, angle=0, **params):
+ return keypoint_safe_rotate(keypoint, angle=angle, rows=params["rows"], cols=params["cols"])
+
+
+class CropWhite(A.DualTransform):
+
+ def __init__(self, value=(255, 255, 255), pad=0, p=1.0):
+ super(CropWhite, self).__init__(p=p)
+ self.value = value
+ self.pad = pad
+ assert pad >= 0
+
+ def update_params(self, params, **kwargs):
+ super().update_params(params, **kwargs)
+ assert "image" in kwargs
+ img = kwargs["image"]
+ height, width, _ = img.shape
+ x = (img != self.value).sum(axis=2)
+ if x.sum() == 0:
+ return params
+ row_sum = x.sum(axis=1)
+ top = 0
+ while row_sum[top] == 0 and top+1 < height:
+ top += 1
+ bottom = height
+ while row_sum[bottom-1] == 0 and bottom-1 > top:
+ bottom -= 1
+ col_sum = x.sum(axis=0)
+ left = 0
+ while col_sum[left] == 0 and left+1 < width:
+ left += 1
+ right = width
+ while col_sum[right-1] == 0 and right-1 > left:
+ right -= 1
+ # crop_top = max(0, top - self.pad)
+ # crop_bottom = max(0, height - bottom - self.pad)
+ # crop_left = max(0, left - self.pad)
+ # crop_right = max(0, width - right - self.pad)
+ # params.update({"crop_top": crop_top, "crop_bottom": crop_bottom,
+ # "crop_left": crop_left, "crop_right": crop_right})
+ params.update({"crop_top": top, "crop_bottom": height - bottom,
+ "crop_left": left, "crop_right": width - right})
+ return params
+
+ def apply(self, img, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params):
+ height, width, _ = img.shape
+ img = img[crop_top:height - crop_bottom, crop_left:width - crop_right]
+ img = A.augmentations.pad_with_params(
+ img, self.pad, self.pad, self.pad, self.pad, border_mode=cv2.BORDER_CONSTANT, value=self.value)
+ return img
+
+ def apply_to_keypoint(self, keypoint, crop_top=0, crop_bottom=0, crop_left=0, crop_right=0, **params):
+ x, y, angle, scale = keypoint[:4]
+ return x - crop_left + self.pad, y - crop_top + self.pad, angle, scale
+
+ def get_transform_init_args_names(self):
+ return ('value', 'pad')
+
+
+class PadWhite(A.DualTransform):
+
+ def __init__(self, pad_ratio=0.2, p=0.5, value=(255, 255, 255)):
+ super(PadWhite, self).__init__(p=p)
+ self.pad_ratio = pad_ratio
+ self.value = value
+
+ def update_params(self, params, **kwargs):
+ super().update_params(params, **kwargs)
+ assert "image" in kwargs
+ img = kwargs["image"]
+ height, width, _ = img.shape
+ side = random.randrange(4)
+ if side == 0:
+ params['pad_top'] = int(height * self.pad_ratio * random.random())
+ elif side == 1:
+ params['pad_bottom'] = int(height * self.pad_ratio * random.random())
+ elif side == 2:
+ params['pad_left'] = int(width * self.pad_ratio * random.random())
+ elif side == 3:
+ params['pad_right'] = int(width * self.pad_ratio * random.random())
+ return params
+
+ def apply(self, img, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
+ height, width, _ = img.shape
+ img = A.augmentations.pad_with_params(
+ img, pad_top, pad_bottom, pad_left, pad_right, border_mode=cv2.BORDER_CONSTANT, value=self.value)
+ return img
+
+ def apply_to_keypoint(self, keypoint, pad_top=0, pad_bottom=0, pad_left=0, pad_right=0, **params):
+ x, y, angle, scale = keypoint[:4]
+ return x + pad_left, y + pad_top, angle, scale
+
+ def get_transform_init_args_names(self):
+ return ('value', 'pad_ratio')
+
+
+class SaltAndPepperNoise(A.DualTransform):
+
+ def __init__(self, num_dots, value=(0, 0, 0), p=0.5):
+ super().__init__(p)
+ self.num_dots = num_dots
+ self.value = value
+
+ def apply(self, img, **params):
+ height, width, _ = img.shape
+ num_dots = random.randrange(self.num_dots + 1)
+ for i in range(num_dots):
+ x = random.randrange(height)
+ y = random.randrange(width)
+ img[x, y] = self.value
+ return img
+
+ def apply_to_keypoint(self, keypoint, **params):
+ return keypoint
+
+ def get_transform_init_args_names(self):
+ return ('value', 'num_dots')
+
+class ResizePad(A.DualTransform):
+
+ def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, value=(255, 255, 255)):
+ super(ResizePad, self).__init__(always_apply=True)
+ self.height = height
+ self.width = width
+ self.interpolation = interpolation
+ self.value = value
+
+ def apply(self, img, interpolation=cv2.INTER_LINEAR, **params):
+ h, w, _ = img.shape
+ img = A.augmentations.geometric.functional.resize(
+ img,
+ height=min(h, self.height),
+ width=min(w, self.width),
+ interpolation=interpolation
+ )
+ h, w, _ = img.shape
+ pad_top = (self.height - h) // 2
+ pad_bottom = (self.height - h) - pad_top
+ pad_left = (self.width - w) // 2
+ pad_right = (self.width - w) - pad_left
+ img = A.augmentations.pad_with_params(
+ img,
+ pad_top,
+ pad_bottom,
+ pad_left,
+ pad_right,
+ border_mode=cv2.BORDER_CONSTANT,
+ value=self.value,
+ )
+ return img
+
+
+def normalized_grid_distortion(
+ img,
+ num_steps=10,
+ xsteps=(),
+ ysteps=(),
+ *args,
+ **kwargs
+):
+ height, width = img.shape[:2]
+
+ # compensate for smaller last steps in source image.
+ x_step = width // num_steps
+ last_x_step = min(width, ((num_steps + 1) * x_step)) - (num_steps * x_step)
+ xsteps[-1] *= last_x_step / x_step
+
+ y_step = height // num_steps
+ last_y_step = min(height, ((num_steps + 1) * y_step)) - (num_steps * y_step)
+ ysteps[-1] *= last_y_step / y_step
+
+ # now normalize such that distortion never leaves image bounds.
+ tx = width / math.floor(width / num_steps)
+ ty = height / math.floor(height / num_steps)
+ xsteps = np.array(xsteps) * (tx / np.sum(xsteps))
+ ysteps = np.array(ysteps) * (ty / np.sum(ysteps))
+
+ # do actual distortion.
+ return A.augmentations.functional.grid_distortion(img, num_steps, xsteps, ysteps, *args, **kwargs)
+
+
+class NormalizedGridDistortion(A.augmentations.transforms.GridDistortion):
+ def apply(self, img, stepsx=(), stepsy=(), interpolation=cv2.INTER_LINEAR, **params):
+ return normalized_grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode,
+ self.value)
+
+ def apply_to_mask(self, img, stepsx=(), stepsy=(), **params):
+ return normalized_grid_distortion(
+ img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
+
diff --git a/molscribe/chemistry.py b/molscribe/chemistry.py
new file mode 100644
index 0000000000000000000000000000000000000000..145ecc2e73772796100aabc3bf6491f6d6129569
--- /dev/null
+++ b/molscribe/chemistry.py
@@ -0,0 +1,649 @@
+import copy
+import traceback
+import numpy as np
+import multiprocessing
+
+import rdkit
+import rdkit.Chem as Chem
+
+rdkit.RDLogger.DisableLog('rdApp.*')
+
+from SmilesPE.pretokenizer import atomwise_tokenizer
+
+from .constants import RGROUP_SYMBOLS, ABBREVIATIONS, VALENCES, FORMULA_REGEX
+
+
+def is_valid_mol(s, format_='atomtok'):
+ if format_ == 'atomtok':
+ mol = Chem.MolFromSmiles(s)
+ elif format_ == 'inchi':
+ if not s.startswith('InChI=1S'):
+ s = f"InChI=1S/{s}"
+ mol = Chem.MolFromInchi(s)
+ else:
+ raise NotImplemented
+ return mol is not None
+
+
+def _convert_smiles_to_inchi(smiles):
+ try:
+ mol = Chem.MolFromSmiles(smiles)
+ inchi = Chem.MolToInchi(mol)
+ except:
+ inchi = None
+ return inchi
+
+
+def convert_smiles_to_inchi(smiles_list, num_workers=16):
+ with multiprocessing.Pool(num_workers) as p:
+ inchi_list = p.map(_convert_smiles_to_inchi, smiles_list, chunksize=128)
+ n_success = sum([x is not None for x in inchi_list])
+ r_success = n_success / len(inchi_list)
+ inchi_list = [x if x else 'InChI=1S/H2O/h1H2' for x in inchi_list]
+ return inchi_list, r_success
+
+
+def merge_inchi(inchi1, inchi2):
+ replaced = 0
+ inchi1 = copy.deepcopy(inchi1)
+ for i in range(len(inchi1)):
+ if inchi1[i] == 'InChI=1S/H2O/h1H2':
+ inchi1[i] = inchi2[i]
+ replaced += 1
+ return inchi1, replaced
+
+
+def _get_num_atoms(smiles):
+ try:
+ return Chem.MolFromSmiles(smiles).GetNumAtoms()
+ except:
+ return 0
+
+
+def get_num_atoms(smiles, num_workers=16):
+ if type(smiles) is str:
+ return _get_num_atoms(smiles)
+ with multiprocessing.Pool(num_workers) as p:
+ num_atoms = p.map(_get_num_atoms, smiles)
+ return num_atoms
+
+
+def normalize_nodes(nodes, flip_y=True):
+ x, y = nodes[:, 0], nodes[:, 1]
+ minx, maxx = min(x), max(x)
+ miny, maxy = min(y), max(y)
+ x = (x - minx) / max(maxx - minx, 1e-6)
+ if flip_y:
+ y = (maxy - y) / max(maxy - miny, 1e-6)
+ else:
+ y = (y - miny) / max(maxy - miny, 1e-6)
+ return np.stack([x, y], axis=1)
+
+
+def _verify_chirality(mol, coords, symbols, edges, debug=False):
+ try:
+ n = mol.GetNumAtoms()
+ # Make a temp mol to find chiral centers
+ mol_tmp = mol.GetMol()
+ Chem.SanitizeMol(mol_tmp)
+
+ chiral_centers = Chem.FindMolChiralCenters(
+ mol_tmp, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False)
+ chiral_center_ids = [idx for idx, _ in chiral_centers] # List[Tuple[int, any]] -> List[int]
+
+ # correction to clear pre-condition violation (for some corner cases)
+ for bond in mol.GetBonds():
+ if bond.GetBondType() == Chem.BondType.SINGLE:
+ bond.SetBondDir(Chem.BondDir.NONE)
+
+ # Create conformer from 2D coordinate
+ conf = Chem.Conformer(n)
+ conf.Set3D(True)
+ for i, (x, y) in enumerate(coords):
+ conf.SetAtomPosition(i, (x, 1 - y, 0))
+ mol.AddConformer(conf)
+ Chem.SanitizeMol(mol)
+ Chem.AssignStereochemistryFrom3D(mol)
+ # NOTE: seems that only AssignStereochemistryFrom3D can handle double bond E/Z
+ # So we do this first, remove the conformer and add back the 2D conformer for chiral correction
+
+ mol.RemoveAllConformers()
+ conf = Chem.Conformer(n)
+ conf.Set3D(False)
+ for i, (x, y) in enumerate(coords):
+ conf.SetAtomPosition(i, (x, 1 - y, 0))
+ mol.AddConformer(conf)
+
+ # Magic, inferring chirality from coordinates and BondDir. DO NOT CHANGE.
+ Chem.SanitizeMol(mol)
+ Chem.AssignChiralTypesFromBondDirs(mol)
+ Chem.AssignStereochemistry(mol, force=True)
+
+ # Second loop to reset any wedge/dash bond to be starting from the chiral center)
+ for i in chiral_center_ids:
+ for j in range(n):
+ if edges[i][j] == 5:
+ # assert edges[j][i] == 6
+ mol.RemoveBond(i, j)
+ mol.AddBond(i, j, Chem.BondType.SINGLE)
+ mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINWEDGE)
+ elif edges[i][j] == 6:
+ # assert edges[j][i] == 5
+ mol.RemoveBond(i, j)
+ mol.AddBond(i, j, Chem.BondType.SINGLE)
+ mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINDASH)
+ Chem.AssignChiralTypesFromBondDirs(mol)
+ Chem.AssignStereochemistry(mol, force=True)
+
+ # reset chiral tags for non-carbon atom
+ for atom in mol.GetAtoms():
+ if atom.GetSymbol() != "C":
+ atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
+ mol = mol.GetMol()
+
+ except Exception as e:
+ if debug:
+ raise e
+ pass
+ return mol
+
+
+def _parse_tokens(tokens: list):
+ """
+ Parse tokens of condensed formula into list of pairs `(elt, num)`
+ where `num` is the multiplicity of the atom (or nested condensed formula) `elt`
+ Used by `_parse_formula`, which does the same thing but takes a formula in string form as input
+ """
+ elements = []
+ i = 0
+ j = 0
+ while i < len(tokens):
+ if tokens[i] == '(':
+ while j < len(tokens) and tokens[j] != ')':
+ j += 1
+ elt = _parse_tokens(tokens[i + 1:j])
+ else:
+ elt = tokens[i]
+ j += 1
+ if j < len(tokens) and tokens[j].isnumeric():
+ num = int(tokens[j])
+ j += 1
+ else:
+ num = 1
+ elements.append((elt, num))
+ i = j
+ return elements
+
+
+def _parse_formula(formula: str):
+ """
+ Parse condensed formula into list of pairs `(elt, num)`
+ where `num` is the subscript to the atom (or nested condensed formula) `elt`
+ Example: "C2H4O" -> [('C', 2), ('H', 4), ('O', 1)]
+ """
+ tokens = FORMULA_REGEX.findall(formula)
+ # if ''.join(tokens) != formula:
+ # tokens = FORMULA_REGEX_BACKUP.findall(formula)
+ return _parse_tokens(tokens)
+
+
+def _expand_carbon(elements: list):
+ """
+ Given list of pairs `(elt, num)`, output single list of all atoms in order,
+ expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary
+ Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O'])
+ """
+ expanded = []
+ i = 0
+ while i < len(elements):
+ elt, num = elements[i]
+ # expand carbon sequence
+ if elt == 'C' and num > 1 and i + 1 < len(elements):
+ next_elt, next_num = elements[i + 1]
+ quotient, remainder = next_num // num, next_num % num
+ for _ in range(num):
+ expanded.append('C')
+ for _ in range(quotient):
+ expanded.append(next_elt)
+ for _ in range(remainder):
+ expanded.append(next_elt)
+ i += 2
+ # recurse if `elt` itself is a list (nested formula)
+ elif isinstance(elt, list):
+ new_elt = _expand_carbon(elt)
+ for _ in range(num):
+ expanded.append(new_elt)
+ i += 1
+ # simplest case: simply append `elt` `num` times
+ else:
+ for _ in range(num):
+ expanded.append(elt)
+ i += 1
+ return expanded
+
+
+def _expand_abbreviation(abbrev):
+ """
+ Expand abbreviation into its SMILES; also converts [Rn] to [n*]
+ Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula
+ """
+ if abbrev in ABBREVIATIONS:
+ return ABBREVIATIONS[abbrev].smiles
+ if abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()):
+ if abbrev[1:].isdigit():
+ return f'[{abbrev[1:]}*]'
+ return '*'
+ return f'[{abbrev}]'
+
+
+def _get_bond_symb(bond_num):
+ """
+ Get SMILES symbol for a bond given bond order
+ Used in `_condensed_formula_list_to_smiles` while writing the SMILES string
+ """
+ if bond_num == 0:
+ return '.'
+ if bond_num == 1:
+ return ''
+ if bond_num == 2:
+ return '='
+ if bond_num == 3:
+ return '#'
+ return ''
+
+
+def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None):
+ """
+ Converts condensed formula (in the form of a list of symbols) to smiles
+ Input:
+ `formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2
+ `start_bond`: # bonds attached to beginning of formula
+ `end_bond`: # bonds attached to end of formula (deduce automatically if None)
+ `direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically)
+ Returns:
+ `smiles`: smiles corresponding to input condensed formula
+ `bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified
+ `num_trials`: number of trials
+ `success` (bool): whether conversion was successful
+ """
+ # `direction` not specified: try left to right; if fails, try right to left
+ if direction is None:
+ num_trials = 1
+ for dir_choice in [1, -1]:
+ smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice)
+ num_trials += trials
+ if success:
+ return smiles, bonds_left, num_trials, success
+ return None, None, num_trials, False
+ assert direction == 1 or direction == -1
+
+ def dfs(smiles, bonds_left, cur_idx, add_idx):
+ """
+ `smiles`: SMILES string so far
+ `cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached)
+ `cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far
+ `bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to
+ `add_idx`: index (in list `formula`) of atom to be attached to current atom
+ `add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far
+ Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2)
+ """
+ num_trials = 1
+ # end of formula: return result
+ if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1):
+ if end_bond is not None and end_bond != bonds_left:
+ return smiles, bonds_left, num_trials, False
+ return smiles, bonds_left, num_trials, True
+
+ # no more bonds but there are atoms remaining: conversion failed
+ if bonds_left <= 0:
+ return smiles, bonds_left, num_trials, False
+ to_add = formula_list[add_idx] # atom to be added to current atom
+
+ if isinstance(to_add, list): # "atom" added is a list (i.e. nested condensed formula): assume valence of 1
+ if bonds_left > 1:
+ # "atom" added does not use up remaining bonds of current atom
+ # get smiles of "atom" (which is itself a condensed formula)
+ add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction)
+ if val > 0:
+ add_str = _get_bond_symb(val + 1) + add_str
+ num_trials += trials
+ if not success:
+ return smiles, bonds_left, num_trials, False
+ # put smiles of "atom" in parentheses and append to smiles; go to next atom to add to current atom
+ result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction)
+ else:
+ # "atom" added uses up remaining bonds of current atom
+ # get smiles of "atom" and bonds left on it
+ add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction)
+ num_trials += trials
+ if not success:
+ return smiles, bonds_left, num_trials, False
+ # append smiles of "atom" (without parentheses) to smiles; it becomes new current atom
+ result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction)
+ smiles, bonds_left, trials, success = result
+ num_trials += trials
+ return smiles, bonds_left, num_trials, success
+
+ # atom added is a single symbol (as opposed to nested condensed formula)
+ for val in VALENCES.get(to_add, [1]): # try all possible valences of atom added
+ add_str = _expand_abbreviation(to_add) # expand to smiles if symbol is abbreviation
+ if bonds_left > val: # atom added does not use up remaining bonds of current atom; go to next atom to add to current atom
+ if cur_idx >= 0:
+ add_str = _get_bond_symb(val) + add_str
+ result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction)
+ else: # atom added uses up remaining bonds of current atom; it becomes new current atom
+ if cur_idx >= 0:
+ add_str = _get_bond_symb(bonds_left) + add_str
+ result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction)
+ trials, success = result[2:]
+ num_trials += trials
+ if success:
+ return result[0], result[1], num_trials, success
+ if num_trials > 10000:
+ break
+ return smiles, bonds_left, num_trials, False
+
+ cur_idx = -1 if direction == 1 else len(formula_list)
+ add_idx = 0 if direction == 1 else len(formula_list) - 1
+ return dfs('', start_bond, cur_idx, add_idx)
+
+
+def get_smiles_from_symbol(symbol, mol, atom, bonds):
+ """
+ Convert symbol (abbrev. or condensed formula) to smiles
+ If condensed formula, determine parsing direction and num. bonds on each side using coordinates
+ """
+ print(symbol)
+ if symbol in ABBREVIATIONS:
+ return ABBREVIATIONS[symbol].smiles
+ if len(symbol) > 20:
+ return None
+
+ #mol_check = Chem.MolFromSmiles(symbol)
+ #if mol_check:
+ # print(symbol) # Print the symbol to debug
+ # return symbol
+
+ total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds]))
+ formula_list = _expand_carbon(_parse_formula(symbol))
+ smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None)
+ if success:
+ mol_check = Chem.MolFromSmiles(smiles) # Check if the SMILES is valid
+ if mol_check:
+ print(f"smiles:{smiles}") # Print the symbol to debug
+ return smiles
+
+
+ mol_check = Chem.MolFromSmiles(symbol)
+ if mol_check:
+ print(f"symbol:{symbol}") # Print the symbol to debug
+ return symbol
+
+ return None
+
+
+def _replace_functional_group(smiles):
+ smiles = smiles.replace('', 'C')
+ for i, r in enumerate(RGROUP_SYMBOLS):
+ symbol = f'[{r}]'
+ if symbol in smiles:
+ if r[0] == 'R' and r[1:].isdigit():
+ smiles = smiles.replace(symbol, f'[{int(r[1:])}*]')
+ else:
+ smiles = smiles.replace(symbol, '*')
+ # For unknown tokens (i.e. rdkit cannot parse), replace them with [{isotope}*], where isotope is an identifier.
+ tokens = atomwise_tokenizer(smiles)
+ new_tokens = []
+ mappings = {} # isotope : symbol
+ isotope = 50
+ for token in tokens:
+ if token[0] == '[':
+ if token[1:-1] in ABBREVIATIONS or Chem.AtomFromSmiles(token) is None:
+ while f'[{isotope}*]' in smiles or f'[{isotope}*]' in new_tokens:
+ isotope += 1
+ placeholder = f'[{isotope}*]'
+ mappings[isotope] = token[1:-1]
+ new_tokens.append(placeholder)
+ continue
+ new_tokens.append(token)
+ smiles = ''.join(new_tokens)
+ return smiles, mappings
+
+
+def convert_smiles_to_mol(smiles):
+ if smiles is None or smiles == '':
+ return None
+ try:
+ mol = Chem.MolFromSmiles(smiles)
+ except:
+ return None
+ return mol
+
+
+BOND_TYPES = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE}
+
+
+def _expand_functional_group(mol, mappings, debug=False):
+ def _need_expand(mol, mappings):
+ return any([len(Chem.GetAtomAlias(atom)) > 0 for atom in mol.GetAtoms()]) or len(mappings) > 0
+
+ if _need_expand(mol, mappings):
+ mol_w = Chem.RWMol(mol)
+ num_atoms = mol_w.GetNumAtoms()
+ for i, atom in enumerate(mol_w.GetAtoms()): # reset radical electrons
+ atom.SetNumRadicalElectrons(0)
+
+ atoms_to_remove = []
+ for i in range(num_atoms):
+ atom = mol_w.GetAtomWithIdx(i)
+ if atom.GetSymbol() == '*':
+ symbol = Chem.GetAtomAlias(atom)
+ isotope = atom.GetIsotope()
+ if isotope > 0 and isotope in mappings:
+ symbol = mappings[isotope]
+ if not (isinstance(symbol, str) and len(symbol) > 0):
+ continue
+ # rgroups do not need to be expanded
+ if symbol in RGROUP_SYMBOLS:
+ continue
+
+ bonds = atom.GetBonds()
+ sub_smiles = get_smiles_from_symbol(symbol, mol_w, atom, bonds)
+
+ # create mol object for abbreviation/condensed formula from its SMILES
+ mol_r = convert_smiles_to_mol(sub_smiles)
+
+ if mol_r is None:
+ # atom.SetAtomicNum(6)
+ atom.SetIsotope(0)
+ continue
+
+ # remove bonds connected to abbreviation/condensed formula
+ adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds]
+ for adjacent_idx in adjacent_indices:
+ mol_w.RemoveBond(i, adjacent_idx)
+
+ adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices]
+ for adjacent_atom, bond in zip(adjacent_atoms, bonds):
+ adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble()))
+
+ # get indices of atoms of main body that connect to substituent
+ bonding_atoms_w = adjacent_indices
+ # assume indices are concated after combine mol_w and mol_r
+ bonding_atoms_r = [mol_w.GetNumAtoms()]
+ for atm in mol_r.GetAtoms():
+ if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0:
+ bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx())
+
+ # combine main body and substituent into a single molecule object
+ combo = Chem.CombineMols(mol_w, mol_r)
+
+ # connect substituent to main body with bonds
+ mol_w = Chem.RWMol(combo)
+ # if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body
+ for atm in bonding_atoms_w:
+ bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons()
+ mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order])
+
+ # reset radical electrons
+ for atm in bonding_atoms_w:
+ mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
+ for atm in bonding_atoms_r:
+ mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0)
+ atoms_to_remove.append(i)
+
+ # Remove atom in the end, otherwise the id will change
+ # Reverse the order and remove atoms with larger id first
+ atoms_to_remove.sort(reverse=True)
+ for i in atoms_to_remove:
+ mol_w.RemoveAtom(i)
+ smiles = Chem.MolToSmiles(mol_w)
+ mol = mol_w.GetMol()
+ else:
+ smiles = Chem.MolToSmiles(mol)
+ return smiles, mol
+
+
+def _convert_graph_to_smiles(coords, symbols, edges, image=None, debug=False):
+ mol = Chem.RWMol()
+ n = len(symbols)
+ ids = []
+ for i in range(n):
+ symbol = symbols[i]
+ if symbol[0] == '[':
+ symbol = symbol[1:-1]
+ if symbol in RGROUP_SYMBOLS:
+ atom = Chem.Atom("*")
+ if symbol[0] == 'R' and symbol[1:].isdigit():
+ atom.SetIsotope(int(symbol[1:]))
+ Chem.SetAtomAlias(atom, symbol)
+ elif symbol in ABBREVIATIONS:
+ atom = Chem.Atom("*")
+ Chem.SetAtomAlias(atom, symbol)
+ else:
+ try: # try to get SMILES of atom
+ atom = Chem.AtomFromSmiles(symbols[i])
+ atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
+ except: # otherwise, abbreviation or condensed formula
+ atom = Chem.Atom("*")
+ Chem.SetAtomAlias(atom, symbol)
+
+ if atom.GetSymbol() == '*':
+ atom.SetProp('molFileAlias', symbol)
+
+ idx = mol.AddAtom(atom)
+ assert idx == i
+ ids.append(idx)
+
+ for i in range(n):
+ for j in range(i + 1, n):
+ if edges[i][j] == 1:
+ mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
+ elif edges[i][j] == 2:
+ mol.AddBond(ids[i], ids[j], Chem.BondType.DOUBLE)
+ elif edges[i][j] == 3:
+ mol.AddBond(ids[i], ids[j], Chem.BondType.TRIPLE)
+ elif edges[i][j] == 4:
+ mol.AddBond(ids[i], ids[j], Chem.BondType.AROMATIC)
+ elif edges[i][j] == 5:
+ mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
+ mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINWEDGE)
+ elif edges[i][j] == 6:
+ mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
+ mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINDASH)
+
+ pred_smiles = ''
+
+ try:
+ # TODO: move to an util function
+ if image is not None:
+ height, width, _ = image.shape
+ ratio = width / height
+ coords = [[x * ratio * 10, y * 10] for x, y in coords]
+ mol = _verify_chirality(mol, coords, symbols, edges, debug)
+ # molblock is obtained before expanding func groups, otherwise the expanded group won't have coordinates.
+ # TODO: make sure molblock has the abbreviation information
+ pred_molblock = Chem.MolToMolBlock(mol)
+ pred_smiles, mol = _expand_functional_group(mol, {}, debug)
+ success = True
+ except Exception as e:
+ if debug:
+ print(traceback.format_exc())
+ pred_molblock = ''
+ success = False
+
+ if debug:
+ return pred_smiles, pred_molblock, mol, success
+ return pred_smiles, pred_molblock, success
+
+
+def convert_graph_to_smiles(coords, symbols, edges, images=None, num_workers=16):
+ with multiprocessing.Pool(num_workers) as p:
+ if images is None:
+ results = p.starmap(_convert_graph_to_smiles, zip(coords, symbols, edges), chunksize=128)
+ else:
+ results = p.starmap(_convert_graph_to_smiles, zip(coords, symbols, edges, images), chunksize=128)
+ smiles_list, molblock_list, success = zip(*results)
+ r_success = np.mean(success)
+ return smiles_list, molblock_list, r_success
+
+
+def _postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, debug=False):
+ if type(smiles) is not str or smiles == '':
+ return '', False
+ mol = None
+ pred_molblock = ''
+ try:
+ pred_smiles = smiles
+ pred_smiles, mappings = _replace_functional_group(pred_smiles)
+ if coords is not None and symbols is not None and edges is not None:
+ pred_smiles = pred_smiles.replace('@', '').replace('/', '').replace('\\', '')
+ mol = Chem.RWMol(Chem.MolFromSmiles(pred_smiles, sanitize=False))
+ mol = _verify_chirality(mol, coords, symbols, edges, debug)
+ else:
+ mol = Chem.MolFromSmiles(pred_smiles, sanitize=False)
+ # pred_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
+ if molblock:
+ pred_molblock = Chem.MolToMolBlock(mol)
+ pred_smiles, mol = _expand_functional_group(mol, mappings)
+ success = True
+ except Exception as e:
+ if debug:
+ print(traceback.format_exc())
+ pred_smiles = smiles
+ pred_molblock = ''
+ success = False
+ if debug:
+ return pred_smiles, pred_molblock, mol, success
+ return pred_smiles, pred_molblock, success
+
+
+def postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, num_workers=16):
+ with multiprocessing.Pool(num_workers) as p:
+ if coords is not None and symbols is not None and edges is not None:
+ results = p.starmap(_postprocess_smiles, zip(smiles, coords, symbols, edges), chunksize=128)
+ else:
+ results = p.map(_postprocess_smiles, smiles, chunksize=128)
+ smiles_list, molblock_list, success = zip(*results)
+ r_success = np.mean(success)
+ return smiles_list, molblock_list, r_success
+
+
+def _keep_main_molecule(smiles, debug=False):
+ try:
+ mol = Chem.MolFromSmiles(smiles)
+ frags = Chem.GetMolFrags(mol, asMols=True)
+ if len(frags) > 1:
+ num_atoms = [m.GetNumAtoms() for m in frags]
+ main_mol = frags[np.argmax(num_atoms)]
+ smiles = Chem.MolToSmiles(main_mol)
+ except Exception as e:
+ if debug:
+ print(traceback.format_exc())
+ return smiles
+
+
+def keep_main_molecule(smiles, num_workers=16):
+ with multiprocessing.Pool(num_workers) as p:
+ results = p.map(_keep_main_molecule, smiles, chunksize=128)
+ return results
diff --git a/molscribe/constants.py b/molscribe/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b6670e07b5ffc5c35741d66376fec705dbec7cc
--- /dev/null
+++ b/molscribe/constants.py
@@ -0,0 +1,130 @@
+from typing import List
+import re
+
+ORGANIC_SET = {'B', 'C', 'N', 'O', 'P', 'S', 'F', 'Cl', 'Br', 'I'}
+
+RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', "R'",
+ 'Ra', 'Rb', 'Rc', 'Rd', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar']
+
+PLACEHOLDER_ATOMS = ["Lv", "Lu", "Nd", "Yb", "At", "Fm", "Er"]
+
+
+class Substitution(object):
+ '''Define common substitutions for chemical shorthand'''
+ def __init__(self, abbrvs, smarts, smiles, probability):
+ assert type(abbrvs) is list
+ self.abbrvs = abbrvs
+ self.smarts = smarts
+ self.smiles = smiles
+ self.probability = probability
+
+
+SUBSTITUTIONS: List[Substitution] = [
+ Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5),
+ Substitution(['OCOCH3'], '[#8]-[#6](=[#8])-[#6]', "[O]C(=O)C]", 0.5),
+ Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5),
+ Substitution(['CO2Et', 'COOEt', 'EtO2C'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5),
+
+ Substitution(['OAc'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7),
+ Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7),
+ Substitution(['Ac'], 'C(=O)[CH3]', "[C](=O)C", 0.1),
+
+ Substitution(['OBz'], '[OH0;D2]C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[O]C(=O)c1ccccc1", 0.7), # Benzoyl
+ Substitution(['Bz'], 'C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)c1ccccc1", 0.2), # Benzoyl
+
+ Substitution(['OBn'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[O]Cc1ccccc1", 0.7), # Benzyl
+ Substitution(['Bn'], '[CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[CH2]c1ccccc1", 0.2), # Benzyl
+
+ Substitution(['NHBoc'], '[NH1;D2]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6),
+ Substitution(['NBoc'], '[NH0;D3]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6),
+ Substitution(['Boc'], 'C(=O)OC([CH3])([CH3])[CH3]', "[C](=O)OC(C)(C)C", 0.2),
+
+ Substitution(['Cbm'], 'C(=O)[NH2;D1]', "[C](=O)N", 0.2),
+ Substitution(['Cbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[C](=O)OCc1ccccc1", 0.4),
+ Substitution(['Cy'], '[CH1;X3]1[CH2][CH2][CH2][CH2][CH2]1', "[CH1]1CCCCC1", 0.3),
+ Substitution(['Fmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3',
+ "[C](=O)OCC1c(cccc2)c2c3c1cccc3", 0.6),
+ Substitution(['Mes'], '[cH0]1c([CH3])cc([CH3])cc([CH3])1', "[c]1c(C)cc(C)cc(C)1", 0.5),
+ Substitution(['OMs'], '[OH0;D2]S(=O)(=O)[CH3]', "[O]S(=O)(=O)C", 0.7),
+ Substitution(['Ms'], 'S(=O)(=O)[CH3]', "[S](=O)(=O)C", 0.2),
+ Substitution(['Ph'], '[cH0]1[cH][cH][cH1][cH][cH]1', "[c]1ccccc1", 0.5),
+ Substitution(['PMB'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[CH2]c1ccc(OC)cc1", 0.2),
+ Substitution(['Py'], '[cH0]1[n;+0][cH1][cH1][cH1][cH1]1', "[c]1ncccc1", 0.1),
+ Substitution(['SEM'], '[CH2;D2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]CSi(C)(C)C", 0.2),
+ Substitution(['Suc'], 'C(=O)[CH2][CH2]C(=O)[OH]', "[C](=O)CCC(=O)O", 0.2),
+ Substitution(['TBS'], '[Si]([CH3])([CH3])C([CH3])([CH3])[CH3]', "[Si](C)(C)C(C)(C)C", 0.5),
+ Substitution(['TBZ'], 'C(=S)[cH]1[cH][cH][cH1][cH][cH]1', "[C](=S)c1ccccc1", 0.2),
+ Substitution(['OTf'], '[OH0;D2]S(=O)(=O)C(F)(F)F', "[O]S(=O)(=O)C(F)(F)F", 0.7),
+ Substitution(['Tf'], 'S(=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.2),
+ Substitution(['TFA'], 'C(=O)C(F)(F)F', "[C](=O)C(F)(F)F", 0.3),
+ Substitution(['TMS'], '[Si]([CH3])([CH3])[CH3]', "[Si](C)(C)C", 0.5),
+ Substitution(['Ts'], 'S(=O)(=O)c1[cH1][cH1][cH0]([CH3])[cH1][cH1]1', "[S](=O)(=O)c1ccc(C)cc1", 0.6), # Tos
+
+ # Alkyl chains
+ Substitution(['OMe', 'MeO'], '[OH0;D2][CH3;D1]', "[O]C", 0.3),
+ Substitution(['SMe', 'MeS'], '[SH0;D2][CH3;D1]', "[S]C", 0.3),
+ Substitution(['NMe', 'MeN'], '[N;X3][CH3;D1]', "[NH]C", 0.3),
+ Substitution(['Me'], '[CH3;D1]', "[CH3]", 0.1),
+ Substitution(['OEt', 'EtO'], '[OH0;D2][CH2;D2][CH3]', "[O]CC", 0.5),
+ Substitution(['Et', 'C2H5'], '[CH2;D2][CH3]', "[CH2]C", 0.3),
+ Substitution(['Pr', 'nPr', 'n-Pr'], '[CH2;D2][CH2;D2][CH3]', "[CH2]CC", 0.3),
+ Substitution(['Bu', 'nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3),
+
+ # Branched
+ Substitution(['iPr', 'i-Pr'], '[CH1;D3]([CH3])[CH3]', "[CH1](C)C", 0.2),
+ Substitution(['iBu', 'i-Bu'], '[CH2;D2][CH1;D3]([CH3])[CH3]', "[CH2]C(C)C", 0.2),
+ Substitution(['OiBu'], '[OH0;D2][CH2;D2][CH1;D3]([CH3])[CH3]', "[O]CC(C)C", 0.2),
+ Substitution(['OtBu'], '[OH0;D2][CH0]([CH3])([CH3])[CH3]', "[O]C(C)(C)C", 0.6),
+ Substitution(['tBu', 't-Bu'], '[CH0]([CH3])([CH3])[CH3]', "[C](C)(C)C", 0.3),
+
+ # Other shorthands (MIGHT NOT WANT ALL OF THESE)
+ Substitution(['CF3', 'F3C'], '[CH0;D4](F)(F)F', "[C](F)(F)F", 0.5),
+ Substitution(['NCF3', 'F3CN'], '[N;X3][CH0;D4](F)(F)F', "[NH]C(F)(F)F", 0.5),
+ Substitution(['OCF3', 'F3CO'], '[OH0;X2][CH0;D4](F)(F)F', "[O]C(F)(F)F", 0.5),
+ Substitution(['CCl3'], '[CH0;D4](Cl)(Cl)Cl', "[C](Cl)(Cl)Cl", 0.5),
+ Substitution(['CO2H', 'HO2C', 'COOH'], 'C(=O)[OH]', "[C](=O)O", 0.5), # COOH
+ Substitution(['CN', 'NC'], 'C#[ND1]', "[C]#N", 0.5),
+ Substitution(['OCH3', 'H3CO'], '[OH0;D2][CH3]', "[O]C", 0.4),
+ Substitution(['SO3H'], 'S(=O)(=O)[OH]', "[S](=O)(=O)O", 0.4),
+ Substitution(['CH3O'], '[OH0;D2][CH3]', "[O]C", 0),
+ Substitution(['PhCH2CH2'], '[OH0;D2][CH3]', "C1=CC=CC=C1CC", 0),
+ Substitution(['SO2ToI','SO2Tol'], '[OH0;D2][CH3]', "CS(=O)(=O)C1=CC=CC=C1", 0),
+
+
+
+
+
+
+]
+
+ABBREVIATIONS = {abbrv: sub for sub in SUBSTITUTIONS for abbrv in sub.abbrvs}
+
+VALENCES = {
+ "H": [1], "Li": [1], "Be": [2], "B": [3], "C": [4], "N": [3, 5], "O": [2], "F": [1],
+ "Na": [1], "Mg": [2], "Al": [3], "Si": [4], "P": [5, 3], "S": [6, 2, 4], "Cl": [1], "K": [1], "Ca": [2],
+ "Br": [1], "I": [1]
+}
+
+ELEMENTS = [
+ "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
+ "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca",
+ "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
+ "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr",
+ "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn",
+ "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
+ "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb",
+ "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg",
+ "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
+ "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm",
+ "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds",
+ "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
+]
+
+COLORS = {
+ u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0',
+ u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75'
+}
+
+# tokens of condensed formula
+FORMULA_REGEX = re.compile(
+ '(' + '|'.join(list(ABBREVIATIONS.keys())) + '|R[0-9]*|[A-Z][a-z]+|[A-Z]|[0-9]+|\(|\))')
diff --git a/molscribe/dataset.py b/molscribe/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b34c0dd36dc7b8561dc5f44ba37edfbe6e3fe4
--- /dev/null
+++ b/molscribe/dataset.py
@@ -0,0 +1,594 @@
+import os
+import cv2
+import time
+import random
+import re
+import string
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset
+from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+
+from .indigo import Indigo
+from .indigo.renderer import IndigoRenderer
+
+from .augment import SafeRotate, CropWhite, PadWhite, SaltAndPepperNoise
+from .utils import FORMAT_INFO
+from .tokenizer import PAD_ID
+from .chemistry import get_num_atoms, normalize_nodes
+from .constants import RGROUP_SYMBOLS, SUBSTITUTIONS, ELEMENTS, COLORS
+
+cv2.setNumThreads(1)
+
+INDIGO_HYGROGEN_PROB = 0.2
+INDIGO_FUNCTIONAL_GROUP_PROB = 0.8
+INDIGO_CONDENSED_PROB = 0.5
+INDIGO_RGROUP_PROB = 0.5
+INDIGO_COMMENT_PROB = 0.3
+INDIGO_DEARMOTIZE_PROB = 0.8
+INDIGO_COLOR_PROB = 0.2
+
+
+def get_transforms(input_size, augment=True, rotate=True, debug=False):
+ trans_list = []
+ if augment and rotate:
+ trans_list.append(SafeRotate(limit=90, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255)))
+ trans_list.append(CropWhite(pad=5))
+ if augment:
+ trans_list += [
+ # NormalizedGridDistortion(num_steps=10, distort_limit=0.3),
+ A.CropAndPad(percent=[-0.01, 0.00], keep_size=False, p=0.5),
+ PadWhite(pad_ratio=0.4, p=0.2),
+ A.Downscale(scale_min=0.2, scale_max=0.5, interpolation=3),
+ A.Blur(),
+ A.GaussNoise(),
+ SaltAndPepperNoise(num_dots=20, p=0.5)
+ ]
+ trans_list.append(A.Resize(input_size, input_size))
+ if not debug:
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ trans_list += [
+ A.ToGray(p=1),
+ A.Normalize(mean=mean, std=std),
+ ToTensorV2(),
+ ]
+ return A.Compose(trans_list, keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))
+
+
+def add_functional_group(indigo, mol, debug=False):
+ if random.random() > INDIGO_FUNCTIONAL_GROUP_PROB:
+ return mol
+ # Delete functional group and add a pseudo atom with its abbrv
+ substitutions = [sub for sub in SUBSTITUTIONS]
+ random.shuffle(substitutions)
+ for sub in substitutions:
+ query = indigo.loadSmarts(sub.smarts)
+ matcher = indigo.substructureMatcher(mol)
+ matched_atoms_ids = set()
+ for match in matcher.iterateMatches(query):
+ if random.random() < sub.probability or debug:
+ atoms = []
+ atoms_ids = set()
+ for item in query.iterateAtoms():
+ atom = match.mapAtom(item)
+ atoms.append(atom)
+ atoms_ids.add(atom.index())
+ if len(matched_atoms_ids.intersection(atoms_ids)) > 0:
+ continue
+ abbrv = random.choice(sub.abbrvs)
+ superatom = mol.addAtom(abbrv)
+ for atom in atoms:
+ for nei in atom.iterateNeighbors():
+ if nei.index() not in atoms_ids:
+ if nei.symbol() == 'H':
+ # indigo won't match explicit hydrogen, so remove them explicitly
+ atoms_ids.add(nei.index())
+ else:
+ superatom.addBond(nei, nei.bond().bondOrder())
+ for id in atoms_ids:
+ mol.getAtom(id).remove()
+ matched_atoms_ids = matched_atoms_ids.union(atoms_ids)
+ return mol
+
+
+def add_explicit_hydrogen(indigo, mol):
+ atoms = []
+ for atom in mol.iterateAtoms():
+ try:
+ hs = atom.countImplicitHydrogens()
+ if hs > 0:
+ atoms.append((atom, hs))
+ except:
+ continue
+ if len(atoms) > 0 and random.random() < INDIGO_HYGROGEN_PROB:
+ atom, hs = random.choice(atoms)
+ for i in range(hs):
+ h = mol.addAtom('H')
+ h.addBond(atom, 1)
+ return mol
+
+
+def add_rgroup(indigo, mol, smiles):
+ atoms = []
+ for atom in mol.iterateAtoms():
+ try:
+ hs = atom.countImplicitHydrogens()
+ if hs > 0:
+ atoms.append(atom)
+ except:
+ continue
+ if len(atoms) > 0 and '*' not in smiles:
+ if random.random() < INDIGO_RGROUP_PROB:
+ atom_idx = random.choice(range(len(atoms)))
+ atom = atoms[atom_idx]
+ atoms.pop(atom_idx)
+ symbol = random.choice(RGROUP_SYMBOLS)
+ r = mol.addAtom(symbol)
+ r.addBond(atom, 1)
+ return mol
+
+
+def get_rand_symb():
+ symb = random.choice(ELEMENTS)
+ if random.random() < 0.1:
+ symb += random.choice(string.ascii_lowercase)
+ if random.random() < 0.1:
+ symb += random.choice(string.ascii_uppercase)
+ if random.random() < 0.1:
+ symb = f'({gen_rand_condensed()})'
+ return symb
+
+
+def get_rand_num():
+ if random.random() < 0.9:
+ if random.random() < 0.8:
+ return ''
+ else:
+ return str(random.randint(2, 9))
+ else:
+ return '1' + str(random.randint(2, 9))
+
+
+def gen_rand_condensed():
+ tokens = []
+ for i in range(5):
+ if i >= 1 and random.random() < 0.8:
+ break
+ tokens.append(get_rand_symb())
+ tokens.append(get_rand_num())
+ return ''.join(tokens)
+
+
+def add_rand_condensed(indigo, mol):
+ atoms = []
+ for atom in mol.iterateAtoms():
+ try:
+ hs = atom.countImplicitHydrogens()
+ if hs > 0:
+ atoms.append(atom)
+ except:
+ continue
+ if len(atoms) > 0 and random.random() < INDIGO_CONDENSED_PROB:
+ atom = random.choice(atoms)
+ symbol = gen_rand_condensed()
+ r = mol.addAtom(symbol)
+ r.addBond(atom, 1)
+ return mol
+
+
+def generate_output_smiles(indigo, mol):
+ # TODO: if using mol.canonicalSmiles(), explicit H will be removed
+ smiles = mol.smiles()
+ mol = indigo.loadMolecule(smiles)
+ if '*' in smiles:
+ part_a, part_b = smiles.split(' ', maxsplit=1)
+ part_b = re.search(r'\$.*\$', part_b).group(0)[1:-1]
+ symbols = [t for t in part_b.split(';') if len(t) > 0]
+ output = ''
+ cnt = 0
+ for i, c in enumerate(part_a):
+ if c != '*':
+ output += c
+ else:
+ output += f'[{symbols[cnt]}]'
+ cnt += 1
+ return mol, output
+ else:
+ if ' ' in smiles:
+ # special cases with extension
+ smiles = smiles.split(' ')[0]
+ return mol, smiles
+
+
+def add_comment(indigo):
+ if random.random() < INDIGO_COMMENT_PROB:
+ indigo.setOption('render-comment', str(random.randint(1, 20)) + random.choice(string.ascii_letters))
+ indigo.setOption('render-comment-font-size', random.randint(40, 60))
+ indigo.setOption('render-comment-alignment', random.choice([0, 0.5, 1]))
+ indigo.setOption('render-comment-position', random.choice(['top', 'bottom']))
+ indigo.setOption('render-comment-offset', random.randint(2, 30))
+
+
+def add_color(indigo, mol):
+ if random.random() < INDIGO_COLOR_PROB:
+ indigo.setOption('render-coloring', True)
+ if random.random() < INDIGO_COLOR_PROB:
+ indigo.setOption('render-base-color', random.choice(list(COLORS.values())))
+ if random.random() < INDIGO_COLOR_PROB:
+ if random.random() < 0.5:
+ indigo.setOption('render-highlight-color-enabled', True)
+ indigo.setOption('render-highlight-color', random.choice(list(COLORS.values())))
+ if random.random() < 0.5:
+ indigo.setOption('render-highlight-thickness-enabled', True)
+ for atom in mol.iterateAtoms():
+ if random.random() < 0.1:
+ atom.highlight()
+ return mol
+
+
+def get_graph(mol, image, shuffle_nodes=False, pseudo_coords=False):
+ mol.layout()
+ coords, symbols = [], []
+ index_map = {}
+ atoms = [atom for atom in mol.iterateAtoms()]
+ if shuffle_nodes:
+ random.shuffle(atoms)
+ for i, atom in enumerate(atoms):
+ if pseudo_coords:
+ x, y, z = atom.xyz()
+ else:
+ x, y = atom.coords()
+ coords.append([x, y])
+ symbols.append(atom.symbol())
+ index_map[atom.index()] = i
+ if pseudo_coords:
+ coords = normalize_nodes(np.array(coords))
+ h, w, _ = image.shape
+ coords[:, 0] = coords[:, 0] * w
+ coords[:, 1] = coords[:, 1] * h
+ n = len(symbols)
+ edges = np.zeros((n, n), dtype=int)
+ for bond in mol.iterateBonds():
+ s = index_map[bond.source().index()]
+ t = index_map[bond.destination().index()]
+ # 1/2/3/4 : single/double/triple/aromatic
+ edges[s, t] = bond.bondOrder()
+ edges[t, s] = bond.bondOrder()
+ if bond.bondStereo() in [5, 6]:
+ edges[s, t] = bond.bondStereo()
+ edges[t, s] = 11 - bond.bondStereo()
+ graph = {
+ 'coords': coords,
+ 'symbols': symbols,
+ 'edges': edges,
+ 'num_atoms': len(symbols)
+ }
+ return graph
+
+
+def generate_indigo_image(smiles, mol_augment=True, default_option=False, shuffle_nodes=False, pseudo_coords=False,
+ include_condensed=True, debug=False):
+ indigo = Indigo()
+ renderer = IndigoRenderer(indigo)
+ indigo.setOption('render-output-format', 'png')
+ indigo.setOption('render-background-color', '1,1,1')
+ indigo.setOption('render-stereo-style', 'none')
+ indigo.setOption('render-label-mode', 'hetero')
+ indigo.setOption('render-font-family', 'Arial')
+ if not default_option:
+ thickness = random.uniform(0.5, 2) # limit the sum of the following two parameters to be smaller than 4
+ indigo.setOption('render-relative-thickness', thickness)
+ indigo.setOption('render-bond-line-width', random.uniform(1, 4 - thickness))
+ if random.random() < 0.5:
+ indigo.setOption('render-font-family', random.choice(['Arial', 'Times', 'Courier', 'Helvetica']))
+ indigo.setOption('render-label-mode', random.choice(['hetero', 'terminal-hetero']))
+ indigo.setOption('render-implicit-hydrogens-visible', random.choice([True, False]))
+ if random.random() < 0.1:
+ indigo.setOption('render-stereo-style', 'old')
+ if random.random() < 0.2:
+ indigo.setOption('render-atom-ids-visible', True)
+
+ try:
+ mol = indigo.loadMolecule(smiles)
+ if mol_augment:
+ if random.random() < INDIGO_DEARMOTIZE_PROB:
+ mol.dearomatize()
+ else:
+ mol.aromatize()
+ smiles = mol.canonicalSmiles()
+ add_comment(indigo)
+ mol = add_explicit_hydrogen(indigo, mol)
+ mol = add_rgroup(indigo, mol, smiles)
+ if include_condensed:
+ mol = add_rand_condensed(indigo, mol)
+ mol = add_functional_group(indigo, mol, debug)
+ mol = add_color(indigo, mol)
+ mol, smiles = generate_output_smiles(indigo, mol)
+
+ buf = renderer.renderToBuffer(mol)
+ img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1) # decode buffer to image
+ # img = np.repeat(np.expand_dims(img, 2), 3, axis=2) # expand to RGB
+ graph = get_graph(mol, img, shuffle_nodes, pseudo_coords)
+ success = True
+ except Exception:
+ if debug:
+ raise Exception
+ img = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32)
+ graph = {}
+ success = False
+ return img, smiles, graph, success
+
+
+class TrainDataset(Dataset):
+ def __init__(self, args, df, tokenizer, split='train', dynamic_indigo=False):
+ super().__init__()
+ self.df = df
+ self.args = args
+ self.tokenizer = tokenizer
+ if 'file_path' in df.columns:
+ self.file_paths = df['file_path'].values
+ if not self.file_paths[0].startswith(args.data_path):
+ self.file_paths = [os.path.join(args.data_path, path) for path in df['file_path']]
+ self.smiles = df['SMILES'].values if 'SMILES' in df.columns else None
+ self.formats = args.formats
+ self.labelled = (split == 'train')
+ if self.labelled:
+ self.labels = {}
+ for format_ in self.formats:
+ if format_ in ['atomtok', 'inchi']:
+ field = FORMAT_INFO[format_]['name']
+ if field in df.columns:
+ self.labels[format_] = df[field].values
+ self.transform = get_transforms(args.input_size,
+ augment=(self.labelled and args.augment))
+ # self.fix_transform = A.Compose([A.Transpose(p=1), A.VerticalFlip(p=1)])
+ self.dynamic_indigo = (dynamic_indigo and split == 'train')
+ if self.labelled and not dynamic_indigo and args.coords_file is not None:
+ if args.coords_file == 'aux_file':
+ self.coords_df = df
+ self.pseudo_coords = True
+ else:
+ self.coords_df = pd.read_csv(args.coords_file)
+ self.pseudo_coords = False
+ else:
+ self.coords_df = None
+ self.pseudo_coords = args.pseudo_coords
+
+ def __len__(self):
+ return len(self.df)
+
+ def image_transform(self, image, coords=[], renormalize=False):
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # .astype(np.float32)
+ augmented = self.transform(image=image, keypoints=coords)
+ image = augmented['image']
+ if len(coords) > 0:
+ coords = np.array(augmented['keypoints'])
+ if renormalize:
+ coords = normalize_nodes(coords, flip_y=False)
+ else:
+ _, height, width = image.shape
+ coords[:, 0] = coords[:, 0] / width
+ coords[:, 1] = coords[:, 1] / height
+ coords = np.array(coords).clip(0, 1)
+ return image, coords
+ return image
+
+ def __getitem__(self, idx):
+ try:
+ return self.getitem(idx)
+ except Exception as e:
+ with open(os.path.join(self.args.save_path, f'error_dataset_{int(time.time())}.log'), 'w') as f:
+ f.write(str(e))
+ raise e
+
+ def getitem(self, idx):
+ ref = {}
+ if self.dynamic_indigo:
+ begin = time.time()
+ image, smiles, graph, success = generate_indigo_image(
+ self.smiles[idx], mol_augment=self.args.mol_augment, default_option=self.args.default_option,
+ shuffle_nodes=self.args.shuffle_nodes, pseudo_coords=self.pseudo_coords,
+ include_condensed=self.args.include_condensed)
+ # raw_image = image
+ end = time.time()
+ if idx < 30 and self.args.save_image:
+ path = os.path.join(self.args.save_path, 'images')
+ os.makedirs(path, exist_ok=True)
+ cv2.imwrite(os.path.join(path, f'{idx}.png'), image)
+ if not success:
+ return idx, None, {}
+ image, coords = self.image_transform(image, graph['coords'], renormalize=self.pseudo_coords)
+ graph['coords'] = coords
+ ref['time'] = end - begin
+ if 'atomtok' in self.formats:
+ max_len = FORMAT_INFO['atomtok']['max_len']
+ label = self.tokenizer['atomtok'].text_to_sequence(smiles, tokenized=False)
+ ref['atomtok'] = torch.LongTensor(label[:max_len])
+ if 'edges' in self.formats and 'atomtok_coords' not in self.formats and 'chartok_coords' not in self.formats:
+ ref['edges'] = torch.tensor(graph['edges'])
+ if 'atomtok_coords' in self.formats:
+ self._process_atomtok_coords(idx, ref, smiles, graph['coords'], graph['edges'],
+ mask_ratio=self.args.mask_ratio)
+ if 'chartok_coords' in self.formats:
+ self._process_chartok_coords(idx, ref, smiles, graph['coords'], graph['edges'],
+ mask_ratio=self.args.mask_ratio)
+ return idx, image, ref
+ else:
+ file_path = self.file_paths[idx]
+ image = cv2.imread(file_path)
+ if image is None:
+ image = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32)
+ print(file_path, 'not found!')
+ if self.coords_df is not None:
+ h, w, _ = image.shape
+ coords = np.array(eval(self.coords_df.loc[idx, 'node_coords']))
+ if self.pseudo_coords:
+ coords = normalize_nodes(coords)
+ coords[:, 0] = coords[:, 0] * w
+ coords[:, 1] = coords[:, 1] * h
+ image, coords = self.image_transform(image, coords, renormalize=self.pseudo_coords)
+ else:
+ image = self.image_transform(image)
+ coords = None
+ if self.labelled:
+ smiles = self.smiles[idx]
+ if 'atomtok' in self.formats:
+ max_len = FORMAT_INFO['atomtok']['max_len']
+ label = self.tokenizer['atomtok'].text_to_sequence(smiles, False)
+ ref['atomtok'] = torch.LongTensor(label[:max_len])
+ if 'atomtok_coords' in self.formats:
+ if coords is not None:
+ self._process_atomtok_coords(idx, ref, smiles, coords, mask_ratio=0)
+ else:
+ self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1)
+ if 'chartok_coords' in self.formats:
+ if coords is not None:
+ self._process_chartok_coords(idx, ref, smiles, coords, mask_ratio=0)
+ else:
+ self._process_chartok_coords(idx, ref, smiles, mask_ratio=1)
+ if self.args.predict_coords and ('atomtok_coords' in self.formats or 'chartok_coords' in self.formats):
+ smiles = self.smiles[idx]
+ if 'atomtok_coords' in self.formats:
+ self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1)
+ if 'chartok_coords' in self.formats:
+ self._process_chartok_coords(idx, ref, smiles, mask_ratio=1)
+ return idx, image, ref
+
+ def _process_atomtok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0):
+ max_len = FORMAT_INFO['atomtok_coords']['max_len']
+ tokenizer = self.tokenizer['atomtok_coords']
+ if smiles is None or type(smiles) is not str:
+ smiles = ""
+ label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio)
+ ref['atomtok_coords'] = torch.LongTensor(label[:max_len])
+ indices = [i for i in indices if i < max_len]
+ ref['atom_indices'] = torch.LongTensor(indices)
+ if tokenizer.continuous_coords:
+ if coords is not None:
+ ref['coords'] = torch.tensor(coords)
+ else:
+ ref['coords'] = torch.ones(len(indices), 2) * -1.
+ if edges is not None:
+ ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)]
+ else:
+ if 'edges' in self.df.columns:
+ edge_list = eval(self.df.loc[idx, 'edges'])
+ n = len(indices)
+ edges = torch.zeros((n, n), dtype=torch.long)
+ for u, v, t in edge_list:
+ if u < n and v < n:
+ if t <= 4:
+ edges[u, v] = t
+ edges[v, u] = t
+ else:
+ edges[u, v] = t
+ edges[v, u] = 11 - t
+ ref['edges'] = edges
+ else:
+ ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100)
+
+ def _process_chartok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0):
+ max_len = FORMAT_INFO['chartok_coords']['max_len']
+ tokenizer = self.tokenizer['chartok_coords']
+ if smiles is None or type(smiles) is not str:
+ smiles = ""
+ label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio)
+ ref['chartok_coords'] = torch.LongTensor(label[:max_len])
+ indices = [i for i in indices if i < max_len]
+ ref['atom_indices'] = torch.LongTensor(indices)
+ if tokenizer.continuous_coords:
+ if coords is not None:
+ ref['coords'] = torch.tensor(coords)
+ else:
+ ref['coords'] = torch.ones(len(indices), 2) * -1.
+ if edges is not None:
+ ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)]
+ else:
+ if 'edges' in self.df.columns:
+ edge_list = eval(self.df.loc[idx, 'edges'])
+ n = len(indices)
+ edges = torch.zeros((n, n), dtype=torch.long)
+ for u, v, t in edge_list:
+ if u < n and v < n:
+ if t <= 4:
+ edges[u, v] = t
+ edges[v, u] = t
+ else:
+ edges[u, v] = t
+ edges[v, u] = 11 - t
+ ref['edges'] = edges
+ else:
+ ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100)
+
+
+class AuxTrainDataset(Dataset):
+
+ def __init__(self, args, train_df, aux_df, tokenizer):
+ super().__init__()
+ self.train_dataset = TrainDataset(args, train_df, tokenizer, dynamic_indigo=args.dynamic_indigo)
+ self.aux_dataset = TrainDataset(args, aux_df, tokenizer, dynamic_indigo=False)
+
+ def __len__(self):
+ return len(self.train_dataset) + len(self.aux_dataset)
+
+ def __getitem__(self, idx):
+ if idx < len(self.train_dataset):
+ return self.train_dataset[idx]
+ else:
+ return self.aux_dataset[idx - len(self.train_dataset)]
+
+
+def pad_images(imgs):
+ # B, C, H, W
+ max_shape = [0, 0]
+ for img in imgs:
+ for i in range(len(max_shape)):
+ max_shape[i] = max(max_shape[i], img.shape[-1 - i])
+ stack = []
+ for img in imgs:
+ pad = []
+ for i in range(len(max_shape)):
+ pad = pad + [0, max_shape[i] - img.shape[-1 - i]]
+ stack.append(F.pad(img, pad, value=0))
+ return torch.stack(stack)
+
+
+def bms_collate(batch):
+ ids = []
+ imgs = []
+ batch = [ex for ex in batch if ex[1] is not None]
+ formats = list(batch[0][2].keys())
+ seq_formats = [k for k in formats if
+ k in ['atomtok', 'inchi', 'nodes', 'atomtok_coords', 'chartok_coords', 'atom_indices']]
+ refs = {key: [[], []] for key in seq_formats}
+ for ex in batch:
+ ids.append(ex[0])
+ imgs.append(ex[1])
+ ref = ex[2]
+ for key in seq_formats:
+ refs[key][0].append(ref[key])
+ refs[key][1].append(torch.LongTensor([len(ref[key])]))
+ # Sequence
+ for key in seq_formats:
+ # this padding should work for atomtok_with_coords too, each of which has shape (length, 4)
+ refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=PAD_ID)
+ refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1)
+ # Time
+ # if 'time' in formats:
+ # refs['time'] = [ex[2]['time'] for ex in batch]
+ # Coords
+ if 'coords' in formats:
+ refs['coords'] = pad_sequence([ex[2]['coords'] for ex in batch], batch_first=True, padding_value=-1.)
+ # Edges
+ if 'edges' in formats:
+ edges_list = [ex[2]['edges'] for ex in batch]
+ max_len = max([len(edges) for edges in edges_list])
+ refs['edges'] = torch.stack(
+ [F.pad(edges, (0, max_len - len(edges), 0, max_len - len(edges)), value=-100) for edges in edges_list],
+ dim=0)
+ return ids, pad_images(imgs), refs
diff --git a/molscribe/evaluate.py b/molscribe/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff28829c50df4746f2398ae30acdeecd00f49508
--- /dev/null
+++ b/molscribe/evaluate.py
@@ -0,0 +1,79 @@
+import numpy as np
+import multiprocessing
+
+import rdkit
+import rdkit.Chem as Chem
+rdkit.RDLogger.DisableLog('rdApp.*')
+from SmilesPE.pretokenizer import atomwise_tokenizer
+
+
+def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True):
+ if type(smiles) is not str or smiles == '':
+ return '', False
+ if ignore_cistrans:
+ smiles = smiles.replace('/', '').replace('\\', '')
+ if replace_rgroup:
+ tokens = atomwise_tokenizer(smiles)
+ for j, token in enumerate(tokens):
+ if token[0] == '[' and token[-1] == ']':
+ symbol = token[1:-1]
+ if symbol[0] == 'R' and symbol[1:].isdigit():
+ tokens[j] = f'[{symbol[1:]}*]'
+ elif Chem.AtomFromSmiles(token) is None:
+ tokens[j] = '*'
+ smiles = ''.join(tokens)
+ try:
+ canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral))
+ success = True
+ except:
+ canon_smiles = smiles
+ success = False
+ return canon_smiles, success
+
+
+def convert_smiles_to_canonsmiles(
+ smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16):
+ with multiprocessing.Pool(num_workers) as p:
+ results = p.starmap(canonicalize_smiles,
+ [(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list],
+ chunksize=128)
+ canon_smiles, success = zip(*results)
+ return list(canon_smiles), np.mean(success)
+
+
+class SmilesEvaluator(object):
+
+ def __init__(self, gold_smiles, num_workers=16):
+ self.gold_smiles = gold_smiles
+ self.gold_canon_smiles, self.gold_valid = convert_smiles_to_canonsmiles(gold_smiles, num_workers=num_workers)
+ self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles,
+ ignore_chiral=True, num_workers=num_workers)
+ self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles,
+ ignore_cistrans=True, num_workers=num_workers)
+ self.gold_canon_smiles = self._replace_empty(self.gold_canon_smiles)
+ self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral)
+ self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans)
+
+ def _replace_empty(self, smiles_list):
+ """Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty."""
+ return [smiles if smiles is not None and type(smiles) is str and smiles != "" else ""
+ for smiles in smiles_list]
+
+ def evaluate(self, pred_smiles):
+ results = {}
+ results['gold_valid'] = self.gold_valid
+ # Canon SMILES
+ pred_canon_smiles, pred_valid = convert_smiles_to_canonsmiles(pred_smiles)
+ results['canon_smiles_em'] = (np.array(self.gold_canon_smiles) == np.array(pred_canon_smiles)).mean()
+ results['pred_valid'] = pred_valid
+ # Ignore chirality (Graph exact match)
+ pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_chiral=True)
+ results['graph'] = (np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral)).mean()
+ # Ignore double bond cis/trans
+ pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_cistrans=True)
+ results['canon_smiles'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)).mean()
+ # Evaluate on molecules with chiral centers
+ chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g])
+ results['chiral_ratio'] = len(chiral) / len(self.gold_smiles)
+ results['chiral'] = (chiral[:, 0] == chiral[:, 1]).mean() if len(chiral) > 0 else -1
+ return results
diff --git a/molscribe/indigo/__init__.py b/molscribe/indigo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fbf42753036b29b9fc5a726d72db3cecb588c20
--- /dev/null
+++ b/molscribe/indigo/__init__.py
@@ -0,0 +1,4164 @@
+#
+#
+# Copyright (C) from 2009 to Present EPAM Systems.
+#
+# This file is part of Indigo toolkit.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import platform
+import sys
+import warnings
+from array import array
+from ctypes import (CDLL, POINTER, RTLD_GLOBAL, c_byte, c_char_p, c_double,
+ c_float, c_int, c_ulonglong, pointer)
+
+DECODE_ENCODING = "utf-8"
+ENCODE_ENCODING = "utf-8"
+
+
+class IndigoException(Exception):
+ def __init__(self, value):
+ if sys.version_info > (3, 0) and not isinstance(value, str):
+ self.value = value.decode(DECODE_ENCODING)
+ else:
+ self.value = value
+
+ def __str__(self):
+ return self.value
+
+
+class IndigoObject(object):
+ """Docstring for class IndigoObject."""
+
+ def __init__(self, dispatcher, id, parent=None):
+ self.id = id
+ self.dispatcher = dispatcher
+ self.parent = parent
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.dispatcher._setSessionId()
+ self.dispatcher._lib.indigoClose(self.id)
+
+ def __del__(self):
+ self.dispose()
+
+ def dispose(self):
+ if self.id >= 0:
+ if getattr(Indigo, "_lib", None) is not None:
+ self.dispatcher._setSessionId()
+ Indigo._lib.indigoFree(self.id)
+ self.id = -1
+
+ def __iter__(self):
+ return self
+
+ def _next(self):
+ self.dispatcher._setSessionId()
+ newobj = self.dispatcher._checkResult(Indigo._lib.indigoNext(self.id))
+ if newobj == 0:
+ return None
+ else:
+ return self.dispatcher.IndigoObject(self.dispatcher, newobj, self)
+
+ def __next__(self):
+ obj = self._next()
+ if obj == None:
+ raise StopIteration
+ return obj
+
+ def next(self):
+ return self.__next__()
+
+ def oneBitsList(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoOneBitsList(self.id)
+ )
+
+ def mdlct(self):
+ buf = self.dispatcher.writeBuffer()
+ self.dispatcher._setSessionId()
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoSaveMDLCT(self.id, buf.id)
+ )
+ return buf.toBuffer()
+
+ def xyz(self):
+ self.dispatcher._setSessionId()
+ xyz = Indigo._lib.indigoXYZ(self.id)
+ if xyz is None:
+ raise IndigoException(Indigo._lib.indigoGetLastError())
+ return [xyz[0], xyz[1], xyz[2]]
+
+ def coords(self):
+ self.dispatcher._setSessionId()
+ xyz = Indigo._lib.indigoCoords(self.id)
+ if xyz is None:
+ raise IndigoException(Indigo._lib.indigoGetLastError())
+ return [xyz[0], xyz[1]]
+
+ def alignAtoms(self, atom_ids, desired_xyz):
+ if len(atom_ids) * 3 != len(desired_xyz):
+ raise IndigoException(
+ "alignAtoms(): desired_xyz[] must be exactly 3 times bigger than atom_ids[]"
+ )
+ atoms = (c_int * len(atom_ids))()
+ for i in range(len(atoms)):
+ atoms[i] = atom_ids[i]
+ xyz = (c_float * len(desired_xyz))()
+ for i in range(len(desired_xyz)):
+ xyz[i] = desired_xyz[i]
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultFloat(
+ self.dispatcher._lib.indigoAlignAtoms(
+ self.id, len(atoms), atoms, xyz
+ )
+ )
+
+ def addStereocenter(self, type, v1, v2, v3, v4=-1):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddStereocenter(self.id, type, v1, v2, v3, v4)
+ )
+
+ def clone(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(Indigo._lib.indigoClone(self.id)),
+ )
+
+ def check(self, props=""):
+ if props is None:
+ props = ""
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoCheck(self.id, props.encode(ENCODE_ENCODING))
+ )
+
+ def close(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoClose(self.id))
+
+ def hasNext(self):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(Indigo._lib.indigoHasNext(self.id))
+ )
+
+ def index(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoIndex(self.id))
+
+ def remove(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoRemove(self.id))
+
+ def saveMolfile(self, filename):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSaveMolfileToFile(
+ self.id, filename.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def molfile(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoMolfile(self.id)
+ )
+
+ def saveCml(self, filename):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSaveCmlToFile(
+ self.id, filename.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def cml(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoCml(self.id)
+ )
+
+ def saveCdxml(self, filename):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSaveCdxmlToFile(
+ self.id, filename.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def cdxml(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoCdxml(self.id)
+ )
+
+ def json(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoJson(self.id)
+ )
+
+ def saveMDLCT(self, output):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSaveMDLCT(self.id, output.id)
+ )
+
+ def addReactant(self, molecule):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddReactant(self.id, molecule.id)
+ )
+
+ def addProduct(self, molecule):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddProduct(self.id, molecule.id)
+ )
+
+ def addCatalyst(self, molecule):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddCatalyst(self.id, molecule.id)
+ )
+
+ def countReactants(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountReactants(self.id)
+ )
+
+ def countProducts(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountProducts(self.id)
+ )
+
+ def countCatalysts(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountCatalysts(self.id)
+ )
+
+ def countMolecules(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountMolecules(self.id)
+ )
+
+ def getMolecule(self, index):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetMolecule(self.id, index)
+ ),
+ )
+
+ def iterateReactants(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateReactants(self.id)
+ ),
+ )
+
+ def iterateProducts(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateProducts(self.id)
+ ),
+ )
+
+ def iterateCatalysts(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateCatalysts(self.id)
+ ),
+ )
+
+ def iterateMolecules(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateMolecules(self.id)
+ ),
+ )
+
+ def saveRxnfile(self, filename):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSaveRxnfileToFile(
+ self.id, filename.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def rxnfile(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoRxnfile(self.id)
+ )
+
+ def optimize(self, options=""):
+ if options is None:
+ options = ""
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoOptimize(
+ self.id, options.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def normalize(self, options=""):
+ if options is None:
+ options = ""
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoNormalize(
+ self.id, options.encode(ENCODE_ENCODING)
+ )
+ )
+ )
+
+ def standardize(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoStandardize(self.id)
+ )
+
+ def ionize(self, pH, pH_toll):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoIonize(self.id, pH, pH_toll)
+ )
+
+ def getAcidPkaValue(self, atom, level, min_level):
+ self.dispatcher._setSessionId()
+ result = self.dispatcher._checkResultPtr(
+ Indigo._lib.indigoGetAcidPkaValue(
+ self.id, atom.id, level, min_level
+ )
+ )
+ return result[0]
+
+ def getBasicPkaValue(self, atom, level, min_level):
+ self.dispatcher._setSessionId()
+ result = self.dispatcher._checkResultPtr(
+ Indigo._lib.indigoGetBasicPkaValue(
+ self.id, atom.id, level, min_level
+ )
+ )
+ return result[0]
+
+ def automap(self, mode=""):
+ if mode is None:
+ mode = ""
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAutomap(self.id, mode.encode(ENCODE_ENCODING))
+ )
+
+ def atomMappingNumber(self, reaction_atom):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetAtomMappingNumber(self.id, reaction_atom.id)
+ )
+
+ def setAtomMappingNumber(self, reaction_atom, number):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetAtomMappingNumber(
+ self.id, reaction_atom.id, number
+ )
+ )
+
+ def reactingCenter(self, reaction_bond):
+ value = c_int()
+ self.dispatcher._setSessionId()
+ res = self.dispatcher._checkResult(
+ Indigo._lib.indigoGetReactingCenter(
+ self.id, reaction_bond.id, pointer(value)
+ )
+ )
+ if res == 0:
+ return None
+ return value.value
+
+ def setReactingCenter(self, reaction_bond, rc):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetReactingCenter(self.id, reaction_bond.id, rc)
+ )
+
+ def clearAAM(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoClearAAM(self.id)
+ )
+
+ def correctReactingCenters(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCorrectReactingCenters(self.id)
+ )
+
+ def iterateAtoms(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateAtoms(self.id)
+ ),
+ )
+
+ def iteratePseudoatoms(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIteratePseudoatoms(self.id)
+ ),
+ )
+
+ def iterateRSites(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateRSites(self.id)
+ ),
+ )
+
+ def iterateStereocenters(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateStereocenters(self.id)
+ ),
+ )
+
+ def iterateAlleneCenters(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateAlleneCenters(self.id)
+ ),
+ )
+
+ def iterateRGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateRGroups(self.id)
+ ),
+ )
+
+ def countRGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountRGroups(self.id)
+ )
+
+ def isPseudoatom(self):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIsPseudoatom(self.id)
+ )
+ )
+
+ def isRSite(self):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(Indigo._lib.indigoIsRSite(self.id))
+ )
+
+ def isTemplateAtom(self):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIsTemplateAtom(self.id)
+ )
+ )
+
+ def stereocenterType(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoStereocenterType(self.id)
+ )
+
+ def stereocenterGroup(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoStereocenterGroup(self.id)
+ )
+
+ def setStereocenterGroup(self, group):
+ self.dispatcher._setSessionId()
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoSetStereocenterGroup(self.id, group)
+ )
+
+ def changeStereocenterType(self, type):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoChangeStereocenterType(self.id, type)
+ )
+
+ def validateChirality(self):
+ self.dispatcher._setSessionId()
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoValidateChirality(self.id)
+ )
+
+ def singleAllowedRGroup(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSingleAllowedRGroup(self.id)
+ )
+
+ def iterateRGroupFragments(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateRGroupFragments(self.id)
+ ),
+ )
+
+ def countAttachmentPoints(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountAttachmentPoints(self.id)
+ )
+
+ def iterateAttachmentPoints(self, order):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateAttachmentPoints(self.id, order)
+ ),
+ )
+
+ def symbol(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoSymbol(self.id)
+ )
+
+ def degree(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoDegree(self.id))
+
+ def charge(self):
+ value = c_int()
+ self.dispatcher._setSessionId()
+ res = self.dispatcher._checkResult(
+ Indigo._lib.indigoGetCharge(self.id, pointer(value))
+ )
+ if res == 0:
+ return None
+ return value.value
+
+ def getExplicitValence(self):
+ value = c_int()
+ self.dispatcher._setSessionId()
+ res = self.dispatcher._checkResult(
+ Indigo._lib.indigoGetExplicitValence(self.id, pointer(value))
+ )
+ if res == 0:
+ return None
+ return value.value
+
+ def setExplicitValence(self, valence):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetExplicitValence(self.id, valence)
+ )
+
+ def radicalElectrons(self):
+ value = c_int()
+ self.dispatcher._setSessionId()
+ res = self.dispatcher._checkResult(
+ Indigo._lib.indigoGetRadicalElectrons(self.id, pointer(value))
+ )
+ if res == 0:
+ return None
+ return value.value
+
+ def radical(self):
+ value = c_int()
+ self.dispatcher._setSessionId()
+ res = self.dispatcher._checkResult(
+ Indigo._lib.indigoGetRadical(self.id, pointer(value))
+ )
+ if res == 0:
+ return None
+ return value.value
+
+ def setRadical(self, radical):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetRadical(self.id, radical)
+ )
+
+ def atomicNumber(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAtomicNumber(self.id)
+ )
+
+ def isotope(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoIsotope(self.id))
+
+ def valence(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoValence(self.id))
+
+ def checkValence(self):
+
+ """
+ ::
+
+ Since version 1.3.0
+ """
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCheckValence(self.id)
+ )
+
+ def checkQuery(self):
+ """
+ ::
+
+ Since version 1.3.0
+ """
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCheckQuery(self.id)
+ )
+
+ def checkRGroups(self):
+ """
+ ::
+
+ Since version 1.3.0
+ """
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCheckRGroups(self.id)
+ )
+
+ def checkChirality(self):
+
+ """
+ ::
+
+ Since version 1.3.0
+ """
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCheckChirality(self.id)
+ )
+
+ def check3DStereo(self):
+
+ """
+ ::
+
+ Since version 1.3.0
+ """
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCheck3DStereo(self.id)
+ )
+
+ def checkStereo(self):
+
+ """
+ ::
+
+ Since version 1.3.0
+ """
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCheckStereo(self.id)
+ )
+
+ def countHydrogens(self):
+ value = c_int()
+ self.dispatcher._setSessionId()
+ res = self.dispatcher._checkResult(
+ Indigo._lib.indigoCountHydrogens(self.id, pointer(value))
+ )
+ if res == 0:
+ return None
+ return value.value
+
+ def countImplicitHydrogens(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountImplicitHydrogens(self.id)
+ )
+
+ def setXYZ(self, x, y, z):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetXYZ(self.id, x, y, z)
+ )
+
+ def countSuperatoms(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountSuperatoms(self.id)
+ )
+
+ def countDataSGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountDataSGroups(self.id)
+ )
+
+ def countRepeatingUnits(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountRepeatingUnits(self.id)
+ )
+
+ def countMultipleGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountMultipleGroups(self.id)
+ )
+
+ def countGenericSGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountGenericSGroups(self.id)
+ )
+
+ def iterateDataSGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateDataSGroups(self.id)
+ ),
+ )
+
+ def iterateSuperatoms(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateSuperatoms(self.id)
+ ),
+ )
+
+ def iterateGenericSGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateGenericSGroups(self.id)
+ ),
+ )
+
+ def iterateRepeatingUnits(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateRepeatingUnits(self.id)
+ ),
+ )
+
+ def iterateMultipleGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateMultipleGroups(self.id)
+ ),
+ )
+
+ def iterateSGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateSGroups(self.id)
+ ),
+ )
+
+ def iterateTGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateTGroups(self.id)
+ ),
+ )
+
+ def getSuperatom(self, index):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSuperatom(self.id, index)
+ ),
+ )
+
+ def getDataSGroup(self, index):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetDataSGroup(self.id, index)
+ ),
+ )
+
+ def getGenericSGroup(self, index):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetGenericSGroup(self.id, index)
+ ),
+ )
+
+ def getMultipleGroup(self, index):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetMultipleGroup(self.id, index)
+ ),
+ )
+
+ def getRepeatingUnit(self, index):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetRepeatingUnit(self.id, index)
+ ),
+ )
+
+ def description(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoDescription(self.id)
+ )
+
+ def data(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoData(self.id)
+ )
+
+ def addDataSGroup(self, atoms, bonds, description, data):
+ arr2 = (c_int * len(atoms))()
+ for i in range(len(atoms)):
+ arr2[i] = atoms[i]
+ arr4 = (c_int * len(bonds))()
+ for i in range(len(bonds)):
+ arr4[i] = bonds[i]
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoAddDataSGroup(
+ self.id,
+ len(arr2),
+ arr2,
+ len(arr4),
+ arr4,
+ description.encode(ENCODE_ENCODING),
+ data.encode(ENCODE_ENCODING),
+ )
+ ),
+ )
+
+ def addSuperatom(self, atoms, name):
+ arr2 = (c_int * len(atoms))()
+ for i in range(len(atoms)):
+ arr2[i] = atoms[i]
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoAddSuperatom(
+ self.id, len(arr2), arr2, name.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def setDataSGroupXY(self, x, y, options=""):
+ self.dispatcher._setSessionId()
+ if options is None:
+ options = ""
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetDataSGroupXY(
+ self.id, x, y, options.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupData(self, data):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupData(
+ self.id, data.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupCoords(self, x, y):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupCoords(self.id, x, y)
+ )
+
+ def setSGroupDescription(self, description):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupDescription(
+ self.id, description.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupFieldName(self, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupFieldName(
+ self.id, name.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupQueryCode(self, code):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupQueryCode(
+ self.id, code.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupQueryOper(self, oper):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupQueryOper(
+ self.id, oper.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupDisplay(self, option):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupDisplay(
+ self.id, option.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupLocation(self, option):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupLocation(
+ self.id, option.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupTag(self, tag):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupTag(
+ self.id, tag.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupTagAlign(self, tag_align):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupTagAlign(self.id, tag_align)
+ )
+
+ def setSGroupDataType(self, data_type):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupDataType(
+ self.id, data_type.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupXCoord(self, x):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupXCoord(self.id, x)
+ )
+
+ def setSGroupYCoord(self, y):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupYCoord(self.id, y)
+ )
+
+ def createSGroup(self, sgtype, mapping, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoCreateSGroup(
+ sgtype.encode(ENCODE_ENCODING),
+ mapping.id,
+ name.encode(ENCODE_ENCODING),
+ )
+ ),
+ )
+
+ def setSGroupClass(self, sgclass):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupClass(
+ self.id, sgclass.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setSGroupName(self, sgname):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupName(
+ self.id, sgname.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def getSGroupClass(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoGetSGroupClass(self.id)
+ )
+
+ def getSGroupName(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoGetSGroupName(self.id)
+ )
+
+ def getSGroupNumCrossBonds(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSGroupNumCrossBonds(self.id)
+ )
+
+ def addSGroupAttachmentPoint(self, aidx, lvidx, apid):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddSGroupAttachmentPoint(
+ self.id, aidx, lvidx, apid.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def deleteSGroupAttachmentPoint(self, apidx):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoDeleteSGroupAttachmentPoint(self.id, apidx)
+ )
+
+ def getSGroupDisplayOption(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSGroupDisplayOption(self.id)
+ )
+
+ def setSGroupDisplayOption(self, option):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupDisplayOption(self.id, option)
+ )
+
+ def getSGroupSeqId(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSGroupSeqId(self.id)
+ )
+
+ def getSGroupCoords(self):
+ """
+ Returns:
+ XY coordinates for Data sgroup
+ ::
+ Since 1.3.0
+ """
+ self.dispatcher._setSessionId()
+ xyz = Indigo._lib.indigoGetSGroupCoords(self.id)
+ if xyz is None:
+ raise IndigoException(Indigo._lib.indigoGetLastError())
+ return [xyz[0], xyz[1]]
+
+ def getRepeatingUnitSubscript(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoGetRepeatingUnitSubscript(self.id)
+ )
+
+ def getRepeatingUnitConnectivity(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetRepeatingUnitConnectivity(self.id)
+ )
+
+ def getSGroupMultiplier(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSGroupMultiplier(self.id)
+ )
+
+ def setSGroupMultiplier(self, mult):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupMultiplier(self.id, mult)
+ )
+
+ def setSGroupBrackets(self, style, x1, y1, x2, y2, x3, y3, x4, y4):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupBrackets(
+ self.id, style, x1, y1, x2, y2, x3, y3, x4, y4
+ )
+ )
+
+ def findSGroups(self, prop, val):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoFindSGroups(
+ self.id,
+ prop.encode(ENCODE_ENCODING),
+ val.encode(ENCODE_ENCODING),
+ )
+ ),
+ )
+
+ def getSGroupType(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSGroupType(self.id)
+ )
+
+ def getSGroupIndex(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSGroupIndex(self.id)
+ )
+
+ def getSGroupOriginalId(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSGroupOriginalId(self.id)
+ )
+
+ def setSGroupOriginalId(self, original):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupOriginalId(self.id, original)
+ )
+
+ def getSGroupParentId(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSGroupParentId(self.id)
+ )
+
+ def setSGroupParentId(self, parent):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetSGroupParentId(self.id, parent)
+ )
+
+ def addTemplate(self, templates, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddTemplate(
+ self.id, templates.id, name.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def removeTemplate(self, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoRemoveTemplate(
+ self.id, name.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def findTemplate(self, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoFindTemplate(
+ self.id, name.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def getTGroupClass(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoGetTGroupClass(self.id)
+ )
+
+ def getTGroupName(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoGetTGroupName(self.id)
+ )
+
+ def getTGroupAlias(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoGetTGroupAlias(self.id)
+ )
+
+ def transformSCSRtoCTAB(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoTransformSCSRtoCTAB(self.id)
+ )
+
+ def transformCTABtoSCSR(self, templates):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoTransformCTABtoSCSR(self.id, templates.id)
+ )
+
+ def getTemplateAtomClass(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoGetTemplateAtomClass(self.id)
+ )
+
+ def setTemplateAtomClass(self, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetTemplateAtomClass(
+ self.id, name.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def clean2d(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoClean2d(self.id))
+
+ def resetCharge(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoResetCharge(self.id)
+ )
+
+ def resetExplicitValence(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoResetExplicitValence(self.id)
+ )
+
+ def resetRadical(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoResetRadical(self.id)
+ )
+
+ def resetIsotope(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoResetIsotope(self.id)
+ )
+
+ def setAttachmentPoint(self, order):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetAttachmentPoint(self.id, order)
+ )
+
+ def clearAttachmentPoints(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoClearAttachmentPoints(self.id)
+ )
+
+ def removeConstraints(self, type):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoRemoveConstraints(
+ self.id, type.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def addConstraint(self, type, value):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddConstraint(
+ self.id,
+ type.encode(ENCODE_ENCODING),
+ value.encode(ENCODE_ENCODING),
+ )
+ )
+
+ def addConstraintNot(self, type, value):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddConstraintNot(
+ self.id,
+ type.encode(ENCODE_ENCODING),
+ value.encode(ENCODE_ENCODING),
+ )
+ )
+
+ def addConstraintOr(self, type, value):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddConstraintOr(
+ self.id,
+ type.encode(ENCODE_ENCODING),
+ value.encode(ENCODE_ENCODING),
+ )
+ )
+
+ def resetStereo(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoResetStereo(self.id)
+ )
+
+ def invertStereo(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoInvertStereo(self.id)
+ )
+
+ def countAtoms(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountAtoms(self.id)
+ )
+
+ def countBonds(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountBonds(self.id)
+ )
+
+ def countPseudoatoms(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountPseudoatoms(self.id)
+ )
+
+ def countRSites(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountRSites(self.id)
+ )
+
+ def iterateBonds(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateBonds(self.id)
+ ),
+ )
+
+ def bondOrder(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoBondOrder(self.id)
+ )
+
+ def bondStereo(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoBondStereo(self.id)
+ )
+
+ def topology(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoTopology(self.id)
+ )
+
+ def iterateNeighbors(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateNeighbors(self.id)
+ ),
+ )
+
+ def bond(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(Indigo._lib.indigoBond(self.id)),
+ )
+
+ def getAtom(self, idx):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetAtom(self.id, idx)
+ ),
+ )
+
+ def getBond(self, idx):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetBond(self.id, idx)
+ ),
+ )
+
+ def source(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(Indigo._lib.indigoSource(self.id)),
+ )
+
+ def destination(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoDestination(self.id)
+ ),
+ )
+
+ def clearCisTrans(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoClearCisTrans(self.id)
+ )
+
+ def clearStereocenters(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoClearStereocenters(self.id)
+ )
+
+ def countStereocenters(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountStereocenters(self.id)
+ )
+
+ def clearAlleneCenters(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoClearAlleneCenters(self.id)
+ )
+
+ def countAlleneCenters(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountAlleneCenters(self.id)
+ )
+
+ def resetSymmetricCisTrans(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoResetSymmetricCisTrans(self.id)
+ )
+
+ def resetSymmetricStereocenters(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoResetSymmetricStereocenters(self.id)
+ )
+
+ def markEitherCisTrans(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoMarkEitherCisTrans(self.id)
+ )
+
+ def markStereobonds(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoMarkStereobonds(self.id)
+ )
+
+ def addAtom(self, symbol):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoAddAtom(
+ self.id, symbol.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def resetAtom(self, symbol):
+ self.dispatcher._setSessionId()
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoResetAtom(
+ self.id, symbol.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def addRSite(self, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoAddRSite(
+ self.id, name.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def setRSite(self, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetRSite(self.id, name.encode(ENCODE_ENCODING))
+ )
+
+ def setCharge(self, charge):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetCharge(self.id, charge)
+ )
+
+ def setIsotope(self, isotope):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetIsotope(self.id, isotope)
+ )
+
+ def setImplicitHCount(self, impl_h):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetImplicitHCount(self.id, impl_h)
+ )
+
+ def addBond(self, destination, order):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoAddBond(self.id, destination.id, order)
+ ),
+ )
+
+ def setBondOrder(self, order):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoSetBondOrder(self.id, order)
+ ),
+ )
+
+ def merge(self, what):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoMerge(self.id, what.id)
+ ),
+ )
+
+ def highlight(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoHighlight(self.id)
+ )
+
+ def unhighlight(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoUnhighlight(self.id)
+ )
+
+ def isHighlighted(self):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIsHighlighted(self.id)
+ )
+ )
+
+ def countComponents(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountComponents(self.id)
+ )
+
+ def componentIndex(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoComponentIndex(self.id)
+ )
+
+ def iterateComponents(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateComponents(self.id)
+ ),
+ )
+
+ def component(self, index):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoComponent(self.id, index)
+ ),
+ )
+
+ def countSSSR(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountSSSR(self.id)
+ )
+
+ def iterateSSSR(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateSSSR(self.id)
+ ),
+ )
+
+ def iterateSubtrees(self, min_atoms, max_atoms):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateSubtrees(
+ self.id, min_atoms, max_atoms
+ )
+ ),
+ )
+
+ def iterateRings(self, min_atoms, max_atoms):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateRings(self.id, min_atoms, max_atoms)
+ ),
+ )
+
+ def iterateEdgeSubmolecules(self, min_bonds, max_bonds):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateEdgeSubmolecules(
+ self.id, min_bonds, max_bonds
+ )
+ ),
+ )
+
+ def countHeavyAtoms(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountHeavyAtoms(self.id)
+ )
+
+ def grossFormula(self):
+ self.dispatcher._setSessionId()
+ gfid = self.dispatcher._checkResult(
+ Indigo._lib.indigoGrossFormula(self.id)
+ )
+ gf = self.dispatcher.IndigoObject(self.dispatcher, gfid)
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoToString(gf.id)
+ )
+
+ def molecularWeight(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultFloat(
+ Indigo._lib.indigoMolecularWeight(self.id)
+ )
+
+ def mostAbundantMass(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultFloat(
+ Indigo._lib.indigoMostAbundantMass(self.id)
+ )
+
+ def monoisotopicMass(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultFloat(
+ Indigo._lib.indigoMonoisotopicMass(self.id)
+ )
+
+ def massComposition(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoMassComposition(self.id)
+ )
+
+ def canonicalSmiles(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoCanonicalSmiles(self.id)
+ )
+
+ def canonicalSmarts(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoCanonicalSmarts(self.id)
+ )
+
+ def layeredCode(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoLayeredCode(self.id)
+ )
+
+ def symmetryClasses(self):
+ c_size = c_int()
+ self.dispatcher._setSessionId()
+ c_buf = self.dispatcher._checkResultPtr(
+ Indigo._lib.indigoSymmetryClasses(self.id, pointer(c_size))
+ )
+ res = array("i")
+ for i in range(c_size.value):
+ res.append(c_buf[i])
+ return res
+
+ def hasCoord(self):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(Indigo._lib.indigoHasCoord(self.id))
+ )
+
+ def hasZCoord(self):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(Indigo._lib.indigoHasZCoord(self.id))
+ )
+
+ def isChiral(self):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(Indigo._lib.indigoIsChiral(self.id))
+ )
+
+ def isPossibleFischerProjection(self, options):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIsPossibleFischerProjection(
+ self.id, options.encode(ENCODE_ENCODING)
+ )
+ )
+ )
+
+ def createSubmolecule(self, vertices):
+ arr2 = (c_int * len(vertices))()
+ for i in range(len(vertices)):
+ arr2[i] = vertices[i]
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoCreateSubmolecule(self.id, len(arr2), arr2)
+ ),
+ )
+
+ def createEdgeSubmolecule(self, vertices, edges):
+ arr2 = (c_int * len(vertices))()
+ for i in range(len(vertices)):
+ arr2[i] = vertices[i]
+ arr4 = (c_int * len(edges))()
+ for i in range(len(edges)):
+ arr4[i] = edges[i]
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoCreateEdgeSubmolecule(
+ self.id, len(arr2), arr2, len(arr4), arr4
+ )
+ ),
+ )
+
+ def getSubmolecule(self, vertices):
+ arr2 = (c_int * len(vertices))()
+ for i in range(len(vertices)):
+ arr2[i] = vertices[i]
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoGetSubmolecule(self.id, len(arr2), arr2)
+ ),
+ self,
+ )
+
+ def removeAtoms(self, vertices):
+ arr2 = (c_int * len(vertices))()
+ for i in range(len(vertices)):
+ arr2[i] = vertices[i]
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoRemoveAtoms(self.id, len(arr2), arr2)
+ )
+
+ def removeBonds(self, bonds):
+ arr2 = (c_int * len(bonds))()
+ for i in range(len(bonds)):
+ arr2[i] = bonds[i]
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoRemoveBonds(self.id, len(arr2), arr2)
+ )
+
+ def aromatize(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAromatize(self.id)
+ )
+
+ def dearomatize(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoDearomatize(self.id)
+ )
+
+ def foldHydrogens(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoFoldHydrogens(self.id)
+ )
+
+ def unfoldHydrogens(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoUnfoldHydrogens(self.id)
+ )
+
+ def layout(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoLayout(self.id))
+
+ def smiles(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoSmiles(self.id)
+ )
+
+ def smarts(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoSmarts(self.id)
+ )
+
+ def name(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoName(self.id)
+ )
+
+ def setName(self, name):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetName(self.id, name.encode(ENCODE_ENCODING))
+ )
+
+ def serialize(self):
+ c_size = c_int()
+ c_buf = POINTER(c_byte)()
+ self.dispatcher._setSessionId()
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoSerialize(
+ self.id, pointer(c_buf), pointer(c_size)
+ )
+ )
+ res = array("b")
+ for i in range(c_size.value):
+ res.append(c_buf[i])
+ return res
+
+ def hasProperty(self, prop):
+ self.dispatcher._setSessionId()
+ return bool(
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoHasProperty(self.id, prop)
+ )
+ )
+
+ def getProperty(self, prop):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoGetProperty(
+ self.id, prop.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def setProperty(self, prop, value):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSetProperty(
+ self.id,
+ prop.encode(ENCODE_ENCODING),
+ value.encode(ENCODE_ENCODING),
+ )
+ )
+
+ def removeProperty(self, prop):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoRemoveProperty(
+ self.id, prop.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def iterateProperties(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateProperties(self.id)
+ ),
+ )
+
+ def clearProperties(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoClearProperties(self.id)
+ )
+
+ def checkBadValence(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoCheckBadValence(self.id)
+ )
+
+ def checkAmbiguousH(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoCheckAmbiguousH(self.id)
+ )
+
+ def fingerprint(self, type):
+ self.dispatcher._setSessionId()
+ newobj = self.dispatcher._checkResult(
+ Indigo._lib.indigoFingerprint(
+ self.id, type.encode(ENCODE_ENCODING)
+ )
+ )
+ if newobj == 0:
+ return None
+ return self.dispatcher.IndigoObject(self.dispatcher, newobj, self)
+
+ def countBits(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountBits(self.id)
+ )
+
+ def rawData(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoRawData(self.id)
+ )
+
+ def tell(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoTell(self.id))
+
+ def sdfAppend(self, item):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSdfAppend(self.id, item.id)
+ )
+
+ def smilesAppend(self, item):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoSmilesAppend(self.id, item.id)
+ )
+
+ def rdfHeader(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoRdfHeader(self.id)
+ )
+
+ def rdfAppend(self, item):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoRdfAppend(self.id, item.id)
+ )
+
+ def cmlHeader(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCmlHeader(self.id)
+ )
+
+ def cmlAppend(self, item):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCmlAppend(self.id, item.id)
+ )
+
+ def cmlFooter(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCmlFooter(self.id)
+ )
+
+ def append(self, object):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAppend(self.id, object.id)
+ )
+
+ def arrayAdd(self, object):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoArrayAdd(self.id, object.id)
+ )
+
+ def at(self, index):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(Indigo._lib.indigoAt(self.id, index)),
+ )
+
+ def count(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoCount(self.id))
+
+ def clear(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(Indigo._lib.indigoClear(self.id))
+
+ def iterateArray(self):
+ self.dispatcher._setSessionId()
+ newobj = self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateArray(self.id)
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.dispatcher.IndigoObject(self.dispatcher, newobj, self)
+
+ def ignoreAtom(self, atom_object):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoIgnoreAtom(self.id, atom_object.id)
+ )
+
+ def unignoreAtom(self, atom_object):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoUnignoreAtom(self.id, atom_object.id)
+ )
+
+ def unignoreAllAtoms(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoUnignoreAllAtoms(self.id)
+ )
+
+ def match(self, query):
+ self.dispatcher._setSessionId()
+ newobj = self.dispatcher._checkResult(
+ Indigo._lib.indigoMatch(self.id, query.id)
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.dispatcher.IndigoObject(self.dispatcher, newobj, self)
+
+ def countMatches(self, query):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountMatches(self.id, query.id)
+ )
+
+ def countMatchesWithLimit(self, query, embeddings_limit):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoCountMatchesWithLimit(
+ self.id, query.id, embeddings_limit
+ )
+ )
+
+ def iterateMatches(self, query):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateMatches(self.id, query.id)
+ ),
+ )
+
+ def highlightedTarget(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoHighlightedTarget(self.id)
+ ),
+ )
+
+ def mapAtom(self, atom):
+ self.dispatcher._setSessionId()
+ newobj = self.dispatcher._checkResult(
+ Indigo._lib.indigoMapAtom(self.id, atom.id)
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.dispatcher.IndigoObject(self.dispatcher, newobj, self)
+
+ def mapBond(self, bond):
+ self.dispatcher._setSessionId()
+ newobj = self.dispatcher._checkResult(
+ Indigo._lib.indigoMapBond(self.id, bond.id)
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.dispatcher.IndigoObject(self.dispatcher, newobj, self)
+
+ def mapMolecule(self, molecule):
+ self.dispatcher._setSessionId()
+ newobj = self.dispatcher._checkResult(
+ Indigo._lib.indigoMapMolecule(self.id, molecule.id)
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.dispatcher.IndigoObject(self.dispatcher, newobj, self)
+
+ def allScaffolds(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoAllScaffolds(self.id)
+ ),
+ )
+
+ def decomposedMoleculeScaffold(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoDecomposedMoleculeScaffold(self.id)
+ ),
+ )
+
+ def iterateDecomposedMolecules(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateDecomposedMolecules(self.id)
+ ),
+ )
+
+ def decomposedMoleculeHighlighted(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoDecomposedMoleculeHighlighted(self.id)
+ ),
+ )
+
+ def decomposedMoleculeWithRGroups(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoDecomposedMoleculeWithRGroups(self.id)
+ ),
+ )
+
+ def decomposeMolecule(self, mol):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoDecomposeMolecule(self.id, mol.id)
+ ),
+ )
+
+ def iterateDecompositions(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher.IndigoObject(
+ self.dispatcher,
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoIterateDecompositions(self.id)
+ ),
+ )
+
+ def addDecomposition(self, q_match):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoAddDecomposition(self.id, q_match.id)
+ )
+
+ def toString(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoToString(self.id)
+ )
+
+ def toBuffer(self):
+ c_size = c_int()
+ c_buf = POINTER(c_byte)()
+ self.dispatcher._setSessionId()
+ self.dispatcher._checkResult(
+ Indigo._lib.indigoToBuffer(
+ self.id, pointer(c_buf), pointer(c_size)
+ )
+ )
+ res = array("b")
+ for i in range(c_size.value):
+ res.append(c_buf[i])
+ return res
+
+ def stereocenterPyramid(self):
+ self.dispatcher._setSessionId()
+ ptr = self.dispatcher._checkResultPtr(
+ Indigo._lib.indigoStereocenterPyramid(self.id)
+ )
+ res = [0] * 4
+ for i in range(4):
+ res[i] = ptr[i]
+ return res
+
+ def expandAbbreviations(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResult(
+ Indigo._lib.indigoExpandAbbreviations(self.id)
+ )
+
+ def dbgInternalType(self):
+ self.dispatcher._setSessionId()
+ return self.dispatcher._checkResultString(
+ Indigo._lib.indigoDbgInternalType(self.id)
+ )
+
+
+class Indigo(object):
+ ABS = 1
+ OR = 2
+ AND = 3
+ EITHER = 4
+ UP = 5
+ DOWN = 6
+ CIS = 7
+ TRANS = 8
+ CHAIN = 9
+ RING = 10
+ ALLENE = 11
+
+ SINGLET = 101
+ DOUBLET = 102
+ TRIPLET = 103
+ RC_NOT_CENTER = -1
+ RC_UNMARKED = 0
+ RC_CENTER = 1
+ RC_UNCHANGED = 2
+ RC_MADE_OR_BROKEN = 4
+ RC_ORDER_CHANGED = 8
+
+ SG_TYPE_GEN = 0
+ SG_TYPE_DAT = 1
+ SG_TYPE_SUP = 2
+ SG_TYPE_SRU = 3
+ SG_TYPE_MUL = 4
+ SG_TYPE_MON = 5
+ SG_TYPE_MER = 6
+ SG_TYPE_COP = 7
+ SG_TYPE_CRO = 8
+ SG_TYPE_MOD = 9
+ SG_TYPE_GRA = 10
+ SG_TYPE_COM = 11
+ SG_TYPE_MIX = 12
+ SG_TYPE_FOR = 13
+ SG_TYPE_ANY = 14
+
+ _crt = None
+ _crtp = None
+ _lib = None
+
+ # Python embeds path into .pyc code if method is marked with @staticmethod
+ # This causes an error when Indigo is loaded from different places by relative path
+ def _initStatic(self, path=None):
+ def cdll_if_exists(cdll_path_):
+ if os.path.exists(cdll_path_):
+ return CDLL(cdll_path_)
+ return None
+
+ paths = []
+ if not path:
+ cur_file = os.path.abspath(__file__)
+ paths = [
+ os.path.join(os.path.dirname(cur_file), "lib"),
+ os.path.join(
+ os.path.dirname(os.path.dirname(cur_file)), "lib"
+ ),
+ ]
+ else:
+ paths.append(path)
+
+ indigoFound = False
+ for path in paths:
+ if (
+ os.name == "posix"
+ and not platform.mac_ver()[0]
+ and not platform.system().startswith("CYGWIN")
+ ):
+ arch = platform.architecture()[0]
+ path = os.path.join(path, "Linux")
+ if arch == "32bit":
+ path = os.path.join(path, "x86")
+ elif arch == "64bit":
+ path = os.path.join(path, "x64")
+ else:
+ raise IndigoException("unknown platform " + arch)
+ if os.path.exists(os.path.join(path, "libindigo.so")):
+ Indigo._lib = CDLL(
+ os.path.join(path, "libindigo.so"), mode=RTLD_GLOBAL
+ )
+ indigoFound = True
+ Indigo.dllpath = path
+ elif os.name == "nt" or platform.system().startswith("CYGWIN"):
+ arch = platform.architecture()[0]
+ path = os.path.join(path, "Win")
+ if arch == "32bit":
+ path = os.path.join(path, "x86")
+ elif arch == "64bit":
+ path = os.path.join(path, "x64")
+ else:
+ raise IndigoException("unknown platform " + arch)
+ if os.path.exists(os.path.join(path, "indigo.dll")):
+ Indigo._crt = cdll_if_exists(
+ os.path.join(path, "vcruntime140.dll")
+ )
+ Indigo._crt_1 = cdll_if_exists(
+ os.path.join(path, "vcruntime140_1.dll")
+ )
+ Indigo._crtp = cdll_if_exists(
+ os.path.join(path, "msvcp140.dll")
+ )
+ Indigo._crtc = cdll_if_exists(
+ os.path.join(path, "concrt140.dll")
+ )
+ Indigo._lib = CDLL(os.path.join(path, "indigo.dll"))
+ indigoFound = True
+ Indigo.dllpath = path
+ elif platform.mac_ver()[0]:
+ path = os.path.join(path, "Mac")
+ mac_ver = ".".join(platform.mac_ver()[0].split(".")[:2])
+ current_mac_ver = int(mac_ver.split(".")[1])
+ using_mac_ver = None
+ for version in reversed(range(5, current_mac_ver + 1)):
+ if os.path.exists(
+ os.path.join(path, "10." + str(version))
+ ):
+ using_mac_ver = str(version)
+ break
+ if using_mac_ver:
+ path = os.path.join(path, "10." + using_mac_ver)
+ Indigo._lib = CDLL(
+ os.path.join(path, "libindigo.dylib"), mode=RTLD_GLOBAL
+ )
+ indigoFound = True
+ Indigo.dllpath = path
+ else:
+ raise IndigoException("unsupported OS: " + os.name)
+ if not indigoFound:
+ raise IndigoException(
+ "Could not find native libraries for target OS in search directories: {}".format(
+ os.pathsep.join(paths)
+ )
+ )
+
+ def _setSessionId(self):
+ Indigo._lib.indigoSetSessionId(self._sid)
+
+ def __init__(self, path=None):
+ if Indigo._lib is None:
+ self._initStatic(path)
+ self._sid = Indigo._lib.indigoAllocSessionId()
+ # Capture a reference to the _lib to access it in the __del__ method because
+ # at interpreter shutdown, the module's global variables are set to None
+ self._lib = Indigo._lib
+ self._setSessionId()
+ self.IndigoObject = IndigoObject
+ Indigo._lib.indigoVersion.restype = c_char_p
+ Indigo._lib.indigoVersion.argtypes = None
+ Indigo._lib.indigoAllocSessionId.restype = c_ulonglong
+ Indigo._lib.indigoAllocSessionId.argtypes = None
+ Indigo._lib.indigoSetSessionId.restype = None
+ Indigo._lib.indigoSetSessionId.argtypes = [c_ulonglong]
+ Indigo._lib.indigoReleaseSessionId.restype = None
+ Indigo._lib.indigoReleaseSessionId.argtypes = [c_ulonglong]
+ Indigo._lib.indigoGetLastError.restype = c_char_p
+ Indigo._lib.indigoGetLastError.argtypes = None
+ Indigo._lib.indigoFree.restype = c_int
+ Indigo._lib.indigoFree.argtypes = [c_int]
+ Indigo._lib.indigoCountReferences.restype = c_int
+ Indigo._lib.indigoCountReferences.argtypes = None
+ Indigo._lib.indigoFreeAllObjects.restype = c_int
+ Indigo._lib.indigoFreeAllObjects.argtypes = None
+ Indigo._lib.indigoSetOption.restype = c_int
+ Indigo._lib.indigoSetOption.argtypes = [c_char_p, c_char_p]
+ Indigo._lib.indigoSetOptionInt.restype = c_int
+ Indigo._lib.indigoSetOptionInt.argtypes = [c_char_p, c_int]
+ Indigo._lib.indigoSetOptionBool.restype = c_int
+ Indigo._lib.indigoSetOptionBool.argtypes = [c_char_p, c_int]
+ Indigo._lib.indigoSetOptionFloat.restype = c_int
+ Indigo._lib.indigoSetOptionFloat.argtypes = [c_char_p, c_float]
+ Indigo._lib.indigoSetOptionColor.restype = c_int
+ Indigo._lib.indigoSetOptionColor.argtypes = [
+ c_char_p,
+ c_float,
+ c_float,
+ c_float,
+ ]
+ Indigo._lib.indigoSetOptionXY.restype = c_int
+ Indigo._lib.indigoSetOptionXY.argtypes = [c_char_p, c_int, c_int]
+ Indigo._lib.indigoGetOption.restype = c_char_p
+ Indigo._lib.indigoGetOption.argtypes = [c_char_p]
+ Indigo._lib.indigoGetOptionInt.restype = c_int
+ Indigo._lib.indigoGetOptionInt.argtypes = [c_char_p, POINTER(c_int)]
+ Indigo._lib.indigoGetOptionBool.argtypes = [c_char_p, POINTER(c_int)]
+ Indigo._lib.indigoGetOptionBool.restype = c_int
+ Indigo._lib.indigoGetOptionFloat.argtypes = [
+ c_char_p,
+ POINTER(c_float),
+ ]
+ Indigo._lib.indigoGetOptionFloat.restype = c_int
+ Indigo._lib.indigoGetOptionColor.argtypes = [
+ c_char_p,
+ POINTER(c_float),
+ POINTER(c_float),
+ POINTER(c_float),
+ ]
+ Indigo._lib.indigoGetOptionColor.restype = c_int
+ Indigo._lib.indigoGetOptionXY.argtypes = [
+ c_char_p,
+ POINTER(c_int),
+ POINTER(c_int),
+ ]
+ Indigo._lib.indigoGetOptionXY.restype = c_int
+ Indigo._lib.indigoGetOptionType.restype = c_char_p
+ Indigo._lib.indigoGetOptionType.argtypes = [c_char_p]
+ Indigo._lib.indigoReadFile.restype = c_int
+ Indigo._lib.indigoReadFile.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadString.restype = c_int
+ Indigo._lib.indigoLoadString.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadBuffer.restype = c_int
+ Indigo._lib.indigoLoadBuffer.argtypes = [POINTER(c_byte), c_int]
+ Indigo._lib.indigoWriteFile.restype = c_int
+ Indigo._lib.indigoWriteFile.argtypes = [c_char_p]
+ Indigo._lib.indigoWriteBuffer.restype = c_int
+ Indigo._lib.indigoWriteBuffer.argtypes = None
+ Indigo._lib.indigoCreateMolecule.restype = c_int
+ Indigo._lib.indigoCreateMolecule.argtypes = None
+ Indigo._lib.indigoCreateQueryMolecule.restype = c_int
+ Indigo._lib.indigoCreateQueryMolecule.argtypes = None
+ Indigo._lib.indigoLoadMoleculeFromString.restype = c_int
+ Indigo._lib.indigoLoadMoleculeFromString.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadMoleculeFromFile.restype = c_int
+ Indigo._lib.indigoLoadMoleculeFromFile.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadMoleculeFromBuffer.restype = c_int
+ Indigo._lib.indigoLoadMoleculeFromBuffer.argtypes = [
+ POINTER(c_byte),
+ c_int,
+ ]
+ Indigo._lib.indigoLoadQueryMoleculeFromString.restype = c_int
+ Indigo._lib.indigoLoadQueryMoleculeFromString.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadQueryMoleculeFromFile.restype = c_int
+ Indigo._lib.indigoLoadQueryMoleculeFromFile.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadSmartsFromString.restype = c_int
+ Indigo._lib.indigoLoadSmartsFromString.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadSmartsFromFile.restype = c_int
+ Indigo._lib.indigoLoadSmartsFromFile.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadReactionFromString.restype = c_int
+ Indigo._lib.indigoLoadReactionFromString.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadReactionFromFile.restype = c_int
+ Indigo._lib.indigoLoadReactionFromFile.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadQueryReactionFromString.restype = c_int
+ Indigo._lib.indigoLoadQueryReactionFromString.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadQueryReactionFromFile.restype = c_int
+ Indigo._lib.indigoLoadQueryReactionFromFile.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadReactionSmartsFromString.restype = c_int
+ Indigo._lib.indigoLoadReactionSmartsFromString.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadReactionSmartsFromFile.restype = c_int
+ Indigo._lib.indigoLoadReactionSmartsFromFile.argtypes = [c_char_p]
+ Indigo._lib.indigoLoadStructureFromString.restype = c_int
+ Indigo._lib.indigoLoadStructureFromString.argtypes = [
+ c_char_p,
+ c_char_p,
+ ]
+ Indigo._lib.indigoLoadStructureFromBuffer.restype = c_int
+ Indigo._lib.indigoLoadStructureFromBuffer.argtypes = [
+ POINTER(c_byte),
+ c_int,
+ c_char_p,
+ ]
+ Indigo._lib.indigoLoadStructureFromFile.restype = c_int
+ Indigo._lib.indigoLoadStructureFromFile.argtypes = [c_char_p, c_char_p]
+ Indigo._lib.indigoCreateReaction.restype = c_int
+ Indigo._lib.indigoCreateReaction.argtypes = None
+ Indigo._lib.indigoCreateQueryReaction.restype = c_int
+ Indigo._lib.indigoCreateQueryReaction.argtypes = None
+ Indigo._lib.indigoExactMatch.restype = c_int
+ Indigo._lib.indigoExactMatch.argtypes = [c_int, c_int, c_char_p]
+ Indigo._lib.indigoSetTautomerRule.restype = c_int
+ Indigo._lib.indigoSetTautomerRule.argtypes = [
+ c_int,
+ c_char_p,
+ c_char_p,
+ ]
+ Indigo._lib.indigoRemoveTautomerRule.restype = c_int
+ Indigo._lib.indigoRemoveTautomerRule.argtypes = [c_int]
+ Indigo._lib.indigoClearTautomerRules.restype = c_int
+ Indigo._lib.indigoClearTautomerRules.argtypes = None
+ Indigo._lib.indigoUnserialize.restype = c_int
+ Indigo._lib.indigoUnserialize.argtypes = [POINTER(c_byte), c_int]
+ Indigo._lib.indigoCommonBits.restype = c_int
+ Indigo._lib.indigoCommonBits.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSimilarity.restype = c_float
+ Indigo._lib.indigoSimilarity.argtypes = [c_int, c_int, c_char_p]
+ Indigo._lib.indigoIterateSDF.restype = c_int
+ Indigo._lib.indigoIterateSDF.argtypes = [c_int]
+ Indigo._lib.indigoIterateRDF.restype = c_int
+ Indigo._lib.indigoIterateRDF.argtypes = [c_int]
+ Indigo._lib.indigoIterateSmiles.restype = c_int
+ Indigo._lib.indigoIterateSmiles.argtypes = [c_int]
+ Indigo._lib.indigoIterateCML.restype = c_int
+ Indigo._lib.indigoIterateCML.argtypes = [c_int]
+ Indigo._lib.indigoIterateCDX.restype = c_int
+ Indigo._lib.indigoIterateCDX.argtypes = [c_int]
+ Indigo._lib.indigoIterateSDFile.restype = c_int
+ Indigo._lib.indigoIterateSDFile.argtypes = [c_char_p]
+ Indigo._lib.indigoIterateRDFile.restype = c_int
+ Indigo._lib.indigoIterateRDFile.argtypes = [c_char_p]
+ Indigo._lib.indigoIterateSmilesFile.restype = c_int
+ Indigo._lib.indigoIterateSmilesFile.argtypes = [c_char_p]
+ Indigo._lib.indigoIterateCMLFile.restype = c_int
+ Indigo._lib.indigoIterateCMLFile.argtypes = [c_char_p]
+ Indigo._lib.indigoIterateCDXFile.restype = c_int
+ Indigo._lib.indigoIterateCDXFile.argtypes = [c_char_p]
+ Indigo._lib.indigoCreateSaver.restype = c_int
+ Indigo._lib.indigoCreateSaver.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoCreateFileSaver.restype = c_int
+ Indigo._lib.indigoCreateFileSaver.argtypes = [c_char_p, c_char_p]
+ Indigo._lib.indigoCreateArray.restype = c_int
+ Indigo._lib.indigoCreateArray.argtypes = None
+ Indigo._lib.indigoSubstructureMatcher.restype = c_int
+ Indigo._lib.indigoSubstructureMatcher.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoExtractCommonScaffold.restype = c_int
+ Indigo._lib.indigoExtractCommonScaffold.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoDecomposeMolecules.restype = c_int
+ Indigo._lib.indigoDecomposeMolecules.argtypes = [c_int, c_int]
+ Indigo._lib.indigoRGroupComposition.restype = c_int
+ Indigo._lib.indigoRGroupComposition.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoGetFragmentedMolecule.restype = c_int
+ Indigo._lib.indigoGetFragmentedMolecule.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoCreateDecomposer.restype = c_int
+ Indigo._lib.indigoCreateDecomposer.argtypes = [c_int]
+ Indigo._lib.indigoReactionProductEnumerate.restype = c_int
+ Indigo._lib.indigoReactionProductEnumerate.argtypes = [c_int, c_int]
+ Indigo._lib.indigoTransform.restype = c_int
+ Indigo._lib.indigoTransform.argtypes = [c_int, c_int]
+ Indigo._lib.indigoDbgBreakpoint.restype = None
+ Indigo._lib.indigoDbgBreakpoint.argtypes = None
+ Indigo._lib.indigoClone.restype = c_int
+ Indigo._lib.indigoClone.argtypes = [c_int]
+ Indigo._lib.indigoCheck.restype = c_char_p
+ Indigo._lib.indigoCheck.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoCheckStructure.restype = c_char_p
+ Indigo._lib.indigoCheckStructure.argtypes = [c_char_p, c_char_p]
+ Indigo._lib.indigoClose.restype = c_int
+ Indigo._lib.indigoClose.argtypes = [c_int]
+ Indigo._lib.indigoNext.restype = c_int
+ Indigo._lib.indigoNext.argtypes = [c_int]
+ Indigo._lib.indigoHasNext.restype = c_int
+ Indigo._lib.indigoHasNext.argtypes = [c_int]
+ Indigo._lib.indigoIndex.restype = c_int
+ Indigo._lib.indigoIndex.argtypes = [c_int]
+ Indigo._lib.indigoRemove.restype = c_int
+ Indigo._lib.indigoRemove.argtypes = [c_int]
+ Indigo._lib.indigoSaveMolfileToFile.restype = c_int
+ Indigo._lib.indigoSaveMolfileToFile.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoMolfile.restype = c_char_p
+ Indigo._lib.indigoMolfile.argtypes = [c_int]
+ Indigo._lib.indigoSaveCmlToFile.restype = c_int
+ Indigo._lib.indigoSaveCmlToFile.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoCml.restype = c_char_p
+ Indigo._lib.indigoCml.argtypes = [c_int]
+ Indigo._lib.indigoSaveCdxmlToFile.restype = c_int
+ Indigo._lib.indigoSaveCdxmlToFile.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoCdxml.restype = c_char_p
+ Indigo._lib.indigoCdxml.argtypes = [c_int]
+ Indigo._lib.indigoJson.restype = c_char_p
+ Indigo._lib.indigoJson.argtypes = [c_int]
+ Indigo._lib.indigoSaveMDLCT.restype = c_int
+ Indigo._lib.indigoSaveMDLCT.argtypes = [c_int, c_int]
+ Indigo._lib.indigoAddReactant.restype = c_int
+ Indigo._lib.indigoAddReactant.argtypes = [c_int, c_int]
+ Indigo._lib.indigoAddProduct.restype = c_int
+ Indigo._lib.indigoAddProduct.argtypes = [c_int, c_int]
+ Indigo._lib.indigoAddCatalyst.restype = c_int
+ Indigo._lib.indigoAddCatalyst.argtypes = [c_int, c_int]
+ Indigo._lib.indigoCountReactants.restype = c_int
+ Indigo._lib.indigoCountReactants.argtypes = [c_int]
+ Indigo._lib.indigoCountProducts.restype = c_int
+ Indigo._lib.indigoCountProducts.argtypes = [c_int]
+ Indigo._lib.indigoCountCatalysts.restype = c_int
+ Indigo._lib.indigoCountCatalysts.argtypes = [c_int]
+ Indigo._lib.indigoCountMolecules.restype = c_int
+ Indigo._lib.indigoCountMolecules.argtypes = [c_int]
+ Indigo._lib.indigoGetMolecule.restype = c_int
+ Indigo._lib.indigoGetMolecule.argtypes = [c_int, c_int]
+ Indigo._lib.indigoIterateReactants.restype = c_int
+ Indigo._lib.indigoIterateReactants.argtypes = [c_int]
+ Indigo._lib.indigoIterateProducts.restype = c_int
+ Indigo._lib.indigoIterateProducts.argtypes = [c_int]
+ Indigo._lib.indigoIterateCatalysts.restype = c_int
+ Indigo._lib.indigoIterateCatalysts.argtypes = [c_int]
+ Indigo._lib.indigoIterateMolecules.restype = c_int
+ Indigo._lib.indigoIterateMolecules.argtypes = [c_int]
+ Indigo._lib.indigoSaveRxnfileToFile.restype = c_int
+ Indigo._lib.indigoSaveRxnfileToFile.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoRxnfile.restype = c_char_p
+ Indigo._lib.indigoRxnfile.argtypes = [c_int]
+ Indigo._lib.indigoOptimize.restype = c_int
+ Indigo._lib.indigoOptimize.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoNormalize.restype = c_int
+ Indigo._lib.indigoNormalize.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoStandardize.restype = c_int
+ Indigo._lib.indigoStandardize.argtypes = [c_int]
+ Indigo._lib.indigoIonize.restype = c_int
+ Indigo._lib.indigoIonize.argtypes = [c_int, c_float, c_float]
+ Indigo._lib.indigoBuildPkaModel.restype = c_int
+ Indigo._lib.indigoBuildPkaModel.argtypes = [c_int, c_float, c_char_p]
+ Indigo._lib.indigoGetAcidPkaValue.restype = POINTER(c_float)
+ Indigo._lib.indigoGetAcidPkaValue.argtypes = [
+ c_int,
+ c_int,
+ c_int,
+ c_int,
+ ]
+ Indigo._lib.indigoGetBasicPkaValue.restype = POINTER(c_float)
+ Indigo._lib.indigoGetBasicPkaValue.argtypes = [
+ c_int,
+ c_int,
+ c_int,
+ c_int,
+ ]
+ Indigo._lib.indigoAutomap.restype = c_int
+ Indigo._lib.indigoAutomap.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoGetAtomMappingNumber.restype = c_int
+ Indigo._lib.indigoGetAtomMappingNumber.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSetAtomMappingNumber.restype = c_int
+ Indigo._lib.indigoSetAtomMappingNumber.argtypes = [c_int, c_int, c_int]
+ Indigo._lib.indigoGetReactingCenter.restype = c_int
+ Indigo._lib.indigoGetReactingCenter.argtypes = [
+ c_int,
+ c_int,
+ POINTER(c_int),
+ ]
+ Indigo._lib.indigoSetReactingCenter.restype = c_int
+ Indigo._lib.indigoSetReactingCenter.argtypes = [c_int, c_int, c_int]
+ Indigo._lib.indigoClearAAM.restype = c_int
+ Indigo._lib.indigoClearAAM.argtypes = [c_int]
+ Indigo._lib.indigoCorrectReactingCenters.restype = c_int
+ Indigo._lib.indigoCorrectReactingCenters.argtypes = [c_int]
+ Indigo._lib.indigoIterateAtoms.restype = c_int
+ Indigo._lib.indigoIterateAtoms.argtypes = [c_int]
+ Indigo._lib.indigoIteratePseudoatoms.restype = c_int
+ Indigo._lib.indigoIteratePseudoatoms.argtypes = [c_int]
+ Indigo._lib.indigoIterateRSites.restype = c_int
+ Indigo._lib.indigoIterateRSites.argtypes = [c_int]
+ Indigo._lib.indigoIterateStereocenters.restype = c_int
+ Indigo._lib.indigoIterateStereocenters.argtypes = [c_int]
+ Indigo._lib.indigoIterateAlleneCenters.restype = c_int
+ Indigo._lib.indigoIterateAlleneCenters.argtypes = [c_int]
+ Indigo._lib.indigoIterateRGroups.restype = c_int
+ Indigo._lib.indigoIterateRGroups.argtypes = [c_int]
+ Indigo._lib.indigoCountRGroups.restype = c_int
+ Indigo._lib.indigoCountRGroups.argtypes = [c_int]
+ Indigo._lib.indigoIsPseudoatom.restype = c_int
+ Indigo._lib.indigoIsPseudoatom.argtypes = [c_int]
+ Indigo._lib.indigoIsRSite.restype = c_int
+ Indigo._lib.indigoIsRSite.argtypes = [c_int]
+ Indigo._lib.indigoIsTemplateAtom.restype = c_int
+ Indigo._lib.indigoIsTemplateAtom.argtypes = [c_int]
+ Indigo._lib.indigoStereocenterType.restype = c_int
+ Indigo._lib.indigoStereocenterType.argtypes = [c_int]
+ Indigo._lib.indigoStereocenterGroup.restype = c_int
+ Indigo._lib.indigoStereocenterGroup.argtypes = [c_int]
+ Indigo._lib.indigoSetStereocenterGroup.restype = c_int
+ Indigo._lib.indigoSetStereocenterGroup.argtypes = [c_int, c_int]
+ Indigo._lib.indigoChangeStereocenterType.restype = c_int
+ Indigo._lib.indigoChangeStereocenterType.argtypes = [c_int, c_int]
+ Indigo._lib.indigoValidateChirality.restype = c_int
+ Indigo._lib.indigoValidateChirality.argtypes = [c_int]
+ Indigo._lib.indigoSingleAllowedRGroup.restype = c_int
+ Indigo._lib.indigoSingleAllowedRGroup.argtypes = [c_int]
+ Indigo._lib.indigoAddStereocenter.restype = c_int
+ Indigo._lib.indigoAddStereocenter.argtypes = [
+ c_int,
+ c_int,
+ c_int,
+ c_int,
+ c_int,
+ c_int,
+ ]
+ Indigo._lib.indigoIterateRGroupFragments.restype = c_int
+ Indigo._lib.indigoIterateRGroupFragments.argtypes = [c_int]
+ Indigo._lib.indigoCountAttachmentPoints.restype = c_int
+ Indigo._lib.indigoCountAttachmentPoints.argtypes = [c_int]
+ Indigo._lib.indigoIterateAttachmentPoints.restype = c_int
+ Indigo._lib.indigoIterateAttachmentPoints.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSymbol.restype = c_char_p
+ Indigo._lib.indigoSymbol.argtypes = [c_int]
+ Indigo._lib.indigoDegree.restype = c_int
+ Indigo._lib.indigoDegree.argtypes = [c_int]
+ Indigo._lib.indigoGetCharge.restype = c_int
+ Indigo._lib.indigoGetCharge.argtypes = [c_int, POINTER(c_int)]
+ Indigo._lib.indigoGetExplicitValence.restype = c_int
+ Indigo._lib.indigoGetExplicitValence.argtypes = [c_int, POINTER(c_int)]
+ Indigo._lib.indigoSetExplicitValence.restype = c_int
+ Indigo._lib.indigoSetExplicitValence.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetRadicalElectrons.restype = c_int
+ Indigo._lib.indigoGetRadicalElectrons.argtypes = [
+ c_int,
+ POINTER(c_int),
+ ]
+ Indigo._lib.indigoGetRadical.restype = c_int
+ Indigo._lib.indigoGetRadical.argtypes = [c_int, POINTER(c_int)]
+ Indigo._lib.indigoSetRadical.restype = c_int
+ Indigo._lib.indigoSetRadical.argtypes = [c_int, c_int]
+ Indigo._lib.indigoAtomicNumber.restype = c_int
+ Indigo._lib.indigoAtomicNumber.argtypes = [c_int]
+ Indigo._lib.indigoIsotope.restype = c_int
+ Indigo._lib.indigoIsotope.argtypes = [c_int]
+ Indigo._lib.indigoValence.restype = c_int
+ Indigo._lib.indigoValence.argtypes = [c_int]
+ Indigo._lib.indigoCheckValence.restype = c_int
+ Indigo._lib.indigoCheckValence.argtypes = [c_int]
+ Indigo._lib.indigoCheckQuery.restype = c_int
+ Indigo._lib.indigoCheckQuery.argtypes = [c_int]
+ Indigo._lib.indigoCheckRGroups.restype = c_int
+ Indigo._lib.indigoCheckRGroups.argtypes = [c_int]
+ Indigo._lib.indigoCountHydrogens.restype = c_int
+ Indigo._lib.indigoCountHydrogens.argtypes = [c_int, POINTER(c_int)]
+ Indigo._lib.indigoCountImplicitHydrogens.restype = c_int
+ Indigo._lib.indigoCountImplicitHydrogens.argtypes = [c_int]
+ Indigo._lib.indigoXYZ.restype = POINTER(c_float)
+ Indigo._lib.indigoXYZ.argtypes = [c_int]
+ Indigo._lib.indigoCoords.restype = POINTER(c_float)
+ Indigo._lib.indigoCoords.argtypes = [c_int]
+ Indigo._lib.indigoSetXYZ.restype = c_int
+ Indigo._lib.indigoSetXYZ.argtypes = [c_int, c_float, c_float, c_float]
+ Indigo._lib.indigoCountSuperatoms.restype = c_int
+ Indigo._lib.indigoCountSuperatoms.argtypes = [c_int]
+ Indigo._lib.indigoCountDataSGroups.restype = c_int
+ Indigo._lib.indigoCountDataSGroups.argtypes = [c_int]
+ Indigo._lib.indigoCountRepeatingUnits.restype = c_int
+ Indigo._lib.indigoCountRepeatingUnits.argtypes = [c_int]
+ Indigo._lib.indigoCountMultipleGroups.restype = c_int
+ Indigo._lib.indigoCountMultipleGroups.argtypes = [c_int]
+ Indigo._lib.indigoCountGenericSGroups.restype = c_int
+ Indigo._lib.indigoCountGenericSGroups.argtypes = [c_int]
+ Indigo._lib.indigoIterateDataSGroups.restype = c_int
+ Indigo._lib.indigoIterateDataSGroups.argtypes = [c_int]
+ Indigo._lib.indigoIterateSuperatoms.restype = c_int
+ Indigo._lib.indigoIterateSuperatoms.argtypes = [c_int]
+ Indigo._lib.indigoIterateGenericSGroups.restype = c_int
+ Indigo._lib.indigoIterateGenericSGroups.argtypes = [c_int]
+ Indigo._lib.indigoIterateRepeatingUnits.restype = c_int
+ Indigo._lib.indigoIterateRepeatingUnits.argtypes = [c_int]
+ Indigo._lib.indigoIterateMultipleGroups.restype = c_int
+ Indigo._lib.indigoIterateMultipleGroups.argtypes = [c_int]
+ Indigo._lib.indigoIterateSGroups.restype = c_int
+ Indigo._lib.indigoIterateSGroups.argtypes = [c_int]
+ Indigo._lib.indigoIterateTGroups.restype = c_int
+ Indigo._lib.indigoIterateTGroups.argtypes = [c_int]
+ Indigo._lib.indigoGetSuperatom.restype = c_int
+ Indigo._lib.indigoGetSuperatom.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetDataSGroup.restype = c_int
+ Indigo._lib.indigoGetDataSGroup.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetGenericSGroup.restype = c_int
+ Indigo._lib.indigoGetGenericSGroup.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetMultipleGroup.restype = c_int
+ Indigo._lib.indigoGetMultipleGroup.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetRepeatingUnit.restype = c_int
+ Indigo._lib.indigoGetRepeatingUnit.argtypes = [c_int, c_int]
+ Indigo._lib.indigoDescription.restype = c_char_p
+ Indigo._lib.indigoDescription.argtypes = [c_int]
+ Indigo._lib.indigoData.restype = c_char_p
+ Indigo._lib.indigoData.argtypes = [c_int]
+ Indigo._lib.indigoAddDataSGroup.restype = c_int
+ Indigo._lib.indigoAddDataSGroup.argtypes = [
+ c_int,
+ c_int,
+ POINTER(c_int),
+ c_int,
+ POINTER(c_int),
+ c_char_p,
+ c_char_p,
+ ]
+ Indigo._lib.indigoAddSuperatom.restype = c_int
+ Indigo._lib.indigoAddSuperatom.argtypes = [
+ c_int,
+ c_int,
+ POINTER(c_int),
+ c_char_p,
+ ]
+ Indigo._lib.indigoSetDataSGroupXY.restype = c_int
+ Indigo._lib.indigoSetDataSGroupXY.argtypes = [
+ c_int,
+ c_float,
+ c_float,
+ c_char_p,
+ ]
+ Indigo._lib.indigoSetSGroupData.restype = c_int
+ Indigo._lib.indigoSetSGroupData.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupCoords.restype = c_int
+ Indigo._lib.indigoSetSGroupCoords.argtypes = [c_int, c_float, c_float]
+ Indigo._lib.indigoSetSGroupDescription.restype = c_int
+ Indigo._lib.indigoSetSGroupDescription.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupFieldName.restype = c_int
+ Indigo._lib.indigoSetSGroupFieldName.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupQueryCode.restype = c_int
+ Indigo._lib.indigoSetSGroupQueryCode.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupQueryOper.restype = c_int
+ Indigo._lib.indigoSetSGroupQueryOper.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupDisplay.restype = c_int
+ Indigo._lib.indigoSetSGroupDisplay.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupLocation.restype = c_int
+ Indigo._lib.indigoSetSGroupLocation.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupTag.restype = c_int
+ Indigo._lib.indigoSetSGroupTag.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupTagAlign.restype = c_int
+ Indigo._lib.indigoSetSGroupTagAlign.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSetSGroupDataType.restype = c_int
+ Indigo._lib.indigoSetSGroupDataType.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupXCoord.restype = c_int
+ Indigo._lib.indigoSetSGroupXCoord.argtypes = [c_int, c_float]
+ Indigo._lib.indigoSetSGroupYCoord.restype = c_int
+ Indigo._lib.indigoSetSGroupYCoord.argtypes = [c_int, c_float]
+ Indigo._lib.indigoCreateSGroup.restype = c_int
+ Indigo._lib.indigoCreateSGroup.argtypes = [c_char_p, c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupClass.restype = c_int
+ Indigo._lib.indigoSetSGroupClass.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetSGroupName.restype = c_int
+ Indigo._lib.indigoSetSGroupName.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoGetSGroupClass.restype = c_char_p
+ Indigo._lib.indigoGetSGroupClass.argtypes = [c_int]
+ Indigo._lib.indigoGetSGroupName.restype = c_char_p
+ Indigo._lib.indigoGetSGroupName.argtypes = [c_int]
+ Indigo._lib.indigoGetSGroupNumCrossBonds.restype = c_int
+ Indigo._lib.indigoGetSGroupNumCrossBonds.argtypes = [c_int]
+ Indigo._lib.indigoAddSGroupAttachmentPoint.restype = c_int
+ Indigo._lib.indigoAddSGroupAttachmentPoint.argtypes = [
+ c_int,
+ c_int,
+ c_int,
+ c_char_p,
+ ]
+ Indigo._lib.indigoDeleteSGroupAttachmentPoint.restype = c_int
+ Indigo._lib.indigoDeleteSGroupAttachmentPoint.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetSGroupDisplayOption.restype = c_int
+ Indigo._lib.indigoGetSGroupDisplayOption.argtypes = [c_int]
+ Indigo._lib.indigoSetSGroupDisplayOption.restype = c_int
+ Indigo._lib.indigoSetSGroupDisplayOption.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetSGroupSeqId.restype = c_int
+ Indigo._lib.indigoGetSGroupSeqId.argtypes = [c_int]
+ Indigo._lib.indigoGetSGroupCoords.restype = POINTER(c_float)
+ Indigo._lib.indigoGetSGroupCoords.argtypes = [c_int]
+ Indigo._lib.indigoGetRepeatingUnitSubscript.restype = c_char_p
+ Indigo._lib.indigoGetRepeatingUnitSubscript.argtypes = [c_int]
+ Indigo._lib.indigoGetRepeatingUnitConnectivity.restype = c_int
+ Indigo._lib.indigoGetRepeatingUnitConnectivity.argtypes = [c_int]
+ Indigo._lib.indigoGetSGroupMultiplier.restype = c_int
+ Indigo._lib.indigoGetSGroupMultiplier.argtypes = [c_int]
+ Indigo._lib.indigoSetSGroupMultiplier.restype = c_int
+ Indigo._lib.indigoSetSGroupMultiplier.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSetSGroupBrackets.restype = c_int
+ Indigo._lib.indigoSetSGroupBrackets.argtypes = [
+ c_int,
+ c_int,
+ c_float,
+ c_float,
+ c_float,
+ c_float,
+ c_float,
+ c_float,
+ c_float,
+ c_float,
+ ]
+ Indigo._lib.indigoFindSGroups.restype = c_int
+ Indigo._lib.indigoFindSGroups.argtypes = [c_int, c_char_p, c_char_p]
+ Indigo._lib.indigoGetSGroupType.restype = c_int
+ Indigo._lib.indigoGetSGroupType.argtypes = [c_int]
+ Indigo._lib.indigoGetSGroupIndex.restype = c_int
+ Indigo._lib.indigoGetSGroupIndex.argtypes = [c_int]
+ Indigo._lib.indigoGetSGroupOriginalId.restype = c_int
+ Indigo._lib.indigoGetSGroupOriginalId.argtypes = [c_int]
+ Indigo._lib.indigoSetSGroupOriginalId.restype = c_int
+ Indigo._lib.indigoSetSGroupOriginalId.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetSGroupParentId.restype = c_int
+ Indigo._lib.indigoGetSGroupParentId.argtypes = [c_int]
+ Indigo._lib.indigoSetSGroupParentId.restype = c_int
+ Indigo._lib.indigoSetSGroupParentId.argtypes = [c_int, c_int]
+ Indigo._lib.indigoAddTemplate.restype = c_int
+ Indigo._lib.indigoAddTemplate.argtypes = [c_int, c_int, c_char_p]
+ Indigo._lib.indigoRemoveTemplate.restype = c_int
+ Indigo._lib.indigoRemoveTemplate.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoFindTemplate.restype = c_int
+ Indigo._lib.indigoFindTemplate.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoGetTGroupClass.restype = c_char_p
+ Indigo._lib.indigoGetTGroupClass.argtypes = [c_int]
+ Indigo._lib.indigoGetTGroupName.restype = c_char_p
+ Indigo._lib.indigoGetTGroupName.argtypes = [c_int]
+ Indigo._lib.indigoGetTGroupAlias.restype = c_char_p
+ Indigo._lib.indigoGetTGroupAlias.argtypes = [c_int]
+ Indigo._lib.indigoTransformSCSRtoCTAB.restype = c_int
+ Indigo._lib.indigoTransformSCSRtoCTAB.argtypes = [c_int]
+ Indigo._lib.indigoTransformCTABtoSCSR.restype = c_int
+ Indigo._lib.indigoTransformCTABtoSCSR.argtypes = [c_int, c_int]
+ Indigo._lib.indigoTransformHELMtoSCSR.restype = c_int
+ Indigo._lib.indigoTransformHELMtoSCSR.argtypes = [c_int]
+ Indigo._lib.indigoGetTemplateAtomClass.restype = c_char_p
+ Indigo._lib.indigoGetTemplateAtomClass.argtypes = [c_int]
+ Indigo._lib.indigoSetTemplateAtomClass.restype = c_int
+ Indigo._lib.indigoSetTemplateAtomClass.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoResetCharge.restype = c_int
+ Indigo._lib.indigoResetCharge.argtypes = [c_int]
+ Indigo._lib.indigoResetExplicitValence.restype = c_int
+ Indigo._lib.indigoResetExplicitValence.argtypes = [c_int]
+ Indigo._lib.indigoResetRadical.restype = c_int
+ Indigo._lib.indigoResetRadical.argtypes = [c_int]
+ Indigo._lib.indigoResetIsotope.restype = c_int
+ Indigo._lib.indigoResetIsotope.argtypes = [c_int]
+ Indigo._lib.indigoSetAttachmentPoint.restype = c_int
+ Indigo._lib.indigoSetAttachmentPoint.argtypes = [c_int, c_int]
+ Indigo._lib.indigoClearAttachmentPoints.restype = c_int
+ Indigo._lib.indigoClearAttachmentPoints.argtypes = [c_int]
+ Indigo._lib.indigoRemoveConstraints.restype = c_int
+ Indigo._lib.indigoRemoveConstraints.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoAddConstraint.restype = c_int
+ Indigo._lib.indigoAddConstraint.argtypes = [c_int, c_char_p, c_char_p]
+ Indigo._lib.indigoAddConstraintNot.restype = c_int
+ Indigo._lib.indigoAddConstraintNot.argtypes = [
+ c_int,
+ c_char_p,
+ c_char_p,
+ ]
+ Indigo._lib.indigoAddConstraintOr.restype = c_int
+ Indigo._lib.indigoAddConstraintOr.argtypes = [
+ c_int,
+ c_char_p,
+ c_char_p,
+ ]
+ Indigo._lib.indigoResetStereo.restype = c_int
+ Indigo._lib.indigoResetStereo.argtypes = [c_int]
+ Indigo._lib.indigoInvertStereo.restype = c_int
+ Indigo._lib.indigoInvertStereo.argtypes = [c_int]
+ Indigo._lib.indigoCountAtoms.restype = c_int
+ Indigo._lib.indigoCountAtoms.argtypes = [c_int]
+ Indigo._lib.indigoCountBonds.restype = c_int
+ Indigo._lib.indigoCountBonds.argtypes = [c_int]
+ Indigo._lib.indigoCountPseudoatoms.restype = c_int
+ Indigo._lib.indigoCountPseudoatoms.argtypes = [c_int]
+ Indigo._lib.indigoCountRSites.restype = c_int
+ Indigo._lib.indigoCountRSites.argtypes = [c_int]
+ Indigo._lib.indigoIterateBonds.restype = c_int
+ Indigo._lib.indigoIterateBonds.argtypes = [c_int]
+ Indigo._lib.indigoBondOrder.restype = c_int
+ Indigo._lib.indigoBondOrder.argtypes = [c_int]
+ Indigo._lib.indigoBondStereo.restype = c_int
+ Indigo._lib.indigoBondStereo.argtypes = [c_int]
+ Indigo._lib.indigoTopology.restype = c_int
+ Indigo._lib.indigoTopology.argtypes = [c_int]
+ Indigo._lib.indigoIterateNeighbors.restype = c_int
+ Indigo._lib.indigoIterateNeighbors.argtypes = [c_int]
+ Indigo._lib.indigoBond.restype = c_int
+ Indigo._lib.indigoBond.argtypes = [c_int]
+ Indigo._lib.indigoGetAtom.restype = c_int
+ Indigo._lib.indigoGetAtom.argtypes = [c_int, c_int]
+ Indigo._lib.indigoGetBond.restype = c_int
+ Indigo._lib.indigoGetBond.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSource.restype = c_int
+ Indigo._lib.indigoSource.argtypes = [c_int]
+ Indigo._lib.indigoDestination.restype = c_int
+ Indigo._lib.indigoDestination.argtypes = [c_int]
+ Indigo._lib.indigoClearCisTrans.restype = c_int
+ Indigo._lib.indigoClearCisTrans.argtypes = [c_int]
+ Indigo._lib.indigoClearStereocenters.restype = c_int
+ Indigo._lib.indigoClearStereocenters.argtypes = [c_int]
+ Indigo._lib.indigoCountStereocenters.restype = c_int
+ Indigo._lib.indigoCountStereocenters.argtypes = [c_int]
+ Indigo._lib.indigoClearAlleneCenters.restype = c_int
+ Indigo._lib.indigoClearAlleneCenters.argtypes = [c_int]
+ Indigo._lib.indigoCountAlleneCenters.restype = c_int
+ Indigo._lib.indigoCountAlleneCenters.argtypes = [c_int]
+ Indigo._lib.indigoResetSymmetricCisTrans.restype = c_int
+ Indigo._lib.indigoResetSymmetricCisTrans.argtypes = [c_int]
+ Indigo._lib.indigoResetSymmetricStereocenters.restype = c_int
+ Indigo._lib.indigoResetSymmetricStereocenters.argtypes = [c_int]
+ Indigo._lib.indigoMarkEitherCisTrans.restype = c_int
+ Indigo._lib.indigoMarkEitherCisTrans.argtypes = [c_int]
+ Indigo._lib.indigoMarkStereobonds.restype = c_int
+ Indigo._lib.indigoMarkStereobonds.argtypes = [c_int]
+ Indigo._lib.indigoAddAtom.restype = c_int
+ Indigo._lib.indigoAddAtom.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoResetAtom.restype = c_int
+ Indigo._lib.indigoResetAtom.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoAddRSite.restype = c_int
+ Indigo._lib.indigoAddRSite.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetRSite.restype = c_int
+ Indigo._lib.indigoSetRSite.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetCharge.restype = c_int
+ Indigo._lib.indigoSetCharge.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSetIsotope.restype = c_int
+ Indigo._lib.indigoSetIsotope.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSetImplicitHCount.restype = c_int
+ Indigo._lib.indigoSetImplicitHCount.argtypes = [c_int, c_int]
+ Indigo._lib.indigoAddBond.restype = c_int
+ Indigo._lib.indigoAddBond.argtypes = [c_int, c_int, c_int]
+ Indigo._lib.indigoSetBondOrder.restype = c_int
+ Indigo._lib.indigoSetBondOrder.argtypes = [c_int, c_int]
+ Indigo._lib.indigoMerge.restype = c_int
+ Indigo._lib.indigoMerge.argtypes = [c_int, c_int]
+ Indigo._lib.indigoHighlight.restype = c_int
+ Indigo._lib.indigoHighlight.argtypes = [c_int]
+ Indigo._lib.indigoUnhighlight.restype = c_int
+ Indigo._lib.indigoUnhighlight.argtypes = [c_int]
+ Indigo._lib.indigoIsHighlighted.restype = c_int
+ Indigo._lib.indigoIsHighlighted.argtypes = [c_int]
+ Indigo._lib.indigoCountComponents.restype = c_int
+ Indigo._lib.indigoCountComponents.argtypes = [c_int]
+ Indigo._lib.indigoComponentIndex.restype = c_int
+ Indigo._lib.indigoComponentIndex.argtypes = [c_int]
+ Indigo._lib.indigoIterateComponents.restype = c_int
+ Indigo._lib.indigoIterateComponents.argtypes = [c_int]
+ Indigo._lib.indigoComponent.restype = c_int
+ Indigo._lib.indigoComponent.argtypes = [c_int, c_int]
+ Indigo._lib.indigoCountSSSR.restype = c_int
+ Indigo._lib.indigoCountSSSR.argtypes = [c_int]
+ Indigo._lib.indigoIterateSSSR.restype = c_int
+ Indigo._lib.indigoIterateSSSR.argtypes = [c_int]
+ Indigo._lib.indigoIterateSubtrees.restype = c_int
+ Indigo._lib.indigoIterateSubtrees.argtypes = [c_int, c_int, c_int]
+ Indigo._lib.indigoIterateRings.restype = c_int
+ Indigo._lib.indigoIterateRings.argtypes = [c_int, c_int, c_int]
+ Indigo._lib.indigoIterateEdgeSubmolecules.restype = c_int
+ Indigo._lib.indigoIterateEdgeSubmolecules.argtypes = [
+ c_int,
+ c_int,
+ c_int,
+ ]
+ Indigo._lib.indigoCountHeavyAtoms.restype = c_int
+ Indigo._lib.indigoCountHeavyAtoms.argtypes = [c_int]
+ Indigo._lib.indigoGrossFormula.restype = c_int
+ Indigo._lib.indigoGrossFormula.argtypes = [c_int]
+ Indigo._lib.indigoMolecularWeight.restype = c_double
+ Indigo._lib.indigoMolecularWeight.argtypes = [c_int]
+ Indigo._lib.indigoMostAbundantMass.restype = c_double
+ Indigo._lib.indigoMostAbundantMass.argtypes = [c_int]
+ Indigo._lib.indigoMonoisotopicMass.restype = c_double
+ Indigo._lib.indigoMonoisotopicMass.argtypes = [c_int]
+ Indigo._lib.indigoMassComposition.restype = c_char_p
+ Indigo._lib.indigoMassComposition.argtypes = [c_int]
+ Indigo._lib.indigoCanonicalSmiles.restype = c_char_p
+ Indigo._lib.indigoCanonicalSmiles.argtypes = [c_int]
+ Indigo._lib.indigoCanonicalSmarts.restype = c_char_p
+ Indigo._lib.indigoCanonicalSmarts.argtypes = [c_int]
+ Indigo._lib.indigoLayeredCode.restype = c_char_p
+ Indigo._lib.indigoLayeredCode.argtypes = [c_int]
+ Indigo._lib.indigoSymmetryClasses.restype = POINTER(c_int)
+ Indigo._lib.indigoSymmetryClasses.argtypes = [c_int, POINTER(c_int)]
+ Indigo._lib.indigoHasCoord.restype = c_int
+ Indigo._lib.indigoHasCoord.argtypes = [c_int]
+ Indigo._lib.indigoHasZCoord.restype = c_int
+ Indigo._lib.indigoHasZCoord.argtypes = [c_int]
+ Indigo._lib.indigoIsChiral.restype = c_int
+ Indigo._lib.indigoIsChiral.argtypes = [c_int]
+ Indigo._lib.indigoIsPossibleFischerProjection.restype = c_int
+ Indigo._lib.indigoIsPossibleFischerProjection.argtypes = [
+ c_int,
+ c_char_p,
+ ]
+ Indigo._lib.indigoCreateSubmolecule.restype = c_int
+ Indigo._lib.indigoCreateSubmolecule.argtypes = [
+ c_int,
+ c_int,
+ POINTER(c_int),
+ ]
+ Indigo._lib.indigoCreateEdgeSubmolecule.restype = c_int
+ Indigo._lib.indigoCreateEdgeSubmolecule.argtypes = [
+ c_int,
+ c_int,
+ POINTER(c_int),
+ c_int,
+ POINTER(c_int),
+ ]
+ Indigo._lib.indigoGetSubmolecule.restype = c_int
+ Indigo._lib.indigoGetSubmolecule.argtypes = [
+ c_int,
+ c_int,
+ POINTER(c_int),
+ ]
+ Indigo._lib.indigoRemoveAtoms.restype = c_int
+ Indigo._lib.indigoRemoveAtoms.argtypes = [c_int, c_int, POINTER(c_int)]
+ Indigo._lib.indigoRemoveBonds.restype = c_int
+ Indigo._lib.indigoRemoveBonds.argtypes = [c_int, c_int, POINTER(c_int)]
+ Indigo._lib.indigoAlignAtoms.restype = c_float
+ Indigo._lib.indigoAlignAtoms.argtypes = [
+ c_int,
+ c_int,
+ POINTER(c_int),
+ POINTER(c_float),
+ ]
+ Indigo._lib.indigoAromatize.restype = c_int
+ Indigo._lib.indigoAromatize.argtypes = [c_int]
+ Indigo._lib.indigoDearomatize.restype = c_int
+ Indigo._lib.indigoDearomatize.argtypes = [c_int]
+ Indigo._lib.indigoFoldHydrogens.restype = c_int
+ Indigo._lib.indigoFoldHydrogens.argtypes = [c_int]
+ Indigo._lib.indigoUnfoldHydrogens.restype = c_int
+ Indigo._lib.indigoUnfoldHydrogens.argtypes = [c_int]
+ Indigo._lib.indigoLayout.restype = c_int
+ Indigo._lib.indigoLayout.argtypes = [c_int]
+ Indigo._lib.indigoClean2d.restype = c_int
+ Indigo._lib.indigoClean2d.argtypes = [c_int]
+ Indigo._lib.indigoSmiles.restype = c_char_p
+ Indigo._lib.indigoSmiles.argtypes = [c_int]
+ Indigo._lib.indigoSmarts.restype = c_char_p
+ Indigo._lib.indigoSmarts.argtypes = [c_int]
+ Indigo._lib.indigoName.restype = c_char_p
+ Indigo._lib.indigoName.argtypes = [c_int]
+ Indigo._lib.indigoSetName.restype = c_int
+ Indigo._lib.indigoSetName.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSerialize.restype = c_int
+ Indigo._lib.indigoSerialize.argtypes = [
+ c_int,
+ POINTER(POINTER(c_byte)),
+ POINTER(c_int),
+ ]
+ Indigo._lib.indigoHasProperty.restype = c_int
+ Indigo._lib.indigoHasProperty.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoGetProperty.restype = c_char_p
+ Indigo._lib.indigoGetProperty.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoSetProperty.restype = c_int
+ Indigo._lib.indigoSetProperty.argtypes = [c_int, c_char_p, c_char_p]
+ Indigo._lib.indigoRemoveProperty.restype = c_int
+ Indigo._lib.indigoRemoveProperty.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoIterateProperties.restype = c_int
+ Indigo._lib.indigoIterateProperties.argtypes = [c_int]
+ Indigo._lib.indigoClearProperties.restype = c_int
+ Indigo._lib.indigoClearProperties.argtypes = [c_int]
+ Indigo._lib.indigoCheckBadValence.restype = c_char_p
+ Indigo._lib.indigoCheckBadValence.argtypes = [c_int]
+ Indigo._lib.indigoCheckAmbiguousH.restype = c_char_p
+ Indigo._lib.indigoCheckAmbiguousH.argtypes = [c_int]
+ Indigo._lib.indigoCheckChirality.restype = c_int
+ Indigo._lib.indigoCheckChirality.argtypes = [c_int]
+ Indigo._lib.indigoCheck3DStereo.restype = c_int
+ Indigo._lib.indigoCheck3DStereo.argtypes = [c_int]
+ Indigo._lib.indigoCheckStereo.restype = c_int
+ Indigo._lib.indigoCheckStereo.argtypes = [c_int]
+ Indigo._lib.indigoFingerprint.restype = c_int
+ Indigo._lib.indigoFingerprint.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoLoadFingerprintFromBuffer.restype = c_int
+ Indigo._lib.indigoLoadFingerprintFromBuffer.argtypes = [
+ POINTER(c_byte),
+ c_int,
+ ]
+ Indigo._lib.indigoLoadFingerprintFromDescriptors.restype = c_int
+ Indigo._lib.indigoLoadFingerprintFromDescriptors.argtypes = [
+ POINTER(c_double),
+ c_int,
+ c_int,
+ c_double,
+ ]
+ Indigo._lib.indigoCountBits.restype = c_int
+ Indigo._lib.indigoCountBits.argtypes = [c_int]
+ Indigo._lib.indigoRawData.restype = c_char_p
+ Indigo._lib.indigoRawData.argtypes = [c_int]
+ Indigo._lib.indigoTell.restype = c_int
+ Indigo._lib.indigoTell.argtypes = [c_int]
+ Indigo._lib.indigoSdfAppend.restype = c_int
+ Indigo._lib.indigoSdfAppend.argtypes = [c_int, c_int]
+ Indigo._lib.indigoSmilesAppend.restype = c_int
+ Indigo._lib.indigoSmilesAppend.argtypes = [c_int, c_int]
+ Indigo._lib.indigoRdfHeader.restype = c_int
+ Indigo._lib.indigoRdfHeader.argtypes = [c_int]
+ Indigo._lib.indigoRdfAppend.restype = c_int
+ Indigo._lib.indigoRdfAppend.argtypes = [c_int, c_int]
+ Indigo._lib.indigoCmlHeader.restype = c_int
+ Indigo._lib.indigoCmlHeader.argtypes = [c_int]
+ Indigo._lib.indigoCmlAppend.restype = c_int
+ Indigo._lib.indigoCmlAppend.argtypes = [c_int, c_int]
+ Indigo._lib.indigoCmlFooter.restype = c_int
+ Indigo._lib.indigoCmlFooter.argtypes = [c_int]
+ Indigo._lib.indigoAppend.restype = c_int
+ Indigo._lib.indigoAppend.argtypes = [c_int, c_int]
+ Indigo._lib.indigoArrayAdd.restype = c_int
+ Indigo._lib.indigoArrayAdd.argtypes = [c_int, c_int]
+ Indigo._lib.indigoAt.restype = c_int
+ Indigo._lib.indigoAt.argtypes = [c_int, c_int]
+ Indigo._lib.indigoCount.restype = c_int
+ Indigo._lib.indigoCount.argtypes = [c_int]
+ Indigo._lib.indigoClear.restype = c_int
+ Indigo._lib.indigoClear.argtypes = [c_int]
+ Indigo._lib.indigoIterateArray.restype = c_int
+ Indigo._lib.indigoIterateArray.argtypes = [c_int]
+ Indigo._lib.indigoIgnoreAtom.restype = c_int
+ Indigo._lib.indigoIgnoreAtom.argtypes = [c_int, c_int]
+ Indigo._lib.indigoUnignoreAtom.restype = c_int
+ Indigo._lib.indigoUnignoreAtom.argtypes = [c_int, c_int]
+ Indigo._lib.indigoUnignoreAllAtoms.restype = c_int
+ Indigo._lib.indigoUnignoreAllAtoms.argtypes = [c_int]
+ Indigo._lib.indigoMatch.restype = c_int
+ Indigo._lib.indigoMatch.argtypes = [c_int, c_int]
+ Indigo._lib.indigoCountMatches.restype = c_int
+ Indigo._lib.indigoCountMatches.argtypes = [c_int, c_int]
+ Indigo._lib.indigoCountMatchesWithLimit.restype = c_int
+ Indigo._lib.indigoCountMatchesWithLimit.argtypes = [
+ c_int,
+ c_int,
+ c_int,
+ ]
+ Indigo._lib.indigoIterateMatches.restype = c_int
+ Indigo._lib.indigoIterateMatches.argtypes = [c_int, c_int]
+ Indigo._lib.indigoHighlightedTarget.restype = c_int
+ Indigo._lib.indigoHighlightedTarget.argtypes = [c_int]
+ Indigo._lib.indigoMapAtom.restype = c_int
+ Indigo._lib.indigoMapAtom.argtypes = [c_int, c_int]
+ Indigo._lib.indigoMapBond.restype = c_int
+ Indigo._lib.indigoMapBond.argtypes = [c_int, c_int]
+ Indigo._lib.indigoMapMolecule.restype = c_int
+ Indigo._lib.indigoMapMolecule.argtypes = [c_int, c_int]
+ Indigo._lib.indigoIterateTautomers.restype = c_int
+ Indigo._lib.indigoIterateTautomers.argtypes = [c_int, c_char_p]
+ Indigo._lib.indigoAllScaffolds.restype = c_int
+ Indigo._lib.indigoAllScaffolds.argtypes = [c_int]
+ Indigo._lib.indigoDecomposedMoleculeScaffold.restype = c_int
+ Indigo._lib.indigoDecomposedMoleculeScaffold.argtypes = [c_int]
+ Indigo._lib.indigoIterateDecomposedMolecules.restype = c_int
+ Indigo._lib.indigoIterateDecomposedMolecules.argtypes = [c_int]
+ Indigo._lib.indigoDecomposedMoleculeHighlighted.restype = c_int
+ Indigo._lib.indigoDecomposedMoleculeHighlighted.argtypes = [c_int]
+ Indigo._lib.indigoDecomposedMoleculeWithRGroups.restype = c_int
+ Indigo._lib.indigoDecomposedMoleculeWithRGroups.argtypes = [c_int]
+ Indigo._lib.indigoDecomposeMolecule.restype = c_int
+ Indigo._lib.indigoDecomposeMolecule.argtypes = [c_int, c_int]
+ Indigo._lib.indigoIterateDecompositions.restype = c_int
+ Indigo._lib.indigoIterateDecompositions.argtypes = [c_int]
+ Indigo._lib.indigoAddDecomposition.restype = c_int
+ Indigo._lib.indigoAddDecomposition.argtypes = [c_int, c_int]
+ Indigo._lib.indigoToString.restype = c_char_p
+ Indigo._lib.indigoToString.argtypes = [c_int]
+ Indigo._lib.indigoOneBitsList.restype = c_char_p
+ Indigo._lib.indigoOneBitsList.argtypes = [c_int]
+ Indigo._lib.indigoToBuffer.restype = c_int
+ Indigo._lib.indigoToBuffer.argtypes = [
+ c_int,
+ POINTER(POINTER(c_byte)),
+ POINTER(c_int),
+ ]
+ Indigo._lib.indigoStereocenterPyramid.restype = POINTER(c_int)
+ Indigo._lib.indigoStereocenterPyramid.argtypes = [c_int]
+ Indigo._lib.indigoExpandAbbreviations.restype = c_int
+ Indigo._lib.indigoExpandAbbreviations.argtypes = [c_int]
+ Indigo._lib.indigoDbgInternalType.restype = c_char_p
+ Indigo._lib.indigoDbgInternalType.argtypes = [c_int]
+ Indigo._lib.indigoNameToStructure.restype = c_int
+ Indigo._lib.indigoNameToStructure.argtypes = [c_char_p, c_char_p]
+ Indigo._lib.indigoResetOptions.restype = c_int
+ Indigo._lib.indigoResetOptions.argtypes = None
+
+ def __del__(self):
+ if hasattr(self, "_lib"):
+ self._lib.indigoReleaseSessionId(self._sid)
+
+ def deserialize(self, arr):
+ values = (c_byte * len(arr))()
+ for i in range(len(arr)):
+ values[i] = arr[i]
+ self._setSessionId()
+ res = Indigo._lib.indigoUnserialize(values, len(arr))
+ return self.IndigoObject(self, self._checkResult(res))
+
+ def unserialize(self, arr):
+ warnings.warn(
+ "unserialize() is deprecated, use deserialize() instead",
+ DeprecationWarning,
+ )
+ return self.deserialize(arr)
+
+ def setOption(self, option, value1, value2=None, value3=None):
+ self._setSessionId()
+ if (
+ (
+ type(value1).__name__ == "str"
+ or type(value1).__name__ == "unicode"
+ )
+ and value2 is None
+ and value3 is None
+ ):
+ self._checkResult(
+ Indigo._lib.indigoSetOption(
+ option.encode(ENCODE_ENCODING),
+ value1.encode(ENCODE_ENCODING),
+ )
+ )
+ elif (
+ type(value1).__name__ == "int"
+ and value2 is None
+ and value3 is None
+ ):
+ self._checkResult(
+ Indigo._lib.indigoSetOptionInt(
+ option.encode(ENCODE_ENCODING), value1
+ )
+ )
+ elif (
+ type(value1).__name__ == "float"
+ and value2 is None
+ and value3 is None
+ ):
+ self._checkResult(
+ Indigo._lib.indigoSetOptionFloat(
+ option.encode(ENCODE_ENCODING), value1
+ )
+ )
+ elif (
+ type(value1).__name__ == "bool"
+ and value2 is None
+ and value3 is None
+ ):
+ value1_b = 0
+ if value1:
+ value1_b = 1
+ self._checkResult(
+ Indigo._lib.indigoSetOptionBool(
+ option.encode(ENCODE_ENCODING), value1_b
+ )
+ )
+ elif (
+ type(value1).__name__ == "int"
+ and value2
+ and type(value2).__name__ == "int"
+ and value3 is None
+ ):
+ self._checkResult(
+ Indigo._lib.indigoSetOptionXY(
+ option.encode(ENCODE_ENCODING), value1, value2
+ )
+ )
+ elif (
+ type(value1).__name__ == "float"
+ and value2
+ and type(value2).__name__ == "float"
+ and value3
+ and type(value3).__name__ == "float"
+ ):
+ self._checkResult(
+ Indigo._lib.indigoSetOptionColor(
+ option.encode(ENCODE_ENCODING), value1, value2, value3
+ )
+ )
+ else:
+ raise IndigoException("bad option")
+
+ def getOption(self, option):
+ self._setSessionId()
+ return self._checkResultString(
+ Indigo._lib.indigoGetOption(option.encode(ENCODE_ENCODING))
+ )
+
+ def getOptionInt(self, option):
+ self._setSessionId()
+ value = c_int()
+ self._checkResult(
+ Indigo._lib.indigoGetOptionInt(
+ option.encode(ENCODE_ENCODING), pointer(value)
+ )
+ )
+ return value.value
+
+ def getOptionBool(self, option):
+ self._setSessionId()
+ value = c_int()
+ self._checkResult(
+ Indigo._lib.indigoGetOptionBool(
+ option.encode(ENCODE_ENCODING), pointer(value)
+ )
+ )
+ if value.value == 1:
+ return True
+ return False
+
+ def getOptionFloat(self, option):
+ self._setSessionId()
+ value = c_float()
+ self._checkResult(
+ Indigo._lib.indigoGetOptionFloat(
+ option.encode(ENCODE_ENCODING), pointer(value)
+ )
+ )
+ return value.value
+
+ def getOptionType(self, option):
+ self._setSessionId()
+ return self._checkResultString(
+ Indigo._lib.indigoGetOptionType(option.encode(ENCODE_ENCODING))
+ )
+
+ def resetOptions(self):
+ self._setSessionId()
+ self._checkResult(Indigo._lib.indigoResetOptions())
+
+ def _checkResult(self, result):
+ if result < 0:
+ raise IndigoException(Indigo._lib.indigoGetLastError())
+ return result
+
+ def _checkResultFloat(self, result):
+ if result < -0.5:
+ raise IndigoException(Indigo._lib.indigoGetLastError())
+ return result
+
+ def _checkResultPtr(self, result):
+ if result is None:
+ raise IndigoException(Indigo._lib.indigoGetLastError())
+ return result
+
+ def _checkResultString(self, result):
+ return self._checkResultPtr(result).decode(DECODE_ENCODING)
+
+ def convertToArray(self, iteratable):
+ if isinstance(iteratable, IndigoObject):
+ return iteratable
+ try:
+ some_object_iterator = iter(iteratable)
+ res = self.createArray()
+ for obj in some_object_iterator:
+ res.arrayAdd(self.convertToArray(obj))
+ return res
+ except TypeError:
+ raise IndigoException(
+ "Cannot convert object %s to an array" % (iteratable)
+ )
+
+ def dbgBreakpoint(self):
+ self._setSessionId()
+ return Indigo._lib.indigoDbgBreakpoint()
+
+ def version(self):
+ self._setSessionId()
+ return self._checkResultString(Indigo._lib.indigoVersion())
+
+ def countReferences(self):
+ self._setSessionId()
+ return self._checkResult(Indigo._lib.indigoCountReferences())
+
+ def writeFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoWriteFile(filename.encode(ENCODE_ENCODING))
+ ),
+ )
+
+ def writeBuffer(self):
+ self._setSessionId()
+ return self.IndigoObject(
+ self, self._checkResult(Indigo._lib.indigoWriteBuffer())
+ )
+
+ def createMolecule(self):
+ self._setSessionId()
+ return self.IndigoObject(
+ self, self._checkResult(Indigo._lib.indigoCreateMolecule())
+ )
+
+ def createQueryMolecule(self):
+ self._setSessionId()
+ return self.IndigoObject(
+ self, self._checkResult(Indigo._lib.indigoCreateQueryMolecule())
+ )
+
+ def loadMolecule(self, string):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadMoleculeFromString(
+ string.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadMoleculeFromFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadMoleculeFromFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadMoleculeFromBuffer(self, data):
+ """
+ Loads molecule from given buffer. Automatically detects input format
+
+ Args:
+ * buf - byte array
+
+ Usage:
+ ```
+ with open (..), 'rb') as f:
+ m = indigo.loadMoleculeFromBuffer(f.read())
+ ```
+ Raises:
+ Exception if structure format is incorrect
+
+ ::
+
+ Since version 1.3.0
+ """
+ if sys.version_info[0] < 3:
+ buf = map(ord, data)
+ else:
+ buf = data
+ values = (c_byte * len(buf))()
+ for i in range(len(buf)):
+ values[i] = buf[i]
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadMoleculeFromBuffer(values, len(buf))
+ ),
+ )
+
+ def loadQueryMolecule(self, string):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadQueryMoleculeFromString(
+ string.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadQueryMoleculeFromFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadQueryMoleculeFromFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadSmarts(self, string):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadSmartsFromString(
+ string.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadSmartsFromFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadSmartsFromFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadReaction(self, string):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadReactionFromString(
+ string.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadReactionFromFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadReactionFromFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadQueryReaction(self, string):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadQueryReactionFromString(
+ string.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadQueryReactionFromFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadQueryReactionFromFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadReactionSmarts(self, string):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadReactionSmartsFromString(
+ string.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadReactionSmartsFromFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadReactionSmartsFromFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadStructure(self, structureStr, parameter=None):
+ self._setSessionId()
+ parameter = "" if parameter is None else parameter
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadStructureFromString(
+ structureStr.encode(ENCODE_ENCODING),
+ parameter.encode(ENCODE_ENCODING),
+ )
+ ),
+ )
+
+ def loadStructureFromBuffer(self, structureData, parameter=None):
+ if sys.version_info[0] < 3:
+ buf = map(ord, structureData)
+ else:
+ buf = structureData
+ values = (c_byte * len(buf))()
+ for i in range(len(buf)):
+ values[i] = buf[i]
+ self._setSessionId()
+ parameter = "" if parameter is None else parameter
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadStructureFromBuffer(
+ values, len(buf), parameter.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def loadStructureFromFile(self, filename, parameter=None):
+ self._setSessionId()
+ parameter = "" if parameter is None else parameter
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadStructureFromFile(
+ filename.encode(ENCODE_ENCODING),
+ parameter.encode(ENCODE_ENCODING),
+ )
+ ),
+ )
+
+ def checkStructure(self, structure, props=""):
+ if props is None:
+ props = ""
+ self._setSessionId()
+ return self._checkResultString(
+ Indigo._lib.indigoCheckStructure(
+ structure.encode(ENCODE_ENCODING),
+ props.encode(ENCODE_ENCODING),
+ )
+ )
+
+ def loadFingerprintFromBuffer(self, buffer):
+ """Creates a fingerprint from the supplied binary data
+
+ :param buffer: a list of bytes
+ :return: a fingerprint object
+
+ Since version 1.3.0
+ """
+ self._setSessionId()
+ length = len(buffer)
+
+ values = (c_byte * length)()
+ for i in range(length):
+ values[i] = buffer[i]
+
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadFingerprintFromBuffer(values, length)
+ ),
+ )
+
+ def loadFingerprintFromDescriptors(self, descriptors, size, density):
+ """Packs a list of molecule descriptors into a fingerprint object
+
+ :param descriptors: list of normalized numbers (roughly) between 0.0 and 1.0
+ :param size: size of the fingerprint in bytes
+ :param density: approximate density of '1's vs `0`s in the fingerprint
+ :return: a fingerprint object
+
+ Since version 1.3.0
+ """
+ self._setSessionId()
+ length = len(descriptors)
+
+ descr_arr = (c_double * length)()
+ for i in range(length):
+ descr_arr[i] = descriptors[i]
+
+ result = Indigo._lib.indigoLoadFingerprintFromDescriptors(
+ descr_arr, length, size, density
+ )
+ return self.IndigoObject(self, self._checkResult(result))
+
+ def createReaction(self):
+ self._setSessionId()
+ return self.IndigoObject(
+ self, self._checkResult(Indigo._lib.indigoCreateReaction())
+ )
+
+ def createQueryReaction(self):
+ self._setSessionId()
+ return self.IndigoObject(
+ self, self._checkResult(Indigo._lib.indigoCreateQueryReaction())
+ )
+
+ def exactMatch(self, item1, item2, flags=""):
+ if flags is None:
+ flags = ""
+ self._setSessionId()
+ newobj = self._checkResult(
+ Indigo._lib.indigoExactMatch(
+ item1.id, item2.id, flags.encode(ENCODE_ENCODING)
+ )
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.IndigoObject(self, newobj, [item1, item2, self])
+
+ def setTautomerRule(self, id, beg, end):
+ self._setSessionId()
+ return self._checkResult(
+ Indigo._lib.indigoSetTautomerRule(
+ id, beg.encode(ENCODE_ENCODING), end.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def removeTautomerRule(self, id):
+ self._setSessionId()
+ return self._checkResult(Indigo._lib.indigoRemoveTautomerRule(id))
+
+ def clearTautomerRules(self):
+ self._setSessionId()
+ return self._checkResult(Indigo._lib.indigoClearTautomerRules())
+
+ def commonBits(self, fingerprint1, fingerprint2):
+ self._setSessionId()
+ return self._checkResult(
+ Indigo._lib.indigoCommonBits(fingerprint1.id, fingerprint2.id)
+ )
+
+ def similarity(self, item1, item2, metrics=""):
+ if metrics is None:
+ metrics = ""
+ self._setSessionId()
+ return self._checkResultFloat(
+ Indigo._lib.indigoSimilarity(
+ item1.id, item2.id, metrics.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def iterateSDFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoIterateSDFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def iterateRDFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoIterateRDFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def iterateSmilesFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoIterateSmilesFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def iterateCMLFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoIterateCMLFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def iterateCDXFile(self, filename):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoIterateCDXFile(
+ filename.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def createFileSaver(self, filename, format):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoCreateFileSaver(
+ filename.encode(ENCODE_ENCODING),
+ format.encode(ENCODE_ENCODING),
+ )
+ ),
+ )
+
+ def createSaver(self, obj, format):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoCreateSaver(
+ obj.id, format.encode(ENCODE_ENCODING)
+ )
+ ),
+ )
+
+ def createArray(self):
+ self._setSessionId()
+ return self.IndigoObject(
+ self, self._checkResult(Indigo._lib.indigoCreateArray())
+ )
+
+ def substructureMatcher(self, target, mode=""):
+ if mode is None:
+ mode = ""
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoSubstructureMatcher(
+ target.id, mode.encode(ENCODE_ENCODING)
+ )
+ ),
+ target,
+ )
+
+ def extractCommonScaffold(self, structures, options=""):
+ structures = self.convertToArray(structures)
+ if options is None:
+ options = ""
+ self._setSessionId()
+ newobj = self._checkResult(
+ Indigo._lib.indigoExtractCommonScaffold(
+ structures.id, options.encode(ENCODE_ENCODING)
+ )
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.IndigoObject(self, newobj, self)
+
+ def decomposeMolecules(self, scaffold, structures):
+ structures = self.convertToArray(structures)
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoDecomposeMolecules(
+ scaffold.id, structures.id
+ )
+ ),
+ scaffold,
+ )
+
+ def rgroupComposition(self, molecule, options=""):
+ if options is None:
+ options = ""
+ self._setSessionId()
+ newobj = self._checkResult(
+ Indigo._lib.indigoRGroupComposition(
+ molecule.id, options.encode(ENCODE_ENCODING)
+ )
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.IndigoObject(self, newobj, self)
+
+ def getFragmentedMolecule(self, elem, options=""):
+ if options is None:
+ options = ""
+ self._setSessionId()
+ newobj = self._checkResult(
+ Indigo._lib.indigoGetFragmentedMolecule(
+ elem.id, options.encode(ENCODE_ENCODING)
+ )
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.IndigoObject(self, newobj, self)
+
+ def createDecomposer(self, scaffold):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(Indigo._lib.indigoCreateDecomposer(scaffold.id)),
+ scaffold,
+ )
+
+ def reactionProductEnumerate(self, replacedaction, monomers):
+ self._setSessionId()
+ monomers = self.convertToArray(monomers)
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoReactionProductEnumerate(
+ replacedaction.id, monomers.id
+ )
+ ),
+ replacedaction,
+ )
+
+ def transform(self, reaction, monomers):
+ self._setSessionId()
+ newobj = self._checkResult(
+ Indigo._lib.indigoTransform(reaction.id, monomers.id)
+ )
+ if newobj == 0:
+ return None
+ else:
+ return self.IndigoObject(self, newobj, self)
+
+ def loadBuffer(self, buf):
+ buf = list(buf)
+ values = (c_byte * len(buf))()
+ for i in range(len(buf)):
+ values[i] = buf[i]
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(Indigo._lib.indigoLoadBuffer(values, len(buf))),
+ )
+
+ def loadString(self, string):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoLoadString(string.encode(ENCODE_ENCODING))
+ ),
+ )
+
+ def iterateSDF(self, reader):
+ self._setSessionId()
+ result = self._checkResult(Indigo._lib.indigoIterateSDF(reader.id))
+ if not result:
+ return None
+ return self.IndigoObject(self, result, reader)
+
+ def iterateSmiles(self, reader):
+ self._setSessionId()
+ result = self._checkResult(Indigo._lib.indigoIterateSmiles(reader.id))
+ if not result:
+ return None
+ return self.IndigoObject(self, result, reader)
+
+ def iterateCML(self, reader):
+ self._setSessionId()
+ result = self._checkResult(Indigo._lib.indigoIterateCML(reader.id))
+ if not result:
+ return None
+ return self.IndigoObject(self, result, reader)
+
+ def iterateCDX(self, reader):
+ self._setSessionId()
+ result = self._checkResult(Indigo._lib.indigoIterateCDX(reader.id))
+ if not result:
+ return None
+ return self.IndigoObject(self, result, reader)
+
+ def iterateRDF(self, reader):
+ self._setSessionId()
+ result = self._checkResult(Indigo._lib.indigoIterateRDF(reader.id))
+ if not result:
+ return None
+ return self.IndigoObject(self, result, reader)
+
+ def iterateTautomers(self, molecule, params):
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoIterateTautomers(
+ molecule.id, params.encode(ENCODE_ENCODING)
+ )
+ ),
+ molecule,
+ )
+
+ def nameToStructure(self, name, params=None):
+ """
+ Converts a chemical name into a corresponding structure
+
+ Args:
+ * name - a name to parse
+ * params - a string containing parsing options or nullptr if no options are changed
+
+ Raises:
+ Exception if parsing fails or no structure is found
+
+ ::
+
+ Since version 1.3.0
+ """
+ if params is None:
+ params = ""
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(
+ Indigo._lib.indigoNameToStructure(
+ name.encode(ENCODE_ENCODING),
+ params.encode(ENCODE_ENCODING),
+ )
+ ),
+ )
+
+ def buildPkaModel(self, level, threshold, filename):
+ self._setSessionId()
+ return self._checkResult(
+ Indigo._lib.indigoBuildPkaModel(
+ level, threshold, filename.encode(ENCODE_ENCODING)
+ )
+ )
+
+ def transformHELMtoSCSR(self, item):
+ """
+ ::
+
+ Since version 1.3.0
+ """
+ self._setSessionId()
+ return self.IndigoObject(
+ self,
+ self._checkResult(Indigo._lib.indigoTransformHELMtoSCSR(item.id)),
+ )
diff --git a/molscribe/indigo/__pycache__/__init__.cpython-310.pyc b/molscribe/indigo/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..738de4beeea5aa655114bcf9edb3868b0f44923b
Binary files /dev/null and b/molscribe/indigo/__pycache__/__init__.cpython-310.pyc differ
diff --git a/molscribe/indigo/__pycache__/bingo.cpython-310.pyc b/molscribe/indigo/__pycache__/bingo.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29eeae0f22f7b7b1cf23a4654d77055974e48fc6
Binary files /dev/null and b/molscribe/indigo/__pycache__/bingo.cpython-310.pyc differ
diff --git a/molscribe/indigo/__pycache__/inchi.cpython-310.pyc b/molscribe/indigo/__pycache__/inchi.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54bc5aac5d7942ac8129b74020a6a932bc379f13
Binary files /dev/null and b/molscribe/indigo/__pycache__/inchi.cpython-310.pyc differ
diff --git a/molscribe/indigo/__pycache__/renderer.cpython-310.pyc b/molscribe/indigo/__pycache__/renderer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a2b20083c8224e48178a84e8d4c089bfbf98912
Binary files /dev/null and b/molscribe/indigo/__pycache__/renderer.cpython-310.pyc differ
diff --git a/molscribe/indigo/bingo.py b/molscribe/indigo/bingo.py
new file mode 100644
index 0000000000000000000000000000000000000000..07c1f2701f31387bd32b2b189fda34f17baadabd
--- /dev/null
+++ b/molscribe/indigo/bingo.py
@@ -0,0 +1,334 @@
+#
+# Copyright (C) from 2009 to Present EPAM Systems.
+#
+# This file is part of Indigo toolkit.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from . import *
+
+
+class BingoException(Exception):
+
+ def __init__(self, value):
+ self.value = value
+
+ def __str__(self):
+ if sys.version_info > (3, 0):
+ return repr(self.value.decode('ascii'))
+ else:
+ return repr(self.value)
+
+
+class Bingo(object):
+ def __init__(self, bingoId, indigo, lib):
+ self._id = bingoId
+ self._indigo = indigo
+ self._lib = lib
+ self._lib.bingoVersion.restype = c_char_p
+ self._lib.bingoVersion.argtypes = None
+ self._lib.bingoCreateDatabaseFile.restype = c_int
+ self._lib.bingoCreateDatabaseFile.argtypes = [c_char_p, c_char_p, c_char_p]
+ self._lib.bingoLoadDatabaseFile.restype = c_int
+ self._lib.bingoLoadDatabaseFile.argtypes = [c_char_p, c_char_p]
+ self._lib.bingoCloseDatabase.restype = c_int
+ self._lib.bingoCloseDatabase.argtypes = [c_int]
+ self._lib.bingoInsertRecordObj.restype = c_int
+ self._lib.bingoInsertRecordObj.argtypes = [c_int, c_int]
+ self._lib.bingoInsertRecordObjWithExtFP.restype = c_int
+ self._lib.bingoInsertRecordObjWithExtFP.argtypes = [c_int, c_int, c_int]
+ self._lib.bingoGetRecordObj.restype = c_int
+ self._lib.bingoGetRecordObj.argtypes = [c_int, c_int]
+ self._lib.bingoInsertRecordObjWithId.restype = c_int
+ self._lib.bingoInsertRecordObjWithId.argtypes = [c_int, c_int, c_int]
+ self._lib.bingoInsertRecordObjWithIdAndExtFP.restype = c_int
+ self._lib.bingoInsertRecordObjWithIdAndExtFP.argtypes = [c_int, c_int, c_int, c_int]
+ self._lib.bingoDeleteRecord.restype = c_int
+ self._lib.bingoDeleteRecord.argtypes = [c_int, c_int]
+ self._lib.bingoSearchSub.restype = c_int
+ self._lib.bingoSearchSub.argtypes = [c_int, c_int, c_char_p]
+ self._lib.bingoSearchExact.restype = c_int
+ self._lib.bingoSearchExact.argtypes = [c_int, c_int, c_char_p]
+ self._lib.bingoSearchMolFormula.restype = c_int
+ self._lib.bingoSearchMolFormula.argtypes = [c_int, c_char_p, c_char_p]
+ self._lib.bingoSearchSim.restype = c_int
+ self._lib.bingoSearchSim.argtypes = [c_int, c_int, c_float, c_float, c_char_p]
+ self._lib.bingoSearchSimWithExtFP.restype = c_int
+ self._lib.bingoSearchSimWithExtFP.argtypes = [c_int, c_int, c_float, c_float, c_int, c_char_p]
+ self._lib.bingoSearchSimTopN.restype = c_int
+ self._lib.bingoSearchSimTopN.argtypes = [c_int, c_int, c_int, c_float, c_char_p]
+ self._lib.bingoSearchSimTopNWithExtFP.restype = c_int
+ self._lib.bingoSearchSimTopNWithExtFP.argtypes = [c_int, c_int, c_int, c_float, c_int, c_char_p]
+ self._lib.bingoEnumerateId.restype = c_int
+ self._lib.bingoEnumerateId.argtypes = [c_int]
+ self._lib.bingoNext.restype = c_int
+ self._lib.bingoNext.argtypes = [c_int]
+ self._lib.bingoGetCurrentId.restype = c_int
+ self._lib.bingoGetCurrentId.argtypes = [c_int]
+ self._lib.bingoGetObject.restype = c_int
+ self._lib.bingoGetObject.argtypes = [c_int]
+ self._lib.bingoEndSearch.restype = c_int
+ self._lib.bingoEndSearch.argtypes = [c_int]
+ self._lib.bingoGetCurrentSimilarityValue.restype = c_float
+ self._lib.bingoGetCurrentSimilarityValue.argtypes = [c_int]
+ self._lib.bingoOptimize.restype = c_int
+ self._lib.bingoOptimize.argtypes = [c_int]
+ self._lib.bingoEstimateRemainingResultsCount.restype = c_int
+ self._lib.bingoEstimateRemainingResultsCount.argtypes = [c_int]
+ self._lib.bingoEstimateRemainingResultsCountError.restype = c_int
+ self._lib.bingoEstimateRemainingResultsCountError.argtypes = [c_int]
+ self._lib.bingoEstimateRemainingTime.restype = c_int
+ self._lib.bingoEstimateRemainingTime.argtypes = [c_int, POINTER(c_float)]
+ self._lib.bingoContainersCount.restype = c_int
+ self._lib.bingoContainersCount.argtypes = [c_int]
+ self._lib.bingoCellsCount.restype = c_int
+ self._lib.bingoCellsCount.argtypes = [c_int]
+ self._lib.bingoCurrentCell.restype = c_int
+ self._lib.bingoCurrentCell.argtypes = [c_int]
+ self._lib.bingoMinCell.restype = c_int
+ self._lib.bingoMinCell.argtypes = [c_int]
+ self._lib.bingoMaxCell.restype = c_int
+ self._lib.bingoMaxCell.argtypes = [c_int]
+
+ def __del__(self):
+ self.close()
+
+ def close(self):
+ self._indigo._setSessionId()
+ if self._id >= 0:
+ Bingo._checkResult(self._indigo, self._lib.bingoCloseDatabase(self._id))
+ self._id = -1
+
+ @staticmethod
+ def _checkResult(indigo, result):
+ if result < 0:
+ raise BingoException(indigo._lib.indigoGetLastError())
+ return result
+
+ @staticmethod
+ def _checkResultPtr (indigo, result):
+ if result is None:
+ raise BingoException(indigo._lib.indigoGetLastError())
+ return result
+
+ @staticmethod
+ def _checkResultString (indigo, result):
+ res = Bingo._checkResultPtr(indigo, result)
+ if sys.version_info >= (3, 0):
+ return res.decode('ascii')
+ else:
+ return res.encode('ascii')
+
+ @staticmethod
+ def _getLib(indigo):
+ if os.name == 'posix' and not platform.mac_ver()[0] and not platform.system().startswith("CYGWIN"):
+ _lib = CDLL(indigo.dllpath + "/libbingo.so")
+ elif os.name == 'nt' or platform.system().startswith("CYGWIN"):
+ _lib = CDLL(indigo.dllpath + "/bingo.dll")
+ elif platform.mac_ver()[0]:
+ _lib = CDLL(indigo.dllpath + "/libbingo.dylib")
+ else:
+ raise BingoException("unsupported OS: " + os.name)
+ return _lib
+
+ @staticmethod
+ def createDatabaseFile(indigo, path, databaseType, options=''):
+ indigo._setSessionId()
+ if not options:
+ options = ''
+ lib = Bingo._getLib(indigo)
+ lib.bingoCreateDatabaseFile.restype = c_int
+ lib.bingoCreateDatabaseFile.argtypes = [c_char_p, c_char_p, c_char_p]
+ return Bingo(Bingo._checkResult(indigo, lib.bingoCreateDatabaseFile(path.encode('ascii'), databaseType.encode('ascii'), options.encode('ascii'))), indigo, lib)
+
+ @staticmethod
+ def loadDatabaseFile(indigo, path, options=''):
+ indigo._setSessionId()
+ if not options:
+ options = ''
+ lib = Bingo._getLib(indigo)
+ lib.bingoLoadDatabaseFile.restype = c_int
+ lib.bingoLoadDatabaseFile.argtypes = [c_char_p, c_char_p]
+ return Bingo(Bingo._checkResult(indigo, lib.bingoLoadDatabaseFile(path.encode('ascii'), options.encode('ascii'))), indigo, lib)
+
+ def version(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResultString(self._indigo, self._lib.bingoVersion())
+
+ def insert(self, indigoObject, index=None):
+ self._indigo._setSessionId()
+ if not index:
+ return Bingo._checkResult(self._indigo, self._lib.bingoInsertRecordObj(self._id, indigoObject.id))
+ else:
+ return Bingo._checkResult(self._indigo,
+ self._lib.bingoInsertRecordObjWithId(self._id, indigoObject.id, index))
+
+ def insertWithExtFP(self, indigoObject, ext_fp, index=None):
+ self._indigo._setSessionId()
+ if not index:
+ return Bingo._checkResult(self._indigo, self._lib.bingoInsertRecordObjWithExtFP(self._id, indigoObject.id, ext_fp.id))
+ else:
+ return Bingo._checkResult(self._indigo,
+ self._lib.bingoInsertRecordObjWithIdAndExtFP(self._id, indigoObject.id, index, ext_fp.id))
+
+ def delete(self, index):
+ self._indigo._setSessionId()
+ Bingo._checkResult(self._indigo, self._lib.bingoDeleteRecord(self._id, index))
+
+ def searchSub(self, query, options=''):
+ self._indigo._setSessionId()
+ if not options:
+ options = ''
+ return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchSub(self._id, query.id, options.encode('ascii'))),
+ self._indigo, self)
+
+ def searchExact(self, query, options=''):
+ self._indigo._setSessionId()
+ if not options:
+ options = ''
+ return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchExact(self._id, query.id, options.encode('ascii'))),
+ self._indigo, self)
+
+ def searchSim(self, query, minSim, maxSim, metric='tanimoto'):
+ self._indigo._setSessionId()
+ if not metric:
+ metric = 'tanimoto'
+ return BingoObject(
+ Bingo._checkResult(self._indigo, self._lib.bingoSearchSim(self._id, query.id, minSim, maxSim, metric.encode('ascii'))),
+ self._indigo, self)
+
+ def searchSimWithExtFP(self, query, minSim, maxSim, ext_fp, metric='tanimoto'):
+ self._indigo._setSessionId()
+ if not metric:
+ metric = 'tanimoto'
+ return BingoObject(
+ Bingo._checkResult(self._indigo, self._lib.bingoSearchSimWithExtFP(self._id, query.id, minSim, maxSim, ext_fp.id, metric.encode('ascii'))),
+ self._indigo, self)
+
+ def searchSimTopN(self, query, limit, minSim, metric='tanimoto'):
+ self._indigo._setSessionId()
+ if not metric:
+ metric = 'tanimoto'
+ return BingoObject(
+ Bingo._checkResult(self._indigo, self._lib.bingoSearchSimTopN(self._id, query.id, limit, minSim, metric.encode('ascii'))),
+ self._indigo, self)
+
+ def searchSimTopNWithExtFP(self, query, limit, minSim, ext_fp, metric='tanimoto'):
+ self._indigo._setSessionId()
+ if not metric:
+ metric = 'tanimoto'
+ return BingoObject(
+ Bingo._checkResult(self._indigo, self._lib.bingoSearchSimTopNWithExtFP(self._id, query.id, limit, minSim, ext_fp.id, metric.encode('ascii'))),
+ self._indigo, self)
+
+ def enumerateId(self):
+ self._indigo._setSessionId()
+ e = self._lib.bingoEnumerateId(self._id)
+ result = Bingo._checkResult(self._indigo, e)
+ return BingoObject(result, self._indigo, self)
+
+ def searchMolFormula(self, query, options=''):
+ self._indigo._setSessionId()
+ if not options:
+ options = ''
+ return BingoObject(Bingo._checkResult(self._indigo, self._lib.bingoSearchMolFormula(self._id, query.encode('ascii'), options.encode('ascii'))),
+ self._indigo, self)
+
+ def optimize(self):
+ self._indigo._setSessionId()
+ Bingo._checkResult(self._indigo, self._lib.bingoOptimize(self._id))
+
+ def getRecordById (self, id):
+ self._indigo._setSessionId()
+ return IndigoObject(self._indigo, Bingo._checkResult(self._indigo, self._lib.bingoGetRecordObj(self._id, id)))
+
+class BingoObject(object):
+ def __init__(self, objId, indigo, bingo):
+ self._id = objId
+ self._indigo = indigo
+ self._bingo = bingo
+
+ def __del__(self):
+ self.close()
+
+ def close(self):
+ self._indigo._setSessionId()
+ if self._id >= 0:
+ Bingo._checkResult(self._indigo, self._bingo._lib.bingoEndSearch(self._id))
+ self._id = -1
+
+ def next(self):
+ self._indigo._setSessionId()
+ return (Bingo._checkResult(self._indigo, self._bingo._lib.bingoNext(self._id)) == 1)
+
+ def getCurrentId(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetCurrentId(self._id))
+
+ def getIndigoObject(self):
+ self._indigo._setSessionId()
+ return IndigoObject(self._indigo, Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetObject(self._id)))
+
+ def getCurrentSimilarityValue(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoGetCurrentSimilarityValue(self._id))
+
+ def estimateRemainingResultsCount(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingResultsCount(self._id))
+
+ def estimateRemainingResultsCountError(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingResultsCountError(self._id))
+
+ def estimateRemainingTime(self):
+ self._indigo._setSessionId()
+ value = c_float()
+ Bingo._checkResult(self._indigo, self._bingo._lib.bingoEstimateRemainingTime(self._id, pointer(value)))
+ return value.value
+
+ def containersCount(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoContainersCount(self._id))
+
+ def cellsCount(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoCellsCount(self._id))
+
+ def currentCell(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoCurrentCell(self._id))
+
+ def minCell(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoMinCell(self._id))
+
+ def maxCell(self):
+ self._indigo._setSessionId()
+ return Bingo._checkResult(self._indigo, self._bingo._lib.bingoMaxCell(self._id))
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ next_item = self.next()
+ if next_item:
+ return self
+ raise StopIteration
diff --git a/molscribe/indigo/inchi.py b/molscribe/indigo/inchi.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a0cfd528ae05ed12af1963675217e1840468a0a
--- /dev/null
+++ b/molscribe/indigo/inchi.py
@@ -0,0 +1,84 @@
+#
+# Copyright (C) from 2009 to Present EPAM Systems.
+#
+# This file is part of Indigo toolkit.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import *
+
+
+class IndigoInchi(object):
+ def __init__(self, indigo):
+ self.indigo = indigo
+
+ if os.name == 'posix' and not platform.mac_ver()[0] and not platform.system().startswith("CYGWIN"):
+ self._lib = CDLL(indigo.dllpath + "/libindigo-inchi.so")
+ elif os.name == 'nt' or platform.system().startswith("CYGWIN"):
+ self._lib = CDLL(indigo.dllpath + "\indigo-inchi.dll")
+ elif platform.mac_ver()[0]:
+ self._lib = CDLL(indigo.dllpath + "/libindigo-inchi.dylib")
+ else:
+ raise IndigoException("unsupported OS: " + os.name)
+
+ self._lib.indigoInchiVersion.restype = c_char_p
+ self._lib.indigoInchiVersion.argtypes = []
+ self._lib.indigoInchiResetOptions.restype = c_int
+ self._lib.indigoInchiResetOptions.argtypes = []
+ self._lib.indigoInchiLoadMolecule.restype = c_int
+ self._lib.indigoInchiLoadMolecule.argtypes = [c_char_p]
+ self._lib.indigoInchiGetInchi.restype = c_char_p
+ self._lib.indigoInchiGetInchi.argtypes = [c_int]
+ self._lib.indigoInchiGetInchiKey.restype = c_char_p
+ self._lib.indigoInchiGetInchiKey.argtypes = [c_char_p]
+ self._lib.indigoInchiGetWarning.restype = c_char_p
+ self._lib.indigoInchiGetWarning.argtypes = []
+ self._lib.indigoInchiGetLog.restype = c_char_p
+ self._lib.indigoInchiGetLog.argtypes = []
+ self._lib.indigoInchiGetAuxInfo.restype = c_char_p
+ self._lib.indigoInchiGetAuxInfo.argtypes = []
+
+ def resetOptions(self):
+ self.indigo._setSessionId()
+ self.indigo._checkResult(self._lib.indigoInchiResetOptions())
+
+ def loadMolecule(self, inchi):
+ self.indigo._setSessionId()
+ res = self.indigo._checkResult(self._lib.indigoInchiLoadMolecule(inchi.encode('ascii')))
+ if res == 0:
+ return None
+ return self.indigo.IndigoObject(self.indigo, res)
+
+ def version(self):
+ self.indigo._setSessionId()
+ return self.indigo._checkResultString(self._lib.indigoInchiVersion())
+
+ def getInchi(self, molecule):
+ self.indigo._setSessionId()
+ return self.indigo._checkResultString(self._lib.indigoInchiGetInchi(molecule.id))
+
+ def getInchiKey(self, inchi):
+ self.indigo._setSessionId()
+ return self.indigo._checkResultString(self._lib.indigoInchiGetInchiKey(inchi.encode('ascii')))
+
+ def getWarning(self):
+ self.indigo._setSessionId()
+ return self.indigo._checkResultString(self._lib.indigoInchiGetWarning())
+
+ def getLog(self):
+ self.indigo._setSessionId()
+ return self.indigo._checkResultString(self._lib.indigoInchiGetLog())
+
+ def getAuxInfo(self):
+ self.indigo._setSessionId()
+ return self.indigo._checkResultString(self._lib.indigoInchiGetAuxInfo())
diff --git a/molscribe/indigo/renderer.py b/molscribe/indigo/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..699c80e905d3c307239cef3c8f6d8c8f9b7dc075
--- /dev/null
+++ b/molscribe/indigo/renderer.py
@@ -0,0 +1,113 @@
+#
+# Copyright (C) from 2009 to Present EPAM Systems.
+#
+# This file is part of Indigo toolkit.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import platform
+from ctypes import CDLL, POINTER, c_char_p, c_int
+
+from . import IndigoException
+
+
+class IndigoRenderer(object):
+ def __init__(self, indigo):
+ self.indigo = indigo
+
+ if (
+ os.name == "posix"
+ and not platform.mac_ver()[0]
+ and not platform.system().startswith("CYGWIN")
+ ):
+ self._lib = CDLL(indigo.dllpath + "/libindigo-renderer.so")
+ elif os.name == "nt" or platform.system().startswith("CYGWIN"):
+ self._lib = CDLL(indigo.dllpath + "\indigo-renderer.dll")
+ elif platform.mac_ver()[0]:
+ self._lib = CDLL(indigo.dllpath + "/libindigo-renderer.dylib")
+ else:
+ raise IndigoException("unsupported OS: " + os.name)
+
+ self._lib.indigoRender.restype = c_int
+ self._lib.indigoRender.argtypes = [c_int, c_int]
+ self._lib.indigoRenderToFile.restype = c_int
+ self._lib.indigoRenderToFile.argtypes = [c_int, c_char_p]
+ self._lib.indigoRenderGrid.restype = c_int
+ self._lib.indigoRenderGrid.argtypes = [
+ c_int,
+ POINTER(c_int),
+ c_int,
+ c_int,
+ ]
+ self._lib.indigoRenderGridToFile.restype = c_int
+ self._lib.indigoRenderGridToFile.argtypes = [
+ c_int,
+ POINTER(c_int),
+ c_int,
+ c_char_p,
+ ]
+ self._lib.indigoRenderReset.restype = c_int
+ self._lib.indigoRenderReset.argtypes = [c_int]
+
+ def renderToBuffer(self, obj):
+ self.indigo._setSessionId()
+ wb = self.indigo.writeBuffer()
+ try:
+ self.indigo._checkResult(self._lib.indigoRender(obj.id, wb.id))
+ return wb.toBuffer()
+ finally:
+ wb.dispose()
+
+ def renderToFile(self, obj, filename):
+ self.indigo._setSessionId()
+ self.indigo._checkResult(
+ self._lib.indigoRenderToFile(obj.id, filename.encode("ascii"))
+ )
+
+ def renderGridToFile(self, objects, refatoms, ncolumns, filename):
+ self.indigo._setSessionId()
+ arr = None
+ if refatoms:
+ if len(refatoms) != objects.count():
+ raise IndigoException(
+ "renderGridToFile(): refatoms[] size must be equal to the number of objects"
+ )
+ arr = (c_int * len(refatoms))()
+ for i in range(len(refatoms)):
+ arr[i] = refatoms[i]
+ self.indigo._checkResult(
+ self._lib.indigoRenderGridToFile(
+ objects.id, arr, ncolumns, filename.encode("ascii")
+ )
+ )
+
+ def renderGridToBuffer(self, objects, refatoms, ncolumns):
+ self.indigo._setSessionId()
+ arr = None
+ if refatoms:
+ if len(refatoms) != objects.count():
+ raise IndigoException(
+ "renderGridToBuffer(): refatoms[] size must be equal to the number of objects"
+ )
+ arr = (c_int * len(refatoms))()
+ for i in range(len(refatoms)):
+ arr[i] = refatoms[i]
+ wb = self.indigo.writeBuffer()
+ try:
+ self.indigo._checkResult(
+ self._lib.indigoRenderGrid(objects.id, arr, ncolumns, wb.id)
+ )
+ return wb.toBuffer()
+ finally:
+ wb.dispose()
diff --git a/molscribe/inference/__init__.py b/molscribe/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c61feef15b9a94dbb97126be593f0c445f1870c0
--- /dev/null
+++ b/molscribe/inference/__init__.py
@@ -0,0 +1,4 @@
+from .greedy_search import GreedySearch
+from .beam_search import BeamSearch
+
+__all__ = ["GreedySearch", "BeamSearch"]
diff --git a/molscribe/inference/__pycache__/__init__.cpython-310.pyc b/molscribe/inference/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8b43f18a51e763393520b8df253a57861579bcc
Binary files /dev/null and b/molscribe/inference/__pycache__/__init__.cpython-310.pyc differ
diff --git a/molscribe/inference/__pycache__/beam_search.cpython-310.pyc b/molscribe/inference/__pycache__/beam_search.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d80514dfb41e30a77c0a1582a169356bde691d9
Binary files /dev/null and b/molscribe/inference/__pycache__/beam_search.cpython-310.pyc differ
diff --git a/molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc b/molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e586198b03149acb340d38306ee4f971c38c9078
Binary files /dev/null and b/molscribe/inference/__pycache__/decode_strategy.cpython-310.pyc differ
diff --git a/molscribe/inference/__pycache__/greedy_search.cpython-310.pyc b/molscribe/inference/__pycache__/greedy_search.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1653da695283df3f74c8b4439eab570a25ad30d3
Binary files /dev/null and b/molscribe/inference/__pycache__/greedy_search.cpython-310.pyc differ
diff --git a/molscribe/inference/beam_search.py b/molscribe/inference/beam_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..60e7df337f92d51ddceceb7a51070aae92e3db08
--- /dev/null
+++ b/molscribe/inference/beam_search.py
@@ -0,0 +1,190 @@
+import torch
+from .decode_strategy import DecodeStrategy
+
+
+class BeamSearch(DecodeStrategy):
+ """Generation with beam search.
+ """
+
+ def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length,
+ return_attention, max_length):
+ super(BeamSearch, self).__init__(
+ pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length)
+ self.beam_size = beam_size
+ self.n_best = n_best
+
+ # result caching
+ self.hypotheses = [[] for _ in range(batch_size)]
+
+ # beam state
+ self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool)
+
+ self._batch_offset = torch.arange(batch_size, dtype=torch.long)
+
+ self.select_indices = None
+ self.done = False
+
+ def initialize(self, memory_bank, device=None):
+ """Repeat src objects `beam_size` times.
+ """
+
+ def fn_map_state(state, dim):
+ return torch.repeat_interleave(state, self.beam_size, dim=dim)
+
+ memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0)
+ if device is None:
+ device = memory_bank.device
+
+ self.memory_length = memory_bank.size(1)
+ super().initialize(memory_bank, device)
+
+ self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device)
+ self._beam_offset = torch.arange(
+ 0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device)
+ self.topk_log_probs = torch.tensor(
+ [0.0] + [float("-inf")] * (self.beam_size - 1), device=device
+ ).repeat(self.batch_size)
+ # buffers for the topk scores and 'backpointer'
+ self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device)
+ self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device)
+ self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device)
+
+ return fn_map_state, memory_bank
+
+ @property
+ def current_predictions(self):
+ return self.alive_seq[:, -1]
+
+ @property
+ def current_backptr(self):
+ # for testing
+ return self.select_indices.view(self.batch_size, self.beam_size)
+
+ @property
+ def batch_offset(self):
+ return self._batch_offset
+
+ def _pick(self, log_probs):
+ """Return token decision for a step.
+
+ Args:
+ log_probs (FloatTensor): (B, vocab_size)
+
+ Returns:
+ topk_scores (FloatTensor): (B, beam_size)
+ topk_ids (LongTensor): (B, beam_size)
+ """
+ vocab_size = log_probs.size(-1)
+
+ # Flatten probs into a list of probabilities.
+ curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
+ topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1)
+ return topk_scores, topk_ids
+
+ def advance(self, log_probs, attn):
+ """
+ Args:
+ log_probs: (B * beam_size, vocab_size)
+ """
+ vocab_size = log_probs.size(-1)
+
+ # (non-finished) batch_size
+ _B = log_probs.shape[0] // self.beam_size
+
+ step = len(self) # alive_seq
+ self.ensure_min_length(log_probs)
+
+ # Multiply probs by the beam probability
+ log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
+
+ curr_length = step + 1
+ curr_scores = log_probs / curr_length # avg log_prob
+ self.topk_scores, self.topk_ids = self._pick(curr_scores)
+ # topk_scores/topk_ids: (batch_size, beam_size)
+
+ # Recover log probs
+ torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs)
+
+ # Resolve beam origin and map to batch index flat representation.
+ self._batch_index = self.topk_ids // vocab_size
+ self._batch_index += self._beam_offset[:_B].unsqueeze(1)
+ self.select_indices = self._batch_index.view(_B * self.beam_size)
+ self.topk_ids.fmod_(vocab_size) # resolve true word ids
+
+ # Append last prediction.
+ self.alive_seq = torch.cat(
+ [self.alive_seq.index_select(0, self.select_indices),
+ self.topk_ids.view(_B * self.beam_size, 1)], -1)
+
+ if self.return_attention:
+ current_attn = attn.index_select(1, self.select_indices)
+ if step == 1:
+ self.alive_attn = current_attn
+ else:
+ self.alive_attn = self.alive_attn.index_select(
+ 1, self.select_indices)
+ self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
+
+ self.is_finished = self.topk_ids.eq(self.eos)
+ self.ensure_max_length()
+
+ def update_finished(self):
+ _B_old = self.topk_log_probs.shape[0]
+ step = self.alive_seq.shape[-1] # len(self)
+ self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
+
+ self.is_finished = self.is_finished.to('cpu')
+ self.top_beam_finished |= self.is_finished[:, 0].eq(1)
+ predictions = self.alive_seq.view(_B_old, self.beam_size, step)
+ attention = (
+ self.alive_attn.view(
+ step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
+ if self.alive_attn is not None else None)
+ non_finished_batch = []
+ for i in range(self.is_finished.size(0)):
+ b = self._batch_offset[i]
+ finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1)
+ # Store finished hypothesis for this batch.
+ for j in finished_hyp: # Beam level: finished beam j in batch i
+ self.hypotheses[b].append((
+ self.topk_scores[i, j],
+ predictions[i, j, 1:], # Ignore start token
+ attention[:, i, j, :self.memory_length]
+ if attention is not None else None))
+ # End condition is the top beam finished and we can return
+ # n_best hypotheses.
+ finish_flag = self.top_beam_finished[i] != 0
+ if finish_flag and len(self.hypotheses[b]) >= self.n_best:
+ best_hyp = sorted(
+ self.hypotheses[b], key=lambda x: x[0], reverse=True)
+ for n, (score, pred, attn) in enumerate(best_hyp):
+ if n >= self.n_best:
+ break
+ self.scores[b].append(score.item())
+ self.predictions[b].append(pred)
+ self.attention[b].append(
+ attn if attn is not None else [])
+ else:
+ non_finished_batch.append(i)
+ non_finished = torch.tensor(non_finished_batch)
+
+ if len(non_finished) == 0:
+ self.done = True
+ return
+
+ _B_new = non_finished.shape[0]
+ # Remove finished batches for the next step
+ self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished)
+ self._batch_offset = self._batch_offset.index_select(0, non_finished)
+ non_finished = non_finished.to(self.topk_ids.device)
+ self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished)
+ self._batch_index = self._batch_index.index_select(0, non_finished)
+ self.select_indices = self._batch_index.view(_B_new * self.beam_size)
+ self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1))
+ self.topk_scores = self.topk_scores.index_select(0, non_finished)
+ self.topk_ids = self.topk_ids.index_select(0, non_finished)
+
+ if self.alive_attn is not None:
+ inp_seq_len = self.alive_attn.size(-1)
+ self.alive_attn = attention.index_select(1, non_finished) \
+ .view(step - 1, _B_new * self.beam_size, inp_seq_len)
diff --git a/molscribe/inference/decode_strategy.py b/molscribe/inference/decode_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a92a9d4f69f3f5cf7f5d15c74ccf178351eb48e
--- /dev/null
+++ b/molscribe/inference/decode_strategy.py
@@ -0,0 +1,63 @@
+import torch
+
+
+class DecodeStrategy(object):
+ def __init__(self, pad, bos, eos, batch_size, parallel_paths, min_length, max_length,
+ return_attention=False, return_hidden=False):
+ self.pad = pad
+ self.bos = bos
+ self.eos = eos
+
+ self.batch_size = batch_size
+ self.parallel_paths = parallel_paths
+ # result catching
+ self.predictions = [[] for _ in range(batch_size)]
+ self.scores = [[] for _ in range(batch_size)]
+ self.token_scores = [[] for _ in range(batch_size)]
+ self.attention = [[] for _ in range(batch_size)]
+ self.hidden = [[] for _ in range(batch_size)]
+
+ self.alive_attn = None
+ self.alive_hidden = None
+
+ self.min_length = min_length
+ self.max_length = max_length
+
+ n_paths = batch_size * parallel_paths
+ self.return_attention = return_attention
+ self.return_hidden = return_hidden
+
+ self.done = False
+
+ def initialize(self, memory_bank, device=None):
+ if device is None:
+ device = torch.device('cpu')
+ self.alive_seq = torch.full(
+ [self.batch_size * self.parallel_paths, 1], self.bos,
+ dtype=torch.long, device=device)
+ self.is_finished = torch.zeros(
+ [self.batch_size, self.parallel_paths],
+ dtype=torch.uint8, device=device)
+ self.alive_log_token_scores = torch.zeros(
+ [self.batch_size * self.parallel_paths, 0],
+ dtype=torch.float, device=device)
+
+ return None, memory_bank
+
+ def __len__(self):
+ return self.alive_seq.shape[1]
+
+ def ensure_min_length(self, log_probs):
+ if len(self) <= self.min_length:
+ log_probs[:, self.eos] = -1e20 # forced non-end
+
+ def ensure_max_length(self):
+ if len(self) == self.max_length + 1:
+ self.is_finished.fill_(1)
+
+ def advance(self, log_probs, attn):
+ raise NotImplementedError()
+
+ def update_finished(self):
+ raise NotImplementedError
+
diff --git a/molscribe/inference/greedy_search.py b/molscribe/inference/greedy_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..55593ea99643a1163429e3c8263fa8105d653166
--- /dev/null
+++ b/molscribe/inference/greedy_search.py
@@ -0,0 +1,128 @@
+import torch
+from .decode_strategy import DecodeStrategy
+
+
+def sample_with_temperature(logits, sampling_temp, keep_topk):
+ """Select next tokens randomly from the top k possible next tokens.
+
+ Samples from a categorical distribution over the ``keep_topk`` words using
+ the category probabilities ``logits / sampling_temp``.
+ """
+
+ if sampling_temp == 0.0 or keep_topk == 1:
+ # argmax
+ topk_scores, topk_ids = logits.topk(1, dim=-1)
+ if sampling_temp > 0:
+ topk_scores /= sampling_temp
+ else:
+ logits = torch.div(logits, sampling_temp)
+ if keep_topk > 0:
+ top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
+ kth_best = top_values[:, -1].view([-1, 1])
+ kth_best = kth_best.repeat([1, logits.shape[1]]).float()
+ ignore = torch.lt(logits, kth_best)
+ logits = logits.masked_fill(ignore, -10000)
+
+ dist = torch.distributions.Multinomial(logits=logits, total_count=1)
+ topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
+ topk_scores = logits.gather(dim=1, index=topk_ids)
+
+ return topk_ids, topk_scores
+
+
+class GreedySearch(DecodeStrategy):
+ """Select next tokens randomly from the top k possible next tokens.
+ """
+
+ def __init__(self, pad, bos, eos, batch_size, min_length, max_length,
+ return_attention=False, return_hidden=False, sampling_temp=1, keep_topk=1):
+ super().__init__(
+ pad, bos, eos, batch_size, 1, min_length, max_length, return_attention, return_hidden)
+ self.sampling_temp = sampling_temp
+ self.keep_topk = keep_topk
+ self.topk_scores = None
+
+ def initialize(self, memory_bank, device=None):
+ fn_map_state = None
+
+ if device is None:
+ device = memory_bank.device
+
+ self.memory_length = memory_bank.size(1)
+ super().initialize(memory_bank, device)
+
+ self.select_indices = torch.arange(
+ self.batch_size, dtype=torch.long, device=device)
+ self.original_batch_idx = torch.arange(
+ self.batch_size, dtype=torch.long, device=device)
+
+ return fn_map_state, memory_bank
+
+ @property
+ def current_predictions(self):
+ return self.alive_seq[:, -1]
+
+ @property
+ def batch_offset(self):
+ return self.select_indices
+
+ def _pick(self, log_probs):
+ """Function used to pick next tokens.
+ """
+ topk_ids, topk_scores = sample_with_temperature(
+ log_probs, self.sampling_temp, self.keep_topk)
+ return topk_ids, topk_scores
+
+ def advance(self, log_probs, attn=None, hidden=None, label=None):
+ """Select next tokens randomly from the top k possible next tokens.
+ """
+ self.ensure_min_length(log_probs)
+ topk_ids, self.topk_scores = self._pick(log_probs) # log_probs: b x v; topk_ids & self.topk_scores: b x (t=1)
+ self.is_finished = topk_ids.eq(self.eos)
+ if label is not None:
+ label = label.view_as(self.is_finished)
+ self.is_finished = label.eq(self.eos)
+ self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) # b x (l+1) (first element is ; note l = len(self)-1)
+ self.alive_log_token_scores = torch.cat([self.alive_log_token_scores, self.topk_scores], -1)
+
+ if self.return_attention:
+ if self.alive_attn is None:
+ self.alive_attn = attn
+ else:
+ self.alive_attn = torch.cat([self.alive_attn, attn], 1)
+ if self.return_hidden:
+ if self.alive_hidden is None:
+ self.alive_hidden = hidden
+ else:
+ self.alive_hidden = torch.cat([self.alive_hidden, hidden], 1) # b x l x h
+ self.ensure_max_length()
+
+ def update_finished(self):
+ """Finalize scores and predictions."""
+ # is_finished indicates the decoder finished generating the sequence. Remove it from the batch and update
+ # the results.
+ finished_batches = self.is_finished.view(-1).nonzero()
+ for b in finished_batches.view(-1):
+ b_orig = self.original_batch_idx[b]
+ # scores/predictions/attention are lists,
+ # (to be compatible with beam-search)
+ self.scores[b_orig].append(torch.exp(torch.mean(self.alive_log_token_scores[b])).item())
+ self.token_scores[b_orig].append(torch.exp(self.alive_log_token_scores[b]).tolist())
+ self.predictions[b_orig].append(self.alive_seq[b, 1:]) # skip
+ self.attention[b_orig].append(
+ self.alive_attn[b, :, :self.memory_length] if self.alive_attn is not None else [])
+ self.hidden[b_orig].append(
+ self.alive_hidden[b, :] if self.alive_hidden is not None else [])
+ self.done = self.is_finished.all()
+ if self.done:
+ return
+ is_alive = ~self.is_finished.view(-1)
+ self.alive_seq = self.alive_seq[is_alive]
+ self.alive_log_token_scores = self.alive_log_token_scores[is_alive]
+ if self.alive_attn is not None:
+ self.alive_attn = self.alive_attn[is_alive]
+ if self.alive_hidden is not None:
+ self.alive_hidden = self.alive_hidden[is_alive]
+ self.select_indices = is_alive.nonzero().view(-1)
+ self.original_batch_idx = self.original_batch_idx[is_alive]
+ # select_indices is equal to original_batch_idx for greedy search?
diff --git a/molscribe/interface.py b/molscribe/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a3380dd8502693dca8e3c8b385c65cebbdd9484
--- /dev/null
+++ b/molscribe/interface.py
@@ -0,0 +1,223 @@
+import argparse
+from typing import List
+
+import cv2
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+from .dataset import get_transforms
+from .model import Encoder, Decoder
+from .chemistry import convert_graph_to_smiles
+from .tokenizer import get_tokenizer
+
+
+BOND_TYPES = ["", "single", "double", "triple", "aromatic", "solid wedge", "dashed wedge"]
+
+
+def safe_load(module, module_states):
+ def remove_prefix(state_dict):
+ return {k.replace('module.', ''): v for k, v in state_dict.items()}
+ missing_keys, unexpected_keys = module.load_state_dict(remove_prefix(module_states), strict=False)
+ return
+
+
+class MolScribe:
+
+ def __init__(self, model_path, device=None):
+ """
+ MolScribe Interface
+ :param model_path: path of the model checkpoint.
+ :param device: torch device, defaults to be CPU.
+ """
+ model_states = torch.load(model_path, map_location=torch.device('cpu'))
+ args = self._get_args(model_states['args'])
+ if device is None:
+ device = torch.device('cpu')
+ self.device = device
+ self.tokenizer = get_tokenizer(args)
+ self.encoder, self.decoder = self._get_model(args, self.tokenizer, self.device, model_states)
+ self.transform = get_transforms(args.input_size, augment=False)
+
+ def _get_args(self, args_states=None):
+ parser = argparse.ArgumentParser()
+ # Model
+ parser.add_argument('--encoder', type=str, default='swin_base')
+ parser.add_argument('--decoder', type=str, default='transformer')
+ parser.add_argument('--trunc_encoder', action='store_true') # use the hidden states before downsample
+ parser.add_argument('--no_pretrained', action='store_true')
+ parser.add_argument('--use_checkpoint', action='store_true', default=True)
+ parser.add_argument('--dropout', type=float, default=0.5)
+ parser.add_argument('--embed_dim', type=int, default=256)
+ parser.add_argument('--enc_pos_emb', action='store_true')
+ group = parser.add_argument_group("transformer_options")
+ group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6)
+ group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256)
+ group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8)
+ group.add_argument("--dec_num_queries", type=int, default=128)
+ group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1)
+ group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1)
+ group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0)
+ parser.add_argument('--continuous_coords', action='store_true')
+ parser.add_argument('--compute_confidence', action='store_true')
+ # Data
+ parser.add_argument('--input_size', type=int, default=384)
+ parser.add_argument('--vocab_file', type=str, default=None)
+ parser.add_argument('--coord_bins', type=int, default=64)
+ parser.add_argument('--sep_xy', action='store_true', default=True)
+
+ args = parser.parse_args([])
+ if args_states:
+ for key, value in args_states.items():
+ args.__dict__[key] = value
+ return args
+
+ def _get_model(self, args, tokenizer, device, states):
+ encoder = Encoder(args, pretrained=False)
+ args.encoder_dim = encoder.n_features
+ decoder = Decoder(args, tokenizer)
+
+ safe_load(encoder, states['encoder'])
+ safe_load(decoder, states['decoder'])
+ # print(f"Model loaded from {load_path}")
+
+ encoder.to(device)
+ decoder.to(device)
+ encoder.eval()
+ decoder.eval()
+ return encoder, decoder
+
+ def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=16):
+ device = self.device
+ predictions = []
+ self.decoder.compute_confidence = return_confidence
+
+ for idx in range(0, len(input_images), batch_size):
+ batch_images = input_images[idx:idx+batch_size]
+ images = [self.transform(image=image, keypoints=[])['image'] for image in batch_images]
+ images = torch.stack(images, dim=0).to(device)
+ with torch.no_grad():
+ features, hiddens = self.encoder(images)
+ batch_predictions = self.decoder.decode(features, hiddens)
+ predictions += batch_predictions
+
+ return self.convert_graph_to_output(predictions, input_images, return_confidence, return_atoms_bonds)
+
+
+ def convert_graph_to_output(self, predictions, input_images, return_confidence=True, return_atoms_bonds=True):
+ node_coords = [pred['chartok_coords']['coords'] for pred in predictions]
+ node_symbols = [pred['chartok_coords']['symbols'] for pred in predictions]
+ edges = [pred['edges'] for pred in predictions]
+ # node_symbols = [r_groups[symbol] if symbol in r_groups else symbol for symbol in node_symbols]
+ smiles_list, molblock_list, r_success = convert_graph_to_smiles(
+ node_coords, node_symbols, edges, images=input_images)
+
+ outputs = []
+ for smiles, molblock, pred in zip(smiles_list, molblock_list, predictions):
+ pred_dict = {"smiles": smiles, "molfile": molblock, "oringinal_coords": pred['chartok_coords']['coords'], "original_symbols": pred['chartok_coords']['symbols'], "orignal_edges": pred['edges']}
+ if return_confidence:
+ pred_dict["confidence"] = pred["overall_score"]
+ if return_atoms_bonds:
+ coords = pred['chartok_coords']['coords']
+ symbols = pred['chartok_coords']['symbols']
+
+
+ # get atoms info
+ atom_list = []
+ for i, (symbol, coord) in enumerate(zip(symbols, coords)):
+ atom_dict = {"atom_symbol": symbol, "x": round(coord[0],3), "y": round(coord[1],3)}
+ if return_confidence:
+ atom_dict["confidence"] = pred['chartok_coords']['atom_scores'][i]
+ atom_list.append(atom_dict)
+ pred_dict["atoms"] = atom_list
+ # get bonds info
+ bond_list = []
+ num_atoms = len(symbols)
+ for i in range(num_atoms-1):
+ for j in range(i+1, num_atoms):
+ bond_type_int = pred['edges'][i][j]
+ if bond_type_int != 0:
+ bond_type_str = BOND_TYPES[bond_type_int]
+ bond_dict = {"bond_type": bond_type_str, "endpoint_atoms": (i, j)}
+ if return_confidence:
+ bond_dict["confidence"] = pred["edge_scores"][i][j]
+ bond_list.append(bond_dict)
+ pred_dict["bonds"] = bond_list
+ outputs.append(pred_dict)
+ return outputs
+
+ def predict_image(self, image, return_atoms_bonds=False, return_confidence=False):
+ return self.predict_images([
+ image], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0]
+
+ def predict_image_files(self, image_files: List, return_atoms_bonds=False, return_confidence=False):
+ input_images = []
+ for path in image_files:
+ image = cv2.imread(path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ input_images.append(image)
+ return self.predict_images(
+ input_images, return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)
+
+ def predict_image_file(self, image_file: str, return_atoms_bonds=False, return_confidence=False):
+ return self.predict_image_files(
+ [image_file], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0]
+
+ def draw_prediction(self, prediction, image, notebook=False):
+ if "atoms" not in prediction or "bonds" not in prediction:
+ raise ValueError("atoms and bonds information are not provided.")
+ h, w, _ = image.shape
+ h, w = np.array([h, w]) * 400 / max(h, w)
+ image = cv2.resize(image, (int(w), int(h)))
+ fig, ax = plt.subplots(1, 1)
+ ax.axis('off')
+ ax.set_xlim(-0.05 * w, w * 1.05)
+ ax.set_ylim(1.05 * h, -0.05 * h)
+ plt.imshow(image, alpha=0.)
+ x = [a['x'] * w for a in prediction['atoms']]
+ y = [a['y'] * h for a in prediction['atoms']]
+ markersize = min(w, h) / 3
+ plt.scatter(x, y, marker='o', s=markersize, color='lightskyblue', zorder=10)
+ for i, atom in enumerate(prediction['atoms']):
+ symbol = atom['atom_symbol'].lstrip('[').rstrip(']')
+ plt.annotate(symbol, xy=(x[i], y[i]), ha='center', va='center', color='black', zorder=100)
+ for bond in prediction['bonds']:
+ u, v = bond['endpoint_atoms']
+ x1, y1, x2, y2 = x[u], y[u], x[v], y[v]
+ bond_type = bond['bond_type']
+ if bond_type == 'single':
+ color = 'tab:green'
+ ax.plot([x1, x2], [y1, y2], color, linewidth=4)
+ elif bond_type == 'aromatic':
+ color = 'tab:purple'
+ ax.plot([x1, x2], [y1, y2], color, linewidth=4)
+ elif bond_type == 'double':
+ color = 'tab:green'
+ ax.plot([x1, x2], [y1, y2], color=color, linewidth=7)
+ ax.plot([x1, x2], [y1, y2], color='w', linewidth=1.5, zorder=2.1)
+ elif bond_type == 'triple':
+ color = 'tab:green'
+ x1s, x2s = 0.8 * x1 + 0.2 * x2, 0.2 * x1 + 0.8 * x2
+ y1s, y2s = 0.8 * y1 + 0.2 * y2, 0.2 * y1 + 0.8 * y2
+ ax.plot([x1s, x2s], [y1s, y2s], color=color, linewidth=9)
+ ax.plot([x1, x2], [y1, y2], color='w', linewidth=5, zorder=2.05)
+ ax.plot([x1, x2], [y1, y2], color=color, linewidth=2, zorder=2.1)
+ else:
+ length = 10
+ width = 10
+ color = 'tab:green'
+ if bond_type == 'solid wedge':
+ ax.annotate('', xy=(x1, y1), xytext=(x2, y2),
+ arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2)
+ else:
+ ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
+ arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2)
+ fig.tight_layout()
+ if not notebook:
+ canvas = FigureCanvasAgg(fig)
+ canvas.draw()
+ buf = canvas.buffer_rgba()
+ result_image = np.asarray(buf)
+ plt.close(fig)
+ return result_image
diff --git a/molscribe/loss.py b/molscribe/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..24d7dc33c24e5be8d36c9a5b11f504cd9e64b389
--- /dev/null
+++ b/molscribe/loss.py
@@ -0,0 +1,125 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+from .tokenizer import PAD_ID, MASK, MASK_ID
+
+
+class LabelSmoothingLoss(nn.Module):
+ """
+ With label smoothing,
+ KL-divergence between q_{smoothed ground truth prob.}(w)
+ and p_{prob. computed by model}(w) is minimized.
+ """
+ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
+ assert 0.0 < label_smoothing <= 1.0
+ self.ignore_index = ignore_index
+ super(LabelSmoothingLoss, self).__init__()
+
+ smoothing_value = label_smoothing / (tgt_vocab_size - 2)
+ one_hot = torch.full((tgt_vocab_size,), smoothing_value)
+ one_hot[self.ignore_index] = 0
+ self.register_buffer('one_hot', one_hot.unsqueeze(0))
+
+ self.confidence = 1.0 - label_smoothing
+
+ def forward(self, output, target):
+ """
+ output (FloatTensor): batch_size x n_classes
+ target (LongTensor): batch_size
+ """
+ # assuming output is raw logits
+ # convert to log_probs
+ log_probs = F.log_softmax(output, dim=-1)
+
+ model_prob = self.one_hot.repeat(target.size(0), 1)
+ model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
+ model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
+
+ # reduction mean or sum?
+ return F.kl_div(log_probs, model_prob, reduction='batchmean')
+
+
+class SequenceLoss(nn.Module):
+
+ def __init__(self, label_smoothing, vocab_size, ignore_index=-100, ignore_indices=[]):
+ super(SequenceLoss, self).__init__()
+ if ignore_indices:
+ ignore_index = ignore_indices[0]
+ self.ignore_index = ignore_index
+ self.ignore_indices = ignore_indices
+ if label_smoothing == 0:
+ self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean')
+ else:
+ self.criterion = LabelSmoothingLoss(label_smoothing, vocab_size, ignore_index)
+
+ def forward(self, output, target):
+ """
+ :param output: [batch, len, vocab]
+ :param target: [batch, len]
+ :return:
+ """
+ batch_size, max_len, vocab_size = output.size()
+ output = output.reshape(-1, vocab_size)
+ target = target.reshape(-1)
+ for idx in self.ignore_indices:
+ if idx != self.ignore_index:
+ target.masked_fill_((target == idx), self.ignore_index)
+ loss = self.criterion(output, target)
+ return loss
+
+
+class GraphLoss(nn.Module):
+
+ def __init__(self):
+ super(GraphLoss, self).__init__()
+ weight = torch.ones(7) * 10
+ weight[0] = 1
+ self.criterion = nn.CrossEntropyLoss(weight, ignore_index=-100)
+
+ def forward(self, outputs, targets):
+ results = {}
+ if 'coords' in outputs:
+ pred = outputs['coords']
+ max_len = pred.size(1)
+ target = targets['coords'][:, :max_len]
+ mask = target.ge(0)
+ loss = F.l1_loss(pred, target, reduction='none')
+ results['coords'] = (loss * mask).sum() / mask.sum()
+ if 'edges' in outputs:
+ pred = outputs['edges']
+ max_len = pred.size(-1)
+ target = targets['edges'][:, :max_len, :max_len]
+ results['edges'] = self.criterion(pred, target)
+ return results
+
+
+class Criterion(nn.Module):
+
+ def __init__(self, args, tokenizer):
+ super(Criterion, self).__init__()
+ criterion = {}
+ for format_ in args.formats:
+ if format_ == 'edges':
+ criterion['edges'] = GraphLoss()
+ else:
+ if MASK in tokenizer[format_].stoi:
+ ignore_indices = [PAD_ID, MASK_ID]
+ else:
+ ignore_indices = []
+ criterion[format_] = SequenceLoss(args.label_smoothing, len(tokenizer[format_]),
+ ignore_index=PAD_ID, ignore_indices=ignore_indices)
+ self.criterion = nn.ModuleDict(criterion)
+
+ def forward(self, results, refs):
+ losses = {}
+ for format_ in results:
+ predictions, targets, *_ = results[format_]
+ loss_ = self.criterion[format_](predictions, targets)
+ if type(loss_) is dict:
+ losses.update(loss_)
+ else:
+ if loss_.numel() > 1:
+ loss_ = loss_.mean()
+ losses[format_] = loss_
+ return losses
diff --git a/molscribe/model.py b/molscribe/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff85e087d679daf2b25df787950fb72a0c5fcf8
--- /dev/null
+++ b/molscribe/model.py
@@ -0,0 +1,397 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import timm
+
+from .utils import FORMAT_INFO, to_device
+from .tokenizer import SOS_ID, EOS_ID, PAD_ID, MASK_ID
+from .inference import GreedySearch, BeamSearch
+from .transformer import TransformerDecoder, Embeddings
+
+
+class Encoder(nn.Module):
+ def __init__(self, args, pretrained=False):
+ super().__init__()
+ model_name = args.encoder
+ self.model_name = model_name
+ if model_name.startswith('resnet'):
+ self.model_type = 'resnet'
+ self.cnn = timm.create_model(model_name, pretrained=pretrained)
+ self.n_features = self.cnn.num_features # encoder_dim
+ self.cnn.global_pool = nn.Identity()
+ self.cnn.fc = nn.Identity()
+ elif model_name.startswith('swin'):
+ self.model_type = 'swin'
+ self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False,
+ use_checkpoint=args.use_checkpoint)
+ self.n_features = self.transformer.num_features
+ self.transformer.head = nn.Identity()
+ elif 'efficientnet' in model_name:
+ self.model_type = 'efficientnet'
+ self.cnn = timm.create_model(model_name, pretrained=pretrained)
+ self.n_features = self.cnn.num_features
+ self.cnn.global_pool = nn.Identity()
+ self.cnn.classifier = nn.Identity()
+ else:
+ raise NotImplemented
+
+ def swin_forward(self, transformer, x):
+ x = transformer.patch_embed(x)
+ if transformer.absolute_pos_embed is not None:
+ x = x + transformer.absolute_pos_embed
+ x = transformer.pos_drop(x)
+
+ def layer_forward(layer, x, hiddens):
+ for blk in layer.blocks:
+ if not torch.jit.is_scripting() and layer.use_checkpoint:
+ x = torch.utils.checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ H, W = layer.input_resolution
+ B, L, C = x.shape
+ hiddens.append(x.view(B, H, W, C))
+ if layer.downsample is not None:
+ x = layer.downsample(x)
+ return x, hiddens
+
+ hiddens = []
+ for layer in transformer.layers:
+ x, hiddens = layer_forward(layer, x, hiddens)
+ x = transformer.norm(x) # B L C
+ hiddens[-1] = x.view_as(hiddens[-1])
+ return x, hiddens
+
+ def forward(self, x, refs=None):
+ if self.model_type in ['resnet', 'efficientnet']:
+ features = self.cnn(x)
+ features = features.permute(0, 2, 3, 1)
+ hiddens = []
+ elif self.model_type == 'swin':
+ if 'patch' in self.model_name:
+ features, hiddens = self.swin_forward(self.transformer, x)
+ else:
+ features, hiddens = self.transformer(x)
+ else:
+ raise NotImplemented
+ return features, hiddens
+
+
+class TransformerDecoderBase(nn.Module):
+
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+
+ self.enc_trans_layer = nn.Sequential(
+ nn.Linear(args.encoder_dim, args.dec_hidden_size)
+ # nn.LayerNorm(args.dec_hidden_size, eps=1e-6)
+ )
+ self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None
+
+ self.decoder = TransformerDecoder(
+ num_layers=args.dec_num_layers,
+ d_model=args.dec_hidden_size,
+ heads=args.dec_attn_heads,
+ d_ff=args.dec_hidden_size * 4,
+ copy_attn=False,
+ self_attn_type="scaled-dot",
+ dropout=args.hidden_dropout,
+ attention_dropout=args.attn_dropout,
+ max_relative_positions=args.max_relative_positions,
+ aan_useffn=False,
+ full_context_alignment=False,
+ alignment_layer=0,
+ alignment_heads=0,
+ pos_ffn_activation_fn='gelu'
+ )
+
+ def enc_transform(self, encoder_out):
+ batch_size = encoder_out.size(0)
+ encoder_dim = encoder_out.size(-1)
+ encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
+ max_len = encoder_out.size(1)
+ device = encoder_out.device
+ if self.enc_pos_emb:
+ pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0)
+ encoder_out = encoder_out + pos_emb
+ encoder_out = self.enc_trans_layer(encoder_out)
+ return encoder_out
+
+
+class TransformerDecoderAR(TransformerDecoderBase):
+ """Autoregressive Transformer Decoder"""
+
+ def __init__(self, args, tokenizer):
+ super().__init__(args)
+ self.tokenizer = tokenizer
+ self.vocab_size = len(self.tokenizer)
+ self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True)
+ self.embeddings = Embeddings(
+ word_vec_size=args.dec_hidden_size,
+ word_vocab_size=self.vocab_size,
+ word_padding_idx=PAD_ID,
+ position_encoding=True,
+ dropout=args.hidden_dropout)
+
+ def dec_embedding(self, tgt, step=None):
+ pad_idx = self.embeddings.word_padding_idx
+ tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt]
+ emb = self.embeddings(tgt, step=step)
+ assert emb.dim() == 3 # batch x len x embedding_dim
+ return emb, tgt_pad_mask
+
+ def forward(self, encoder_out, labels, label_lengths):
+ """Training mode"""
+ batch_size, max_len, _ = encoder_out.size()
+ memory_bank = self.enc_transform(encoder_out)
+
+ tgt = labels.unsqueeze(-1) # (b, t, 1)
+ tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
+ dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask)
+
+ logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v)
+ return logits[:, :-1], labels[:, 1:], dec_out
+
+ def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256,
+ labels=None):
+ """Inference mode. Autoregressively decode the sequence. Only greedy search is supported now. Beam search is
+ out-dated. The labels is used for partial prediction, i.e. part of the sequence is given. In standard decoding,
+ labels=None."""
+ batch_size, max_len, _ = encoder_out.size()
+ memory_bank = self.enc_transform(encoder_out)
+ orig_labels = labels
+
+ if beam_size == 1:
+ decode_strategy = GreedySearch(
+ sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length,
+ pad=PAD_ID, bos=SOS_ID, eos=EOS_ID,
+ return_attention=False, return_hidden=True)
+ else:
+ decode_strategy = BeamSearch(
+ beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length,
+ pad=PAD_ID, bos=SOS_ID, eos=EOS_ID,
+ return_attention=False)
+
+ # adapted from onmt.translate.translator
+ results = {
+ "predictions": None,
+ "scores": None,
+ "attention": None
+ }
+
+ # (2) prep decode_strategy. Possibly repeat src objects.
+ _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank)
+
+ # (3) Begin decoding step by step:
+ for step in range(decode_strategy.max_length):
+ tgt = decode_strategy.current_predictions.view(-1, 1, 1)
+ if labels is not None:
+ label = labels[:, step].view(-1, 1, 1)
+ mask = label.eq(MASK_ID).long()
+ tgt = tgt * mask + label * (1 - mask)
+ tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
+ dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank,
+ tgt_pad_mask=tgt_pad_mask, step=step)
+
+ attn = dec_attn.get("std", None)
+
+ dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v]
+ dec_logits = dec_logits.squeeze(1)
+ log_probs = F.log_softmax(dec_logits, dim=-1)
+
+ if self.tokenizer.output_constraint:
+ output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()]
+ output_mask = torch.tensor(output_mask, device=log_probs.device)
+ log_probs.masked_fill_(output_mask, -10000)
+
+ label = labels[:, step + 1] if labels is not None and step + 1 < labels.size(1) else None
+ decode_strategy.advance(log_probs, attn, dec_out, label)
+ any_finished = decode_strategy.is_finished.any()
+ if any_finished:
+ decode_strategy.update_finished()
+ if decode_strategy.done:
+ break
+
+ select_indices = decode_strategy.select_indices
+ if any_finished:
+ # Reorder states.
+ memory_bank = memory_bank.index_select(0, select_indices)
+ if labels is not None:
+ labels = labels.index_select(0, select_indices)
+ self.map_state(lambda state, dim: state.index_select(dim, select_indices))
+
+ results["scores"] = decode_strategy.scores # fixed to be average of token scores
+ results["token_scores"] = decode_strategy.token_scores
+ results["predictions"] = decode_strategy.predictions
+ results["attention"] = decode_strategy.attention
+ results["hidden"] = decode_strategy.hidden
+ if orig_labels is not None:
+ for i in range(batch_size):
+ pred = results["predictions"][i][0]
+ label = orig_labels[i][1:len(pred) + 1]
+ mask = label.eq(MASK_ID).long()
+ pred = pred[:len(label)]
+ results["predictions"][i][0] = pred * mask + label * (1 - mask)
+
+ return results["predictions"], results['scores'], results["token_scores"], results["hidden"]
+
+ # adapted from onmt.decoders.transformer
+ def map_state(self, fn):
+ def _recursive_map(struct, batch_dim=0):
+ for k, v in struct.items():
+ if v is not None:
+ if isinstance(v, dict):
+ _recursive_map(v)
+ else:
+ struct[k] = fn(v, batch_dim)
+
+ if self.decoder.state["cache"] is not None:
+ _recursive_map(self.decoder.state["cache"])
+
+
+class GraphPredictor(nn.Module):
+
+ def __init__(self, decoder_dim, coords=False):
+ super(GraphPredictor, self).__init__()
+ self.coords = coords
+ self.mlp = nn.Sequential(
+ nn.Linear(decoder_dim * 2, decoder_dim), nn.GELU(),
+ nn.Linear(decoder_dim, 7)
+ )
+ if coords:
+ self.coords_mlp = nn.Sequential(
+ nn.Linear(decoder_dim, decoder_dim), nn.GELU(),
+ nn.Linear(decoder_dim, 2)
+ )
+
+ def forward(self, hidden, indices=None):
+ b, l, dim = hidden.size()
+ if indices is None:
+ index = [i for i in range(3, l, 3)]
+ hidden = hidden[:, index]
+ else:
+ batch_id = torch.arange(b).unsqueeze(1).expand_as(indices).reshape(-1)
+ indices = indices.view(-1)
+ hidden = hidden[batch_id, indices].view(b, -1, dim)
+ b, l, dim = hidden.size()
+ results = {}
+ hh = torch.cat([hidden.unsqueeze(2).expand(b, l, l, dim), hidden.unsqueeze(1).expand(b, l, l, dim)], dim=3)
+ results['edges'] = self.mlp(hh).permute(0, 3, 1, 2)
+ if self.coords:
+ results['coords'] = self.coords_mlp(hidden)
+ return results
+
+
+def get_edge_prediction(edge_prob):
+ if not edge_prob:
+ return [], []
+ n = len(edge_prob)
+ if n == 0:
+ return [], []
+ for i in range(n):
+ for j in range(i + 1, n):
+ for k in range(5):
+ edge_prob[i][j][k] = (edge_prob[i][j][k] + edge_prob[j][i][k]) / 2
+ edge_prob[j][i][k] = edge_prob[i][j][k]
+ edge_prob[i][j][5] = (edge_prob[i][j][5] + edge_prob[j][i][6]) / 2
+ edge_prob[i][j][6] = (edge_prob[i][j][6] + edge_prob[j][i][5]) / 2
+ edge_prob[j][i][5] = edge_prob[i][j][6]
+ edge_prob[j][i][6] = edge_prob[i][j][5]
+ prediction = np.argmax(edge_prob, axis=2).tolist()
+ score = np.max(edge_prob, axis=2).tolist()
+ return prediction, score
+
+
+class Decoder(nn.Module):
+ """This class is a wrapper for different decoder architectures, and support multiple decoders."""
+
+ def __init__(self, args, tokenizer):
+ super(Decoder, self).__init__()
+ self.args = args
+ self.formats = args.formats
+ self.tokenizer = tokenizer
+ decoder = {}
+ for format_ in args.formats:
+ if format_ == 'edges':
+ decoder['edges'] = GraphPredictor(args.dec_hidden_size, coords=args.continuous_coords)
+ else:
+ decoder[format_] = TransformerDecoderAR(args, tokenizer[format_])
+ self.decoder = nn.ModuleDict(decoder)
+ self.compute_confidence = args.compute_confidence
+
+ def forward(self, encoder_out, hiddens, refs):
+ """Training mode. Compute the logits with teacher forcing."""
+ results = {}
+ refs = to_device(refs, encoder_out.device)
+ for format_ in self.formats:
+ if format_ == 'edges':
+ if 'atomtok_coords' in results:
+ dec_out = results['atomtok_coords'][2]
+ predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0])
+ elif 'chartok_coords' in results:
+ dec_out = results['chartok_coords'][2]
+ predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0])
+ else:
+ raise NotImplemented
+ targets = {'edges': refs['edges']}
+ if 'coords' in predictions:
+ targets['coords'] = refs['coords']
+ results['edges'] = (predictions, targets)
+ else:
+ labels, label_lengths = refs[format_]
+ results[format_] = self.decoder[format_](encoder_out, labels, label_lengths)
+ return results
+
+ def decode(self, encoder_out, hiddens=None, refs=None, beam_size=1, n_best=1):
+ """Inference mode. Call each decoder's decode method (if required), convert the output format (e.g. token to
+ sequence). Beam search is not supported yet."""
+ results = {}
+ predictions = []
+ for format_ in self.formats:
+ if format_ in ['atomtok', 'atomtok_coords', 'chartok_coords']:
+ max_len = FORMAT_INFO[format_]['max_len']
+ results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len)
+ outputs, scores, token_scores, *_ = results[format_]
+ beam_preds = [[self.tokenizer[format_].sequence_to_smiles(x.tolist()) for x in pred]
+ for pred in outputs]
+ predictions = [{format_: pred[0]} for pred in beam_preds]
+ if self.compute_confidence:
+ for i in range(len(predictions)):
+ # -1: y score, -2: x score, -3: symbol score
+ indices = np.array(predictions[i][format_]['indices']) - 3
+ if format_ == 'chartok_coords':
+ atom_scores = []
+ for symbol, index in zip(predictions[i][format_]['symbols'], indices):
+ atom_score = (np.prod(token_scores[i][0][index - len(symbol) + 1:index + 1])
+ ** (1 / len(symbol))).item()
+ atom_scores.append(atom_score)
+ else:
+ atom_scores = np.array(token_scores[i][0])[indices].tolist()
+ predictions[i][format_]['atom_scores'] = atom_scores
+ predictions[i][format_]['average_token_score'] = scores[i][0]
+ if format_ == 'edges':
+ if 'atomtok_coords' in results:
+ atom_format = 'atomtok_coords'
+ elif 'chartok_coords' in results:
+ atom_format = 'chartok_coords'
+ else:
+ raise NotImplemented
+ dec_out = results[atom_format][3] # batch x n_best x len x dim
+ for i in range(len(dec_out)):
+ hidden = dec_out[i][0].unsqueeze(0) # 1 * len * dim
+ indices = torch.LongTensor(predictions[i][atom_format]['indices']).unsqueeze(0) # 1 * k
+ pred = self.decoder['edges'](hidden, indices) # k * k
+ prob = F.softmax(pred['edges'].squeeze(0).permute(1, 2, 0), dim=2).tolist() # k * k * 7
+ edge_pred, edge_score = get_edge_prediction(prob)
+ predictions[i]['edges'] = edge_pred
+ if self.compute_confidence:
+ predictions[i]['edge_scores'] = edge_score
+ predictions[i]['edge_score_product'] = np.sqrt(np.prod(edge_score)).item()
+ predictions[i]['overall_score'] = predictions[i][atom_format]['average_token_score'] * \
+ predictions[i]['edge_score_product']
+ predictions[i][atom_format].pop('average_token_score')
+ predictions[i].pop('edge_score_product')
+ return predictions
diff --git a/molscribe/tokenizer.py b/molscribe/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6789ab09433ad13e6f61512cfe6a6f0f307ff75d
--- /dev/null
+++ b/molscribe/tokenizer.py
@@ -0,0 +1,524 @@
+import os
+import json
+import random
+import numpy as np
+from SmilesPE.pretokenizer import atomwise_tokenizer
+
+PAD = ''
+SOS = ''
+EOS = ''
+UNK = ''
+MASK = ''
+PAD_ID = 0
+SOS_ID = 1
+EOS_ID = 2
+UNK_ID = 3
+MASK_ID = 4
+
+
+class Tokenizer(object):
+
+ def __init__(self, path=None):
+ self.stoi = {}
+ self.itos = {}
+ if path:
+ self.load(path)
+
+ def __len__(self):
+ return len(self.stoi)
+
+ @property
+ def output_constraint(self):
+ return False
+
+ def save(self, path):
+ with open(path, 'w') as f:
+ json.dump(self.stoi, f)
+
+ def load(self, path):
+ with open(path) as f:
+ self.stoi = json.load(f)
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
+
+ def fit_on_texts(self, texts):
+ vocab = set()
+ for text in texts:
+ vocab.update(text.split(' '))
+ vocab = [PAD, SOS, EOS, UNK] + list(vocab)
+ for i, s in enumerate(vocab):
+ self.stoi[s] = i
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
+ assert self.stoi[PAD] == PAD_ID
+ assert self.stoi[SOS] == SOS_ID
+ assert self.stoi[EOS] == EOS_ID
+ assert self.stoi[UNK] == UNK_ID
+
+ def text_to_sequence(self, text, tokenized=True):
+ sequence = []
+ sequence.append(self.stoi[''])
+ if tokenized:
+ tokens = text.split(' ')
+ else:
+ tokens = atomwise_tokenizer(text)
+ for s in tokens:
+ if s not in self.stoi:
+ s = ''
+ sequence.append(self.stoi[s])
+ sequence.append(self.stoi[''])
+ return sequence
+
+ def texts_to_sequences(self, texts):
+ sequences = []
+ for text in texts:
+ sequence = self.text_to_sequence(text)
+ sequences.append(sequence)
+ return sequences
+
+ def sequence_to_text(self, sequence):
+ return ''.join(list(map(lambda i: self.itos[i], sequence)))
+
+ def sequences_to_texts(self, sequences):
+ texts = []
+ for sequence in sequences:
+ text = self.sequence_to_text(sequence)
+ texts.append(text)
+ return texts
+
+ def predict_caption(self, sequence):
+ caption = ''
+ for i in sequence:
+ if i == self.stoi[''] or i == self.stoi['']:
+ break
+ caption += self.itos[i]
+ return caption
+
+ def predict_captions(self, sequences):
+ captions = []
+ for sequence in sequences:
+ caption = self.predict_caption(sequence)
+ captions.append(caption)
+ return captions
+
+ def sequence_to_smiles(self, sequence):
+ return {'smiles': self.predict_caption(sequence)}
+
+
+class NodeTokenizer(Tokenizer):
+
+ def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False):
+ super().__init__(path)
+ self.maxx = input_size # height
+ self.maxy = input_size # width
+ self.sep_xy = sep_xy
+ self.special_tokens = [PAD, SOS, EOS, UNK, MASK]
+ self.continuous_coords = continuous_coords
+ self.debug = debug
+
+ def __len__(self):
+ if self.sep_xy:
+ return self.offset + self.maxx + self.maxy
+ else:
+ return self.offset + max(self.maxx, self.maxy)
+
+ @property
+ def offset(self):
+ return len(self.stoi)
+
+ @property
+ def output_constraint(self):
+ return not self.continuous_coords
+
+ def len_symbols(self):
+ return len(self.stoi)
+
+ def fit_atom_symbols(self, atoms):
+ vocab = self.special_tokens + list(set(atoms))
+ for i, s in enumerate(vocab):
+ self.stoi[s] = i
+ assert self.stoi[PAD] == PAD_ID
+ assert self.stoi[SOS] == SOS_ID
+ assert self.stoi[EOS] == EOS_ID
+ assert self.stoi[UNK] == UNK_ID
+ assert self.stoi[MASK] == MASK_ID
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
+
+ def is_x(self, x):
+ return self.offset <= x < self.offset + self.maxx
+
+ def is_y(self, y):
+ if self.sep_xy:
+ return self.offset + self.maxx <= y
+ return self.offset <= y
+
+ def is_symbol(self, s):
+ return len(self.special_tokens) <= s < self.offset or s == UNK_ID
+
+ def is_atom(self, id):
+ if self.is_symbol(id):
+ return self.is_atom_token(self.itos[id])
+ return False
+
+ def is_atom_token(self, token):
+ return token.isalpha() or token.startswith("[") or token == '*' or token == UNK
+
+ def x_to_id(self, x):
+ return self.offset + round(x * (self.maxx - 1))
+
+ def y_to_id(self, y):
+ if self.sep_xy:
+ return self.offset + self.maxx + round(y * (self.maxy - 1))
+ return self.offset + round(y * (self.maxy - 1))
+
+ def id_to_x(self, id):
+ return (id - self.offset) / (self.maxx - 1)
+
+ def id_to_y(self, id):
+ if self.sep_xy:
+ return (id - self.offset - self.maxx) / (self.maxy - 1)
+ return (id - self.offset) / (self.maxy - 1)
+
+ def get_output_mask(self, id):
+ mask = [False] * len(self)
+ if self.continuous_coords:
+ return mask
+ if self.is_atom(id):
+ return [True] * self.offset + [False] * self.maxx + [True] * self.maxy
+ if self.is_x(id):
+ return [True] * (self.offset + self.maxx) + [False] * self.maxy
+ if self.is_y(id):
+ return [False] * self.offset + [True] * (self.maxx + self.maxy)
+ return mask
+
+ def symbol_to_id(self, symbol):
+ if symbol not in self.stoi:
+ return UNK_ID
+ return self.stoi[symbol]
+
+ def symbols_to_labels(self, symbols):
+ labels = []
+ for symbol in symbols:
+ labels.append(self.symbol_to_id(symbol))
+ return labels
+
+ def labels_to_symbols(self, labels):
+ symbols = []
+ for label in labels:
+ symbols.append(self.itos[label])
+ return symbols
+
+ def nodes_to_grid(self, nodes):
+ coords, symbols = nodes['coords'], nodes['symbols']
+ grid = np.zeros((self.maxx, self.maxy), dtype=int)
+ for [x, y], symbol in zip(coords, symbols):
+ x = round(x * (self.maxx - 1))
+ y = round(y * (self.maxy - 1))
+ grid[x][y] = self.symbol_to_id(symbol)
+ return grid
+
+ def grid_to_nodes(self, grid):
+ coords, symbols, indices = [], [], []
+ for i in range(self.maxx):
+ for j in range(self.maxy):
+ if grid[i][j] != 0:
+ x = i / (self.maxx - 1)
+ y = j / (self.maxy - 1)
+ coords.append([x, y])
+ symbols.append(self.itos[grid[i][j]])
+ indices.append([i, j])
+ return {'coords': coords, 'symbols': symbols, 'indices': indices}
+
+ def nodes_to_sequence(self, nodes):
+ coords, symbols = nodes['coords'], nodes['symbols']
+ labels = [SOS_ID]
+ for (x, y), symbol in zip(coords, symbols):
+ assert 0 <= x <= 1
+ assert 0 <= y <= 1
+ labels.append(self.x_to_id(x))
+ labels.append(self.y_to_id(y))
+ labels.append(self.symbol_to_id(symbol))
+ labels.append(EOS_ID)
+ return labels
+
+ def sequence_to_nodes(self, sequence):
+ coords, symbols = [], []
+ i = 0
+ if sequence[0] == SOS_ID:
+ i += 1
+ while i + 2 < len(sequence):
+ if sequence[i] == EOS_ID:
+ break
+ if self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]):
+ x = self.id_to_x(sequence[i])
+ y = self.id_to_y(sequence[i+1])
+ symbol = self.itos[sequence[i+2]]
+ coords.append([x, y])
+ symbols.append(symbol)
+ i += 3
+ return {'coords': coords, 'symbols': symbols}
+
+ def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False):
+ tokens = atomwise_tokenizer(smiles)
+ labels = [SOS_ID]
+ indices = []
+ atom_idx = -1
+ for token in tokens:
+ if atom_only and not self.is_atom_token(token):
+ continue
+ if token in self.stoi:
+ labels.append(self.stoi[token])
+ else:
+ if self.debug:
+ print(f'{token} not in vocab')
+ labels.append(UNK_ID)
+ if self.is_atom_token(token):
+ atom_idx += 1
+ if not self.continuous_coords:
+ if mask_ratio > 0 and random.random() < mask_ratio:
+ labels.append(MASK_ID)
+ labels.append(MASK_ID)
+ elif coords is not None:
+ if atom_idx < len(coords):
+ x, y = coords[atom_idx]
+ assert 0 <= x <= 1
+ assert 0 <= y <= 1
+ else:
+ x = random.random()
+ y = random.random()
+ labels.append(self.x_to_id(x))
+ labels.append(self.y_to_id(y))
+ indices.append(len(labels) - 1)
+ labels.append(EOS_ID)
+ return labels, indices
+
+ def sequence_to_smiles(self, sequence):
+ has_coords = not self.continuous_coords
+ smiles = ''
+ coords, symbols, indices = [], [], []
+ for i, label in enumerate(sequence):
+ if label == EOS_ID or label == PAD_ID:
+ break
+ if self.is_x(label) or self.is_y(label):
+ continue
+ token = self.itos[label]
+ smiles += token
+ if self.is_atom_token(token):
+ if has_coords:
+ if i+3 < len(sequence) and self.is_x(sequence[i+1]) and self.is_y(sequence[i+2]):
+ x = self.id_to_x(sequence[i+1])
+ y = self.id_to_y(sequence[i+2])
+ coords.append([x, y])
+ symbols.append(token)
+ indices.append(i+3)
+ else:
+ if i+1 < len(sequence):
+ symbols.append(token)
+ indices.append(i+1)
+ results = {'smiles': smiles, 'symbols': symbols, 'indices': indices}
+ if has_coords:
+ results['coords'] = coords
+ return results
+
+
+class CharTokenizer(NodeTokenizer):
+
+ def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False):
+ super().__init__(input_size, path, sep_xy, continuous_coords, debug)
+
+ def fit_on_texts(self, texts):
+ vocab = set()
+ for text in texts:
+ vocab.update(list(text))
+ if ' ' in vocab:
+ vocab.remove(' ')
+ vocab = [PAD, SOS, EOS, UNK] + list(vocab)
+ for i, s in enumerate(vocab):
+ self.stoi[s] = i
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
+ assert self.stoi[PAD] == PAD_ID
+ assert self.stoi[SOS] == SOS_ID
+ assert self.stoi[EOS] == EOS_ID
+ assert self.stoi[UNK] == UNK_ID
+
+ def text_to_sequence(self, text, tokenized=True):
+ sequence = []
+ sequence.append(self.stoi[''])
+ if tokenized:
+ tokens = text.split(' ')
+ assert all(len(s) == 1 for s in tokens)
+ else:
+ tokens = list(text)
+ for s in tokens:
+ if s not in self.stoi:
+ s = ''
+ sequence.append(self.stoi[s])
+ sequence.append(self.stoi[''])
+ return sequence
+
+ def fit_atom_symbols(self, atoms):
+ atoms = list(set(atoms))
+ chars = []
+ for atom in atoms:
+ chars.extend(list(atom))
+ vocab = self.special_tokens + chars
+ for i, s in enumerate(vocab):
+ self.stoi[s] = i
+ assert self.stoi[PAD] == PAD_ID
+ assert self.stoi[SOS] == SOS_ID
+ assert self.stoi[EOS] == EOS_ID
+ assert self.stoi[UNK] == UNK_ID
+ assert self.stoi[MASK] == MASK_ID
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
+
+ def get_output_mask(self, id):
+ ''' TO FIX '''
+ mask = [False] * len(self)
+ if self.continuous_coords:
+ return mask
+ if self.is_x(id):
+ return [True] * (self.offset + self.maxx) + [False] * self.maxy
+ if self.is_y(id):
+ return [False] * self.offset + [True] * (self.maxx + self.maxy)
+ return mask
+
+ def nodes_to_sequence(self, nodes):
+ coords, symbols = nodes['coords'], nodes['symbols']
+ labels = [SOS_ID]
+ for (x, y), symbol in zip(coords, symbols):
+ assert 0 <= x <= 1
+ assert 0 <= y <= 1
+ labels.append(self.x_to_id(x))
+ labels.append(self.y_to_id(y))
+ for char in symbol:
+ labels.append(self.symbol_to_id(char))
+ labels.append(EOS_ID)
+ return labels
+
+ def sequence_to_nodes(self, sequence):
+ coords, symbols = [], []
+ i = 0
+ if sequence[0] == SOS_ID:
+ i += 1
+ while i < len(sequence):
+ if sequence[i] == EOS_ID:
+ break
+ if i+2 < len(sequence) and self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]):
+ x = self.id_to_x(sequence[i])
+ y = self.id_to_y(sequence[i+1])
+ for j in range(i+2, len(sequence)):
+ if not self.is_symbol(sequence[j]):
+ break
+ symbol = ''.join(self.itos(sequence[k]) for k in range(i+2, j))
+ coords.append([x, y])
+ symbols.append(symbol)
+ i = j
+ else:
+ i += 1
+ return {'coords': coords, 'symbols': symbols}
+
+ def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False):
+ tokens = atomwise_tokenizer(smiles)
+ labels = [SOS_ID]
+ indices = []
+ atom_idx = -1
+ for token in tokens:
+ if atom_only and not self.is_atom_token(token):
+ continue
+ for c in token:
+ if c in self.stoi:
+ labels.append(self.stoi[c])
+ else:
+ if self.debug:
+ print(f'{c} not in vocab')
+ labels.append(UNK_ID)
+ if self.is_atom_token(token):
+ atom_idx += 1
+ if not self.continuous_coords:
+ if mask_ratio > 0 and random.random() < mask_ratio:
+ labels.append(MASK_ID)
+ labels.append(MASK_ID)
+ elif coords is not None:
+ if atom_idx < len(coords):
+ x, y = coords[atom_idx]
+ assert 0 <= x <= 1
+ assert 0 <= y <= 1
+ else:
+ x = random.random()
+ y = random.random()
+ labels.append(self.x_to_id(x))
+ labels.append(self.y_to_id(y))
+ indices.append(len(labels) - 1)
+ labels.append(EOS_ID)
+ return labels, indices
+
+ def sequence_to_smiles(self, sequence):
+ has_coords = not self.continuous_coords
+ smiles = ''
+ coords, symbols, indices = [], [], []
+ i = 0
+ while i < len(sequence):
+ label = sequence[i]
+ if label == EOS_ID or label == PAD_ID:
+ break
+ if self.is_x(label) or self.is_y(label):
+ i += 1
+ continue
+ if not self.is_atom(label):
+ smiles += self.itos[label]
+ i += 1
+ continue
+ if self.itos[label] == '[':
+ j = i + 1
+ while j < len(sequence):
+ if not self.is_symbol(sequence[j]):
+ break
+ if self.itos[sequence[j]] == ']':
+ j += 1
+ break
+ j += 1
+ else:
+ if i+1 < len(sequence) and (self.itos[label] == 'C' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'l' \
+ or self.itos[label] == 'B' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'r'):
+ j = i+2
+ else:
+ j = i+1
+ token = ''.join(self.itos[sequence[k]] for k in range(i, j))
+ smiles += token
+ if has_coords:
+ if j+2 < len(sequence) and self.is_x(sequence[j]) and self.is_y(sequence[j+1]):
+ x = self.id_to_x(sequence[j])
+ y = self.id_to_y(sequence[j+1])
+ coords.append([x, y])
+ symbols.append(token)
+ indices.append(j+2)
+ i = j+2
+ else:
+ i = j
+ else:
+ if j < len(sequence):
+ symbols.append(token)
+ indices.append(j)
+ i = j
+ results = {'smiles': smiles, 'symbols': symbols, 'indices': indices}
+ if has_coords:
+ results['coords'] = coords
+ return results
+
+
+def get_tokenizer(args):
+ tokenizer = {}
+ for format_ in args.formats:
+ if format_ == 'atomtok':
+ if args.vocab_file is None:
+ args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json')
+ tokenizer['atomtok'] = Tokenizer(args.vocab_file)
+ elif format_ == "atomtok_coords":
+ if args.vocab_file is None:
+ args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json')
+ tokenizer["atomtok_coords"] = NodeTokenizer(args.coord_bins, args.vocab_file, args.sep_xy,
+ continuous_coords=args.continuous_coords)
+ elif format_ == "chartok_coords":
+ if args.vocab_file is None:
+ args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_chars.json')
+ tokenizer["chartok_coords"] = CharTokenizer(args.coord_bins, args.vocab_file, args.sep_xy,
+ continuous_coords=args.continuous_coords)
+ return tokenizer
diff --git a/molscribe/transformer/__init__.py b/molscribe/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9c953b157c20f4c8bd13345f0a846fd70e0815e
--- /dev/null
+++ b/molscribe/transformer/__init__.py
@@ -0,0 +1,3 @@
+from .decoder import TransformerDecoder
+from .embedding import Embeddings
+from .swin_transformer import swin_base, swin_large
diff --git a/molscribe/transformer/__pycache__/__init__.cpython-310.pyc b/molscribe/transformer/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47ce9e2fc212766b515f99c90dbe1b63e2c7e59e
Binary files /dev/null and b/molscribe/transformer/__pycache__/__init__.cpython-310.pyc differ
diff --git a/molscribe/transformer/__pycache__/decoder.cpython-310.pyc b/molscribe/transformer/__pycache__/decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73723791b9f0a10fbdb97881b250a0c7cf49e20e
Binary files /dev/null and b/molscribe/transformer/__pycache__/decoder.cpython-310.pyc differ
diff --git a/molscribe/transformer/__pycache__/embedding.cpython-310.pyc b/molscribe/transformer/__pycache__/embedding.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a79b6c699fcb371619c293667584178d38912b19
Binary files /dev/null and b/molscribe/transformer/__pycache__/embedding.cpython-310.pyc differ
diff --git a/molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc b/molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f099d220ec4f1f06c5c14522a5fbe8f22731463e
Binary files /dev/null and b/molscribe/transformer/__pycache__/swin_transformer.cpython-310.pyc differ
diff --git a/molscribe/transformer/decoder.py b/molscribe/transformer/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a04a96aa4c472fd00d3e8a9d470bd62d380b3202
--- /dev/null
+++ b/molscribe/transformer/decoder.py
@@ -0,0 +1,487 @@
+"""
+Implementation of "Attention is All You Need" and of
+subsequent transformer based architectures
+"""
+
+import torch
+import torch.nn as nn
+
+from onmt.decoders.decoder import DecoderBase
+from onmt.modules import MultiHeadedAttention, AverageAttention
+from onmt.modules.position_ffn import PositionwiseFeedForward
+from onmt.modules.position_ffn import ActivationFunction
+from onmt.utils.misc import sequence_mask
+
+
+class TransformerDecoderLayerBase(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type="scaled-dot",
+ max_relative_positions=0,
+ aan_useffn=False,
+ full_context_alignment=False,
+ alignment_heads=0,
+ pos_ffn_activation_fn=ActivationFunction.relu,
+ ):
+ """
+ Args:
+ d_model (int): the dimension of keys/values/queries in
+ :class:`MultiHeadedAttention`, also the input size of
+ the first-layer of the :class:`PositionwiseFeedForward`.
+ heads (int): the number of heads for MultiHeadedAttention.
+ d_ff (int): the second-layer of the
+ :class:`PositionwiseFeedForward`.
+ dropout (float): dropout in residual, self-attn(dot) and
+ feed-forward
+ attention_dropout (float): dropout in context_attn (and
+ self-attn(avg))
+ self_attn_type (string): type of self-attention scaled-dot,
+ average
+ max_relative_positions (int):
+ Max distance between inputs in relative positions
+ representations
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
+ full_context_alignment (bool):
+ whether enable an extra full context decoder forward for
+ alignment
+ alignment_heads (int):
+ N. of cross attention heads to use for alignment guiding
+ pos_ffn_activation_fn (ActivationFunction):
+ activation function choice for PositionwiseFeedForward layer
+
+ """
+ super(TransformerDecoderLayerBase, self).__init__()
+
+ if self_attn_type == "scaled-dot":
+ self.self_attn = MultiHeadedAttention(
+ heads,
+ d_model,
+ dropout=attention_dropout,
+ max_relative_positions=max_relative_positions,
+ )
+ elif self_attn_type == "average":
+ self.self_attn = AverageAttention(
+ d_model, dropout=attention_dropout, aan_useffn=aan_useffn
+ )
+
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
+ pos_ffn_activation_fn
+ )
+ self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
+ self.drop = nn.Dropout(dropout)
+ self.full_context_alignment = full_context_alignment
+ self.alignment_heads = alignment_heads
+
+ def forward(self, *args, **kwargs):
+ """Extend `_forward` for (possibly) multiple decoder pass:
+ Always a default (future masked) decoder forward pass,
+ Possibly a second future aware decoder pass for joint learn
+ full context alignement, :cite:`garg2019jointly`.
+
+ Args:
+ * All arguments of _forward.
+ with_align (bool): whether return alignment attention.
+
+ Returns:
+ (FloatTensor, FloatTensor, FloatTensor or None):
+
+ * output ``(batch_size, T, model_dim)``
+ * top_attn ``(batch_size, T, src_len)``
+ * attn_align ``(batch_size, T, src_len)`` or None
+ """
+ with_align = kwargs.pop("with_align", False)
+ output, attns = self._forward(*args, **kwargs)
+ top_attn = attns[:, 0, :, :].contiguous()
+ attn_align = None
+ if with_align:
+ if self.full_context_alignment:
+ # return _, (B, Q_len, K_len)
+ _, attns = self._forward(*args, **kwargs, future=True)
+
+ if self.alignment_heads > 0:
+ attns = attns[:, : self.alignment_heads, :, :].contiguous()
+ # layer average attention across heads, get ``(B, Q, K)``
+ # Case 1: no full_context, no align heads -> layer avg baseline
+ # Case 2: no full_context, 1 align heads -> guided align
+ # Case 3: full_context, 1 align heads -> full cte guided align
+ attn_align = attns.mean(dim=1)
+ return output, top_attn, attn_align
+
+ def update_dropout(self, dropout, attention_dropout):
+ self.self_attn.update_dropout(attention_dropout)
+ self.feed_forward.update_dropout(dropout)
+ self.drop.p = dropout
+
+ def _forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _compute_dec_mask(self, tgt_pad_mask, future):
+ tgt_len = tgt_pad_mask.size(-1)
+ if not future: # apply future_mask, result mask in (B, T, T)
+ future_mask = torch.ones(
+ [tgt_len, tgt_len],
+ device=tgt_pad_mask.device,
+ dtype=torch.uint8,
+ )
+ future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
+ # BoolTensor was introduced in pytorch 1.2
+ try:
+ future_mask = future_mask.bool()
+ except AttributeError:
+ pass
+ dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
+ else: # only mask padding, result mask in (B, 1, T)
+ dec_mask = tgt_pad_mask
+ return dec_mask
+
+ def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
+ if isinstance(self.self_attn, MultiHeadedAttention):
+ return self.self_attn(
+ inputs_norm,
+ inputs_norm,
+ inputs_norm,
+ mask=dec_mask,
+ layer_cache=layer_cache,
+ attn_type="self",
+ )
+ elif isinstance(self.self_attn, AverageAttention):
+ return self.self_attn(
+ inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step
+ )
+ else:
+ raise ValueError(
+ f"self attention {type(self.self_attn)} not supported"
+ )
+
+
+class TransformerDecoderLayer(TransformerDecoderLayerBase):
+ """Transformer Decoder layer block in Pre-Norm style.
+ Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
+ providing better converge speed and performance. This is also the actual
+ implementation in tensor2tensor and also avalable in fairseq.
+ See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
+
+ .. mermaid::
+
+ graph LR
+ %% "*SubLayer" can be self-attn, src-attn or feed forward block
+ A(input) --> B[Norm]
+ B --> C["*SubLayer"]
+ C --> D[Drop]
+ D --> E((+))
+ A --> E
+ E --> F(out)
+
+ """
+
+ def __init__(
+ self,
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type="scaled-dot",
+ max_relative_positions=0,
+ aan_useffn=False,
+ full_context_alignment=False,
+ alignment_heads=0,
+ pos_ffn_activation_fn=ActivationFunction.relu,
+ ):
+ """
+ Args:
+ See TransformerDecoderLayerBase
+ """
+ super(TransformerDecoderLayer, self).__init__(
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type,
+ max_relative_positions,
+ aan_useffn,
+ full_context_alignment,
+ alignment_heads,
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
+ )
+ self.context_attn = MultiHeadedAttention(
+ heads, d_model, dropout=attention_dropout
+ )
+ self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
+
+ def update_dropout(self, dropout, attention_dropout):
+ super(TransformerDecoderLayer, self).update_dropout(
+ dropout, attention_dropout
+ )
+ self.context_attn.update_dropout(attention_dropout)
+
+ def _forward(
+ self,
+ inputs,
+ memory_bank,
+ src_pad_mask,
+ tgt_pad_mask,
+ layer_cache=None,
+ step=None,
+ future=False,
+ ):
+ """A naive forward pass for transformer decoder.
+
+ # T: could be 1 in the case of stepwise decoding or tgt_len
+
+ Args:
+ inputs (FloatTensor): ``(batch_size, T, model_dim)``
+ memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
+ src_pad_mask (bool): ``(batch_size, 1, src_len)``
+ tgt_pad_mask (bool): ``(batch_size, 1, T)``
+ layer_cache (dict or None): cached layer info when stepwise decode
+ step (int or None): stepwise decoding counter
+ future (bool): If set True, do not apply future_mask.
+
+ Returns:
+ (FloatTensor, FloatTensor):
+
+ * output ``(batch_size, T, model_dim)``
+ * attns ``(batch_size, head, T, src_len)``
+
+ """
+ dec_mask = None
+
+ if inputs.size(1) > 1:
+ # masking is necessary when sequence length is greater than one
+ dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
+
+ inputs_norm = self.layer_norm_1(inputs)
+
+ query, _ = self._forward_self_attn(
+ inputs_norm, dec_mask, layer_cache, step
+ )
+
+ query = self.drop(query) + inputs
+
+ query_norm = self.layer_norm_2(query)
+ mid, attns = self.context_attn(
+ memory_bank,
+ memory_bank,
+ query_norm,
+ mask=src_pad_mask,
+ layer_cache=layer_cache,
+ attn_type="context",
+ )
+ output = self.feed_forward(self.drop(mid) + query)
+
+ return output, attns
+
+
+class TransformerDecoderBase(DecoderBase):
+ def __init__(self, d_model, copy_attn, alignment_layer):
+ super(TransformerDecoderBase, self).__init__()
+
+ # Decoder State
+ self.state = {}
+
+ # previously, there was a GlobalAttention module here for copy
+ # attention. But it was never actually used -- the "copy" attention
+ # just reuses the context attention.
+ self._copy = copy_attn
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ self.alignment_layer = alignment_layer
+
+ @classmethod
+ def from_opt(cls, opt, embeddings):
+ """Alternate constructor."""
+ return cls(
+ opt.dec_layers,
+ opt.dec_rnn_size,
+ opt.heads,
+ opt.transformer_ff,
+ opt.copy_attn,
+ opt.self_attn_type,
+ opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
+ opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout,
+ embeddings,
+ opt.max_relative_positions,
+ opt.aan_useffn,
+ opt.full_context_alignment,
+ opt.alignment_layer,
+ alignment_heads=opt.alignment_heads,
+ pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
+ )
+
+ def init_state(self, src, memory_bank, enc_hidden):
+ """Initialize decoder state."""
+ self.state["src"] = src
+ self.state["cache"] = None
+
+ def map_state(self, fn):
+ def _recursive_map(struct, batch_dim=0):
+ for k, v in struct.items():
+ if v is not None:
+ if isinstance(v, dict):
+ _recursive_map(v)
+ else:
+ struct[k] = fn(v, batch_dim)
+
+ if self.state["src"] is not None:
+ self.state["src"] = fn(self.state["src"], 1)
+ if self.state["cache"] is not None:
+ _recursive_map(self.state["cache"])
+
+ def detach_state(self):
+ raise NotImplementedError
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def update_dropout(self, dropout, attention_dropout):
+ self.embeddings.update_dropout(dropout)
+ for layer in self.transformer_layers:
+ layer.update_dropout(dropout, attention_dropout)
+
+
+class TransformerDecoder(TransformerDecoderBase):
+ """The Transformer decoder from "Attention is All You Need".
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
+
+ .. mermaid::
+
+ graph BT
+ A[input]
+ B[multi-head self-attn]
+ BB[multi-head src-attn]
+ C[feed forward]
+ O[output]
+ A --> B
+ B --> BB
+ BB --> C
+ C --> O
+
+
+ Args:
+ num_layers (int): number of decoder layers.
+ d_model (int): size of the model
+ heads (int): number of heads
+ d_ff (int): size of the inner FF layer
+ copy_attn (bool): if using a separate copy attention
+ self_attn_type (str): type of self-attention scaled-dot, average
+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
+ embeddings (onmt.modules.Embeddings):
+ embeddings to use, should have positional encodings
+ max_relative_positions (int):
+ Max distance between inputs in relative positions representations
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
+ full_context_alignment (bool):
+ whether enable an extra full context decoder forward for alignment
+ alignment_layer (int): N° Layer to supervise with for alignment guiding
+ alignment_heads (int):
+ N. of cross attention heads to use for alignment guiding
+ """
+
+ def __init__(
+ self,
+ num_layers,
+ d_model,
+ heads,
+ d_ff,
+ copy_attn,
+ self_attn_type,
+ dropout,
+ attention_dropout,
+ max_relative_positions,
+ aan_useffn,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
+ pos_ffn_activation_fn=ActivationFunction.relu,
+ ):
+ super(TransformerDecoder, self).__init__(
+ d_model, copy_attn, alignment_layer
+ )
+
+ self.transformer_layers = nn.ModuleList(
+ [
+ TransformerDecoderLayer(
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type=self_attn_type,
+ max_relative_positions=max_relative_positions,
+ aan_useffn=aan_useffn,
+ full_context_alignment=full_context_alignment,
+ alignment_heads=alignment_heads,
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ def detach_state(self):
+ self.state["src"] = self.state["src"].detach()
+
+ def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs):
+ """Decode, possibly stepwise."""
+ if step == 0:
+ self._init_cache(memory_bank)
+
+ batch_size, src_len, src_dim = memory_bank.size()
+ device = memory_bank.device
+ if src_pad_mask is None:
+ src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device)
+ output = tgt_emb
+ batch_size, tgt_len, tgt_dim = tgt_emb.size()
+ if tgt_pad_mask is None:
+ tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device)
+
+ future = kwargs.pop("future", False)
+ with_align = kwargs.pop("with_align", False)
+ attn_aligns = []
+ hiddens = []
+
+ for i, layer in enumerate(self.transformer_layers):
+ layer_cache = (
+ self.state["cache"]["layer_{}".format(i)]
+ if step is not None
+ else None
+ )
+ output, attn, attn_align = layer(
+ output,
+ memory_bank,
+ src_pad_mask,
+ tgt_pad_mask,
+ layer_cache=layer_cache,
+ step=step,
+ with_align=with_align,
+ future=future
+ )
+ hiddens.append(output)
+ if attn_align is not None:
+ attn_aligns.append(attn_align)
+
+ output = self.layer_norm(output) # (B, L, D)
+
+ attns = {"std": attn}
+ if self._copy:
+ attns["copy"] = attn
+ if with_align:
+ attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
+ # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
+
+ # TODO change the way attns is returned dict => list or tuple (onnx)
+ return output, attns, hiddens
+
+ def _init_cache(self, memory_bank):
+ self.state["cache"] = {}
+ for i, layer in enumerate(self.transformer_layers):
+ layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None}
+ self.state["cache"]["layer_{}".format(i)] = layer_cache
+
diff --git a/molscribe/transformer/embedding.py b/molscribe/transformer/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..39647774d7f183690a3443bdf608479088fb0f5d
--- /dev/null
+++ b/molscribe/transformer/embedding.py
@@ -0,0 +1,260 @@
+""" Embeddings module """
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+
+from onmt.modules.util_class import Elementwise
+
+
+class SequenceTooLongError(Exception):
+ pass
+
+
+class PositionalEncoding(nn.Module):
+ """Sinusoidal positional encoding for non-recurrent neural networks.
+
+ Implementation based on "Attention Is All You Need"
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
+
+ Args:
+ dropout (float): dropout parameter
+ dim (int): embedding size
+ """
+
+ def __init__(self, dropout, dim, max_len=5000):
+ if dim % 2 != 0:
+ raise ValueError("Cannot use sin/cos positional encoding with "
+ "odd dim (got dim={:d})".format(dim))
+ pe = torch.zeros(max_len, dim)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
+ -(math.log(10000.0) / dim)))
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
+ pe = pe.unsqueeze(1)
+ super(PositionalEncoding, self).__init__()
+ self.register_buffer('pe', pe)
+ self.dropout = nn.Dropout(p=dropout)
+ self.dim = dim
+
+ def forward(self, emb, step=None):
+ """Embed inputs.
+
+ Args:
+ emb (FloatTensor): Sequence of word vectors
+ ``(seq_len, batch_size, self.dim)``
+ step (int or NoneType): If stepwise (``seq_len = 1``), use
+ the encoding for this position.
+ """
+
+ emb = emb * math.sqrt(self.dim)
+ step = step or 0
+ if self.pe.size(0) < step + emb.size(0):
+ raise SequenceTooLongError(
+ f"Sequence is {emb.size(0) + step} but PositionalEncoding is"
+ f" limited to {self.pe.size(0)}. See max_len argument."
+ )
+ emb = emb + self.pe[step:emb.size(0)+step]
+ emb = self.dropout(emb)
+ return emb
+
+
+class Embeddings(nn.Module):
+ """Words embeddings for encoder/decoder.
+
+ Additionally includes ability to add sparse input features
+ based on "Linguistic Input Features Improve Neural Machine Translation"
+ :cite:`sennrich2016linguistic`.
+
+
+ .. mermaid::
+
+ graph LR
+ A[Input]
+ C[Feature 1 Lookup]
+ A-->B[Word Lookup]
+ A-->C
+ A-->D[Feature N Lookup]
+ B-->E[MLP/Concat]
+ C-->E
+ D-->E
+ E-->F[Output]
+
+ Args:
+ word_vec_size (int): size of the dictionary of embeddings.
+ word_padding_idx (int): padding index for words in the embeddings.
+ feat_padding_idx (List[int]): padding index for a list of features
+ in the embeddings.
+ word_vocab_size (int): size of dictionary of embeddings for words.
+ feat_vocab_sizes (List[int], optional): list of size of dictionary
+ of embeddings for each feature.
+ position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding`
+ feat_merge (string): merge action for the features embeddings:
+ concat, sum or mlp.
+ feat_vec_exponent (float): when using `-feat_merge concat`, feature
+ embedding size is N^feat_dim_exponent, where N is the
+ number of values the feature takes.
+ feat_vec_size (int): embedding dimension for features when using
+ `-feat_merge mlp`
+ dropout (float): dropout probability.
+ freeze_word_vecs (bool): freeze weights of word vectors.
+ """
+
+ def __init__(self, word_vec_size,
+ word_vocab_size,
+ word_padding_idx,
+ position_encoding=False,
+ feat_merge="concat",
+ feat_vec_exponent=0.7,
+ feat_vec_size=-1,
+ feat_padding_idx=[],
+ feat_vocab_sizes=[],
+ dropout=0,
+ sparse=False,
+ freeze_word_vecs=False):
+ self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent,
+ feat_vec_size, feat_padding_idx)
+
+ if feat_padding_idx is None:
+ feat_padding_idx = []
+ self.word_padding_idx = word_padding_idx
+
+ self.word_vec_size = word_vec_size
+
+ # Dimensions and padding for constructing the word embedding matrix
+ vocab_sizes = [word_vocab_size]
+ emb_dims = [word_vec_size]
+ pad_indices = [word_padding_idx]
+
+ # Dimensions and padding for feature embedding matrices
+ # (these have no effect if feat_vocab_sizes is empty)
+ if feat_merge == 'sum':
+ feat_dims = [word_vec_size] * len(feat_vocab_sizes)
+ elif feat_vec_size > 0:
+ feat_dims = [feat_vec_size] * len(feat_vocab_sizes)
+ else:
+ feat_dims = [int(vocab ** feat_vec_exponent)
+ for vocab in feat_vocab_sizes]
+ vocab_sizes.extend(feat_vocab_sizes)
+ emb_dims.extend(feat_dims)
+ pad_indices.extend(feat_padding_idx)
+
+ # The embedding matrix look-up tables. The first look-up table
+ # is for words. Subsequent ones are for features, if any exist.
+ emb_params = zip(vocab_sizes, emb_dims, pad_indices)
+ embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse)
+ for vocab, dim, pad in emb_params]
+ emb_luts = Elementwise(feat_merge, embeddings)
+
+ # The final output size of word + feature vectors. This can vary
+ # from the word vector size if and only if features are defined.
+ # This is the attribute you should access if you need to know
+ # how big your embeddings are going to be.
+ self.embedding_size = (sum(emb_dims) if feat_merge == 'concat'
+ else word_vec_size)
+
+ # The sequence of operations that converts the input sequence
+ # into a sequence of embeddings. At minimum this consists of
+ # looking up the embeddings for each word and feature in the
+ # input. Model parameters may require the sequence to contain
+ # additional operations as well.
+ super(Embeddings, self).__init__()
+ self.make_embedding = nn.Sequential()
+ self.make_embedding.add_module('emb_luts', emb_luts)
+
+ if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0:
+ in_dim = sum(emb_dims)
+ mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU())
+ self.make_embedding.add_module('mlp', mlp)
+
+ self.position_encoding = position_encoding
+
+ if self.position_encoding:
+ pe = PositionalEncoding(dropout, self.embedding_size)
+ self.make_embedding.add_module('pe', pe)
+
+ if freeze_word_vecs:
+ self.word_lut.weight.requires_grad = False
+
+ def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent,
+ feat_vec_size, feat_padding_idx):
+ if feat_merge == "sum":
+ # features must use word_vec_size
+ if feat_vec_exponent != 0.7:
+ warnings.warn("Merging with sum, but got non-default "
+ "feat_vec_exponent. It will be unused.")
+ if feat_vec_size != -1:
+ warnings.warn("Merging with sum, but got non-default "
+ "feat_vec_size. It will be unused.")
+ elif feat_vec_size > 0:
+ # features will use feat_vec_size
+ if feat_vec_exponent != -1:
+ warnings.warn("Not merging with sum and positive "
+ "feat_vec_size, but got non-default "
+ "feat_vec_exponent. It will be unused.")
+ else:
+ if feat_vec_exponent <= 0:
+ raise ValueError("Using feat_vec_exponent to determine "
+ "feature vec size, but got feat_vec_exponent "
+ "less than or equal to 0.")
+ n_feats = len(feat_vocab_sizes)
+ if n_feats != len(feat_padding_idx):
+ raise ValueError("Got unequal number of feat_vocab_sizes and "
+ "feat_padding_idx ({:d} != {:d})".format(
+ n_feats, len(feat_padding_idx)))
+
+ @property
+ def word_lut(self):
+ """Word look-up table."""
+ return self.make_embedding[0][0]
+
+ @property
+ def emb_luts(self):
+ """Embedding look-up table."""
+ return self.make_embedding[0]
+
+ def load_pretrained_vectors(self, emb_file):
+ """Load in pretrained embeddings.
+
+ Args:
+ emb_file (str) : path to torch serialized embeddings
+ """
+
+ if emb_file:
+ pretrained = torch.load(emb_file)
+ pretrained_vec_size = pretrained.size(1)
+ if self.word_vec_size > pretrained_vec_size:
+ self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained
+ elif self.word_vec_size < pretrained_vec_size:
+ self.word_lut.weight.data \
+ .copy_(pretrained[:, :self.word_vec_size])
+ else:
+ self.word_lut.weight.data.copy_(pretrained)
+
+ def forward(self, source, step=None):
+ """Computes the embeddings for words and features.
+
+ Args:
+ source (LongTensor): index tensor ``(len, batch, nfeat)``
+
+ Returns:
+ FloatTensor: Word embeddings ``(len, batch, embedding_size)``
+ """
+
+ if self.position_encoding:
+ for i, module in enumerate(self.make_embedding._modules.values()):
+ if i == len(self.make_embedding._modules.values()) - 1:
+ source = module(source, step=step)
+ else:
+ source = module(source)
+ else:
+ source = self.make_embedding(source)
+
+ return source
+
+ def update_dropout(self, dropout):
+ if self.position_encoding:
+ self._modules['make_embedding'][1].dropout.p = dropout
+
diff --git a/molscribe/transformer/swin_transformer.py b/molscribe/transformer/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea2580fb5f730e4f66e2d07fbe071c09c976fd80
--- /dev/null
+++ b/molscribe/transformer/swin_transformer.py
@@ -0,0 +1,677 @@
+""" Swin Transformer
+A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
+ - https://arxiv.org/pdf/2103.14030
+
+Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
+"""
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+import logging
+import math
+from copy import deepcopy
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
+from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+from timm.models.vision_transformer import checkpoint_filter_fn, _init_vit_weights
+
+_logger = logging.getLogger(__name__)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # patch models (my experiments)
+ 'swin_base_patch4_window12_384': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+
+ 'swin_base_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
+ ),
+
+ 'swin_large_patch4_window12_384': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+
+ 'swin_large_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
+ ),
+
+ 'swin_small_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
+ ),
+
+ 'swin_tiny_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
+ ),
+
+ 'swin_base_patch4_window12_384_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
+
+ 'swin_base_patch4_window7_224_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
+ num_classes=21841),
+
+ 'swin_large_patch4_window12_384_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
+
+ 'swin_large_patch4_window7_224_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
+ num_classes=21841),
+
+}
+
+
+def window_partition(x, window_size: int):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size: int, H: int, W: int):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask: Optional[torch.Tensor] = None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
+ attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def get_attn_mask(self, H, W, device):
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ return attn_mask
+
+ def forward(self, x, H, W):
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_mask = self.get_attn_mask(Hp, Wp, x.device)
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """
+ x: B, H*W, C
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ # assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ H, W = x.shape[1:3]
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x, H, W
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self, dim, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim, num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W, hiddens):
+ for blk in self.blocks:
+ if not torch.jit.is_scripting() and self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, H, W)
+ else:
+ x = blk(x, H, W)
+ hiddens.append(x)
+ if self.downsample is not None:
+ x, H, W = self.downsample(x, H, W)
+ return x, H, W
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+
+class PatchEmbed(nn.Module):
+ """ 2D Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
+ # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x)
+ H, W = x.shape[2:]
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x, H, W
+
+
+class SwinTransformer(nn.Module):
+ r""" Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, weight_init='', **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ self.patch_grid = self.patch_embed.grid_size
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+ else:
+ self.absolute_pos_embed = None
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ layers = []
+ for i_layer in range(self.num_layers):
+ layers += [BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ ]
+ self.layers = nn.Sequential(*layers)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
+ if weight_init.startswith('jax'):
+ for n, m in self.named_modules():
+ _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
+ else:
+ self.apply(_init_vit_weights)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward(self, x):
+ x, H, W = self.patch_embed(x)
+ if self.absolute_pos_embed is not None:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ hiddens = []
+ for layer in self.layers:
+ x, H, W = layer(x, H, W, hiddens)
+ x = self.norm(x) # B L C
+ # x = self.avgpool(x.transpose(1, 2)) # B C 1
+ # x = torch.flatten(x, 1)
+ return x, hiddens
+
+ # def forward(self, x):
+ # x = self.forward_features(x)
+ # x = self.head(x)
+ # return x
+
+
+def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
+ if default_cfg is None:
+ default_cfg = deepcopy(default_cfgs[variant])
+ overlay_external_default_cfg(default_cfg, kwargs)
+ default_num_classes = default_cfg['num_classes']
+ default_img_size = default_cfg['input_size'][-2:]
+
+ num_classes = kwargs.pop('num_classes', default_num_classes)
+ img_size = kwargs.pop('img_size', default_img_size)
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ SwinTransformer, variant, pretrained,
+ default_cfg=default_cfg,
+ img_size=img_size,
+ num_classes=num_classes,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+
+ return model
+
+
+
+@register_model
+def swin_base(pretrained=False, **kwargs):
+ """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+ return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_large(pretrained=False, **kwargs):
+ """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+ return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
+
+
+# @register_model
+# def swin_small_patch4_window7_224(pretrained=False, **kwargs):
+# """ Swin-S @ 224x224, trained ImageNet-1k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
+# return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
+# """ Swin-T @ 224x224, trained ImageNet-1k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)
+# return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
+# """ Swin-B @ 384x384, trained ImageNet-22k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+# return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
+# """ Swin-B @ 224x224, trained ImageNet-22k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+# return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
+# """ Swin-L @ 384x384, trained ImageNet-22k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+# return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
+# """ Swin-L @ 224x224, trained ImageNet-22k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+# return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
diff --git a/molscribe/utils.py b/molscribe/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1dd8331835b2c7075fa0c0229b40cc5d9f7266b
--- /dev/null
+++ b/molscribe/utils.py
@@ -0,0 +1,163 @@
+import os
+import random
+import numpy as np
+import torch
+import math
+import time
+import datetime
+import json
+from json import encoder
+
+
+FORMAT_INFO = {
+ "inchi": {
+ "name": "InChI_text",
+ "tokenizer": "tokenizer_inchi.json",
+ "max_len": 300
+ },
+ "atomtok": {
+ "name": "SMILES_atomtok",
+ "tokenizer": "tokenizer_smiles_atomtok.json",
+ "max_len": 256
+ },
+ "nodes": {"max_len": 384},
+ "atomtok_coords": {"max_len": 480},
+ "chartok_coords": {"max_len": 480}
+}
+
+
+def init_logger(log_file='train.log'):
+ from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler
+ logger = getLogger(__name__)
+ logger.setLevel(INFO)
+ handler1 = StreamHandler()
+ handler1.setFormatter(Formatter("%(message)s"))
+ handler2 = FileHandler(filename=log_file)
+ handler2.setFormatter(Formatter("%(message)s"))
+ logger.addHandler(handler1)
+ logger.addHandler(handler2)
+ return logger
+
+
+def init_summary_writer(save_path):
+ from tensorboardX import SummaryWriter
+ summary = SummaryWriter(save_path)
+ return summary
+
+
+def save_args(args):
+ dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M")
+ path = os.path.join(args.save_path, f'train_{dt}.log')
+ with open(path, 'w') as f:
+ for k, v in vars(args).items():
+ f.write(f"**** {k} = *{v}*\n")
+ return
+
+
+def seed_torch(seed=42):
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+class EpochMeter(AverageMeter):
+ def __init__(self):
+ super().__init__()
+ self.epoch = AverageMeter()
+
+ def update(self, val, n=1):
+ super().update(val, n)
+ self.epoch.update(val, n)
+
+
+class LossMeter(EpochMeter):
+ def __init__(self):
+ self.subs = {}
+ super().__init__()
+
+ def reset(self):
+ super().reset()
+ for k in self.subs:
+ self.subs[k].reset()
+
+ def update(self, loss, losses, n=1):
+ loss = loss.item()
+ super().update(loss, n)
+ losses = {k: v.item() for k, v in losses.items()}
+ for k, v in losses.items():
+ if k not in self.subs:
+ self.subs[k] = EpochMeter()
+ self.subs[k].update(v, n)
+
+
+def asMinutes(s):
+ m = math.floor(s / 60)
+ s -= m * 60
+ return '%dm %ds' % (m, s)
+
+
+def timeSince(since, percent):
+ now = time.time()
+ s = now - since
+ es = s / (percent)
+ rs = es - s
+ return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))
+
+
+def print_rank_0(message):
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_rank() == 0:
+ print(message, flush=True)
+ else:
+ print(message, flush=True)
+
+
+def to_device(data, device):
+ if torch.is_tensor(data):
+ return data.to(device)
+ if type(data) is list:
+ return [to_device(v, device) for v in data]
+ if type(data) is dict:
+ return {k: to_device(v, device) for k, v in data.items()}
+
+
+def round_floats(o):
+ if isinstance(o, float):
+ return round(o, 3)
+ if isinstance(o, dict):
+ return {k: round_floats(v) for k, v in o.items()}
+ if isinstance(o, (list, tuple)):
+ return [round_floats(x) for x in o]
+ return o
+
+
+def format_df(df):
+ def _dumps(obj):
+ if obj is None:
+ return obj
+ return json.dumps(round_floats(obj)).replace(" ", "")
+ for field in ['node_coords', 'node_symbols', 'edges']:
+ if field in df.columns:
+ df[field] = [_dumps(obj) for obj in df[field]]
+ return df
diff --git a/molscribe/vocab/vocab_chars.json b/molscribe/vocab/vocab_chars.json
new file mode 100644
index 0000000000000000000000000000000000000000..daf380c9efda9b89dcdd07398e84609971f0c2bc
--- /dev/null
+++ b/molscribe/vocab/vocab_chars.json
@@ -0,0 +1 @@
+{"": 0, "": 1, "": 2, "": 3, "": 4, ".": 5, "-": 6, "=": 7, "#": 8, ":": 9, "/": 10, "\\": 11, "(": 12, ")": 13, "[": 14, "]": 15, "@": 16, "+": 17, "%": 18, "0": 19, "1": 20, "2": 21, "3": 22, "4": 23, "5": 24, "6": 25, "7": 26, "8": 27, "9": 28, "a": 29, "b": 30, "c": 31, "d": 32, "e": 33, "f": 34, "g": 35, "h": 36, "i": 37, "j": 38, "k": 39, "l": 40, "m": 41, "n": 42, "o": 43, "p": 44, "q": 45, "r": 46, "s": 47, "t": 48, "u": 49, "v": 50, "w": 51, "x": 52, "y": 53, "z": 54, "A": 55, "B": 56, "C": 57, "D": 58, "E": 59, "F": 60, "G": 61, "H": 62, "I": 63, "J": 64, "K": 65, "L": 66, "M": 67, "N": 68, "O": 69, "P": 70, "Q": 71, "R": 72, "S": 73, "T": 74, "U": 75, "V": 76, "W": 77, "X": 78, "Y": 79, "Z": 80, "*": 81, "~": 82, "\u000f": 83, "!": 84, "\"": 85, "$": 86, "&": 87, "'": 88, ",": 89, ";": 90, "<": 91, ">": 92, "?": 93, "^": 94, "_": 95, "`": 96, "{": 97, "|": 98, "}": 99, "\u0155": 100}
diff --git a/molscribe/vocab/vocab_uspto.json b/molscribe/vocab/vocab_uspto.json
new file mode 100644
index 0000000000000000000000000000000000000000..1916eaf0291909eee431678cfcfc390609c8479a
--- /dev/null
+++ b/molscribe/vocab/vocab_uspto.json
@@ -0,0 +1 @@
+{"": 0, "": 1, "": 2, "": 3, "": 4, "[OR12]": 5, "[Z;]": 6, "[LG]": 7, "[10*:0]": 8, "[R35]": 9, "[U1]": 10, "[CH2)]": 11, "[(XV)]": 12, "[fmoc]": 13, "[(Z)n]": 14, "[(L)m]": 15, "[24*]": 16, "[CN;]": 17, "[E,]": 18, "[OC4H9(n)]": 19, "[62*]": 20, "[NH2;]": 21, "[OR]": 22, "[Rb3]": 23, "[(Ra)p]": 24, "[Z13]": 25, "[Den1]": 26, "[NOMe]": 27, "[R1]": 28, "[(NH]": 29, "4": 30, "[R70]": 31, "[O*]": 32, "8": 33, "[Cl(O]": 34, "%15": 35, "[R4a]": 36, "[RK]": 37, "[Br-]": 38, "[)n]": 39, "[KH]": 40, "[12]": 41, "[K]": 42, "[ROOC]": 43, "[(CH2CH2O)x]": 44, "[Re1]": 45, "[Cl]": 46, "[i]": 47, "[X.]": 48, "[SO3-M+]": 49, "[37*]": 50, "[(CR3R4)n]": 51, "[Ry]": 52, "[(R2)m]": 53, "[(OH)n]": 54, "[Pg2]": 55, "[OH.]": 56, "[(A)p]": 57, "[b2]": 58, "[R1c]": 59, "[NO2.]": 60, "[22*]": 61, "[OR1A]": 62, "[R50]": 63, "[(L)]": 64, "[R52]": 65, "[Ce]": 66, "[Ra3]": 67, "[39*]": 68, "[(R2)p]": 69, "[(R10)t]": 70, "[H]": 71, "[NHAc]": 72, "[(n)C4H9]": 73, "[CONR1R2]": 74, "[AlH3]": 75, "[Gal]": 76, "[XVII]": 77, "[20]": 78, "[X;]": 79, "[HZ]": 80, "[C*]": 81, "[RA1]": 82, "[(CH2)n]": 83, "[R(6)]": 84, "[S@@]": 85, "[(R4)q]": 86, "[(R5)c]": 87, "[EtOOC]": 88, "[(R12)m]": 89, "[NR7R8]": 90, "[P4]": 91, "[ALOC]": 92, "[Ra]": 93, "[45*]": 94, "[Dye]": 95, "[se]": 96, "[W(1)(2)]": 97, "[LIGAND]": 98, "[4*]": 99, "[O,]": 100, "[FG]": 101, "[(SiO)n]": 102, "[*HN]": 103, "[(R8)r]": 104, "[CH2)3]": 105, "[)a]": 106, "[(R)m]": 107, "[NBoc]": 108, "[CH2OH.]": 109, "[AA3]": 110, "[3*+]": 111, "[OR15]": 112, "[(O)x]": 113, "[CO2R6]": 114, "[NH,]": 115, "[S@H]": 116, "[Bu-t]": 117, "[Bn]": 118, "[Rc4]": 119, "[IH]": 120, "[Z21]": 121, "[65*]": 122, "[IH+]": 123, "[42*]": 124, "[CO2But]": 125, "[R7]": 126, "[30*:0]": 127, "[N2]": 128, "[Thy]": 129, "[(CF2)n]": 130, "[(Rx)q]": 131, "[(R4b)]": 132, "[(t)C5H11]": 133, "[Rx1]": 134, "[Bz]": 135, "[X-Ph3P+CH2]": 136, "[N3]": 137, "[(Y)a]": 138, "[29]": 139, "[NH.]": 140, "[A0]": 141, "[G4]": 142, "[NR10R11]": 143, "[SH-]": 144, "[(X)]": 145, "[Z4]": 146, "[85%]": 147, "[COR2]": 148, "[nH+]": 149, "[t-BuO]": 150, "[Ac]": 151, "[(CR7R8)n]": 152, "[R9a]": 153, "[G]": 154, "[OR2]": 155, "[(CH3)m]": 156, "%20": 157, "[O)x]": 158, "[G7]": 159, "[(OCH2CH2)n]": 160, "2": 161, "[S;]": 162, "[R2b]": 163, "/": 164, "[Mc]": 165, "[CH+]": 166, "[(CH2)z]": 167, "[BocN]": 168, "[OTs]": 169, "[L]": 170, "[(a)]": 171, "[Aryl]": 172, "[O(CH2)z]": 173, "[(R9)p]": 174, "[RL]": 175, "[An]": 176, "[CH]": 177, "[Pb]": 178, "[X-]": 179, "[CO2R1]": 180, "[)C]": 181, "[Rii]": 182, "[N(R1)2]": 183, "[(CR2R3)n]": 184, "[NR5]": 185, "[O)CH3]": 186, "[55*]": 187, "[(Br)n]": 188, "[*N]": 189, "[Ib]": 190, "[OR9]": 191, "[Halo]": 192, "[hal]": 193, "[Qs]": 194, "[X14]": 195, "[C2F5]": 196, "[CH)m]": 197, "[(R4)a]": 198, "[halo]": 199, "[(MeO)3Si(CH2)3]": 200, "[OR20]": 201, "[(R13)b]": 202, "[X5]": 203, "[OR6,]": 204, "[Hf]": 205, "[21*]": 206, "S": 207, "[CR12]": 208, "[65535*:0]": 209, "[(R1]": 210, "[[A]": 211, "[L11]": 212, "[CO2X]": 213, "[Sx]": 214, "[OBut]": 215, "[[Z]": 216, "[(OCH2CH2)x]": 217, "[Z,]": 218, "[CR6]": 219, "[F,]": 220, "[(CR10R11)m]": 221, "[Ab]": 222, "[PG]": 223, "[d(H2C)]": 224, "F": 225, "[(OH)m]": 226, "[OPG1]": 227, "[Sb]": 228, "[43*]": 229, "[HO2C]": 230, "%24": 231, "[60*]": 232, "B": 233, "[OH)]": 234, "[SG]": 235, "[(t)H9C4]": 236, "[W5]": 237, "[NR3R4]": 238, "[Rsa2]": 239, "[(R2)n]": 240, "[SOmR4]": 241, "[CH2;]": 242, "[RB2]": 243, "[28*]": 244, "[)2]": 245, "[Rq]": 246, "[NH3+]": 247, "[CO2R7]": 248, "[base]": 249, "[SO2R]": 250, "[Rn]": 251, "[SMe]": 252, "[(Q2)n]": 253, "[NHZ]": 254, "[Ri]": 255, "[(XVII)]": 256, "[P@@]": 257, "[NH2]": 258, "[C2H4]": 259, "[W-]": 260, "[(O)n]": 261, "[(R4)p]": 262, "[R54]": 263, "[51*]": 264, "[M2]": 265, "[tBuO]": 266, "[NHMe]": 267, "[X15]": 268, "[16*]": 269, "[LG2]": 270, "[C@@]": 271, "[15]": 272, "[i)]": 273, "[*C]": 274, "[DMT]": 275, "[PG1]": 276, "[Rb2]": 277, "[(L1)a]": 278, "[S]": 279, "[-]": 280, "[b)]": 281, "[90*]": 282, "[RA3]": 283, "[RI]": 284, "[C12H25-n]": 285, "[R22]": 286, "[X21]": 287, "[(CH2)m]": 288, "[R14]": 289, "[(SO3H)n]": 290, "[IVa]": 291, "[NO2;]": 292, "[H3CO]": 293, "[C5]": 294, "[Rv]": 295, "[Ra5]": 296, "[(Z)]": 297, "[0]": 298, "[O]": 299, "[RHN]": 300, "[S@+]": 301, "[Se]": 302, "[C4H9-n]": 303, "[64]": 304, "[RO]": 305, "[Y2,]": 306, "[(O)p]": 307, "[A8]": 308, "[*:0]": 309, "[O)n]": 310, "[Ar7]": 311, "[R51]": 312, "[Os]": 313, "[NR3]": 314, "[J1]": 315, "[CN.]": 316, "[ORc]": 317, "[29*]": 318, "[L32]": 319, "[(R8)p]": 320, "[L31]": 321, "[X22]": 322, "[Example]": 323, "[30]": 324, "[CHO.]": 325, "[R31]": 326, "[CF2]": 327, "[13C]": 328, "[Ag]": 329, "[RN1]": 330, "[(CH2)a]": 331, "[(CR1R2)n]": 332, "[L2]": 333, "[K1]": 334, "[I-2]": 335, "[(C1]": 336, "[(R8)s]": 337, "[wherein]": 338, "[NH]": 339, "[S(O)m]": 340, "[r1]": 341, "[(2)]": 342, "[NMe]": 343, "[3.]": 344, "[H2,]": 345, "[M+]": 346, "[B(OR)2]": 347, "[70*]": 348, "[YR]": 349, "[OP2]": 350, "[(CH2)q]": 351, "[COOM]": 352, "[ButO2C]": 353, "[A5]": 354, "[(CH2)j]": 355, "[CR9]": 356, "[Cl;]": 357, "[ZH]": 358, "[[CH2CH2O]": 359, "[CR3R4]": 360, "[O-]": 361, "[CnH2n+1]": 362, "[2.]": 363, "[S.]": 364, "[b4]": 365, "[11]": 366, "[(C(R6)2)k]": 367, "[125I]": 368, "[1]": 369, "[(CH2)c]": 370, "[CH2)m]": 371, "[W4]": 372, "[SO2]": 373, "[G3]": 374, "[R12]": 375, "[(X)q]": 376, "[Formula]": 377, "[NHR5]": 378, "[HfLn]": 379, "[t-Bu]": 380, "[(CH2CH(CH3)O)y]": 381, "[CH2CO2Me]": 382, "[X2]": 383, "[(IX)]": 384, "[O(CH2)n]": 385, "[F1]": 386, "[44*]": 387, "[(L)r]": 388, "[NR2]": 389, "[14]": 390, "[CO2]": 391, "[SiR1R2R3]": 392, "[(x)]": 393, "[(N]": 394, "[Xn]": 395, "[OM]": 396, "[P(O)(OR1)2]": 397, "[(R5)i]": 398, "[C(]": 399, "[CH3.]": 400, "[Z1)n]": 401, "[MgH]": 402, "[PHV]": 403, "[Mn+]": 404, "[H;]": 405, "[(R2)b]": 406, "[n-Bu]": 407, "[B]": 408, "[HH]": 409, "[[R1]": 410, "[(R5)n]": 411, "[OMe.]": 412, "[SeH]": 413, "[(OH)y]": 414, "[I-]": 415, "[RA2]": 416, ")": 417, "[SR]": 418, "[OE]": 419, "[R;]": 420, "[R2a]": 421, "[}]": 422, "[(R3)c]": 423, "[t-Boc]": 424, "[Rh]": 425, "[S+]": 426, "9": 427, "[z]": 428, "[*-]": 429, "[Polymer]": 430, "[(CH2)1]": 431, "[Base]": 432, "[(X1)d]": 433, "[Me3Si]": 434, "[CO2R10]": 435, "[CH3SO2]": 436, "[RC1]": 437, "[Lv]": 438, "[RN]": 439, "[MO3S]": 440, "[Rb1]": 441, "[(CH2)l]": 442, "[(C)q]": 443, "[23]": 444, "[w]": 445, "[E5]": 446, "[R24]": 447, "[CmH2m]": 448, "[SiH2]": 449, "[OSiMe2tBu]": 450, "[(CF2]": 451, "[(AcG)GLYACHMGPIT(1-nal)VCQPLR(MeG)]": 452, "[P-]": 453, "[Rx]": 454, "[Ra6]": 455, "[R(4)]": 456, "[CH3;]": 457, "[*]": 458, "[R15]": 459, "[Xc]": 460, "[52*]": 461, "[C-]": 462, "[Bu]": 463, "[R21]": 464, "[Den]": 465, "[L10]": 466, "[mPEG]": 467, "[PGO]": 468, "[Cy2]": 469, "[NHFmoc]": 470, "[NA]": 471, "[C6H4]": 472, "[c]": 473, "[(L)n]": 474, "[(SO3M)n]": 475, "[Ra11]": 476, "[2*:0]": 477, "[PH2+]": 478, "[I+3]": 479, "[F-]": 480, "[N.]": 481, "[82*]": 482, "[YR1]": 483, "[SO3H.]": 484, "[OX1]": 485, "[Me.]": 486, "[R(2)]": 487, "[O+]": 488, "[Y]": 489, "[OH3+]": 490, "[CH2CH3.]": 491, "[X31]": 492, "[(CH2)4]": 493, "[Ax]": 494, "[Y2]": 495, "[COOR3]": 496, "[CH2)n]": 497, "[Z22]": 498, "[y(RYZ)]": 499, "[Rm]": 500, "[OCF3]": 501, "[50]": 502, "[2*]": 503, "[NR1R2]": 504, "[Q2]": 505, "3": 506, "[;]": 507, "[NHR]": 508, "[CnH2n]": 509, "[X13]": 510, "[Tr]": 511, "[SO2R1]": 512, "[D3]": 513, "[g]": 514, "[OTES]": 515, "[(R10)r]": 516, "[C4H8SO3]": 517, "[Pr]": 518, "[AR]": 519, "[NR4R5]": 520, "[(C)m]": 521, "[Y7]": 522, "[M]": 523, "[N;]": 524, "[OQ]": 525, "[C@@H]": 526, "[Rb]": 527, "[Rj]": 528, "[where]": 529, "[(R4)x]": 530, "[Z14]": 531, "[R41]": 532, "[)x]": 533, "[Ar12]": 534, "[(CH2)r]": 535, "[C@H]": 536, "[(R5)x]": 537, "[(R14)n]": 538, "o": 539, "[O-Cat+]": 540, "[C(R7)]": 541, "[M(X-Y)n]": 542, "[(T)n]": 543, "[Q3]": 544, "[CsH]": 545, "[(V)t]": 546, "[L,]": 547, "[(R4)s]": 548, "[CO2R5b]": 549, "[63*]": 550, "[E4]": 551, "[EO]": 552, "[Cr]": 553, "[(R1)s]": 554, "[(R11)p]": 555, "[OPg]": 556, "[R']": 557, "[18]": 558, "[(CH2)y]": 559, "[(R4)t]": 560, "[(OH)b]": 561, "[RD]": 562, "[Si]": 563, "[(R8)q]": 564, "[56*]": 565, "[Bc]": 566, "[R]": 567, "[)d]": 568, "[C4H9(t)]": 569, "[R39]": 570, "[H2NSO2]": 571, "[*,]": 572, "[XX]": 573, "[(CH2)f]": 574, "[(R2)q]": 575, "[R37]": 576, "[(H2C)n]": 577, "[48*]": 578, "[QL]": 579, "[(CH2)w]": 580, "[Rk]": 581, "[D5]": 582, "[(R3)r]": 583, "[Cl.]": 584, "[(Q)n]": 585, "[OMe]": 586, "[DE]": 587, "[C4H8]": 588, "[BH3-]": 589, "[V3]": 590, "[R34]": 591, "[68*]": 592, "[Linker]": 593, "[PCy3]": 594, "[(XVI)]": 595, "[OtBu]": 596, "[R16]": 597, "[K+]": 598, "[X8]": 599, "[RJ]": 600, "[RAn]": 601, "[u]": 602, "[NHEt]": 603, "%21": 604, "[(Rb2)n2]": 605, "[OP3]": 606, "[OiPr]": 607, "[P5]": 608, "[tBu]": 609, "[COORD1b]": 610, "[Xa]": 611, "[(R)q]": 612, "[27]": 613, "[Trt]": 614, "[RS]": 615, "[XV]": 616, "[OMe;]": 617, "[Nuc]": 618, "[Zn+2]": 619, "[ORx]": 620, "[R(7)]": 621, "[link]": 622, "[OZ]": 623, "[(XII)]": 624, "[78*]": 625, "C": 626, "[C5H11(t)]": 627, "[RP2]": 628, "[CuPc]": 629, "[Silica]": 630, "[Den2]": 631, "[HO3S]": 632, "\\": 633, "[RaH2]": 634, "6": 635, "[SiH4]": 636, "[nBu]": 637, "[(R2)]": 638, "[COOR1]": 639, "[X0]": 640, "[L13]": 641, "[1*:0]": 642, "[Rc1]": 643, "[(R5)m]": 644, "[RY]": 645, "[A11]": 646, "[OR7]": 647, "[*CH*]": 648, "[alk]": 649, "[Na]": 650, "[P2O]": 651, "[L9]": 652, "[(O)a]": 653, "[(X)p]": 654, "[I+]": 655, "[CR1]": 656, "[Y13]": 657, "[Me2N]": 658, "[C8H17(n)]": 659, "[CO2t-Bu]": 660, "[C2]": 661, "[BaH2]": 662, "1": 663, "[But]": 664, "I": 665, "[CO2Et.]": 666, "[Ti]": 667, "[Et]": 668, "[XH]": 669, "[Y8]": 670, "[C4H9-t]": 671, "[S(O)n]": 672, "[A22]": 673, "[34]": 674, "[B4]": 675, "[CO2R5]": 676, "[(R3)a]": 677, "[Ha]": 678, "[Ta]": 679, "[NHR8]": 680, "[phenyl]": 681, "[(R7)n]": 682, "[HET]": 683, "[Ar1SO2]": 684, "[(R4)b]": 685, "[PROTECTING]": 686, "[(CH2)i]": 687, "[[R]": 688, "[53*]": 689, "[40]": 690, "[3CF3CO2H]": 691, "[12c]": 692, "[R56]": 693, "[A9]": 694, "[COOR10]": 695, "[X18]": 696, "[COOEt]": 697, "[R5a]": 698, "[S3]": 699, "[NHBoc]": 700, "[61*]": 701, "[8]": 702, "[F2C]": 703, "[RSUB2]": 704, "[SO3-Et3NH+]": 705, "[PG2]": 706, "[X1,]": 707, "[R9]": 708, "[(R3)m]": 709, "[(XIV)]": 710, "[CHR3]": 711, "[OR24]": 712, "[B5]": 713, "[1/2]": 714, "[13CH]": 715, "[2Mana1]": 716, "[(CH)m]": 717, "%14": 718, "[66*]": 719, "[COOMe]": 720, "[COOtBu]": 721, "[(X)n]": 722, "[L1*]": 723, "[RB1]": 724, "[(CH2CH2O)m]": 725, "[#]": 726, "[NH4+]": 727, "[Cm]": 728, "[L7]": 729, "[10*]": 730, "[(R2)a]": 731, "[O)m]": 732, "[11*]": 733, "[Cn]": 734, "[6*:0]": 735, "%12": 736, "[[SEQ]": 737, "[(CH2)s]": 738, "[H-]": 739, "[20*]": 740, "Br": 741, "[(R13)m]": 742, "[(Ic)]": 743, "[X1b]": 744, "[W]": 745, "[PH-]": 746, "[IH-4]": 747, "[Al]": 748, "[p]": 749, "[RB4]": 750, "[f]": 751, "[74*]": 752, "[CH2.]": 753, "[(R3)q]": 754, "[FmocHN]": 755, "[OR4]": 756, "[R29]": 757, "[COOC4H9(n)]": 758, "[R32]": 759, "[ClH+]": 760, "[P1]": 761, "[(O]": 762, "[Sm]": 763, "[OL]": 764, "[R8,]": 765, "[MeS]": 766, "[MeLeu]": 767, "[C6H13(n)]": 768, "[S(]": 769, "[COOR7]": 770, "[(X1)m]": 771, "[b3]": 772, "[e]": 773, "[RC]": 774, "[OR13]": 775, "[NR13]": 776, "[E]": 777, "[R-link]": 778, "[Fmoc]": 779, "[C(CH3)]": 780, "[H.]": 781, "[F3CO]": 782, "[CO2R8]": 783, "[NHR2]": 784, "[NHR11]": 785, "[ZO]": 786, "[Alkyl]": 787, "[OR14]": 788, "[X1]": 789, "[SEM]": 790, "[J4]": 791, "[NO2]": 792, "[(Z]": 793, "[Gd+3]": 794, "[(R4)o]": 795, "[SO3M1]": 796, "[(R6)n]": 797, "[A2]": 798, "[CO2PG]": 799, "[Het1]": 800, "[OR8]": 801, "[O)2]": 802, "[SO2R3]": 803, "[23*]": 804, "[(R)k]": 805, "[T6]": 806, "[Q6]": 807, "[OX2]": 808, "[t-C4H9]": 809, "[RSUB1]": 810, "[75*]": 811, "[CH)n]": 812, "[OEt.]": 813, "[Qn]": 814, "[5*]": 815, "[R23]": 816, "[t-BuO2C]": 817, "[C@]": 818, "[R38]": 819, "[Pn]": 820, "[34*]": 821, "[M*]": 822, "[X6]": 823, "[Ar1]": 824, "[Pt]": 825, "[R*]": 826, "[3H]": 827, "[OG]": 828, "[NR7]": 829, "[R8]": 830, "[(CH2]": 831, "[(OH)p]": 832, "[(TR7)t]": 833, "%27": 834, "[OR6]": 835, "[(CH2)e]": 836, "[R25]": 837, "[J]": 838, "[MgH2]": 839, "[CF3.]": 840, "[(SO3H)m]": 841, "[ORd]": 842, "[51*:0]": 843, "[Cd]": 844, "[Ru]": 845, "[21]": 846, "[L14]": 847, "[W5a]": 848, "[Z3]": 849, "[Ir]": 850, "[Sc]": 851, "n": 852, "[Ge]": 853, "[V1]": 854, "[R3b]": 855, "[AlkO]": 856, "[Hg]": 857, "[(R10)0-2]": 858, "[Y10]": 859, "[N@]": 860, "[R10]": 861, "[V]": 862, "[NHSO2Me]": 863, "[Te]": 864, "[CR1R2]": 865, "[R33]": 866, "[2R]": 867, "[35*]": 868, "[alkyl]": 869, "[SO3M]": 870, "[OR3,]": 871, "[XII]": 872, "[CH3]": 873, "[(R4]": 874, "[NCO2t-Bu]": 875, "O": 876, "[s]": 877, "[(CH2)k]": 878, "[R28]": 879, "[Y)n]": 880, "[(R41P)0-2]": 881, "[67*]": 882, "[(ZRY)y]": 883, "%16": 884, "[58*]": 885, "[M3]": 886, "[2H]": 887, "[N-]": 888, "[R11]": 889, "[Ba]": 890, "[C3]": 891, "[(R)p]": 892, "[R:]": 893, "(": 894, "[Cu+]": 895, "[Q4]": 896, "[Nu1]": 897, "[(WRW)m]": 898, "[G1]": 899, "[NHR3]": 900, "[MSG1]": 901, "[Ch]": 902, "[(R4)d]": 903, "[N*]": 904, "[Ph3P]": 905, "[RaH]": 906, "[R.]": 907, "[R72]": 908, "[T4]": 909, "[R45]": 910, "[A]": 911, "[CR3]": 912, "b": 913, "[NHOH.]": 914, "[(A1]": 915, "[BPin]": 916, "[(R5)a]": 917, "[Ni+2]": 918, "[S@]": 919, "[A.]": 920, "[(R9)n]": 921, "[(Z)q]": 922, "[Y1,]": 923, "[n-]": 924, "[OBn]": 925, "[R0]": 926, "[9*:0]": 927, "[peptide]": 928, "[Ra12]": 929, "[32*]": 930, "[(R5)d]": 931, "[25]": 932, "[n-C3H7]": 933, "[COOR6]": 934, "[PGN]": 935, "[SiO3/2]": 936, "[O-M+]": 937, "[Cy]": 938, "[Mj]": 939, "[PgO]": 940, "[Z15]": 941, "[C5H11-t]": 942, "[(R)n]": 943, "[TG]": 944, "[R43]": 945, "[H2n+1Cn]": 946, "[)]": 947, "[(C)p]": 948, "[T3]": 949, "[(V)n-1]": 950, "[La]": 951, "[alkenyl]": 952, "[D4]": 953, "[Val]": 954, "[)b]": 955, "[R18]": 956, "[(R1)p]": 957, "[Rz]": 958, "[(C)r]": 959, "[D1]": 960, "[6*]": 961, "[COR]": 962, "[Ua]": 963, "[M(X]": 964, "%19": 965, "[R,]": 966, "[OR5]": 967, "[N+]": 968, "[Ar]": 969, "[(R1)3]": 970, "[Hf-]": 971, "[CH2OR]": 972, "[SnH]": 973, "[RA]": 974, "[boc]": 975, "[CH2+]": 976, "[(A]": 977, "[[X]": 978, "[n-Pr]": 979, "[Py]": 980, "[(R5)q]": 981, "[OPv]": 982, "[Z0]": 983, "[k]": 984, "[i-Bu]": 985, "[PhO]": 986, "[R100]": 987, "[Z12]": 988, "[15N]": 989, "[Drug]": 990, "[MeSO2]": 991, "[X3]": 992, "[Ar3]": 993, "[Y4]": 994, "[n-C6H13]": 995, "[NHCORB]": 996, "[(R6)m]": 997, "[S*]": 998, "[CF]": 999, "[(Rz)v]": 1000, "[(W)n]": 1001, "[B-]": 1002, "[26*]": 1003, "[Core]": 1004, "[(R1)g]": 1005, "[17*]": 1006, "[H,]": 1007, "[ROH]": 1008, "[Tc]": 1009, "[CF3]": 1010, "[Ar11]": 1011, "[tC4H9]": 1012, "[basic]": 1013, "[Mg]": 1014, "[G2]": 1015, "[7*:0]": 1016, "[CO2H;]": 1017, "[3H-]": 1018, "[[R5]": 1019, "[Cys]": 1020, "[P3]": 1021, "[(R9)q]": 1022, "[b]": 1023, "[O2]": 1024, "[DNA]": 1025, "[X2R3]": 1026, "[CH2*]": 1027, "[NEt2]": 1028, "[TsO]": 1029, "[(R1)x]": 1030, "[CH2R]": 1031, "[NR12]": 1032, "[OPiv]": 1033, "[(CH]": 1034, "[NH-]": 1035, "[L4]": 1036, "[NH2-]": 1037, "[Zb]": 1038, "[1)]": 1039, "[31*]": 1040, "[T1]": 1041, "[(R3)p]": 1042, "[NMe2]": 1043, "[(R2)m1]": 1044, "[2CF3CO2H]": 1045, "[F.]": 1046, "Cl": 1047, "[OEt]": 1048, "[NR48]": 1049, "[OR17]": 1050, "[Ar1p]": 1051, "[Rt]": 1052, "[Pd/C]": 1053, "[TES]": 1054, "[Qa]": 1055, "[(R2)s]": 1056, "[PH2]": 1057, "[BH4-]": 1058, "[(R6)b]": 1059, "[Z9]": 1060, "[C1]": 1061, "[CPG]": 1062, "[OA]": 1063, "[tBuOOC]": 1064, "[H+]": 1065, "[Cbm]": 1066, "[(HO)n]": 1067, "[Compound]": 1068, "[q]": 1069, "[NHPG]": 1070, "[PAG]": 1071, "[Rd]": 1072, "[NR11R12]": 1073, "[OHC]": 1074, "[CONHtBu]": 1075, "[A3]": 1076, "[R2]": 1077, "[19*]": 1078, "[CO2H.]": 1079, "[31]": 1080, "[(R1)2]": 1081, "[(1)]": 1082, "[CO2tBu]": 1083, "[72*]": 1084, "[(R1)]": 1085, "[C(R3)]": 1086, "[SiH]": 1087, "[ORN-a]": 1088, "[S1]": 1089, "[36]": 1090, "[O;]": 1091, "[SO3X]": 1092, "[[R6O2C]": 1093, "[(R2)r]": 1094, "[80*]": 1095, "[Cl,]": 1096, "[+]": 1097, "[S2]": 1098, "[Z.]": 1099, "[a1]": 1100, "[Lb]": 1101, "[*CH2]": 1102, "[PR1]": 1103, "[5*:0]": 1104, "[59*]": 1105, "[77*]": 1106, "[(Z)o]": 1107, "[Rb8]": 1108, "[NR14]": 1109, "[Li+]": 1110, "[L6]": 1111, "[PMP]": 1112, "[(R1)t]": 1113, "[Co-2]": 1114, "[o+]": 1115, "[2)]": 1116, "[98*]": 1117, "[81*]": 1118, "[C4H9(n)]": 1119, "[X16]": 1120, "[l]": 1121, "[85*]": 1122, "[R20]": 1123, "[RpO]": 1124, "[SO2Ph]": 1125, "[Mg+2]": 1126, "[COOX]": 1127, "[OR22]": 1128, "[(Y]": 1129, "[A31]": 1130, "[Rsa1]": 1131, "[OR11]": 1132, "[aryl]": 1133, "[O2N]": 1134, "[Ar4]": 1135, "[30*]": 1136, "[NaBH3CN]": 1137, "[Au]": 1138, "[14*]": 1139, "[CO2R]": 1140, "[Reaction]": 1141, "[(Si]": 1142, "[93*]": 1143, "[Ia]": 1144, "[COOR4]": 1145, "[n+]": 1146, "[Z1]": 1147, "[R17]": 1148, "[(XI)]": 1149, "[2H-]": 1150, "[CR2]": 1151, "[(Y)q]": 1152, "[PAG1]": 1153, "[C34H56+a]": 1154, "[22]": 1155, "[SO3]": 1156, "[S@@+]": 1157, "[CH2]": 1158, "[Pd(PPh3)4/NMM/HOAc/CHCl3]": 1159, "[R61]": 1160, "[Z100]": 1161, "[OPG2]": 1162, "[(XX)]": 1163, "[NR8]": 1164, "[W6]": 1165, "[O)b]": 1166, "[Q]": 1167, "[Pa]": 1168, "[16]": 1169, "[Ar5]": 1170, "[L16]": 1171, "[POLY]": 1172, "[CF3;]": 1173, "[NR6]": 1174, "[(IIIa)]": 1175, "[[Si]": 1176, "[NHR4]": 1177, "[Q,]": 1178, "[CHR1]": 1179, "[C6H4)]": 1180, "[Rc2]": 1181, "[CH3,]": 1182, "[Y12]": 1183, "[Y3]": 1184, "[MeO2C]": 1185, "[(R5)o]": 1186, "[Cb]": 1187, "[,]": 1188, "[Y,]": 1189, ".": 1190, "[(C]": 1191, "[Zc]": 1192, "c": 1193, "[Yp]": 1194, "%26": 1195, "[SH2+]": 1196, "[AM]": 1197, "[NaH]": 1198, "+": 1199, "[5]": 1200, "[R44]": 1201, "[COOR8]": 1202, "[[CH2]": 1203, "[V-2]": 1204, "[4GlcNAc]": 1205, "[91*]": 1206, "[(4)]": 1207, "[COOR2]": 1208, "[R13]": 1209, "[(CH2)x]": 1210, "[NCO]": 1211, "[X+]": 1212, "[Z7]": 1213, "[OP1]": 1214, "[*+]": 1215, "[CR]": 1216, "[(A2]": 1217, "[A21]": 1218, "[nC4H9]": 1219, "[(Bn)2]": 1220, "[(CH2)t]": 1221, "[(H2C]": 1222, "[S-]": 1223, "[W1]": 1224, "[Cbz]": 1225, "[(R12)n]": 1226, "[(XIII)]": 1227, "[Ca]": 1228, "[Ym]": 1229, "[Rg]": 1230, "[3*]": 1231, "%25": 1232, "[O2S]": 1233, "[SO2Me]": 1234, "[SiR]": 1235, "[P@]": 1236, "[(R)b]": 1237, "[(R1)a]": 1238, "[A+]": 1239, "[NH2.]": 1240, "[Zm]": 1241, "[(CH2)u]": 1242, "[NHR6]": 1243, "[Gd]": 1244, "[CH2OR1]": 1245, "[Ra1]": 1246, "[NHR1]": 1247, "[COR6a]": 1248, "[Antibody]": 1249, "[Si(OR)3]": 1250, "[XVI]": 1251, "[(CH2)nCH3]": 1252, "[(A)m]": 1253, "[IH-2]": 1254, "[28]": 1255, "[Y22]": 1256, "[(R2)x]": 1257, "[7]": 1258, "[(A)]": 1259, "[33*]": 1260, "p": 1261, "[J5]": 1262, "[Z8]": 1263, "[RG]": 1264, "[Rb6]": 1265, "[41*]": 1266, "[m-PEG]": 1267, "[OMs]": 1268, "[LQ]": 1269, "[(X)m]": 1270, "[SiH3]": 1271, "[79*]": 1272, "[CO2Bn]": 1273, "[A4]": 1274, "[12*:0]": 1275, "[P3O]": 1276, "[(Z)p]": 1277, "[(R8a)m]": 1278, "[L8]": 1279, "[Rc]": 1280, "[R7,]": 1281, "[mPEG(30K)]": 1282, "[(CH2)v]": 1283, "[OCH3;]": 1284, "[CH)x]": 1285, "[AR2]": 1286, "[69*]": 1287, "[IH-]": 1288, "[COR1]": 1289, "[(Z1)n]": 1290, "[Man]": 1291, "[linker]": 1292, "[R3a]": 1293, "[(OH)q]": 1294, "[M,]": 1295, "[(AO)n]": 1296, "[Y21]": 1297, "[{]": 1298, "[(R3]": 1299, "[n]": 1300, "[Prt]": 1301, "[OR3]": 1302, "[(NH)p]": 1303, "[COOH;]": 1304, "[Z6]": 1305, "[OH-]": 1306, "[Nu]": 1307, "[13CH2]": 1308, "[F3]": 1309, "[(R5)b]": 1310, "[EtO]": 1311, "[CO2Z1]": 1312, "[n-C12H25]": 1313, "[Q8]": 1314, "[(Y)t]": 1315, "[SO3-]": 1316, "[9*]": 1317, "[W3]": 1318, "[2*+]": 1319, "[C)n]": 1320, "[Ca+2]": 1321, "[A-]": 1322, "[Fe]": 1323, "[Q+]": 1324, "[SH]": 1325, "[ODMT]": 1326, "[(R0)q]": 1327, "[XIV]": 1328, "[NHCbz]": 1329, "[(CH)i]": 1330, "[NR6R7]": 1331, "[(Y)m]": 1332, "[MsO]": 1333, "[COOBn]": 1334, "[X7]": 1335, "[y]": 1336, "[Lg]": 1337, "[P2]": 1338, "[DMTO]": 1339, "[X,]": 1340, "[XXII]": 1341, "[Me;]": 1342, "[(Y)p]": 1343, "[COOR]": 1344, "[NR2R3]": 1345, "[P1O]": 1346, "[Sp]": 1347, "[O-2]": 1348, "[B0]": 1349, "[T2]": 1350, "[Hs]": 1351, "[20*:0]": 1352, "[NCS]": 1353, "[NR5R6]": 1354, "[d]": 1355, "[HG]": 1356, "[R40]": 1357, "[Sg]": 1358, "[(CH2)m-1]": 1359, "[NPht]": 1360, "[Si(R2)(R3)]": 1361, "%18": 1362, "[83*]": 1363, "[Aa]": 1364, "[N(iPr)2]": 1365, "[CH2,]": 1366, "[D]": 1367, "[(alkyl)]": 1368, "[F3C]": 1369, "[Pt+2]": 1370, "[3HH]": 1371, "[pH]": 1372, "[OR102]": 1373, "[h]": 1374, "[Z-]": 1375, "[Ep]": 1376, "[ORE]": 1377, "[[R3]": 1378, "[Rr]": 1379, "[OCF2CFHCF3]": 1380, "[)y]": 1381, "[[Y]": 1382, "[SH.]": 1383, "[Cs]": 1384, "[U]": 1385, "[Q7]": 1386, "[Ln]": 1387, "[R81]": 1388, "[(R12)a]": 1389, "[Riii]": 1390, "%17": 1391, "[(R8)n]": 1392, "[Hb]": 1393, "[E1]": 1394, "[(Y)n]": 1395, "[(RA)n]": 1396, "[Z1,]": 1397, "[nH]": 1398, "[Li]": 1399, "[PH+]": 1400, "[E3]": 1401, "[Heteroaryl]": 1402, "[(Ra)w]": 1403, "[(CH2)g]": 1404, "[E2]": 1405, "[ClSO2]": 1406, "[COOBut]": 1407, "N": 1408, "[RF]": 1409, "[No]": 1410, "[Co]": 1411, "[[C(O)]": 1412, "[C)]": 1413, "[[N]": 1414, "[R5]": 1415, "[C]": 1416, "[s+]": 1417, "[HA]": 1418, "[ii)]": 1419, "[R3,]": 1420, "[CmH2m+1]": 1421, "[SA]": 1422, "[a2]": 1423, "[S@@H]": 1424, "[38*]": 1425, "[SR2]": 1426, "[(CH2)n2]": 1427, "[n-C4H9]": 1428, "[(Ra)n]": 1429, "[BASE]": 1430, "[Z5]": 1431, "[RZ]": 1432, "[Ns]": 1433, "[RP]": 1434, "[Ara]": 1435, "[i-C3F7]": 1436, "[Cy1]": 1437, "[SR1]": 1438, "[a4]": 1439, "[Prot]": 1440, "[8*:0]": 1441, "[Zn]": 1442, "[Chiral]": 1443, "[[]": 1444, "[19]": 1445, "[(R4)m]": 1446, "[NHSO2]": 1447, "[A6]": 1448, "[or]": 1449, "[*;]": 1450, "[RM]": 1451, "[OR10]": 1452, "[LiH]": 1453, "[a)]": 1454, "[Abu]": 1455, "[NR9]": 1456, "[COO-t-Bu]": 1457, "[QO]": 1458, "[Rl]": 1459, "[(AcG)GLYACHMGPIT(1-nal)VCQPLR]": 1460, "[1*+]": 1461, "[(CH2CH2]": 1462, "[Lys]": 1463, "[R4,]": 1464, "[OH+]": 1465, "[3)]": 1466, "[MeLeu-MeVal-N]": 1467, "[a3]": 1468, "[OR16]": 1469, "[Cu]": 1470, "[X11]": 1471, "[(R3)n]": 1472, "[(R10)n]": 1473, "[(CH2)p]": 1474, "[P+]": 1475, "[OR7a]": 1476, ":": 1477, "%11": 1478, "[BnO]": 1479, "[76*]": 1480, "[O)y]": 1481, "[X4]": 1482, "[1*]": 1483, "[O1/2]": 1484, "[NR1]": 1485, "%13": 1486, "[Rb4]": 1487, "[Mo]": 1488, "[(R3aR3bC)n]": 1489, "[Cyc2]": 1490, "[R**]": 1491, "[**]": 1492, "[R30]": 1493, "[Hal]": 1494, "[Ala]": 1495, "[C4F9]": 1496, "[RN2]": 1497, "[pol]": 1498, "[36*]": 1499, "[47*]": 1500, "[is]": 1501, "*": 1502, "[R(5)]": 1503, "[27*]": 1504, "[SOmR4,]": 1505, "[Roxa]": 1506, "[86*]": 1507, "[CbzHN]": 1508, "[PH]": 1509, "[(CH2)n,]": 1510, "[8*]": 1511, "[(Z)m]": 1512, "[Rb7]": 1513, "[10]": 1514, "[R27]": 1515, "[SO2R6]": 1516, "[Fe+3]": 1517, "[(R5)p]": 1518, "[(R15)q]": 1519, "[NH2+]": 1520, "[(R12)r]": 1521, "[33]": 1522, "[OPG]": 1523, "[(R5)r]": 1524, "[OY1]": 1525, "[H5C2]": 1526, "[Pg0O]": 1527, "[OR21]": 1528, "[TESO]": 1529, "[Lc]": 1530, "[RP1]": 1531, "[[C]": 1532, "[54*]": 1533, "[(R10)m]": 1534, "[Y5]": 1535, "[B3]": 1536, "[Pd]": 1537, "[Arom]": 1538, "[(CH2)nO]": 1539, "[(CRn)a]": 1540, "[84*]": 1541, "[AA]": 1542, "[CH2X]": 1543, "[(E)]": 1544, "[(t)C4H9]": 1545, "[(R7)q]": 1546, "[(QRX)x]": 1547, "[(R4)n]": 1548, "[Y-]": 1549, "[CH2)2]": 1550, "[GlcNAc]": 1551, "[CH)]": 1552, "[(C(R6)2)p]": 1553, "[Mm+]": 1554, "[M4]": 1555, "[CH3(CH2)n]": 1556, "[and/or]": 1557, "[46*]": 1558, "[(CRVRVI)x]": 1559, "[(CH2)o]": 1560, "[DMTrO]": 1561, "[K2]": 1562, "[(CH2)h]": 1563, "[Y11]": 1564, "[(R1)q]": 1565, "[CO2R11]": 1566, "[RB]": 1567, "[(H,]": 1568, "[Ya]": 1569, "[OH2+]": 1570, "[NR10]": 1571, "[X17]": 1572, "[(R9)m]": 1573, "[X9]": 1574, "[TfO]": 1575, "[(R2)y]": 1576, "[Z32]": 1577, "[Y3,]": 1578, "[)3]": 1579, "[SiMe3]": 1580, "[(R4)r]": 1581, "[OCOR]": 1582, "[(R6)s]": 1583, "[Z1)m]": 1584, "[CO2R3]": 1585, "[(R4)c]": 1586, "[CO2R2]": 1587, "[SO2R2]": 1588, "[Sn]": 1589, "[DMF]": 1590, "[AlH]": 1591, "[R6]": 1592, "[CH2)x]": 1593, "[15*]": 1594, "[Ar6]": 1595, "[b1]": 1596, "[(R10)q]": 1597, "[(R6)q]": 1598, "[11*:0]": 1599, "[X10]": 1600, "[P]": 1601, "[NH+]": 1602, "[N1]": 1603, "[2]": 1604, "[Si@]": 1605, "[)m]": 1606, "[32]": 1607, "[~]": 1608, "[O)]": 1609, "[A1]": 1610, "[ZN]": 1611, "[MeLeu-D-Ala-Ala-MeLeu-Val-MeLeu]": 1612, "[(CR1b2)p]": 1613, "[OH;]": 1614, "[NR]": 1615, "[OH,]": 1616, "[Boc]": 1617, "[CO2Me]": 1618, "[CR5]": 1619, "[Rs]": 1620, "[Cl-]": 1621, "[PPh2]": 1622, "=": 1623, "[Bx]": 1624, "[Bzl]": 1625, "[Z2,]": 1626, "[TBZ]": 1627, "[iPr]": 1628, "[Y.]": 1629, "[49*]": 1630, "[(C(R13)H)r]": 1631, "[.]": 1632, "[a]": 1633, "[18F]": 1634, "[Synthesis]": 1635, "[OH]": 1636, "[COOR5]": 1637, "[)c]": 1638, "[R53]": 1639, "[Yb]": 1640, "[(R3)b]": 1641, "[Cx]": 1642, "[(]": 1643, "[4*:0]": 1644, "[71*]": 1645, "[(R)]": 1646, "[XIX]": 1647, "[L1]": 1648, "[12*]": 1649, "[EWG]": 1650, "[MeOH]": 1651, "[A32]": 1652, "[Dq]": 1653, "[m]": 1654, "%28": 1655, "#": 1656, "[(R3)s]": 1657, "[NH;]": 1658, "[v]": 1659, "[o]": 1660, "[92*]": 1661, "[HX]": 1662, "[BocNH]": 1663, "[(CH2)n-1]": 1664, "[R42]": 1665, "[(L)k]": 1666, "[[R2]": 1667, "[J3]": 1668, "[Alk1]": 1669, "[L22]": 1670, "[(CHR2)n]": 1671, "[C18H37]": 1672, "[x]": 1673, "[SO2R4]": 1674, "[64*]": 1675, "[(T)t]": 1676, "[MeOOC]": 1677, "[(CH)n]": 1678, "[C4]": 1679, "[Y1a]": 1680, "[OCH2Ph]": 1681, "[CO2R14]": 1682, "[(R1)n]": 1683, "[CO]": 1684, "[R46]": 1685, "[(R)2]": 1686, "[Me]": 1687, "[N,]": 1688, "[n-C5H11]": 1689, "[Zr]": 1690, "[Et2N]": 1691, "[SO2NMe2,]": 1692, "[R-link-P]": 1693, "[Y1]": 1694, "[RB3]": 1695, "[Het]": 1696, "[CH*]": 1697, "[Re]": 1698, "[3]": 1699, "[R62]": 1700, "[3R]": 1701, "[ZHN]": 1702, "[Sf]": 1703, "[E6]": 1704, "[A,]": 1705, "[CBz]": 1706, "[Xb]": 1707, "[CH2C6H5]": 1708, "[XVIII]": 1709, "[[CH]": 1710, "[STol]": 1711, "[(R8)m]": 1712, "[C8F17]": 1713, "[(CHR5)m]": 1714, "[73*]": 1715, "[EQ]": 1716, "[R66]": 1717, "[Abu-Sar]": 1718, "[(O)k]": 1719, "[(X]": 1720, "[O.]": 1721, "[Y6]": 1722, "[R26]": 1723, "[Z]": 1724, "[COOt-Bu]": 1725, "[(CR23]": 1726, "[(R7)m]": 1727, "s": 1728, "[2+]": 1729, "[SnBu3]": 1730, "[Rb5]": 1731, "[(R19)p]": 1732, "[Na+]": 1733, "[CCH2)y]": 1734, "[Br]": 1735, "[n-C8H17]": 1736, "[Protecting]": 1737, "[V-]": 1738, "[CHR2]": 1739, "[(CH2)n1]": 1740, "[Fc]": 1741, "[(R1)m]": 1742, "[NCF3]": 1743, "[Z11]": 1744, "[(C(R14)R20)n]": 1745, "[AA1]": 1746, "[C2F4]": 1747, "[R4]": 1748, "[Cycl2]": 1749, "[R1b]": 1750, "[F8]": 1751, "[(R1)r]": 1752, "[XR1]": 1753, "[J2]": 1754, "[OCH3.]": 1755, "[Y2b]": 1756, "[Peptide]": 1757, "[F3CN]": 1758, "[L15]": 1759, "[H2]": 1760, "[R10,]": 1761, "[(R)x]": 1762, "[R71]": 1763, "[A;]": 1764, "[Alk]": 1765, "[C+]": 1766, "[Z10]": 1767, "P": 1768, "[L3]": 1769, "%23": 1770, "[OTf]": 1771, "[H2/Pd]": 1772, "[GP]": 1773, "[tBuO2C]": 1774, "[t]": 1775, "[(Rv)r,]": 1776, "[(Q)d]": 1777, "[L5]": 1778, "[4GlcNAcb1]": 1779, "[polymer]": 1780, "[A23]": 1781, "[XI]": 1782, "[3*:0]": 1783, "[F(Cl,]": 1784, "[C6]": 1785, "[2HCl]": 1786, "[1.]": 1787, "[CO2Et]": 1788, "[G11]": 1789, "[G6]": 1790, "[(B)n]": 1791, "[Yn]": 1792, "[B.]": 1793, "[U2]": 1794, "[PEG]": 1795, "[OX]": 1796, "[CO2CH3.]": 1797, "[CH3-]": 1798, "[Qb]": 1799, "[Rc5]": 1800, "[(C)n]": 1801, "[BocHN]": 1802, "[XO]": 1803, "[(R5)t]": 1804, "[(L1)m]": 1805, "[B2]": 1806, "5": 1807, "[G5]": 1808, "[Ey]": 1809, "[C(H)p]": 1810, "[IH-3]": 1811, "[*CH]": 1812, "[1H]": 1813, "[MO]": 1814, "[X12]": 1815, "[TBS]": 1816, "[Rc3]": 1817, "[7*]": 1818, "[B1]": 1819, "[Y+]": 1820, "[Resin]": 1821, "[COX]": 1822, "[NHR7]": 1823, "[CBZ]": 1824, "[Rf1]": 1825, "[XIII]": 1826, "[(CH2)nRf]": 1827, "[Q5]": 1828, "[Pg]": 1829, "[Rw]": 1830, "[RX]": 1831, "[IX]": 1832, "[AR1]": 1833, "[25*]": 1834, "[SCHEME]": 1835, "[I-3]": 1836, "[F2]": 1837, "[L2*]": 1838, "[COOZ2]": 1839, "[CH2N2]": 1840, "[Mn]": 1841, "[Za]": 1842, "[CX3]": 1843, "[Cl+3]": 1844, "[iBu]": 1845, "[NR8R9]": 1846, "[Fm]": 1847, "[Cu+2]": 1848, "[NHR;]": 1849, "[C15H31-n]": 1850, "[R(1)]": 1851, "[OTr]": 1852, "[PG3]": 1853, "[C6H5]": 1854, "[CH2-]": 1855, "[alkyl,]": 1856, "[(R5]": 1857, "-": 1858, "[A10]": 1859, "[Step]": 1860, "[Q1]": 1861, "[A13]": 1862, "[(R2]": 1863, "[57*]": 1864, "[EE]": 1865, "[PhSO2]": 1866, "[AlH2]": 1867, "[A7]": 1868, "[(O)m]": 1869, "[COOH]": 1870, "[APG]": 1871, "[OR1]": 1872, "[A15]": 1873, "[Ni]": 1874, "[Z31]": 1875, "[(Ib)]": 1876, "[Yq]": 1877, "[Al+3]": 1878, "[As]": 1879, "[(Ia)]": 1880, "[OPh]": 1881, "[Rf2]": 1882, "[Br;]": 1883, "[W2]": 1884, "[Rf]": 1885, "[t-butyl]": 1886, "[R1,]": 1887, "[CaH2]": 1888, "[V2]": 1889, "[C8H17(t)]": 1890, "[CR5R6]": 1891, "[D2]": 1892, "[(O)q]": 1893, "[SH+]": 1894, "[(R0)n]": 1895, "[Cp]": 1896, "[R2,]": 1897, "[(A)n]": 1898, "[Tb]": 1899, "[A12]": 1900, "[nPr]": 1901, "[FULL]": 1902, "[R1a]": 1903, "[RbH]": 1904, "[X1a]": 1905, "[SO2NHtBu]": 1906, "~": 1907, "[Br.]": 1908, "[R(3)]": 1909, "[L0]": 1910, "[(IIa)]": 1911, "[6]": 1912, "[(CRR)m]": 1913, "[Mana1]": 1914, "[ORA]": 1915, "[Ra4]": 1916, "[24]": 1917, "[BH2-]": 1918, "[and]": 1919, "[(CH2)b]": 1920, "[Ra2]": 1921, "[I]": 1922, "[halogen]": 1923, "[R36]": 1924, "[R19]": 1925, "[r]": 1926, "[[O]": 1927, "[OR2(2-a)]": 1928, "[R3]": 1929, "[13*]": 1930, "[M1]": 1931, "[Ry1]": 1932, "[ED]": 1933, "[40*]": 1934, "[RO2C]": 1935, "[OAlk]": 1936, "[MeO]": 1937, "[18*]": 1938, "%22": 1939, "[Fe+2]": 1940, "[Nb]": 1941, "[Protein]": 1942, "[LINKER]": 1943, "[CHR4]": 1944, "[Z2]": 1945, "[Ph]": 1946, "[n(H2C)]": 1947, "[PivO]": 1948, "[N]": 1949, "[GO]": 1950, "[(3)]": 1951, "[(CH2CH]": 1952, "[i-Pr]": 1953, "[L12]": 1954, "[2HH]": 1955, "[1R]": 1956, "[(CH2CH)z]": 1957, "7": 1958, "[F;]": 1959, "[9]": 1960, "[BzO]": 1961, "[Scheme]": 1962, "[L21]": 1963, "[13]": 1964, "[(CH2)d]": 1965, "[Rp]": 1966, "[Y;]": 1967, "[(R1)k]": 1968, "%10": 1969, "[CONH-n-C12H25]": 1970, "[X]": 1971, "[TrO]": 1972, "[(CR4R5)n]": 1973, "[4]": 1974, "[(R6)p]": 1975, "[NH2,]": 1976, "[COOH.]": 1977, "[c+]": 1978, "[OEE]": 1979, "[OR,]": 1980, "[Cu-]": 1981, "[(M)n]": 1982, "[17]": 1983, "[Si@@]": 1984, "[Pg1]": 1985, "[RE]": 1986, "[pg]": 1987, "[0.5]": 1988, "[R6a]": 1989, "[CO2R4]": 1990, "[OBz]": 1991, "[Ts]": 1992, "[50*]": 1993, "[T5]": 1994, "[(O)b]": 1995, "[Xm]": 1996, "[Ar2]": 1997, "[Ro]": 1998, "[S-2]": 1999, "[MSG2]": 2000, "[(CH2CH2O)n]": 2001, "[AA2]": 2002, "[A14]": 2003, "[(Al]": 2004, "[ETA]": 2005, "[(R]": 2006, "[(C(R12)H)q]": 2007, "[=]": 2008, "[NR4]": 2009, "[EtO2C]": 2010, "[CR4]": 2011, "[CH-]": 2012}
\ No newline at end of file
diff --git a/prompts/2_RxnOCR.txt b/prompts/2_RxnOCR.txt
new file mode 100644
index 0000000000000000000000000000000000000000..658863fe1885a51d44183914385683791a3636de
--- /dev/null
+++ b/prompts/2_RxnOCR.txt
@@ -0,0 +1,3 @@
+You are a helpful assistant in identifying chemistry data in an image. In this reaction image, there is a chemistry reaction diagram with one step or multiple step reactions. Your task is to review both the reactions, and output an array of entries in a json format, which consists of the properly-substituted reactions and all items of the entry present in the table. Your output should be a list of reactions. Each reaction entry should contain its reactants (in SMILES format and with label when the label is provided such as "1a","2a","3b" ...,, or else use "label":"None"), its conditions (Note that molecular smiles or text can both appear in the conditions. First recheck the image carefully and correct the OCR errors and missings of the tool for the text content, and then identify the text or smiles condition role in"reagent","solvent","yield","time(such as "1 h", "24 h")","temperature (Note "rt" is temperature too)",if there is no then use "None"), its products (in SMILES format and with label when the label is provided such as "1a","2a","3b" ..., or else use "label":"None"). Make sure that the SMILES strings are correctly formatted.
+here is an example output:
+{"reactions":[{"reaction_id":"1","reactants":[{"smiles":"Oc1ccc2cccc(C(O)c3ccccc3)c2c1","label":"1a"},{"smiles":"c1ccc2c(c1)[nH]c1ccccc12","label":"None"},...],"conditions":[{"role":"reagent","smiles":"*c1cc2ccccc2c2c1OP(=O)(O)Oc1c(*)cc3ccccc3c1-2"}, {"role":"reagent","text":"DCM"},{"role": "solvent","text": "toluene"},,...],"products":[{"smiles":"Oc1ccc2cccc([C@H](c3ccccc3)n3c4ccccc4c4ccccc43)c2c1","label":"None"}]}]}
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a8f5bce38aafdca79deb76c8646feedc80f6e0ee
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,20 @@
+numpy==1.26.4
+torch>=1.10.0,<2.0
+transformers>=4.6.0
+layoutparser[effdet]
+opencv-python==4.5.5.64
+opencv-python-headless==4.5.4.60
+Pillow==9.5.0
+ipython
+openai
+albumentations==1.1.0
+matplotlib>=3.5.3
+SmilesPE==0.0.3
+OpenNMT-py==2.2.0
+rdkit-pypi>=2021.03.2
+timm==0.4.12
+pandas>=1.2.4
+pycocotools>=2.0.4
+pytorch-lightning>=1.8.6
+huggingface-hub>=0.11.0
+easyocr>=1.6.2
diff --git a/rxn/main.py b/rxn/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..126c526aff932f8ccaf7a2498625e52b8bb42870
--- /dev/null
+++ b/rxn/main.py
@@ -0,0 +1,367 @@
+import os
+import math
+import json
+import random
+import argparse
+import numpy as np
+
+import torch
+import torch.distributed as dist
+import pytorch_lightning as pl
+from pytorch_lightning import LightningModule, LightningDataModule
+from pytorch_lightning.callbacks import LearningRateMonitor
+from pytorch_lightning.strategies.ddp import DDPStrategy
+from transformers import get_scheduler
+
+from reaction.model import Encoder, Decoder
+from reaction.pix2seq import build_pix2seq_model
+from reaction.loss import Criterion
+from reaction.tokenizer import get_tokenizer
+from reaction.dataset import ReactionDataset, get_collate_fn
+from reaction.data import postprocess_reactions
+from reaction.evaluate import CocoEvaluator, ReactionEvaluator
+import reaction.utils as utils
+
+
+def get_args(notebook=False):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--do_train', action='store_true')
+ parser.add_argument('--do_valid', action='store_true')
+ parser.add_argument('--do_test', action='store_true')
+ parser.add_argument('--fp16', action='store_true')
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpus', type=int, default=1)
+ parser.add_argument('--print_freq', type=int, default=200)
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--no_eval', action='store_true')
+ # Model
+ parser.add_argument('--encoder', type=str, default='resnet34')
+ parser.add_argument('--decoder', type=str, default='lstm')
+ parser.add_argument('--trunc_encoder', action='store_true') # use the hidden states before downsample
+ parser.add_argument('--no_pretrained', action='store_true')
+ parser.add_argument('--use_checkpoint', action='store_true')
+ parser.add_argument('--lstm_dropout', type=float, default=0.5)
+ parser.add_argument('--embed_dim', type=int, default=256)
+ parser.add_argument('--enc_pos_emb', action='store_true')
+ group = parser.add_argument_group("lstm_options")
+ group.add_argument('--decoder_dim', type=int, default=512)
+ group.add_argument('--decoder_layer', type=int, default=1)
+ group.add_argument('--attention_dim', type=int, default=256)
+ group = parser.add_argument_group("transformer_options")
+ group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6)
+ group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256)
+ group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8)
+ group.add_argument("--dec_num_queries", type=int, default=128)
+ group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1)
+ group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1)
+ group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0)
+ # Pix2Seq
+ parser.add_argument('--pix2seq', action='store_true', help="specify the model from playground")
+ parser.add_argument('--pix2seq_ckpt', type=str, default=None)
+ parser.add_argument('--large_scale_jitter', action='store_true', help='large scale jitter')
+ parser.add_argument('--pred_eos', action='store_true', help='use eos token instead of predicting 100 objects')
+ # * Backbone
+ parser.add_argument('--backbone', default='resnet50', type=str, help="Name of the convolutional backbone to use")
+ parser.add_argument('--dilation', action='store_true',
+ help="If true, we replace stride with dilation in the last convolutional block (DC5)")
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
+ help="Type of positional embedding to use on top of the image features")
+ # * Transformer
+ parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer")
+ parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer")
+ parser.add_argument('--dim_feedforward', default=1024, type=int,
+ help="Intermediate size of the feedforward layers in the transformer blocks")
+ parser.add_argument('--hidden_dim', default=256, type=int,
+ help="Size of the embeddings (dimension of the transformer)")
+ parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer")
+ parser.add_argument('--nheads', default=8, type=int,
+ help="Number of attention heads inside the transformer's attentions")
+ parser.add_argument('--pre_norm', action='store_true')
+ # Data
+ parser.add_argument('--data_path', type=str, default=None)
+ parser.add_argument('--image_path', type=str, default=None)
+ parser.add_argument('--train_file', type=str, default=None)
+ parser.add_argument('--valid_file', type=str, default=None)
+ parser.add_argument('--test_file', type=str, default=None)
+ parser.add_argument('--vocab_file', type=str, default=None)
+ parser.add_argument('--format', type=str, default='reaction')
+ parser.add_argument('--num_workers', type=int, default=8)
+ parser.add_argument('--input_size', type=int, default=224)
+ parser.add_argument('--augment', action='store_true')
+ parser.add_argument('--composite_augment', action='store_true')
+ parser.add_argument('--coord_bins', type=int, default=100)
+ parser.add_argument('--sep_xy', action='store_true')
+ parser.add_argument('--rand_order', action='store_true', help="randomly permute the sequence of input targets")
+ parser.add_argument('--add_noise', action='store_true')
+ parser.add_argument('--mix_noise', action='store_true')
+ parser.add_argument('--shuffle_bbox', action='store_true')
+ parser.add_argument('--images', type=str, default='')
+ # Training
+ parser.add_argument('--epochs', type=int, default=8)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--lr', type=float, default=1e-4)
+ parser.add_argument('--weight_decay', type=float, default=0.05)
+ parser.add_argument('--max_grad_norm', type=float, default=5.)
+ parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine')
+ parser.add_argument('--warmup_ratio', type=float, default=0)
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
+ parser.add_argument('--load_path', type=str, default=None)
+ parser.add_argument('--load_encoder_only', action='store_true')
+ parser.add_argument('--train_steps_per_epoch', type=int, default=-1)
+ parser.add_argument('--eval_per_epoch', type=int, default=10)
+ parser.add_argument('--save_path', type=str, default='output/')
+ parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last'])
+ parser.add_argument('--load_ckpt', type=str, default='best')
+ parser.add_argument('--resume', action='store_true')
+ parser.add_argument('--num_train_example', type=int, default=None)
+ parser.add_argument('--label_smoothing', type=float, default=0.0)
+ parser.add_argument('--save_image', action='store_true')
+ # Inference
+ parser.add_argument('--beam_size', type=int, default=1)
+ parser.add_argument('--n_best', type=int, default=1)
+ parser.add_argument('--molscribe', action='store_true')
+ args = parser.parse_args([]) if notebook else parser.parse_args()
+
+ args.images = args.images.split(',')
+
+ return args
+
+
+class ReactionExtractor(LightningModule):
+
+ def __init__(self, args, tokenizer):
+ super().__init__()
+ self.args = args
+ self.tokenizer = tokenizer
+ self.encoder = Encoder(args, pretrained=(not args.no_pretrained))
+ args.encoder_dim = self.encoder.n_features
+ self.decoder = Decoder(args, tokenizer)
+ self.criterion = Criterion(args, tokenizer)
+
+ def training_step(self, batch, batch_idx):
+ indices, images, refs = batch
+ features, hiddens = self.encoder(images, refs)
+ results = self.decoder(features, hiddens, refs)
+ losses = self.criterion(results, refs)
+ loss = sum(losses.values())
+ self.log('train/loss', loss)
+ self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ indices, images, refs = batch
+ features, hiddens = self.encoder(images, refs)
+ batch_preds, batch_beam_preds = self.decoder.decode(
+ features, hiddens, refs,
+ beam_size=self.args.beam_size, n_best=self.args.n_best)
+ return indices, batch_preds
+
+ def validation_epoch_end(self, outputs, phase='val'):
+ if self.trainer.num_devices > 1:
+ gathered_outputs = [None for i in range(self.trainer.num_devices)]
+ dist.all_gather_object(gathered_outputs, outputs)
+ gathered_outputs = sum(gathered_outputs, [])
+ else:
+ gathered_outputs = outputs
+
+ format = self.args.format
+ predictions = utils.merge_predictions(gathered_outputs)
+
+ name = self.eval_dataset.name
+ scores = [0]
+
+ if self.trainer.is_global_zero:
+ if not self.args.no_eval:
+ if format == 'bbox':
+ coco_evaluator = CocoEvaluator(self.eval_dataset.coco)
+ stats = coco_evaluator.evaluate(predictions['bbox'])
+ scores = results = list(stats)
+ elif format == 'reaction':
+ epoch = self.trainer.current_epoch
+ evaluator = ReactionEvaluator()
+ results, *_ = evaluator.evaluate_summarize(self.eval_dataset.data, predictions['reaction'])
+ precision, recall, f1 = \
+ results['overall']['precision'], results['overall']['recall'], results['overall']['f1']
+ scores = [f1]
+ self.print(f'Epoch: {epoch:>3} Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}')
+ results['mol_only'], *_ = evaluator.evaluate_summarize(
+ self.eval_dataset.data, predictions['reaction'], mol_only=True, merge_condition=True)
+ else:
+ raise NotImplementedError
+ with open(os.path.join(self.trainer.default_root_dir, f'eval_{name}.json'), 'w') as f:
+ json.dump(results, f)
+ if phase == 'test':
+ self.print(json.dumps(results, indent=4))
+ with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f:
+ json.dump(predictions, f)
+
+ dist.broadcast_object_list(scores)
+ self.log(f'{phase}/score', scores[0], prog_bar=True, rank_zero_only=True)
+
+ def test_step(self, batch, batch_idx):
+ return self.validation_step(batch, batch_idx)
+
+ def test_epoch_end(self, outputs):
+ return self.validation_epoch_end(outputs, phase='test')
+
+ def predict_step(self, batch, batch_idx):
+ return self.validation_step(batch, batch_idx)
+
+ def configure_optimizers(self):
+ num_training_steps = self.trainer.num_training_steps
+ self.print(f'Num training steps: {num_training_steps}')
+ num_warmup_steps = int(num_training_steps * self.args.warmup_ratio)
+ # parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
+ scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps)
+ return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
+
+
+class ReactionExtractorPix2Seq(ReactionExtractor):
+
+ def __init__(self, args, tokenizer):
+ super(ReactionExtractor, self).__init__()
+ self.args = args
+ self.tokenizer = tokenizer
+ self.format = args.format
+ self.model = build_pix2seq_model(args, tokenizer[self.format])
+ self.criterion = Criterion(args, tokenizer)
+ self.molscribe = None
+
+ def training_step(self, batch, batch_idx):
+ indices, images, refs = batch
+ format = self.format
+ results = {format: (self.model(images, refs[format]), refs[format+'_out'][0][:, 1:])}
+ losses = self.criterion(results, refs)
+ loss = sum(losses.values())
+ self.log('train/loss', loss)
+ self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ indices, images, refs = batch
+ format = self.format
+ batch_preds = {format: [], 'file_name': []}
+ pred_seqs, pred_scores = self.model(images, max_len=self.tokenizer[format].max_len)
+ for i, (seqs, scores) in enumerate(zip(pred_seqs, pred_scores)):
+ if format == 'reaction':
+ reactions = self.tokenizer[format].sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs['scale'][i])
+ reactions = postprocess_reactions(reactions)
+ batch_preds[format].append(reactions)
+ if format == 'bbox':
+ bboxes = self.tokenizer[format].sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs['scale'][i])
+ batch_preds[format].append(bboxes)
+ batch_preds['file_name'].append(refs['file_name'][i])
+ return indices, batch_preds
+
+
+class ReactionDataModule(LightningDataModule):
+
+ def __init__(self, args, tokenizer):
+ super().__init__()
+ self.args = args
+ self.tokenizer = tokenizer
+ self.collate_fn = get_collate_fn(self.pad_id)
+
+ @property
+ def pad_id(self):
+ return self.tokenizer[self.args.format].PAD_ID
+
+ def prepare_data(self):
+ args = self.args
+ if args.do_train:
+ self.train_dataset = ReactionDataset(args, self.tokenizer, args.train_file, split='train')
+ if self.args.do_train or self.args.do_valid:
+ self.val_dataset = ReactionDataset(args, self.tokenizer, args.valid_file, split='valid')
+ if self.args.do_test:
+ self.test_dataset = ReactionDataset(args, self.tokenizer, args.test_file, split='test')
+
+ def print_stats(self):
+ if self.args.do_train:
+ print(f'Train dataset: {len(self.train_dataset)}')
+ if self.args.do_train or self.args.do_valid:
+ print(f'Valid dataset: {len(self.val_dataset)}')
+ if self.args.do_test:
+ print(f'Test dataset: {len(self.test_dataset)}')
+
+ def train_dataloader(self):
+ return torch.utils.data.DataLoader(
+ self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
+ collate_fn=self.collate_fn)
+
+ def val_dataloader(self):
+ return torch.utils.data.DataLoader(
+ self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
+ collate_fn=self.collate_fn)
+
+ def test_dataloader(self):
+ return torch.utils.data.DataLoader(
+ self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
+ collate_fn=self.collate_fn)
+
+
+class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
+ def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str:
+ filepath = self.format_checkpoint_name(monitor_candidates)
+ return filepath
+
+
+def main():
+
+ args = get_args()
+ pl.seed_everything(args.seed, workers=True)
+
+ if args.debug:
+ args.save_path = "output/debug"
+
+ tokenizer = get_tokenizer(args)
+
+ MODEL = ReactionExtractorPix2Seq if args.pix2seq else ReactionExtractor
+ if args.do_train:
+ model = MODEL(args, tokenizer)
+ else:
+ model = MODEL.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False,
+ args=args, tokenizer=tokenizer)
+
+ dm = ReactionDataModule(args, tokenizer)
+ dm.prepare_data()
+ dm.print_stats()
+
+ checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True)
+ # checkpoint = ModelCheckpoint(monitor=None, save_top_k=0, save_last=True)
+ lr_monitor = LearningRateMonitor(logging_interval='step')
+ logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='')
+
+ trainer = pl.Trainer(
+ strategy=DDPStrategy(find_unused_parameters=False),
+ accelerator='gpu',
+ devices=4,
+ logger=logger,
+ default_root_dir=args.save_path,
+ callbacks=[checkpoint, lr_monitor],
+ max_epochs=args.epochs,
+ gradient_clip_val=args.max_grad_norm,
+ accumulate_grad_batches=args.gradient_accumulation_steps,
+ check_val_every_n_epoch=args.eval_per_epoch,
+ log_every_n_steps=10,
+ deterministic=True)
+
+ if args.do_train:
+ trainer.num_training_steps = math.ceil(
+ len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs
+ model.eval_dataset = dm.val_dataset
+ ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None
+ trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)
+ model = MODEL.load_from_checkpoint(checkpoint.best_model_path, args=args, tokenizer=tokenizer)
+
+ if args.do_valid:
+ model.eval_dataset = dm.val_dataset
+ trainer.validate(model, datamodule=dm)
+
+ if args.do_test:
+ model.eval_dataset = dm.test_dataset
+ trainer.test(model, datamodule=dm)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/rxn/model/model.ckpt b/rxn/model/model.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..059bb7122c96e5dc0209614bfdd7a412da59285b
--- /dev/null
+++ b/rxn/model/model.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0020634f13fb3e1f588bddca97f68fd6483f0cdecd83e2ca31c2434ea4340fe
+size 432324497
diff --git a/rxn/reaction/__init__.py b/rxn/reaction/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c201611dbec7f2eb80941ee7aeb9f79bda996d0c
--- /dev/null
+++ b/rxn/reaction/__init__.py
@@ -0,0 +1 @@
+from .interface import Reaction
diff --git a/rxn/reaction/__pycache__/__init__.cpython-310.pyc b/rxn/reaction/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9ce740377d14874a994bbb4b6354ab1a3a9672c
Binary files /dev/null and b/rxn/reaction/__pycache__/__init__.cpython-310.pyc differ
diff --git a/rxn/reaction/__pycache__/__init__.cpython-312.pyc b/rxn/reaction/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce3d55d68647daf137fc61387fa1759746348ee5
Binary files /dev/null and b/rxn/reaction/__pycache__/__init__.cpython-312.pyc differ
diff --git a/rxn/reaction/__pycache__/__init__.cpython-38.pyc b/rxn/reaction/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59a801c3eaf8b4ca92a8240e5a59e391b2c110dd
Binary files /dev/null and b/rxn/reaction/__pycache__/__init__.cpython-38.pyc differ
diff --git a/rxn/reaction/__pycache__/data.cpython-310.pyc b/rxn/reaction/__pycache__/data.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6bb3019a79b8baf407b7509037f62eafd38f78a0
Binary files /dev/null and b/rxn/reaction/__pycache__/data.cpython-310.pyc differ
diff --git a/rxn/reaction/__pycache__/data.cpython-38.pyc b/rxn/reaction/__pycache__/data.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5170c1ee838e3fb411ded1d2022bdab9e709020
Binary files /dev/null and b/rxn/reaction/__pycache__/data.cpython-38.pyc differ
diff --git a/rxn/reaction/__pycache__/dataset.cpython-310.pyc b/rxn/reaction/__pycache__/dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f4f7681208ba8555dc443c232966b7b83582e27
Binary files /dev/null and b/rxn/reaction/__pycache__/dataset.cpython-310.pyc differ
diff --git a/rxn/reaction/__pycache__/dataset.cpython-38.pyc b/rxn/reaction/__pycache__/dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07ab969faedfd01c11366a37cf9f9d5e01784730
Binary files /dev/null and b/rxn/reaction/__pycache__/dataset.cpython-38.pyc differ
diff --git a/rxn/reaction/__pycache__/interface.cpython-310.pyc b/rxn/reaction/__pycache__/interface.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f858e0a84e96a432ffee61f142447484df79bcaf
Binary files /dev/null and b/rxn/reaction/__pycache__/interface.cpython-310.pyc differ
diff --git a/rxn/reaction/__pycache__/interface.cpython-312.pyc b/rxn/reaction/__pycache__/interface.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f605c50693c4f03f4a0b68d28474f5c2b841b12b
Binary files /dev/null and b/rxn/reaction/__pycache__/interface.cpython-312.pyc differ
diff --git a/rxn/reaction/__pycache__/interface.cpython-38.pyc b/rxn/reaction/__pycache__/interface.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ef614e5df15dc06a2da43616e8c43154a202ac0
Binary files /dev/null and b/rxn/reaction/__pycache__/interface.cpython-38.pyc differ
diff --git a/rxn/reaction/__pycache__/tokenizer.cpython-310.pyc b/rxn/reaction/__pycache__/tokenizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12a48181adadfd2192e812fe9b7651d0751e0fb6
Binary files /dev/null and b/rxn/reaction/__pycache__/tokenizer.cpython-310.pyc differ
diff --git a/rxn/reaction/__pycache__/tokenizer.cpython-38.pyc b/rxn/reaction/__pycache__/tokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7013b5ae91606712e200b3e3484cbdb46e35515
Binary files /dev/null and b/rxn/reaction/__pycache__/tokenizer.cpython-38.pyc differ
diff --git a/rxn/reaction/__pycache__/transforms.cpython-310.pyc b/rxn/reaction/__pycache__/transforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50cc47bf280a8551a904d3eb8a1c4617b01e4128
Binary files /dev/null and b/rxn/reaction/__pycache__/transforms.cpython-310.pyc differ
diff --git a/rxn/reaction/__pycache__/transforms.cpython-38.pyc b/rxn/reaction/__pycache__/transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7e22ed3d22345fc8f609a5246318ce61244453e
Binary files /dev/null and b/rxn/reaction/__pycache__/transforms.cpython-38.pyc differ
diff --git a/rxn/reaction/data.py b/rxn/reaction/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..a51b34351df5305eccb9ad1e2b7d08f1ff4e31b6
--- /dev/null
+++ b/rxn/reaction/data.py
@@ -0,0 +1,367 @@
+import os
+import cv2
+import numpy as np
+import matplotlib.colors as colors
+import matplotlib.patches as patches
+
+
+class BBox(object):
+
+ def __init__(self, bbox, image_data=None, xyxy=False, normalized=False):
+ """
+ :param bbox: {'catrgory_id', 'bbox'}
+ :param input_image: ImageData object
+ :param xyxy:
+ :param normalized:
+ """
+ self.data = bbox
+ self.image_data = image_data
+ if image_data is not None:
+ self.width = image_data.width
+ self.height = image_data.height
+ self.category_id = bbox['category_id']
+ if xyxy:
+ x1, y1, x2, y2 = bbox['bbox']
+ else:
+ x1, y1, w, h = bbox['bbox']
+ x2, y2 = x1 + w, y1 + h
+ if not normalized:
+ x1, y1, x2, y2 = x1 / self.width, y1 / self.height, x2 / self.width, y2 / self.height
+ self.x1, self.y1, self.x2, self.y2 = x1, y1, x2, y2
+
+ @property
+ def is_mol(self):
+ return self.category_id == 1
+
+ @property
+ def is_empty(self):
+ return abs(self.x2 - self.x1) <= 0.01 or abs(self.y2 - self.y1) <= 0.01
+
+ def unnormalize(self):
+ return self.x1 * self.width, self.y1 * self.height, self.x2 * self.width, self.y2 * self.height
+
+ def image(self):
+ x1, y1, x2, y2 = self.unnormalize()
+ x1, y1, x2, y2 = max(int(x1), 0), max(int(y1), 0), min(int(x2), self.width), min(int(y2), self.height)
+ return self.image_data.image[y1:y2, x1:x2]
+
+ COLOR = {1: 'purple', 2: 'orange', 3: 'cyan', 4: 'magenta'}
+ CATEGORY = {1: 'Str', 2: 'Txt', 3: 'Txt', 4: 'Sup'}
+
+
+ def draw(self, ax, color=None):
+ x1, y1, x2, y2 = self.unnormalize()
+ if color is None:
+ color = self.COLOR[self.category_id]
+ rect = patches.Rectangle(
+ (x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor=color, facecolor=colors.to_rgba(color, 0.2))
+ text = f'{self.CATEGORY[self.category_id]}'
+ ax.text(x2, y2, text, fontsize=20, bbox=dict(linewidth=0.5, facecolor='Lightgreen', alpha=0.5))
+ ax.add_patch(rect)
+ return
+
+ def set_smiles(self, smiles, coords, symbols, edges, molfile=None):
+ rounded_coords = [[round(coord[0], 3), round(coord[1], 3)] for coord in coords] #pred['oringinal_coords'],pred['original_symbols'],pred['orignal_edges']
+ self.data['smiles'] = smiles
+ self.data['coords'] = rounded_coords
+ self.data['symbols'] = symbols
+ self.data['edges'] = edges
+ if molfile:
+ self.data['molfile'] = molfile
+
+ def set_text(self, text):
+ self.data['text'] = text
+
+ def to_json(self):
+ return self.data
+
+
+class Reaction(object):
+
+ def __init__(self, reaction=None, bboxes=None, image_data=None):
+ '''
+ if image_data is None, create from prediction
+ if image_data is not None, create from groundtruth
+ '''
+ self.reactants = []
+ self.conditions = []
+ self.products = []
+ self.bboxes = []
+ if reaction is not None:
+ for x in reaction['reactants']:
+ bbox = bboxes[x] if type(x) is int else BBox(x, image_data, xyxy=True, normalized=True)
+ self.bboxes.append(bbox)
+ self.reactants.append(len(self.bboxes) - 1)
+ for x in reaction['conditions']:
+ bbox = bboxes[x] if type(x) is int else BBox(x, image_data, xyxy=True, normalized=True)
+ self.bboxes.append(bbox)
+ self.conditions.append(len(self.bboxes) - 1)
+ for x in reaction['products']:
+ bbox = bboxes[x] if type(x) is int else BBox(x, image_data, xyxy=True, normalized=True)
+ self.bboxes.append(bbox)
+ self.products.append(len(self.bboxes) - 1)
+
+ def to_json(self):
+ return {
+ 'reactants': [self.bboxes[i].to_json() for i in self.reactants],
+ 'conditions': [self.bboxes[i].to_json() for i in self.conditions],
+ 'products': [self.bboxes[i].to_json() for i in self.products]
+ }
+
+ def _deduplicate_bboxes(self, indices):
+ results = []
+ for i, idx_i in enumerate(indices):
+ duplicate = False
+ for j, idx_j in enumerate(indices[:i]):
+ if get_iou(self.bboxes[idx_i], self.bboxes[idx_j]) > 0.6:
+ duplicate = True
+ break
+ if not duplicate:
+ results.append(idx_i)
+ return results
+
+ def deduplicate(self):
+ flags = [False] * len(self.bboxes)
+ bbox_list = self.reactants + self.products + self.conditions
+ for i, idx_i in enumerate(bbox_list):
+ if self.bboxes[idx_i].is_empty:
+ flags[idx_i] = True
+ continue
+ for idx_j in bbox_list[:i]:
+ if flags[idx_j] is False and get_iou(self.bboxes[idx_i], self.bboxes[idx_j]) > 0.6:
+ flags[idx_i] = True
+ break
+ self.reactants = [i for i in self.reactants if not flags[i]]
+ self.conditions = [i for i in self.conditions if not flags[i]]
+ self.products = [i for i in self.products if not flags[i]]
+
+ def schema(self, mol_only=False):
+ # Return reactants, conditions, and products. If mol_only is True, only include bboxes that are mol structures.
+ if mol_only:
+ reactants, conditions, products = [[idx for idx in indices if self.bboxes[idx].is_mol]
+ for indices in [self.reactants, self.conditions, self.products]]
+ # It would be unfair to compare two reactions if their reactants or products are empty after filtering.
+ # Setting them to the original ones in this case.
+ if len(reactants) == 0:
+ reactants = self.reactants
+ if len(products) == 0:
+ products = self.products
+ return reactants, conditions, products
+ else:
+ return self.reactants, self.conditions, self.products
+
+ def compare(self, other, mol_only=False, merge_condition=False, debug=False):
+ reactants1, conditions1, products1 = self.schema(mol_only)
+ reactants2, conditions2, products2 = other.schema(mol_only)
+ if debug:
+ print(reactants1, conditions1, products1, ';', reactants2, conditions2, products2)
+ if len(reactants1) + len(conditions1) + len(products1) == 0:
+ # schema is empty, always return False
+ return False
+ if len(reactants1) + len(conditions1) + len(products1) != len(reactants2) + len(conditions2) + len(products2):
+ return False
+ # Match use original index
+ match1, match2, scores = get_bboxes_match(self.bboxes, other.bboxes, iou_thres=0.5)
+ m_reactants, m_conditions, m_products = [[match1[i] for i in x] for x in [reactants1, conditions1, products1]]
+ if any([m == -1 for m in m_reactants + m_conditions + m_products]):
+ return False
+ if debug:
+ print(m_reactants, m_conditions, m_products, ';', reactants2, conditions2, products2)
+ if merge_condition:
+ return sorted(m_reactants + m_conditions) == sorted(reactants2 + conditions2) \
+ and sorted(m_products) == sorted(products2)
+ else:
+ return sorted(m_reactants) == sorted(reactants2) and sorted(m_conditions) == sorted(conditions2) \
+ and sorted(m_products) == sorted(products2)
+
+ def __eq__(self, other):
+ # Exact matching of two reactions
+ return self.compare(other)
+
+ def draw(self, ax):
+ for i in self.reactants:
+ self.bboxes[i].draw(ax, color='cyan')
+ for i in self.conditions:
+ self.bboxes[i].draw(ax, color='red')
+ for i in self.products:
+ self.bboxes[i].draw(ax, color='orange')
+ return
+
+
+class ReactionSet(object):
+
+ def __init__(self, reactions, bboxes=None, image_data=None):
+ self.reactions = [Reaction(reaction, bboxes, image_data) for reaction in reactions]
+
+ def __len__(self):
+ return len(self.reactions)
+
+ def __iter__(self):
+ return iter(self.reactions)
+
+ def __getitem__(self, item):
+ return self.reactions[item]
+
+ def deduplicate(self):
+ results = []
+ for reaction in self.reactions:
+ if any(r == reaction for r in results):
+ continue
+ if len(reaction.reactants) < 1 or len(reaction.products) < 1:
+ continue
+ results.append(reaction)
+ self.reactions = results
+
+ def to_json(self):
+ return [r.to_json() for r in self.reactions]
+
+
+class ImageData(object):
+
+ def __init__(self, data=None, predictions=None, image_file=None, image=None):
+ self.width, self.height = None, None
+ if data:
+ self.file_name = data['file_name']
+ self.width = data['width']
+ self.height = data['height']
+ if image_file:
+ self.image = cv2.imread(image_file)
+ self.height, self.width, _ = self.image.shape
+ if image is not None:
+ if not isinstance(image, np.ndarray):
+ image = np.asarray(image)
+ self.image = image
+ self.height, self.width, _ = self.image.shape
+ if data and 'bboxes' in data:
+ self.gold_bboxes = [BBox(bbox, self, xyxy=False, normalized=False) for bbox in data['bboxes']]
+ if predictions is not None:
+ self.pred_bboxes = [BBox(bbox, self, xyxy=True, normalized=True) for bbox in predictions]
+
+ def draw_gold(self, ax, image=None):
+ if image is not None:
+ ax.imshow(image)
+ for b in self.gold_bboxes:
+ b.draw(ax)
+
+ def draw_prediction(self, ax, image=None):
+ if image is not None:
+ ax.imshow(image)
+ for b in self.pred_bboxes:
+ b.draw(ax)
+
+
+class ReactionImageData(ImageData):
+
+ def __init__(self, data=None, predictions=None, image_file=None, image=None):
+ super().__init__(data=data, image_file=image_file, image=image)
+ if data and 'reactions' in data:
+ self.gold_reactions = ReactionSet(data['reactions'], self.gold_bboxes, image_data=self)
+ if predictions is not None:
+ self.pred_reactions = ReactionSet(predictions, image_data=self)
+ self.pred_reactions.deduplicate()
+
+ def evaluate(self, mol_only=False, merge_condition=False, debug=False):
+ gold_total = len(self.gold_reactions)
+ gold_hit = [False] * gold_total
+ pred_total = len(self.pred_reactions)
+ pred_hit = [False] * pred_total
+ for i, ri in enumerate(self.gold_reactions):
+ for j, rj in enumerate(self.pred_reactions):
+ if gold_hit[i] and pred_hit[j]:
+ continue
+ if ri.compare(rj, mol_only, merge_condition, debug):
+ gold_hit[i] = True
+ pred_hit[j] = True
+ return gold_hit, pred_hit
+
+
+def get_iou(bb1, bb2):
+ """Calculate the Intersection over Union (IoU) of two bounding boxes."""
+ bb1 = {'x1': bb1.x1, 'y1': bb1.y1, 'x2': bb1.x2, 'y2': bb1.y2}
+ bb2 = {'x1': bb2.x1, 'y1': bb2.y1, 'x2': bb2.x2, 'y2': bb2.y2}
+
+ assert bb1['x1'] < bb1['x2']
+ assert bb1['y1'] < bb1['y2']
+ assert bb2['x1'] < bb2['x2']
+ assert bb2['y1'] < bb2['y2']
+
+ # determine the coordinates of the intersection rectangle
+ x_left = max(bb1['x1'], bb2['x1'])
+ y_top = max(bb1['y1'], bb2['y1'])
+ x_right = min(bb1['x2'], bb2['x2'])
+ y_bottom = min(bb1['y2'], bb2['y2'])
+
+ if x_right < x_left or y_bottom < y_top:
+ return 0.0
+
+ # The intersection of two axis-aligned bounding boxes is always an
+ # axis-aligned bounding box
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
+
+ # compute the area of both AABBs
+ bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1'])
+ bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1'])
+
+ # compute the intersection over union by taking the intersection
+ # area and dividing it by the sum of prediction + ground-truth
+ # areas - the interesection area
+ iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
+ assert iou >= 0.0
+ assert iou <= 1.0
+ return iou
+
+
+def get_bboxes_match(bboxes1, bboxes2, iou_thres=0.5, match_category=False):
+ """Find the match between two sets of bboxes. Each bbox is matched with a bbox with maximum overlap
+ (at least above iou_thres). -1 if a bbox does not have a match."""
+ scores = np.zeros((len(bboxes1), len(bboxes2)))
+ for i, bbox1 in enumerate(bboxes1):
+ for j, bbox2 in enumerate(bboxes2):
+ if match_category and bbox1.category_id != bbox2.category_id:
+ scores[i, j] = 0
+ else:
+ scores[i, j] = get_iou(bbox1, bbox2)
+ match1 = scores.argmax(axis=1)
+ for i in range(len(match1)):
+ if scores[i, match1[i]] < iou_thres:
+ match1[i] = -1
+ match2 = scores.argmax(axis=0)
+ for j in range(len(match2)):
+ if scores[match2[j], j] < iou_thres:
+ match2[j] = -1
+ return match1, match2, scores
+
+
+def deduplicate_reactions(reactions):
+ pred_reactions = ReactionSet(reactions)
+ for r in pred_reactions:
+ r.deduplicate()
+ pred_reactions.deduplicate()
+ return pred_reactions.to_json()
+
+
+def postprocess_reactions(reactions, image_file=None, image=None, molscribe=None, ocr=None, batch_size=32):
+ image_data = ReactionImageData(predictions=reactions, image_file=image_file, image=image)
+ pred_reactions = image_data.pred_reactions
+ for r in pred_reactions:
+ r.deduplicate()
+ pred_reactions.deduplicate()
+ if molscribe:
+ bbox_images, bbox_indices = [], []
+ for i, reaction in enumerate(pred_reactions):
+ for j, bbox in enumerate(reaction.bboxes):
+ if bbox.is_mol:
+ bbox_images.append(bbox.image())
+ bbox_indices.append((i, j))
+ if len(bbox_images) > 0:
+ predictions = molscribe.predict_images(bbox_images, batch_size=batch_size)
+ for (i, j), pred in zip(bbox_indices, predictions):
+ #pred_reactions[i].bboxes[j].set_smiles(pred['smiles'], pred['molfile'])
+ pred_reactions[i].bboxes[j].set_smiles(pred['smiles'],pred['oringinal_coords'],pred['original_symbols'],pred['orignal_edges'])
+ if ocr:
+ for reaction in pred_reactions:
+ for bbox in reaction.bboxes:
+ if not bbox.is_mol:
+ text = ocr.readtext(bbox.image(), detail=0)
+ bbox.set_text(text)
+ return pred_reactions.to_json()
diff --git a/rxn/reaction/dataset.py b/rxn/reaction/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab36efc869fa76e09c697bcb8d4b3922202ae11
--- /dev/null
+++ b/rxn/reaction/dataset.py
@@ -0,0 +1,244 @@
+import os
+import cv2
+import copy
+import random
+import json
+import contextlib
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset
+from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
+
+from . import transforms as T
+
+from pycocotools.coco import COCO
+from PIL import Image
+
+
+class ReactionDataset(Dataset):
+ def __init__(self, args, tokenizer, data_file=None, image_files=None, split='train', debug=False):
+ super().__init__()
+ self.args = args
+ self.tokenizer = tokenizer
+ if data_file:
+ data_file = os.path.join(args.data_path, data_file)
+ with open(data_file) as f:
+ self.data = json.load(f)['images']
+ if split == 'train' and args.num_train_example is not None:
+ self.data = self.data[:args.num_train_example]
+ if split != 'train':
+ with open(os.devnull, 'w') as devnull:
+ with contextlib.redirect_stdout(devnull):
+ self.coco = COCO(data_file)
+ self.name = os.path.basename(data_file).split('.')[0]
+ if image_files:
+ self.data = [{'file_name': file} for file in image_files]
+ self.image_path = args.image_path
+ self.split = split
+ self.format = args.format
+ self.is_train = (split == 'train')
+ self.transform = make_transforms(split, args.augment, debug)
+ # self.reaction_transform = T.RandomReactionCrop()
+
+ def __len__(self):
+ return len(self.data)
+
+ @property
+ def pad_id(self):
+ return self.tokenizer[self.format].PAD_ID
+
+ def generate_sample(self, image, target):
+ ref = {}
+ # coordinates are normalized after transform
+ image, target = self.transform(image, target)
+ ref['scale'] = target['scale']
+ if self.is_train:
+ args = self.args
+ if self.format == 'reaction':
+ max_len = self.tokenizer['reaction'].max_len
+ label, label_out = self.tokenizer['reaction'].data_to_sequence(
+ target, rand_order=args.rand_order, shuffle_bbox=args.shuffle_bbox, add_noise=args.add_noise,
+ mix_noise=args.mix_noise)
+ ref['reaction'] = torch.LongTensor(label[:max_len])
+ ref['reaction_out'] = torch.LongTensor(label_out[:max_len])
+ if self.format == 'bbox':
+ max_len = self.tokenizer['bbox'].max_len
+ label, label_out = self.tokenizer['bbox'].data_to_sequence(
+ target, rand_order=args.rand_order, add_noise=args.add_noise)
+ ref['bbox'] = torch.LongTensor(label[:max_len])
+ ref['bbox_out'] = torch.LongTensor(label_out[:max_len])
+ return image, ref
+
+ def __getitem__(self, idx):
+ image, target = self.load_and_prepare(idx)
+ if self.is_train and self.args.composite_augment:
+ cnt = 0
+ while idx % 2 == random.randrange(2) and cnt < 5:
+ # Augment with probability 0.5
+ n = len(self)
+ idx2 = (idx + random.randrange(n)) % n
+ image2, target2 = self.load_and_prepare(idx2)
+ # if 'reaction' in self.formats:
+ # image, target = self.reaction_transform(image, target)
+ # image2, target2 = self.reaction_transform(image2, target2)
+ image, target = self.concat(image, target, image2, target2)
+ cnt += 1
+ if self.is_train and self.args.augment:
+ image1, ref1 = self.generate_sample(image, target)
+ image2, ref2 = self.generate_sample(image, target)
+ return [[idx, image1, ref1], [idx, image2, ref2]]
+ else:
+ image, ref = self.generate_sample(image, target)
+ ref['file_name'] = self.data[idx]['file_name']
+ return [[idx, image, ref]]
+
+ def load_and_prepare(self, idx):
+ target = self.data[idx]
+ path = os.path.join(self.image_path, target['file_name'])
+ if not os.path.exists(path):
+ print(path, "doesn't exists.", flush=True)
+ image = Image.open(path).convert("RGB")
+ if self.is_train:
+ image, target = self.prepare(image, target)
+ return image, target
+
+ def prepare(self, image, target):
+ w, h = target['width'], target['height']
+
+ image_id = target["id"]
+ image_id = torch.tensor([image_id])
+
+ anno = target["bboxes"]
+
+ boxes = [obj['bbox'] for obj in anno]
+ # guard against no boxes via resizing
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2].clamp_(min=0, max=w)
+ boxes[:, 1::2].clamp_(min=0, max=h)
+
+ classes = [obj["category_id"] for obj in anno]
+ classes = torch.tensor(classes, dtype=torch.int64)
+
+ target = copy.deepcopy(target)
+ target["boxes"] = boxes
+ target["labels"] = classes
+ target["image_id"] = image_id
+
+ # for conversion to coco api
+ area = torch.tensor([obj["bbox"][2] * obj['bbox'][3] for obj in anno])
+ target["area"] = area
+ target["orig_size"] = torch.as_tensor([int(w), int(h)])
+ target["size"] = torch.as_tensor([int(w), int(h)])
+
+ return image, target
+
+ def concat(self, image1, target1, image2, target2):
+ color = (255, 255, 255)
+ if random.random() < 1:
+ # Vertically concat two images
+ w = max(image1.width, image2.width)
+ h = image1.height + image2.height
+ if image1.width > image2.width:
+ x1, y1 = 0, 0
+ x2, y2 = random.randint(0, image1.width - image2.width), image1.height
+ else:
+ x1, y1 = random.randint(0, image2.width - image1.width), 0
+ x2, y2 = 0, image1.height
+ else:
+ # Horizontally concat two images
+ w = image1.width + image2.width
+ h = max(image1.height, image2.height)
+ if image1.height > image2.height:
+ x1, y1 = 0, 0
+ x2, y2 = image1.width, random.randint(0, image1.height - image2.height)
+ else:
+ x1, y1 = 0, random.randint(0, image2.height - image1.height)
+ x2, y2 = image1.width, 0
+ image = Image.new('RGB', (w, h), color)
+ image.paste(image1, (x1, y1))
+ image.paste(image2, (x2, y2))
+ target = {
+ "image_id": target1["image_id"],
+ "orig_size": torch.as_tensor([int(w), int(h)]),
+ "size": torch.as_tensor([int(w), int(h)])
+ }
+ target1["boxes"][:, 0::2] += x1
+ target1["boxes"][:, 1::2] += y1
+ target2["boxes"][:, 0::2] += x2
+ target2["boxes"][:, 1::2] += y2
+ for key in ["boxes", "labels", "area"]:
+ target[key] = torch.cat([target1[key], target2[key]], dim=0)
+ if "reactions" in target1:
+ target["reactions"] = [r for r in target1["reactions"]]
+ nbox = len(target1["boxes"])
+ for r in target2["reactions"]:
+ newr = {}
+ for key, seq in r.items():
+ newr[key] = [x + nbox for x in seq]
+ target["reactions"].append(newr)
+ return image, target
+
+
+def make_transforms(image_set, augment=False, debug=False):
+ normalize = T.Compose([
+ # T.Resize((1333, 1333)),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], debug)
+ ])
+
+ if image_set == 'train' and augment:
+ return T.Compose([
+ T.RandomRotate(),
+ T.RandomHorizontalFlip(),
+ T.LargeScaleJitter(output_size=1333, aug_scale_min=0.3, aug_scale_max=2.0),
+ T.RandomDistortion(0.5, 0.5, 0.5, 0.5),
+ normalize])
+ else:
+ return T.Compose([
+ T.LargeScaleJitter(output_size=1333, aug_scale_min=1.0, aug_scale_max=1.0),
+ normalize])
+
+
+def pad_images(imgs):
+ # B, C, H, W
+ max_shape = [0, 0]
+ for img in imgs:
+ for i in range(len(max_shape)):
+ max_shape[i] = max(max_shape[i], img.shape[-1-i])
+ stack = []
+ for img in imgs:
+ pad = []
+ for i in range(len(max_shape)):
+ pad = pad + [0, max_shape[i] - img.shape[-1-i]]
+ stack.append(F.pad(img, pad, value=0))
+ return torch.stack(stack)
+
+
+def get_collate_fn(pad_id):
+ def rxn_collate(batch):
+ ids = []
+ imgs = []
+ batch = [ex for seq in batch for ex in seq]
+ keys = list(batch[0][2].keys())
+ seq_formats = [key for key in keys if key in ['bbox', 'bbox_out', 'reaction', 'reaction_out']]
+ refs = {key: [[], []] for key in seq_formats}
+ for ex in batch:
+ ids.append(ex[0])
+ imgs.append(ex[1])
+ ref = ex[2]
+ for key in seq_formats:
+ refs[key][0].append(ref[key])
+ refs[key][1].append(torch.LongTensor([len(ref[key])]))
+ # Sequence
+ for key in keys:
+ if key in seq_formats:
+ refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=pad_id)
+ refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1)
+ else:
+ refs[key] = [ex[2][key] for ex in batch]
+ return ids, pad_images(imgs), refs
+
+ return rxn_collate
diff --git a/rxn/reaction/evaluate.py b/rxn/reaction/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7ecab5cb9fd6fb4a44205440e54f18165974f0e
--- /dev/null
+++ b/rxn/reaction/evaluate.py
@@ -0,0 +1,145 @@
+import os
+import contextlib
+import copy
+import numpy as np
+
+from pycocotools.cocoeval import COCOeval
+from pycocotools.coco import COCO
+
+from .data import ImageData, ReactionImageData
+
+
+class CocoEvaluator(object):
+
+ def __init__(self, coco_gt):
+ coco_gt = copy.deepcopy(coco_gt)
+ self.coco_gt = coco_gt
+
+ def evaluate(self, predictions):
+ img_ids, results = self.prepare(predictions, 'bbox')
+ if len(results) == 0:
+ return np.zeros((12,))
+ coco_dt = self.coco_gt.loadRes(results)
+ cocoEval = COCOeval(self.coco_gt, coco_dt, 'bbox')
+ cocoEval.params.imgIds = img_ids
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ self.cocoEval = cocoEval
+ return cocoEval.stats
+
+ def prepare(self, predictions, iou_type):
+ if iou_type == "bbox":
+ return self.prepare_for_coco_detection(predictions)
+ else:
+ raise ValueError("Unknown iou type {}".format(iou_type))
+
+ def prepare_for_coco_detection(self, predictions):
+ img_ids = []
+ coco_results = []
+ for idx, prediction in enumerate(predictions):
+ if len(prediction) == 0:
+ continue
+
+ image = self.coco_gt.dataset['images'][idx]
+ img_ids.append(image['id'])
+ width = image['width']
+ height = image['height']
+
+ coco_results.extend(
+ [
+ {
+ "image_id": image['id'],
+ "category_id": pred['category_id'],
+ "bbox": convert_to_xywh(pred['bbox'], width, height),
+ "score": pred['score'],
+ }
+ for pred in prediction
+ ]
+ )
+ return img_ids, coco_results
+
+
+def convert_to_xywh(box, width, height):
+ xmin, ymin, xmax, ymax = box
+ return [xmin * width, ymin * height, (xmax - xmin) * width, (ymax - ymin) * height]
+
+
+EMPTY_STATS = {'gold_hits': 0, 'gold_total': 0, 'pred_hits': 0, 'pred_total': 0, 'image': 0}
+
+
+class ReactionEvaluator(object):
+
+ def evaluate_image(self, gold_image, pred_image, **kwargs):
+ data = ReactionImageData(gold_image, pred_image)
+ return data.evaluate(**kwargs)
+
+ def compute_metrics(self, gold_hits, gold_total, pred_hits, pred_total):
+ precision = pred_hits / max(pred_total, 1)
+ recall = gold_hits / max(gold_total, 1)
+ f1 = precision * recall * 2 / max(precision + recall, 1e-6)
+ return {'precision': precision, 'recall': recall, 'f1': f1}
+
+ def evaluate(self, groundtruths, predictions, **kwargs):
+ gold_hits, gold_total, pred_hits, pred_total = 0, 0, 0, 0
+ for gold_image, pred_image in zip(groundtruths, predictions):
+ gh, ph = self.evaluate_image(gold_image, pred_image, **kwargs)
+ gold_hits += sum(gh)
+ gold_total += len(gh)
+ pred_hits += sum(ph)
+ pred_total += len(ph)
+ return self.compute_metrics(gold_hits, gold_total, pred_hits, pred_total)
+
+ def evaluate_by_size(self, groundtruths, predictions, **kwargs):
+ group_stats = {}
+ for gold_image, pred_image in zip(groundtruths, predictions):
+ gh, ph = self.evaluate_image(gold_image, pred_image, **kwargs)
+ gtotal = len(gh)
+ if gtotal not in group_stats:
+ group_stats[gtotal] = copy.deepcopy(EMPTY_STATS)
+ group_stats[gtotal]['gold_hits'] += sum(gh)
+ group_stats[gtotal]['gold_total'] += len(gh)
+ group_stats[gtotal]['pred_hits'] += sum(ph)
+ group_stats[gtotal]['pred_total'] += len(ph)
+ group_stats[gtotal]['image'] += 1
+ group_scores = {}
+ for gtotal, stats in group_stats.items():
+ group_scores[gtotal] = self.compute_metrics(
+ stats['gold_hits'], stats['gold_total'], stats['pred_hits'], stats['pred_total'])
+ return group_scores, group_stats
+
+ def evaluate_by_group(self, groundtruths, predictions, **kwargs):
+ group_stats = {}
+ for gold_image, pred_image in zip(groundtruths, predictions):
+ gh, ph = self.evaluate_image(gold_image, pred_image, **kwargs)
+ diagram_type = gold_image['diagram_type']
+ if diagram_type not in group_stats:
+ group_stats[diagram_type] = copy.deepcopy(EMPTY_STATS)
+ group_stats[diagram_type]['gold_hits'] += sum(gh)
+ group_stats[diagram_type]['gold_total'] += len(gh)
+ group_stats[diagram_type]['pred_hits'] += sum(ph)
+ group_stats[diagram_type]['pred_total'] += len(ph)
+ group_stats[diagram_type]['image'] += 1
+ group_scores = {}
+ for group, stats in group_stats.items():
+ group_scores[group] = self.compute_metrics(
+ stats['gold_hits'], stats['gold_total'], stats['pred_hits'], stats['pred_total'])
+ return group_scores, group_stats
+
+ def evaluate_summarize(self, groundtruths, predictions, **kwargs):
+ size_scores, size_stats = self.evaluate_by_size(groundtruths, predictions, **kwargs)
+ summarize = {
+ 'overall': copy.deepcopy(EMPTY_STATS),
+ # 'single': copy.deepcopy(EMPTY_STATS),
+ # 'multiple': copy.deepcopy(EMPTY_STATS)
+ }
+ for size, stats in size_stats.items():
+ if type(size) is int:
+ # output = summarize['single'] if size <= 1 else summarize['multiple']
+ for key in stats:
+ # output[key] += stats[key]
+ summarize['overall'][key] += stats[key]
+ scores = {}
+ for key, val in summarize.items():
+ scores[key] = self.compute_metrics(val['gold_hits'], val['gold_total'], val['pred_hits'], val['pred_total'])
+ return scores, summarize, size_stats
diff --git a/rxn/reaction/inference/__init__.py b/rxn/reaction/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c61feef15b9a94dbb97126be593f0c445f1870c0
--- /dev/null
+++ b/rxn/reaction/inference/__init__.py
@@ -0,0 +1,4 @@
+from .greedy_search import GreedySearch
+from .beam_search import BeamSearch
+
+__all__ = ["GreedySearch", "BeamSearch"]
diff --git a/rxn/reaction/inference/beam_search.py b/rxn/reaction/inference/beam_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b2fe9361f564b796913c00fee294f0367af66d1
--- /dev/null
+++ b/rxn/reaction/inference/beam_search.py
@@ -0,0 +1,200 @@
+import torch
+from .decode_strategy import DecodeStrategy
+
+import warnings
+
+
+class BeamSearch(DecodeStrategy):
+ """Generation with beam search.
+ """
+
+ def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length,
+ return_attention, max_length):
+ super(BeamSearch, self).__init__(
+ pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length)
+ self.beam_size = beam_size
+ self.n_best = n_best
+
+ # result caching
+ self.hypotheses = [[] for _ in range(batch_size)]
+
+ # beam state
+ self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool)
+
+ self._batch_offset = torch.arange(batch_size, dtype=torch.long)
+
+ self.select_indices = None
+ self.done = False
+
+ def initialize(self, memory_bank, device=None):
+ """Repeat src objects `beam_size` times.
+ """
+ def fn_map_state(state, dim):
+ return torch.repeat_interleave(state, self.beam_size, dim=dim)
+
+ memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0)
+ if device is None:
+ device = memory_bank.device
+
+ self.memory_length = memory_bank.size(1)
+ super().initialize(memory_bank, device)
+
+ self.best_scores = torch.full(
+ [self.batch_size], -1e10, dtype=torch.float, device=device)
+ self._beam_offset = torch.arange(
+ 0, self.batch_size * self.beam_size, step=self.beam_size,
+ dtype=torch.long, device=device)
+ self.topk_log_probs = torch.tensor(
+ [0.0] + [float("-inf")] * (self.beam_size - 1), device=device
+ ).repeat(self.batch_size)
+ # buffers for the topk scores and 'backpointer'
+ self.topk_scores = torch.empty((self.batch_size, self.beam_size),
+ dtype=torch.float, device=device)
+ self.topk_ids = torch.empty((self.batch_size, self.beam_size),
+ dtype=torch.long, device=device)
+ self._batch_index = torch.empty([self.batch_size, self.beam_size],
+ dtype=torch.long, device=device)
+
+ return fn_map_state, memory_bank
+
+ @property
+ def current_predictions(self):
+ return self.alive_seq[:, -1]
+
+ @property
+ def current_backptr(self):
+ # for testing
+ return self.select_indices.view(self.batch_size, self.beam_size)
+
+ @property
+ def batch_offset(self):
+ return self._batch_offset
+
+ def _pick(self, log_probs):
+ """Return token decision for a step.
+
+ Args:
+ log_probs (FloatTensor): (B, vocab_size)
+
+ Returns:
+ topk_scores (FloatTensor): (B, beam_size)
+ topk_ids (LongTensor): (B, beam_size)
+ """
+ vocab_size = log_probs.size(-1)
+
+ # Flatten probs into a list of probabilities.
+ curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
+ topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1)
+ return topk_scores, topk_ids
+
+ def advance(self, log_probs, attn):
+ """
+ Args:
+ log_probs: (B * beam_size, vocab_size)
+ """
+ vocab_size = log_probs.size(-1)
+
+ # (non-finished) batch_size
+ _B = log_probs.shape[0] // self.beam_size
+
+ step = len(self) # alive_seq
+ self.ensure_min_length(log_probs)
+
+ # Multiply probs by the beam probability
+ log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
+
+ curr_length = step + 1
+ curr_scores = log_probs / curr_length # avg log_prob
+ self.topk_scores, self.topk_ids = self._pick(curr_scores)
+ # topk_scores/topk_ids: (batch_size, beam_size)
+
+ # Recover log probs
+ torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs)
+
+ # Resolve beam origin and map to batch index flat representation.
+ self._batch_index = self.topk_ids // vocab_size
+ self._batch_index += self._beam_offset[:_B].unsqueeze(1)
+ self.select_indices = self._batch_index.view(_B * self.beam_size)
+ self.topk_ids.fmod_(vocab_size) # resolve true word ids
+
+ # Append last prediction.
+ self.alive_seq = torch.cat(
+ [self.alive_seq.index_select(0, self.select_indices),
+ self.topk_ids.view(_B * self.beam_size, 1)], -1)
+
+ if self.return_attention:
+ current_attn = attn.index_select(1, self.select_indices)
+ if step == 1:
+ self.alive_attn = current_attn
+ else:
+ self.alive_attn = self.alive_attn.index_select(
+ 1, self.select_indices)
+ self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
+
+ self.is_finished = self.topk_ids.eq(self.eos)
+ self.ensure_max_length()
+
+ def update_finished(self):
+ _B_old = self.topk_log_probs.shape[0]
+ step = self.alive_seq.shape[-1] # len(self)
+ self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
+
+ self.is_finished = self.is_finished.to('cpu')
+ self.top_beam_finished |= self.is_finished[:, 0].eq(1)
+ predictions = self.alive_seq.view(_B_old, self.beam_size, step)
+ attention = (
+ self.alive_attn.view(
+ step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
+ if self.alive_attn is not None else None)
+ non_finished_batch = []
+ for i in range(self.is_finished.size(0)):
+ b = self._batch_offset[i]
+ finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1)
+ # Store finished hypothesis for this batch.
+ for j in finished_hyp: # Beam level: finished beam j in batch i
+ self.hypotheses[b].append((
+ self.topk_scores[i, j],
+ predictions[i, j, 1:], # Ignore start token
+ attention[:, i, j, :self.memory_length]
+ if attention is not None else None))
+ # End condition is the top beam finished and we can return
+ # n_best hypotheses.
+ finish_flag = self.top_beam_finished[i] != 0
+ if finish_flag and len(self.hypotheses[b]) >= self.n_best:
+ best_hyp = sorted(
+ self.hypotheses[b], key=lambda x: x[0], reverse=True)
+ for n, (score, pred, attn) in enumerate(best_hyp):
+ if n >= self.n_best:
+ break
+ self.scores[b].append(score.item())
+ self.predictions[b].append(pred)
+ self.attention[b].append(
+ attn if attn is not None else [])
+ else:
+ non_finished_batch.append(i)
+ non_finished = torch.tensor(non_finished_batch)
+
+ if len(non_finished) == 0:
+ self.done = True
+ return
+
+ _B_new = non_finished.shape[0]
+ # Remove finished batches for the next step
+ self.top_beam_finished = self.top_beam_finished.index_select(
+ 0, non_finished)
+ self._batch_offset = self._batch_offset.index_select(0, non_finished)
+ non_finished = non_finished.to(self.topk_ids.device)
+ self.topk_log_probs = self.topk_log_probs.index_select(
+ 0, non_finished)
+ self._batch_index = self._batch_index.index_select(0, non_finished)
+ self.select_indices = self._batch_index.view(_B_new * self.beam_size)
+ self.alive_seq = predictions.index_select(0, non_finished) \
+ .view(-1, self.alive_seq.size(-1))
+ self.topk_scores = self.topk_scores.index_select(0, non_finished)
+ self.topk_ids = self.topk_ids.index_select(0, non_finished)
+
+ if self.alive_attn is not None:
+ inp_seq_len = self.alive_attn.size(-1)
+ self.alive_attn = attention.index_select(1, non_finished) \
+ .view(step - 1, _B_new * self.beam_size, inp_seq_len)
+
diff --git a/rxn/reaction/inference/decode_strategy.py b/rxn/reaction/inference/decode_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..394c83e951c29ed9b19f41d16a22eef3e802ab35
--- /dev/null
+++ b/rxn/reaction/inference/decode_strategy.py
@@ -0,0 +1,60 @@
+import torch
+from copy import deepcopy
+
+
+class DecodeStrategy(object):
+ def __init__(self, pad, bos, eos, batch_size, parallel_paths, min_length, max_length,
+ return_attention=False, return_hidden=False):
+ self.pad = pad
+ self.bos = bos
+ self.eos = eos
+
+ self.batch_size = batch_size
+ self.parallel_paths = parallel_paths
+ # result catching
+ self.predictions = [[] for _ in range(batch_size)]
+ self.scores = [[] for _ in range(batch_size)]
+ self.attention = [[] for _ in range(batch_size)]
+ self.hidden = [[] for _ in range(batch_size)]
+
+ self.alive_attn = None
+ self.alive_hidden = None
+
+ self.min_length = min_length
+ self.max_length = max_length
+
+ n_paths = batch_size * parallel_paths
+ self.return_attention = return_attention
+ self.return_hidden = return_hidden
+
+ self.done = False
+
+ def initialize(self, memory_bank, device=None):
+ if device is None:
+ device = torch.device('cpu')
+ self.alive_seq = torch.full(
+ [self.batch_size * self.parallel_paths, 1], self.bos,
+ dtype=torch.long, device=device)
+ self.is_finished = torch.zeros(
+ [self.batch_size, self.parallel_paths],
+ dtype=torch.uint8, device=device)
+
+ return None, memory_bank
+
+ def __len__(self):
+ return self.alive_seq.shape[1]
+
+ def ensure_min_length(self, log_probs):
+ if len(self) <= self.min_length:
+ log_probs[:, self.eos] = -1e20 # forced non-end
+
+ def ensure_max_length(self):
+ if len(self) == self.max_length + 1:
+ self.is_finished.fill_(1)
+
+ def advance(self, log_probs, attn):
+ raise NotImplementedError()
+
+ def update_finished(self):
+ raise NotImplementedError
+
diff --git a/rxn/reaction/inference/greedy_search.py b/rxn/reaction/inference/greedy_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..3878dcc620caa71a298133fc35358abe4c34ad95
--- /dev/null
+++ b/rxn/reaction/inference/greedy_search.py
@@ -0,0 +1,123 @@
+import torch
+from .decode_strategy import DecodeStrategy
+
+
+def sample_with_temperature(logits, sampling_temp, keep_topk):
+ """Select next tokens randomly from the top k possible next tokens.
+
+ Samples from a categorical distribution over the ``keep_topk`` words using
+ the category probabilities ``logits / sampling_temp``.
+ """
+
+ if sampling_temp == 0.0 or keep_topk == 1:
+ # argmax
+ topk_scores, topk_ids = logits.topk(1, dim=-1)
+ if sampling_temp > 0:
+ topk_scores /= sampling_temp
+ else:
+ logits = torch.div(logits, sampling_temp)
+ if keep_topk > 0:
+ top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
+ kth_best = top_values[:, -1].view([-1, 1])
+ kth_best = kth_best.repeat([1, logits.shape[1]]).float()
+ ignore = torch.lt(logits, kth_best)
+ logits = logits.masked_fill(ignore, -10000)
+
+ dist = torch.distributions.Multinomial(logits=logits, total_count=1)
+ topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
+ topk_scores = logits.gather(dim=1, index=topk_ids)
+
+ return topk_ids, topk_scores
+
+
+class GreedySearch(DecodeStrategy):
+ """Select next tokens randomly from the top k possible next tokens.
+ """
+
+ def __init__(self, pad, bos, eos, batch_size, min_length, max_length,
+ return_attention=False, return_hidden=False, sampling_temp=1, keep_topk=1):
+ super().__init__(
+ pad, bos, eos, batch_size, 1, min_length, max_length, return_attention, return_hidden)
+ self.sampling_temp = sampling_temp
+ self.keep_topk = keep_topk
+ self.topk_scores = None
+
+ def initialize(self, memory_bank, device=None):
+ fn_map_state = None
+
+ if device is None:
+ device = memory_bank.device
+
+ self.memory_length = memory_bank.size(1)
+ super().initialize(memory_bank, device)
+
+ self.select_indices = torch.arange(
+ self.batch_size, dtype=torch.long, device=device)
+ self.original_batch_idx = torch.arange(
+ self.batch_size, dtype=torch.long, device=device)
+
+ return fn_map_state, memory_bank
+
+ @property
+ def current_predictions(self):
+ return self.alive_seq[:, -1]
+
+ @property
+ def batch_offset(self):
+ return self.select_indices
+
+ def _pick(self, log_probs):
+ """Function used to pick next tokens.
+ """
+ topk_ids, topk_scores = sample_with_temperature(
+ log_probs, self.sampling_temp, self.keep_topk)
+ return topk_ids, topk_scores
+
+ def advance(self, log_probs, attn=None, hidden=None, label=None):
+ """Select next tokens randomly from the top k possible next tokens.
+ """
+ self.ensure_min_length(log_probs)
+ topk_ids, self.topk_scores = self._pick(log_probs)
+ self.is_finished = topk_ids.eq(self.eos)
+ if label is not None:
+ label = label.view_as(self.is_finished)
+ self.is_finished = label.eq(self.eos)
+ self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1)
+
+ if self.return_attention:
+ if self.alive_attn is None:
+ self.alive_attn = attn
+ else:
+ self.alive_attn = torch.cat([self.alive_attn, attn], 1)
+ if self.return_hidden:
+ if self.alive_hidden is None:
+ self.alive_hidden = hidden
+ else:
+ self.alive_hidden = torch.cat([self.alive_hidden, hidden], 1)
+ self.ensure_max_length()
+
+ def update_finished(self):
+ """Finalize scores and predictions."""
+ finished_batches = self.is_finished.view(-1).nonzero()
+ for b in finished_batches.view(-1):
+ b_orig = self.original_batch_idx[b]
+ # scores/predictions/attention are lists,
+ # (to be compatible with beam-search)
+ self.scores[b_orig].append(self.topk_scores[b, 0].item())
+ self.predictions[b_orig].append(self.alive_seq[b, 1:])
+ self.attention[b_orig].append(
+ self.alive_attn[b, :, :self.memory_length] if self.alive_attn is not None else [])
+ self.hidden[b_orig].append(
+ self.alive_hidden[b, :] if self.alive_hidden is not None else [])
+ self.done = self.is_finished.all()
+ if self.done:
+ return
+ is_alive = ~self.is_finished.view(-1)
+ self.alive_seq = self.alive_seq[is_alive]
+ if self.alive_attn is not None:
+ self.alive_attn = self.alive_attn[is_alive]
+ if self.alive_hidden is not None:
+ self.alive_hidden = self.alive_hidden[is_alive]
+ self.select_indices = is_alive.nonzero().view(-1)
+ self.original_batch_idx = self.original_batch_idx[is_alive]
+ # select_indices is equal to original_batch_idx for greedy search?
diff --git a/rxn/reaction/interface.py b/rxn/reaction/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..c934a5ffda6859cf33626f3a06af7308539e86c8
--- /dev/null
+++ b/rxn/reaction/interface.py
@@ -0,0 +1,164 @@
+import os
+import argparse
+from typing import List
+import PIL
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+from .pix2seq import build_pix2seq_model
+from .tokenizer import get_tokenizer
+from .dataset import make_transforms
+from .data import postprocess_reactions, ReactionImageData
+
+from molscribe import MolScribe
+from huggingface_hub import hf_hub_download
+import easyocr
+
+
+class Reaction:
+
+ def __init__(self, model_path, device=None):
+ """
+ :param model_path: path of the model checkpoint.
+ :param device: torch device, defaults to be CPU.
+ """
+ args = self._get_args()
+ args.format = 'reaction'
+ states = torch.load(model_path, map_location=torch.device('cpu'))
+ if device is None:
+ device = torch.device('cpu')
+ self.device = device
+ self.tokenizer = get_tokenizer(args)
+ self.model = self.get_model(args, self.tokenizer, self.device, states['state_dict'])
+ self.transform = make_transforms('test', augment=False, debug=False)
+ self.molscribe = self.get_molscribe()
+ self.ocr_model = self.get_ocr_model()
+
+ def _get_args(self):
+ parser = argparse.ArgumentParser()
+ # * Backbone
+ parser.add_argument('--backbone', default='resnet50', type=str,
+ help="Name of the convolutional backbone to use")
+ parser.add_argument('--dilation', action='store_true',
+ help="If true, we replace stride with dilation in the last convolutional block (DC5)")
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
+ help="Type of positional embedding to use on top of the image features")
+ # * Transformer
+ parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer")
+ parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer")
+ parser.add_argument('--dim_feedforward', default=1024, type=int,
+ help="Intermediate size of the feedforward layers in the transformer blocks")
+ parser.add_argument('--hidden_dim', default=256, type=int,
+ help="Size of the embeddings (dimension of the transformer)")
+ parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer")
+ parser.add_argument('--nheads', default=8, type=int,
+ help="Number of attention heads inside the transformer's attentions")
+ parser.add_argument('--pre_norm', action='store_true')
+ # Data
+ parser.add_argument('--format', type=str, default='reaction')
+ parser.add_argument('--input_size', type=int, default=1333)
+
+ args = parser.parse_args([])
+ args.pix2seq = True
+ args.pix2seq_ckpt = None
+ args.pred_eos = True
+ return args
+
+ def get_model(self, args, tokenizer, device, model_states):
+ def remove_prefix(state_dict):
+ return {k.replace('model.', ''): v for k, v in state_dict.items()}
+
+ model = build_pix2seq_model(args, tokenizer[args.format])
+ model.load_state_dict(remove_prefix(model_states), strict=False)
+ model.to(device)
+ model.eval()
+ return model
+
+ def get_molscribe(self):
+ ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m.pth")
+ molscribe = MolScribe(ckpt_path, device=self.device)
+ return molscribe
+
+ def get_ocr_model(self):
+ reader = easyocr.Reader(['en'], gpu=(self.device.type == 'cuda'))
+ return reader
+
+ def predict_images(self, input_images: List, batch_size=16, molscribe=False, ocr=False):
+ # images: a list of PIL images
+ device = self.device
+ tokenizer = self.tokenizer['reaction']
+ predictions = []
+ for idx in range(0, len(input_images), batch_size):
+ batch_images = input_images[idx:idx+batch_size]
+ images, refs = zip(*[self.transform(image) for image in batch_images])
+ images = torch.stack(images, dim=0).to(device)
+ with torch.no_grad():
+ pred_seqs, pred_scores = self.model(images, max_len=tokenizer.max_len)
+ for i, (seqs, scores) in enumerate(zip(pred_seqs, pred_scores)):
+ reactions = tokenizer.sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs[i]['scale'])
+ reactions = postprocess_reactions(
+ reactions,
+ image=input_images[i],
+ molscribe=self.molscribe if molscribe else None,
+ ocr=self.ocr_model if ocr else None
+ )
+ predictions.append(reactions)
+ return predictions
+
+ def predict_image(self, image, **kwargs):
+ predictions = self.predict_images([image], **kwargs)
+ return predictions[0]
+
+ def predict_image_files(self, image_files: List, **kwargs):
+ input_images = []
+ for path in image_files:
+ image = PIL.Image.open(path).convert("RGB")
+ input_images.append(image)
+ return self.predict_images(input_images, **kwargs)
+
+ def predict_image_file(self, image_file: str, **kwargs):
+ predictions = self.predict_image_files([image_file], **kwargs)
+ return predictions[0]
+
+ def draw_predictions(self, predictions, image=None, image_file=None):
+ results = []
+ assert image or image_file
+ data = ReactionImageData(predictions=predictions, image=image, image_file=image_file)
+ h, w = np.array([data.height, data.width]) * 10 / max(data.height, data.width)
+ for r in data.pred_reactions:
+ fig, ax = plt.subplots(figsize=(w, h))
+ fig.tight_layout()
+ canvas = FigureCanvasAgg(fig)
+ ax.imshow(data.image)
+ ax.axis('off')
+ r.draw(ax)
+ canvas.draw()
+ buf = canvas.buffer_rgba()
+ results.append(np.asarray(buf))
+ plt.close(fig)
+ return results
+
+ def draw_predictions_combined(self, predictions, image=None, image_file=None):
+ assert image or image_file
+ data = ReactionImageData(predictions=predictions, image=image, image_file=image_file)
+ h, w = np.array([data.height, data.width]) * 10 / max(data.height, data.width)
+ n = len(data.pred_reactions)
+ fig, axes = plt.subplots(n, 1, figsize=(w, h * n))
+ if n == 1:
+ axes = [axes]
+ fig.tight_layout(rect=(0.02, 0.02, 0.99, 0.99))
+ canvas = FigureCanvasAgg(fig)
+ for i, r in enumerate(data.pred_reactions):
+ ax = axes[i]
+ ax.imshow(data.image)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_title(f'reaction # {i}', fontdict={'fontweight': 'bold', 'fontsize': 14})
+ r.draw(ax)
+ canvas.draw()
+ buf = canvas.buffer_rgba()
+ result_image = np.asarray(buf)
+ plt.close(fig)
+ return result_image
diff --git a/rxn/reaction/loss.py b/rxn/reaction/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9a86eabfa2e8c5d403774cb982cdcd1a925ab2
--- /dev/null
+++ b/rxn/reaction/loss.py
@@ -0,0 +1,92 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LabelSmoothingLoss(nn.Module):
+ """
+ With label smoothing,
+ KL-divergence between q_{smoothed ground truth prob.}(w)
+ and p_{prob. computed by model}(w) is minimized.
+ """
+ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
+ assert 0.0 < label_smoothing <= 1.0
+ self.ignore_index = ignore_index
+ super(LabelSmoothingLoss, self).__init__()
+
+ smoothing_value = label_smoothing / (tgt_vocab_size - 2)
+ one_hot = torch.full((tgt_vocab_size,), smoothing_value)
+ one_hot[self.ignore_index] = 0
+ self.register_buffer('one_hot', one_hot.unsqueeze(0))
+
+ self.confidence = 1.0 - label_smoothing
+
+ def forward(self, output, target):
+ """
+ output (FloatTensor): batch_size x n_classes
+ target (LongTensor): batch_size
+ """
+ # assuming output is raw logits
+ # convert to log_probs
+ log_probs = F.log_softmax(output, dim=-1)
+
+ model_prob = self.one_hot.repeat(target.size(0), 1)
+ model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
+ model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
+
+ # reduction mean or sum?
+ return F.kl_div(log_probs, model_prob, reduction='batchmean')
+
+
+class SequenceLoss(nn.Module):
+
+ def __init__(self, label_smoothing, vocab_size, ignore_index=-100, ignore_indices=[]):
+ super(SequenceLoss, self).__init__()
+ if ignore_indices:
+ ignore_index = ignore_indices[0]
+ self.ignore_index = ignore_index
+ self.ignore_indices = ignore_indices
+ if label_smoothing == 0:
+ self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean')
+ # Cross entropy = KL divergence + constant
+ else:
+ self.criterion = LabelSmoothingLoss(label_smoothing, vocab_size, ignore_index)
+
+ def forward(self, output, target):
+ """
+ :param output: [batch, len, vocab]
+ :param target: [batch, len]
+ :return:
+ """
+ batch_size, max_len, vocab_size = output.size()
+ output = output.reshape(-1, vocab_size)
+ target = target.reshape(-1)
+ for idx in self.ignore_indices:
+ if idx != self.ignore_index:
+ target.masked_fill_((target == idx), self.ignore_index)
+ loss = self.criterion(output, target)
+ return loss
+
+
+class Criterion(nn.Module):
+
+ def __init__(self, args, tokenizer):
+ super(Criterion, self).__init__()
+ criterion = {}
+ format = args.format
+ tn = tokenizer[format]
+ criterion[format] = SequenceLoss(args.label_smoothing, len(tn), ignore_index=tn.PAD_ID)
+ self.criterion = nn.ModuleDict(criterion)
+
+ def forward(self, results, refs):
+ losses = {}
+ for format_ in results:
+ predictions, targets, *_ = results[format_]
+ loss_ = self.criterion[format_](predictions, targets)
+ if type(loss_) is dict:
+ losses.update(loss_)
+ else:
+ if loss_.numel() > 1:
+ loss_ = loss_.mean()
+ losses[format_] = loss_
+ return losses
diff --git a/rxn/reaction/model.py b/rxn/reaction/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..db197725949cd19ee82521399dc81d5c65e3c019
--- /dev/null
+++ b/rxn/reaction/model.py
@@ -0,0 +1,260 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import timm
+
+from .inference import GreedySearch, BeamSearch
+from .transformer import TransformerDecoder, Embeddings
+
+
+class Encoder(nn.Module):
+ def __init__(self, args, pretrained=False):
+ super().__init__()
+ model_name = args.encoder
+ self.model_name = model_name
+ if model_name.startswith('resnet'):
+ self.model_type = 'resnet'
+ self.cnn = timm.create_model(model_name, pretrained=pretrained)
+ self.n_features = self.cnn.num_features # encoder_dim
+ self.cnn.global_pool = nn.Identity()
+ self.cnn.fc = nn.Identity()
+ elif model_name.startswith('swin'):
+ self.model_type = 'swin'
+ self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False,
+ use_checkpoint=args.use_checkpoint)
+ self.n_features = self.transformer.num_features
+ self.transformer.head = nn.Identity()
+ elif 'efficientnet' in model_name:
+ self.model_type = 'efficientnet'
+ self.cnn = timm.create_model(model_name, pretrained=pretrained)
+ self.n_features = self.cnn.num_features
+ self.cnn.global_pool = nn.Identity()
+ self.cnn.classifier = nn.Identity()
+ else:
+ raise NotImplemented
+
+ def swin_forward(self, transformer, x):
+ x = transformer.patch_embed(x)
+ if transformer.absolute_pos_embed is not None:
+ x = x + transformer.absolute_pos_embed
+ x = transformer.pos_drop(x)
+
+ def layer_forward(layer, x, hiddens):
+ for blk in layer.blocks:
+ if not torch.jit.is_scripting() and layer.use_checkpoint:
+ x = torch.utils.checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ H, W = layer.input_resolution
+ B, L, C = x.shape
+ hiddens.append(x.view(B, H, W, C))
+ if layer.downsample is not None:
+ x = layer.downsample(x)
+ return x, hiddens
+
+ hiddens = []
+ for layer in transformer.layers:
+ x, hiddens = layer_forward(layer, x, hiddens)
+ x = transformer.norm(x) # B L C
+ hiddens[-1] = x.view_as(hiddens[-1])
+ return x, hiddens
+
+ def forward(self, x, refs=None):
+ if self.model_type in ['resnet', 'efficientnet']:
+ features = self.cnn(x)
+ features = features.permute(0, 2, 3, 1)
+ hiddens = []
+ elif self.model_type == 'swin':
+ if 'patch' in self.model_name:
+ features, hiddens = self.swin_forward(self.transformer, x)
+ else:
+ features, hiddens = self.transformer(x)
+ else:
+ raise NotImplemented
+ return features, hiddens
+
+
+class TransformerDecoderBase(nn.Module):
+
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+
+ self.enc_trans_layer = nn.Sequential(
+ nn.Linear(args.encoder_dim, args.dec_hidden_size)
+ # nn.LayerNorm(args.dec_hidden_size, eps=1e-6)
+ )
+ self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None
+
+ self.decoder = TransformerDecoder(
+ num_layers=args.dec_num_layers,
+ d_model=args.dec_hidden_size,
+ heads=args.dec_attn_heads,
+ d_ff=args.dec_hidden_size * 4,
+ copy_attn=False,
+ self_attn_type="scaled-dot",
+ dropout=args.hidden_dropout,
+ attention_dropout=args.attn_dropout,
+ max_relative_positions=args.max_relative_positions,
+ aan_useffn=False,
+ full_context_alignment=False,
+ alignment_layer=0,
+ alignment_heads=0,
+ pos_ffn_activation_fn='gelu'
+ )
+
+ def enc_transform(self, encoder_out):
+ batch_size = encoder_out.size(0)
+ encoder_dim = encoder_out.size(-1)
+ encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
+ max_len = encoder_out.size(1)
+ device = encoder_out.device
+ if self.enc_pos_emb:
+ pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0)
+ encoder_out = encoder_out + pos_emb
+ encoder_out = self.enc_trans_layer(encoder_out)
+ return encoder_out
+
+
+class TransformerDecoderAR(TransformerDecoderBase):
+
+ def __init__(self, args, tokenizer):
+ super().__init__(args)
+ self.tokenizer = tokenizer
+ self.vocab_size = len(self.tokenizer)
+ self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True)
+ self.embeddings = Embeddings(
+ word_vec_size=args.dec_hidden_size,
+ word_vocab_size=self.vocab_size,
+ word_padding_idx=tokenizer.PAD_ID,
+ position_encoding=True,
+ dropout=args.hidden_dropout)
+
+ def dec_embedding(self, tgt, step=None):
+ pad_idx = self.embeddings.word_padding_idx
+ tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt]
+ emb = self.embeddings(tgt, step=step)
+ assert emb.dim() == 3 # batch x len x embedding_dim
+ return emb, tgt_pad_mask
+
+ def forward(self, encoder_out, labels, label_lengths):
+ batch_size, max_len, _ = encoder_out.size()
+ memory_bank = self.enc_transform(encoder_out)
+
+ tgt = labels.unsqueeze(-1) # (b, t, 1)
+ tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
+ dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask)
+
+ logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v)
+ return logits[:, :-1], labels[:, 1:], dec_out
+
+ def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256):
+ batch_size, max_len, _ = encoder_out.size()
+ memory_bank = self.enc_transform(encoder_out)
+
+ if beam_size == 1:
+ decode_strategy = GreedySearch(
+ sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length,
+ pad=self.tokenizer.PAD_ID, bos=self.tokenizer.SOS_ID, eos=self.tokenizer.EOS_ID,
+ return_attention=False, return_hidden=True)
+ else:
+ decode_strategy = BeamSearch(
+ beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length,
+ pad=self.tokenizer.PAD_ID, bos=self.tokenizer.SOS_ID, eos=self.tokenizer.EOS_ID,
+ return_attention=False)
+
+ # adapted from onmt.translate.translator
+ results = {
+ "predictions": None,
+ "scores": None,
+ "attention": None
+ }
+
+ # (2) prep decode_strategy. Possibly repeat src objects.
+ _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank)
+
+ # (3) Begin decoding step by step:
+ for step in range(decode_strategy.max_length):
+ tgt = decode_strategy.current_predictions.view(-1, 1, 1)
+ tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
+ dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank,
+ tgt_pad_mask=tgt_pad_mask, step=step)
+
+ attn = dec_attn.get("std", None)
+
+ dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v]
+ dec_logits = dec_logits.squeeze(1)
+ log_probs = F.log_softmax(dec_logits, dim=-1)
+
+ if self.tokenizer.output_constraint:
+ output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()]
+ output_mask = torch.tensor(output_mask, device=log_probs.device)
+ log_probs.masked_fill_(output_mask, -10000)
+
+ decode_strategy.advance(log_probs, attn, dec_out)
+ any_finished = decode_strategy.is_finished.any()
+ if any_finished:
+ decode_strategy.update_finished()
+ if decode_strategy.done:
+ break
+
+ select_indices = decode_strategy.select_indices
+ if any_finished:
+ # Reorder states.
+ memory_bank = memory_bank.index_select(0, select_indices)
+ self.map_state(lambda state, dim: state.index_select(dim, select_indices))
+
+ results["scores"] = decode_strategy.scores
+ results["predictions"] = decode_strategy.predictions
+ results["attention"] = decode_strategy.attention
+ results["hidden"] = decode_strategy.hidden
+
+ return results["predictions"], results['scores'], results["hidden"]
+
+ # adapted from onmt.decoders.transformer
+ def map_state(self, fn):
+ def _recursive_map(struct, batch_dim=0):
+ for k, v in struct.items():
+ if v is not None:
+ if isinstance(v, dict):
+ _recursive_map(v)
+ else:
+ struct[k] = fn(v, batch_dim)
+ if self.decoder.state["cache"] is not None:
+ _recursive_map(self.decoder.state["cache"])
+
+
+class Decoder(nn.Module):
+
+ def __init__(self, args, tokenizer):
+ super(Decoder, self).__init__()
+ self.args = args
+ self.formats = args.formats
+ self.tokenizer = tokenizer
+ decoder = {}
+ for format_ in args.formats:
+ decoder[format_] = TransformerDecoderAR(args, tokenizer[format_])
+ self.decoder = nn.ModuleDict(decoder)
+
+ def forward(self, encoder_out, hiddens, refs):
+ results = {}
+ for format_ in self.formats:
+ labels, label_lengths = refs[format_]
+ results[format_] = self.decoder[format_](encoder_out, labels, label_lengths)
+ return results
+
+ def decode(self, encoder_out, hiddens, refs=None, beam_size=1, n_best=1):
+ results = {}
+ predictions = {}
+ beam_predictions = {}
+ for format_ in self.formats:
+ max_len = self.tokenizer[format_].max_len
+ results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len)
+ outputs, scores, *_ = results[format_]
+ beam_preds = [[self.tokenizer[format_].sequence_to_data(x.tolist()) for x in pred] for pred in outputs]
+ beam_predictions[format_] = (beam_preds, scores)
+ predictions[format_] = [preds[0] for preds in beam_preds]
+ return predictions, beam_predictions
diff --git a/rxn/reaction/pix2seq/__init__.py b/rxn/reaction/pix2seq/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fcc52bb49f98bdcc0fd2e0088c338bbbe65c2a8
--- /dev/null
+++ b/rxn/reaction/pix2seq/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+from .pix2seq import build_pix2seq_model
diff --git a/rxn/reaction/pix2seq/__pycache__/__init__.cpython-310.pyc b/rxn/reaction/pix2seq/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4ff8acdffb8df34bf34e615716aede7b13fb727
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/__init__.cpython-310.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/__init__.cpython-312.pyc b/rxn/reaction/pix2seq/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb8260526bc1bb7d782f8a29a9c96092b1f186e7
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/__init__.cpython-312.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/__init__.cpython-38.pyc b/rxn/reaction/pix2seq/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e2854f9d383c839d19fb569cc9555dea8a2a8a3
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/__init__.cpython-38.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/attention_layer.cpython-310.pyc b/rxn/reaction/pix2seq/__pycache__/attention_layer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4542991a2e35304c1dba35d4366fb20e1494f87
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/attention_layer.cpython-310.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/attention_layer.cpython-38.pyc b/rxn/reaction/pix2seq/__pycache__/attention_layer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2cb68c2db1d8747c85ca126831f9fe5f9dde073
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/attention_layer.cpython-38.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/backbone.cpython-310.pyc b/rxn/reaction/pix2seq/__pycache__/backbone.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07d79eb214ef646e5462e462330067eabbd0485c
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/backbone.cpython-310.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/backbone.cpython-38.pyc b/rxn/reaction/pix2seq/__pycache__/backbone.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f491fc8c03cd2b7ca6332e3a14f4cb0a37b83d8
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/backbone.cpython-38.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/misc.cpython-310.pyc b/rxn/reaction/pix2seq/__pycache__/misc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b35ebae55b713aa658b68474a0d204dbbf157fc
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/misc.cpython-310.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/misc.cpython-312.pyc b/rxn/reaction/pix2seq/__pycache__/misc.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc88a0bf36e2633ab3e63101e5d283fb42ccb9a5
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/misc.cpython-312.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/misc.cpython-38.pyc b/rxn/reaction/pix2seq/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f1897975b9c70d2a167a7da361e8d34c5987704
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/misc.cpython-38.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-310.pyc b/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e6bb35e4f11d42cab6a7446655048461b897476
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-310.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-312.pyc b/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a9e9b34b3c978e83a7a0037896dea917247fb95e
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-312.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-38.pyc b/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86d06daeee89ed3fcc7aea3960e8d8e5cae871a8
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/pix2seq.cpython-38.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/position_encoding.cpython-310.pyc b/rxn/reaction/pix2seq/__pycache__/position_encoding.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b8e04097eb31b4b07b8c95c000d9db8712c5b20
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/position_encoding.cpython-310.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/position_encoding.cpython-38.pyc b/rxn/reaction/pix2seq/__pycache__/position_encoding.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf0ed05e347023fbc9bccb12e9ae27562243f688
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/position_encoding.cpython-38.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/transformer.cpython-310.pyc b/rxn/reaction/pix2seq/__pycache__/transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..319b56432c7974d37b97d05768d4baf2c093237f
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/transformer.cpython-310.pyc differ
diff --git a/rxn/reaction/pix2seq/__pycache__/transformer.cpython-38.pyc b/rxn/reaction/pix2seq/__pycache__/transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65bb40d194126c25cf6efc23c5f88fcc1da6137a
Binary files /dev/null and b/rxn/reaction/pix2seq/__pycache__/transformer.cpython-38.pyc differ
diff --git a/rxn/reaction/pix2seq/attention_layer.py b/rxn/reaction/pix2seq/attention_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ddf98672787801d999645ced04887d0c13de31c
--- /dev/null
+++ b/rxn/reaction/pix2seq/attention_layer.py
@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, dropout=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3)
+ self.attn_drop = nn.Dropout(dropout)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x, pre_kv=None, attn_mask=None):
+ N, B, C = x.shape
+ qkv = self.qkv(x).reshape(N, B, 3, self.num_heads, C // self.num_heads).permute(2, 1, 3, 0, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ if not self.training:
+ k = torch.cat([pre_kv[0], k], dim=2)
+ v = torch.cat([pre_kv[1], v], dim=2)
+ pre_kv = torch.stack([k, v], dim=0)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+
+ if attn_mask is not None:
+ attn.masked_fill_(attn_mask, float('-inf'))
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).permute(2, 0, 1, 3).reshape(N, B, C)
+ x = self.proj(x)
+ return x, pre_kv
diff --git a/rxn/reaction/pix2seq/backbone.py b/rxn/reaction/pix2seq/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..f489ff23abeac88a624b95506c18ed2ebd1221d4
--- /dev/null
+++ b/rxn/reaction/pix2seq/backbone.py
@@ -0,0 +1,119 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Backbone modules.
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from typing import Dict, List
+
+from .misc import NestedTensor, is_main_process
+from .position_encoding import build_position_encoding
+
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other models than torchvision.models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = 1e-5
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
+ super().__init__()
+ for name, parameter in backbone.named_parameters():
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+ parameter.requires_grad_(False)
+ if return_interm_layers:
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ else:
+ return_layers = {'layer4': "0"}
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ self.num_channels = num_channels
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self.body(tensor_list.tensors)
+ out: Dict[str, NestedTensor] = {}
+ for name, x in xs.items():
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ return out
+
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+ def __init__(self, name: str,
+ train_backbone: bool,
+ return_interm_layers: bool,
+ dilation: bool):
+ backbone = getattr(torchvision.models, name)(
+ replace_stride_with_dilation=[False, False, dilation],
+ pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
+ # weights="IMAGENET1K_V1"
+ num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
+
+
+class Joiner(nn.Sequential):
+ def __init__(self, backbone, position_embedding):
+ super().__init__(backbone, position_embedding)
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self[0](tensor_list)
+ out: List[NestedTensor] = []
+ pos = []
+ for name, x in xs.items():
+ out.append(x)
+ # position encoding
+ pos.append(self[1](x).to(x.tensors.dtype))
+
+ return out, pos
+
+
+def build_backbone(args):
+ position_embedding = build_position_encoding(args)
+ train_backbone = True
+ return_interm_layers = False
+ backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
+ model = Joiner(backbone, position_embedding)
+ model.num_channels = backbone.num_channels
+ return model
diff --git a/rxn/reaction/pix2seq/misc.py b/rxn/reaction/pix2seq/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc430d31d4c91c29d4caa10452015364f79d1548
--- /dev/null
+++ b/rxn/reaction/pix2seq/misc.py
@@ -0,0 +1,604 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import os
+import subprocess
+import time
+from collections import defaultdict, deque
+import datetime
+import pickle
+from typing import Optional, List
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+from bisect import bisect_right
+from torch.optim.lr_scheduler import _LRScheduler
+
+# needed due to empty tensor bug in pytorch and torchvision 0.5
+import torchvision
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device="cuda")
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+ if local_size != max_size:
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'
+ ])
+ else:
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
+ sha = 'N/A'
+ diff = "clean"
+ branch = 'N/A'
+ try:
+ sha = _run(['git', 'rev-parse', 'HEAD'])
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
+ diff = _run(['git', 'diff-index', 'HEAD'])
+ diff = "has uncommited changes" if diff else "clean"
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+def collate_fn(batch):
+ batch = list(zip(*batch))
+ if len(batch) > 2:
+ batch[0] = nested_tensor_from_tensor_list(batch[0] + batch[1], batch[2] + batch[3])
+ return tuple([batch[0], batch[2] + batch[3]])
+ else:
+ batch[0] = nested_tensor_from_tensor_list(batch[0], batch[1])
+ return tuple(batch)
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor], target_list=None):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ if target_list is not None:
+ for img, pad_img, m, target in zip(tensor_list, tensor, mask, target_list):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ size = target["size"]
+ m[:size[0], :size[1]] = False
+ else:
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], :img.shape[2]] = False
+ else:
+ raise ValueError('not supported')
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_local_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return int(os.environ['LOCAL_SIZE'])
+
+
+def get_local_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return int(os.environ['LOCAL_RANK'])
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ args.dist_url = 'env://'
+ os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ addr = subprocess.getoutput(
+ 'scontrol show hostname {} | head -n1'.format(node_list))
+ os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['RANK'] = str(proc_id)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['LOCAL_SIZE'] = str(num_gpus)
+ args.dist_url = 'env://'
+ args.world_size = ntasks
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if target.numel() == 0:
+ return [torch.zeros([], device=output.device)]
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
+ """
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
+ This will eventually be supported natively by PyTorch, and this
+ class can go away.
+ """
+ if float(torchvision.__version__.split('+')[0][2:]) < 7.0:
+ if input.numel() > 0:
+ return torch.nn.functional.interpolate(
+ input, size, scale_factor, mode, align_corners
+ )
+
+ output_shape = _output_size(2, input, size, scale_factor)
+ output_shape = list(input.shape[:-2]) + list(output_shape)
+ return _new_empty_tensor(input, output_shape)
+ else:
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+class NoScaler:
+ state_dict_key = "no_scaler"
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
+ loss.backward()
+ if clip_grad is not None and clip_grad > 0:
+ assert parameters is not None
+ torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ optimizer.step()
+
+
+class WarmupLinearDecayLR(_LRScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_factor: float = 0.001,
+ warmup_iters: int = 10,
+ warmup_method: str = "linear",
+ end_epoch: int = 300,
+ final_lr_factor: float = 0.003,
+ last_epoch: int = -1,
+ ):
+ """
+ Multi Step LR with warmup
+
+ Args:
+ optimizer (torch.optim.Optimizer): optimizer used.
+ warmup_factor (float): lr = warmup_factor * base_lr
+ warmup_iters (int): iters to warmup
+ warmup_method (str): warmup method in ["constant", "linear", "burnin"]
+ last_epoch(int): The index of last epoch. Default: -1.
+ """
+ self.warmup_factor = warmup_factor
+ self.warmup_iters = warmup_iters
+ self.warmup_method = warmup_method
+ self.end_epoch = end_epoch
+ assert 0 < final_lr_factor < 1
+ self.final_lr_factor = final_lr_factor
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> List[float]:
+ warmup_factor = _get_warmup_factor_at_iter(
+ self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor)
+ linear_decay_factor = _get_lr_linear_decay_factor_at_iter(
+ self.last_epoch, self.warmup_iters, self.end_epoch, self.final_lr_factor)
+ return [
+ base_lr * warmup_factor * linear_decay_factor for base_lr in self.base_lrs
+ ]
+
+ def _get_closed_form_lr(self):
+ warmup_factor = _get_warmup_factor_at_iter(
+ self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor)
+ linear_decay_factor = _get_lr_linear_decay_factor_at_iter(
+ self.last_epoch, self.warmup_iters, self.end_epoch, self.final_lr_factor)
+ return [
+ base_lr * warmup_factor * linear_decay_factor for base_lr in self.base_lrs
+ ]
+
+
+def _get_lr_linear_decay_factor_at_iter(iter: int, start_epoch: int, end_epoch: int,
+ final_lr_factor: float):
+ assert iter <= end_epoch
+ if iter <= start_epoch:
+ return 1.0
+ alpha = (iter - start_epoch) / (end_epoch - start_epoch)
+ lr_step = final_lr_factor * alpha + 1 - alpha
+
+ return lr_step
+
+
+def _get_warmup_factor_at_iter(method: str, iter: int, warmup_iters: int,
+ warmup_factor: float) -> float:
+ """
+ Return the learning rate warmup factor at a specific iteration.
+ See https://arxiv.org/abs/1706.02677 for more details.
+
+ Args:
+ method (str): warmup method; either "constant" or "linear".
+ iter (int): iteration at which to calculate the warmup factor.
+ warmup_iters (int): the number of warmup iterations.
+ warmup_factor (float): the base warmup factor (the meaning changes according
+ to the method used).
+
+ Returns:
+ float: the effective warmup factor at the given iteration.
+ """
+ if iter >= warmup_iters:
+ return 1.0
+
+ if method == "constant":
+ return warmup_factor
+ elif method == "linear":
+ alpha = iter / warmup_iters
+ return warmup_factor * (1 - alpha) + alpha
+ elif method == "burnin":
+ return (iter / warmup_iters)**4
+ else:
+ raise ValueError("Unknown warmup method: {}".format(method))
diff --git a/rxn/reaction/pix2seq/pix2seq.py b/rxn/reaction/pix2seq/pix2seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eb424fe0891ede52bf34d9deb4f029968748950
--- /dev/null
+++ b/rxn/reaction/pix2seq/pix2seq.py
@@ -0,0 +1,79 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Pix2Seq model and criterion classes.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .misc import nested_tensor_from_tensor_list
+from .backbone import build_backbone
+from .transformer import build_transformer
+
+
+class Pix2Seq(nn.Module):
+ """ This is the Pix2Seq module that performs object detection """
+ def __init__(self, backbone, transformer):
+ """ Initializes the model.
+ Parameters:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_classes: number of object classes
+ num_bins: number of bins for each side of the input image
+ """
+ super().__init__()
+ self.transformer = transformer
+ hidden_dim = transformer.d_model
+ self.input_proj = nn.Sequential(
+ nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=(1, 1)),
+ nn.GroupNorm(32, hidden_dim))
+ self.backbone = backbone
+
+ def forward(self, image_tensor, targets=None, max_len=500):
+ """
+ image_tensor:
+ The forward expects a NestedTensor, which consists of:
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+ It returns a dict with the following elements:
+ - "pred_logits": the classification logits (including no-object) for all vocabulary.
+ Shape= [batch_size, num_sequence, num_vocal]
+ """
+ if isinstance(image_tensor, (list, torch.Tensor)):
+ image_tensor = nested_tensor_from_tensor_list(image_tensor)
+ features, pos = self.backbone(image_tensor)
+
+ src, mask = features[-1].decompose()
+ assert mask is not None
+ mask = torch.zeros_like(mask).bool()
+
+ src = self.input_proj(src)
+ if targets is not None:
+ input_seq, input_len = targets
+ output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1])
+ return output_logits[:, :-1]
+ else:
+ output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len)
+ return output_seqs, output_scores
+
+
+def build_pix2seq_model(args, tokenizer):
+ # the `num_classes` naming here is somewhat misleading.
+ # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
+ # is the maximum id for a class in your dataset. For example,
+ # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
+ # As another example, for a dataset that has a single class with id 1,
+ # you should pass `num_classes` to be 2 (max_obj_id + 1).
+ # For more details on this, check the following discussion
+ # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
+
+ backbone = build_backbone(args)
+ transformer = build_transformer(args, tokenizer)
+
+ model = Pix2Seq(backbone, transformer)
+
+ if args.pix2seq_ckpt is not None:
+ checkpoint = torch.load(args.pix2seq_ckpt, map_location='cpu')
+ model.load_state_dict(checkpoint['model'])
+
+ return model
diff --git a/rxn/reaction/pix2seq/position_encoding.py b/rxn/reaction/pix2seq/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..189461a0f158e45bc2733a0c73b4ed1608ccd6f5
--- /dev/null
+++ b/rxn/reaction/pix2seq/position_encoding.py
@@ -0,0 +1,89 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+
+from .misc import NestedTensor
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = torch.ones_like(mask, dtype=torch.bool)
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = torch.cat([
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return pos
+
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim // 2
+ if args.position_embedding in ('v2', 'sine'):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(N_steps)
+ elif args.position_embedding in ('v3', 'learned'):
+ position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
diff --git a/rxn/reaction/pix2seq/transformer.py b/rxn/reaction/pix2seq/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3c17eebd11a00795fe2d2b6e5bcd4e308bea12f
--- /dev/null
+++ b/rxn/reaction/pix2seq/transformer.py
@@ -0,0 +1,372 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Pix2Seq Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import Optional, List
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+from .attention_layer import Attention
+
+
+class Transformer(nn.Module):
+
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
+ num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
+ activation="relu", normalize_before=False, num_vocal=2094,
+ pred_eos=False, tokenizer=None):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
+ self._reset_parameters()
+
+ self.num_vocal = num_vocal
+ self.vocal_classifier = nn.Linear(d_model, num_vocal)
+ self.det_embed = nn.Embedding(1, d_model)
+ self.vocal_embed = nn.Embedding(self.num_vocal - 2, d_model)
+ self.pred_eos = pred_eos
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.num_decoder_layers = num_decoder_layers
+ self.tokenizer = tokenizer
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, input_seq, mask, pos_embed, max_len=500):
+ """
+ Args:
+ src: shape[B, C, H, W]
+ input_seq: shape[B, 501, C] for training and shape[B, 1, C] for inference
+ mask: shape[B, H, W]
+ pos_embed: shape[B, C, H, W]
+ """
+ # flatten NxCxHxW to HWxNxC
+ bs = src.shape[0]
+ src = src.flatten(2).permute(2, 0, 1)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ pre_kv = [torch.as_tensor([[], []], device=memory.device)
+ for _ in range(self.num_decoder_layers)]
+
+ if self.training:
+ input_seq = input_seq.clamp(max=self.num_vocal - 3)
+ input_embed = torch.cat(
+ [self.det_embed.weight.unsqueeze(0).repeat(bs, 1, 1),
+ self.vocal_embed(input_seq)], dim=1)
+ input_embed = input_embed.transpose(0, 1)
+ num_seq = input_embed.shape[0]
+ self_attn_mask = torch.triu(torch.ones((num_seq, num_seq)), diagonal=1).bool().to(input_embed.device)
+ hs, pre_kv = self.decoder(
+ input_embed,
+ memory,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ pre_kv_list=pre_kv,
+ self_attn_mask=self_attn_mask)
+ # hs: N x B x D
+ pred_seq_logits = self.vocal_classifier(hs.transpose(0, 1))
+ return pred_seq_logits
+ else:
+ end = torch.zeros(bs).bool().to(memory.device)
+ end_lens = torch.zeros(bs).long().to(memory.device)
+ input_embed = self.det_embed.weight.unsqueeze(0).repeat(bs, 1, 1).transpose(0, 1)
+ states, pred_token = [None] * bs, [None] * bs
+ pred_seq, pred_scores = [], []
+ for seq_i in range(max_len):
+ hs, pre_kv = self.decoder(
+ input_embed,
+ memory,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ pre_kv_list=pre_kv)
+ # hs: N x B x D
+ logits = self.vocal_classifier(hs.transpose(0, 1))
+ log_probs = F.log_softmax(logits, dim=-1)
+ if self.tokenizer.output_constraint:
+ states, output_masks = self.tokenizer.update_states_and_masks(states, pred_token)
+ output_masks = torch.tensor(output_masks, device=logits.device).unsqueeze(1)
+ log_probs.masked_fill_(output_masks, -10000)
+ if not self.pred_eos:
+ log_probs[:, :, self.tokenizer.EOS_ID] = -10000
+
+ score, pred_token = log_probs.max(dim=-1)
+ pred_seq.append(pred_token)
+ pred_scores.append(score)
+
+ if self.pred_eos:
+ stop_state = pred_token.squeeze(1).eq(self.tokenizer.EOS_ID)
+ end_lens += seq_i * (~end * stop_state)
+ end = (stop_state + end).bool()
+ if end.all() and seq_i > 4:
+ break
+
+ token = log_probs[:, :, :self.num_vocal - 2].argmax(dim=-1)
+ input_embed = self.vocal_embed(token.transpose(0, 1))
+
+ if not self.pred_eos:
+ end_lens = end_lens.fill_(max_len)
+ pred_seq = torch.cat(pred_seq, dim=1)
+ pred_seq = [seq[:end_idx] for end_idx, seq in zip(end_lens, pred_seq)]
+ pred_scores = torch.cat(pred_scores, dim=1)
+ pred_scores = [scores[:end_idx] for end_idx, scores in zip(end_lens, pred_scores)]
+ return pred_seq, pred_scores
+
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self, src,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ output = src
+
+ for layer in self.layers:
+ output = layer(output, src_key_padding_mask=src_key_padding_mask, pos=pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self, tgt, memory, memory_key_padding_mask, pos, pre_kv_list=None, self_attn_mask=None):
+ output = tgt
+ cur_kv_list = []
+ for layer, pre_kv in zip(self.layers, pre_kv_list):
+ output, cur_kv = layer(
+ output,
+ memory,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ self_attn_mask=self_attn_mask,
+ pre_kv=pre_kv)
+ cur_kv_list.append(cur_kv)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output, cur_kv_list
+
+
+class TransformerEncoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self,
+ src,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(self, src,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(q, k, value=src2, key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(self, src,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(src, src_key_padding_mask, pos)
+ return self.forward_post(src, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = Attention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ self_attn_mask: Optional[Tensor] = None,
+ pre_kv=None,
+ ):
+ tgt2, pre_kv = self.self_attn(tgt, pre_kv=pre_kv, attn_mask=self_attn_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=tgt,
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt, pre_kv
+
+ def forward_pre(
+ self,
+ tgt,
+ memory,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ self_attn_mask: Optional[Tensor] = None,
+ pre_kv=None,
+ ):
+ tgt2 = self.norm1(tgt)
+ tgt2, pre_kv = self.self_attn(tgt2, pre_kv=pre_kv, attn_mask=self_attn_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=tgt2,
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt, pre_kv
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ self_attn_mask: Optional[Tensor] = None,
+ pre_kv=None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, memory_key_padding_mask, pos, self_attn_mask, pre_kv)
+ return self.forward_post(tgt, memory, memory_key_padding_mask, pos, self_attn_mask, pre_kv)
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_transformer(args, tokenizer):
+ num_vocal = len(tokenizer)
+ return Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ num_vocal=num_vocal,
+ pred_eos=args.pred_eos,
+ tokenizer=tokenizer
+ )
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
diff --git a/rxn/reaction/tokenizer.py b/rxn/reaction/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..072c2322dc0d82ba82ef37b52ed2d7dfc3f65bfe
--- /dev/null
+++ b/rxn/reaction/tokenizer.py
@@ -0,0 +1,468 @@
+import json
+import copy
+import random
+import numpy as np
+
+
+PAD = ''
+SOS = ''
+EOS = ''
+UNK = ''
+MASK = ''
+
+Rxn = '[Rxn]' # Reaction
+Rct = '[Rct]' # Reactant
+Prd = '[Prd]' # Product
+Cnd = '[Cnd]' # Condition
+Idt = '[Idt]' # Identifier
+Mol = '[Mol]' # Molecule
+Txt = '[Txt]' # Text
+Sup = '[Sup]' # Supplement
+Noise = '[Nos]'
+
+
+class ReactionTokenizer(object):
+
+ def __init__(self, input_size=100, sep_xy=True, pix2seq=False):
+ self.stoi = {}
+ self.itos = {}
+ self.pix2seq = pix2seq
+ self.maxx = input_size # height
+ self.maxy = input_size # width
+ self.sep_xy = sep_xy
+ self.special_tokens = [PAD, SOS, EOS, UNK, MASK]
+ self.tokens = [Rxn, Rct, Prd, Cnd, Idt, Mol, Txt, Sup, Noise]
+ self.fit_tokens(self.tokens)
+
+ def __len__(self):
+ if self.pix2seq:
+ return 2094
+ if self.sep_xy:
+ return self.offset + self.maxx + self.maxy
+ else:
+ return self.offset + max(self.maxx, self.maxy)
+
+ @property
+ def max_len(self):
+ return 256
+
+ @property
+ def PAD_ID(self):
+ return self.stoi[PAD]
+
+ @property
+ def SOS_ID(self):
+ return self.stoi[SOS]
+
+ @property
+ def EOS_ID(self):
+ return self.stoi[EOS]
+
+ @property
+ def UNK_ID(self):
+ return self.stoi[UNK]
+
+ @property
+ def NOISE_ID(self):
+ return self.stoi[Noise]
+
+ @property
+ def offset(self):
+ return 0 if self.pix2seq else len(self.stoi)
+
+ @property
+ def output_constraint(self):
+ return True
+
+ def fit_tokens(self, tokens):
+ vocab = self.special_tokens + tokens
+ if self.pix2seq:
+ for i, s in enumerate(vocab):
+ self.stoi[s] = 2001 + i
+ self.stoi[EOS] = len(self) - 2
+ # self.stoi[Noise] = len(self) - 1
+ else:
+ for i, s in enumerate(vocab):
+ self.stoi[s] = i
+ self.itos = {item[1]: item[0] for item in self.stoi.items()}
+ self.bbox_category_to_token = {1: Mol, 2: Txt, 3: Idt, 4: Sup}
+ self.token_to_bbox_category = {item[1]: item[0] for item in self.bbox_category_to_token.items()}
+
+ def is_x(self, x):
+ return 0 <= x - self.offset < self.maxx
+
+ def is_y(self, y):
+ if self.sep_xy:
+ return self.maxx <= y - self.offset < self.maxx + self.maxy
+ return 0 <= y - self.offset < self.maxy
+
+ def x_to_id(self, x):
+ if x < -0.001 or x > 1.001:
+ print(x)
+ else:
+ x = min(max(x, 0), 1)
+ assert 0 <= x <= 1
+ return self.offset + round(x * (self.maxx - 1))
+
+ def y_to_id(self, y):
+ if y < -0.001 or y > 1.001:
+ print(y)
+ else:
+ y = min(max(y, 0), 1)
+ assert 0 <= y <= 1
+ if self.sep_xy:
+ return self.offset + self.maxx + round(y * (self.maxy - 1))
+ return self.offset + round(y * (self.maxy - 1))
+
+ def id_to_x(self, id, scale=1):
+ if not self.is_x(id):
+ return -1
+ return (id - self.offset) / (self.maxx - 1) / scale
+
+ def id_to_y(self, id, scale=1):
+ if not self.is_y(id):
+ return -1
+ if self.sep_xy:
+ return (id - self.offset - self.maxx) / (self.maxy - 1) * scale
+ return (id - self.offset) / (self.maxy - 1) / scale
+
+ def update_state(self, state, idx):
+ if state is None:
+ new_state = (Rxn, 'e')
+ else:
+ if state[1] == 'x1':
+ new_state = (state[0], 'y1')
+ elif state[1] == 'y1':
+ new_state = (state[0], 'x2')
+ elif state[1] == 'x2':
+ new_state = (state[0], 'y2')
+ elif state[1] == 'y2':
+ new_state = (state[0], 'c')
+ elif state[1] == 'c':
+ if self.is_x(idx):
+ new_state = (state[0], 'x1')
+ else:
+ new_state = (state[0], 'e')
+ else:
+ if state[0] == Rct:
+ if self.is_x(idx):
+ new_state = (Cnd, 'x1')
+ else:
+ new_state = (Cnd, 'e')
+ elif state[0] == Cnd:
+ new_state = (Prd, 'x1')
+ elif state[0] == Prd:
+ new_state = (Rxn, 'e')
+ elif state[0] == Rxn:
+ if self.is_x(idx):
+ new_state = (Rct, 'x1')
+ else:
+ new_state = (EOS, 'e')
+ else:
+ new_state = (EOS, 'e')
+ return new_state
+
+ def output_mask(self, state):
+ # mask: True means forbidden
+ mask = np.array([True] * len(self))
+ if state[1] in ['y1', 'c']:
+ mask[self.offset:self.offset+self.maxx] = False
+ if state[1] in ['x1', 'x2']:
+ if self.sep_xy:
+ mask[self.offset+self.maxx:self.offset+self.maxx+self.maxy] = False
+ else:
+ mask[self.offset:self.offset+self.maxy] = False
+ if state[1] == 'y2':
+ for token in [Idt, Mol, Txt, Sup]:
+ mask[self.stoi[token]] = False
+ if state[1] == 'c':
+ mask[self.stoi[state[0]]] = False
+ if state[1] == 'e':
+ if state[0] in [Rct, Cnd, Rxn]:
+ mask[self.offset:self.offset + self.maxx] = False
+ if state[0] == Rct:
+ mask[self.stoi[Cnd]] = False
+ if state[0] == Prd:
+ mask[self.stoi[Rxn]] = False
+ mask[self.stoi[Noise]] = False
+ if state[0] in [Rxn, EOS]:
+ mask[self.EOS_ID] = False
+ return mask
+
+ def update_states_and_masks(self, states, ids):
+ new_states = [self.update_state(state, idx) for state, idx in zip(states, ids)]
+ masks = np.array([self.output_mask(state) for state in new_states])
+ return new_states, masks
+
+ def bbox_to_sequence(self, bbox, category):
+ sequence = []
+ x1, y1, x2, y2 = bbox
+ if x1 >= x2 or y1 >= y2:
+ return []
+ sequence.append(self.x_to_id(x1))
+ sequence.append(self.y_to_id(y1))
+ sequence.append(self.x_to_id(x2))
+ sequence.append(self.y_to_id(y2))
+ if category in self.bbox_category_to_token:
+ sequence.append(self.stoi[self.bbox_category_to_token[category]])
+ else:
+ sequence.append(self.stoi[Noise])
+ return sequence
+
+ def sequence_to_bbox(self, sequence, scale=[1, 1]):
+ if len(sequence) < 5:
+ return None
+ x1, y1 = self.id_to_x(sequence[0], scale[0]), self.id_to_y(sequence[1], scale[1])
+ x2, y2 = self.id_to_x(sequence[2], scale[0]), self.id_to_y(sequence[3], scale[1])
+ if x1 == -1 or y1 == -1 or x2 == -1 or y2 == -1 or x1 >= x2 or y1 >= y2 or sequence[4] not in self.itos:
+ return None
+ category = self.itos[sequence[4]]
+ if category not in [Mol, Txt, Idt, Sup]:
+ return None
+ return {'category': category, 'bbox': (x1, y1, x2, y2), 'category_id': self.token_to_bbox_category[category]}
+
+ def perturb_reaction(self, reaction, boxes):
+ reaction = copy.deepcopy(reaction)
+ options = []
+ options.append(0) # Option 0: add
+ if not(len(reaction['reactants']) == 1 and len(reaction['conditions']) == 0 and len(reaction['products']) == 1):
+ options.append(1) # Option 1: delete
+ options.append(2) # Option 2: move
+ choice = random.choice(options)
+ if choice == 0:
+ key = random.choice(['reactants', 'conditions', 'products'])
+ # TODO: insert to a random position
+ # We simply add a random box, which may be a duplicate box in this reaction
+ reaction[key].append(random.randrange(len(boxes)))
+ if choice == 1 or choice == 2:
+ options = []
+ for key, val in [('reactants', 1), ('conditions', 0), ('products', 1)]:
+ if len(reaction[key]) > val:
+ options.append(key)
+ key = random.choice(options)
+ idx = random.randrange(len(reaction[key]))
+ del_box = reaction[key][idx]
+ reaction[key] = reaction[key][:idx] + reaction[key][idx+1:]
+ if choice == 2:
+ options = ['reactants', 'conditions', 'products']
+ options.remove(key)
+ newkey = random.choice(options)
+ reaction[newkey].append(del_box)
+ return reaction
+
+ def augment_reaction(self, reactions, data):
+ area, boxes, labels = data['area'], data['boxes'], data['labels']
+ nonempty_boxes = [i for i in range(len(area)) if area[i] > 0]
+ if len(nonempty_boxes) == 0:
+ return None
+ if len(reactions) == 0 or random.randrange(100) < 20:
+ num_reactants = random.randint(1, 3)
+ num_conditions = random.randint(0, 3)
+ num_products = random.randint(1, 3)
+ reaction = {
+ 'reactants': random.choices(nonempty_boxes, k=num_reactants),
+ 'conditions': random.choices(nonempty_boxes, k=num_conditions),
+ 'products': random.choices(nonempty_boxes, k=num_products)
+ }
+ else:
+ assert len(reactions) > 0
+ reaction = self.perturb_reaction(random.choice(reactions), boxes)
+ return reaction
+
+ def reaction_to_sequence(self, reaction, data, shuffle_bbox=False):
+ reaction = copy.deepcopy(reaction)
+ area, boxes, labels = data['area'], data['boxes'], data['labels']
+ # If reactants or products are empty (because of image cropping), skip the reaction
+ if all([area[i] == 0 for i in reaction['reactants']]) or all([area[i] == 0 for i in reaction['products']]):
+ return []
+ if shuffle_bbox:
+ random.shuffle(reaction['reactants'])
+ random.shuffle(reaction['conditions'])
+ random.shuffle(reaction['products'])
+ sequence = []
+ for idx in reaction['reactants']:
+ if area[idx] == 0:
+ continue
+ sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item())
+ sequence.append(self.stoi[Rct])
+ for idx in reaction['conditions']:
+ if area[idx] == 0:
+ continue
+ sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item())
+ sequence.append(self.stoi[Cnd])
+ for idx in reaction['products']:
+ if area[idx] == 0:
+ continue
+ sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item())
+ sequence.append(self.stoi[Prd])
+ sequence.append(self.stoi[Rxn])
+ return sequence
+
+ def data_to_sequence(self, data, rand_order=False, shuffle_bbox=False, add_noise=False, mix_noise=False):
+ sequence = [self.SOS_ID]
+ sequence_out = [self.SOS_ID]
+ reactions = copy.deepcopy(data['reactions'])
+ reactions_seqs = []
+ for reaction in reactions:
+ seq = self.reaction_to_sequence(reaction, data, shuffle_bbox=shuffle_bbox)
+ reactions_seqs.append([seq, seq])
+ noise_seqs = []
+ if add_noise:
+ total_len = sum(len(seq) for seq, seq_out in reactions_seqs)
+ while total_len < self.max_len:
+ reaction = self.augment_reaction(reactions, data)
+ if reaction is None:
+ break
+ seq = self.reaction_to_sequence(reaction, data)
+ if len(seq) == 0:
+ continue
+ if mix_noise:
+ seq[-1] = self.NOISE_ID
+ seq_out = [self.PAD_ID] * (len(seq) - 1) + [self.NOISE_ID]
+ else:
+ seq_out = [self.PAD_ID] * (len(seq) - 1) + [self.NOISE_ID]
+ noise_seqs.append([seq, seq_out])
+ total_len += len(seq)
+ if rand_order:
+ random.shuffle(reactions_seqs)
+ reactions_seqs += noise_seqs
+ if mix_noise:
+ random.shuffle(reactions_seqs)
+ for seq, seq_out in reactions_seqs:
+ sequence += seq
+ sequence_out += seq_out
+ sequence.append(self.EOS_ID)
+ sequence_out.append(self.EOS_ID)
+ return sequence, sequence_out
+
+ def sequence_to_data(self, sequence, scores=None, scale=None):
+ reactions = []
+ i = 0
+ cur_reaction = {'reactants': [], 'conditions': [], 'products': []}
+ flag = 'reactants'
+ if len(sequence) > 0 and sequence[0] == self.SOS_ID:
+ i += 1
+ while i < len(sequence):
+ if sequence[i] == self.EOS_ID:
+ break
+ if sequence[i] in self.itos:
+ if self.itos[sequence[i]] in [Rxn, Noise]:
+ cur_reaction['label'] = self.itos[sequence[i]]
+ if len(cur_reaction['reactants']) > 0 and len(cur_reaction['products']) > 0:
+ reactions.append(cur_reaction)
+ cur_reaction = {'reactants': [], 'conditions': [], 'products': []}
+ flag = 'reactants'
+ elif self.itos[sequence[i]] == Rct:
+ flag = 'conditions'
+ elif self.itos[sequence[i]] == Cnd:
+ flag = 'products'
+ elif self.itos[sequence[i]] == Prd:
+ flag = None
+ elif i+5 <= len(sequence) and flag is not None:
+ bbox = self.sequence_to_bbox(sequence[i:i+5], scale)
+ if bbox is not None:
+ cur_reaction[flag].append(bbox)
+ i += 4
+ i += 1
+ return reactions
+
+ def sequence_to_tokens(self, sequence):
+ return [self.itos[x] if x in self.itos else x for x in sequence]
+
+
+class BboxTokenizer(ReactionTokenizer):
+
+ def __init__(self, input_size=100, sep_xy=True, pix2seq=False):
+ super(BboxTokenizer, self).__init__(input_size, sep_xy, pix2seq)
+
+ @property
+ def max_len(self):
+ return 500
+
+ @property
+ def output_constraint(self):
+ return False
+
+ def random_category(self):
+ return random.choice(list(self.bbox_category_to_token.keys()))
+ # return random.choice([random.choice(list(self.bbox_category_to_token.keys())), self.NOISE_ID])
+
+ def random_bbox(self):
+ _x1, _y1, _x2, _y2 = random.random(), random.random(), random.random(), random.random()
+ x1, y1, x2, y2 = min(_x1, _x2), min(_y1, _y2), max(_x1, _x2), max(_y1, _y2)
+ category = self.random_category()
+ return [x1, y1, x2, y2], category
+
+ def jitter_bbox(self, bbox, ratio=0.2):
+ x1, y1, x2, y2 = bbox
+ w, h = x2 - x1, y2 - y1
+ _x1 = x1 + random.uniform(-w*ratio, w*ratio)
+ _y1 = y1 + random.uniform(-h*ratio, h*ratio)
+ _x2 = x2 + random.uniform(-w * ratio, w * ratio)
+ _y2 = y2 + random.uniform(-h * ratio, h * ratio)
+ x1, y1, x2, y2 = min(_x1, _x2), min(_y1, _y2), max(_x1, _x2), max(_y1, _y2)
+ category = self.random_category()
+ return np.clip([x1, y1, x2, y2], 0, 1), category
+
+ def augment_box(self, bboxes):
+ if len(bboxes) == 0:
+ return self.random_bbox()
+ if random.random() < 0.5:
+ return self.random_bbox()
+ else:
+ return self.jitter_bbox(random.choice(bboxes))
+
+ def data_to_sequence(self, data, add_noise=False, rand_order=False):
+ sequence = [self.SOS_ID]
+ sequence_out = [self.SOS_ID]
+ if rand_order:
+ perm = np.random.permutation(len(data['boxes']))
+ boxes = data['boxes'][perm].tolist()
+ labels = data['labels'][perm].tolist()
+ else:
+ boxes = data['boxes'].tolist()
+ labels = data['labels'].tolist()
+ for bbox, category in zip(boxes, labels):
+ seq = self.bbox_to_sequence(bbox, category)
+ sequence += seq
+ # sequence[-1] = self.random_category()
+ sequence_out += seq
+ if add_noise:
+ while len(sequence) < self.max_len:
+ bbox, category = self.augment_box(boxes)
+ sequence += self.bbox_to_sequence(bbox, category)
+ sequence_out += [self.PAD_ID] * 4 + [self.NOISE_ID]
+ sequence.append(self.EOS_ID)
+ sequence_out.append(self.EOS_ID)
+ return sequence, sequence_out
+
+ def sequence_to_data(self, sequence, scores=None, scale=None):
+ bboxes = []
+ i = 0
+ if len(sequence) > 0 and sequence[0] == self.SOS_ID:
+ i += 1
+ while i < len(sequence):
+ if sequence[i] == self.EOS_ID:
+ break
+ if i+4 < len(sequence):
+ bbox = self.sequence_to_bbox(sequence[i:i+5], scale)
+ if bbox is not None:
+ if scores is not None:
+ bbox['score'] = scores[i + 4]
+ bboxes.append(bbox)
+ i += 4
+ i += 1
+ return bboxes
+
+
+def get_tokenizer(args):
+ tokenizer = {}
+ if args.pix2seq:
+ args.coord_bins = 2000
+ args.sep_xy = False
+ format = args.format
+ if format == 'reaction':
+ tokenizer[format] = ReactionTokenizer(args.coord_bins, args.sep_xy, args.pix2seq)
+ if format == 'bbox':
+ tokenizer[format] = BboxTokenizer(args.coord_bins, args.sep_xy, args.pix2seq)
+ return tokenizer
diff --git a/rxn/reaction/transformer/__init__.py b/rxn/reaction/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9c953b157c20f4c8bd13345f0a846fd70e0815e
--- /dev/null
+++ b/rxn/reaction/transformer/__init__.py
@@ -0,0 +1,3 @@
+from .decoder import TransformerDecoder
+from .embedding import Embeddings
+from .swin_transformer import swin_base, swin_large
diff --git a/rxn/reaction/transformer/decoder.py b/rxn/reaction/transformer/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a04a96aa4c472fd00d3e8a9d470bd62d380b3202
--- /dev/null
+++ b/rxn/reaction/transformer/decoder.py
@@ -0,0 +1,487 @@
+"""
+Implementation of "Attention is All You Need" and of
+subsequent transformer based architectures
+"""
+
+import torch
+import torch.nn as nn
+
+from onmt.decoders.decoder import DecoderBase
+from onmt.modules import MultiHeadedAttention, AverageAttention
+from onmt.modules.position_ffn import PositionwiseFeedForward
+from onmt.modules.position_ffn import ActivationFunction
+from onmt.utils.misc import sequence_mask
+
+
+class TransformerDecoderLayerBase(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type="scaled-dot",
+ max_relative_positions=0,
+ aan_useffn=False,
+ full_context_alignment=False,
+ alignment_heads=0,
+ pos_ffn_activation_fn=ActivationFunction.relu,
+ ):
+ """
+ Args:
+ d_model (int): the dimension of keys/values/queries in
+ :class:`MultiHeadedAttention`, also the input size of
+ the first-layer of the :class:`PositionwiseFeedForward`.
+ heads (int): the number of heads for MultiHeadedAttention.
+ d_ff (int): the second-layer of the
+ :class:`PositionwiseFeedForward`.
+ dropout (float): dropout in residual, self-attn(dot) and
+ feed-forward
+ attention_dropout (float): dropout in context_attn (and
+ self-attn(avg))
+ self_attn_type (string): type of self-attention scaled-dot,
+ average
+ max_relative_positions (int):
+ Max distance between inputs in relative positions
+ representations
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
+ full_context_alignment (bool):
+ whether enable an extra full context decoder forward for
+ alignment
+ alignment_heads (int):
+ N. of cross attention heads to use for alignment guiding
+ pos_ffn_activation_fn (ActivationFunction):
+ activation function choice for PositionwiseFeedForward layer
+
+ """
+ super(TransformerDecoderLayerBase, self).__init__()
+
+ if self_attn_type == "scaled-dot":
+ self.self_attn = MultiHeadedAttention(
+ heads,
+ d_model,
+ dropout=attention_dropout,
+ max_relative_positions=max_relative_positions,
+ )
+ elif self_attn_type == "average":
+ self.self_attn = AverageAttention(
+ d_model, dropout=attention_dropout, aan_useffn=aan_useffn
+ )
+
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
+ pos_ffn_activation_fn
+ )
+ self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
+ self.drop = nn.Dropout(dropout)
+ self.full_context_alignment = full_context_alignment
+ self.alignment_heads = alignment_heads
+
+ def forward(self, *args, **kwargs):
+ """Extend `_forward` for (possibly) multiple decoder pass:
+ Always a default (future masked) decoder forward pass,
+ Possibly a second future aware decoder pass for joint learn
+ full context alignement, :cite:`garg2019jointly`.
+
+ Args:
+ * All arguments of _forward.
+ with_align (bool): whether return alignment attention.
+
+ Returns:
+ (FloatTensor, FloatTensor, FloatTensor or None):
+
+ * output ``(batch_size, T, model_dim)``
+ * top_attn ``(batch_size, T, src_len)``
+ * attn_align ``(batch_size, T, src_len)`` or None
+ """
+ with_align = kwargs.pop("with_align", False)
+ output, attns = self._forward(*args, **kwargs)
+ top_attn = attns[:, 0, :, :].contiguous()
+ attn_align = None
+ if with_align:
+ if self.full_context_alignment:
+ # return _, (B, Q_len, K_len)
+ _, attns = self._forward(*args, **kwargs, future=True)
+
+ if self.alignment_heads > 0:
+ attns = attns[:, : self.alignment_heads, :, :].contiguous()
+ # layer average attention across heads, get ``(B, Q, K)``
+ # Case 1: no full_context, no align heads -> layer avg baseline
+ # Case 2: no full_context, 1 align heads -> guided align
+ # Case 3: full_context, 1 align heads -> full cte guided align
+ attn_align = attns.mean(dim=1)
+ return output, top_attn, attn_align
+
+ def update_dropout(self, dropout, attention_dropout):
+ self.self_attn.update_dropout(attention_dropout)
+ self.feed_forward.update_dropout(dropout)
+ self.drop.p = dropout
+
+ def _forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _compute_dec_mask(self, tgt_pad_mask, future):
+ tgt_len = tgt_pad_mask.size(-1)
+ if not future: # apply future_mask, result mask in (B, T, T)
+ future_mask = torch.ones(
+ [tgt_len, tgt_len],
+ device=tgt_pad_mask.device,
+ dtype=torch.uint8,
+ )
+ future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
+ # BoolTensor was introduced in pytorch 1.2
+ try:
+ future_mask = future_mask.bool()
+ except AttributeError:
+ pass
+ dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
+ else: # only mask padding, result mask in (B, 1, T)
+ dec_mask = tgt_pad_mask
+ return dec_mask
+
+ def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):
+ if isinstance(self.self_attn, MultiHeadedAttention):
+ return self.self_attn(
+ inputs_norm,
+ inputs_norm,
+ inputs_norm,
+ mask=dec_mask,
+ layer_cache=layer_cache,
+ attn_type="self",
+ )
+ elif isinstance(self.self_attn, AverageAttention):
+ return self.self_attn(
+ inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step
+ )
+ else:
+ raise ValueError(
+ f"self attention {type(self.self_attn)} not supported"
+ )
+
+
+class TransformerDecoderLayer(TransformerDecoderLayerBase):
+ """Transformer Decoder layer block in Pre-Norm style.
+ Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
+ providing better converge speed and performance. This is also the actual
+ implementation in tensor2tensor and also avalable in fairseq.
+ See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
+
+ .. mermaid::
+
+ graph LR
+ %% "*SubLayer" can be self-attn, src-attn or feed forward block
+ A(input) --> B[Norm]
+ B --> C["*SubLayer"]
+ C --> D[Drop]
+ D --> E((+))
+ A --> E
+ E --> F(out)
+
+ """
+
+ def __init__(
+ self,
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type="scaled-dot",
+ max_relative_positions=0,
+ aan_useffn=False,
+ full_context_alignment=False,
+ alignment_heads=0,
+ pos_ffn_activation_fn=ActivationFunction.relu,
+ ):
+ """
+ Args:
+ See TransformerDecoderLayerBase
+ """
+ super(TransformerDecoderLayer, self).__init__(
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type,
+ max_relative_positions,
+ aan_useffn,
+ full_context_alignment,
+ alignment_heads,
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
+ )
+ self.context_attn = MultiHeadedAttention(
+ heads, d_model, dropout=attention_dropout
+ )
+ self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
+
+ def update_dropout(self, dropout, attention_dropout):
+ super(TransformerDecoderLayer, self).update_dropout(
+ dropout, attention_dropout
+ )
+ self.context_attn.update_dropout(attention_dropout)
+
+ def _forward(
+ self,
+ inputs,
+ memory_bank,
+ src_pad_mask,
+ tgt_pad_mask,
+ layer_cache=None,
+ step=None,
+ future=False,
+ ):
+ """A naive forward pass for transformer decoder.
+
+ # T: could be 1 in the case of stepwise decoding or tgt_len
+
+ Args:
+ inputs (FloatTensor): ``(batch_size, T, model_dim)``
+ memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
+ src_pad_mask (bool): ``(batch_size, 1, src_len)``
+ tgt_pad_mask (bool): ``(batch_size, 1, T)``
+ layer_cache (dict or None): cached layer info when stepwise decode
+ step (int or None): stepwise decoding counter
+ future (bool): If set True, do not apply future_mask.
+
+ Returns:
+ (FloatTensor, FloatTensor):
+
+ * output ``(batch_size, T, model_dim)``
+ * attns ``(batch_size, head, T, src_len)``
+
+ """
+ dec_mask = None
+
+ if inputs.size(1) > 1:
+ # masking is necessary when sequence length is greater than one
+ dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
+
+ inputs_norm = self.layer_norm_1(inputs)
+
+ query, _ = self._forward_self_attn(
+ inputs_norm, dec_mask, layer_cache, step
+ )
+
+ query = self.drop(query) + inputs
+
+ query_norm = self.layer_norm_2(query)
+ mid, attns = self.context_attn(
+ memory_bank,
+ memory_bank,
+ query_norm,
+ mask=src_pad_mask,
+ layer_cache=layer_cache,
+ attn_type="context",
+ )
+ output = self.feed_forward(self.drop(mid) + query)
+
+ return output, attns
+
+
+class TransformerDecoderBase(DecoderBase):
+ def __init__(self, d_model, copy_attn, alignment_layer):
+ super(TransformerDecoderBase, self).__init__()
+
+ # Decoder State
+ self.state = {}
+
+ # previously, there was a GlobalAttention module here for copy
+ # attention. But it was never actually used -- the "copy" attention
+ # just reuses the context attention.
+ self._copy = copy_attn
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ self.alignment_layer = alignment_layer
+
+ @classmethod
+ def from_opt(cls, opt, embeddings):
+ """Alternate constructor."""
+ return cls(
+ opt.dec_layers,
+ opt.dec_rnn_size,
+ opt.heads,
+ opt.transformer_ff,
+ opt.copy_attn,
+ opt.self_attn_type,
+ opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
+ opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout,
+ embeddings,
+ opt.max_relative_positions,
+ opt.aan_useffn,
+ opt.full_context_alignment,
+ opt.alignment_layer,
+ alignment_heads=opt.alignment_heads,
+ pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
+ )
+
+ def init_state(self, src, memory_bank, enc_hidden):
+ """Initialize decoder state."""
+ self.state["src"] = src
+ self.state["cache"] = None
+
+ def map_state(self, fn):
+ def _recursive_map(struct, batch_dim=0):
+ for k, v in struct.items():
+ if v is not None:
+ if isinstance(v, dict):
+ _recursive_map(v)
+ else:
+ struct[k] = fn(v, batch_dim)
+
+ if self.state["src"] is not None:
+ self.state["src"] = fn(self.state["src"], 1)
+ if self.state["cache"] is not None:
+ _recursive_map(self.state["cache"])
+
+ def detach_state(self):
+ raise NotImplementedError
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def update_dropout(self, dropout, attention_dropout):
+ self.embeddings.update_dropout(dropout)
+ for layer in self.transformer_layers:
+ layer.update_dropout(dropout, attention_dropout)
+
+
+class TransformerDecoder(TransformerDecoderBase):
+ """The Transformer decoder from "Attention is All You Need".
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
+
+ .. mermaid::
+
+ graph BT
+ A[input]
+ B[multi-head self-attn]
+ BB[multi-head src-attn]
+ C[feed forward]
+ O[output]
+ A --> B
+ B --> BB
+ BB --> C
+ C --> O
+
+
+ Args:
+ num_layers (int): number of decoder layers.
+ d_model (int): size of the model
+ heads (int): number of heads
+ d_ff (int): size of the inner FF layer
+ copy_attn (bool): if using a separate copy attention
+ self_attn_type (str): type of self-attention scaled-dot, average
+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
+ embeddings (onmt.modules.Embeddings):
+ embeddings to use, should have positional encodings
+ max_relative_positions (int):
+ Max distance between inputs in relative positions representations
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
+ full_context_alignment (bool):
+ whether enable an extra full context decoder forward for alignment
+ alignment_layer (int): N° Layer to supervise with for alignment guiding
+ alignment_heads (int):
+ N. of cross attention heads to use for alignment guiding
+ """
+
+ def __init__(
+ self,
+ num_layers,
+ d_model,
+ heads,
+ d_ff,
+ copy_attn,
+ self_attn_type,
+ dropout,
+ attention_dropout,
+ max_relative_positions,
+ aan_useffn,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
+ pos_ffn_activation_fn=ActivationFunction.relu,
+ ):
+ super(TransformerDecoder, self).__init__(
+ d_model, copy_attn, alignment_layer
+ )
+
+ self.transformer_layers = nn.ModuleList(
+ [
+ TransformerDecoderLayer(
+ d_model,
+ heads,
+ d_ff,
+ dropout,
+ attention_dropout,
+ self_attn_type=self_attn_type,
+ max_relative_positions=max_relative_positions,
+ aan_useffn=aan_useffn,
+ full_context_alignment=full_context_alignment,
+ alignment_heads=alignment_heads,
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ def detach_state(self):
+ self.state["src"] = self.state["src"].detach()
+
+ def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs):
+ """Decode, possibly stepwise."""
+ if step == 0:
+ self._init_cache(memory_bank)
+
+ batch_size, src_len, src_dim = memory_bank.size()
+ device = memory_bank.device
+ if src_pad_mask is None:
+ src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device)
+ output = tgt_emb
+ batch_size, tgt_len, tgt_dim = tgt_emb.size()
+ if tgt_pad_mask is None:
+ tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device)
+
+ future = kwargs.pop("future", False)
+ with_align = kwargs.pop("with_align", False)
+ attn_aligns = []
+ hiddens = []
+
+ for i, layer in enumerate(self.transformer_layers):
+ layer_cache = (
+ self.state["cache"]["layer_{}".format(i)]
+ if step is not None
+ else None
+ )
+ output, attn, attn_align = layer(
+ output,
+ memory_bank,
+ src_pad_mask,
+ tgt_pad_mask,
+ layer_cache=layer_cache,
+ step=step,
+ with_align=with_align,
+ future=future
+ )
+ hiddens.append(output)
+ if attn_align is not None:
+ attn_aligns.append(attn_align)
+
+ output = self.layer_norm(output) # (B, L, D)
+
+ attns = {"std": attn}
+ if self._copy:
+ attns["copy"] = attn
+ if with_align:
+ attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
+ # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
+
+ # TODO change the way attns is returned dict => list or tuple (onnx)
+ return output, attns, hiddens
+
+ def _init_cache(self, memory_bank):
+ self.state["cache"] = {}
+ for i, layer in enumerate(self.transformer_layers):
+ layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None}
+ self.state["cache"]["layer_{}".format(i)] = layer_cache
+
diff --git a/rxn/reaction/transformer/embedding.py b/rxn/reaction/transformer/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..39647774d7f183690a3443bdf608479088fb0f5d
--- /dev/null
+++ b/rxn/reaction/transformer/embedding.py
@@ -0,0 +1,260 @@
+""" Embeddings module """
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+
+from onmt.modules.util_class import Elementwise
+
+
+class SequenceTooLongError(Exception):
+ pass
+
+
+class PositionalEncoding(nn.Module):
+ """Sinusoidal positional encoding for non-recurrent neural networks.
+
+ Implementation based on "Attention Is All You Need"
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
+
+ Args:
+ dropout (float): dropout parameter
+ dim (int): embedding size
+ """
+
+ def __init__(self, dropout, dim, max_len=5000):
+ if dim % 2 != 0:
+ raise ValueError("Cannot use sin/cos positional encoding with "
+ "odd dim (got dim={:d})".format(dim))
+ pe = torch.zeros(max_len, dim)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
+ -(math.log(10000.0) / dim)))
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
+ pe = pe.unsqueeze(1)
+ super(PositionalEncoding, self).__init__()
+ self.register_buffer('pe', pe)
+ self.dropout = nn.Dropout(p=dropout)
+ self.dim = dim
+
+ def forward(self, emb, step=None):
+ """Embed inputs.
+
+ Args:
+ emb (FloatTensor): Sequence of word vectors
+ ``(seq_len, batch_size, self.dim)``
+ step (int or NoneType): If stepwise (``seq_len = 1``), use
+ the encoding for this position.
+ """
+
+ emb = emb * math.sqrt(self.dim)
+ step = step or 0
+ if self.pe.size(0) < step + emb.size(0):
+ raise SequenceTooLongError(
+ f"Sequence is {emb.size(0) + step} but PositionalEncoding is"
+ f" limited to {self.pe.size(0)}. See max_len argument."
+ )
+ emb = emb + self.pe[step:emb.size(0)+step]
+ emb = self.dropout(emb)
+ return emb
+
+
+class Embeddings(nn.Module):
+ """Words embeddings for encoder/decoder.
+
+ Additionally includes ability to add sparse input features
+ based on "Linguistic Input Features Improve Neural Machine Translation"
+ :cite:`sennrich2016linguistic`.
+
+
+ .. mermaid::
+
+ graph LR
+ A[Input]
+ C[Feature 1 Lookup]
+ A-->B[Word Lookup]
+ A-->C
+ A-->D[Feature N Lookup]
+ B-->E[MLP/Concat]
+ C-->E
+ D-->E
+ E-->F[Output]
+
+ Args:
+ word_vec_size (int): size of the dictionary of embeddings.
+ word_padding_idx (int): padding index for words in the embeddings.
+ feat_padding_idx (List[int]): padding index for a list of features
+ in the embeddings.
+ word_vocab_size (int): size of dictionary of embeddings for words.
+ feat_vocab_sizes (List[int], optional): list of size of dictionary
+ of embeddings for each feature.
+ position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding`
+ feat_merge (string): merge action for the features embeddings:
+ concat, sum or mlp.
+ feat_vec_exponent (float): when using `-feat_merge concat`, feature
+ embedding size is N^feat_dim_exponent, where N is the
+ number of values the feature takes.
+ feat_vec_size (int): embedding dimension for features when using
+ `-feat_merge mlp`
+ dropout (float): dropout probability.
+ freeze_word_vecs (bool): freeze weights of word vectors.
+ """
+
+ def __init__(self, word_vec_size,
+ word_vocab_size,
+ word_padding_idx,
+ position_encoding=False,
+ feat_merge="concat",
+ feat_vec_exponent=0.7,
+ feat_vec_size=-1,
+ feat_padding_idx=[],
+ feat_vocab_sizes=[],
+ dropout=0,
+ sparse=False,
+ freeze_word_vecs=False):
+ self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent,
+ feat_vec_size, feat_padding_idx)
+
+ if feat_padding_idx is None:
+ feat_padding_idx = []
+ self.word_padding_idx = word_padding_idx
+
+ self.word_vec_size = word_vec_size
+
+ # Dimensions and padding for constructing the word embedding matrix
+ vocab_sizes = [word_vocab_size]
+ emb_dims = [word_vec_size]
+ pad_indices = [word_padding_idx]
+
+ # Dimensions and padding for feature embedding matrices
+ # (these have no effect if feat_vocab_sizes is empty)
+ if feat_merge == 'sum':
+ feat_dims = [word_vec_size] * len(feat_vocab_sizes)
+ elif feat_vec_size > 0:
+ feat_dims = [feat_vec_size] * len(feat_vocab_sizes)
+ else:
+ feat_dims = [int(vocab ** feat_vec_exponent)
+ for vocab in feat_vocab_sizes]
+ vocab_sizes.extend(feat_vocab_sizes)
+ emb_dims.extend(feat_dims)
+ pad_indices.extend(feat_padding_idx)
+
+ # The embedding matrix look-up tables. The first look-up table
+ # is for words. Subsequent ones are for features, if any exist.
+ emb_params = zip(vocab_sizes, emb_dims, pad_indices)
+ embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse)
+ for vocab, dim, pad in emb_params]
+ emb_luts = Elementwise(feat_merge, embeddings)
+
+ # The final output size of word + feature vectors. This can vary
+ # from the word vector size if and only if features are defined.
+ # This is the attribute you should access if you need to know
+ # how big your embeddings are going to be.
+ self.embedding_size = (sum(emb_dims) if feat_merge == 'concat'
+ else word_vec_size)
+
+ # The sequence of operations that converts the input sequence
+ # into a sequence of embeddings. At minimum this consists of
+ # looking up the embeddings for each word and feature in the
+ # input. Model parameters may require the sequence to contain
+ # additional operations as well.
+ super(Embeddings, self).__init__()
+ self.make_embedding = nn.Sequential()
+ self.make_embedding.add_module('emb_luts', emb_luts)
+
+ if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0:
+ in_dim = sum(emb_dims)
+ mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU())
+ self.make_embedding.add_module('mlp', mlp)
+
+ self.position_encoding = position_encoding
+
+ if self.position_encoding:
+ pe = PositionalEncoding(dropout, self.embedding_size)
+ self.make_embedding.add_module('pe', pe)
+
+ if freeze_word_vecs:
+ self.word_lut.weight.requires_grad = False
+
+ def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent,
+ feat_vec_size, feat_padding_idx):
+ if feat_merge == "sum":
+ # features must use word_vec_size
+ if feat_vec_exponent != 0.7:
+ warnings.warn("Merging with sum, but got non-default "
+ "feat_vec_exponent. It will be unused.")
+ if feat_vec_size != -1:
+ warnings.warn("Merging with sum, but got non-default "
+ "feat_vec_size. It will be unused.")
+ elif feat_vec_size > 0:
+ # features will use feat_vec_size
+ if feat_vec_exponent != -1:
+ warnings.warn("Not merging with sum and positive "
+ "feat_vec_size, but got non-default "
+ "feat_vec_exponent. It will be unused.")
+ else:
+ if feat_vec_exponent <= 0:
+ raise ValueError("Using feat_vec_exponent to determine "
+ "feature vec size, but got feat_vec_exponent "
+ "less than or equal to 0.")
+ n_feats = len(feat_vocab_sizes)
+ if n_feats != len(feat_padding_idx):
+ raise ValueError("Got unequal number of feat_vocab_sizes and "
+ "feat_padding_idx ({:d} != {:d})".format(
+ n_feats, len(feat_padding_idx)))
+
+ @property
+ def word_lut(self):
+ """Word look-up table."""
+ return self.make_embedding[0][0]
+
+ @property
+ def emb_luts(self):
+ """Embedding look-up table."""
+ return self.make_embedding[0]
+
+ def load_pretrained_vectors(self, emb_file):
+ """Load in pretrained embeddings.
+
+ Args:
+ emb_file (str) : path to torch serialized embeddings
+ """
+
+ if emb_file:
+ pretrained = torch.load(emb_file)
+ pretrained_vec_size = pretrained.size(1)
+ if self.word_vec_size > pretrained_vec_size:
+ self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained
+ elif self.word_vec_size < pretrained_vec_size:
+ self.word_lut.weight.data \
+ .copy_(pretrained[:, :self.word_vec_size])
+ else:
+ self.word_lut.weight.data.copy_(pretrained)
+
+ def forward(self, source, step=None):
+ """Computes the embeddings for words and features.
+
+ Args:
+ source (LongTensor): index tensor ``(len, batch, nfeat)``
+
+ Returns:
+ FloatTensor: Word embeddings ``(len, batch, embedding_size)``
+ """
+
+ if self.position_encoding:
+ for i, module in enumerate(self.make_embedding._modules.values()):
+ if i == len(self.make_embedding._modules.values()) - 1:
+ source = module(source, step=step)
+ else:
+ source = module(source)
+ else:
+ source = self.make_embedding(source)
+
+ return source
+
+ def update_dropout(self, dropout):
+ if self.position_encoding:
+ self._modules['make_embedding'][1].dropout.p = dropout
+
diff --git a/rxn/reaction/transformer/swin_transformer.py b/rxn/reaction/transformer/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f144d1042799562c14833f90657697338a1ec33
--- /dev/null
+++ b/rxn/reaction/transformer/swin_transformer.py
@@ -0,0 +1,677 @@
+""" Swin Transformer
+A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
+ - https://arxiv.org/pdf/2103.14030
+
+Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
+"""
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+import logging
+import math
+from copy import deepcopy
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
+from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+from timm.models.vision_transformer import checkpoint_filter_fn, _init_vit_weights
+
+_logger = logging.getLogger(__name__)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # patch models (my experiments)
+ 'swin_base_patch4_window12_384': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+
+ 'swin_base_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
+ ),
+
+ 'swin_large_patch4_window12_384': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+
+ 'swin_large_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
+ ),
+
+ 'swin_small_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
+ ),
+
+ 'swin_tiny_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
+ ),
+
+ 'swin_base_patch4_window12_384_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
+
+ 'swin_base_patch4_window7_224_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
+ num_classes=21841),
+
+ 'swin_large_patch4_window12_384_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
+
+ 'swin_large_patch4_window7_224_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
+ num_classes=21841),
+
+}
+
+
+def window_partition(x, window_size: int):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size: int, H: int, W: int):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask: Optional[torch.Tensor] = None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
+ attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def get_attn_mask(self, H, W, device):
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, H, W, 1), device=device) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ return attn_mask
+
+ def forward(self, x, H, W):
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_mask = self.get_attn_mask(Hp, Wp, x.device)
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """
+ x: B, H*W, C
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ # assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ H, W = x.shape[1:3]
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x, H, W
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self, dim, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim, num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W, hiddens):
+ for blk in self.blocks:
+ if not torch.jit.is_scripting() and self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, H, W)
+ else:
+ x = blk(x, H, W)
+ hiddens.append(x)
+ if self.downsample is not None:
+ x, H, W = self.downsample(x, H, W)
+ return x, H, W
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+
+class PatchEmbed(nn.Module):
+ """ 2D Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
+ # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x)
+ H, W = x.shape[2:]
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x, H, W
+
+
+class SwinTransformer(nn.Module):
+ r""" Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, weight_init='', **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ self.patch_grid = self.patch_embed.grid_size
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+ else:
+ self.absolute_pos_embed = None
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ layers = []
+ for i_layer in range(self.num_layers):
+ layers += [BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ ]
+ self.layers = nn.Sequential(*layers)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
+ if weight_init.startswith('jax'):
+ for n, m in self.named_modules():
+ _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
+ else:
+ self.apply(_init_vit_weights)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward(self, x):
+ x, H, W = self.patch_embed(x)
+ if self.absolute_pos_embed is not None:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ hiddens = []
+ for layer in self.layers:
+ x, H, W = layer(x, H, W, hiddens)
+ x = self.norm(x) # B L C
+ # x = self.avgpool(x.transpose(1, 2)) # B C 1
+ # x = torch.flatten(x, 1)
+ return x, hiddens
+
+ # def forward(self, x):
+ # x = self.forward_features(x)
+ # x = self.head(x)
+ # return x
+
+
+def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
+ if default_cfg is None:
+ default_cfg = deepcopy(default_cfgs[variant])
+ overlay_external_default_cfg(default_cfg, kwargs)
+ default_num_classes = default_cfg['num_classes']
+ default_img_size = default_cfg['input_size'][-2:]
+
+ num_classes = kwargs.pop('num_classes', default_num_classes)
+ img_size = kwargs.pop('img_size', default_img_size)
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ SwinTransformer, variant, pretrained,
+ default_cfg=default_cfg,
+ img_size=img_size,
+ num_classes=num_classes,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+
+ return model
+
+
+
+@register_model
+def swin_base(pretrained=False, **kwargs):
+ """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+ return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_large(pretrained=False, **kwargs):
+ """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+ return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_small(pretrained=False, **kwargs):
+ """ Swin-S @ 224x224, trained ImageNet-1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
+ return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
+
+
+# @register_model
+# def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
+# """ Swin-T @ 224x224, trained ImageNet-1k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)
+# return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
+# """ Swin-B @ 384x384, trained ImageNet-22k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+# return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
+# """ Swin-B @ 224x224, trained ImageNet-22k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+# return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
+# """ Swin-L @ 384x384, trained ImageNet-22k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+# return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
+#
+#
+# @register_model
+# def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
+# """ Swin-L @ 224x224, trained ImageNet-22k
+# """
+# model_kwargs = dict(
+# patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+# return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
diff --git a/rxn/reaction/transforms.py b/rxn/reaction/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f359c5501a20208a433cc73e3b0f87e44b434dc
--- /dev/null
+++ b/rxn/reaction/transforms.py
@@ -0,0 +1,498 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Transforms and data augmentation for both image + bbox.
+"""
+import random
+import math
+
+import PIL
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as F
+
+import numpy as np
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
+ (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
+
+def crop(image, target, region):
+ cropped_image = F.crop(image, *region)
+
+ target = target.copy()
+ i, j, h, w = region
+
+ # should we do something wrt the original size?
+ # target["size"] = torch.tensor([h, w])
+
+ fields = ["labels", "area"]
+
+ if "boxes" in target:
+ boxes = target["boxes"]
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
+ cropped_boxes = cropped_boxes.clamp(min=0)
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
+ target["area"] = area
+ fields.append("boxes")
+
+ # remove elements for which the boxes or masks that have zero area
+ # if "boxes" in target or "masks" in target:
+ # # favor boxes selection when defining which elements to keep
+ # # this is compatible with previous implementation
+ # if "boxes" in target:
+ # cropped_boxes = target['boxes'].reshape(-1, 2, 2)
+ # keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
+ # else:
+ # keep = target['masks'].flatten(1).any(1)
+ #
+ # for field in fields:
+ # target[field] = target[field][keep]
+
+ return cropped_image, target
+
+
+def hflip(image, target):
+ flipped_image = F.hflip(image)
+
+ w, h = image.size
+
+ target = target.copy()
+ if "boxes" in target:
+ boxes = target["boxes"]
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
+ target["boxes"] = boxes
+
+ return flipped_image, target
+
+
+def rotate90(image, target):
+ rotated_image = image.rotate(90, expand=1)
+
+ w, h = rotated_image.size
+
+ target = target.copy()
+ if "boxes" in target:
+ boxes = target["boxes"]
+ boxes = boxes[:, [1, 2, 3, 0]] * torch.as_tensor([1, -1, 1, -1]) + torch.as_tensor([0, h, 0, h])
+ target["boxes"] = boxes
+
+ return rotated_image, target
+
+
+def resize(image, target, size, max_size=None):
+ # size can be min_size (scalar) or (w, h) tuple
+
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
+ w, h = image_size
+ if max_size is not None:
+ min_original_size = float(min((w, h)))
+ max_original_size = float(max((w, h)))
+ if max_original_size / min_original_size * size > max_size:
+ size = int(round(max_size * min_original_size / max_original_size))
+
+ if (w <= h and w == size) or (h <= w and h == size):
+ return (h, w)
+
+ if w < h:
+ ow = size
+ oh = int(size * h / w)
+ else:
+ oh = size
+ ow = int(size * w / h)
+
+ return (oh, ow)
+
+ def get_size(image_size, size, max_size=None):
+ if isinstance(size, (list, tuple)):
+ return size[::-1]
+ else:
+ return get_size_with_aspect_ratio(image_size, size, max_size)
+
+ size = get_size(image.size, size, max_size)
+ rescaled_image = F.resize(image, size)
+
+ if target is None:
+ return rescaled_image, None
+
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
+ ratio_width, ratio_height = ratios
+
+ target = target.copy()
+ if "boxes" in target:
+ boxes = target["boxes"]
+ scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
+ target["boxes"] = scaled_boxes
+
+ if "area" in target:
+ area = target["area"]
+ scaled_area = area * (ratio_width * ratio_height)
+ target["area"] = scaled_area
+
+ return rescaled_image, target
+
+
+def pad(image, target, padding):
+ # assumes that we only pad on the bottom right corners
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
+ if target is None:
+ return padded_image, None
+ target = target.copy()
+ # should we do something wrt the original size?
+ target["size"] = torch.tensor(padded_image.size[::-1])
+ if "masks" in target:
+ target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
+ return padded_image, target
+
+
+class RandomCrop(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img, target):
+ region = T.RandomCrop.get_params(img, self.size)
+ return crop(img, target, region)
+
+
+class RandomSizeCrop(object):
+ def __init__(self, min_size: int, max_size: int):
+ self.min_size = min_size
+ self.max_size = max_size
+
+ def __call__(self, img: PIL.Image.Image, target: dict):
+ w = random.randint(self.min_size, min(img.width, self.max_size))
+ h = random.randint(self.min_size, min(img.height, self.max_size))
+ region = T.RandomCrop.get_params(img, [h, w])
+ return crop(img, target, region)
+
+
+class CenterCrop(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img, target):
+ image_width, image_height = img.size
+ crop_height, crop_width = self.size
+ crop_top = int(round((image_height - crop_height) / 2.))
+ crop_left = int(round((image_width - crop_width) / 2.))
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
+
+
+class RandomReactionCrop(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, img, target):
+ w, h = img.size
+ boxes = target["boxes"]
+ x_avail = [1] * w
+ y_avail = [1] * h
+ for reaction in target['reactions']:
+ ids = reaction['reactants'] + reaction['conditions'] + reaction['products']
+ rboxes = boxes[ids].round().int()
+ rmin, _ = rboxes.min(dim=0)
+ rmax, _ = rboxes.max(dim=0)
+ x1, x2 = (rmin[0].item(), rmax[2].item())
+ for i in range(x1, x2):
+ x_avail[i] = 0
+ y1, y2 = (rmin[1].item(), rmax[3].item())
+ for i in range(y1, y2):
+ y_avail[i] = 0
+
+ def sample_from_avail(w):
+ spans = []
+ left, right = 0, 0
+ while right < len(w):
+ while right < len(w) and w[left] == w[right]:
+ right += 1
+ if w[left] == 1:
+ spans.append((left, right))
+ left, right = right + 1, right + 1
+ if w[0] == 0:
+ spans = [(0, 0)] + spans
+ if w[-1] == 0:
+ spans = spans + [(len(w), len(w))]
+ if len(spans) < 2:
+ w1 = random.randint(0, len(w))
+ w2 = random.randint(0, len(w))
+ else:
+ spans = random.sample(spans, 2)
+ w1 = random.randint(*spans[0])
+ w2 = random.randint(*spans[1])
+ return min(w1, w2), max(w1, w2)
+
+ x1, x2 = sample_from_avail(x_avail)
+ y1, y2 = sample_from_avail(y_avail)
+ region = (y1, x1, y2-y1, x2-x1)
+ if x2-x1 < 30 or y2-y1 < 30:
+ # Cropped region too small
+ return img, target
+ else:
+ return crop(img, target, region)
+
+
+class RandomHorizontalFlip(object):
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, target):
+ if random.random() < self.p:
+ return hflip(img, target)
+ return img, target
+
+
+class RandomRotate(object):
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, target):
+ if random.random() < self.p:
+ return rotate90(img, target)
+ return img, target
+
+
+class RandomResize(object):
+ def __init__(self, sizes, max_size=None):
+ assert isinstance(sizes, (list, tuple))
+ self.sizes = sizes
+ self.max_size = max_size
+
+ def __call__(self, img, target=None):
+ size = random.choice(self.sizes)
+ return resize(img, target, size, self.max_size)
+
+
+class RandomPad(object):
+ def __init__(self, max_pad):
+ self.max_pad = max_pad
+
+ def __call__(self, img, target):
+ pad_x = random.randint(0, self.max_pad)
+ pad_y = random.randint(0, self.max_pad)
+ return pad(img, target, (pad_x, pad_y))
+
+
+class RandomSelect(object):
+ """
+ Randomly selects between transforms1 and transforms2,
+ with probability p for transforms1 and (1 - p) for transforms2
+ """
+ def __init__(self, transforms1, transforms2, p=0.5):
+ self.transforms1 = transforms1
+ self.transforms2 = transforms2
+ self.p = p
+
+ def __call__(self, img, target):
+ if random.random() < self.p:
+ return self.transforms1(img, target)
+ return self.transforms2(img, target)
+
+
+class Resize(object):
+ def __init__(self, size):
+ assert isinstance(size, (list, tuple))
+ self.size = size
+
+ def __call__(self, img, target=None):
+ return resize(img, target, self.size)
+
+
+class ToTensor(object):
+ def __call__(self, img, target):
+ return F.to_tensor(img), target
+
+
+class RandomErasing(object):
+
+ def __init__(self, *args, **kwargs):
+ self.eraser = T.RandomErasing(*args, **kwargs)
+
+ def __call__(self, img, target):
+ return self.eraser(img), target
+
+
+class Normalize(object):
+ def __init__(self, mean, std, debug=False):
+ self.mean = mean
+ self.std = std
+ self.debug = debug
+
+ def __call__(self, image, target=None):
+ if not self.debug:
+ image = F.normalize(image, mean=self.mean, std=self.std)
+ if target is None:
+ return image, None
+ target = target.copy()
+ h, w = image.shape[-2:]
+ if "boxes" in target:
+ boxes = target["boxes"]
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
+ target["boxes"] = boxes.clamp(min=0, max=1)
+ return image, target
+
+
+class Compose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, image, target=None):
+ for t in self.transforms:
+ image, target = t(image, target)
+ return image, target
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
+
+
+class LargeScaleJitter(object):
+ """
+ implementation of large scale jitter from copy_paste
+ """
+
+ def __init__(self, output_size=1333, aug_scale_min=0.3, aug_scale_max=2.0):
+ self.desired_size = output_size
+ self.aug_scale_min = aug_scale_min
+ self.aug_scale_max = aug_scale_max
+ self.random = (aug_scale_min != 1) or (aug_scale_max != 1)
+
+ def rescale_target(self, scaled_size, image_size, target):
+ # compute rescaled targets
+ image_scale = scaled_size / image_size
+ ratio_height, ratio_width = image_scale
+
+ target = target.copy()
+
+ if "boxes" in target:
+ boxes = target["boxes"]
+ scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
+ target["boxes"] = scaled_boxes
+
+ if "area" in target:
+ area = target["area"]
+ scaled_area = area * (ratio_width * ratio_height)
+ target["area"] = scaled_area
+
+ return target
+
+ def crop_target(self, region, target):
+ i, j, h, w = region
+ fields = ["labels", "area"]
+
+ target = target.copy()
+
+ if "boxes" in target:
+ boxes = target["boxes"]
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
+ cropped_boxes = cropped_boxes.clamp(min=0)
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
+ target["area"] = area
+ fields.append("boxes")
+
+ # Do not remove the boxes with zero area. Tokenizer does it instead.
+ # if "boxes" in target:
+ # # favor boxes selection when defining which elements to keep
+ # # this is compatible with previous implementation
+ # cropped_boxes = target['boxes'].reshape(-1, 2, 2)
+ # keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
+ # for field in fields:
+ # target[field] = target[field][keep]
+ return target
+
+ def pad_target(self, padding, target):
+ # padding: left, top, right, bottom
+ target = target.copy()
+ if "boxes" in target:
+ left, top, right, bottom = padding
+ target["boxes"][:, 0::2] += left
+ target["boxes"][:, 1::2] += top
+ return target
+
+ def __call__(self, image, target=None):
+ image_size = image.size
+ image_size = torch.tensor(image_size[::-1])
+ if target is None:
+ target = {}
+
+ # out_desired_size = (self.desired_size * image_size / max(image_size)).round().int()
+ out_desired_size = torch.tensor([self.desired_size, self.desired_size])
+
+ random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min
+ scaled_size = (random_scale * self.desired_size).round()
+
+ scale = torch.minimum(scaled_size / image_size[0], scaled_size / image_size[1])
+ scaled_size = (image_size * scale).round().int().clamp(min=1)
+
+ scaled_image = F.resize(image, scaled_size.tolist())
+
+ if target is not None:
+ target = self.rescale_target(scaled_size, image_size, target)
+
+ # randomly crop or pad images
+ delta = scaled_size - out_desired_size
+ output_image = scaled_image
+
+ w, h = scaled_image.size
+ target["scale"] = [w / self.desired_size, h / self.desired_size]
+
+ if delta.lt(0).any():
+ padding = torch.clamp(-delta, min=0)
+ if self.random:
+ padding1 = (torch.rand(1) * padding).round().int()
+ padding2 = padding - padding1
+ padding = padding1.tolist()[::-1] + padding2.tolist()[::-1]
+ else:
+ padding = [0, 0] + padding.tolist()[::-1]
+ output_image = F.pad(output_image, padding, 255)
+ # output_image = F.pad(scaled_image, [0, 0, padding[1].item(), padding[0].item()])
+ if target is not None:
+ target = self.pad_target(padding, target)
+
+ if delta.gt(0).any():
+ # Selects non-zero random offset (x, y) if scaled image is larger than desired_size.
+ max_offset = torch.clamp(delta, min=0)
+ if self.random:
+ offset = (max_offset * torch.rand(2)).floor().int()
+ else:
+ offset = torch.zeros(2)
+ region = (offset[0].item(), offset[1].item(), out_desired_size[0].item(), out_desired_size[1].item())
+ output_image = F.crop(output_image, *region)
+ if target is not None:
+ target = self.crop_target(region, target)
+
+ return output_image, target
+
+
+class RandomDistortion(object):
+ """
+ Distort image w.r.t hue, saturation and exposure.
+ """
+
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, prob=0.5):
+ self.prob = prob
+ self.tfm = T.ColorJitter(brightness, contrast, saturation, hue)
+
+ def __call__(self, img, target=None):
+ if np.random.random() < self.prob:
+ return self.tfm(img), target
+ else:
+ return img, target
diff --git a/rxn/reaction/utils.py b/rxn/reaction/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e91b42d249e88a50458f4ee2109a1ff9b171ce16
--- /dev/null
+++ b/rxn/reaction/utils.py
@@ -0,0 +1,14 @@
+import json
+
+
+def merge_predictions(results):
+ if len(results) == 0:
+ return {}
+ formats = results[0][1].keys()
+ predictions = {format_: {} for format_ in formats}
+ for format_ in formats:
+ for indices, batch_preds in results:
+ for idx, preds in zip(indices, batch_preds[format_]):
+ predictions[format_][idx] = preds
+ predictions[format_] = [predictions[format_][i] for i in range(len(predictions[format_]))]
+ return predictions
diff --git a/rxnim.py b/rxnim.py
new file mode 100644
index 0000000000000000000000000000000000000000..30ccb42bb937bb6bee6bfbe711fbe1c16423dfd5
--- /dev/null
+++ b/rxnim.py
@@ -0,0 +1,162 @@
+import base64
+import json
+from openai import AzureOpenAI
+import os
+import sys
+sys.path.append('./rxn/')
+import torch
+import json
+from getReaction import get_reaction
+
+
+
+class RXNIM:
+ def __init__(self, api_key_file='api_key.txt', api_version='2024-06-01', azure_endpoint='https://hkust.azure-api.net'):
+ # Read API key
+ with open(api_key_file, 'r') as api_key_file_handle:
+ self.API_KEY = api_key_file_handle.read().strip()
+
+ # def __init__(self, api_version='2024-06-01', azure_endpoint='https://hkust.azure-api.net'):
+ # # 从环境变量读取 API Key
+ # self.API_KEY = os.environ.get('key')
+ # if not self.API_KEY:
+ # raise ValueError("Environment variable 'KEY' not set.")
+
+ # Set up client
+ self.client = AzureOpenAI(
+ api_key=self.API_KEY,
+ api_version=api_version,
+ azure_endpoint=azure_endpoint,
+ )
+
+ # Define tools
+ self.tools = [
+ {
+ 'type': 'function',
+ 'function': {
+ 'name': 'get_reaction',
+ 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'image_path': {
+ 'type': 'string',
+ 'description': 'The path to the reaction image.',
+ },
+ },
+ 'required': ['image_path'],
+ 'additionalProperties': False,
+ },
+ },
+ },
+ ]
+
+ # Define tool mapping
+ self.TOOL_MAP = {
+ 'get_reaction': get_reaction,
+ }
+
+ def encode_image(self, image_path: str):
+ '''Returns a base64 string of the input image.'''
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode('utf-8')
+
+ def process(self, image_path: str, prompt_path: str):
+ # Encode image
+ base64_image = self.encode_image(image_path)
+
+ # Read prompt
+ with open(prompt_path, 'r') as prompt_file:
+ prompt = prompt_file.read()
+
+ # Build initial messages
+ messages = [
+ {'role': 'system', 'content': 'You are a helpful assistant. Before providing the final answer, consider if any additional information or tool usage is needed to improve your response.'},
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': prompt
+ },
+ {
+ 'type': 'image_url',
+ 'image_url': {
+ 'url': f'data:image/png;base64,{base64_image}'
+ }
+ }
+ ]
+ },
+ ]
+
+ MAX_ITERATIONS = 5
+ iterations = 0
+
+ while iterations < MAX_ITERATIONS:
+ iterations += 1
+ print(f'Iteration {iterations}')
+
+ # Call the model
+ response = self.client.chat.completions.create(
+ model='gpt-4o',
+ temperature=0,
+ response_format={'type': 'json_object'},
+ messages=messages,
+ tools=self.tools,
+ )
+
+ # Get assistant's message
+ assistant_message = response.choices[0].message
+
+ # Add assistant's message to messages
+ messages.append(assistant_message)
+
+ # Check for tool calls
+ if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
+ tool_calls = assistant_message.tool_calls
+ results = []
+
+ for tool_call in tool_calls:
+ tool_name = tool_call.function.name
+ tool_arguments = tool_call.function.arguments
+ tool_call_id = tool_call.id
+
+ tool_args = json.loads(tool_arguments)
+
+ if tool_name in self.TOOL_MAP:
+ try:
+ # Call the tool function
+ tool_result = self.TOOL_MAP[tool_name](image_path)
+ print(f'{tool_name} result: {tool_result}')
+ except Exception as e:
+ tool_result = {'error': str(e)}
+ else:
+ tool_result = {'error': f"Unknown tool called: {tool_name}"}
+
+ # Append tool result to messages
+ results.append({
+ 'role': 'tool',
+ 'content': json.dumps({
+ 'image_path': image_path,
+ f'{tool_name}': tool_result,
+ }),
+ 'tool_call_id': tool_call_id,
+ })
+ print(results)
+
+ # Add tool results to messages
+ messages.extend(results)
+ else:
+ # No more tool calls, assume task is completed
+ break
+
+ else:
+ # Exceeded maximum iterations
+ return "The assistant could not complete the task within the maximum number of iterations."
+
+ # Return the final assistant message
+ final_content = assistant_message.content
+ return final_content
+
+
+