mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- finetune_gradio.py +150 -7
finetune_gradio.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
|
|
|
|
|
4 |
from transformers import pipeline
|
5 |
import gradio as gr
|
6 |
import torch
|
|
|
7 |
import click
|
8 |
import torchaudio
|
9 |
from glob import glob
|
@@ -20,11 +23,16 @@ import psutil
|
|
20 |
import platform
|
21 |
import subprocess
|
22 |
from datasets.arrow_writer import ArrowWriter
|
|
|
|
|
23 |
|
24 |
|
25 |
training_process = None
|
26 |
system = platform.system()
|
27 |
python_executable = sys.executable or "python"
|
|
|
|
|
|
|
28 |
|
29 |
path_data = "data"
|
30 |
|
@@ -240,7 +248,12 @@ def start_training(
|
|
240 |
last_per_steps=800,
|
241 |
finetune=True,
|
242 |
):
|
243 |
-
global training_process
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
246 |
|
@@ -288,7 +301,7 @@ def start_training(
|
|
288 |
training_process = subprocess.Popen(cmd, shell=True)
|
289 |
|
290 |
time.sleep(5)
|
291 |
-
yield "
|
292 |
|
293 |
# Wait for the training process to finish
|
294 |
training_process.wait()
|
@@ -519,6 +532,17 @@ def calculate_train(
|
|
519 |
path_project = os.path.join(path_data, name_project)
|
520 |
file_duraction = os.path.join(path_project, "duration.json")
|
521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
with open(file_duraction, "r") as file:
|
523 |
data = json.load(file)
|
524 |
|
@@ -549,8 +573,8 @@ def calculate_train(
|
|
549 |
else:
|
550 |
max_samples = 64
|
551 |
|
552 |
-
num_warmup_updates = int(samples * 0.
|
553 |
-
save_per_updates = int(samples * 0.
|
554 |
last_per_steps = int(save_per_updates * 5)
|
555 |
|
556 |
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
@@ -559,7 +583,7 @@ def calculate_train(
|
|
559 |
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
560 |
|
561 |
if finetune:
|
562 |
-
learning_rate = 1e-
|
563 |
else:
|
564 |
learning_rate = 7.5e-5
|
565 |
|
@@ -611,6 +635,7 @@ def vocab_check(project_name):
|
|
611 |
sp = item.split("|")
|
612 |
if len(sp) != 2:
|
613 |
continue
|
|
|
614 |
text = sp[1].lower().strip()
|
615 |
|
616 |
for t in text:
|
@@ -625,6 +650,80 @@ def vocab_check(project_name):
|
|
625 |
return info
|
626 |
|
627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
with gr.Blocks() as app:
|
629 |
with gr.Row():
|
630 |
project_name = gr.Textbox(label="project name", value="my_speak")
|
@@ -661,6 +760,18 @@ with gr.Blocks() as app:
|
|
661 |
)
|
662 |
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
663 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
with gr.TabItem("prepare Data"):
|
665 |
gr.Markdown(
|
666 |
"""```plaintext
|
@@ -687,6 +798,16 @@ with gr.Blocks() as app:
|
|
687 |
txt_info_prepare = gr.Text(label="info", value="")
|
688 |
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
|
689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
690 |
with gr.TabItem("train Data"):
|
691 |
with gr.Row():
|
692 |
bt_calculate = bt_create = gr.Button("Auto Settings")
|
@@ -696,11 +817,11 @@ with gr.Blocks() as app:
|
|
696 |
|
697 |
with gr.Row():
|
698 |
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
699 |
-
learning_rate = gr.Number(label="Learning Rate", value=1e-
|
700 |
|
701 |
with gr.Row():
|
702 |
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
|
703 |
-
max_samples = gr.Number(label="Max Samples", value=
|
704 |
|
705 |
with gr.Row():
|
706 |
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
@@ -778,6 +899,28 @@ with gr.Blocks() as app:
|
|
778 |
txt_info_check = gr.Text(label="info", value="")
|
779 |
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
|
780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
|
782 |
@click.command()
|
783 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
|
4 |
+
import tempfile
|
5 |
+
import random
|
6 |
from transformers import pipeline
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
+
import gc
|
10 |
import click
|
11 |
import torchaudio
|
12 |
from glob import glob
|
|
|
23 |
import platform
|
24 |
import subprocess
|
25 |
from datasets.arrow_writer import ArrowWriter
|
26 |
+
from datasets import Dataset as Dataset_
|
27 |
+
from api import F5TTS
|
28 |
|
29 |
|
30 |
training_process = None
|
31 |
system = platform.system()
|
32 |
python_executable = sys.executable or "python"
|
33 |
+
tts_api = None
|
34 |
+
last_checkpoint = ""
|
35 |
+
last_device = ""
|
36 |
|
37 |
path_data = "data"
|
38 |
|
|
|
248 |
last_per_steps=800,
|
249 |
finetune=True,
|
250 |
):
|
251 |
+
global training_process, tts_api
|
252 |
+
|
253 |
+
if tts_api is not None:
|
254 |
+
del tts_api
|
255 |
+
gc.collect()
|
256 |
+
torch.cuda.empty_cache()
|
257 |
|
258 |
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
259 |
|
|
|
301 |
training_process = subprocess.Popen(cmd, shell=True)
|
302 |
|
303 |
time.sleep(5)
|
304 |
+
yield "train start", gr.update(interactive=False), gr.update(interactive=True)
|
305 |
|
306 |
# Wait for the training process to finish
|
307 |
training_process.wait()
|
|
|
532 |
path_project = os.path.join(path_data, name_project)
|
533 |
file_duraction = os.path.join(path_project, "duration.json")
|
534 |
|
535 |
+
if not os.path.isfile(file_duraction):
|
536 |
+
return (
|
537 |
+
1000,
|
538 |
+
max_samples,
|
539 |
+
num_warmup_updates,
|
540 |
+
save_per_updates,
|
541 |
+
last_per_steps,
|
542 |
+
"project not found !",
|
543 |
+
learning_rate,
|
544 |
+
)
|
545 |
+
|
546 |
with open(file_duraction, "r") as file:
|
547 |
data = json.load(file)
|
548 |
|
|
|
573 |
else:
|
574 |
max_samples = 64
|
575 |
|
576 |
+
num_warmup_updates = int(samples * 0.05)
|
577 |
+
save_per_updates = int(samples * 0.10)
|
578 |
last_per_steps = int(save_per_updates * 5)
|
579 |
|
580 |
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
|
|
583 |
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
584 |
|
585 |
if finetune:
|
586 |
+
learning_rate = 1e-5
|
587 |
else:
|
588 |
learning_rate = 7.5e-5
|
589 |
|
|
|
635 |
sp = item.split("|")
|
636 |
if len(sp) != 2:
|
637 |
continue
|
638 |
+
|
639 |
text = sp[1].lower().strip()
|
640 |
|
641 |
for t in text:
|
|
|
650 |
return info
|
651 |
|
652 |
|
653 |
+
def get_random_sample_prepare(project_name):
|
654 |
+
name_project = project_name + "_pinyin"
|
655 |
+
path_project = os.path.join(path_data, name_project)
|
656 |
+
file_arrow = os.path.join(path_project, "raw.arrow")
|
657 |
+
if not os.path.isfile(file_arrow):
|
658 |
+
return "", None
|
659 |
+
dataset = Dataset_.from_file(file_arrow)
|
660 |
+
random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
|
661 |
+
text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
|
662 |
+
audio_path = random_sample["audio_path"][0]
|
663 |
+
return text, audio_path
|
664 |
+
|
665 |
+
|
666 |
+
def get_random_sample_transcribe(project_name):
|
667 |
+
name_project = project_name + "_pinyin"
|
668 |
+
path_project = os.path.join(path_data, name_project)
|
669 |
+
file_metadata = os.path.join(path_project, "metadata.csv")
|
670 |
+
if not os.path.isfile(file_metadata):
|
671 |
+
return "", None
|
672 |
+
|
673 |
+
data = ""
|
674 |
+
with open(file_metadata, "r", encoding="utf-8") as f:
|
675 |
+
data = f.read()
|
676 |
+
|
677 |
+
list_data = []
|
678 |
+
for item in data.split("\n"):
|
679 |
+
sp = item.split("|")
|
680 |
+
if len(sp) != 2:
|
681 |
+
continue
|
682 |
+
list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
|
683 |
+
|
684 |
+
if list_data == []:
|
685 |
+
return "", None
|
686 |
+
|
687 |
+
random_item = random.choice(list_data)
|
688 |
+
|
689 |
+
return random_item[1], random_item[0]
|
690 |
+
|
691 |
+
|
692 |
+
def get_random_sample_infer(project_name):
|
693 |
+
text, audio = get_random_sample_transcribe(project_name)
|
694 |
+
return (
|
695 |
+
text,
|
696 |
+
text,
|
697 |
+
audio,
|
698 |
+
)
|
699 |
+
|
700 |
+
|
701 |
+
def infer(project_name, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
702 |
+
global last_checkpoint, last_device, tts_api
|
703 |
+
|
704 |
+
if not os.path.isfile(file_checkpoint):
|
705 |
+
return None
|
706 |
+
|
707 |
+
if training_process is not None:
|
708 |
+
device_test = "cpu"
|
709 |
+
else:
|
710 |
+
device_test = None
|
711 |
+
|
712 |
+
if last_checkpoint != file_checkpoint or last_device != device_test:
|
713 |
+
if last_checkpoint != file_checkpoint:
|
714 |
+
last_checkpoint = file_checkpoint
|
715 |
+
if last_device != device_test:
|
716 |
+
last_device = device_test
|
717 |
+
|
718 |
+
tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
|
719 |
+
|
720 |
+
print("update", device_test, file_checkpoint)
|
721 |
+
|
722 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
723 |
+
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
724 |
+
return f.name
|
725 |
+
|
726 |
+
|
727 |
with gr.Blocks() as app:
|
728 |
with gr.Row():
|
729 |
project_name = gr.Textbox(label="project name", value="my_speak")
|
|
|
760 |
)
|
761 |
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
762 |
|
763 |
+
random_sample_transcribe = gr.Button("random sample")
|
764 |
+
|
765 |
+
with gr.Row():
|
766 |
+
random_text_transcribe = gr.Text(label="Text")
|
767 |
+
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
|
768 |
+
|
769 |
+
random_sample_transcribe.click(
|
770 |
+
fn=get_random_sample_transcribe,
|
771 |
+
inputs=[project_name],
|
772 |
+
outputs=[random_text_transcribe, random_audio_transcribe],
|
773 |
+
)
|
774 |
+
|
775 |
with gr.TabItem("prepare Data"):
|
776 |
gr.Markdown(
|
777 |
"""```plaintext
|
|
|
798 |
txt_info_prepare = gr.Text(label="info", value="")
|
799 |
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
|
800 |
|
801 |
+
random_sample_prepare = gr.Button("random sample")
|
802 |
+
|
803 |
+
with gr.Row():
|
804 |
+
random_text_prepare = gr.Text(label="Pinyin")
|
805 |
+
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
806 |
+
|
807 |
+
random_sample_prepare.click(
|
808 |
+
fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
|
809 |
+
)
|
810 |
+
|
811 |
with gr.TabItem("train Data"):
|
812 |
with gr.Row():
|
813 |
bt_calculate = bt_create = gr.Button("Auto Settings")
|
|
|
817 |
|
818 |
with gr.Row():
|
819 |
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
820 |
+
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
821 |
|
822 |
with gr.Row():
|
823 |
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
|
824 |
+
max_samples = gr.Number(label="Max Samples", value=64)
|
825 |
|
826 |
with gr.Row():
|
827 |
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
|
|
899 |
txt_info_check = gr.Text(label="info", value="")
|
900 |
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
|
901 |
|
902 |
+
with gr.TabItem("test model"):
|
903 |
+
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
904 |
+
nfe_step = gr.Number(label="n_step", value=32)
|
905 |
+
file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
|
906 |
+
|
907 |
+
random_sample_infer = gr.Button("random sample")
|
908 |
+
|
909 |
+
ref_text = gr.Textbox(label="ref text")
|
910 |
+
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
911 |
+
gen_text = gr.Textbox(label="gen text")
|
912 |
+
random_sample_infer.click(
|
913 |
+
fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
|
914 |
+
)
|
915 |
+
check_button_infer = gr.Button("infer")
|
916 |
+
gen_audio = gr.Audio(label="audio gen", type="filepath")
|
917 |
+
|
918 |
+
check_button_infer.click(
|
919 |
+
fn=infer,
|
920 |
+
inputs=[project_name, file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
|
921 |
+
outputs=[gen_audio],
|
922 |
+
)
|
923 |
+
|
924 |
|
925 |
@click.command()
|
926 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|