|
import os |
|
import torch |
|
import stat |
|
import re |
|
import time |
|
import argparse |
|
import numpy as np |
|
|
|
from functools import partial |
|
from typing import List, Tuple |
|
|
|
import torch.distributed as dist |
|
from sat.helpers import print_rank0 |
|
from sat import mpu, get_args, get_tokenizer |
|
from sat.generation.utils import timed_name, generate_continually |
|
from sat.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default |
|
|
|
from .utils import move_cursor_up, move_cursor_down |
|
|
|
|
|
def get_masks_and_position_ids(seq, msa_len, max_gen_length, gmask=False): |
|
context_length = seq.shape[1] |
|
query_len = msa_len |
|
max_msa_num = (max_gen_length - 2) // query_len |
|
max_gen_length = max_msa_num * query_len + 2 |
|
tokens = torch.nn.functional.pad(seq, (0, max_gen_length - context_length), mode="constant", value=-1) |
|
attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device) |
|
attention_mask.tril_() |
|
attention_mask.unsqueeze_(1) |
|
attention_mask = (attention_mask < 0.5).bool() |
|
|
|
position_ids = np.zeros(max_gen_length, dtype=int) |
|
block_position_ids = np.zeros(max_gen_length, dtype=int) |
|
pre = 0 |
|
for msa_idx in range(max_msa_num): |
|
position_ids[(1 + pre): (1 + pre + query_len)] = np.arange(query_len, dtype = int) |
|
block_position_ids[(1 + pre): (1 + pre + query_len)] = msa_idx |
|
pre += query_len |
|
position_ids = np.stack((position_ids, block_position_ids), axis=0) |
|
position_ids = torch.from_numpy(position_ids).to(tokens.device) |
|
position_ids = position_ids.unsqueeze(0) |
|
return tokens, attention_mask, position_ids |
|
|
|
|
|
|
|
def generation_sequence( |
|
model, |
|
seqs, |
|
strategy, |
|
max_memory_length=100000, |
|
get_masks_and_position_ids=get_masks_and_position_ids, |
|
stream=False, |
|
mems=None, |
|
**kw_args |
|
): |
|
''' |
|
seq: [2, 3, 5, ..., -1(to be generated), -1, ...] |
|
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size] |
|
cache, should be first mems.shape[1] parts of context_tokens. |
|
mems are the first-level citizens here, but we don't assume what is memorized. |
|
input mems are used when multi-phase generation. |
|
''' |
|
assert len(seqs.shape) == 2 |
|
|
|
batch_size, context_length = seqs.shape |
|
seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs) |
|
tokens = seqs[..., :context_length] |
|
|
|
counter = context_length |
|
index = 0 if mems is None else mems.shape[2] |
|
num_beams = 1 |
|
|
|
while counter < seqs.shape[1] - 1: |
|
|
|
|
|
|
|
tokens = tokens.reshape(batch_size * num_beams, -1) |
|
mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None |
|
model.eval() |
|
with torch.no_grad(): |
|
logits, *output_per_layers = model( |
|
tokens[:, index:], |
|
position_ids[..., index: counter], |
|
attention_mask[..., index: counter, :counter], |
|
mems=mems, |
|
**kw_args |
|
) |
|
mem_kv = [o['mem_kv'] for o in output_per_layers] |
|
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) |
|
logits = logits[:, -1] |
|
index = counter |
|
counter += 1 |
|
logits = logits.reshape(batch_size, num_beams, -1) |
|
tokens = tokens.reshape(batch_size, num_beams, -1) |
|
mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1]) |
|
tokens, mems = strategy.forward(logits, tokens, mems) |
|
if len(tokens.shape) == 3 and num_beams == 1: |
|
num_beams = tokens.shape[1] |
|
position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, 2, -1).reshape(batch_size * num_beams, 2, -1) |
|
attention_mask_shape = attention_mask.shape[-3:] |
|
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape( |
|
batch_size * num_beams, *attention_mask_shape) |
|
if strategy.is_done: |
|
break |
|
return strategy.finalize(tokens, mems) |
|
|
|
|
|
def stream_generation_sequence( |
|
model, |
|
seqs, |
|
strategy, |
|
max_memory_length=100000, |
|
get_masks_and_position_ids=get_masks_and_position_ids, |
|
stream=False, |
|
mems=None, |
|
**kw_args |
|
): |
|
''' |
|
seq: [2, 3, 5, ..., -1(to be generated), -1, ...] |
|
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size] |
|
cache, should be first mems.shape[1] parts of context_tokens. |
|
mems are the first-level citizens here, but we don't assume what is memorized. |
|
input mems are used when multi-phase generation. |
|
''' |
|
assert len(seqs.shape) == 2 |
|
|
|
batch_size, context_length = seqs.shape |
|
seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs) |
|
tokens = seqs[..., :context_length] |
|
|
|
counter = context_length |
|
index = 0 if mems is None else mems.shape[2] |
|
num_beams = 1 |
|
|
|
while counter < seqs.shape[1] - 1: |
|
|
|
|
|
|
|
tokens = tokens.reshape(batch_size * num_beams, -1) |
|
mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None |
|
model.eval() |
|
with torch.no_grad(): |
|
logits, *output_per_layers = model( |
|
tokens[:, index:], |
|
position_ids[..., index: counter], |
|
attention_mask[..., index: counter, :counter], |
|
mems=mems, |
|
**kw_args |
|
) |
|
mem_kv = [o['mem_kv'] for o in output_per_layers] |
|
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) |
|
logits = logits[:, -1] |
|
index = counter |
|
counter += 1 |
|
logits = logits.reshape(batch_size, num_beams, -1) |
|
tokens = tokens.reshape(batch_size, num_beams, -1) |
|
mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1]) |
|
tokens, mems = strategy.forward(logits, tokens, mems, is_first=False) |
|
if len(tokens.shape) == 3 and num_beams == 1: |
|
num_beams = tokens.shape[1] |
|
position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, 2, -1).reshape(batch_size * num_beams, 2, -1) |
|
attention_mask_shape = attention_mask.shape[-3:] |
|
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape( |
|
batch_size * num_beams, *attention_mask_shape) |
|
yield tokens, mems |
|
if strategy.is_done: |
|
break |
|
|
|
|
|
|
|
def autoregressive_sampling(args, raw_text: str, model, tokenizer, strategy, stream=False) -> Tuple[List[str], List[str], List[List[str]]]: |
|
|
|
generation_mask = "[gMASK]" |
|
seq = [] |
|
msa_len = len(raw_text[0]) + 1 |
|
seq += [tokenizer.get_command(generation_mask)] + [tokenizer.get_command("sop")] |
|
for each in raw_text: |
|
seq += tokenizer.tokenize(each) + [tokenizer.get_command('<M>')] |
|
|
|
output_list = [seq] |
|
num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1 |
|
seq = output_list[0] |
|
|
|
mask_token = tokenizer.get_command(generation_mask) |
|
mask_position = seq.index(mask_token) |
|
|
|
last_pos, answers, blanks, output_list = ( |
|
[0] * num_output, |
|
["" for _ in range(num_output)], |
|
[[] for _ in range(num_output)], |
|
[] |
|
) |
|
icl_msas = len(raw_text) |
|
input_seq = torch.tensor( |
|
[seq], |
|
dtype = torch.long, |
|
device=args.device, |
|
) |
|
if args.stream_chat: |
|
if args.chinese: |
|
print(f"{'生成的MSA'.center(20, '*')}", flush=True) |
|
else: |
|
print(f"{'Virtual MSA'.center(20, '*')}", flush=True) |
|
output_stream = stream_generation_sequence( |
|
model = model, |
|
seqs = input_seq, |
|
strategy=strategy, |
|
get_masks_and_position_ids=partial( |
|
get_masks_and_position_ids, |
|
msa_len = msa_len, |
|
max_gen_length=args.max_gen_length, |
|
gmask=True |
|
) |
|
) |
|
offset = -1 |
|
for tmp_res, mems in output_stream: |
|
if isinstance(tmp_res, torch.Tensor): |
|
output = tmp_res.tolist() |
|
output_list = output[0] |
|
for i in range(len(output_list)): |
|
output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i] |
|
bog = output.index(tokenizer.get_command("sop")) |
|
try: |
|
unfinished = output.index(-1) |
|
except ValueError: |
|
unfinished = len(output) |
|
output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] |
|
for i, output in enumerate(output_list): |
|
if output[-1] == tokenizer.get_command("eos"): |
|
output = output[:-1] |
|
answers[i] = tokenizer.detokenize(output) |
|
tmp_ret = answers[0] |
|
if mpu.get_model_parallel_rank() == 0: |
|
if not args.multiline_stream: |
|
vit_msa = tmp_ret[offset if offset>0 else -1:] |
|
print(vit_msa, end='', flush=True) |
|
offset = len(tmp_ret) |
|
else: |
|
print_len = 0 |
|
vit_msa = tmp_ret.split('[<M>]')[icl_msas:] |
|
vit_msa = [_ for _ in vit_msa if len(_) > 0] |
|
for _ in vit_msa: |
|
print(_) |
|
print_len += 1 |
|
move_cursor_up(print_len) |
|
|
|
move_cursor_down(print_len) |
|
print('\n') |
|
output = strategy.finalize(tmp_res, mems)[0] |
|
else: |
|
output, _ = generation_sequence( |
|
model = model, |
|
seqs = input_seq, |
|
strategy=strategy, |
|
get_masks_and_position_ids=partial( |
|
get_masks_and_position_ids, |
|
msa_len = msa_len, |
|
max_gen_length=args.max_gen_length, |
|
gmask=True |
|
) |
|
) |
|
last_pos, answers, blanks, output_list = ( |
|
[0] * num_output, |
|
["" for _ in range(num_output)], |
|
[[] for _ in range(num_output)], |
|
[] |
|
) |
|
if isinstance(output, torch.Tensor): |
|
output = output.tolist() |
|
output = output[0] |
|
output_list.extend(output) |
|
|
|
for i in range(len(output_list)): |
|
output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i] |
|
try: |
|
unfinished = output.index(-1) |
|
except ValueError: |
|
unfinished = len(output) |
|
|
|
|
|
bog = output.index(tokenizer.get_command("sop")) |
|
|
|
prefix = tokenizer.detokenize(output[last_pos[i] : mask_position]) |
|
blank = tokenizer.detokenize(output[bog + 1 : unfinished]) |
|
blanks[i].append(blank) |
|
last_pos[i] = mask_position + unfinished - (bog + 1) |
|
output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] |
|
|
|
|
|
for i, output in enumerate(output_list): |
|
if output[-1] == tokenizer.get_command("eos"): |
|
output = output[:-1] |
|
answers[i] = tokenizer.detokenize(output) |
|
return answers |
|
|
|
|
|
def offline_generation(args, temp, top_p, top_k, func): |
|
os.makedirs(args.output_path, exist_ok=True) |
|
with open(args.input_source, 'r', encoding="utf-8") as fin: |
|
inputs = fin.readlines() |
|
output_path = os.path.join(args.output_path, f"tmp_{temp}_p_{top_p}_k_{top_k}") |
|
fin = open(output_path, 'w') |
|
start_time = time.time() |
|
for line_no, raw_text in enumerate(inputs): |
|
if line_no % mpu.get_data_parallel_world_size() != mpu.get_data_parallel_rank(): |
|
continue |
|
rk = dist.get_rank() |
|
raw_text = raw_text.strip() |
|
raw_text = raw_text.split('<M>') |
|
main_seq = raw_text[0] |
|
|
|
msa_len = len(main_seq) + 1 |
|
icl_msas = len(raw_text) |
|
require_min_gen_length = msa_len * (icl_msas + 1) + 2 |
|
if args.max_gen_length < require_min_gen_length: |
|
args.max_gen_length = require_min_gen_length |
|
|
|
if mpu.get_model_parallel_rank() == 0: |
|
print(f'Processing No. {line_no} on model group {rk} input main seq: "{main_seq}" few-shot prompt: "{"<M>".join(raw_text[1:])}"') |
|
if len(raw_text) == 0: |
|
continue |
|
ret = func(raw_text) |
|
if mpu.get_model_parallel_rank() == 0: |
|
if args.print_all_beams: |
|
for idx, vit_msa in enumerate(ret): |
|
vit_msa = vit_msa.split('[<M>]')[icl_msas:] |
|
vit_msa = [_ for _ in vit_msa if len(_) > 0] |
|
vit_msa_len = len(vit_msa) |
|
vit_msa_str = '<M>'.join(vit_msa) |
|
print('Beam: {} #Vitural Length:{} | MSA: "{}" | (Temp, P, K)=({}, {}, {}) | Taken time {:.2f}'.format(idx, vit_msa_len, vit_msa_str, temp, top_p, top_k, time.time() - start_time), flush=True) |
|
else: |
|
vit_msa = ret[0] |
|
vit_msa = vit_msa.split('[<M>]')[icl_msas:] |
|
vit_msa = [_ for _ in vit_msa if len(_) > 0] |
|
vit_msa_len = len(vit_msa) |
|
vit_msa_str = '<M>'.join(vit_msa) |
|
fin.write(f"{vit_msa_str}"+'\n') |
|
print('#Vitural Length:{} | MSA: "{}" | (Temp, P, K)=({}, {}, {}) | Taken time {:.2f}'.format(vit_msa_len, vit_msa_str, temp, top_p, top_k, time.time() - start_time), flush=True) |
|
print() |
|
fin.flush() |
|
dist.barrier() |
|
fin.close() |
|
|
|
|
|
def online_generation(args, query, temp, top_p, top_k, func): |
|
raw_text = query.strip() |
|
raw_text = raw_text.split('<M>') |
|
main_seq = raw_text[0] |
|
msa_len = len(main_seq) + 1 |
|
icl_msas = len(raw_text) |
|
require_min_gen_length = msa_len * (icl_msas + 1) + 2 |
|
if args.max_gen_length < require_min_gen_length: |
|
args.max_gen_length = require_min_gen_length |
|
ret = func(raw_text) |
|
response = [] |
|
if mpu.get_model_parallel_rank() == 0: |
|
for idx, vit_msa in enumerate(ret): |
|
vit_msa = vit_msa.split('[<M>]')[icl_msas:] |
|
vit_msa = [_ for _ in vit_msa if len(_) > 0] |
|
response.append(vit_msa) |
|
return response |
|
|
|
|
|
def chat_api(args, model, tokenizer, strategy, query=None): |
|
if args.input_source == 'chat': |
|
assert query is not None |
|
ret = online_generation(args, query, temp=args.temperature, top_p = args.top_p, top_k = args.top_k, func = partial(autoregressive_sampling, args, model = model, tokenizer = tokenizer, strategy = strategy)) |
|
return ret |
|
else: |
|
assert not args.stream_chat, "Offline Generation don't support streaming output." |
|
offline_generation(args, temp=args.temperature, top_p = args.top_p, top_k = args.top_k, func = partial(autoregressive_sampling, args, model = model, tokenizer = tokenizer, strategy = strategy)) |
|
|