File size: 5,087 Bytes
f5b4ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# @ [email protected]

import argparse, pickle
import logging
import os, random
import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
from data.tokenizer import (
    tokenize_audio,
    tokenize_text
)
import time


@torch.no_grad()
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, prompt_text, target_text, mask_interval, cfg_coef, aug_text, aug_context, use_watermark, tts, device, decode_config):
    # phonemize
    text_tokens = [phn2num[phn] for phn in
            tokenize_text(
                text_tokenizer, text=target_text.strip()
            ) if phn in phn2num
        ]
    text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
    text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])

    prompt_text_tokens = [phn2num[phn] for phn in
        tokenize_text(
            text_tokenizer, text=prompt_text.strip()
        ) if phn in phn2num
    ]
    prompt_text_tokens = torch.LongTensor(prompt_text_tokens).unsqueeze(0)
    prompt_text_tokens_lens = torch.LongTensor([prompt_text_tokens.shape[-1]])

    encoded_frames, scale, emb = tokenize_audio(audio_tokenizer, audio_fn)
    original_audio = encoded_frames.transpose(2,1) # [1,T,K]
    assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
    logging.info(f"with direct encodec encoding before input, original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")

    # forward
    stime = time.time()
    encoded_frames, marks, masks, ori_masks = model.inference(
        text_tokens.to(device),
        text_tokens_lens.to(device),
        prompt_text_tokens.to(device),
        prompt_text_tokens_lens.to(device),
        original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
        original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
        mask_interval=mask_interval.unsqueeze(0).to(device),
        top_k=decode_config['top_k'],
        top_p=decode_config['top_p'],
        temperature=decode_config['temperature'],
        stop_repetition=decode_config['stop_repetition'],
        kvcache=decode_config['kvcache'],
        cfg_coef=cfg_coef,
        aug_text=aug_text,
    ) # output is [1,K,T]
    logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
    if type(encoded_frames) == tuple:
        encoded_frames = encoded_frames[0]
    logging.info(f"generated encoded_frames.shape: {encoded_frames.shape}, which is {encoded_frames.shape[-1]/decode_config['codec_sr']} sec.")

    # decode 
    if use_watermark:
        multiple = 320
        wav, sr = torchaudio.load(audio_fn)
        current_length = wav.shape[-1]
        padding_length = (multiple - (current_length % multiple)) % multiple
        if padding_length > 0:
            wav = F.pad(wav, (0, padding_length), "constant", 0)
        # new_emb = torch.zeros((1, emb.shape[1], encoded_frames.shape[-1])).to(encoded_frames.device)
        new_wav = torch.zeros(1, encoded_frames.shape[-1]*320) # codec hz
        
        ori_non_mask_intervals = [(max(item[0],0), item[1]) for item in ori_masks]
        non_mask_intervals = [(max(item[0],0), item[1]) for item in masks]
        for i in range(len(ori_non_mask_intervals)):
            # new_emb[..., non_mask_intervals[i][0]:non_mask_intervals[i][1]] = emb[..., ori_non_mask_intervals[i][0]:ori_non_mask_intervals[i][1]]
            new_wav[:, non_mask_intervals[i][0]*320:non_mask_intervals[i][1]*320] = wav[:, ori_non_mask_intervals[i][0]*320:ori_non_mask_intervals[i][1]*320]

        # generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_emb, scale)
        generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_wav.unsqueeze(0).to(encoded_frames.device), scale)

    else:
        generated_sample = audio_tokenizer.decode(encoded_frames, scale)
        
    if tts:
        wav, sr = torchaudio.load(audio_fn)
        generated_sample = generated_sample[:,:, masks[0][1]*320:]
            
    return generated_sample



def get_mask_interval(ali_fn, word_span):
    with open(ali_fn, "r") as rf:
        data = [l.strip().split(",") for l in rf.readlines()]
        data = data[1:]
    data = [item for item in data if item[3] == 'words']
    # print(data)
    s, e = word_span[0], word_span[1]
    assert s <= e, f"s:{s}, e:{e}"
    assert s >= 0, f"s:{s}"
    assert e <= len(data), f"e:{e}"
    if e == 0: # start
        start = 0.
        end = float(data[0][0])
    elif s == len(data): # end
        start = float(data[-1][1])
        end = float(data[-1][1]) # don't know the end yet
    elif s == e: # insert
        start = float(data[s-1][1])
        end = float(data[s][0])
    else:
        start = float(data[s-1][1]) if s > 0 else float(data[s][0])
        end = float(data[e][0]) if e < len(data) else float(data[-1][1])

    return (start, end)

if __name__ == "__main__":
    pass