|
from __future__ import unicode_literals |
|
from collections import defaultdict |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Vocab(object): |
|
"""Defines a vocabulary object that will be used to numericalize a field. |
|
|
|
Attributes: |
|
freqs: A collections.Counter object holding the frequencies of tokens |
|
in the data used to build the Vocab. |
|
stoi: A collections.defaultdict instance mapping token strings to |
|
numerical identifiers. |
|
itos: A list of token strings indexed by their numerical identifiers. |
|
""" |
|
|
|
|
|
UNK = '<unk>' |
|
|
|
def __init__(self, counter, max_size=None, min_freq=1, specials=['<unk>', '<pad>'], specials_first=True): |
|
"""Create a Vocab object from a collections.Counter. |
|
|
|
Arguments: |
|
counter: collections.Counter object holding the frequencies of |
|
each value found in the data. |
|
max_size: The maximum size of the vocabulary, or None for no |
|
maximum. Default: None. |
|
min_freq: The minimum frequency needed to include a token in the |
|
vocabulary. Values less than 1 will be set to 1. Default: 1. |
|
specials: The list of special tokens (e.g., padding or eos) that |
|
will be prepended to the vocabulary. Default: ['<unk'>, '<pad>'] |
|
specials_first: Whether to add special tokens into the vocabulary at first. |
|
If it is False, they are added into the vocabulary at last. |
|
Default: True. |
|
""" |
|
self.freqs = counter |
|
counter = counter.copy() |
|
min_freq = max(min_freq, 1) |
|
|
|
self.itos = list() |
|
self.unk_index = None |
|
if specials_first: |
|
self.itos = list(specials) |
|
|
|
max_size = None if max_size is None else max_size + len(specials) |
|
|
|
|
|
|
|
for tok in specials: |
|
del counter[tok] |
|
|
|
|
|
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) |
|
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) |
|
|
|
for word, freq in words_and_frequencies: |
|
if freq < min_freq or len(self.itos) == max_size: |
|
break |
|
self.itos.append(word) |
|
|
|
if Vocab.UNK in specials: |
|
unk_index = specials.index(Vocab.UNK) |
|
|
|
self.unk_index = unk_index if specials_first else len(self.itos) + unk_index |
|
self.stoi = defaultdict(self._default_unk_index) |
|
else: |
|
self.stoi = defaultdict() |
|
|
|
if not specials_first: |
|
self.itos.extend(list(specials)) |
|
|
|
|
|
self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) |
|
|
|
def _default_unk_index(self): |
|
return self.unk_index |
|
|
|
def __getitem__(self, token): |
|
return self.stoi.get(token, self.stoi.get(Vocab.UNK)) |
|
|
|
def __getstate__(self): |
|
|
|
attrs = dict(self.__dict__) |
|
|
|
attrs['stoi'] = dict(self.stoi) |
|
return attrs |
|
|
|
def __setstate__(self, state): |
|
if state.get("unk_index", None) is None: |
|
stoi = defaultdict() |
|
else: |
|
stoi = defaultdict(self._default_unk_index) |
|
stoi.update(state['stoi']) |
|
state['stoi'] = stoi |
|
self.__dict__.update(state) |
|
|
|
def __eq__(self, other): |
|
if self.freqs != other.freqs: |
|
return False |
|
if self.stoi != other.stoi: |
|
return False |
|
if self.itos != other.itos: |
|
return False |
|
return True |
|
|
|
def __len__(self): |
|
return len(self.itos) |
|
|
|
def extend(self, v, sort=False): |
|
words = sorted(v.itos) if sort else v.itos |
|
for w in words: |
|
if w not in self.stoi: |
|
self.itos.append(w) |
|
self.stoi[w] = len(self.itos) - 1 |
|
|