onnx converted model for https://huggingface.co/proxectonos/Nos_TTS-celtia-vits-graphemes

for inference

# minimal onnx inference extracted from coqui-tts
import json
import re
from typing import Callable, List

import numpy as np
import onnxruntime as ort
import scipy

# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")


class Graphemes:
    def __init__(
            self,
            characters: str = None,
            punctuations: str = None,
            pad: str = None,
            eos: str = None,
            bos: str = None,
            blank: str = "<BLNK>",
            is_unique: bool = False,
            is_sorted: bool = True,
    ) -> None:
        self._characters = characters
        self._punctuations = punctuations
        self._pad = pad
        self._eos = eos
        self._bos = bos
        self._blank = blank
        self.is_unique = is_unique
        self.is_sorted = is_sorted
        self._create_vocab()

    @property
    def pad_id(self) -> int:
        return self.char_to_id(self.pad) if self.pad else len(self.vocab)

    @property
    def blank_id(self) -> int:
        return self.char_to_id(self.blank) if self.blank else len(self.vocab)

    @property
    def eos_id(self) -> int:
        return self.char_to_id(self.eos) if self.eos else len(self.vocab)

    @property
    def bos_id(self) -> int:
        return self.char_to_id(self.bos) if self.bos else len(self.vocab)

    @property
    def characters(self):
        return self._characters

    @characters.setter
    def characters(self, characters):
        self._characters = characters
        self._create_vocab()

    @property
    def punctuations(self):
        return self._punctuations

    @punctuations.setter
    def punctuations(self, punctuations):
        self._punctuations = punctuations
        self._create_vocab()

    @property
    def pad(self):
        return self._pad

    @pad.setter
    def pad(self, pad):
        self._pad = pad
        self._create_vocab()

    @property
    def eos(self):
        return self._eos

    @eos.setter
    def eos(self, eos):
        self._eos = eos
        self._create_vocab()

    @property
    def bos(self):
        return self._bos

    @bos.setter
    def bos(self, bos):
        self._bos = bos
        self._create_vocab()

    @property
    def blank(self):
        return self._blank

    @blank.setter
    def blank(self, blank):
        self._blank = blank
        self._create_vocab()

    @property
    def vocab(self):
        return self._vocab

    @vocab.setter
    def vocab(self, vocab):
        self._vocab = vocab
        self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
        self._id_to_char = {
            idx: char for idx, char in enumerate(self.vocab)  # pylint: disable=unnecessary-comprehension
        }

    @property
    def num_chars(self):
        return len(self._vocab)

    def _create_vocab(self):
        self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
        self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
        # pylint: disable=unnecessary-comprehension
        self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}

    def char_to_id(self, char: str) -> int:
        try:
            return self._char_to_id[char]
        except KeyError as e:
            raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e

    def id_to_char(self, idx: int) -> str:
        return self._id_to_char[idx]


class TTSTokenizer:
    """🐸TTS tokenizer to convert input characters to token IDs and back.

    Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later.

    Args:
        characters (Characters):
            A Characters object to use for character-to-ID and ID-to-character mappings.

        text_cleaner (callable):
            A function to pre-process the text before tokenization and phonemization. Defaults to None.
    """

    def __init__(
            self,
            text_cleaner: Callable = None,
            characters: Graphemes = None,
            add_blank: bool = False,
            use_eos_bos=False,
    ):
        self.text_cleaner = text_cleaner
        self.add_blank = add_blank
        self.use_eos_bos = use_eos_bos
        self.characters = characters
        self.not_found_characters = []

    @property
    def characters(self):
        return self._characters

    @characters.setter
    def characters(self, new_characters):
        self._characters = new_characters
        self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None
        self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None

    def encode(self, text: str) -> List[int]:
        """Encodes a string of text as a sequence of IDs."""
        token_ids = []
        for char in text:
            try:
                idx = self.characters.char_to_id(char)
                token_ids.append(idx)
            except KeyError:
                # discard but store not found characters
                if char not in self.not_found_characters:
                    self.not_found_characters.append(char)
                    print(text)
                    print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
        return token_ids

    def text_to_ids(self, text: str) -> List[int]:  # pylint: disable=unused-argument
        """Converts a string of text to a sequence of token IDs.

        Args:
            text(str):
                The text to convert to token IDs.

        1. Text normalization
        3. Add blank char between characters
        4. Add BOS and EOS characters
        5. Text to token IDs
        """
        if self.text_cleaner is not None:
            text = self.text_cleaner(text)
        text = self.encode(text)
        if self.add_blank:
            text = self.intersperse_blank_char(text, True)
        if self.use_eos_bos:
            text = self.pad_with_bos_eos(text)
        return text

    def pad_with_bos_eos(self, char_sequence: List[str]):
        """Pads a sequence with the special BOS and EOS characters."""
        return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id]

    def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
        """Intersperses the blank character between characters in a sequence.

        Use the ```blank``` character if defined else use the ```pad``` character.
        """
        char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad
        result = [char_to_use] * (len(char_sequence) * 2 + 1)
        result[1::2] = char_sequence
        return result


