Spaces:
Runtime error
Runtime error
fix: use model name model.safetensors
Browse files- inference.py +1 -2
inference.py
CHANGED
@@ -18,7 +18,6 @@ from litgpt.generate.base import (
|
|
18 |
)
|
19 |
import soundfile as sf
|
20 |
from litgpt.model import GPT, Config
|
21 |
-
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
22 |
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
23 |
from utils.snac_utils import get_snac, generate_audio_data
|
24 |
import whisper
|
@@ -359,7 +358,7 @@ def load_model(ckpt_dir, device):
|
|
359 |
model = GPT(config)
|
360 |
|
361 |
model = fabric.setup(model)
|
362 |
-
state_dict = load_file(ckpt_dir + "/
|
363 |
model.load_state_dict(state_dict, strict=True)
|
364 |
model.to(device).eval()
|
365 |
|
|
|
18 |
)
|
19 |
import soundfile as sf
|
20 |
from litgpt.model import GPT, Config
|
|
|
21 |
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
22 |
from utils.snac_utils import get_snac, generate_audio_data
|
23 |
import whisper
|
|
|
358 |
model = GPT(config)
|
359 |
|
360 |
model = fabric.setup(model)
|
361 |
+
state_dict = load_file(ckpt_dir + "/model.safetensors")
|
362 |
model.load_state_dict(state_dict, strict=True)
|
363 |
model.to(device).eval()
|
364 |
|