litagin commited on
Commit
bf9a094
·
1 Parent(s): 93ba194
Files changed (2) hide show
  1. app.py +1 -1
  2. style_gen.py +66 -0
app.py CHANGED
@@ -469,7 +469,7 @@ if __name__ == "__main__":
469
  style_weight = gr.Slider(
470
  minimum=0,
471
  maximum=50,
472
- value=1,
473
  step=0.1,
474
  label="スタイルの強さ",
475
  )
 
469
  style_weight = gr.Slider(
470
  minimum=0,
471
  maximum=50,
472
+ value=5,
473
  step=0.1,
474
  label="スタイルの強さ",
475
  )
style_gen.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import concurrent.futures
3
+ import sys
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ import utils
11
+ from config import config
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+ from pyannote.audio import Inference, Model
15
+
16
+ model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM")
17
+ inference = Inference(model, window="whole")
18
+ device = torch.device(config.style_gen_config.device)
19
+ inference.to(device)
20
+
21
+
22
+ def extract_style_vector(wav_path):
23
+ return inference(wav_path)
24
+
25
+
26
+ def save_style_vector(wav_path):
27
+ style_vec = extract_style_vector(wav_path)
28
+ # `test.wav` -> `test.wav.npy`
29
+ np.save(f"{wav_path}.npy", style_vec)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ "-c", "--config", type=str, default=config.style_gen_config.config_path
36
+ )
37
+ parser.add_argument(
38
+ "--num_processes", type=int, default=config.style_gen_config.num_processes
39
+ )
40
+ args, _ = parser.parse_known_args()
41
+ config_path = args.config
42
+ num_processes = args.num_processes
43
+
44
+ hps = utils.get_hparams_from_file(config_path)
45
+
46
+ device = config.style_gen_config.device
47
+
48
+ lines = []
49
+ with open(hps.data.training_files, encoding="utf-8") as f:
50
+ lines.extend(f.readlines())
51
+
52
+ with open(hps.data.validation_files, encoding="utf-8") as f:
53
+ lines.extend(f.readlines())
54
+
55
+ wavnames = [line.split("|")[0] for line in lines]
56
+
57
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_processes) as executor:
58
+ list(
59
+ tqdm(
60
+ executor.map(save_style_vector, wavnames),
61
+ total=len(wavnames),
62
+ file=sys.stdout,
63
+ )
64
+ )
65
+
66
+ print(f"Finished generating style vectors! total: {len(wavnames)} npy files.")