import torch class DecodeStrategy(object): def __init__(self, pad, bos, eos, batch_size, parallel_paths, min_length, max_length, return_attention=False, return_hidden=False): self.pad = pad self.bos = bos self.eos = eos self.batch_size = batch_size self.parallel_paths = parallel_paths # result catching self.predictions = [[] for _ in range(batch_size)] self.scores = [[] for _ in range(batch_size)] self.token_scores = [[] for _ in range(batch_size)] self.attention = [[] for _ in range(batch_size)] self.hidden = [[] for _ in range(batch_size)] self.alive_attn = None self.alive_hidden = None self.min_length = min_length self.max_length = max_length n_paths = batch_size * parallel_paths self.return_attention = return_attention self.return_hidden = return_hidden self.done = False def initialize(self, memory_bank, device=None): if device is None: device = torch.device('cpu') self.alive_seq = torch.full( [self.batch_size * self.parallel_paths, 1], self.bos, dtype=torch.long, device=device) self.is_finished = torch.zeros( [self.batch_size, self.parallel_paths], dtype=torch.uint8, device=device) self.alive_log_token_scores = torch.zeros( [self.batch_size * self.parallel_paths, 0], dtype=torch.float, device=device) return None, memory_bank def __len__(self): return self.alive_seq.shape[1] def ensure_min_length(self, log_probs): if len(self) <= self.min_length: log_probs[:, self.eos] = -1e20 # forced non-end def ensure_max_length(self): if len(self) == self.max_length + 1: self.is_finished.fill_(1) def advance(self, log_probs, attn): raise NotImplementedError() def update_finished(self): raise NotImplementedError