Spaces:
Configuration error
Configuration error
__author__ = "Jérôme Louradour" | |
__credits__ = ["Jérôme Louradour"] | |
__license__ = "GPLv3" | |
import unittest | |
import sys | |
import os | |
import subprocess | |
import shutil | |
import tempfile | |
import json | |
import torch | |
import jsonschema | |
FAIL_IF_REFERENCE_NOT_FOUND = True | |
GENERATE_NEW_ONLY = False | |
GENERATE_ALL = False | |
GENERATE_DEVICE_DEPENDENT = False | |
SKIP_LONG_TEST_IF_CPU = True | |
CMD_OPTIONS = [] | |
class TestHelper(unittest.TestCase): | |
def skipLongTests(self): | |
return SKIP_LONG_TEST_IF_CPU and not torch.cuda.is_available() | |
def setUp(self): | |
self.maxDiff = None | |
self.createdReferences = [] | |
def tearDown(self): | |
if GENERATE_ALL or GENERATE_NEW_ONLY or not FAIL_IF_REFERENCE_NOT_FOUND or GENERATE_DEVICE_DEPENDENT: | |
if len(self.createdReferences) > 0: | |
print("WARNING: Created references: " + | |
", ".join(self.createdReferences).replace(self.get_data_path()+"/", "")) | |
else: | |
self.assertEqual(self.createdReferences, [], "Created references: " + | |
", ".join(self.createdReferences).replace(self.get_data_path()+"/", "")) | |
def get_main_path(self, fn=None, check=False): | |
return self._get_path("whisper_timestamped", fn, check=check) | |
def get_output_path(self, fn=None): | |
if fn == None: | |
return tempfile.gettempdir() | |
return os.path.join(tempfile.gettempdir(), fn + self._extra_cmd_options()) | |
def get_expected_path(self, fn=None, check=False): | |
return self._get_path("tests/expected" + self._extra_cmd_options(), fn, check=check) | |
def _extra_cmd_options(self): | |
s = "".join([f.replace("-","").strip() for f in CMD_OPTIONS]) | |
if s: | |
return "." + s | |
return "" | |
def get_data_files(self, files=None, excluded_by_default=["apollo11.mp3", "music.mp4", "arabic.mp3", "japanese.mp3", "empty.wav", "words.wav"]): | |
if files == None: | |
files = os.listdir(self.get_data_path()) | |
files = [f for f in files if f not in excluded_by_default and not f.endswith("json")] | |
files = sorted(files) | |
return [self.get_data_path(fn) for fn in files] | |
def get_generated_files(self, input_filename, output_path, extensions): | |
for ext in extensions: | |
yield os.path.join(output_path, os.path.basename(input_filename) + "." + ext.lstrip(".")) | |
def main_script(self, pyscript = "transcribe.py", exename = "whisper_timestamped"): | |
main_script = self.get_main_path(pyscript, check=False) | |
if not os.path.exists(main_script): | |
main_script = exename | |
return main_script | |
def assertRun(self, cmd): | |
if isinstance(cmd, str): | |
return self.assertRun(cmd.split()) | |
curdir = os.getcwd() | |
os.chdir(tempfile.gettempdir()) | |
if cmd[0].endswith(".py"): | |
cmd = [sys.executable] + cmd | |
print("Running:", " ".join(cmd)) | |
p = subprocess.Popen(cmd, | |
# Otherwise ".local" path might be missing | |
env=dict( | |
os.environ, PYTHONPATH=os.pathsep.join(sys.path)), | |
stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
) | |
os.chdir(curdir) | |
(stdout, stderr) = p.communicate() | |
self.assertEqual(p.returncode, 0, msg=stderr.decode("utf-8")) | |
return (stdout.decode("utf-8"), stderr.decode("utf-8")) | |
def assertNonRegression(self, content, reference, string_is_file=True): | |
""" | |
Check that a file/folder is the same as a reference file/folder. | |
""" | |
if isinstance(content, dict): | |
# Make a temporary file | |
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", encoding="utf8", delete=False) as f: | |
json.dump(content, f, indent=2, ensure_ascii=False) | |
content = f.name | |
res = self.assertNonRegression(f.name, reference) | |
os.remove(f.name) | |
return res | |
elif not isinstance(content, str): | |
raise ValueError(f"Invalid content type: {type(content)}") | |
if not string_is_file: | |
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", encoding="utf8", delete=False) as f: | |
f.write(content) | |
content = f.name | |
res = self.assertNonRegression(f.name, reference) | |
os.remove(f.name) | |
return res | |
self.assertTrue(os.path.exists(content), f"Missing file: {content}") | |
is_file = os.path.isfile(reference) if os.path.exists(reference) else os.path.isfile(content) | |
reference = self.get_expected_path( | |
reference, check=FAIL_IF_REFERENCE_NOT_FOUND) | |
if not os.path.exists(reference) or ((GENERATE_ALL or GENERATE_DEVICE_DEPENDENT) and reference not in self.createdReferences): | |
dirname = os.path.dirname(reference) | |
if not os.path.isdir(dirname): | |
os.makedirs(dirname) | |
if is_file: | |
shutil.copyfile(content, reference) | |
else: | |
shutil.copytree(content, reference) | |
self.createdReferences.append(reference) | |
if is_file: | |
self.assertTrue(os.path.isfile(content)) | |
self._check_file_non_regression(content, reference) | |
else: | |
self.assertTrue(os.path.isdir(content)) | |
for root, dirs, files in os.walk(content): | |
for f in files: | |
f_ref = os.path.join(reference, f) | |
self.assertTrue(os.path.isfile(f_ref), | |
f"Additional file: {f}") | |
self._check_file_non_regression( | |
os.path.join(root, f), f_ref) | |
for root, dirs, files in os.walk(reference): | |
for f in files: | |
f = os.path.join(content, f) | |
self.assertTrue(os.path.isfile(f), f"Missing file: {f}") | |
def get_data_path(self, fn=None, check=True): | |
return self._get_path("tests/data", fn, check) | |
def _get_path(self, prefix, fn=None, check=True): | |
path = os.path.join( | |
os.path.dirname(os.path.dirname(__file__)), | |
prefix | |
) | |
if fn: | |
path = os.path.join(path, fn) | |
if check: | |
self.assertTrue(os.path.exists(path), f"Cannot find {path}") | |
return path | |
def _check_file_non_regression(self, file, reference): | |
if file.endswith(".json"): | |
with open(file) as f: | |
content = json.load(f) | |
with open(reference) as f: | |
reference_content = json.load(f) | |
if "language" in content and "language" in reference_content: | |
content["language"] = self.norm_language(content["language"]) | |
reference_content["language"] = self.norm_language(reference_content["language"]) | |
self.assertClose(content, reference_content, | |
msg=f"File {file} does not match reference {reference}") | |
return | |
with open(file) as f: | |
content = f.readlines() | |
with open(reference) as f: | |
reference_content = f.readlines() | |
self.assertEqual(content, reference_content, | |
msg=f"File {file} does not match reference {reference}") | |
def assertClose(self, obj1, obj2, msg=None): | |
return self.assertEqual(self.loose(obj1), self.loose(obj2), msg=msg) | |
def loose(self, obj): | |
# Return an approximative value of an object | |
if isinstance(obj, list): | |
return [self.loose(a) for a in obj] | |
if isinstance(obj, float): | |
f = round(obj, 1) | |
return 0.0 if f == -0.0 else f | |
if isinstance(obj, dict): | |
return {k: self.loose(v) for k, v in obj.items()} | |
if isinstance(obj, tuple): | |
return tuple(self.loose(list(obj))) | |
if isinstance(obj, set): | |
return self.loose(list(obj), "set") | |
return obj | |
def get_audio_duration(self, audio_file): | |
# Get the duration in sec *without introducing additional dependencies* | |
import whisper | |
return len(whisper.load_audio(audio_file)) / whisper.audio.SAMPLE_RATE | |
def get_device_str(self): | |
import torch | |
return "cpu" if not torch.cuda.is_available() else "cuda" | |
def norm_language(self, language): | |
# Cheap custom stuff to avoid importing everything | |
return { | |
"japanese": "ja", | |
}.get(language.lower(), language) | |
class TestHelperCli(TestHelper): | |
json_schema = None | |
def _test_cli_(self, opts, name, files=None, extensions=["words.json"], prefix=None, one_per_call=True, device_specific=None): | |
""" | |
Test command line | |
opts: list of options | |
name: name of the test | |
files: list of files to process | |
extensions: list of extensions to check, or None to test the stdout | |
prefix: prefix to add to the reference files | |
one_per_call: if True, each file is processed separately, otherwise all files are processed by a single process | |
""" | |
opts = opts + CMD_OPTIONS | |
output_dir = self.get_output_path(name) | |
input_filenames = self.get_data_files(files) | |
for i, input_filename in enumerate(input_filenames): | |
# Butterfly effect: Results are different depending on the device for long files | |
duration = self.get_audio_duration(input_filename) | |
if device_specific is None: | |
device_dependent = duration > 60 or (duration > 30 and "tiny_fr" in name) or ("empty" in input_filename and "medium_auto" in name) | |
else: | |
device_dependent = device_specific | |
name_ = name | |
if device_dependent and self.get_device_str() != "cuda": | |
name_ += f".{self.get_device_str()}" | |
def ref_name(output_filename): | |
return name_ + "/" + (f"{prefix}_" if prefix else "") + os.path.basename(output_filename) | |
generic_name = ref_name(input_filename + ".*") | |
if GENERATE_DEVICE_DEPENDENT and not device_dependent: | |
print("Skipping non-regression test", generic_name) | |
continue | |
if GENERATE_NEW_ONLY and min([os.path.exists(self.get_expected_path(ref_name(output_filename))) | |
for output_filename in self.get_generated_files(input_filename, output_dir, extensions=extensions)] | |
): | |
print("Skipping non-regression test", generic_name) | |
continue | |
print("Running non-regression test", generic_name) | |
if one_per_call or i == 0: | |
if one_per_call: | |
(stdout, stderr) = self.assertRun([self.main_script(), input_filename, "--output_dir", output_dir, *opts]) | |
else: | |
(stdout, stderr) = self.assertRun([self.main_script(), *input_filenames, "--output_dir", output_dir, *opts]) | |
print(stdout) | |
print(stderr) | |
output_json = self.get_generated_files(input_filename, output_dir, extensions=["words.json"]).__next__() | |
if os.path.isfile(output_json): | |
self.check_json(output_json) | |
if extensions is None: | |
output_filename = list(self.get_generated_files(input_filename, output_dir, extensions=["stdout"]))[0] | |
self.assertNonRegression(stdout, ref_name(output_filename), string_is_file=False) | |
else: | |
for output_filename in self.get_generated_files(input_filename, output_dir, extensions=extensions): | |
self.assertNonRegression(output_filename, ref_name(output_filename)) | |
shutil.rmtree(output_dir, ignore_errors=True) | |
def check_json(self, json_file): | |
with open(json_file) as f: | |
content = json.load(f) | |
if self.json_schema is None: | |
schema_file = os.path.join(os.path.dirname(__file__), "json_schema.json") | |
self.assertTrue(os.path.isfile(schema_file), msg=f"Schema file {schema_file} not found") | |
self.json_schema = json.load(open(schema_file)) | |
jsonschema.validate(instance=content, schema=self.json_schema) | |
class TestTranscribeTiny(TestHelperCli): | |
def test_cli_tiny_auto(self): | |
self._test_cli_( | |
["--model", "tiny"], | |
"tiny_auto", | |
) | |
def test_cli_tiny_fr(self): | |
self._test_cli_( | |
["--model", "tiny", "--language", "fr"], | |
"tiny_fr", | |
) | |
class TestTranscribeMedium(TestHelperCli): | |
def test_cli_medium_auto(self): | |
self._test_cli_( | |
["--model", "medium"], | |
"medium_auto", | |
) | |
def test_cli_medium_fr(self): | |
self._test_cli_( | |
["--model", "medium", "--language", "fr"], | |
"medium_fr", | |
) | |
class TestTranscribeNaive(TestHelperCli): | |
def test_naive(self): | |
self._test_cli_( | |
["--model", "small", "--language", "en", "--efficient", "--naive"], | |
"naive", | |
files=["apollo11.mp3"], | |
prefix="naive", | |
) | |
self._test_cli_( | |
["--model", "small", "--language", "en", "--accurate"], | |
"naive", | |
files=["apollo11.mp3"], | |
prefix="accurate", | |
) | |
def test_stucked_segments(self): | |
self._test_cli_( | |
["--model", "tiny"], | |
"corner_cases", | |
files=["apollo11.mp3"], | |
prefix="accurate.tiny", | |
) | |
class TestTranscribeCornerCases(TestHelperCli): | |
def test_stucked_lm(self): | |
if self.skipLongTests(): | |
return | |
self._test_cli_( | |
["--model", "small", "--language", "en", "--efficient"], | |
"corner_cases", | |
files=["apollo11.mp3"], | |
prefix="stucked_lm", | |
) | |
def test_punctuation_only(self): | |
# When there is only a punctuation detected in a segment, it could cause issue #24 | |
self._test_cli_( | |
["--model", "medium.en", "--efficient", "--punctuations", "False"], | |
"corner_cases", | |
files=["empty.wav"], | |
prefix="issue24", | |
) | |
def test_temperature(self): | |
self._test_cli_( | |
["--model", "small", "--language", "English", | |
"--condition", "False", "--temperature", "0.1", "--efficient"], | |
"corner_cases", | |
files=["apollo11.mp3"], | |
prefix="random.nocond", | |
) | |
if self.skipLongTests(): | |
return | |
self._test_cli_( | |
["--model", "small", "--language", "en", "--temperature", "0.2", "--efficient"], | |
"corner_cases", | |
files=["apollo11.mp3"], | |
prefix="random", | |
) | |
def test_not_conditioned(self): | |
if not os.path.exists(self.get_data_path("music.mp4", check=False)): | |
return | |
if self.skipLongTests(): | |
return | |
self._test_cli_( | |
["--model", "medium", "--language", "en", "--condition", "False", "--efficient"], | |
"corner_cases", | |
files=["music.mp4"], | |
prefix="nocond", | |
) | |
self._test_cli_( | |
["--model", "medium", "--language", "en", | |
"--condition", "False", "--temperature", "0.4", "--efficient"], | |
"corner_cases", | |
files=["music.mp4"], | |
prefix="nocond.random", | |
) | |
def test_large(self): | |
if self.skipLongTests(): | |
return | |
self._test_cli_( | |
["--model", "large-v2", "--language", "en", | |
"--condition", "False", "--temperature", "0.4", "--efficient"], | |
"corner_cases", | |
files=["apollo11.mp3"], | |
prefix="large", | |
) | |
if os.path.exists(self.get_data_path("arabic.mp3", check=False)): | |
self._test_cli_( | |
["--model", "large-v2", "--language", "Arabic", "--efficient"], | |
"corner_cases", | |
files=["arabic.mp3"] | |
) | |
def test_gloria(self): | |
for model in ["medium", "large-v2"]: | |
for dec in ["efficient", "accurate"]: | |
self._test_cli_( | |
["--model", model, "--language", "en", "--" + dec], | |
"corner_cases", | |
files=["gloria.mp3"], | |
prefix=model + "." + dec, | |
) | |
class TestTranscribeMonolingual(TestHelperCli): | |
def test_monolingual_tiny(self): | |
files = ["bonjour_vous_allez_bien.mp3"] | |
self._test_cli_( | |
["--model", "tiny.en", "--efficient"], | |
"tiny.en", | |
files=files, | |
prefix="efficient", | |
) | |
self._test_cli_( | |
["--model", "tiny.en", "--accurate"], | |
"tiny.en", | |
files=files, | |
prefix="accurate", | |
) | |
self._test_cli_( | |
["--model", "tiny.en", "--condition", "False", "--efficient"], | |
"tiny.en", | |
files=files, | |
prefix="nocond", | |
) | |
def test_monolingual_small(self): | |
if os.path.exists(self.get_data_path("arabic.mp3", check=False)): | |
self._test_cli_( | |
["--model", "small.en", "--condition", "True", "--efficient"], | |
"small.en", | |
files=["arabic.mp3"], | |
device_specific=True, | |
) | |
class TestTranscribeWithVad(TestHelperCli): | |
def test_vad_default(self): | |
self._test_cli_( | |
["--model", "tiny", "--accurate", "--language", "en", "--vad", "True", "--verbose", "True"], | |
"verbose", | |
files=["words.wav"], | |
prefix="vad", | |
extensions=None, | |
) | |
def test_vad_custom_silero(self): | |
self._test_cli_( | |
["--model", "tiny", "--accurate", "--language", "en", "--vad", "silero:v3.1", "--verbose", "True"], | |
"verbose", | |
files=["words.wav"], | |
prefix="vad_silero3.1", | |
extensions=None, | |
) | |
self._test_cli_( | |
["--model", "tiny", "--accurate", "--language", "en", "--vad", "silero:v3.0", "--verbose", "True"], | |
"verbose", | |
files=["words.wav"], | |
prefix="vad_silero3.0", | |
extensions=None, | |
) | |
def test_vad_custom_auditok(self): | |
self._test_cli_( | |
["--model", "tiny", "--language", "en", "--vad", "auditok", "--verbose", "True"], | |
"verbose", | |
files=["words.wav"], | |
prefix="vad_auditok", | |
extensions=None, | |
) | |
class TestTranscribeUnspacedLanguage(TestHelperCli): | |
def test_japanese(self): | |
self._test_cli_( | |
["--model", "tiny", "--efficient"], | |
"tiny_auto", | |
files=["japanese.mp3"], | |
device_specific=True, | |
) | |
self._test_cli_( | |
["--model", "tiny", "--language", "Japanese", "--efficient"], | |
"tiny_auto", | |
files=["japanese.mp3"], | |
prefix="jp", | |
device_specific=True, | |
) | |
self._test_cli_( | |
["--model", "tiny", "--accurate"], | |
"tiny_auto", | |
files=["japanese.mp3"], | |
prefix="accurate", | |
device_specific=True, | |
) | |
self._test_cli_( | |
["--model", "tiny", "--language", "Japanese", "--accurate"], | |
"tiny_auto", | |
files=["japanese.mp3"], | |
prefix="accurate_jp", | |
device_specific=True, | |
) | |
class TestTranscribeFormats(TestHelperCli): | |
def test_cli_outputs(self): | |
files = ["punctuations.mp3", "bonjour.wav"] | |
extensions = ["txt", "srt", "vtt", "words.srt", "words.vtt", | |
"words.json", "csv", "words.csv", "tsv", "words.tsv"] | |
opts = ["--model", "medium", "--language", "fr"] | |
# An audio / model combination that produces coma | |
self._test_cli_( | |
opts, | |
"punctuations_yes", | |
files=files, | |
extensions=extensions, | |
one_per_call=False, | |
) | |
self._test_cli_( | |
opts + ["--punctuations", "False"], | |
"punctuations_no", | |
files=files, | |
extensions=extensions, | |
one_per_call=False, | |
) | |
def test_verbose(self): | |
files = ["bonjour_vous_allez_bien.mp3"] | |
opts = ["--model", "tiny", "--verbose", "True"] | |
self._test_cli_( | |
["--efficient", *opts], | |
"verbose", files=files, extensions=None, | |
prefix="efficient.auto", | |
device_specific=True, | |
) | |
self._test_cli_( | |
["--language", "fr", "--efficient", *opts], | |
"verbose", files=files, extensions=None, | |
prefix="efficient.fr", | |
device_specific=True, | |
) | |
self._test_cli_( | |
opts, | |
"verbose", files=files, extensions=None, | |
prefix="accurate.auto", | |
device_specific=True, | |
) | |
self._test_cli_( | |
["--language", "fr", *opts], | |
"verbose", files=files, extensions=None, | |
prefix="accurate.fr", | |
device_specific=True, | |
) | |
class TestMakeSubtitles(TestHelper): | |
def test_make_subtitles(self): | |
main_script = self.main_script("make_subtitles.py", "whisper_timestamped_make_subtitles") | |
inputs = [ | |
self.get_data_path("smartphone.mp3.words.json"), | |
self.get_data_path("no_punctuations.mp3.words.json", check=True), | |
self.get_data_path("yes_punctuations.mp3.words.json", check=True), | |
] | |
for i, input in enumerate(inputs): | |
filename = os.path.basename(input).replace(".words.json", "") | |
for len in 6, 20, 50: | |
output_dir = self.get_output_path() | |
self.assertRun([main_script, | |
input if i > 0 else self.get_data_path(), output_dir, | |
"--max_length", str(len), | |
]) | |
for format in "vtt", "srt",: | |
output_file = os.path.join(output_dir, f"{filename}.{format}") | |
self.assertTrue(os.path.isfile(output_file), msg=f"File {output_file} not found") | |
expected_file = f"split_subtitles/{filename.split('_')[-1]}_{len}.{format}" | |
self.assertNonRegression(output_file, expected_file) | |
os.remove(output_file) | |
self.assertRun([main_script, | |
input, output_file, | |
"--max_length", str(len), | |
]) | |
self.assertTrue(os.path.isfile(output_file), msg=f"File {output_file} not found") | |
self.assertNonRegression(output_file, expected_file) | |
class TestHuggingFaceModel(TestHelperCli): | |
def test_hugging_face_model(self): | |
self._test_cli_( | |
["--model", "qanastek/whisper-tiny-french-cased", "--verbose", "True"], | |
"verbose", files=["bonjour.wav"], extensions=None, | |
prefix="hf", | |
device_specific=True, | |
) | |
# "ZZZ" to run this test at last (because it will fill the CUDA with some memory) | |
class TestZZZPythonImport(TestHelper): | |
def test_python_import(self): | |
try: | |
import whisper_timestamped | |
except ModuleNotFoundError: | |
sys.path.append(os.path.realpath( | |
os.path.dirname(os.path.dirname(__file__)))) | |
import whisper_timestamped | |
# Test version | |
version = whisper_timestamped.__version__ | |
self.assertTrue(isinstance(version, str)) | |
(stdout, sterr) = self.assertRun([self.main_script(), "-v"]) | |
self.assertEqual(stdout.strip(), version) | |
model = whisper_timestamped.load_model("tiny") | |
# Check processing of different files | |
for filename in "bonjour.wav", "laugh1.mp3", "laugh2.mp3": | |
res = whisper_timestamped.transcribe( | |
model, self.get_data_path(filename)) | |
if self._can_generate_reference(): | |
self.assertNonRegression(res, f"tiny_auto/{filename}.words.json") | |
for filename in "bonjour.wav", "laugh1.mp3", "laugh2.mp3": | |
res = whisper_timestamped.transcribe( | |
model, self.get_data_path(filename), language="fr") | |
if self._can_generate_reference(): | |
self.assertNonRegression(res, f"tiny_fr/{filename}.words.json") | |
def _can_generate_reference(self): | |
return not GENERATE_DEVICE_DEPENDENT or self.get_device_str() != "cpu" | |
def test_split_tokens(self): | |
import whisper | |
whisperversion = whisper.__version__ | |
import whisper_timestamped as whisper | |
from whisper_timestamped.transcribe import split_tokens_on_spaces | |
tokenizer = whisper.tokenizer.get_tokenizer(True, language=None) | |
# 220 means space | |
tokens = [50364, 220, 6455, 11, 2232, 11, 286, 2041, 11, 2232, 11, 8660, | |
291, 808, 493, 220, 365, 11, 220, 445, 718, 505, 458, 13, 220, 50714] | |
self.assertEqual( | |
split_tokens_on_spaces(tokens, tokenizer), | |
(['<|0.00|>', 'So,', 'uh,', 'I', 'guess,', 'uh,', 'wherever', 'you', 'come', 'up', 'with,', 'just', 'let', 'us', 'know.', '<|7.00|>'], | |
[['<|0.00|>'], | |
[' ', 'So', ','], | |
[' uh', ','], | |
[' I'], | |
[' guess', ','], | |
[' uh', ','], | |
[' wherever'], | |
[' you'], | |
[' come'], | |
[' up'], | |
[' ', ' with', ','], | |
[' ', ' just'], | |
[' let'], | |
[' us'], | |
[' know', '.', ' '], | |
['<|7.00|>']], | |
[[50364], | |
[220, 6455, 11], | |
[2232, 11], | |
[286], | |
[2041, 11], | |
[2232, 11], | |
[8660], | |
[291], | |
[808], | |
[493], | |
[220, 365, 11], | |
[220, 445], | |
[718], | |
[505], | |
[458, 13, 220], | |
[50714] | |
]) | |
) | |
tokens = [50366, 314, 6, 11771, 17134, 11, 4666, 11, 1022, 220, 875, 2557, 68, 11, 6992, 631, 269, 6, 377, 220, 409, 7282, 1956, 871, 566, 2707, 394, 1956, 256, 622, 8208, 631, 8208, 871, 517, 7282, 1956, 5977, 7418, 371, 1004, 306, 580, 11, 5977, 12, 9498, 9505, 84, 6, 50416] | |
self.assertEqual( | |
split_tokens_on_spaces(tokens, tokenizer), | |
( | |
['<|0.04|>', "T'façon,", 'nous,', 'sur', 'la', 'touche,', 'parce', 'que', "c'est", 'un', 'sport', 'qui', 'est', 'important', 'qui', 'tue', 'deux', 'que', 'deux', 'est', 'un', 'sport', 'qui', 'peut', 'être', 'violent,', 'peut-être', "qu'", '<|1.04|>'], | |
[['<|0.04|>'], | |
[' T', "'", 'fa', 'çon', ','], | |
[' nous', ','], | |
[' sur'], | |
[' ', 'la'], | |
[' touch', 'e', ','], | |
[' parce'], | |
[' que'], | |
[' c', "'", 'est'], | |
[' ', 'un'], | |
[' sport'], | |
[' qui'], | |
[' est'], | |
[' im', 'port', 'ant'], | |
[' qui'], | |
[' t', 'ue'], | |
[' deux'], | |
[' que'], | |
[' deux'], | |
[' est'], | |
[' un'], | |
[' sport'], | |
[' qui'], | |
[' peut'], | |
[' être'], | |
[' v', 'io', 'le', 'nt', ','], | |
[' peut', '-', 'être'], | |
[' q', 'u', "'"], | |
['<|1.04|>']], | |
[[50366], | |
[314, 6, 11771, 17134, 11], | |
[4666, 11], | |
[1022], | |
[220, 875], | |
[2557, 68, 11], | |
[6992], | |
[631], | |
[269, 6, 377], | |
[220, 409], | |
[7282], | |
[1956], | |
[871], | |
[566, 2707, 394], | |
[1956], | |
[256, 622], | |
[8208], | |
[631], | |
[8208], | |
[871], | |
[517], | |
[7282], | |
[1956], | |
[5977], | |
[7418], | |
[371, 1004, 306, 580, 11], | |
[5977, 12, 9498], | |
[9505, 84, 6], | |
[50416]] | |
) | |
) | |
tokens = [50364, 220, 220, 6455, 11, 220, 220, 2232, 220, 220, 11, 220, 50714] | |
self.assertEqual( | |
split_tokens_on_spaces(tokens, tokenizer), | |
(['<|0.00|>', 'So,', 'uh', ',', '<|7.00|>'], | |
[['<|0.00|>'], | |
[' ', ' ', 'So', ','], | |
[' ', ' ', ' uh'], | |
[' ', ' ', ',', ' '], | |
['<|7.00|>']], | |
[[50364], [220, 220, 6455, 11], [220, 220, 2232], [220, 220, 11, 220], [50714]] | |
) | |
) | |
# Careful with the double spaces at the end... | |
tokens = [50364, 220, 220, 6455, 11, 220, 220, 2232, 220, 220, 11, 220, 220, 50714] | |
self.assertEqual( | |
split_tokens_on_spaces(tokens, tokenizer), | |
(['<|0.00|>', 'So,', 'uh', ',', '', '<|7.00|>'], | |
[['<|0.00|>'], | |
[' ', ' ', 'So', ','], | |
[' ', ' ', ' uh'], | |
[' ', ' ', ','], | |
[' ', ' '], | |
['<|7.00|>']], | |
[[50364], [220, 220, 6455, 11], [220, 220, 2232], [220, 220, 11], [220, 220], [50714]] | |
) | |
) | |
# Tokens that could be removed | |
tokens = [50364, 6024, 95, 8848, 7649, 8717, 38251, 11703, 3224, 51864] | |
self.assertEqual( | |
split_tokens_on_spaces(tokens, tokenizer), | |
(['<|0.00|>', 'الآذان', 'نسمّه', '<|30.00|>'], | |
[['<|0.00|>'], ['', ' الآ', 'ذ', 'ان'], [' ن', 'سم', 'ّ', 'ه'], ['<|30.00|>']], | |
[[50364], [6024, 95, 8848, 7649], [8717, 38251, 11703, 3224], [51864]] | |
) | |
) | |
# issue #61 | |
# Special tokens that are not timestamps | |
tokens = [50414, 805, 12, 17, 50299, 11, 568, 12, 18, 12, 21, 11, 502, 12, 17, 12, 51464] | |
# 50299 is "<|te|>" and appears as "" | |
te = "" | |
self.assertEqual( | |
split_tokens_on_spaces(tokens, tokenizer), | |
(['<|1.00|>', f'3-2{te},', '2-3-6,', '1-2-', '<|22.00|>'], | |
[['<|1.00|>'], [' 3', '-', '2', f'{te}', ','], [' 2', '-', '3', '-','6', ','], [' 1', '-', '2', '-'], ['<|22.00|>']], | |
[[50414], [805, 12, 17, 50299, 11], [568, 12, 18, 12, 21, 11], [502, 12, 17, 12], [51464]]) | |
) | |
tokenizer = whisper.tokenizer.get_tokenizer(False, language="en") | |
# Just a punctuation character | |
tokens = [50363, 764, 51813] | |
_dot = "." if whisperversion < "20230314" else " ." | |
self.assertEqual( | |
split_tokens_on_spaces(tokens, tokenizer), | |
(['<|0.00|>', ".", '<|29.00|>'], | |
[['<|0.00|>'], [_dot], ['<|29.00|>']], | |
[[50363], [764], [51813]] | |
) | |
) | |