|
""" Embeddings module """ |
|
import math |
|
import warnings |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from onmt.modules.util_class import Elementwise |
|
|
|
|
|
class SequenceTooLongError(Exception): |
|
pass |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
"""Sinusoidal positional encoding for non-recurrent neural networks. |
|
|
|
Implementation based on "Attention Is All You Need" |
|
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17` |
|
|
|
Args: |
|
dropout (float): dropout parameter |
|
dim (int): embedding size |
|
""" |
|
|
|
def __init__(self, dropout, dim, max_len=5000): |
|
if dim % 2 != 0: |
|
raise ValueError("Cannot use sin/cos positional encoding with " |
|
"odd dim (got dim={:d})".format(dim)) |
|
pe = torch.zeros(max_len, dim) |
|
position = torch.arange(0, max_len).unsqueeze(1) |
|
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * |
|
-(math.log(10000.0) / dim))) |
|
pe[:, 0::2] = torch.sin(position.float() * div_term) |
|
pe[:, 1::2] = torch.cos(position.float() * div_term) |
|
pe = pe.unsqueeze(1) |
|
super(PositionalEncoding, self).__init__() |
|
self.register_buffer('pe', pe) |
|
self.dropout = nn.Dropout(p=dropout) |
|
self.dim = dim |
|
|
|
def forward(self, emb, step=None): |
|
"""Embed inputs. |
|
|
|
Args: |
|
emb (FloatTensor): Sequence of word vectors |
|
``(seq_len, batch_size, self.dim)`` |
|
step (int or NoneType): If stepwise (``seq_len = 1``), use |
|
the encoding for this position. |
|
""" |
|
|
|
emb = emb * math.sqrt(self.dim) |
|
step = step or 0 |
|
if self.pe.size(0) < step + emb.size(0): |
|
raise SequenceTooLongError( |
|
f"Sequence is {emb.size(0) + step} but PositionalEncoding is" |
|
f" limited to {self.pe.size(0)}. See max_len argument." |
|
) |
|
emb = emb + self.pe[step:emb.size(0)+step] |
|
emb = self.dropout(emb) |
|
return emb |
|
|
|
|
|
class Embeddings(nn.Module): |
|
"""Words embeddings for encoder/decoder. |
|
|
|
Additionally includes ability to add sparse input features |
|
based on "Linguistic Input Features Improve Neural Machine Translation" |
|
:cite:`sennrich2016linguistic`. |
|
|
|
|
|
.. mermaid:: |
|
|
|
graph LR |
|
A[Input] |
|
C[Feature 1 Lookup] |
|
A-->B[Word Lookup] |
|
A-->C |
|
A-->D[Feature N Lookup] |
|
B-->E[MLP/Concat] |
|
C-->E |
|
D-->E |
|
E-->F[Output] |
|
|
|
Args: |
|
word_vec_size (int): size of the dictionary of embeddings. |
|
word_padding_idx (int): padding index for words in the embeddings. |
|
feat_padding_idx (List[int]): padding index for a list of features |
|
in the embeddings. |
|
word_vocab_size (int): size of dictionary of embeddings for words. |
|
feat_vocab_sizes (List[int], optional): list of size of dictionary |
|
of embeddings for each feature. |
|
position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding` |
|
feat_merge (string): merge action for the features embeddings: |
|
concat, sum or mlp. |
|
feat_vec_exponent (float): when using `-feat_merge concat`, feature |
|
embedding size is N^feat_dim_exponent, where N is the |
|
number of values the feature takes. |
|
feat_vec_size (int): embedding dimension for features when using |
|
`-feat_merge mlp` |
|
dropout (float): dropout probability. |
|
freeze_word_vecs (bool): freeze weights of word vectors. |
|
""" |
|
|
|
def __init__(self, word_vec_size, |
|
word_vocab_size, |
|
word_padding_idx, |
|
position_encoding=False, |
|
feat_merge="concat", |
|
feat_vec_exponent=0.7, |
|
feat_vec_size=-1, |
|
feat_padding_idx=[], |
|
feat_vocab_sizes=[], |
|
dropout=0, |
|
sparse=False, |
|
freeze_word_vecs=False): |
|
self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, |
|
feat_vec_size, feat_padding_idx) |
|
|
|
if feat_padding_idx is None: |
|
feat_padding_idx = [] |
|
self.word_padding_idx = word_padding_idx |
|
|
|
self.word_vec_size = word_vec_size |
|
|
|
|
|
vocab_sizes = [word_vocab_size] |
|
emb_dims = [word_vec_size] |
|
pad_indices = [word_padding_idx] |
|
|
|
|
|
|
|
if feat_merge == 'sum': |
|
feat_dims = [word_vec_size] * len(feat_vocab_sizes) |
|
elif feat_vec_size > 0: |
|
feat_dims = [feat_vec_size] * len(feat_vocab_sizes) |
|
else: |
|
feat_dims = [int(vocab ** feat_vec_exponent) |
|
for vocab in feat_vocab_sizes] |
|
vocab_sizes.extend(feat_vocab_sizes) |
|
emb_dims.extend(feat_dims) |
|
pad_indices.extend(feat_padding_idx) |
|
|
|
|
|
|
|
emb_params = zip(vocab_sizes, emb_dims, pad_indices) |
|
embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse) |
|
for vocab, dim, pad in emb_params] |
|
emb_luts = Elementwise(feat_merge, embeddings) |
|
|
|
|
|
|
|
|
|
|
|
self.embedding_size = (sum(emb_dims) if feat_merge == 'concat' |
|
else word_vec_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
super(Embeddings, self).__init__() |
|
self.make_embedding = nn.Sequential() |
|
self.make_embedding.add_module('emb_luts', emb_luts) |
|
|
|
if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0: |
|
in_dim = sum(emb_dims) |
|
mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU()) |
|
self.make_embedding.add_module('mlp', mlp) |
|
|
|
self.position_encoding = position_encoding |
|
|
|
if self.position_encoding: |
|
pe = PositionalEncoding(dropout, self.embedding_size) |
|
self.make_embedding.add_module('pe', pe) |
|
|
|
if freeze_word_vecs: |
|
self.word_lut.weight.requires_grad = False |
|
|
|
def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent, |
|
feat_vec_size, feat_padding_idx): |
|
if feat_merge == "sum": |
|
|
|
if feat_vec_exponent != 0.7: |
|
warnings.warn("Merging with sum, but got non-default " |
|
"feat_vec_exponent. It will be unused.") |
|
if feat_vec_size != -1: |
|
warnings.warn("Merging with sum, but got non-default " |
|
"feat_vec_size. It will be unused.") |
|
elif feat_vec_size > 0: |
|
|
|
if feat_vec_exponent != -1: |
|
warnings.warn("Not merging with sum and positive " |
|
"feat_vec_size, but got non-default " |
|
"feat_vec_exponent. It will be unused.") |
|
else: |
|
if feat_vec_exponent <= 0: |
|
raise ValueError("Using feat_vec_exponent to determine " |
|
"feature vec size, but got feat_vec_exponent " |
|
"less than or equal to 0.") |
|
n_feats = len(feat_vocab_sizes) |
|
if n_feats != len(feat_padding_idx): |
|
raise ValueError("Got unequal number of feat_vocab_sizes and " |
|
"feat_padding_idx ({:d} != {:d})".format( |
|
n_feats, len(feat_padding_idx))) |
|
|
|
@property |
|
def word_lut(self): |
|
"""Word look-up table.""" |
|
return self.make_embedding[0][0] |
|
|
|
@property |
|
def emb_luts(self): |
|
"""Embedding look-up table.""" |
|
return self.make_embedding[0] |
|
|
|
def load_pretrained_vectors(self, emb_file): |
|
"""Load in pretrained embeddings. |
|
|
|
Args: |
|
emb_file (str) : path to torch serialized embeddings |
|
""" |
|
|
|
if emb_file: |
|
pretrained = torch.load(emb_file) |
|
pretrained_vec_size = pretrained.size(1) |
|
if self.word_vec_size > pretrained_vec_size: |
|
self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained |
|
elif self.word_vec_size < pretrained_vec_size: |
|
self.word_lut.weight.data \ |
|
.copy_(pretrained[:, :self.word_vec_size]) |
|
else: |
|
self.word_lut.weight.data.copy_(pretrained) |
|
|
|
def forward(self, source, step=None): |
|
"""Computes the embeddings for words and features. |
|
|
|
Args: |
|
source (LongTensor): index tensor ``(len, batch, nfeat)`` |
|
|
|
Returns: |
|
FloatTensor: Word embeddings ``(len, batch, embedding_size)`` |
|
""" |
|
|
|
if self.position_encoding: |
|
for i, module in enumerate(self.make_embedding._modules.values()): |
|
if i == len(self.make_embedding._modules.values()) - 1: |
|
source = module(source, step=step) |
|
else: |
|
source = module(source) |
|
else: |
|
source = self.make_embedding(source) |
|
|
|
return source |
|
|
|
def update_dropout(self, dropout): |
|
if self.position_encoding: |
|
self._modules['make_embedding'][1].dropout.p = dropout |
|
|
|
|