ltg
/

larkkin's picture
Add code and readme
c45d283
raw
history blame
4.26 kB
from __future__ import unicode_literals
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
class Vocab(object):
"""Defines a vocabulary object that will be used to numericalize a field.
Attributes:
freqs: A collections.Counter object holding the frequencies of tokens
in the data used to build the Vocab.
stoi: A collections.defaultdict instance mapping token strings to
numerical identifiers.
itos: A list of token strings indexed by their numerical identifiers.
"""
# TODO (@mttk): Populate classs with default values of special symbols
UNK = '<unk>'
def __init__(self, counter, max_size=None, min_freq=1, specials=['<unk>', '<pad>'], specials_first=True):
"""Create a Vocab object from a collections.Counter.
Arguments:
counter: collections.Counter object holding the frequencies of
each value found in the data.
max_size: The maximum size of the vocabulary, or None for no
maximum. Default: None.
min_freq: The minimum frequency needed to include a token in the
vocabulary. Values less than 1 will be set to 1. Default: 1.
specials: The list of special tokens (e.g., padding or eos) that
will be prepended to the vocabulary. Default: ['<unk'>, '<pad>']
specials_first: Whether to add special tokens into the vocabulary at first.
If it is False, they are added into the vocabulary at last.
Default: True.
"""
self.freqs = counter
counter = counter.copy()
min_freq = max(min_freq, 1)
self.itos = list()
self.unk_index = None
if specials_first:
self.itos = list(specials)
# only extend max size if specials are prepended
max_size = None if max_size is None else max_size + len(specials)
# frequencies of special tokens are not counted when building vocabulary
# in frequency order
for tok in specials:
del counter[tok]
# sort by frequency, then alphabetically
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
for word, freq in words_and_frequencies:
if freq < min_freq or len(self.itos) == max_size:
break
self.itos.append(word)
if Vocab.UNK in specials: # hard-coded for now
unk_index = specials.index(Vocab.UNK) # position in list
# account for ordering of specials, set variable
self.unk_index = unk_index if specials_first else len(self.itos) + unk_index
self.stoi = defaultdict(self._default_unk_index)
else:
self.stoi = defaultdict()
if not specials_first:
self.itos.extend(list(specials))
# stoi is simply a reverse dict for itos
self.stoi.update({tok: i for i, tok in enumerate(self.itos)})
def _default_unk_index(self):
return self.unk_index
def __getitem__(self, token):
return self.stoi.get(token, self.stoi.get(Vocab.UNK))
def __getstate__(self):
# avoid picking defaultdict
attrs = dict(self.__dict__)
# cast to regular dict
attrs['stoi'] = dict(self.stoi)
return attrs
def __setstate__(self, state):
if state.get("unk_index", None) is None:
stoi = defaultdict()
else:
stoi = defaultdict(self._default_unk_index)
stoi.update(state['stoi'])
state['stoi'] = stoi
self.__dict__.update(state)
def __eq__(self, other):
if self.freqs != other.freqs:
return False
if self.stoi != other.stoi:
return False
if self.itos != other.itos:
return False
return True
def __len__(self):
return len(self.itos)
def extend(self, v, sort=False):
words = sorted(v.itos) if sort else v.itos
for w in words:
if w not in self.stoi:
self.itos.append(w)
self.stoi[w] = len(self.itos) - 1