|
import torch |
|
from .decode_strategy import DecodeStrategy |
|
|
|
|
|
class BeamSearch(DecodeStrategy): |
|
"""Generation with beam search. |
|
""" |
|
|
|
def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length, |
|
return_attention, max_length): |
|
super(BeamSearch, self).__init__( |
|
pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length) |
|
self.beam_size = beam_size |
|
self.n_best = n_best |
|
|
|
|
|
self.hypotheses = [[] for _ in range(batch_size)] |
|
|
|
|
|
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool) |
|
|
|
self._batch_offset = torch.arange(batch_size, dtype=torch.long) |
|
|
|
self.select_indices = None |
|
self.done = False |
|
|
|
def initialize(self, memory_bank, device=None): |
|
"""Repeat src objects `beam_size` times. |
|
""" |
|
|
|
def fn_map_state(state, dim): |
|
return torch.repeat_interleave(state, self.beam_size, dim=dim) |
|
|
|
memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0) |
|
if device is None: |
|
device = memory_bank.device |
|
|
|
self.memory_length = memory_bank.size(1) |
|
super().initialize(memory_bank, device) |
|
|
|
self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device) |
|
self._beam_offset = torch.arange( |
|
0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device) |
|
self.topk_log_probs = torch.tensor( |
|
[0.0] + [float("-inf")] * (self.beam_size - 1), device=device |
|
).repeat(self.batch_size) |
|
|
|
self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device) |
|
self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device) |
|
self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device) |
|
|
|
return fn_map_state, memory_bank |
|
|
|
@property |
|
def current_predictions(self): |
|
return self.alive_seq[:, -1] |
|
|
|
@property |
|
def current_backptr(self): |
|
|
|
return self.select_indices.view(self.batch_size, self.beam_size) |
|
|
|
@property |
|
def batch_offset(self): |
|
return self._batch_offset |
|
|
|
def _pick(self, log_probs): |
|
"""Return token decision for a step. |
|
|
|
Args: |
|
log_probs (FloatTensor): (B, vocab_size) |
|
|
|
Returns: |
|
topk_scores (FloatTensor): (B, beam_size) |
|
topk_ids (LongTensor): (B, beam_size) |
|
""" |
|
vocab_size = log_probs.size(-1) |
|
|
|
|
|
curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size) |
|
topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1) |
|
return topk_scores, topk_ids |
|
|
|
def advance(self, log_probs, attn): |
|
""" |
|
Args: |
|
log_probs: (B * beam_size, vocab_size) |
|
""" |
|
vocab_size = log_probs.size(-1) |
|
|
|
|
|
_B = log_probs.shape[0] // self.beam_size |
|
|
|
step = len(self) |
|
self.ensure_min_length(log_probs) |
|
|
|
|
|
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) |
|
|
|
curr_length = step + 1 |
|
curr_scores = log_probs / curr_length |
|
self.topk_scores, self.topk_ids = self._pick(curr_scores) |
|
|
|
|
|
|
|
torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs) |
|
|
|
|
|
self._batch_index = self.topk_ids // vocab_size |
|
self._batch_index += self._beam_offset[:_B].unsqueeze(1) |
|
self.select_indices = self._batch_index.view(_B * self.beam_size) |
|
self.topk_ids.fmod_(vocab_size) |
|
|
|
|
|
self.alive_seq = torch.cat( |
|
[self.alive_seq.index_select(0, self.select_indices), |
|
self.topk_ids.view(_B * self.beam_size, 1)], -1) |
|
|
|
if self.return_attention: |
|
current_attn = attn.index_select(1, self.select_indices) |
|
if step == 1: |
|
self.alive_attn = current_attn |
|
else: |
|
self.alive_attn = self.alive_attn.index_select( |
|
1, self.select_indices) |
|
self.alive_attn = torch.cat([self.alive_attn, current_attn], 0) |
|
|
|
self.is_finished = self.topk_ids.eq(self.eos) |
|
self.ensure_max_length() |
|
|
|
def update_finished(self): |
|
_B_old = self.topk_log_probs.shape[0] |
|
step = self.alive_seq.shape[-1] |
|
self.topk_log_probs.masked_fill_(self.is_finished, -1e10) |
|
|
|
self.is_finished = self.is_finished.to('cpu') |
|
self.top_beam_finished |= self.is_finished[:, 0].eq(1) |
|
predictions = self.alive_seq.view(_B_old, self.beam_size, step) |
|
attention = ( |
|
self.alive_attn.view( |
|
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) |
|
if self.alive_attn is not None else None) |
|
non_finished_batch = [] |
|
for i in range(self.is_finished.size(0)): |
|
b = self._batch_offset[i] |
|
finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1) |
|
|
|
for j in finished_hyp: |
|
self.hypotheses[b].append(( |
|
self.topk_scores[i, j], |
|
predictions[i, j, 1:], |
|
attention[:, i, j, :self.memory_length] |
|
if attention is not None else None)) |
|
|
|
|
|
finish_flag = self.top_beam_finished[i] != 0 |
|
if finish_flag and len(self.hypotheses[b]) >= self.n_best: |
|
best_hyp = sorted( |
|
self.hypotheses[b], key=lambda x: x[0], reverse=True) |
|
for n, (score, pred, attn) in enumerate(best_hyp): |
|
if n >= self.n_best: |
|
break |
|
self.scores[b].append(score.item()) |
|
self.predictions[b].append(pred) |
|
self.attention[b].append( |
|
attn if attn is not None else []) |
|
else: |
|
non_finished_batch.append(i) |
|
non_finished = torch.tensor(non_finished_batch) |
|
|
|
if len(non_finished) == 0: |
|
self.done = True |
|
return |
|
|
|
_B_new = non_finished.shape[0] |
|
|
|
self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished) |
|
self._batch_offset = self._batch_offset.index_select(0, non_finished) |
|
non_finished = non_finished.to(self.topk_ids.device) |
|
self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished) |
|
self._batch_index = self._batch_index.index_select(0, non_finished) |
|
self.select_indices = self._batch_index.view(_B_new * self.beam_size) |
|
self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1)) |
|
self.topk_scores = self.topk_scores.index_select(0, non_finished) |
|
self.topk_ids = self.topk_ids.index_select(0, non_finished) |
|
|
|
if self.alive_attn is not None: |
|
inp_seq_len = self.alive_attn.size(-1) |
|
self.alive_attn = attention.index_select(1, non_finished) \ |
|
.view(step - 1, _B_new * self.beam_size, inp_seq_len) |
|
|