L0SG commited on
Commit
8097450
·
1 Parent(s): 12a1102
Files changed (2) hide show
  1. README.md +2 -2
  2. bigvgan.py +34 -20
README.md CHANGED
@@ -59,10 +59,10 @@ from meldataset import get_mel_spectrogram
59
 
60
  # load wav file and compute mel spectrogram
61
  wav, sr = librosa.load('/path/to/your/audio.wav', sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
62
- wav = torch.FloatTensor(wav).to(device).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
63
 
64
  # compute mel spectrogram from the ground truth audio
65
- mel = get_mel_spectrogram(wav, model.h) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
66
 
67
  # generate waveform from mel
68
  with torch.inference_mode():
 
59
 
60
  # load wav file and compute mel spectrogram
61
  wav, sr = librosa.load('/path/to/your/audio.wav', sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
62
+ wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
63
 
64
  # compute mel spectrogram from the ground truth audio
65
+ mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
66
 
67
  # generate waveform from mel
68
  with torch.inference_mode():
bigvgan.py CHANGED
@@ -257,14 +257,18 @@ class BigVGAN(
257
  return x
258
 
259
  def remove_weight_norm(self):
260
- print('Removing weight norm...')
261
- for l in self.ups:
262
- for l_i in l:
263
- remove_weight_norm(l_i)
264
- for l in self.resblocks:
265
- l.remove_weight_norm()
266
- remove_weight_norm(self.conv_pre)
267
- remove_weight_norm(self.conv_post)
 
 
 
 
268
 
269
  ##################################################################
270
  # additional methods for huggingface_hub support
@@ -304,17 +308,21 @@ class BigVGAN(
304
  ##################################################################
305
  # download and load hyperparameters (h) used by BigVGAN
306
  ##################################################################
307
- config_file = hf_hub_download(
308
- repo_id=model_id,
309
- filename='config.json',
310
- revision=revision,
311
- cache_dir=cache_dir,
312
- force_download=force_download,
313
- proxies=proxies,
314
- resume_download=resume_download,
315
- token=token,
316
- local_files_only=local_files_only,
317
- )
 
 
 
 
318
  h = load_hparams_from_json(config_file)
319
 
320
  ##################################################################
@@ -347,6 +355,12 @@ class BigVGAN(
347
  )
348
 
349
  checkpoint_dict = torch.load(model_file, map_location=map_location)
350
- model.load_state_dict(checkpoint_dict['generator'])
 
 
 
 
 
 
351
 
352
  return model
 
257
  return x
258
 
259
  def remove_weight_norm(self):
260
+ try:
261
+ print('Removing weight norm...')
262
+ for l in self.ups:
263
+ for l_i in l:
264
+ remove_weight_norm(l_i)
265
+ for l in self.resblocks:
266
+ l.remove_weight_norm()
267
+ remove_weight_norm(self.conv_pre)
268
+ remove_weight_norm(self.conv_post)
269
+ except ValueError:
270
+ print('[INFO] Model already removed weight norm. Skipping!')
271
+ pass
272
 
273
  ##################################################################
274
  # additional methods for huggingface_hub support
 
308
  ##################################################################
309
  # download and load hyperparameters (h) used by BigVGAN
310
  ##################################################################
311
+ if os.path.isdir(model_id):
312
+ print("Loading config.json from local directory")
313
+ config_file = os.path.join(model_id, 'config.json')
314
+ else:
315
+ config_file = hf_hub_download(
316
+ repo_id=model_id,
317
+ filename='config.json',
318
+ revision=revision,
319
+ cache_dir=cache_dir,
320
+ force_download=force_download,
321
+ proxies=proxies,
322
+ resume_download=resume_download,
323
+ token=token,
324
+ local_files_only=local_files_only,
325
+ )
326
  h = load_hparams_from_json(config_file)
327
 
328
  ##################################################################
 
355
  )
356
 
357
  checkpoint_dict = torch.load(model_file, map_location=map_location)
358
+
359
+ try:
360
+ model.load_state_dict(checkpoint_dict['generator'])
361
+ except RuntimeError:
362
+ print(f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!")
363
+ model.remove_weight_norm()
364
+ model.load_state_dict(checkpoint_dict['generator'])
365
 
366
  return model