import os import re import pickle import argparse import io import requests import torch import stylegan2 from stylegan2 import utils pretrained_model_urls = { 'car-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-e.pkl', 'car-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl', 'cat-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl', 'church-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl', 'ffhq-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-e.pkl', 'ffhq-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl', 'horse-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-horse-config-f.pkl', 'car-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl', 'car-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl', 'car-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl', 'car-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl', 'car-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl', 'car-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl', 'car-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl', 'car-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl', 'car-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl', 'ffhq-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl', 'ffhq-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl', 'ffhq-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl', 'ffhq-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl', 'ffhq-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl', 'ffhq-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl', 'ffhq-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl', 'ffhq-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl', 'ffhq-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl', } class Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == 'dnnlib.tflib.network' and name == 'Network': return utils.AttributeDict return super(Unpickler, self).find_class(module, name) def load_tf_models_file(fpath): with open(fpath, 'rb') as fp: return Unpickler(fp).load() def load_tf_models_url(url): print('Downloading file {}...'.format(url)) with requests.Session() as session: with session.get(url) as ret: fp = io.BytesIO(ret.content) return Unpickler(fp).load() def convert_kwargs(static_kwargs, kwargs_mapping): kwargs = utils.AttributeDict() for key, value in static_kwargs.items(): if key in kwargs_mapping: if value == 'lrelu': value = 'leaky:0.2' for k in utils.to_list(kwargs_mapping[key]): kwargs[k] = value return kwargs _PERMITTED_MODELS = ['G_main', 'G_mapping', 'G_synthesis_stylegan2', 'D_stylegan2', 'D_main', 'G_synthesis'] def convert_from_tf(tf_state): tf_state = utils.AttributeDict.convert_dict_recursive(tf_state) model_type = tf_state.build_func_name assert model_type in _PERMITTED_MODELS, \ 'Found model type {}. '.format(model_type) + \ 'Allowed model types are: {}'.format(_PERMITTED_MODELS) if model_type == 'G_main': kwargs = convert_kwargs( static_kwargs=tf_state.static_kwargs, kwargs_mapping={ 'dlatent_avg_beta': 'dlatent_avg_beta' } ) kwargs.G_mapping = convert_from_tf(tf_state.components.mapping) kwargs.G_synthesis = convert_from_tf(tf_state.components.synthesis) G = stylegan2.models.Generator(**kwargs) for name, var in tf_state.variables: if name == 'dlatent_avg': G.dlatent_avg.data.copy_(torch.from_numpy(var)) kwargs = convert_kwargs( static_kwargs=tf_state.static_kwargs, kwargs_mapping={ 'truncation_psi': 'truncation_psi', 'truncation_cutoff': 'truncation_cutoff', 'truncation_psi_val': 'truncation_psi', 'truncation_cutoff_val': 'truncation_cutoff' } ) G.set_truncation(**kwargs) return G if model_type == 'G_mapping': kwargs = convert_kwargs( static_kwargs=tf_state.static_kwargs, kwargs_mapping={ 'mapping_nonlinearity': 'activation', 'normalize_latents': 'normalize_input', 'mapping_lr_mul': 'lr_mul' } ) kwargs.num_layers = sum( 1 for var_name, _ in tf_state.variables if re.match('Dense[0-9]+/weight', var_name) ) for var_name, var in tf_state.variables: if var_name == 'LabelConcat/weight': kwargs.label_size = var.shape[0] if var_name == 'Dense0/weight': kwargs.latent_size = var.shape[0] kwargs.hidden = var.shape[1] if var_name == 'Dense{}/bias'.format(kwargs.num_layers - 1): kwargs.out_size = var.shape[0] G_mapping = stylegan2.models.GeneratorMapping(**kwargs) for var_name, var in tf_state.variables: if re.match('Dense[0-9]+/[a-zA-Z]*', var_name): layer_idx = int(re.search('Dense(\d+)/[a-zA-Z]*', var_name).groups()[0]) if var_name.endswith('weight'): G_mapping.main[layer_idx].layer.weight.data.copy_( torch.from_numpy(var.T).contiguous()) elif var_name.endswith('bias'): G_mapping.main[layer_idx].bias.data.copy_(torch.from_numpy(var)) if var_name == 'LabelConcat/weight': G_mapping.embedding.weight.data.copy_(torch.from_numpy(var)) return G_mapping if model_type == 'G_synthesis_stylegan2' or model_type == 'G_synthesis': assert tf_state.static_kwargs.get('fused_modconv', True), \ 'Can not load TF networks that use `fused_modconv=False`' noise_tensors = [] conv_vars = {} for var_name, var in tf_state.variables: if var_name.startswith('noise'): noise_tensors.append(torch.from_numpy(var)) else: layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0]) if layer_size not in conv_vars: conv_vars[layer_size] = {} var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '') conv_vars[layer_size][var_name] = var noise_tensors = sorted(noise_tensors, key=lambda x:x.size(-1)) kwargs = convert_kwargs( static_kwargs=tf_state.static_kwargs, kwargs_mapping={ 'nonlinearity': 'activation', 'resample_filter': ['conv_filter', 'skip_filter'] } ) kwargs.skip = False kwargs.resnet = True kwargs.channels = [] for size in sorted(conv_vars.keys(), reverse=True): if size == 4: if 'ToRGB/weight' in conv_vars[size]: kwargs.skip = True kwargs.resnet = False kwargs.latent_size = conv_vars[size]['Conv/mod_weight'].shape[0] kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0]) else: kwargs.channels.append(conv_vars[size]['Conv1/bias'].shape[0]) if 'ToRGB/bias' in conv_vars[size]: kwargs.data_channels = conv_vars[size]['ToRGB/bias'].shape[0] G_synthesis = stylegan2.models.GeneratorSynthesis(**kwargs) G_synthesis.const.data.copy_(torch.from_numpy(conv_vars[4]['Const/const']).squeeze(0)) def assign_weights(layer, weight, bias, mod_weight, mod_bias, noise_strength, transposed=False): layer.bias.data.copy_(torch.from_numpy(bias)) layer.layer.weight.data.copy_(torch.tensor(noise_strength)) layer.layer.layer.dense.layer.weight.data.copy_( torch.from_numpy(mod_weight.T).contiguous()) layer.layer.layer.dense.bias.data.copy_(torch.from_numpy(mod_bias + 1)) weight = torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous() if transposed: weight = weight.flip(dims=[2,3]) layer.layer.layer.weight.data.copy_(weight) conv_blocks = G_synthesis.conv_blocks for i, size in enumerate(sorted(conv_vars.keys())): block = conv_blocks[i] if size == 4: assign_weights( layer=block.conv_block[0], weight=conv_vars[size]['Conv/weight'], bias=conv_vars[size]['Conv/bias'], mod_weight=conv_vars[size]['Conv/mod_weight'], mod_bias=conv_vars[size]['Conv/mod_bias'], noise_strength=conv_vars[size]['Conv/noise_strength'], ) else: assign_weights( layer=block.conv_block[0], weight=conv_vars[size]['Conv0_up/weight'], bias=conv_vars[size]['Conv0_up/bias'], mod_weight=conv_vars[size]['Conv0_up/mod_weight'], mod_bias=conv_vars[size]['Conv0_up/mod_bias'], noise_strength=conv_vars[size]['Conv0_up/noise_strength'], transposed=True ) assign_weights( layer=block.conv_block[1], weight=conv_vars[size]['Conv1/weight'], bias=conv_vars[size]['Conv1/bias'], mod_weight=conv_vars[size]['Conv1/mod_weight'], mod_bias=conv_vars[size]['Conv1/mod_bias'], noise_strength=conv_vars[size]['Conv1/noise_strength'], ) if 'Skip/weight' in conv_vars[size]: block.projection.weight.data.copy_(torch.from_numpy( conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous()) to_RGB = G_synthesis.to_data_layers[i] if to_RGB is not None: to_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['ToRGB/bias'])) to_RGB.layer.weight.data.copy_(torch.from_numpy( conv_vars[size]['ToRGB/weight']).permute((3, 2, 0, 1)).contiguous()) to_RGB.layer.dense.bias.data.copy_( torch.from_numpy(conv_vars[size]['ToRGB/mod_bias'] + 1)) to_RGB.layer.dense.layer.weight.data.copy_( torch.from_numpy(conv_vars[size]['ToRGB/mod_weight'].T).contiguous()) if not tf_state.static_kwargs.get('randomize_noise', True): G_synthesis.static_noise(noise_tensors=noise_tensors) return G_synthesis if model_type == 'D_stylegan2' or model_type == 'D_main': output_vars = {} conv_vars = {} for var_name, var in tf_state.variables: if var_name.startswith('Output'): output_vars[var_name.replace('Output/', '')] = var else: layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0]) if layer_size not in conv_vars: conv_vars[layer_size] = {} var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '') conv_vars[layer_size][var_name] = var kwargs = convert_kwargs( static_kwargs=tf_state.static_kwargs, kwargs_mapping={ 'nonlinearity': 'activation', 'resample_filter': ['conv_filter', 'skip_filter'], 'mbstd_group_size': 'mbstd_group_size' } ) kwargs.skip = False kwargs.resnet = True kwargs.channels = [] for size in sorted(conv_vars.keys(), reverse=True): if size == 4: if 'FromRGB/weight' in conv_vars[size]: kwargs.skip = True kwargs.resnet = False kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0]) kwargs.dense_hidden = conv_vars[size]['Dense0/bias'].shape[0] else: kwargs.channels.append(conv_vars[size]['Conv0/bias'].shape[0]) if 'FromRGB/weight' in conv_vars[size]: kwargs.data_channels = conv_vars[size]['FromRGB/weight'].shape[-2] output_size = output_vars['bias'].shape[0] if output_size > 1: kwargs.label_size = output_size D = stylegan2.models.Discriminator(**kwargs) def assign_weights(layer, weight, bias): layer.bias.data.copy_(torch.from_numpy(bias)) layer.layer.weight.data.copy_( torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous()) conv_blocks = D.conv_blocks for i, size in enumerate(sorted(conv_vars.keys())): block = conv_blocks[-i - 1] if size == 4: assign_weights( layer=block[-1].conv_block[0], weight=conv_vars[size]['Conv/weight'], bias=conv_vars[size]['Conv/bias'], ) else: assign_weights( layer=block.conv_block[0], weight=conv_vars[size]['Conv0/weight'], bias=conv_vars[size]['Conv0/bias'], ) assign_weights( layer=block.conv_block[1], weight=conv_vars[size]['Conv1_down/weight'], bias=conv_vars[size]['Conv1_down/bias'], ) if 'Skip/weight' in conv_vars[size]: block.projection.weight.data.copy_(torch.from_numpy( conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous()) from_RGB = D.from_data_layers[-i - 1] if from_RGB is not None: from_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['FromRGB/bias'])) from_RGB.layer.weight.data.copy_(torch.from_numpy( conv_vars[size]['FromRGB/weight']).permute((3, 2, 0, 1)).contiguous()) return D def get_arg_parser(): parser = argparse.ArgumentParser( description='Convert tensorflow stylegan2 model to pytorch.', epilog='Pretrained models that can be downloaded:\n{}'.format( '\n'.join(pretrained_model_urls.keys())) ) parser.add_argument( '-i', '--input', help='File path to pickled tensorflow models.', type=str, default=None, ) parser.add_argument( '-d', '--download', help='Download the specified pretrained model. Use --help for info on available models.', type=str, default=None, ) parser.add_argument( '-o', '--output', help='One or more output file paths. Alternatively a directory path ' + \ 'where all models will be saved. Default: current directory', type=str, nargs='*', default=['.'], ) return parser def main(): args = get_arg_parser().parse_args() assert bool(args.input) != bool(args.download), \ 'Incorrect input format. Can only take either one ' + \ 'input filepath to a pickled tensorflow model or ' + \ 'a model name to download, but not both at the same ' + \ 'time or none at all.' if args.input: unpickled = load_tf_models_file(args.input) else: assert args.download in pretrained_model_urls.keys(), \ 'Unknown model {}. Use --help for list of models.'.format(args.download) unpickled = load_tf_models_url(pretrained_model_urls[args.download]) if not isinstance(unpickled, (tuple, list)): unpickled = [unpickled] print('Converting tensorflow models and saving them...') converted = [convert_from_tf(tf_state) for tf_state in unpickled] if len(args.output) == 1 and (os.path.isdir(args.output[0]) or not os.path.splitext(args.output[0])[-1]): if not os.path.exists(args.output[0]): os.makedirs(args.output[0]) for tf_state, torch_model in zip(unpickled, converted): torch_model.save(os.path.join(args.output[0], tf_state['name'] + '.pth')) else: assert len(args.output) == len(converted), 'Found {} models '.format(len(converted)) + \ 'in pickled file but only {} output paths were given.'.format(len(args.output)) for out_path, torch_model in zip(args.output, converted): torch_model.save(out_path) print('Done!') if __name__ == '__main__': main()