ntt123 commited on
Commit
ec02525
·
1 Parent(s): 6434756

Support generating long clips

Browse files

Generate short clips at paragraph level and then combine them.

Files changed (1) hide show
  1. app.py +35 -23
app.py CHANGED
@@ -134,7 +134,7 @@ def text_to_phone_idx(text):
134
  return tokens
135
 
136
 
137
- def text_to_speech(text):
138
  # prevent too long text
139
  if len(text) > 500:
140
  text = text[:500]
@@ -146,9 +146,6 @@ def text_to_speech(text):
146
  }
147
 
148
  # predict phoneme duration
149
- duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
150
- duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
151
- duration_net = duration_net.eval()
152
  phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
153
  phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
154
  with torch.inference_mode():
@@ -158,24 +155,7 @@ def text_to_speech(text):
158
  )
159
  phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
160
 
161
- generator = SynthesizerTrn(
162
- hps.data.vocab_size,
163
- hps.data.filter_length // 2 + 1,
164
- hps.train.segment_size // hps.data.hop_length,
165
- **vars(hps.model),
166
- ).to(device)
167
- del generator.enc_q
168
- ckpt = torch.load(lightspeed_model_path, map_location=device)
169
- params = {}
170
- for k, v in ckpt["net_g"].items():
171
- k = k[7:] if k.startswith("module.") else k
172
- params[k] = v
173
- generator.load_state_dict(params, strict=False)
174
- del ckpt, params
175
- generator = generator.eval()
176
- # mininum 1 frame for each phone
177
- # phone_duration = torch.clamp_min(phone_duration, hps.data.hop_length * 1000 / hps.data.sampling_rate)
178
- # phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
179
  end_time = torch.cumsum(phone_duration, dim=-1)
180
  start_time = end_time - phone_duration
181
  start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
@@ -194,8 +174,40 @@ def text_to_speech(text):
194
  return (wave * (2**15)).astype(np.int16)
195
 
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def speak(text):
198
- y = text_to_speech(text)
 
 
 
 
 
 
 
 
 
 
199
  return hps.data.sampling_rate, y
200
 
201
 
 
134
  return tokens
135
 
136
 
137
+ def text_to_speech(duration_net, generator, text):
138
  # prevent too long text
139
  if len(text) > 500:
140
  text = text[:500]
 
146
  }
147
 
148
  # predict phoneme duration
 
 
 
149
  phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
150
  phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
151
  with torch.inference_mode():
 
155
  )
156
  phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
157
 
158
+ # generate waveform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  end_time = torch.cumsum(phone_duration, dim=-1)
160
  start_time = end_time - phone_duration
161
  start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
 
174
  return (wave * (2**15)).astype(np.int16)
175
 
176
 
177
+ def load_models():
178
+ duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
179
+ duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
180
+ duration_net = duration_net.eval()
181
+ generator = SynthesizerTrn(
182
+ hps.data.vocab_size,
183
+ hps.data.filter_length // 2 + 1,
184
+ hps.train.segment_size // hps.data.hop_length,
185
+ **vars(hps.model),
186
+ ).to(device)
187
+ del generator.enc_q
188
+ ckpt = torch.load(lightspeed_model_path, map_location=device)
189
+ params = {}
190
+ for k, v in ckpt["net_g"].items():
191
+ k = k[7:] if k.startswith("module.") else k
192
+ params[k] = v
193
+ generator.load_state_dict(params, strict=False)
194
+ del ckpt, params
195
+ generator = generator.eval()
196
+ return duration_net, generator
197
+
198
+
199
  def speak(text):
200
+ duration_net, generator = load_models()
201
+ paragraphs = text.split("\n")
202
+ clips = [] # list of audio clips
203
+ # silence = np.zeros(hps.data.sampling_rate // 4)
204
+ for paragraph in paragraphs:
205
+ paragraph = paragraph.strip()
206
+ if paragraph == "":
207
+ continue
208
+ clips.append(text_to_speech(duration_net, generator, paragraph))
209
+ # clips.append(silence)
210
+ y = np.concatenate(clips)
211
  return hps.data.sampling_rate, y
212
 
213