Spaces:
Sleeping
Sleeping
from utils.dist import * | |
from parse import * | |
from utils.util import find_free_port | |
import torch.multiprocessing as mp | |
import torch.distributed | |
from importlib import import_module | |
from flow_inputs import args_parser | |
def main_worker(rank, opt): | |
if 'local_rank' not in opt: | |
opt['local_rank'] = opt['global_rank'] = rank | |
if opt['distributed']: | |
torch.cuda.set_device(int(opt['local_rank'])) | |
torch.distributed.init_process_group(backend='nccl', | |
init_method=opt['init_method'], | |
world_size=opt['world_size'], | |
rank=opt['global_rank'], | |
group_name='mtorch') | |
print('using GPU {}-{} for training'.format( | |
int(opt['global_rank']), int(opt['local_rank']))) | |
if torch.cuda.is_available(): | |
opt['device'] = torch.device("cuda:{}".format(opt['local_rank'])) | |
else: | |
opt['device'] = 'cpu' | |
pkg = import_module('networks.{}'.format(opt['network'])) | |
trainer = pkg.Network(opt, rank) | |
trainer.train() | |
def main(args_obj): | |
opt = parse(args_obj) | |
opt['world_size'] = get_world_size() | |
free_port = find_free_port() | |
master_ip = get_master_ip() | |
opt['init_method'] = "tcp://{}:{}".format(master_ip, free_port) | |
opt['distributed'] = True if opt['world_size'] > 1 else False | |
print(f'World size is: {opt["world_size"]}, and init_method is: {opt["init_method"]}') | |
print('Import network module: ', opt['network']) | |
# dataset file names | |
if opt['gen_state'] != '': | |
opt['path']['gen_state'] = opt['gen_state'] | |
if opt['opt_state'] != '': | |
opt['path']['opt_state'] = opt['opt_state'] | |
if args.finetune == 1: | |
opt['finetune'] = True | |
else: | |
opt['finetune'] = False | |
print(f'model is: {opt["model"]}') | |
if get_master_ip() == "127.0.0.1": | |
# localhost | |
mp.spawn(main_worker, nprocs=opt['world_size'], args=(opt,)) | |
else: | |
# multiple processes should be launched by openmpi | |
opt['local_rank'] = get_local_rank() | |
opt['global_rank'] = get_global_rank() | |
main_worker(-1, opt) | |
if __name__ == '__main__': | |
args = args_parser() | |
args_obj = vars(args) | |
main(args_obj) | |