import numpy as np import torch import torch.nn.functional as F from sat.generation.sampling_strategies.base_strategy import top_k_logits from sat.mpu.initialize import get_model_parallel_world_size, get_model_parallel_src_rank, get_model_parallel_group class AdvancedBaseStrategy: def __init__(self, batch_size, invalid_slices=[], temperature=1., no_repeat_ngram_size = 0, top_k=200, eps=1e-4, top_p=0.0, min_gen_length=1, end_tokens=None): self.batch_size = batch_size self.invalid_slices = invalid_slices self.temperature = temperature self.topk = top_k self.top_p = top_p self.eps = eps self.min_gen_length = min_gen_length self.ngram=no_repeat_ngram_size if end_tokens is None: end_tokens = [] self.end_tokens = end_tokens self.length_generated = 0 self.cached_beam_ngram_bans = [{} for _ in range(self.batch_size)] self._is_done = np.zeros(self.batch_size, dtype=np.bool_) self._init_cache() @property def is_done(self) -> bool: return self._is_done.all() def _init_cache(self): self.length_generated = 0 self.cached_beam_ngram_bans = [[{}] for _ in range(self.batch_size)] self._is_done = np.zeros(self.batch_size, dtype=bool) def forward(self, logits, tokens, mems, is_first = False, temperature=None): # print(is_first) batch_size, num_beam, seq_len = tokens.shape seq_len = tokens.shape[-1] if temperature is None: temperature = self.temperature logits = logits / temperature if self.min_gen_length > self.length_generated: for end_token in self.end_tokens: logits[..., end_token] = -65504 for invalid_slice in self.invalid_slices: logits[..., invalid_slice] = -65504 if self.ngram > 0 and seq_len > self.ngram: for batch_idx in range(batch_size): for i in range(num_beam): ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1 for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []): logits[batch_idx, i, banned_index] = -65504 logits = logits.view(-1, logits.size(-1)) logits = top_k_logits(logits, self.topk, self.top_p) probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch pred = torch.multinomial(probs, num_samples=1) for i in range(self.batch_size): if i >= batch_size: self._is_done[i] = True elif self._is_done[i]: pred[i] = -1 elif pred[i].item() in self.end_tokens: self._is_done[i] = True if self.ngram > 0: for batch_idx in range(batch_size): bans_continue = [] for i in range(num_beam): bans = self.cached_beam_ngram_bans[batch_idx][i].copy() ngram_prefix = tuple(tokens[batch_idx, i, -(self.ngram - 1):].tolist()) bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (pred[batch_idx],) bans_continue.append(bans) self.cached_beam_ngram_bans[batch_idx] = bans_continue tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1) self.length_generated += 1 return tokens, mems def finalize(self, tokens, mems): self._is_done = np.zeros(self.batch_size, dtype=np.bool_) self._init_cache() return tokens, mems class BeamSearchStrategy: def __init__( self, batch_size, num_beams, length_penalty=1.0, consider_end=False, end_tokens=[], invalid_slices=[], no_repeat_ngram_size=0, min_gen_length=0, deterministic=False, ): self.batch_size = batch_size self.num_beams = num_beams self.length_penalty = length_penalty self.end_tokens = end_tokens self.ngram = no_repeat_ngram_size self.min_gen_length = min_gen_length self.invalid_slices = invalid_slices self.consider_end = consider_end self.deterministic = deterministic self._init_cache() def _init_cache(self): self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors self.cached_beam_scores = 0 # [batch_size] self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)] self.length_generated = 0 self._is_done = np.zeros(self.batch_size, dtype=np.bool_) def _add_end_beams(self, score, beam, batch_idx): score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT for i in range(len(self.end_beams[batch_idx]), -1, -1): if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]: break self.end_beams[batch_idx].insert(i, beam) self.end_beams_penalized_scores[batch_idx].insert(i, score) self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams] self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams] @property def is_done(self) -> bool: return self._is_done.all() def forward(self, logits, tokens, mems): batch_size, num_beams, vocab_size = logits.shape seq_len = tokens.shape[-1] logits = logits.float() for invalid_slice in self.invalid_slices: logits[..., invalid_slice] = -65504 if self.min_gen_length > self.length_generated: for end_token in self.end_tokens: logits[..., end_token] = -65504 if self.ngram > 0 and seq_len > self.ngram: for batch_idx in range(batch_size): for i in range(num_beams): ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1 for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []): logits[batch_idx, i, banned_index] = -65504 next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size] prev_scores = self.cached_beam_scores if isinstance(prev_scores, torch.Tensor): prev_scores = prev_scores[..., None].expand_as(next_token_scores) next_token_scores = next_token_scores + prev_scores next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = F.softmax(next_token_scores, dim=-1) if num_beams < self.num_beams: # First token probs = probs[..., :vocab_size] if self.deterministic: next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb] else: next_tokens = torch.multinomial( probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams ) # [2*nb] next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens] next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices] next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc") next_tokens = next_tokens % vocab_size # select out end beams or continue beams beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], [] for batch_idx in range(batch_size): beam_continue = [] scores_continue = [] bans_continue = [] mems_contiue = [] for i in range(len(next_tokens[batch_idx])): beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1])) if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens: self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx) elif len(beam_continue) < self.num_beams: beam_continue.append(beam) mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]]) # update caches scores_continue.append(next_token_scores[batch_idx, i]) if self.ngram > 0: bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy() # TODO ngram=1 ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist()) bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],) bans_continue.append(bans) else: break beam_continue_batch.append(torch.stack(beam_continue)) mems_continue_batch.append(torch.stack(mems_contiue, dim=1)) score_continue_batch.append(scores_continue) self.cached_beam_ngram_bans[batch_idx] = bans_continue tokens = torch.stack(beam_continue_batch) mems = torch.stack(mems_continue_batch, dim=1) self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device) self.length_generated += 1 for batch_idx in range(self.batch_size): if batch_idx >= batch_size: self._is_done[batch_idx] = True elif ( len(self.end_beams[batch_idx]) == self.num_beams and self.end_beams_penalized_scores[batch_idx][-1] >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty ): # We're done if none of current tokens will better than the worst in end_beams self._is_done[batch_idx] = True return tokens, mems def finalize(self, tokens, mems): if self.consider_end: batch_size, num_beams = tokens.shape[:2] for batch_idx in range(batch_size): if not self._is_done[batch_idx]: for i in range(num_beams): self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx) mems = None ret = self.end_beams[:batch_size] else: ret = tokens self._init_cache() return ret, mems