import os import json import random import numpy as np from SmilesPE.pretokenizer import atomwise_tokenizer PAD = '' SOS = '' EOS = '' UNK = '' MASK = '' PAD_ID = 0 SOS_ID = 1 EOS_ID = 2 UNK_ID = 3 MASK_ID = 4 class Tokenizer(object): def __init__(self, path=None): self.stoi = {} self.itos = {} if path: self.load(path) def __len__(self): return len(self.stoi) @property def output_constraint(self): return False def save(self, path): with open(path, 'w') as f: json.dump(self.stoi, f) def load(self, path): with open(path) as f: self.stoi = json.load(f) self.itos = {item[1]: item[0] for item in self.stoi.items()} def fit_on_texts(self, texts): vocab = set() for text in texts: vocab.update(text.split(' ')) vocab = [PAD, SOS, EOS, UNK] + list(vocab) for i, s in enumerate(vocab): self.stoi[s] = i self.itos = {item[1]: item[0] for item in self.stoi.items()} assert self.stoi[PAD] == PAD_ID assert self.stoi[SOS] == SOS_ID assert self.stoi[EOS] == EOS_ID assert self.stoi[UNK] == UNK_ID def text_to_sequence(self, text, tokenized=True): sequence = [] sequence.append(self.stoi['']) if tokenized: tokens = text.split(' ') else: tokens = atomwise_tokenizer(text) for s in tokens: if s not in self.stoi: s = '' sequence.append(self.stoi[s]) sequence.append(self.stoi['']) return sequence def texts_to_sequences(self, texts): sequences = [] for text in texts: sequence = self.text_to_sequence(text) sequences.append(sequence) return sequences def sequence_to_text(self, sequence): return ''.join(list(map(lambda i: self.itos[i], sequence))) def sequences_to_texts(self, sequences): texts = [] for sequence in sequences: text = self.sequence_to_text(sequence) texts.append(text) return texts def predict_caption(self, sequence): caption = '' for i in sequence: if i == self.stoi[''] or i == self.stoi['']: break caption += self.itos[i] return caption def predict_captions(self, sequences): captions = [] for sequence in sequences: caption = self.predict_caption(sequence) captions.append(caption) return captions def sequence_to_smiles(self, sequence): return {'smiles': self.predict_caption(sequence)} class NodeTokenizer(Tokenizer): def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False): super().__init__(path) self.maxx = input_size # height self.maxy = input_size # width self.sep_xy = sep_xy self.special_tokens = [PAD, SOS, EOS, UNK, MASK] self.continuous_coords = continuous_coords self.debug = debug def __len__(self): if self.sep_xy: return self.offset + self.maxx + self.maxy else: return self.offset + max(self.maxx, self.maxy) @property def offset(self): return len(self.stoi) @property def output_constraint(self): return not self.continuous_coords def len_symbols(self): return len(self.stoi) def fit_atom_symbols(self, atoms): vocab = self.special_tokens + list(set(atoms)) for i, s in enumerate(vocab): self.stoi[s] = i assert self.stoi[PAD] == PAD_ID assert self.stoi[SOS] == SOS_ID assert self.stoi[EOS] == EOS_ID assert self.stoi[UNK] == UNK_ID assert self.stoi[MASK] == MASK_ID self.itos = {item[1]: item[0] for item in self.stoi.items()} def is_x(self, x): return self.offset <= x < self.offset + self.maxx def is_y(self, y): if self.sep_xy: return self.offset + self.maxx <= y return self.offset <= y def is_symbol(self, s): return len(self.special_tokens) <= s < self.offset or s == UNK_ID def is_atom(self, id): if self.is_symbol(id): return self.is_atom_token(self.itos[id]) return False def is_atom_token(self, token): return token.isalpha() or token.startswith("[") or token == '*' or token == UNK def x_to_id(self, x): return self.offset + round(x * (self.maxx - 1)) def y_to_id(self, y): if self.sep_xy: return self.offset + self.maxx + round(y * (self.maxy - 1)) return self.offset + round(y * (self.maxy - 1)) def id_to_x(self, id): return (id - self.offset) / (self.maxx - 1) def id_to_y(self, id): if self.sep_xy: return (id - self.offset - self.maxx) / (self.maxy - 1) return (id - self.offset) / (self.maxy - 1) def get_output_mask(self, id): mask = [False] * len(self) if self.continuous_coords: return mask if self.is_atom(id): return [True] * self.offset + [False] * self.maxx + [True] * self.maxy if self.is_x(id): return [True] * (self.offset + self.maxx) + [False] * self.maxy if self.is_y(id): return [False] * self.offset + [True] * (self.maxx + self.maxy) return mask def symbol_to_id(self, symbol): if symbol not in self.stoi: return UNK_ID return self.stoi[symbol] def symbols_to_labels(self, symbols): labels = [] for symbol in symbols: labels.append(self.symbol_to_id(symbol)) return labels def labels_to_symbols(self, labels): symbols = [] for label in labels: symbols.append(self.itos[label]) return symbols def nodes_to_grid(self, nodes): coords, symbols = nodes['coords'], nodes['symbols'] grid = np.zeros((self.maxx, self.maxy), dtype=int) for [x, y], symbol in zip(coords, symbols): x = round(x * (self.maxx - 1)) y = round(y * (self.maxy - 1)) grid[x][y] = self.symbol_to_id(symbol) return grid def grid_to_nodes(self, grid): coords, symbols, indices = [], [], [] for i in range(self.maxx): for j in range(self.maxy): if grid[i][j] != 0: x = i / (self.maxx - 1) y = j / (self.maxy - 1) coords.append([x, y]) symbols.append(self.itos[grid[i][j]]) indices.append([i, j]) return {'coords': coords, 'symbols': symbols, 'indices': indices} def nodes_to_sequence(self, nodes): coords, symbols = nodes['coords'], nodes['symbols'] labels = [SOS_ID] for (x, y), symbol in zip(coords, symbols): assert 0 <= x <= 1 assert 0 <= y <= 1 labels.append(self.x_to_id(x)) labels.append(self.y_to_id(y)) labels.append(self.symbol_to_id(symbol)) labels.append(EOS_ID) return labels def sequence_to_nodes(self, sequence): coords, symbols = [], [] i = 0 if sequence[0] == SOS_ID: i += 1 while i + 2 < len(sequence): if sequence[i] == EOS_ID: break if self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]): x = self.id_to_x(sequence[i]) y = self.id_to_y(sequence[i+1]) symbol = self.itos[sequence[i+2]] coords.append([x, y]) symbols.append(symbol) i += 3 return {'coords': coords, 'symbols': symbols} def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False): tokens = atomwise_tokenizer(smiles) labels = [SOS_ID] indices = [] atom_idx = -1 for token in tokens: if atom_only and not self.is_atom_token(token): continue if token in self.stoi: labels.append(self.stoi[token]) else: if self.debug: print(f'{token} not in vocab') labels.append(UNK_ID) if self.is_atom_token(token): atom_idx += 1 if not self.continuous_coords: if mask_ratio > 0 and random.random() < mask_ratio: labels.append(MASK_ID) labels.append(MASK_ID) elif coords is not None: if atom_idx < len(coords): x, y = coords[atom_idx] assert 0 <= x <= 1 assert 0 <= y <= 1 else: x = random.random() y = random.random() labels.append(self.x_to_id(x)) labels.append(self.y_to_id(y)) indices.append(len(labels) - 1) labels.append(EOS_ID) return labels, indices def sequence_to_smiles(self, sequence): has_coords = not self.continuous_coords smiles = '' coords, symbols, indices = [], [], [] for i, label in enumerate(sequence): if label == EOS_ID or label == PAD_ID: break if self.is_x(label) or self.is_y(label): continue token = self.itos[label] smiles += token if self.is_atom_token(token): if has_coords: if i+3 < len(sequence) and self.is_x(sequence[i+1]) and self.is_y(sequence[i+2]): x = self.id_to_x(sequence[i+1]) y = self.id_to_y(sequence[i+2]) coords.append([x, y]) symbols.append(token) indices.append(i+3) else: if i+1 < len(sequence): symbols.append(token) indices.append(i+1) results = {'smiles': smiles, 'symbols': symbols, 'indices': indices} if has_coords: results['coords'] = coords return results class CharTokenizer(NodeTokenizer): def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False): super().__init__(input_size, path, sep_xy, continuous_coords, debug) def fit_on_texts(self, texts): vocab = set() for text in texts: vocab.update(list(text)) if ' ' in vocab: vocab.remove(' ') vocab = [PAD, SOS, EOS, UNK] + list(vocab) for i, s in enumerate(vocab): self.stoi[s] = i self.itos = {item[1]: item[0] for item in self.stoi.items()} assert self.stoi[PAD] == PAD_ID assert self.stoi[SOS] == SOS_ID assert self.stoi[EOS] == EOS_ID assert self.stoi[UNK] == UNK_ID def text_to_sequence(self, text, tokenized=True): sequence = [] sequence.append(self.stoi['']) if tokenized: tokens = text.split(' ') assert all(len(s) == 1 for s in tokens) else: tokens = list(text) for s in tokens: if s not in self.stoi: s = '' sequence.append(self.stoi[s]) sequence.append(self.stoi['']) return sequence def fit_atom_symbols(self, atoms): atoms = list(set(atoms)) chars = [] for atom in atoms: chars.extend(list(atom)) vocab = self.special_tokens + chars for i, s in enumerate(vocab): self.stoi[s] = i assert self.stoi[PAD] == PAD_ID assert self.stoi[SOS] == SOS_ID assert self.stoi[EOS] == EOS_ID assert self.stoi[UNK] == UNK_ID assert self.stoi[MASK] == MASK_ID self.itos = {item[1]: item[0] for item in self.stoi.items()} def get_output_mask(self, id): ''' TO FIX ''' mask = [False] * len(self) if self.continuous_coords: return mask if self.is_x(id): return [True] * (self.offset + self.maxx) + [False] * self.maxy if self.is_y(id): return [False] * self.offset + [True] * (self.maxx + self.maxy) return mask def nodes_to_sequence(self, nodes): coords, symbols = nodes['coords'], nodes['symbols'] labels = [SOS_ID] for (x, y), symbol in zip(coords, symbols): assert 0 <= x <= 1 assert 0 <= y <= 1 labels.append(self.x_to_id(x)) labels.append(self.y_to_id(y)) for char in symbol: labels.append(self.symbol_to_id(char)) labels.append(EOS_ID) return labels def sequence_to_nodes(self, sequence): coords, symbols = [], [] i = 0 if sequence[0] == SOS_ID: i += 1 while i < len(sequence): if sequence[i] == EOS_ID: break if i+2 < len(sequence) and self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]): x = self.id_to_x(sequence[i]) y = self.id_to_y(sequence[i+1]) for j in range(i+2, len(sequence)): if not self.is_symbol(sequence[j]): break symbol = ''.join(self.itos(sequence[k]) for k in range(i+2, j)) coords.append([x, y]) symbols.append(symbol) i = j else: i += 1 return {'coords': coords, 'symbols': symbols} def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False): tokens = atomwise_tokenizer(smiles) labels = [SOS_ID] indices = [] atom_idx = -1 for token in tokens: if atom_only and not self.is_atom_token(token): continue for c in token: if c in self.stoi: labels.append(self.stoi[c]) else: if self.debug: print(f'{c} not in vocab') labels.append(UNK_ID) if self.is_atom_token(token): atom_idx += 1 if not self.continuous_coords: if mask_ratio > 0 and random.random() < mask_ratio: labels.append(MASK_ID) labels.append(MASK_ID) elif coords is not None: if atom_idx < len(coords): x, y = coords[atom_idx] assert 0 <= x <= 1 assert 0 <= y <= 1 else: x = random.random() y = random.random() labels.append(self.x_to_id(x)) labels.append(self.y_to_id(y)) indices.append(len(labels) - 1) labels.append(EOS_ID) return labels, indices def sequence_to_smiles(self, sequence): has_coords = not self.continuous_coords smiles = '' coords, symbols, indices = [], [], [] i = 0 while i < len(sequence): label = sequence[i] if label == EOS_ID or label == PAD_ID: break if self.is_x(label) or self.is_y(label): i += 1 continue if not self.is_atom(label): smiles += self.itos[label] i += 1 continue if self.itos[label] == '[': j = i + 1 while j < len(sequence): if not self.is_symbol(sequence[j]): break if self.itos[sequence[j]] == ']': j += 1 break j += 1 else: if i+1 < len(sequence) and (self.itos[label] == 'C' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'l' \ or self.itos[label] == 'B' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'r'): j = i+2 else: j = i+1 token = ''.join(self.itos[sequence[k]] for k in range(i, j)) smiles += token if has_coords: if j+2 < len(sequence) and self.is_x(sequence[j]) and self.is_y(sequence[j+1]): x = self.id_to_x(sequence[j]) y = self.id_to_y(sequence[j+1]) coords.append([x, y]) symbols.append(token) indices.append(j+2) i = j+2 else: i = j else: if j < len(sequence): symbols.append(token) indices.append(j) i = j results = {'smiles': smiles, 'symbols': symbols, 'indices': indices} if has_coords: results['coords'] = coords return results def get_tokenizer(args): tokenizer = {} for format_ in args.formats: if format_ == 'atomtok': if args.vocab_file is None: args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json') tokenizer['atomtok'] = Tokenizer(args.vocab_file) elif format_ == "atomtok_coords": if args.vocab_file is None: args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json') tokenizer["atomtok_coords"] = NodeTokenizer(args.coord_bins, args.vocab_file, args.sep_xy, continuous_coords=args.continuous_coords) elif format_ == "chartok_coords": if args.vocab_file is None: args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_chars.json') tokenizer["chartok_coords"] = CharTokenizer(args.coord_bins, args.vocab_file, args.sep_xy, continuous_coords=args.continuous_coords) return tokenizer