StableTTS1.1 / utils /load.py
KdaiP's picture
Upload 80 files
3dd84f8 verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int:
"""load the latest checkpoints and optimizers"""
model_dict = {}
optimizer_dict = {}
# globt all the checkpoints in the directory
for file in os.listdir(checkpoint_path):
if file.endswith(".pt") and '_' in file:
name, epoch_str = file.rsplit('_', 1)
epoch = int(epoch_str.split('.')[0])
if name.startswith("checkpoint"):
model_dict[epoch] = file
elif name.startswith("optimizer"):
optimizer_dict[epoch] = file
# get the largest epoch
common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys())
if common_epochs:
max_epoch = max(common_epochs)
model_path = os.path.join(checkpoint_path, model_dict[max_epoch])
optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch])
# load model and optimizer
model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu'))
print(f'resume model and optimizer from {max_epoch} epoch')
return max_epoch + 1
else:
# load pretrained checkpoint
if model_dict:
model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())])
model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
return 0