English
Protein_Language_Model
MSA Generation
MSAGPT / utils /strategies.py
Bo1015's picture
Upload 27 files
0dce0bd verified
raw
history blame
10.8 kB
import numpy as np
import torch
import torch.nn.functional as F
from sat.generation.sampling_strategies.base_strategy import top_k_logits
from sat.mpu.initialize import get_model_parallel_world_size, get_model_parallel_src_rank, get_model_parallel_group
class AdvancedBaseStrategy:
def __init__(self, batch_size, invalid_slices=[], temperature=1., no_repeat_ngram_size = 0, top_k=200, eps=1e-4, top_p=0.0, min_gen_length=1, end_tokens=None):
self.batch_size = batch_size
self.invalid_slices = invalid_slices
self.temperature = temperature
self.topk = top_k
self.top_p = top_p
self.eps = eps
self.min_gen_length = min_gen_length
self.ngram=no_repeat_ngram_size
if end_tokens is None:
end_tokens = []
self.end_tokens = end_tokens
self.length_generated = 0
self.cached_beam_ngram_bans = [{} for _ in range(self.batch_size)]
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
self._init_cache()
@property
def is_done(self) -> bool:
return self._is_done.all()
def _init_cache(self):
self.length_generated = 0
self.cached_beam_ngram_bans = [[{}] for _ in range(self.batch_size)]
self._is_done = np.zeros(self.batch_size, dtype=bool)
def forward(self, logits, tokens, mems, is_first = False, temperature=None):
# print(is_first)
batch_size, num_beam, seq_len = tokens.shape
seq_len = tokens.shape[-1]
if temperature is None:
temperature = self.temperature
logits = logits / temperature
if self.min_gen_length > self.length_generated:
for end_token in self.end_tokens:
logits[..., end_token] = -65504
for invalid_slice in self.invalid_slices:
logits[..., invalid_slice] = -65504
if self.ngram > 0 and seq_len > self.ngram:
for batch_idx in range(batch_size):
for i in range(num_beam):
ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
logits[batch_idx, i, banned_index] = -65504
logits = logits.view(-1, logits.size(-1))
logits = top_k_logits(logits, self.topk, self.top_p)
probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
pred = torch.multinomial(probs, num_samples=1)
for i in range(self.batch_size):
if i >= batch_size:
self._is_done[i] = True
elif self._is_done[i]:
pred[i] = -1
elif pred[i].item() in self.end_tokens:
self._is_done[i] = True
if self.ngram > 0:
for batch_idx in range(batch_size):
bans_continue = []
for i in range(num_beam):
bans = self.cached_beam_ngram_bans[batch_idx][i].copy()
ngram_prefix = tuple(tokens[batch_idx, i, -(self.ngram - 1):].tolist())
bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (pred[batch_idx],)
bans_continue.append(bans)
self.cached_beam_ngram_bans[batch_idx] = bans_continue
tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
self.length_generated += 1
return tokens, mems
def finalize(self, tokens, mems):
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
self._init_cache()
return tokens, mems
class BeamSearchStrategy:
def __init__(
self,
batch_size,
num_beams,
length_penalty=1.0,
consider_end=False,
end_tokens=[],
invalid_slices=[],
no_repeat_ngram_size=0,
min_gen_length=0,
deterministic=False,
):
self.batch_size = batch_size
self.num_beams = num_beams
self.length_penalty = length_penalty
self.end_tokens = end_tokens
self.ngram = no_repeat_ngram_size
self.min_gen_length = min_gen_length
self.invalid_slices = invalid_slices
self.consider_end = consider_end
self.deterministic = deterministic
self._init_cache()
def _init_cache(self):
self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors
self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors
self.cached_beam_scores = 0 # [batch_size]
self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
self.length_generated = 0
self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
def _add_end_beams(self, score, beam, batch_idx):
score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
for i in range(len(self.end_beams[batch_idx]), -1, -1):
if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
break
self.end_beams[batch_idx].insert(i, beam)
self.end_beams_penalized_scores[batch_idx].insert(i, score)
self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
@property
def is_done(self) -> bool:
return self._is_done.all()
def forward(self, logits, tokens, mems):
batch_size, num_beams, vocab_size = logits.shape
seq_len = tokens.shape[-1]
logits = logits.float()
for invalid_slice in self.invalid_slices:
logits[..., invalid_slice] = -65504
if self.min_gen_length > self.length_generated:
for end_token in self.end_tokens:
logits[..., end_token] = -65504
if self.ngram > 0 and seq_len > self.ngram:
for batch_idx in range(batch_size):
for i in range(num_beams):
ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
logits[batch_idx, i, banned_index] = -65504
next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
prev_scores = self.cached_beam_scores
if isinstance(prev_scores, torch.Tensor):
prev_scores = prev_scores[..., None].expand_as(next_token_scores)
next_token_scores = next_token_scores + prev_scores
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = F.softmax(next_token_scores, dim=-1)
if num_beams < self.num_beams: # First token
probs = probs[..., :vocab_size]
if self.deterministic:
next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
else:
next_tokens = torch.multinomial(
probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
) # [2*nb]
next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
next_tokens = next_tokens % vocab_size
# select out end beams or continue beams
beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
for batch_idx in range(batch_size):
beam_continue = []
scores_continue = []
bans_continue = []
mems_contiue = []
for i in range(len(next_tokens[batch_idx])):
beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
elif len(beam_continue) < self.num_beams:
beam_continue.append(beam)
mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
# update caches
scores_continue.append(next_token_scores[batch_idx, i])
if self.ngram > 0:
bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
# TODO ngram=1
ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
bans_continue.append(bans)
else:
break
beam_continue_batch.append(torch.stack(beam_continue))
mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
score_continue_batch.append(scores_continue)
self.cached_beam_ngram_bans[batch_idx] = bans_continue
tokens = torch.stack(beam_continue_batch)
mems = torch.stack(mems_continue_batch, dim=1)
self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
self.length_generated += 1
for batch_idx in range(self.batch_size):
if batch_idx >= batch_size:
self._is_done[batch_idx] = True
elif (
len(self.end_beams[batch_idx]) == self.num_beams
and self.end_beams_penalized_scores[batch_idx][-1]
>= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
): # We're done if none of current tokens will better than the worst in end_beams
self._is_done[batch_idx] = True
return tokens, mems
def finalize(self, tokens, mems):
if self.consider_end:
batch_size, num_beams = tokens.shape[:2]
for batch_idx in range(batch_size):
if not self._is_done[batch_idx]:
for i in range(num_beams):
self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
mems = None
ret = self.end_beams[:batch_size]
else:
ret = tokens
self._init_cache()
return ret, mems