# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py # Copyright 2023 (authors: Feiteng Li) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Pattern, Union import numpy as np import torch import torchaudio from phonemizer.backend import EspeakBackend from phonemizer.backend.espeak.language_switch import LanguageSwitch from phonemizer.backend.espeak.words_mismatch import WordMismatch from phonemizer.punctuation import Punctuation from phonemizer.separator import Separator import torch.nn.functional as F class TextTokenizer: """Phonemize Text.""" def __init__( self, language="en-us", backend="espeak", separator=Separator(word="_", syllable="-", phone="|"), preserve_punctuation=True, punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), with_stress: bool = False, tie: Union[bool, str] = False, language_switch: LanguageSwitch = "keep-flags", words_mismatch: WordMismatch = "ignore", ) -> None: phonemizer = EspeakBackend( language, punctuation_marks=punctuation_marks, preserve_punctuation=preserve_punctuation, with_stress=with_stress, tie=tie, language_switch=language_switch, words_mismatch=words_mismatch, ) self.backend = phonemizer self.separator = separator def to_list(self, phonemized: str) -> List[str]: fields = [] for word in phonemized.split(self.separator.word): # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) fields.extend( [p for p in pp if p != self.separator.phone] + [self.separator.word] ) assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( self.separator.phone ) return fields[:-1] def __call__(self, text, strip=True) -> List[List[str]]: if isinstance(text, str): text = [text] phonemized = self.backend.phonemize( text, separator=self.separator, strip=strip, njobs=1 ) return [self.to_list(p) for p in phonemized] def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: phonemes = tokenizer([text.strip()]) return phonemes[0] # k2symbols def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): assert wav.shape[0] in [1, 2], "Audio must be mono or stereo." if target_channels == 1: wav = wav.mean(0, keepdim=True) elif target_channels == 2: *shape, _, length = wav.shape wav = wav.expand(*shape, target_channels, length) elif wav.shape[0] == 1: wav = wav.expand(target_channels, -1) wav = torchaudio.transforms.Resample(sr, target_sr)(wav) return wav class AudioTokenizer: """EnCodec audio.""" def __init__( self, device: Any = None, signature = None ) -> None: from audiocraft.solvers import WMCompressionSolver model = WMCompressionSolver.model_from_checkpoint(signature).eval() self.sample_rate = model.sample_rate self.channels = model.channels if not device: device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda:0") self._device = device self.codec = model.to(device) @property def device(self): return self._device def encode(self, wav: torch.Tensor) -> torch.Tensor: codes, scale, emb = self.codec.encode(wav.to(self.device)) return codes, scale, emb def decode(self, frames: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return self.codec.decode(frames, scale) def wmdecode(self, frames: torch.Tensor, marks: torch.Tensor, wav: torch.Tensor, scale: torch.Tensor): out, _ = self.codec.wmdecode(frames.to(self.device), marks.to(self.device), wav.to(self.device), scale) return out def detect_watermark(self, wav: torch.Tensor): marks = self.codec.detect_watermark(wav.to(self.device)) return marks def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1, multiple=320): # Load and pre-process the audio waveform if offset != -1 and num_frames!=-1: wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames) else: wav, sr = torchaudio.load(audio_path) current_length = wav.shape[-1] padding_length = (multiple - (current_length % multiple)) % multiple if padding_length > 0: wav = F.pad(wav, (0, padding_length), "constant", 0) wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) wav = wav.unsqueeze(0) # Extract discrete codes from EnCodec with torch.no_grad(): encoded_frames, scale, emb = tokenizer.encode(wav) return encoded_frames, scale, emb