from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, AutoConfig from encodec import EncodecModel from encodec.utils import convert_audio import torch import torchaudio import re class GPTTTS(PreTrainedModel): def __init__(self, *model_args, **model_kwargs): super().__init__(AutoConfig.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice"), *model_args, **model_kwargs) self.model = AutoModelForCausalLM.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice") self.encodec_model = EncodecModel.encodec_model_24khz() self.encodec_model.set_target_bandwidth(1.5) self.sample_rate = self.encodec_model.sample_rate def forward(self, input_ids): #decoded = tokenizer.decode(tokens[0], skip_special_tokens=True) #decoded = input_text # Get all audio_token_ #pattern = r'audio_token_(\d+)' #audio_tokens = re.findall(pattern, decoded) #audio_tokens = [int(token) for token in audio_tokens] tokens = self.model.generate(input_ids, do_sample=True, max_length=1024, temperature=1, top_k=50, top_p=0.95)[0] # Get all tokens which are larger than 50257, and subtract 50257 from them audio_tokens = [token - 50257 for token in tokens if token > 50257] number_of_codebooks = 2 number_of_samples = len(audio_tokens) // number_of_codebooks frame = torch.zeros(1, number_of_codebooks, number_of_samples, dtype=torch.long) for sample in range(number_of_samples): for codebook in range(number_of_codebooks): frame[0, codebook, sample] = audio_tokens[sample * number_of_codebooks + codebook] frames = [(frame, None)] with torch.no_grad(): wav = self.encodec_model.decode(frames) return wav[0, :, :] class GPTTTSTokenizer(PreTrainedTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.tokenizer = AutoTokenizer.from_pretrained("anforsm/distilgpt2-finetuned-common-voice") self.tokenizer.pad_token_id = self.tokenizer.eos_token_id def tokenize(self, text, *args, **kwargs): prompt = f"text: {text}\nsound:" return self.tokenizer(prompt, return_tensors="pt") def _tokenize(self, *args, **kwargs): return self.tokenize(*args, **kwargs) def convert_tokens_to_ids(self, tokens): return tokens["input_ids"] def convert_ids_to_tokens(self, ids): return self.tokenizer.decode(ids[0], skip_special_tokens=True) def _batch_encode_plus(self, *args, **kwargs): return self.tokenize(*args, **kwargs) def _encode_plus(self, *args, **kwargs): return self.tokenize(*args, **kwargs) def save_vocabulary(self, *args, **kwargs): return self.tokenizer.save_vocabulary(*args, **kwargs)