class VitsOnnxInference:
    def __init__(self, onnx_model_path: str, config_path: str, cuda=False):
        self.config = {}
        if config_path:
            with open(config_path) as f:
                self.config = json.load(f)
        providers = [
            "CPUExecutionProvider"
            if cuda is False
            else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
        ]
        sess_options = ort.SessionOptions()
        self.onnx_sess = ort.InferenceSession(
            onnx_model_path,
            sess_options=sess_options,
            providers=providers,
        )

        _pad = self.config.get("characters", {}).get("pad", "_")
        _punctuations = self.config.get("characters", {}).get("punctuations", "!\"(),-.:;?\u00a1\u00bf ")
        _letters = self.config.get("characters", {}).get("characters",
                                                         "ABCDEFGHIJKLMNOPQRSTUVXYZabcdefghijklmnopqrstuvwxyz\u00c1\u00c9\u00cd\u00d3\u00da\u00e1\u00e9\u00ed\u00f1\u00f3\u00fa\u00fc")

        vocab = Graphemes(characters=_letters,
                          punctuations=_punctuations,
                          pad=_pad)

        self.tokenizer = TTSTokenizer(
            text_cleaner=self.normalize_text,
            characters=vocab,
            add_blank=self.config.get("add_blank", True),
            use_eos_bos=False,
        )

    @staticmethod
    def normalize_text(text: str) -> str:
        """Basic pipeline that lowercases and collapses whitespace without transliteration."""
        text = text.lower()
        text = text.replace(";", ",")
        text = text.replace("-", " ")
        text = text.replace(":", ",")
        text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text)
        text = re.sub(_whitespace_re, " ", text).strip()
        return text

    def inference_onnx(self, text: str):
        """ONNX inference"""
        x = np.asarray(
            self.tokenizer.text_to_ids(text),
            dtype=np.int64,
        )[None, :]

        x_lengths = np.array([x.shape[1]], dtype=np.int64)

        scales = np.array(
            [self.config.get("inference_noise_scale", 0.667),
             self.config.get("length_scale", 1.0),
             self.config.get("inference_noise_scale_dp", 1.0), ],
            dtype=np.float32,
        )
        input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}

        audio = self.onnx_sess.run(
            ["output"],
            input_params,
        )
        return audio[0][0]

    @staticmethod
    def save_wav(wav: np.ndarray, path: str, sample_rate: int = 16000) -> None:
        """Save float waveform to a file using Scipy.

        Args:
            wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
            path (str): Path to a output file.
        """
        wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
        wav_norm = wav_norm.astype(np.int16)
        scipy.io.wavfile.write(path, sample_rate, wav_norm)

    def synth(self, text: str, path: str):
        wavs = self.inference_onnx(text)
        self.save_wav(wavs[0], path, self.config.get("sample_rate", 16000))
```
Downloads last month
3
Inference API
Unable to determine this model's library. Check the docs .

Model tree for Jarbas/proxectonos-celtia-vits-graphemes-onnx

Quantized
(1)
this model