saeki commited on
Commit
ed57756
·
1 Parent(s): 7683907

initial commit

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import yaml
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ from lightning_module import SSLDualLightningModule
7
+ import gradio as gr
8
+
9
+ def normalize_waveform(wav, sr, db=-3):
10
+ wav, _ = torchaudio.sox_effects.apply_effects_tensor(
11
+ wav.unsqueeze(0),
12
+ sr,
13
+ [["norm", "{}".format(db)]],
14
+ )
15
+ return wav.squeeze(0)
16
+
17
+ def calc_spectrogram(wav, config):
18
+ spec_module = torchaudio.transforms.MelSpectrogram(
19
+ sample_rate=config["preprocess"]["sampling_rate"],
20
+ n_fft=config["preprocess"]["fft_length"],
21
+ win_length=config["preprocess"]["frame_length"],
22
+ hop_length=config["preprocess"]["frame_shift"],
23
+ f_min=config["preprocess"]["fmin"],
24
+ f_max=config["preprocess"]["fmax"],
25
+ n_mels=config["preprocess"]["n_mels"],
26
+ power=1,
27
+ center=True,
28
+ norm="slaney",
29
+ mel_scale="slaney",
30
+ )
31
+ specs = spec_module(wav)
32
+ log_spec = torch.log(
33
+ torch.clamp_min(specs, config["preprocess"]["min_magnitude"])
34
+ * config["preprocess"]["comp_factor"]
35
+ ).to(torch.float32)
36
+ return log_spec
37
+
38
+ def transfer(audio):
39
+ wp_src = pathlib.Path("aet_sample/src.wav")
40
+ wav_src, sr = torchaudio.load(wp_src)
41
+ sr_inp, wav_tar = audio
42
+ wav_tar = wav_tar / (np.max(np.abs(wav_tar)) * 1.1)
43
+ wav_tar = torch.from_numpy(wav_tar.astype(np.float32))
44
+ resampler = torchaudio.transforms.Resample(
45
+ orig_freq=sr_inp,
46
+ new_freq=sr,
47
+ )
48
+ wav_tar = resampler(wav_tar)
49
+ config_path = pathlib.Path("configs/test/melspec/ssl_tono.yaml")
50
+ config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
51
+
52
+ melspec_src = calc_spectrogram(
53
+ normalize_waveform(wav_src.squeeze(0), sr), config
54
+ )
55
+ wav_tar = normalize_waveform(wav_tar.squeeze(0), sr)
56
+ ckpt_path = pathlib.Path("aet_sample/tono_melspec_aet.ckpt")
57
+ src_model = SSLDualLightningModule(config).load_from_checkpoint(
58
+ checkpoint_path=ckpt_path,
59
+ config=config,
60
+ )
61
+
62
+ encoder_src = src_model.encoder
63
+ channelfeats_src = src_model.channelfeats
64
+ channel_src = src_model.channel
65
+
66
+ _, enc_hidden_src = encoder_src(
67
+ melspec_src.unsqueeze(0).unsqueeze(1).transpose(2, 3)
68
+ )
69
+ chfeats_src = channelfeats_src(enc_hidden_src)
70
+ wav_transfer = channel_src(wav_tar.unsqueeze(0), chfeats_src)
71
+ wav_transfer = wav_transfer.detach().numpy()[0, :]
72
+ return sr, wav_transfer
73
+
74
+ if __name__ == "__main__":
75
+ iface = gr.Interface(
76
+ transfer,
77
+ "audio",
78
+ gr.outputs.Audio(type="numpy"),
79
+ examples=[
80
+ ["aet_sample/tar.wav"]
81
+ ],
82
+ title='Audio effect transfer demo',
83
+ description='Add channel feature of Japanese old audio recording to any high-quality audio'
84
+ )
85
+
86
+ iface.launch()