""" Embeddings module """ import math import warnings import torch import torch.nn as nn from onmt.modules.util_class import Elementwise class SequenceTooLongError(Exception): pass class PositionalEncoding(nn.Module): """Sinusoidal positional encoding for non-recurrent neural networks. Implementation based on "Attention Is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` Args: dropout (float): dropout parameter dim (int): embedding size """ def __init__(self, dropout, dim, max_len=5000): if dim % 2 != 0: raise ValueError("Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(dim)) pe = torch.zeros(max_len, dim) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) pe = pe.unsqueeze(1) super(PositionalEncoding, self).__init__() self.register_buffer('pe', pe) self.dropout = nn.Dropout(p=dropout) self.dim = dim def forward(self, emb, step=None): """Embed inputs. Args: emb (FloatTensor): Sequence of word vectors ``(seq_len, batch_size, self.dim)`` step (int or NoneType): If stepwise (``seq_len = 1``), use the encoding for this position. """ emb = emb * math.sqrt(self.dim) step = step or 0 if self.pe.size(0) < step + emb.size(0): raise SequenceTooLongError( f"Sequence is {emb.size(0) + step} but PositionalEncoding is" f" limited to {self.pe.size(0)}. See max_len argument." ) emb = emb + self.pe[step:emb.size(0)+step] emb = self.dropout(emb) return emb class Embeddings(nn.Module): """Words embeddings for encoder/decoder. Additionally includes ability to add sparse input features based on "Linguistic Input Features Improve Neural Machine Translation" :cite:`sennrich2016linguistic`. .. mermaid:: graph LR A[Input] C[Feature 1 Lookup] A-->B[Word Lookup] A-->C A-->D[Feature N Lookup] B-->E[MLP/Concat] C-->E D-->E E-->F[Output] Args: word_vec_size (int): size of the dictionary of embeddings. word_padding_idx (int): padding index for words in the embeddings. feat_padding_idx (List[int]): padding index for a list of features in the embeddings. word_vocab_size (int): size of dictionary of embeddings for words. feat_vocab_sizes (List[int], optional): list of size of dictionary of embeddings for each feature. position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding` feat_merge (string): merge action for the features embeddings: concat, sum or mlp. feat_vec_exponent (float): when using `-feat_merge concat`, feature embedding size is N^feat_dim_exponent, where N is the number of values the feature takes. feat_vec_size (int): embedding dimension for features when using `-feat_merge mlp` dropout (float): dropout probability. freeze_word_vecs (bool): freeze weights of word vectors. """ def __init__(self, word_vec_size, word_vocab_size, word_padding_idx, position_encoding=False, feat_merge="concat", feat_vec_exponent=0.7, feat_vec_size=-1, feat_padding_idx=[], feat_vocab_sizes=[], dropout=0, sparse=False, freeze_word_vecs=False): self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx) if feat_padding_idx is None: feat_padding_idx = [] self.word_padding_idx = word_padding_idx self.word_vec_size = word_vec_size # Dimensions and padding for constructing the word embedding matrix vocab_sizes = [word_vocab_size] emb_dims = [word_vec_size] pad_indices = [word_padding_idx] # Dimensions and padding for feature embedding matrices # (these have no effect if feat_vocab_sizes is empty) if feat_merge == 'sum': feat_dims = [word_vec_size] * len(feat_vocab_sizes) elif feat_vec_size > 0: feat_dims = [feat_vec_size] * len(feat_vocab_sizes) else: feat_dims = [int(vocab ** feat_vec_exponent) for vocab in feat_vocab_sizes] vocab_sizes.extend(feat_vocab_sizes) emb_dims.extend(feat_dims) pad_indices.extend(feat_padding_idx) # The embedding matrix look-up tables. The first look-up table # is for words. Subsequent ones are for features, if any exist. emb_params = zip(vocab_sizes, emb_dims, pad_indices) embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse) for vocab, dim, pad in emb_params] emb_luts = Elementwise(feat_merge, embeddings) # The final output size of word + feature vectors. This can vary # from the word vector size if and only if features are defined. # This is the attribute you should access if you need to know # how big your embeddings are going to be. self.embedding_size = (sum(emb_dims) if feat_merge == 'concat' else word_vec_size) # The sequence of operations that converts the input sequence # into a sequence of embeddings. At minimum this consists of # looking up the embeddings for each word and feature in the # input. Model parameters may require the sequence to contain # additional operations as well. super(Embeddings, self).__init__() self.make_embedding = nn.Sequential() self.make_embedding.add_module('emb_luts', emb_luts) if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0: in_dim = sum(emb_dims) mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU()) self.make_embedding.add_module('mlp', mlp) self.position_encoding = position_encoding if self.position_encoding: pe = PositionalEncoding(dropout, self.embedding_size) self.make_embedding.add_module('pe', pe) if freeze_word_vecs: self.word_lut.weight.requires_grad = False def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx): if feat_merge == "sum": # features must use word_vec_size if feat_vec_exponent != 0.7: warnings.warn("Merging with sum, but got non-default " "feat_vec_exponent. It will be unused.") if feat_vec_size != -1: warnings.warn("Merging with sum, but got non-default " "feat_vec_size. It will be unused.") elif feat_vec_size > 0: # features will use feat_vec_size if feat_vec_exponent != -1: warnings.warn("Not merging with sum and positive " "feat_vec_size, but got non-default " "feat_vec_exponent. It will be unused.") else: if feat_vec_exponent <= 0: raise ValueError("Using feat_vec_exponent to determine " "feature vec size, but got feat_vec_exponent " "less than or equal to 0.") n_feats = len(feat_vocab_sizes) if n_feats != len(feat_padding_idx): raise ValueError("Got unequal number of feat_vocab_sizes and " "feat_padding_idx ({:d} != {:d})".format( n_feats, len(feat_padding_idx))) @property def word_lut(self): """Word look-up table.""" return self.make_embedding[0][0] @property def emb_luts(self): """Embedding look-up table.""" return self.make_embedding[0] def load_pretrained_vectors(self, emb_file): """Load in pretrained embeddings. Args: emb_file (str) : path to torch serialized embeddings """ if emb_file: pretrained = torch.load(emb_file) pretrained_vec_size = pretrained.size(1) if self.word_vec_size > pretrained_vec_size: self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained elif self.word_vec_size < pretrained_vec_size: self.word_lut.weight.data \ .copy_(pretrained[:, :self.word_vec_size]) else: self.word_lut.weight.data.copy_(pretrained) def forward(self, source, step=None): """Computes the embeddings for words and features. Args: source (LongTensor): index tensor ``(len, batch, nfeat)`` Returns: FloatTensor: Word embeddings ``(len, batch, embedding_size)`` """ if self.position_encoding: for i, module in enumerate(self.make_embedding._modules.values()): if i == len(self.make_embedding._modules.values()) - 1: source = module(source, step=step) else: source = module(source) else: source = self.make_embedding(source) return source def update_dropout(self, dropout): if self.position_encoding: self._modules['make_embedding'][1].dropout.p = dropout