SSR-Speech / inference_scale.py
OpenSound's picture
11
f5b4ff2
raw
history blame
5.09 kB
# @ [email protected]
import argparse, pickle
import logging
import os, random
import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
from data.tokenizer import (
tokenize_audio,
tokenize_text
)
import time
@torch.no_grad()
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, prompt_text, target_text, mask_interval, cfg_coef, aug_text, aug_context, use_watermark, tts, device, decode_config):
# phonemize
text_tokens = [phn2num[phn] for phn in
tokenize_text(
text_tokenizer, text=target_text.strip()
) if phn in phn2num
]
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
prompt_text_tokens = [phn2num[phn] for phn in
tokenize_text(
text_tokenizer, text=prompt_text.strip()
) if phn in phn2num
]
prompt_text_tokens = torch.LongTensor(prompt_text_tokens).unsqueeze(0)
prompt_text_tokens_lens = torch.LongTensor([prompt_text_tokens.shape[-1]])
encoded_frames, scale, emb = tokenize_audio(audio_tokenizer, audio_fn)
original_audio = encoded_frames.transpose(2,1) # [1,T,K]
assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
logging.info(f"with direct encodec encoding before input, original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
# forward
stime = time.time()
encoded_frames, marks, masks, ori_masks = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
prompt_text_tokens.to(device),
prompt_text_tokens_lens.to(device),
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
mask_interval=mask_interval.unsqueeze(0).to(device),
top_k=decode_config['top_k'],
top_p=decode_config['top_p'],
temperature=decode_config['temperature'],
stop_repetition=decode_config['stop_repetition'],
kvcache=decode_config['kvcache'],
cfg_coef=cfg_coef,
aug_text=aug_text,
) # output is [1,K,T]
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
if type(encoded_frames) == tuple:
encoded_frames = encoded_frames[0]
logging.info(f"generated encoded_frames.shape: {encoded_frames.shape}, which is {encoded_frames.shape[-1]/decode_config['codec_sr']} sec.")
# decode
if use_watermark:
multiple = 320
wav, sr = torchaudio.load(audio_fn)
current_length = wav.shape[-1]
padding_length = (multiple - (current_length % multiple)) % multiple
if padding_length > 0:
wav = F.pad(wav, (0, padding_length), "constant", 0)
# new_emb = torch.zeros((1, emb.shape[1], encoded_frames.shape[-1])).to(encoded_frames.device)
new_wav = torch.zeros(1, encoded_frames.shape[-1]*320) # codec hz
ori_non_mask_intervals = [(max(item[0],0), item[1]) for item in ori_masks]
non_mask_intervals = [(max(item[0],0), item[1]) for item in masks]
for i in range(len(ori_non_mask_intervals)):
# new_emb[..., non_mask_intervals[i][0]:non_mask_intervals[i][1]] = emb[..., ori_non_mask_intervals[i][0]:ori_non_mask_intervals[i][1]]
new_wav[:, non_mask_intervals[i][0]*320:non_mask_intervals[i][1]*320] = wav[:, ori_non_mask_intervals[i][0]*320:ori_non_mask_intervals[i][1]*320]
# generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_emb, scale)
generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_wav.unsqueeze(0).to(encoded_frames.device), scale)
else:
generated_sample = audio_tokenizer.decode(encoded_frames, scale)
if tts:
wav, sr = torchaudio.load(audio_fn)
generated_sample = generated_sample[:,:, masks[0][1]*320:]
return generated_sample
def get_mask_interval(ali_fn, word_span):
with open(ali_fn, "r") as rf:
data = [l.strip().split(",") for l in rf.readlines()]
data = data[1:]
data = [item for item in data if item[3] == 'words']
# print(data)
s, e = word_span[0], word_span[1]
assert s <= e, f"s:{s}, e:{e}"
assert s >= 0, f"s:{s}"
assert e <= len(data), f"e:{e}"
if e == 0: # start
start = 0.
end = float(data[0][0])
elif s == len(data): # end
start = float(data[-1][1])
end = float(data[-1][1]) # don't know the end yet
elif s == e: # insert
start = float(data[s-1][1])
end = float(data[s][0])
else:
start = float(data[s-1][1]) if s > 0 else float(data[s][0])
end = float(data[e][0]) if e < len(data) else float(data[-1][1])
return (start, end)
if __name__ == "__main__":
pass