Spaces:
Running
on
Zero
Running
on
Zero
# adapted from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py | |
# Copyright (c) 2022 OpenAI | |
# MIT License for this file | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import base64 | |
import os | |
import string | |
from dataclasses import dataclass, field | |
from functools import cached_property, lru_cache | |
from typing import Dict, List, Optional, Tuple | |
import tiktoken | |
LANGUAGES = { | |
"en": "english", | |
"zh": "chinese", | |
"de": "german", | |
"es": "spanish", | |
"ru": "russian", | |
"ko": "korean", | |
"fr": "french", | |
"ja": "japanese", | |
"pt": "portuguese", | |
"tr": "turkish", | |
"pl": "polish", | |
"ca": "catalan", | |
"nl": "dutch", | |
"ar": "arabic", | |
"sv": "swedish", | |
"it": "italian", | |
"id": "indonesian", | |
"hi": "hindi", | |
"fi": "finnish", | |
"vi": "vietnamese", | |
"he": "hebrew", | |
"uk": "ukrainian", | |
"el": "greek", | |
"ms": "malay", | |
"cs": "czech", | |
"ro": "romanian", | |
"da": "danish", | |
"hu": "hungarian", | |
"ta": "tamil", | |
"no": "norwegian", | |
"th": "thai", | |
"ur": "urdu", | |
"hr": "croatian", | |
"bg": "bulgarian", | |
"lt": "lithuanian", | |
"la": "latin", | |
"mi": "maori", | |
"ml": "malayalam", | |
"cy": "welsh", | |
"sk": "slovak", | |
"te": "telugu", | |
"fa": "persian", | |
"lv": "latvian", | |
"bn": "bengali", | |
"sr": "serbian", | |
"az": "azerbaijani", | |
"sl": "slovenian", | |
"kn": "kannada", | |
"et": "estonian", | |
"mk": "macedonian", | |
"br": "breton", | |
"eu": "basque", | |
"is": "icelandic", | |
"hy": "armenian", | |
"ne": "nepali", | |
"mn": "mongolian", | |
"bs": "bosnian", | |
"kk": "kazakh", | |
"sq": "albanian", | |
"sw": "swahili", | |
"gl": "galician", | |
"mr": "marathi", | |
"pa": "punjabi", | |
"si": "sinhala", | |
"km": "khmer", | |
"sn": "shona", | |
"yo": "yoruba", | |
"so": "somali", | |
"af": "afrikaans", | |
"oc": "occitan", | |
"ka": "georgian", | |
"be": "belarusian", | |
"tg": "tajik", | |
"sd": "sindhi", | |
"gu": "gujarati", | |
"am": "amharic", | |
"yi": "yiddish", | |
"lo": "lao", | |
"uz": "uzbek", | |
"fo": "faroese", | |
"ht": "haitian creole", | |
"ps": "pashto", | |
"tk": "turkmen", | |
"nn": "nynorsk", | |
"mt": "maltese", | |
"sa": "sanskrit", | |
"lb": "luxembourgish", | |
"my": "myanmar", | |
"bo": "tibetan", | |
"tl": "tagalog", | |
"mg": "malagasy", | |
"as": "assamese", | |
"tt": "tatar", | |
"haw": "hawaiian", | |
"ln": "lingala", | |
"ha": "hausa", | |
"ba": "bashkir", | |
"jw": "javanese", | |
"su": "sundanese", | |
"yue": "cantonese", | |
} | |
# language code lookup by name, with a few language aliases | |
TO_LANGUAGE_CODE = { | |
**{language: code for code, language in LANGUAGES.items()}, | |
"burmese": "my", | |
"valencian": "ca", | |
"flemish": "nl", | |
"haitian": "ht", | |
"letzeburgesch": "lb", | |
"pushto": "ps", | |
"panjabi": "pa", | |
"moldavian": "ro", | |
"moldovan": "ro", | |
"sinhalese": "si", | |
"castilian": "es", | |
"mandarin": "zh", | |
} | |
class Tokenizer: | |
"""A thin wrapper around `tiktoken` providing quick access to special tokens""" | |
encoding: tiktoken.Encoding | |
num_languages: int | |
language: Optional[str] = None | |
task: Optional[str] = None | |
sot_sequence: Tuple[int] = () | |
special_tokens: Dict[str, int] = field(default_factory=dict) | |
def __post_init__(self): | |
for special in self.encoding.special_tokens_set: | |
special_token = self.encoding.encode_single_token(special) | |
self.special_tokens[special] = special_token | |
sot: int = self.special_tokens["[startoftranscript]"] | |
translate: int = self.special_tokens["[translate]"] | |
transcribe: int = self.special_tokens["[transcribe]"] | |
langs = tuple(LANGUAGES.keys())[: self.num_languages] | |
sot_sequence = [sot] | |
if self.language is not None: | |
sot_sequence.append(sot + 1 + langs.index(self.language)) | |
if self.task is not None: | |
task_token: int = transcribe if self.task == "transcribe" else translate | |
sot_sequence.append(task_token) | |
self.sot_sequence = tuple(sot_sequence) | |
def get_vocab_size(self): | |
return self.encoding.n_vocab | |
def encode(self, text): | |
return self.encoding.encode(text, allowed_special="all") | |
def decode(self, token_ids: List[int], **kwargs) -> str: | |
return self.encoding.decode(token_ids, **kwargs) | |
def eot(self) -> int: | |
return self.encoding.eot_token | |
def stop(self) -> int: | |
return self.special_tokens["[STOP]"] | |
def start(self) -> int: | |
return self.special_tokens["[START]"] | |
def transcribe(self) -> int: | |
return self.special_tokens["[transcribe]"] | |
def translate(self) -> int: | |
return self.special_tokens["[translate]"] | |
def sot(self) -> int: | |
return self.special_tokens["[startoftranscript]"] | |
def sot_lm(self) -> int: | |
return self.special_tokens["[startoflm]"] | |
def sot_prev(self) -> int: | |
return self.special_tokens["[startofprev]"] | |
def no_speech(self) -> int: | |
return self.special_tokens["[nospeech]"] | |
def language_token(self) -> int: | |
"""Returns the token id corresponding to the value of the `language` field""" | |
if self.language is None: | |
raise ValueError("This tokenizer does not have language token configured") | |
return self.to_language_token(self.language) | |
def to_language_token(self, language): | |
if token := self.special_tokens.get(f"[{language}]", None): | |
return token | |
raise KeyError(f"Language {language} not found in tokenizer.") | |
def all_language_tokens(self) -> Tuple[int]: | |
result = [] | |
for token, token_id in self.special_tokens.items(): | |
if token.strip("[]") in LANGUAGES: | |
result.append(token_id) | |
return tuple(result)[: self.num_languages] | |
def all_language_codes(self) -> Tuple[str]: | |
return tuple(self.decode([_l]).strip("[]") for _l in self.all_language_tokens) | |
def non_speech_tokens(self) -> Tuple[int]: | |
""" | |
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech | |
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. | |
- ♪♪♪ | |
- ( SPEAKING FOREIGN LANGUAGE ) | |
- [DAVID] Hey there, | |
keeping basic punctuations like commas, periods, question marks, exclamation points, etc. | |
""" | |
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') | |
symbols += ( | |
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() | |
) | |
# symbols that may be a single token or multiple tokens depending on the tokenizer. | |
# In case they're multiple tokens, suppress the first token, which is safe because: | |
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress | |
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes. | |
miscellaneous = set("♩♪♫♬♭♮♯") | |
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) | |
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word | |
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} | |
for symbol in symbols + list(miscellaneous): | |
for tokens in [ | |
self.encoding.encode(symbol), | |
self.encoding.encode(" " + symbol), | |
]: | |
if len(tokens) == 1 or symbol in miscellaneous: | |
result.add(tokens[0]) | |
return tuple(sorted(result)) | |
def split_to_word_tokens(self, tokens: List[int]): | |
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: | |
# These languages don't typically use spaces, so it is difficult to split words | |
# without morpheme analysis. Here, we instead split words at any | |
# position where the tokens are decoded as valid unicode points | |
return self.split_tokens_on_unicode(tokens) | |
return self.split_tokens_on_spaces(tokens) | |
def split_tokens_on_unicode(self, tokens: List[int]): | |
decoded_full = self.decode(tokens) | |
replacement_char = "\ufffd" | |
words = [] | |
word_tokens = [] | |
current_tokens = [] | |
unicode_offset = 0 | |
for token in tokens: | |
current_tokens.append(token) | |
decoded = self.decode(current_tokens) | |
if ( | |
replacement_char not in decoded | |
or decoded_full[unicode_offset + decoded.index(replacement_char)] | |
== replacement_char | |
): | |
words.append(decoded) | |
word_tokens.append(current_tokens) | |
current_tokens = [] | |
unicode_offset += len(decoded) | |
return words, word_tokens | |
def split_tokens_on_spaces(self, tokens: List[int]): | |
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) | |
words = [] | |
word_tokens = [] | |
for subword, subword_tokens in zip(subwords, subword_tokens_list): | |
special = subword_tokens[0] >= self.eot | |
with_space = subword.startswith(" ") | |
punctuation = subword.strip() in string.punctuation | |
if special or with_space or punctuation or len(words) == 0: | |
words.append(subword) | |
word_tokens.append(subword_tokens) | |
else: | |
words[-1] = words[-1] + subword | |
word_tokens[-1].extend(subword_tokens) | |
return words, word_tokens | |
def get_encoding(name: str = "multilingual", num_languages: int = 100): | |
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") | |
ranks = { | |
base64.b64decode(token): int(rank) | |
for token, rank in (line.split() for line in open(vocab_path) if line) | |
} | |
n_vocab = len(ranks) | |
special_tokens = {} | |
specials = [ | |
"[STOP]", | |
"[UNK]", | |
"[SPACE]", | |
"[START]", | |
"[nospk]", | |
"[spkemb]", | |
"[emotionemb]", | |
"[contextemb]", | |
"[sbreak]", | |
"[pbreak]", | |
"[uvbreak]", | |
"[bsing]", | |
"[esing]", | |
"[sing]", | |
"[hum]", | |
"[laugh]", | |
"[break]", | |
"[breath]", | |
"[oralsii]", | |
"[oralze]", | |
"[prolong]", | |
"[stress]", | |
"[bstrong]", | |
"[estrong]", | |
"[hiccup]", | |
"[inhale]", | |
"[exhale]", | |
"[emounknown]", | |
"[happy]", | |
"[neutral]", | |
"[sad]", | |
"[surprise]", | |
"[angry]", | |
"[disgust]", | |
"[emo]", | |
"[laugha]", | |
"[laughb]", | |
"[laughc]", | |
"[orala]", | |
"[oralb]", | |
"[oralc]", | |
"[orald]", | |
"[orale]", | |
"[breaka]", | |
"[breakb]", | |
"[breakc]", | |
"[breakd]", | |
"[breake]", | |
"[breakf]", | |
"[endoftext]", | |
"[startoftranscript]", | |
*[f"[{lang}]" for lang in list(LANGUAGES.keys())[:num_languages]], | |
"[translate]", | |
"[transcribe]", | |
"[startoflm]", | |
"[startofprev]", | |
"[nospeech]", | |
] | |
for token in specials: | |
special_tokens[token] = n_vocab | |
n_vocab += 1 | |
return tiktoken.Encoding( | |
name=os.path.basename(vocab_path), | |
explicit_n_vocab=n_vocab, | |
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d|\[[A-Z]+\]|\[[a-z]+\]|[\x{4e00}-\x{9df5}]| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", | |
mergeable_ranks=ranks, | |
special_tokens=special_tokens, | |
) | |
def get_tokenizer( | |
multilingual: bool, | |
*, | |
num_languages: int = 100, | |
language: Optional[str] = None, | |
task: Optional[str] = None, # Literal["transcribe", "translate", None] | |
) -> Tokenizer: | |
if language is not None: | |
language = language.lower() | |
if language not in LANGUAGES: | |
if language in TO_LANGUAGE_CODE: | |
language = TO_LANGUAGE_CODE[language] | |
else: | |
raise ValueError(f"Unsupported language: {language}") | |
if multilingual: | |
encoding_name = "multilingual" | |
language = language or "en" | |
task = task or "transcribe" | |
else: | |
encoding_name = "gpt2" | |
language = None | |
task = None | |
encoding = get_encoding(name=encoding_name, num_languages=num_languages) | |
return Tokenizer( | |
encoding=encoding, num_languages=num_languages, language=language, task=task | |
) | |