Keras
legal
File size: 4,428 Bytes
5d58b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# coding: utf-8
"""
Every NLP task needs a Vocabulary
Every Vocabulary is built from Instances
Every Instance is a collection of Fields
"""

__all__ = ['DefaultLookupDict', 'Vocabulary']

PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<bos>'
EOS_TOKEN = '<eos>'
PAD_IDX = 0
UNK_IDX = 1


class DefaultLookupDict(dict):
    def __init__(self, default):
        super(DefaultLookupDict, self).__init__()
        self._default = default

    def __getitem__(self, item):
        return self.get(item, self._default)


class Vocabulary:
    """
    Define a vocabulary object that will be used to numericalize a field.
    Attributes:
        token2id: A collections.defaultdict instance mapping token strings to
            numerical identifiers.
        id2token: A list of token strings indexed by their numerical
        identifiers.
        embedding: pretrained vectors.

    Examples:
    >>> from torchlight.vocab import Vocabulary
    >>> from collections import Counter
    >>> text_data = ['hello', 'world', 'hello', 'nice', 'world', 'hi', 'world']
    >>> vocab = Vocabulary(Counter(text_data))
    """
    def __init__(self, counter, max_size=None, min_freq=1, specials=None):
        """
        Create a Vocabulary given Counter.
        Args:
            counter: collections.Counter object holding the frequencies of
                each value found in the data.
            max_size: The maximum size of the vocabulary, or None for no
                maximum. Default: None.
            min_freq: The minimum frequency needed to include a token in the
                vocabulary. Values less than 1 will be set to 1. Default: 1.
            specials: The list of special tokens except ['<pad>', '<unk>'].
                Possible choices: [CLS] [MASK] [SEP] in BERT or <bos> <eos>
                in Machine Translation.
        """
        min_freq = max(min_freq, 1)  # must be positive

        if specials is None:
            self.specials = [PAD_TOKEN, UNK_TOKEN]
        else:
            assert isinstance(specials, list), "'specials' is of type list"
            self.specials = [PAD_TOKEN, UNK_TOKEN] + specials

        assert len(set(self.specials)) == len(self.specials), \
            "specials can not contain duplicates."

        if max_size is not None:
            max_size = len(self.specials) + max_size

        self.id2token = self.specials[:]
        self.token2id = DefaultLookupDict(UNK_IDX)
        self.token2id.update({tok: i for i, tok in enumerate(self.id2token)})

        # sort by frequency, then alphabetically
        token_freqs = sorted(counter.items(), key=lambda tup: tup[0])
        token_freqs.sort(key=lambda tup: tup[1], reverse=True)

        for token, freq in token_freqs:
            if freq < min_freq or len(self.id2token) == max_size:
                break
            if token not in self.specials:
                self.id2token.append(token)
                self.token2id[token] = len(self.id2token) - 1

        # TODO
        self.embedding = None

    def __len__(self):
        return len(self.id2token)

    def __repr__(self):
        return 'Vocab(size={}, specials="{}")'.format(len(self), self.specials)

    def __getitem__(self, tokens):
        """Looks up indices of text tokens according to the vocabulary.
        If `unknown_token` of the vocabulary is None, looking up unknown tokens
        results in KeyError.
        Parameters
        ----------
        tokens : str or list of strs
            A source token or tokens to be converted.
        Returns
        -------
        int or list of ints
            A token index or a list of token indices according to the vocabulary.
        """

        if not isinstance(tokens, (list, tuple)):
            return self.token2id[tokens]
        else:
            return [self.token2id[token] for token in tokens]

    def __call__(self, tokens):
        """Looks up indices of text tokens according to the vocabulary.
        Parameters
        ----------
        tokens : str or list of strs
            A source token or tokens to be converted.
        Returns
        -------
        int or list of ints
            A token index or a list of token indices according to the
            vocabulary.
        """

        return self[tokens]

    @classmethod
    def from_json(cls, json_str):
        pass

    def to_json(self):
        pass

    def set_embedding(self):
        pass