Spaces:
Running
Running
File size: 1,715 Bytes
3dd84f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int:
"""load the latest checkpoints and optimizers"""
model_dict = {}
optimizer_dict = {}
# globt all the checkpoints in the directory
for file in os.listdir(checkpoint_path):
if file.endswith(".pt") and '_' in file:
name, epoch_str = file.rsplit('_', 1)
epoch = int(epoch_str.split('.')[0])
if name.startswith("checkpoint"):
model_dict[epoch] = file
elif name.startswith("optimizer"):
optimizer_dict[epoch] = file
# get the largest epoch
common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys())
if common_epochs:
max_epoch = max(common_epochs)
model_path = os.path.join(checkpoint_path, model_dict[max_epoch])
optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch])
# load model and optimizer
model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu'))
print(f'resume model and optimizer from {max_epoch} epoch')
return max_epoch + 1
else:
# load pretrained checkpoint
if model_dict:
model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())])
model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
return 0 |