|
import random |
|
from contextlib import contextmanager |
|
from copy import deepcopy |
|
import re |
|
|
|
from functools import partial |
|
|
|
|
|
def _split_tokenizer(x): |
|
return x.split() |
|
|
|
|
|
def _spacy_tokenize(x, spacy): |
|
return [tok.text for tok in spacy.tokenizer(x)] |
|
|
|
|
|
_patterns = [r'\'', |
|
r'\"', |
|
r'\.', |
|
r'<br \/>', |
|
r',', |
|
r'\(', |
|
r'\)', |
|
r'\!', |
|
r'\?', |
|
r'\;', |
|
r'\:', |
|
r'\s+'] |
|
|
|
_replacements = [' \' ', |
|
'', |
|
' . ', |
|
' ', |
|
' , ', |
|
' ( ', |
|
' ) ', |
|
' ! ', |
|
' ? ', |
|
' ', |
|
' ', |
|
' '] |
|
|
|
_patterns_dict = list((re.compile(p), r) for p, r in zip(_patterns, _replacements)) |
|
|
|
|
|
def _basic_english_normalize(line): |
|
r""" |
|
Basic normalization for a line of text. |
|
Normalization includes |
|
- lowercasing |
|
- complete some basic text normalization for English words as follows: |
|
add spaces before and after '\'' |
|
remove '\"', |
|
add spaces before and after '.' |
|
replace '<br \/>'with single space |
|
add spaces before and after ',' |
|
add spaces before and after '(' |
|
add spaces before and after ')' |
|
add spaces before and after '!' |
|
add spaces before and after '?' |
|
replace ';' with single space |
|
replace ':' with single space |
|
replace multiple spaces with single space |
|
|
|
Returns a list of tokens after splitting on whitespace. |
|
""" |
|
|
|
line = line.lower() |
|
for pattern_re, replaced_str in _patterns_dict: |
|
line = pattern_re.sub(replaced_str, line) |
|
return line.split() |
|
|
|
|
|
def get_tokenizer(tokenizer, language='en'): |
|
r""" |
|
Generate tokenizer function for a string sentence. |
|
|
|
Arguments: |
|
tokenizer: the name of tokenizer function. If None, it returns split() |
|
function, which splits the string sentence by space. |
|
If basic_english, it returns _basic_english_normalize() function, |
|
which normalize the string first and split by space. If a callable |
|
function, it will return the function. If a tokenizer library |
|
(e.g. spacy, moses, toktok, revtok, subword), it returns the |
|
corresponding library. |
|
language: Default en |
|
|
|
Examples: |
|
>>> import torchtext |
|
>>> from torchtext.data import get_tokenizer |
|
>>> tokenizer = get_tokenizer("basic_english") |
|
>>> tokens = tokenizer("You can now install TorchText using pip!") |
|
>>> tokens |
|
>>> ['you', 'can', 'now', 'install', 'torchtext', 'using', 'pip', '!'] |
|
|
|
""" |
|
|
|
|
|
if tokenizer is None: |
|
return _split_tokenizer |
|
|
|
if tokenizer == "basic_english": |
|
if language != 'en': |
|
raise ValueError("Basic normalization is only available for Enlish(en)") |
|
return _basic_english_normalize |
|
|
|
|
|
if callable(tokenizer): |
|
return tokenizer |
|
|
|
if tokenizer == "spacy": |
|
try: |
|
import spacy |
|
spacy = spacy.load(language) |
|
return partial(_spacy_tokenize, spacy=spacy) |
|
except ImportError: |
|
print("Please install SpaCy. " |
|
"See the docs at https://spacy.io for more information.") |
|
raise |
|
except AttributeError: |
|
print("Please install SpaCy and the SpaCy {} tokenizer. " |
|
"See the docs at https://spacy.io for more " |
|
"information.".format(language)) |
|
raise |
|
elif tokenizer == "moses": |
|
try: |
|
from sacremoses import MosesTokenizer |
|
moses_tokenizer = MosesTokenizer() |
|
return moses_tokenizer.tokenize |
|
except ImportError: |
|
print("Please install SacreMoses. " |
|
"See the docs at https://github.com/alvations/sacremoses " |
|
"for more information.") |
|
raise |
|
elif tokenizer == "toktok": |
|
try: |
|
from nltk.tokenize.toktok import ToktokTokenizer |
|
toktok = ToktokTokenizer() |
|
return toktok.tokenize |
|
except ImportError: |
|
print("Please install NLTK. " |
|
"See the docs at https://nltk.org for more information.") |
|
raise |
|
elif tokenizer == 'revtok': |
|
try: |
|
import revtok |
|
return revtok.tokenize |
|
except ImportError: |
|
print("Please install revtok.") |
|
raise |
|
elif tokenizer == 'subword': |
|
try: |
|
import revtok |
|
return partial(revtok.tokenize, decap=True) |
|
except ImportError: |
|
print("Please install revtok.") |
|
raise |
|
raise ValueError("Requested tokenizer {}, valid choices are a " |
|
"callable that takes a single string as input, " |
|
"\"revtok\" for the revtok reversible tokenizer, " |
|
"\"subword\" for the revtok caps-aware tokenizer, " |
|
"\"spacy\" for the SpaCy English tokenizer, or " |
|
"\"moses\" for the NLTK port of the Moses tokenization " |
|
"script.".format(tokenizer)) |
|
|
|
|
|
def is_tokenizer_serializable(tokenizer, language): |
|
"""Extend with other tokenizers which are found to not be serializable |
|
""" |
|
if tokenizer == 'spacy': |
|
return False |
|
return True |
|
|
|
|
|
def interleave_keys(a, b): |
|
"""Interleave bits from two sort keys to form a joint sort key. |
|
|
|
Examples that are similar in both of the provided keys will have similar |
|
values for the key defined by this function. Useful for tasks with two |
|
text fields like machine translation or natural language inference. |
|
""" |
|
def interleave(args): |
|
return ''.join([x for t in zip(*args) for x in t]) |
|
return int(''.join(interleave(format(x, '016b') for x in (a, b))), base=2) |
|
|
|
|
|
def get_torch_version(): |
|
import torch |
|
v = torch.__version__ |
|
version_substrings = v.split('.') |
|
major, minor = version_substrings[0], version_substrings[1] |
|
return int(major), int(minor) |
|
|
|
|
|
def dtype_to_attr(dtype): |
|
|
|
|
|
|
|
_, dtype = str(dtype).split('.') |
|
return dtype |
|
|
|
|
|
|
|
def ngrams_iterator(token_list, ngrams): |
|
"""Return an iterator that yields the given tokens and their ngrams. |
|
|
|
Arguments: |
|
token_list: A list of tokens |
|
ngrams: the number of ngrams. |
|
|
|
Examples: |
|
>>> token_list = ['here', 'we', 'are'] |
|
>>> list(ngrams_iterator(token_list, 2)) |
|
>>> ['here', 'here we', 'we', 'we are', 'are'] |
|
""" |
|
|
|
def _get_ngrams(n): |
|
return zip(*[token_list[i:] for i in range(n)]) |
|
|
|
for x in token_list: |
|
yield x |
|
for n in range(2, ngrams + 1): |
|
for x in _get_ngrams(n): |
|
yield ' '.join(x) |
|
|
|
|
|
class RandomShuffler(object): |
|
"""Use random functions while keeping track of the random state to make it |
|
reproducible and deterministic.""" |
|
|
|
def __init__(self, random_state=None): |
|
self._random_state = random_state |
|
if self._random_state is None: |
|
self._random_state = random.getstate() |
|
|
|
@contextmanager |
|
def use_internal_state(self): |
|
"""Use a specific RNG state.""" |
|
old_state = random.getstate() |
|
random.setstate(self._random_state) |
|
yield |
|
self._random_state = random.getstate() |
|
random.setstate(old_state) |
|
|
|
@property |
|
def random_state(self): |
|
return deepcopy(self._random_state) |
|
|
|
@random_state.setter |
|
def random_state(self, s): |
|
self._random_state = s |
|
|
|
def __call__(self, data): |
|
"""Shuffle and return a new list.""" |
|
with self.use_internal_state(): |
|
return random.sample(data, len(data)) |
|
|