RxnIM / molscribe /inference /beam_search.py
CYF200127's picture
Upload 116 files
5e9bd47 verified
import torch
from .decode_strategy import DecodeStrategy
class BeamSearch(DecodeStrategy):
"""Generation with beam search.
"""
def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length,
return_attention, max_length):
super(BeamSearch, self).__init__(
pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length)
self.beam_size = beam_size
self.n_best = n_best
# result caching
self.hypotheses = [[] for _ in range(batch_size)]
# beam state
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool)
self._batch_offset = torch.arange(batch_size, dtype=torch.long)
self.select_indices = None
self.done = False
def initialize(self, memory_bank, device=None):
"""Repeat src objects `beam_size` times.
"""
def fn_map_state(state, dim):
return torch.repeat_interleave(state, self.beam_size, dim=dim)
memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0)
if device is None:
device = memory_bank.device
self.memory_length = memory_bank.size(1)
super().initialize(memory_bank, device)
self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device)
self._beam_offset = torch.arange(
0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device)
self.topk_log_probs = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), device=device
).repeat(self.batch_size)
# buffers for the topk scores and 'backpointer'
self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device)
self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device)
self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device)
return fn_map_state, memory_bank
@property
def current_predictions(self):
return self.alive_seq[:, -1]
@property
def current_backptr(self):
# for testing
return self.select_indices.view(self.batch_size, self.beam_size)
@property
def batch_offset(self):
return self._batch_offset
def _pick(self, log_probs):
"""Return token decision for a step.
Args:
log_probs (FloatTensor): (B, vocab_size)
Returns:
topk_scores (FloatTensor): (B, beam_size)
topk_ids (LongTensor): (B, beam_size)
"""
vocab_size = log_probs.size(-1)
# Flatten probs into a list of probabilities.
curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1)
return topk_scores, topk_ids
def advance(self, log_probs, attn):
"""
Args:
log_probs: (B * beam_size, vocab_size)
"""
vocab_size = log_probs.size(-1)
# (non-finished) batch_size
_B = log_probs.shape[0] // self.beam_size
step = len(self) # alive_seq
self.ensure_min_length(log_probs)
# Multiply probs by the beam probability
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
curr_length = step + 1
curr_scores = log_probs / curr_length # avg log_prob
self.topk_scores, self.topk_ids = self._pick(curr_scores)
# topk_scores/topk_ids: (batch_size, beam_size)
# Recover log probs
torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs)
# Resolve beam origin and map to batch index flat representation.
self._batch_index = self.topk_ids // vocab_size
self._batch_index += self._beam_offset[:_B].unsqueeze(1)
self.select_indices = self._batch_index.view(_B * self.beam_size)
self.topk_ids.fmod_(vocab_size) # resolve true word ids
# Append last prediction.
self.alive_seq = torch.cat(
[self.alive_seq.index_select(0, self.select_indices),
self.topk_ids.view(_B * self.beam_size, 1)], -1)
if self.return_attention:
current_attn = attn.index_select(1, self.select_indices)
if step == 1:
self.alive_attn = current_attn
else:
self.alive_attn = self.alive_attn.index_select(
1, self.select_indices)
self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
self.is_finished = self.topk_ids.eq(self.eos)
self.ensure_max_length()
def update_finished(self):
_B_old = self.topk_log_probs.shape[0]
step = self.alive_seq.shape[-1] # len(self)
self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
self.is_finished = self.is_finished.to('cpu')
self.top_beam_finished |= self.is_finished[:, 0].eq(1)
predictions = self.alive_seq.view(_B_old, self.beam_size, step)
attention = (
self.alive_attn.view(
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
if self.alive_attn is not None else None)
non_finished_batch = []
for i in range(self.is_finished.size(0)):
b = self._batch_offset[i]
finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1)
# Store finished hypothesis for this batch.
for j in finished_hyp: # Beam level: finished beam j in batch i
self.hypotheses[b].append((
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start token
attention[:, i, j, :self.memory_length]
if attention is not None else None))
# End condition is the top beam finished and we can return
# n_best hypotheses.
finish_flag = self.top_beam_finished[i] != 0
if finish_flag and len(self.hypotheses[b]) >= self.n_best:
best_hyp = sorted(
self.hypotheses[b], key=lambda x: x[0], reverse=True)
for n, (score, pred, attn) in enumerate(best_hyp):
if n >= self.n_best:
break
self.scores[b].append(score.item())
self.predictions[b].append(pred)
self.attention[b].append(
attn if attn is not None else [])
else:
non_finished_batch.append(i)
non_finished = torch.tensor(non_finished_batch)
if len(non_finished) == 0:
self.done = True
return
_B_new = non_finished.shape[0]
# Remove finished batches for the next step
self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished)
self._batch_offset = self._batch_offset.index_select(0, non_finished)
non_finished = non_finished.to(self.topk_ids.device)
self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished)
self._batch_index = self._batch_index.index_select(0, non_finished)
self.select_indices = self._batch_index.view(_B_new * self.beam_size)
self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1))
self.topk_scores = self.topk_scores.index_select(0, non_finished)
self.topk_ids = self.topk_ids.index_select(0, non_finished)
if self.alive_attn is not None:
inp_seq_len = self.alive_attn.size(-1)
self.alive_attn = attention.index_select(1, non_finished) \
.view(step - 1, _B_new * self.beam_size, inp_seq_len)