English
Protein_Language_Model
MSA Generation
MSAGPT / utils /tokenization.py
Bo1015's picture
Upload 27 files
0dce0bd verified
from typing import Sequence, Tuple, List, Union
import itertools
class ResidueLevelTokenizer:
"""
Tokenizer for Protein Residue Level Tokenization.
"""
def __init__(self, **kwargs):
super(ResidueLevelTokenizer, self).__init__()
self.pad_tok = ['[pad]']
self.all_toks = self.pad_tok
self._tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
self.all_toks.extend(self._tokens)
self._special_tokens = ['MASK', 'gMASK', 'sMASK', 'eod', 'sop', 'eop', '</s>', '<M>']
self.set_special_tokens(self._special_tokens)
self.special_tokens['eos']=self.special_tokens['</s>']
self.special_tokens['tMASK']=self.special_tokens['MASK']
self.all_toks.extend(self._special_tokens)
self._vocab = {t: i for i, t in enumerate(self.all_toks)}
self.command_token = {'[tMASK]': 'tMASK', '[MASK]':'MASK', '[gMASK]': 'gMASK', '[sMASK]':'sMASK'}
# print('Building vocab.: {}'.format(self._vocab))
# print('Special_tokens: {}'.format(self.special_tokens))
# print('All tokens: {}'.format(self.all_toks))
def pad_id(self):
return self._vocab['[pad]']
def set_special_tokens(self, special_tokens):
"""Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.all_toks) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
def __len__(self):
return len(self._vocab)
def EncodeAsIds(self, text, process_fn=None):
"""convert sequence to idx"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
processed_text = str(processed_text)
tokens = [self.TokenToId(c) for c in processed_text]
return tokens
def IdToToken(self, idx):
if idx == 0:
return '[pad]'
elif idx in self.special_tokens_decoder:
return f"[{self.special_tokens_decoder[idx]}]"
else:
try:
tok = self.all_toks[idx]
except:
tok = '*'
return tok
def TokenToId(self, token):
if token == '[pad]':
return 0
elif token in self.special_tokens:
return self.special_tokens[token]
else:
return self._vocab[token]
def DecodeIds(self, Ids):
return ''.join([self.IdToToken(tok) for tok in Ids])
def _tokenize(self, text) -> str:
return text.split()
def tokenize(self, text, **kwargs) -> List[str]:
"""
Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
Converts a string in a sequence of tokens, using the tokenizer.
Args:
text (:obj:`str`):
The sequence to be encoded.
Returns:
:obj:`List[str]`: The list of tokens.
"""
def split_on_token(tok, text):
result = []
split_text = text.split(tok)
for i, sub_text in enumerate(split_text):
# AddedToken can control whitespace stripping around them.
# We use them for GPT2 and Roberta to have different behavior depending on the special token
# Cf. https://github.com/huggingface/transformers/pull/2778
# and https://github.com/huggingface/transformers/issues/3788
# We strip left and right by default
if i < len(split_text) - 1:
sub_text = sub_text.rstrip()
if i > 0:
sub_text = sub_text.lstrip()
if i == 0 and not sub_text:
result.append(tok)
elif i == len(split_text) - 1:
if sub_text:
result.append(sub_text)
else:
pass
else:
if sub_text:
result.append(sub_text)
result.append(tok)
return result
def split_on_tokens(tok_list, text):
if not text.strip():
return []
tokenized_text = []
text_list = [text]
for tok in tok_list:
tokenized_text = []
for sub_text in text_list:
if sub_text not in self._tokens:
tokenized_text.extend(split_on_token(tok, sub_text))
else:
tokenized_text.append(sub_text)
text_list = tokenized_text
return list(
itertools.chain.from_iterable(
(
self._tokenize(token)
if token not in self.all_toks
else [token]
for token in tokenized_text
)
)
)
no_split_token = self.all_toks
tokenized_text = split_on_tokens(no_split_token, text)
return self.convert_tokens_to_ids(tokenized_text)
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
# print_rank_0(tokens)
# print_rank_0(self.vocab)
for token in tokens:
ids.append(self.TokenToId(token))
return ids
class proteinglm_tokenizer:
"""
Protein Tokenizer based on Residue level tokenizer
"""
def __init__(self):
name = 'ProteinTokenizer'
self.tokenizer = ResidueLevelTokenizer()
self.special_tokens = self.tokenizer.special_tokens
def IdToToken(self, idx):
return self.tokenizer.IdToToken(idx)
def TokenToId(self, token):
return self.tokenizer.TokenToId(token)
@property
def vocab_size(self):
return len(self.tokenizer)
def decode(self, token_ids):
return self.tokenizer.DecodeIds([token_ids])
@property
def eod(self):
return self.tokenizer.get_special_token('eos')
def detokenize(self, Ids, type_token=False):
new_tokens = self.tokenizer.DecodeIds(Ids)
return new_tokens
def tokenize(self, text):
ids = self.tokenizer.tokenize(text)
return ids
@property
def vocab(self):
return self.tokenizer._vocab
@property
def inv_vocab(self):
return {v:k for k, v in self.tokenizer._vocab.items()}
@property
def get_pad_id(self):
return self.tokenizer.pad_id
def get_command(self, token):
tok = token
if token in self.tokenizer.command_token:
tok = self.tokenizer.command_token[token]
return self.tokenizer.special_tokens[tok]