|
|
|
from collections import Counter, OrderedDict |
|
from itertools import chain |
|
import six |
|
import torch |
|
|
|
from .pipeline import Pipeline |
|
from .utils import get_tokenizer, dtype_to_attr, is_tokenizer_serializable |
|
from .vocab import Vocab |
|
|
|
|
|
class RawField(object): |
|
""" Defines a general datatype. |
|
|
|
Every dataset consists of one or more types of data. For instance, a text |
|
classification dataset contains sentences and their classes, while a |
|
machine translation dataset contains paired examples of text in two |
|
languages. Each of these types of data is represented by a RawField object. |
|
A RawField object does not assume any property of the data type and |
|
it holds parameters relating to how a datatype should be processed. |
|
|
|
Attributes: |
|
preprocessing: The Pipeline that will be applied to examples |
|
using this field before creating an example. |
|
Default: None. |
|
postprocessing: A Pipeline that will be applied to a list of examples |
|
using this field before assigning to a batch. |
|
Function signature: (batch(list)) -> object |
|
Default: None. |
|
is_target: Whether this field is a target variable. |
|
Affects iteration over batches. Default: False |
|
""" |
|
|
|
def __init__(self, preprocessing=None, postprocessing=None, is_target=False): |
|
self.preprocessing = preprocessing |
|
self.postprocessing = postprocessing |
|
self.is_target = is_target |
|
|
|
def preprocess(self, x): |
|
""" Preprocess an example if the `preprocessing` Pipeline is provided. """ |
|
if hasattr(self, "preprocessing") and self.preprocessing is not None: |
|
return self.preprocessing(x) |
|
else: |
|
return x |
|
|
|
def process(self, batch, *args, **kwargs): |
|
""" Process a list of examples to create a batch. |
|
|
|
Postprocess the batch with user-provided Pipeline. |
|
|
|
Args: |
|
batch (list(object)): A list of object from a batch of examples. |
|
Returns: |
|
object: Processed object given the input and custom |
|
postprocessing Pipeline. |
|
""" |
|
if self.postprocessing is not None: |
|
batch = self.postprocessing(batch) |
|
return batch |
|
|
|
|
|
class Field(RawField): |
|
"""Defines a datatype together with instructions for converting to Tensor. |
|
|
|
Field class models common text processing datatypes that can be represented |
|
by tensors. It holds a Vocab object that defines the set of possible values |
|
for elements of the field and their corresponding numerical representations. |
|
The Field object also holds other parameters relating to how a datatype |
|
should be numericalized, such as a tokenization method and the kind of |
|
Tensor that should be produced. |
|
|
|
If a Field is shared between two columns in a dataset (e.g., question and |
|
answer in a QA dataset), then they will have a shared vocabulary. |
|
|
|
Attributes: |
|
sequential: Whether the datatype represents sequential data. If False, |
|
no tokenization is applied. Default: True. |
|
use_vocab: Whether to use a Vocab object. If False, the data in this |
|
field should already be numerical. Default: True. |
|
init_token: A token that will be prepended to every example using this |
|
field, or None for no initial token. Default: None. |
|
eos_token: A token that will be appended to every example using this |
|
field, or None for no end-of-sentence token. Default: None. |
|
fix_length: A fixed length that all examples using this field will be |
|
padded to, or None for flexible sequence lengths. Default: None. |
|
dtype: The torch.dtype class that represents a batch of examples |
|
of this kind of data. Default: torch.long. |
|
preprocessing: The Pipeline that will be applied to examples |
|
using this field after tokenizing but before numericalizing. Many |
|
Datasets replace this attribute with a custom preprocessor. |
|
Default: None. |
|
postprocessing: A Pipeline that will be applied to examples using |
|
this field after numericalizing but before the numbers are turned |
|
into a Tensor. The pipeline function takes the batch as a list, and |
|
the field's Vocab. |
|
Default: None. |
|
lower: Whether to lowercase the text in this field. Default: False. |
|
tokenize: The function used to tokenize strings using this field into |
|
sequential examples. If "spacy", the SpaCy tokenizer is |
|
used. If a non-serializable function is passed as an argument, |
|
the field will not be able to be serialized. Default: string.split. |
|
tokenizer_language: The language of the tokenizer to be constructed. |
|
Various languages currently supported only in SpaCy. |
|
include_lengths: Whether to return a tuple of a padded minibatch and |
|
a list containing the lengths of each examples, or just a padded |
|
minibatch. Default: False. |
|
batch_first: Whether to produce tensors with the batch dimension first. |
|
Default: False. |
|
pad_token: The string token used as padding. Default: "<pad>". |
|
unk_token: The string token used to represent OOV words. Default: "<unk>". |
|
pad_first: Do the padding of the sequence at the beginning. Default: False. |
|
truncate_first: Do the truncating of the sequence at the beginning. Default: False |
|
stop_words: Tokens to discard during the preprocessing step. Default: None |
|
is_target: Whether this field is a target variable. |
|
Affects iteration over batches. Default: False |
|
""" |
|
|
|
vocab_cls = Vocab |
|
|
|
|
|
dtypes = { |
|
torch.float32: float, |
|
torch.float: float, |
|
torch.float64: float, |
|
torch.double: float, |
|
torch.float16: float, |
|
torch.half: float, |
|
|
|
torch.uint8: int, |
|
torch.int8: int, |
|
torch.int16: int, |
|
torch.short: int, |
|
torch.int32: int, |
|
torch.int: int, |
|
torch.int64: int, |
|
torch.long: int, |
|
} |
|
|
|
ignore = ['dtype', 'tokenize'] |
|
|
|
def __init__(self, sequential=True, use_vocab=True, init_token=None, |
|
eos_token=None, fix_length=None, dtype=torch.long, |
|
preprocessing=None, postprocessing=None, lower=False, |
|
tokenize=None, tokenizer_language='en', include_lengths=False, |
|
batch_first=False, pad_token="<pad>", unk_token="<unk>", |
|
pad_first=False, truncate_first=False, stop_words=None, |
|
is_target=False): |
|
self.sequential = sequential |
|
self.use_vocab = use_vocab |
|
self.init_token = init_token |
|
self.eos_token = eos_token |
|
self.unk_token = unk_token |
|
self.fix_length = fix_length |
|
self.dtype = dtype |
|
self.preprocessing = preprocessing |
|
self.postprocessing = postprocessing |
|
self.lower = lower |
|
|
|
|
|
self.tokenizer_args = (tokenize, tokenizer_language) |
|
self.tokenize = get_tokenizer(tokenize, tokenizer_language) |
|
self.include_lengths = include_lengths |
|
self.batch_first = batch_first |
|
self.pad_token = pad_token if self.sequential else None |
|
self.pad_first = pad_first |
|
self.truncate_first = truncate_first |
|
try: |
|
self.stop_words = set(stop_words) if stop_words is not None else None |
|
except TypeError: |
|
raise ValueError("Stop words must be convertible to a set") |
|
self.is_target = is_target |
|
|
|
def __getstate__(self): |
|
str_type = dtype_to_attr(self.dtype) |
|
if is_tokenizer_serializable(*self.tokenizer_args): |
|
tokenize = self.tokenize |
|
else: |
|
|
|
tokenize = None |
|
attrs = {k: v for k, v in self.__dict__.items() if k not in self.ignore} |
|
attrs['dtype'] = str_type |
|
attrs['tokenize'] = tokenize |
|
|
|
return attrs |
|
|
|
def __setstate__(self, state): |
|
state['dtype'] = getattr(torch, state['dtype']) |
|
if not state['tokenize']: |
|
state['tokenize'] = get_tokenizer(*state['tokenizer_args']) |
|
self.__dict__.update(state) |
|
|
|
def __hash__(self): |
|
|
|
return 42 |
|
|
|
def __eq__(self, other): |
|
if not isinstance(other, RawField): |
|
return False |
|
|
|
return self.__dict__ == other.__dict__ |
|
|
|
def preprocess(self, x): |
|
"""Load a single example using this field, tokenizing if necessary. |
|
|
|
If the input is a Python 2 `str`, it will be converted to Unicode |
|
first. If `sequential=True`, it will be tokenized. Then the input |
|
will be optionally lowercased and passed to the user-provided |
|
`preprocessing` Pipeline.""" |
|
if (six.PY2 and isinstance(x, six.string_types) |
|
and not isinstance(x, six.text_type)): |
|
x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x) |
|
if self.sequential and isinstance(x, six.text_type): |
|
x = self.tokenize(x.rstrip('\n')) |
|
if self.lower: |
|
x = Pipeline(six.text_type.lower)(x) |
|
if self.sequential and self.use_vocab and self.stop_words is not None: |
|
x = [w for w in x if w not in self.stop_words] |
|
if hasattr(self, "preprocessing") and self.preprocessing is not None: |
|
return self.preprocessing(x) |
|
else: |
|
return x |
|
|
|
def process(self, batch, device=None): |
|
""" Process a list of examples to create a torch.Tensor. |
|
|
|
Pad, numericalize, and postprocess a batch and create a tensor. |
|
|
|
Args: |
|
batch (list(object)): A list of object from a batch of examples. |
|
Returns: |
|
torch.autograd.Variable: Processed object given the input |
|
and custom postprocessing Pipeline. |
|
""" |
|
padded = self.pad(batch) |
|
tensor = self.numericalize(padded, device=device) |
|
return tensor |
|
|
|
def pad(self, minibatch): |
|
"""Pad a batch of examples using this field. |
|
|
|
Pads to self.fix_length if provided, otherwise pads to the length of |
|
the longest example in the batch. Prepends self.init_token and appends |
|
self.eos_token if those attributes are not None. Returns a tuple of the |
|
padded list and a list containing lengths of each example if |
|
`self.include_lengths` is `True` and `self.sequential` is `True`, else just |
|
returns the padded list. If `self.sequential` is `False`, no padding is applied. |
|
""" |
|
minibatch = list(minibatch) |
|
if not self.sequential: |
|
return minibatch |
|
if self.fix_length is None: |
|
max_len = max(len(x) for x in minibatch) |
|
else: |
|
max_len = self.fix_length + ( |
|
self.init_token, self.eos_token).count(None) - 2 |
|
padded, lengths = [], [] |
|
for x in minibatch: |
|
if self.pad_first: |
|
padded.append( |
|
[self.pad_token] * max(0, max_len - len(x)) |
|
+ ([] if self.init_token is None else [self.init_token]) |
|
+ list(x[-max_len:] if self.truncate_first else x[:max_len]) |
|
+ ([] if self.eos_token is None else [self.eos_token])) |
|
else: |
|
padded.append( |
|
([] if self.init_token is None else [self.init_token]) |
|
+ list(x[-max_len:] if self.truncate_first else x[:max_len]) |
|
+ ([] if self.eos_token is None else [self.eos_token]) |
|
+ [self.pad_token] * max(0, max_len - len(x))) |
|
lengths.append(len(padded[-1]) - max(0, max_len - len(x))) |
|
if self.include_lengths: |
|
return (padded, lengths) |
|
return padded |
|
|
|
def build_vocab(self, *args, **kwargs): |
|
"""Construct the Vocab object for this field from one or more datasets. |
|
|
|
Arguments: |
|
Positional arguments: Dataset objects or other iterable data |
|
sources from which to construct the Vocab object that |
|
represents the set of possible values for this field. If |
|
a Dataset object is provided, all columns corresponding |
|
to this field are used; individual columns can also be |
|
provided directly. |
|
Remaining keyword arguments: Passed to the constructor of Vocab. |
|
""" |
|
counter = Counter() |
|
sources = [] |
|
for arg in args: |
|
sources.append(arg) |
|
for data in sources: |
|
for x in data: |
|
if not self.sequential: |
|
x = [x] |
|
try: |
|
counter.update(x) |
|
except TypeError: |
|
counter.update(chain.from_iterable(x)) |
|
specials = list(OrderedDict.fromkeys( |
|
tok for tok in [self.unk_token, self.pad_token, self.init_token, |
|
self.eos_token] + kwargs.pop('specials', []) |
|
if tok is not None)) |
|
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) |
|
|
|
def numericalize(self, arr, device=None): |
|
"""Turn a batch of examples that use this field into a Variable. |
|
|
|
If the field has include_lengths=True, a tensor of lengths will be |
|
included in the return value. |
|
|
|
Arguments: |
|
arr (List[List[str]], or tuple of (List[List[str]], List[int])): |
|
List of tokenized and padded examples, or tuple of List of |
|
tokenized and padded examples and List of lengths of each |
|
example if self.include_lengths is True. |
|
device (str or torch.device): A string or instance of `torch.device` |
|
specifying which device the Variables are going to be created on. |
|
If left as default, the tensors will be created on cpu. Default: None. |
|
""" |
|
if self.include_lengths and not isinstance(arr, tuple): |
|
raise ValueError("Field has include_lengths set to True, but " |
|
"input data is not a tuple of " |
|
"(data batch, batch lengths).") |
|
if isinstance(arr, tuple): |
|
arr, lengths = arr |
|
lengths = torch.tensor(lengths, dtype=self.dtype, device=device) |
|
|
|
if self.use_vocab: |
|
if self.sequential: |
|
arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] |
|
else: |
|
arr = [self.vocab.stoi[x] for x in arr] |
|
|
|
if self.postprocessing is not None: |
|
arr = self.postprocessing(arr, self.vocab) |
|
else: |
|
if self.dtype not in self.dtypes: |
|
raise ValueError( |
|
"Specified Field dtype {} can not be used with " |
|
"use_vocab=False because we do not know how to numericalize it. " |
|
"Please raise an issue at " |
|
"https://github.com/pytorch/text/issues".format(self.dtype)) |
|
numericalization_func = self.dtypes[self.dtype] |
|
|
|
|
|
|
|
if not self.sequential: |
|
arr = [numericalization_func(x) if isinstance(x, six.string_types) |
|
else x for x in arr] |
|
if self.postprocessing is not None: |
|
arr = self.postprocessing(arr, None) |
|
|
|
var = torch.tensor(arr, dtype=self.dtype, device=device) |
|
|
|
if self.sequential and not self.batch_first: |
|
var.t_() |
|
if self.sequential: |
|
var = var.contiguous() |
|
|
|
if self.include_lengths: |
|
return var, lengths |
|
return var |
|
|
|
|
|
class NestedField(Field): |
|
"""A nested field. |
|
|
|
A nested field holds another field (called *nesting field*), accepts an untokenized |
|
string or a list string tokens and groups and treats them as one field as described |
|
by the nesting field. Every token will be preprocessed, padded, etc. in the manner |
|
specified by the nesting field. Note that this means a nested field always has |
|
``sequential=True``. The two fields' vocabularies will be shared. Their |
|
numericalization results will be stacked into a single tensor. And NestedField will |
|
share the same include_lengths with nesting_field, so one shouldn't specify the |
|
include_lengths in the nesting_field. This field is |
|
primarily used to implement character embeddings. See ``tests/data/test_field.py`` |
|
for examples on how to use this field. |
|
|
|
Arguments: |
|
nesting_field (Field): A field contained in this nested field. |
|
use_vocab (bool): Whether to use a Vocab object. If False, the data in this |
|
field should already be numerical. Default: ``True``. |
|
init_token (str): A token that will be prepended to every example using this |
|
field, or None for no initial token. Default: ``None``. |
|
eos_token (str): A token that will be appended to every example using this |
|
field, or None for no end-of-sentence token. Default: ``None``. |
|
fix_length (int): A fixed length that all examples using this field will be |
|
padded to, or ``None`` for flexible sequence lengths. Default: ``None``. |
|
dtype: The torch.dtype class that represents a batch of examples |
|
of this kind of data. Default: ``torch.long``. |
|
preprocessing (Pipeline): The Pipeline that will be applied to examples |
|
using this field after tokenizing but before numericalizing. Many |
|
Datasets replace this attribute with a custom preprocessor. |
|
Default: ``None``. |
|
postprocessing (Pipeline): A Pipeline that will be applied to examples using |
|
this field after numericalizing but before the numbers are turned |
|
into a Tensor. The pipeline function takes the batch as a list, and |
|
the field's Vocab. Default: ``None``. |
|
include_lengths: Whether to return a tuple of a padded minibatch and |
|
a list containing the lengths of each examples, or just a padded |
|
minibatch. Default: False. |
|
tokenize: The function used to tokenize strings using this field into |
|
sequential examples. If "spacy", the SpaCy tokenizer is |
|
used. If a non-serializable function is passed as an argument, |
|
the field will not be able to be serialized. Default: string.split. |
|
tokenizer_language: The language of the tokenizer to be constructed. |
|
Various languages currently supported only in SpaCy. |
|
pad_token (str): The string token used as padding. If ``nesting_field`` is |
|
sequential, this will be set to its ``pad_token``. Default: ``"<pad>"``. |
|
pad_first (bool): Do the padding of the sequence at the beginning. Default: |
|
``False``. |
|
""" |
|
|
|
def __init__(self, nesting_field, use_vocab=True, init_token=None, eos_token=None, |
|
fix_length=None, dtype=torch.long, preprocessing=None, |
|
postprocessing=None, tokenize=None, tokenizer_language='en', |
|
include_lengths=False, pad_token='<pad>', |
|
pad_first=False, truncate_first=False): |
|
if isinstance(nesting_field, NestedField): |
|
raise ValueError('nesting field must not be another NestedField') |
|
if nesting_field.include_lengths: |
|
raise ValueError('nesting field cannot have include_lengths=True') |
|
|
|
if nesting_field.sequential: |
|
pad_token = nesting_field.pad_token |
|
super(NestedField, self).__init__( |
|
use_vocab=use_vocab, |
|
init_token=init_token, |
|
eos_token=eos_token, |
|
fix_length=fix_length, |
|
dtype=dtype, |
|
preprocessing=preprocessing, |
|
postprocessing=postprocessing, |
|
lower=nesting_field.lower, |
|
tokenize=tokenize, |
|
tokenizer_language=tokenizer_language, |
|
batch_first=True, |
|
pad_token=pad_token, |
|
unk_token=nesting_field.unk_token, |
|
pad_first=pad_first, |
|
truncate_first=truncate_first, |
|
include_lengths=include_lengths |
|
) |
|
self.nesting_field = nesting_field |
|
|
|
self.nesting_field.batch_first = True |
|
|
|
def preprocess(self, xs): |
|
"""Preprocess a single example. |
|
|
|
Firstly, tokenization and the supplied preprocessing pipeline is applied. Since |
|
this field is always sequential, the result is a list. Then, each element of |
|
the list is preprocessed using ``self.nesting_field.preprocess`` and the resulting |
|
list is returned. |
|
|
|
Arguments: |
|
xs (list or str): The input to preprocess. |
|
|
|
Returns: |
|
list: The preprocessed list. |
|
""" |
|
return [self.nesting_field.preprocess(x) |
|
for x in super(NestedField, self).preprocess(xs)] |
|
|
|
def pad(self, minibatch): |
|
"""Pad a batch of examples using this field. |
|
|
|
If ``self.nesting_field.sequential`` is ``False``, each example in the batch must |
|
be a list of string tokens, and pads them as if by a ``Field`` with |
|
``sequential=True``. Otherwise, each example must be a list of list of tokens. |
|
Using ``self.nesting_field``, pads the list of tokens to |
|
``self.nesting_field.fix_length`` if provided, or otherwise to the length of the |
|
longest list of tokens in the batch. Next, using this field, pads the result by |
|
filling short examples with ``self.nesting_field.pad_token``. |
|
|
|
Example: |
|
>>> import pprint |
|
>>> pp = pprint.PrettyPrinter(indent=4) |
|
>>> |
|
>>> nesting_field = Field(pad_token='<c>', init_token='<w>', eos_token='</w>') |
|
>>> field = NestedField(nesting_field, init_token='<s>', eos_token='</s>') |
|
>>> minibatch = [ |
|
... [list('john'), list('loves'), list('mary')], |
|
... [list('mary'), list('cries')], |
|
... ] |
|
>>> padded = field.pad(minibatch) |
|
>>> pp.pprint(padded) |
|
[ [ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'], |
|
['<w>', 'j', 'o', 'h', 'n', '</w>', '<c>'], |
|
['<w>', 'l', 'o', 'v', 'e', 's', '</w>'], |
|
['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'], |
|
['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>']], |
|
[ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'], |
|
['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'], |
|
['<w>', 'c', 'r', 'i', 'e', 's', '</w>'], |
|
['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>'], |
|
['<c>', '<c>', '<c>', '<c>', '<c>', '<c>', '<c>']]] |
|
|
|
Arguments: |
|
minibatch (list): Each element is a list of string if |
|
``self.nesting_field.sequential`` is ``False``, a list of list of string |
|
otherwise. |
|
|
|
Returns: |
|
list: The padded minibatch. or (padded, sentence_lens, word_lengths) |
|
""" |
|
minibatch = list(minibatch) |
|
if not self.nesting_field.sequential: |
|
return super(NestedField, self).pad(minibatch) |
|
|
|
|
|
old_pad_token = self.pad_token |
|
old_init_token = self.init_token |
|
old_eos_token = self.eos_token |
|
old_fix_len = self.nesting_field.fix_length |
|
|
|
if self.nesting_field.fix_length is None: |
|
max_len = max(len(xs) for ex in minibatch for xs in ex) |
|
fix_len = max_len + 2 - (self.nesting_field.init_token, |
|
self.nesting_field.eos_token).count(None) |
|
self.nesting_field.fix_length = fix_len |
|
self.pad_token = [self.pad_token] * self.nesting_field.fix_length |
|
if self.init_token is not None: |
|
|
|
self.init_token = [self.init_token] |
|
if self.eos_token is not None: |
|
|
|
self.eos_token = [self.eos_token] |
|
|
|
old_include_lengths = self.include_lengths |
|
self.include_lengths = True |
|
self.nesting_field.include_lengths = True |
|
padded, sentence_lengths = super(NestedField, self).pad(minibatch) |
|
padded_with_lengths = [self.nesting_field.pad(ex) for ex in padded] |
|
word_lengths = [] |
|
final_padded = [] |
|
max_sen_len = len(padded[0]) |
|
for (pad, lens), sentence_len in zip(padded_with_lengths, sentence_lengths): |
|
if sentence_len == max_sen_len: |
|
lens = lens |
|
pad = pad |
|
elif self.pad_first: |
|
lens[:(max_sen_len - sentence_len)] = ( |
|
[0] * (max_sen_len - sentence_len)) |
|
pad[:(max_sen_len - sentence_len)] = ( |
|
[self.pad_token] * (max_sen_len - sentence_len)) |
|
else: |
|
lens[-(max_sen_len - sentence_len):] = ( |
|
[0] * (max_sen_len - sentence_len)) |
|
pad[-(max_sen_len - sentence_len):] = ( |
|
[self.pad_token] * (max_sen_len - sentence_len)) |
|
word_lengths.append(lens) |
|
final_padded.append(pad) |
|
padded = final_padded |
|
|
|
|
|
self.nesting_field.fix_length = old_fix_len |
|
self.pad_token = old_pad_token |
|
self.init_token = old_init_token |
|
self.eos_token = old_eos_token |
|
self.include_lengths = old_include_lengths |
|
if self.include_lengths: |
|
return padded, sentence_lengths, word_lengths |
|
return padded |
|
|
|
def build_vocab(self, *args, **kwargs): |
|
"""Construct the Vocab object for nesting field and combine it with this field's vocab. |
|
|
|
Arguments: |
|
Positional arguments: Dataset objects or other iterable data |
|
sources from which to construct the Vocab object that |
|
represents the set of possible values for the nesting field. If |
|
a Dataset object is provided, all columns corresponding |
|
to this field are used; individual columns can also be |
|
provided directly. |
|
Remaining keyword arguments: Passed to the constructor of Vocab. |
|
""" |
|
sources = [] |
|
for arg in args: |
|
sources.append(arg) |
|
|
|
flattened = [] |
|
for source in sources: |
|
flattened.extend(source) |
|
old_vectors = None |
|
old_unk_init = None |
|
old_vectors_cache = None |
|
if "vectors" in kwargs.keys(): |
|
old_vectors = kwargs["vectors"] |
|
kwargs["vectors"] = None |
|
if "unk_init" in kwargs.keys(): |
|
old_unk_init = kwargs["unk_init"] |
|
kwargs["unk_init"] = None |
|
if "vectors_cache" in kwargs.keys(): |
|
old_vectors_cache = kwargs["vectors_cache"] |
|
kwargs["vectors_cache"] = None |
|
|
|
self.nesting_field.build_vocab(*flattened, **kwargs) |
|
super(NestedField, self).build_vocab() |
|
self.vocab.extend(self.nesting_field.vocab) |
|
self.vocab.freqs = self.nesting_field.vocab.freqs.copy() |
|
if old_vectors is not None: |
|
self.vocab.load_vectors(old_vectors, |
|
unk_init=old_unk_init, cache=old_vectors_cache) |
|
|
|
self.nesting_field.vocab = self.vocab |
|
|
|
def numericalize(self, arrs, device=None): |
|
"""Convert a padded minibatch into a variable tensor. |
|
|
|
Each item in the minibatch will be numericalized independently and the resulting |
|
tensors will be stacked at the first dimension. |
|
|
|
Arguments: |
|
arr (List[List[str]]): List of tokenized and padded examples. |
|
device (str or torch.device): A string or instance of `torch.device` |
|
specifying which device the Variables are going to be created on. |
|
If left as default, the tensors will be created on cpu. Default: None. |
|
""" |
|
numericalized = [] |
|
self.nesting_field.include_lengths = False |
|
if self.include_lengths: |
|
arrs, sentence_lengths, word_lengths = arrs |
|
|
|
for arr in arrs: |
|
numericalized_ex = self.nesting_field.numericalize( |
|
arr, device=device) |
|
numericalized.append(numericalized_ex) |
|
padded_batch = torch.stack(numericalized) |
|
|
|
self.nesting_field.include_lengths = True |
|
if self.include_lengths: |
|
sentence_lengths = \ |
|
torch.tensor(sentence_lengths, dtype=self.dtype, device=device) |
|
word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device) |
|
return (padded_batch, sentence_lengths, word_lengths) |
|
return padded_batch |