|
from typing import Sequence, Tuple, List, Union |
|
import itertools |
|
|
|
class ResidueLevelTokenizer: |
|
""" |
|
Tokenizer for Protein Residue Level Tokenization. |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
super(ResidueLevelTokenizer, self).__init__() |
|
self.pad_tok = ['[pad]'] |
|
self.all_toks = self.pad_tok |
|
self._tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-'] |
|
self.all_toks.extend(self._tokens) |
|
self._special_tokens = ['MASK', 'gMASK', 'sMASK', 'eod', 'sop', 'eop', '</s>', '<M>'] |
|
self.set_special_tokens(self._special_tokens) |
|
self.special_tokens['eos']=self.special_tokens['</s>'] |
|
self.special_tokens['tMASK']=self.special_tokens['MASK'] |
|
|
|
self.all_toks.extend(self._special_tokens) |
|
self._vocab = {t: i for i, t in enumerate(self.all_toks)} |
|
self.command_token = {'[tMASK]': 'tMASK', '[MASK]':'MASK', '[gMASK]': 'gMASK', '[sMASK]':'sMASK'} |
|
|
|
|
|
|
|
|
|
def pad_id(self): |
|
return self._vocab['[pad]'] |
|
|
|
def set_special_tokens(self, special_tokens): |
|
"""Add a list of additional tokens to the encoder. |
|
The additional tokens are indexed starting from the last index of the |
|
current vocabulary in the order of the `special_tokens` list. |
|
""" |
|
if not special_tokens: |
|
self.special_tokens = {} |
|
self.special_tokens_decoder = {} |
|
return |
|
self.special_tokens = dict((tok, len(self.all_toks) + i) for i, tok in enumerate(special_tokens)) |
|
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} |
|
|
|
|
|
def __len__(self): |
|
return len(self._vocab) |
|
|
|
|
|
def EncodeAsIds(self, text, process_fn=None): |
|
"""convert sequence to idx""" |
|
processed_text = text |
|
if process_fn is not None: |
|
processed_text = process_fn(processed_text) |
|
processed_text = str(processed_text) |
|
tokens = [self.TokenToId(c) for c in processed_text] |
|
return tokens |
|
|
|
def IdToToken(self, idx): |
|
if idx == 0: |
|
return '[pad]' |
|
elif idx in self.special_tokens_decoder: |
|
return f"[{self.special_tokens_decoder[idx]}]" |
|
else: |
|
try: |
|
tok = self.all_toks[idx] |
|
except: |
|
tok = '*' |
|
return tok |
|
def TokenToId(self, token): |
|
if token == '[pad]': |
|
return 0 |
|
elif token in self.special_tokens: |
|
return self.special_tokens[token] |
|
else: |
|
return self._vocab[token] |
|
|
|
def DecodeIds(self, Ids): |
|
return ''.join([self.IdToToken(tok) for tok in Ids]) |
|
|
|
def _tokenize(self, text) -> str: |
|
return text.split() |
|
|
|
def tokenize(self, text, **kwargs) -> List[str]: |
|
""" |
|
Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py |
|
Converts a string in a sequence of tokens, using the tokenizer. |
|
|
|
Args: |
|
text (:obj:`str`): |
|
The sequence to be encoded. |
|
|
|
Returns: |
|
:obj:`List[str]`: The list of tokens. |
|
""" |
|
|
|
def split_on_token(tok, text): |
|
result = [] |
|
split_text = text.split(tok) |
|
for i, sub_text in enumerate(split_text): |
|
|
|
|
|
|
|
|
|
|
|
if i < len(split_text) - 1: |
|
sub_text = sub_text.rstrip() |
|
if i > 0: |
|
sub_text = sub_text.lstrip() |
|
|
|
if i == 0 and not sub_text: |
|
result.append(tok) |
|
elif i == len(split_text) - 1: |
|
if sub_text: |
|
result.append(sub_text) |
|
else: |
|
pass |
|
else: |
|
if sub_text: |
|
result.append(sub_text) |
|
result.append(tok) |
|
return result |
|
|
|
def split_on_tokens(tok_list, text): |
|
if not text.strip(): |
|
return [] |
|
|
|
tokenized_text = [] |
|
text_list = [text] |
|
for tok in tok_list: |
|
tokenized_text = [] |
|
for sub_text in text_list: |
|
if sub_text not in self._tokens: |
|
tokenized_text.extend(split_on_token(tok, sub_text)) |
|
else: |
|
tokenized_text.append(sub_text) |
|
text_list = tokenized_text |
|
|
|
return list( |
|
itertools.chain.from_iterable( |
|
( |
|
self._tokenize(token) |
|
if token not in self.all_toks |
|
else [token] |
|
for token in tokenized_text |
|
) |
|
) |
|
) |
|
no_split_token = self.all_toks |
|
tokenized_text = split_on_tokens(no_split_token, text) |
|
return self.convert_tokens_to_ids(tokenized_text) |
|
|
|
def convert_tokens_to_ids(self, tokens): |
|
"""Converts a sequence of tokens into ids using the vocab.""" |
|
ids = [] |
|
|
|
|
|
for token in tokens: |
|
ids.append(self.TokenToId(token)) |
|
return ids |
|
|
|
|
|
class proteinglm_tokenizer: |
|
""" |
|
Protein Tokenizer based on Residue level tokenizer |
|
""" |
|
|
|
def __init__(self): |
|
name = 'ProteinTokenizer' |
|
self.tokenizer = ResidueLevelTokenizer() |
|
self.special_tokens = self.tokenizer.special_tokens |
|
|
|
|
|
def IdToToken(self, idx): |
|
return self.tokenizer.IdToToken(idx) |
|
|
|
def TokenToId(self, token): |
|
return self.tokenizer.TokenToId(token) |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self.tokenizer) |
|
|
|
def decode(self, token_ids): |
|
return self.tokenizer.DecodeIds([token_ids]) |
|
|
|
@property |
|
def eod(self): |
|
return self.tokenizer.get_special_token('eos') |
|
|
|
def detokenize(self, Ids, type_token=False): |
|
new_tokens = self.tokenizer.DecodeIds(Ids) |
|
return new_tokens |
|
|
|
def tokenize(self, text): |
|
ids = self.tokenizer.tokenize(text) |
|
return ids |
|
|
|
@property |
|
def vocab(self): |
|
return self.tokenizer._vocab |
|
|
|
@property |
|
def inv_vocab(self): |
|
return {v:k for k, v in self.tokenizer._vocab.items()} |
|
|
|
@property |
|
def get_pad_id(self): |
|
return self.tokenizer.pad_id |
|
|
|
|
|
def get_command(self, token): |
|
tok = token |
|
if token in self.tokenizer.command_token: |
|
tok = self.tokenizer.command_token[token] |
|
return self.tokenizer.special_tokens[tok] |
|
|