VALL-E-X / test.py
ITS-C4SF733\Administrator
all resource
324bf29
import os
import torch
import logging
from data.dataset import create_dataloader
from macros import *
from data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
from data.collation import get_text_token_collater
from models.vallex import VALLE
if torch.cuda.is_available():
device = torch.device("cuda", 0)
from vocos import Vocos
from pathlib import Path
import platform
import pathlib
plt = platform.system()
print("Operating System:", plt)
if plt == 'Linux':
pathlib.WindowsPath = pathlib.PosixPath
def get_model(device):
url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'
checkpoints_dir = "./checkpoints"
model_checkpoint_name = "vallex-checkpoint_modified.pt"
if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
import wget
print("3")
try:
logging.info(
"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
# download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
except Exception as e:
logging.info(e)
raise Exception(
"\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
"\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))
# VALL-E
model = VALLE(
N_DIM,
NUM_HEAD,
NUM_LAYERS,
norm_first=True,
add_prenet=False,
prefix_mode=PREFIX_MODE,
share_embedding=True,
nar_scale_factor=1.0,
prepend_bos=True,
num_quantizers=NUM_QUANTIZERS,
).to(device)
checkpoint_path = Path(checkpoints_dir) / model_checkpoint_name
checkpoint = torch.load(checkpoint_path, map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
)
assert not missing_keys
# Encodec
codec = AudioTokenizer(device)
vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
return model, codec, vocos
def get_valle_model(device):
# VALL-E
model = VALLE(
N_DIM,
NUM_HEAD,
NUM_LAYERS,
norm_first=True,
add_prenet=False,
prefix_mode=PREFIX_MODE,
share_embedding=True,
nar_scale_factor=1.0,
prepend_bos=True,
num_quantizers=NUM_QUANTIZERS,
).to(device)
return model