from typing import Sequence, Tuple, List, Union import itertools class ResidueLevelTokenizer: """ Tokenizer for Protein Residue Level Tokenization. """ def __init__(self, **kwargs): super(ResidueLevelTokenizer, self).__init__() self.pad_tok = ['[pad]'] self.all_toks = self.pad_tok self._tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-'] self.all_toks.extend(self._tokens) self._special_tokens = ['MASK', 'gMASK', 'sMASK', 'eod', 'sop', 'eop', '', ''] self.set_special_tokens(self._special_tokens) self.special_tokens['eos']=self.special_tokens[''] self.special_tokens['tMASK']=self.special_tokens['MASK'] self.all_toks.extend(self._special_tokens) self._vocab = {t: i for i, t in enumerate(self.all_toks)} self.command_token = {'[tMASK]': 'tMASK', '[MASK]':'MASK', '[gMASK]': 'gMASK', '[sMASK]':'sMASK'} # print('Building vocab.: {}'.format(self._vocab)) # print('Special_tokens: {}'.format(self.special_tokens)) # print('All tokens: {}'.format(self.all_toks)) def pad_id(self): return self._vocab['[pad]'] def set_special_tokens(self, special_tokens): """Add a list of additional tokens to the encoder. The additional tokens are indexed starting from the last index of the current vocabulary in the order of the `special_tokens` list. """ if not special_tokens: self.special_tokens = {} self.special_tokens_decoder = {} return self.special_tokens = dict((tok, len(self.all_toks) + i) for i, tok in enumerate(special_tokens)) self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} def __len__(self): return len(self._vocab) def EncodeAsIds(self, text, process_fn=None): """convert sequence to idx""" processed_text = text if process_fn is not None: processed_text = process_fn(processed_text) processed_text = str(processed_text) tokens = [self.TokenToId(c) for c in processed_text] return tokens def IdToToken(self, idx): if idx == 0: return '[pad]' elif idx in self.special_tokens_decoder: return f"[{self.special_tokens_decoder[idx]}]" else: try: tok = self.all_toks[idx] except: tok = '*' return tok def TokenToId(self, token): if token == '[pad]': return 0 elif token in self.special_tokens: return self.special_tokens[token] else: return self._vocab[token] def DecodeIds(self, Ids): return ''.join([self.IdToToken(tok) for tok in Ids]) def _tokenize(self, text) -> str: return text.split() def tokenize(self, text, **kwargs) -> List[str]: """ Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py Converts a string in a sequence of tokens, using the tokenizer. Args: text (:obj:`str`): The sequence to be encoded. Returns: :obj:`List[str]`: The list of tokens. """ def split_on_token(tok, text): result = [] split_text = text.split(tok) for i, sub_text in enumerate(split_text): # AddedToken can control whitespace stripping around them. # We use them for GPT2 and Roberta to have different behavior depending on the special token # Cf. https://github.com/huggingface/transformers/pull/2778 # and https://github.com/huggingface/transformers/issues/3788 # We strip left and right by default if i < len(split_text) - 1: sub_text = sub_text.rstrip() if i > 0: sub_text = sub_text.lstrip() if i == 0 and not sub_text: result.append(tok) elif i == len(split_text) - 1: if sub_text: result.append(sub_text) else: pass else: if sub_text: result.append(sub_text) result.append(tok) return result def split_on_tokens(tok_list, text): if not text.strip(): return [] tokenized_text = [] text_list = [text] for tok in tok_list: tokenized_text = [] for sub_text in text_list: if sub_text not in self._tokens: tokenized_text.extend(split_on_token(tok, sub_text)) else: tokenized_text.append(sub_text) text_list = tokenized_text return list( itertools.chain.from_iterable( ( self._tokenize(token) if token not in self.all_toks else [token] for token in tokenized_text ) ) ) no_split_token = self.all_toks tokenized_text = split_on_tokens(no_split_token, text) return self.convert_tokens_to_ids(tokenized_text) def convert_tokens_to_ids(self, tokens): """Converts a sequence of tokens into ids using the vocab.""" ids = [] # print_rank_0(tokens) # print_rank_0(self.vocab) for token in tokens: ids.append(self.TokenToId(token)) return ids class proteinglm_tokenizer: """ Protein Tokenizer based on Residue level tokenizer """ def __init__(self): name = 'ProteinTokenizer' self.tokenizer = ResidueLevelTokenizer() self.special_tokens = self.tokenizer.special_tokens def IdToToken(self, idx): return self.tokenizer.IdToToken(idx) def TokenToId(self, token): return self.tokenizer.TokenToId(token) @property def vocab_size(self): return len(self.tokenizer) def decode(self, token_ids): return self.tokenizer.DecodeIds([token_ids]) @property def eod(self): return self.tokenizer.get_special_token('eos') def detokenize(self, Ids, type_token=False): new_tokens = self.tokenizer.DecodeIds(Ids) return new_tokens def tokenize(self, text): ids = self.tokenizer.tokenize(text) return ids @property def vocab(self): return self.tokenizer._vocab @property def inv_vocab(self): return {v:k for k, v in self.tokenizer._vocab.items()} @property def get_pad_id(self): return self.tokenizer.pad_id def get_command(self, token): tok = token if token in self.tokenizer.command_token: tok = self.tokenizer.command_token[token] return self.tokenizer.special_tokens[tok]