leafspark commited on
Commit
d9e7e2f
·
verified ·
1 Parent(s): d372864

fix: use model name model.safetensors

Browse files
Files changed (1) hide show
  1. inference.py +1 -2
inference.py CHANGED
@@ -18,7 +18,6 @@ from litgpt.generate.base import (
18
  )
19
  import soundfile as sf
20
  from litgpt.model import GPT, Config
21
- from lightning.fabric.utilities.load import _lazy_load as lazy_load
22
  from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
23
  from utils.snac_utils import get_snac, generate_audio_data
24
  import whisper
@@ -359,7 +358,7 @@ def load_model(ckpt_dir, device):
359
  model = GPT(config)
360
 
361
  model = fabric.setup(model)
362
- state_dict = load_file(ckpt_dir + "/lit_model.safetensors")
363
  model.load_state_dict(state_dict, strict=True)
364
  model.to(device).eval()
365
 
 
18
  )
19
  import soundfile as sf
20
  from litgpt.model import GPT, Config
 
21
  from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
22
  from utils.snac_utils import get_snac, generate_audio_data
23
  import whisper
 
358
  model = GPT(config)
359
 
360
  model = fabric.setup(model)
361
+ state_dict = load_file(ckpt_dir + "/model.safetensors")
362
  model.load_state_dict(state_dict, strict=True)
363
  model.to(device).eval()
364