|
from utils.dist import * |
|
from parse import * |
|
from utils.util import find_free_port |
|
import torch.multiprocessing as mp |
|
import torch.distributed |
|
from importlib import import_module |
|
|
|
from flow_inputs import args_parser |
|
|
|
|
|
def main_worker(rank, opt): |
|
if 'local_rank' not in opt: |
|
opt['local_rank'] = opt['global_rank'] = rank |
|
if opt['distributed']: |
|
torch.cuda.set_device(int(opt['local_rank'])) |
|
torch.distributed.init_process_group(backend='nccl', |
|
init_method=opt['init_method'], |
|
world_size=opt['world_size'], |
|
rank=opt['global_rank'], |
|
group_name='mtorch') |
|
print('using GPU {}-{} for training'.format( |
|
int(opt['global_rank']), int(opt['local_rank']))) |
|
|
|
if torch.cuda.is_available(): |
|
opt['device'] = torch.device("cuda:{}".format(opt['local_rank'])) |
|
else: |
|
opt['device'] = 'cpu' |
|
|
|
pkg = import_module('networks.{}'.format(opt['network'])) |
|
trainer = pkg.Network(opt, rank) |
|
trainer.train() |
|
|
|
|
|
def main(args_obj): |
|
opt = parse(args_obj) |
|
opt['world_size'] = get_world_size() |
|
free_port = find_free_port() |
|
master_ip = get_master_ip() |
|
opt['init_method'] = "tcp://{}:{}".format(master_ip, free_port) |
|
opt['distributed'] = True if opt['world_size'] > 1 else False |
|
print(f'World size is: {opt["world_size"]}, and init_method is: {opt["init_method"]}') |
|
print('Import network module: ', opt['network']) |
|
|
|
|
|
if opt['gen_state'] != '': |
|
opt['path']['gen_state'] = opt['gen_state'] |
|
if opt['opt_state'] != '': |
|
opt['path']['opt_state'] = opt['opt_state'] |
|
|
|
if args.finetune == 1: |
|
opt['finetune'] = True |
|
else: |
|
opt['finetune'] = False |
|
|
|
print(f'model is: {opt["model"]}') |
|
|
|
if get_master_ip() == "127.0.0.1": |
|
|
|
mp.spawn(main_worker, nprocs=opt['world_size'], args=(opt,)) |
|
else: |
|
|
|
opt['local_rank'] = get_local_rank() |
|
opt['global_rank'] = get_global_rank() |
|
main_worker(-1, opt) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = args_parser() |
|
args_obj = vars(args) |
|
main(args_obj) |
|
|