File size: 4,428 Bytes
5d58b52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# coding: utf-8
"""
Every NLP task needs a Vocabulary
Every Vocabulary is built from Instances
Every Instance is a collection of Fields
"""
__all__ = ['DefaultLookupDict', 'Vocabulary']
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<bos>'
EOS_TOKEN = '<eos>'
PAD_IDX = 0
UNK_IDX = 1
class DefaultLookupDict(dict):
def __init__(self, default):
super(DefaultLookupDict, self).__init__()
self._default = default
def __getitem__(self, item):
return self.get(item, self._default)
class Vocabulary:
"""
Define a vocabulary object that will be used to numericalize a field.
Attributes:
token2id: A collections.defaultdict instance mapping token strings to
numerical identifiers.
id2token: A list of token strings indexed by their numerical
identifiers.
embedding: pretrained vectors.
Examples:
>>> from torchlight.vocab import Vocabulary
>>> from collections import Counter
>>> text_data = ['hello', 'world', 'hello', 'nice', 'world', 'hi', 'world']
>>> vocab = Vocabulary(Counter(text_data))
"""
def __init__(self, counter, max_size=None, min_freq=1, specials=None):
"""
Create a Vocabulary given Counter.
Args:
counter: collections.Counter object holding the frequencies of
each value found in the data.
max_size: The maximum size of the vocabulary, or None for no
maximum. Default: None.
min_freq: The minimum frequency needed to include a token in the
vocabulary. Values less than 1 will be set to 1. Default: 1.
specials: The list of special tokens except ['<pad>', '<unk>'].
Possible choices: [CLS] [MASK] [SEP] in BERT or <bos> <eos>
in Machine Translation.
"""
min_freq = max(min_freq, 1) # must be positive
if specials is None:
self.specials = [PAD_TOKEN, UNK_TOKEN]
else:
assert isinstance(specials, list), "'specials' is of type list"
self.specials = [PAD_TOKEN, UNK_TOKEN] + specials
assert len(set(self.specials)) == len(self.specials), \
"specials can not contain duplicates."
if max_size is not None:
max_size = len(self.specials) + max_size
self.id2token = self.specials[:]
self.token2id = DefaultLookupDict(UNK_IDX)
self.token2id.update({tok: i for i, tok in enumerate(self.id2token)})
# sort by frequency, then alphabetically
token_freqs = sorted(counter.items(), key=lambda tup: tup[0])
token_freqs.sort(key=lambda tup: tup[1], reverse=True)
for token, freq in token_freqs:
if freq < min_freq or len(self.id2token) == max_size:
break
if token not in self.specials:
self.id2token.append(token)
self.token2id[token] = len(self.id2token) - 1
# TODO
self.embedding = None
def __len__(self):
return len(self.id2token)
def __repr__(self):
return 'Vocab(size={}, specials="{}")'.format(len(self), self.specials)
def __getitem__(self, tokens):
"""Looks up indices of text tokens according to the vocabulary.
If `unknown_token` of the vocabulary is None, looking up unknown tokens
results in KeyError.
Parameters
----------
tokens : str or list of strs
A source token or tokens to be converted.
Returns
-------
int or list of ints
A token index or a list of token indices according to the vocabulary.
"""
if not isinstance(tokens, (list, tuple)):
return self.token2id[tokens]
else:
return [self.token2id[token] for token in tokens]
def __call__(self, tokens):
"""Looks up indices of text tokens according to the vocabulary.
Parameters
----------
tokens : str or list of strs
A source token or tokens to be converted.
Returns
-------
int or list of ints
A token index or a list of token indices according to the
vocabulary.
"""
return self[tokens]
@classmethod
def from_json(cls, json_str):
pass
def to_json(self):
pass
def set_embedding(self):
pass
|