File size: 7,824 Bytes
5e9bd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
from .decode_strategy import DecodeStrategy
class BeamSearch(DecodeStrategy):
"""Generation with beam search.
"""
def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length,
return_attention, max_length):
super(BeamSearch, self).__init__(
pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length)
self.beam_size = beam_size
self.n_best = n_best
# result caching
self.hypotheses = [[] for _ in range(batch_size)]
# beam state
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool)
self._batch_offset = torch.arange(batch_size, dtype=torch.long)
self.select_indices = None
self.done = False
def initialize(self, memory_bank, device=None):
"""Repeat src objects `beam_size` times.
"""
def fn_map_state(state, dim):
return torch.repeat_interleave(state, self.beam_size, dim=dim)
memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0)
if device is None:
device = memory_bank.device
self.memory_length = memory_bank.size(1)
super().initialize(memory_bank, device)
self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device)
self._beam_offset = torch.arange(
0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device)
self.topk_log_probs = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), device=device
).repeat(self.batch_size)
# buffers for the topk scores and 'backpointer'
self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device)
self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device)
self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device)
return fn_map_state, memory_bank
@property
def current_predictions(self):
return self.alive_seq[:, -1]
@property
def current_backptr(self):
# for testing
return self.select_indices.view(self.batch_size, self.beam_size)
@property
def batch_offset(self):
return self._batch_offset
def _pick(self, log_probs):
"""Return token decision for a step.
Args:
log_probs (FloatTensor): (B, vocab_size)
Returns:
topk_scores (FloatTensor): (B, beam_size)
topk_ids (LongTensor): (B, beam_size)
"""
vocab_size = log_probs.size(-1)
# Flatten probs into a list of probabilities.
curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1)
return topk_scores, topk_ids
def advance(self, log_probs, attn):
"""
Args:
log_probs: (B * beam_size, vocab_size)
"""
vocab_size = log_probs.size(-1)
# (non-finished) batch_size
_B = log_probs.shape[0] // self.beam_size
step = len(self) # alive_seq
self.ensure_min_length(log_probs)
# Multiply probs by the beam probability
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
curr_length = step + 1
curr_scores = log_probs / curr_length # avg log_prob
self.topk_scores, self.topk_ids = self._pick(curr_scores)
# topk_scores/topk_ids: (batch_size, beam_size)
# Recover log probs
torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs)
# Resolve beam origin and map to batch index flat representation.
self._batch_index = self.topk_ids // vocab_size
self._batch_index += self._beam_offset[:_B].unsqueeze(1)
self.select_indices = self._batch_index.view(_B * self.beam_size)
self.topk_ids.fmod_(vocab_size) # resolve true word ids
# Append last prediction.
self.alive_seq = torch.cat(
[self.alive_seq.index_select(0, self.select_indices),
self.topk_ids.view(_B * self.beam_size, 1)], -1)
if self.return_attention:
current_attn = attn.index_select(1, self.select_indices)
if step == 1:
self.alive_attn = current_attn
else:
self.alive_attn = self.alive_attn.index_select(
1, self.select_indices)
self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
self.is_finished = self.topk_ids.eq(self.eos)
self.ensure_max_length()
def update_finished(self):
_B_old = self.topk_log_probs.shape[0]
step = self.alive_seq.shape[-1] # len(self)
self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
self.is_finished = self.is_finished.to('cpu')
self.top_beam_finished |= self.is_finished[:, 0].eq(1)
predictions = self.alive_seq.view(_B_old, self.beam_size, step)
attention = (
self.alive_attn.view(
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
if self.alive_attn is not None else None)
non_finished_batch = []
for i in range(self.is_finished.size(0)):
b = self._batch_offset[i]
finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1)
# Store finished hypothesis for this batch.
for j in finished_hyp: # Beam level: finished beam j in batch i
self.hypotheses[b].append((
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start token
attention[:, i, j, :self.memory_length]
if attention is not None else None))
# End condition is the top beam finished and we can return
# n_best hypotheses.
finish_flag = self.top_beam_finished[i] != 0
if finish_flag and len(self.hypotheses[b]) >= self.n_best:
best_hyp = sorted(
self.hypotheses[b], key=lambda x: x[0], reverse=True)
for n, (score, pred, attn) in enumerate(best_hyp):
if n >= self.n_best:
break
self.scores[b].append(score.item())
self.predictions[b].append(pred)
self.attention[b].append(
attn if attn is not None else [])
else:
non_finished_batch.append(i)
non_finished = torch.tensor(non_finished_batch)
if len(non_finished) == 0:
self.done = True
return
_B_new = non_finished.shape[0]
# Remove finished batches for the next step
self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished)
self._batch_offset = self._batch_offset.index_select(0, non_finished)
non_finished = non_finished.to(self.topk_ids.device)
self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished)
self._batch_index = self._batch_index.index_select(0, non_finished)
self.select_indices = self._batch_index.view(_B_new * self.beam_size)
self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1))
self.topk_scores = self.topk_scores.index_select(0, non_finished)
self.topk_ids = self.topk_ids.index_select(0, non_finished)
if self.alive_attn is not None:
inp_seq_len = self.alive_attn.size(-1)
self.alive_attn = attention.index_select(1, non_finished) \
.view(step - 1, _B_new * self.beam_size, inp_seq_len)
|