DuyTa's picture
Upload folder using huggingface_hub
aeda668 verified
raw
history blame
10 kB
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
from argparse import ArgumentParser
from collections import OrderedDict
from typing import List
from text_processing.data_loader_utils import post_process_punctuation, pre_process
from text_processing.token_parser import PRESERVE_ORDER_KEY, TokenParser
from tqdm import tqdm
try:
import pynini
PYNINI_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
PYNINI_AVAILABLE = False
try:
from text_processing.moses_tokenizers import MosesProcessor
NLP_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
NLP_AVAILABLE = False
class Normalizer:
"""
Normalizer class that converts text from written to spoken form.
Useful for TTS preprocessing.
Args:
input_case: expected input capitalization
lang: language specifying the TN rules, by default: English
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
overwrite_cache: set to True to overwrite .far files
whitelist: path to a file with whitelist replacements
"""
def __init__(
self,
input_case: str,
lang: str = 'en',
deterministic: bool = True,
cache_dir: str = None,
overwrite_cache: bool = False,
whitelist: str = None,
):
raise NotImplementedError
# self.tagger = ClassifyFst(
# input_case=input_case,
# deterministic=deterministic,
# cache_dir=cache_dir,
# overwrite_cache=overwrite_cache,
# whitelist=whitelist,
# )
# self.verbalizer = VerbalizeFinalFst(deterministic=deterministic)
# self.parser = TokenParser()
# self.lang = lang
# if NLP_AVAILABLE:
# self.processor = MosesProcessor(lang_id=lang)
# else:
# self.processor = None
# print("NeMo NLP is not available. Moses de-tokenization will be skipped.")
def normalize_list(self, texts: List[str], verbose=False) -> List[str]:
"""
NeMo text normalizer
Args:
texts: list of input strings
verbose: whether to print intermediate meta information
Returns converted list input strings
"""
res = []
for input in tqdm(texts):
try:
text = self.normalize(input, verbose=verbose)
except:
print(input)
raise Exception
res.append(text)
return res
def normalize(
self, text: str, verbose: bool = False, punct_pre_process: bool = False, punct_post_process: bool = False
) -> str:
"""
Main function. Normalizes tokens from written to spoken form
e.g. 12 kg -> twelve kilograms
Args:
text: string that may include semiotic classes
verbose: whether to print intermediate meta information
punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
punct_post_process: whether to normalize punctuation
Returns: spoken form
"""
if punct_pre_process:
text = pre_process(text)
text = text.strip()
if not text:
if verbose:
print(text)
return text
text = pynini.escape(text)
tagged_lattice = self.find_tags(text)
# if verbose:
# print(tagged_lattice)
tagged_text = self.select_tag(tagged_lattice)
if verbose:
print(tagged_text)
self.parser(tagged_text)
tokens = self.parser.parse()
tags_reordered = self.generate_permutations(tokens)
for tagged_text in tags_reordered:
tagged_text = pynini.escape(tagged_text)
verbalizer_lattice = self.find_verbalizer(tagged_text)
if verbalizer_lattice.num_states() == 0:
continue
output = self.select_verbalizer(verbalizer_lattice)
if punct_post_process:
output = post_process_punctuation(output)
# do post-processing based on Moses detokenizer
if self.processor:
output = self.processor.detokenize([output])
return output
raise ValueError()
def _permute(self, d: OrderedDict) -> List[str]:
"""
Creates reorderings of dictionary elements and serializes as strings
Args:
d: (nested) dictionary of key value pairs
Return permutations of different string serializations of key value pairs
"""
l = []
if PRESERVE_ORDER_KEY in d.keys():
d_permutations = [d.items()]
else:
d_permutations = itertools.permutations(d.items())
for perm in d_permutations:
subl = [""]
for k, v in perm:
if isinstance(v, str):
subl = ["".join(x) for x in itertools.product(subl, [f"{k}: \"{v}\" "])]
elif isinstance(v, OrderedDict):
rec = self._permute(v)
subl = ["".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])]
elif isinstance(v, bool):
subl = ["".join(x) for x in itertools.product(subl, [f"{k}: true "])]
else:
raise ValueError()
l.extend(subl)
return l
def generate_permutations(self, tokens: List[dict]):
"""
Generates permutations of string serializations of list of dictionaries
Args:
tokens: list of dictionaries
Returns string serialization of list of dictionaries
"""
def _helper(prefix: str, tokens: List[dict], idx: int):
"""
Generates permutations of string serializations of given dictionary
Args:
tokens: list of dictionaries
prefix: prefix string
idx: index of next dictionary
Returns string serialization of dictionary
"""
if idx == len(tokens):
yield prefix
return
token_options = self._permute(tokens[idx])
for token_option in token_options:
yield from _helper(prefix + token_option, tokens, idx + 1)
return _helper("", tokens, 0)
def find_tags(self, text: str) -> 'pynini.FstLike':
"""
Given text use tagger Fst to tag text
Args:
text: sentence
Returns: tagged lattice
"""
lattice = text @ self.tagger.fst
return lattice
def select_tag(self, lattice: 'pynini.FstLike') -> str:
"""
Given tagged lattice return shortest path
Args:
tagged_text: tagged text
Returns: shortest path
"""
tagged_text = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
return tagged_text
def find_verbalizer(self, tagged_text: str) -> 'pynini.FstLike':
"""
Given tagged text creates verbalization lattice
This is context-independent.
Args:
tagged_text: input text
Returns: verbalized lattice
"""
lattice = tagged_text @ self.verbalizer.fst
return lattice
def select_verbalizer(self, lattice: 'pynini.FstLike') -> str:
"""
Given verbalized lattice return shortest path
Args:
lattice: verbalization lattice
Returns: shortest path
"""
output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
return output
def parse_args():
parser = ArgumentParser()
parser.add_argument("input_string", help="input string", type=str)
parser.add_argument("--language", help="language", choices=["en"], default="en", type=str)
parser.add_argument(
"--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
)
parser.add_argument("--verbose", help="print info for debugging", action='store_true')
parser.add_argument(
"--punct_post_process", help="set to True to enable punctuation post processing", action="store_true"
)
parser.add_argument(
"--punct_pre_process", help="set to True to enable punctuation pre processing", action="store_true"
)
parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
parser.add_argument(
"--cache_dir",
help="path to a dir with .far grammar file. Set to None to avoid using cache",
default=None,
type=str,
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
normalizer = Normalizer(
input_case=args.input_case, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache, whitelist=whitelist
)
print(
normalizer.normalize(
args.input_string,
verbose=args.verbose,
punct_pre_process=args.punct_pre_process,
punct_post_process=args.punct_post_process,
)
)