Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
import numpy as np | |
import torch | |
import stylegan2 | |
from stylegan2 import utils | |
#---------------------------------------------------------------------------- | |
_description = """StyleGAN2 projector. | |
Run 'python %(prog)s <subcommand> --help' for subcommand help.""" | |
#---------------------------------------------------------------------------- | |
_examples = """examples: | |
# Train a network or convert a pretrained one. | |
# Example of converting pretrained ffhq model: | |
python run_convert_from_tf --download ffhq-config-f --output G.pth D.pth Gs.pth | |
# Project generated images | |
python %(prog)s project_generated_images --network=Gs.pth --seeds=0,1,5 | |
# Project real images | |
python %(prog)s project_real_images --network=Gs.pth --data-dir=path/to/image_folder | |
""" | |
#---------------------------------------------------------------------------- | |
def _add_shared_arguments(parser): | |
parser.add_argument( | |
'--network', | |
help='Network file path', | |
required=True, | |
metavar='FILE' | |
) | |
parser.add_argument( | |
'--num_steps', | |
type=int, | |
help='Number of steps to use for projection. ' + \ | |
'Default: %(default)s', | |
default=1000, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--batch_size', | |
help='Batch size. Default: %(default)s', | |
type=int, | |
default=1, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--label', | |
help='Label to use for dlatent statistics gathering ' + \ | |
'(should be integer index of class). Default: no label.', | |
type=int, | |
default=None, | |
metavar='CLASS_INDEX' | |
) | |
parser.add_argument( | |
'--initial_learning_rate', | |
help='Initial learning rate of projection. Default: %(default)s', | |
default=0.1, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--initial_noise_factor', | |
help='Initial noise factor of projection. Default: %(default)s', | |
default=0.05, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--lr_rampdown_length', | |
help='Learning rate rampdown length for projection. ' + \ | |
'Should be in range [0, 1]. Default: %(default)s', | |
default=0.25, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--lr_rampup_length', | |
help='Learning rate rampup length for projection. ' + \ | |
'Should be in range [0, 1]. Default: %(default)s', | |
default=0.05, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--noise_ramp_length', | |
help='Learning rate rampdown length for projection. ' + \ | |
'Should be in range [0, 1]. Default: %(default)s', | |
default=0.75, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--regularize_noise_weight', | |
help='The weight for noise regularization. Default: %(default)s', | |
default=1e5, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--output', | |
help='Root directory for run results. Default: %(default)s', | |
type=str, | |
default='./results', | |
metavar='DIR' | |
) | |
parser.add_argument( | |
'--num_snapshots', | |
help='Number of snapshots. Default: %(default)s', | |
type=int, | |
default=5, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--pixel_min', | |
help='Minumum of the value range of pixels in generated images. ' + \ | |
'Default: %(default)s', | |
default=-1, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--pixel_max', | |
help='Maximum of the value range of pixels in generated images. ' + \ | |
'Default: %(default)s', | |
default=1, | |
type=float, | |
metavar='VALUE' | |
) | |
parser.add_argument( | |
'--gpu', | |
help='CUDA device indices (given as separate ' + \ | |
'values if multiple, i.e. "--gpu 0 1"). Default: Use CPU', | |
type=int, | |
default=[], | |
nargs='*', | |
metavar='INDEX' | |
) | |
#---------------------------------------------------------------------------- | |
def get_arg_parser(): | |
parser = argparse.ArgumentParser( | |
description=_description, | |
epilog=_examples, | |
formatter_class=argparse.RawDescriptionHelpFormatter | |
) | |
range_desc = 'NOTE: This is a single argument, where list ' + \ | |
'elements are separated by "," and ranges are defined as "a-b". ' + \ | |
'Only integers are allowed.' | |
subparsers = parser.add_subparsers(help='Sub-commands', dest='command') | |
project_generated_images_parser = subparsers.add_parser( | |
'project_generated_images', help='Project generated images') | |
project_generated_images_parser.add_argument( | |
'--seeds', | |
help='List of random seeds for generating images. ' + \ | |
'Default: 66,230,389,1518. ' + range_desc, | |
type=utils.range_type, | |
default=[66, 230, 389, 1518], | |
metavar='RANGE' | |
) | |
project_generated_images_parser.add_argument( | |
'--truncation_psi', | |
help='Truncation psi. Default: %(default)s', | |
type=float, | |
default=1.0, | |
metavar='VALUE' | |
) | |
_add_shared_arguments(project_generated_images_parser) | |
project_real_images_parser = subparsers.add_parser( | |
'project_real_images', help='Project real images') | |
project_real_images_parser.add_argument( | |
'--data_dir', | |
help='Dataset root directory', | |
type=str, | |
required=True, | |
metavar='DIR' | |
) | |
project_real_images_parser.add_argument( | |
'--seed', | |
help='When there are more images available than ' + \ | |
'the number that is going to be projected this ' + \ | |
'seed is used for picking samples. Default: %(default)s', | |
type=int, | |
default=1234, | |
metavar='VALUE' | |
) | |
project_real_images_parser.add_argument( | |
'--num_images', | |
type=int, | |
help='Number of images to project. Default: %(default)s', | |
default=3, | |
metavar='VALUE' | |
) | |
_add_shared_arguments(project_real_images_parser) | |
return parser | |
#---------------------------------------------------------------------------- | |
def project_images(G, images, name_prefix, args): | |
device = torch.device(args.gpu[0] if args.gpu else 'cpu') | |
if device.index is not None: | |
torch.cuda.set_device(device.index) | |
if len(args.gpu) > 1: | |
warnings.warn( | |
'Multi GPU is not available for projection. ' + \ | |
'Using device {}'.format(device) | |
) | |
G = utils.unwrap_module(G).to(device) | |
lpips_model = stylegan2.external_models.lpips.LPIPS_VGG16( | |
pixel_min=args.pixel_min, pixel_max=args.pixel_max) | |
proj = stylegan2.project.Projector( | |
G=G, | |
dlatent_avg_samples=10000, | |
dlatent_avg_label=args.label, | |
dlatent_device=device, | |
dlatent_batch_size=1024, | |
lpips_model=lpips_model, | |
lpips_size=256 | |
) | |
for i in range(0, len(images), args.batch_size): | |
target = images[i: i + args.batch_size] | |
proj.start( | |
target=target, | |
num_steps=args.num_steps, | |
initial_learning_rate=args.initial_learning_rate, | |
initial_noise_factor=args.initial_noise_factor, | |
lr_rampdown_length=args.lr_rampdown_length, | |
lr_rampup_length=args.lr_rampup_length, | |
noise_ramp_length=args.noise_ramp_length, | |
regularize_noise_weight=args.regularize_noise_weight, | |
verbose=True, | |
verbose_prefix='Projecting image(s) {}/{}'.format( | |
i * args.batch_size + len(target), len(images)) | |
) | |
snapshot_steps = set( | |
args.num_steps - np.linspace( | |
0, args.num_steps, args.num_snapshots, endpoint=False, dtype=int)) | |
for k, image in enumerate( | |
utils.tensor_to_PIL(target, pixel_min=args.pixel_min, pixel_max=args.pixel_max)): | |
image.save(os.path.join(args.output, name_prefix[i + k] + 'target.png')) | |
for j in range(args.num_steps): | |
proj.step() | |
if j in snapshot_steps: | |
generated = utils.tensor_to_PIL( | |
proj.generate(), pixel_min=args.pixel_min, pixel_max=args.pixel_max) | |
for k, image in enumerate(generated): | |
image.save(os.path.join( | |
args.output, name_prefix[i + k] + 'step%04d.png' % (j + 1))) | |
#---------------------------------------------------------------------------- | |
def project_generated_images(G, args): | |
latent_size, label_size = G.latent_size, G.label_size | |
device = torch.device(args.gpu[0] if args.gpu else 'cpu') | |
if device.index is not None: | |
torch.cuda.set_device(device.index) | |
G.to(device) | |
if len(args.gpu) > 1: | |
warnings.warn( | |
'Noise can not be randomized based on the seed ' + \ | |
'when using more than 1 GPU device. Noise will ' + \ | |
'now be randomized from default random state.' | |
) | |
G.random_noise() | |
G = torch.nn.DataParallel(G, device_ids=args.gpu) | |
else: | |
noise_reference = G.static_noise() | |
def get_batch(seeds): | |
latents = [] | |
labels = [] | |
if len(args.gpu) <= 1: | |
noise_tensors = [[] for _ in noise_reference] | |
for seed in seeds: | |
rnd = np.random.RandomState(seed) | |
latents.append(torch.from_numpy(rnd.randn(latent_size))) | |
if len(args.gpu) <= 1: | |
for i, ref in enumerate(noise_reference): | |
noise_tensors[i].append( | |
torch.from_numpy(rnd.randn(*ref.size()[1:]))) | |
if label_size: | |
labels.append(torch.tensor([rnd.randint(0, label_size)])) | |
latents = torch.stack(latents, dim=0).to(device=device, dtype=torch.float32) | |
if labels: | |
labels = torch.cat(labels, dim=0).to(device=device, dtype=torch.int64) | |
else: | |
labels = None | |
if len(args.gpu) <= 1: | |
noise_tensors = [ | |
torch.stack(noise, dim=0).to(device=device, dtype=torch.float32) | |
for noise in noise_tensors | |
] | |
else: | |
noise_tensors = None | |
return latents, labels, noise_tensors | |
images = [] | |
progress = utils.ProgressWriter(len(args.seeds)) | |
progress.write('Generating images...', step=False) | |
for i in range(0, len(args.seeds), args.batch_size): | |
latents, labels, noise_tensors = get_batch(args.seeds[i: i + args.batch_size]) | |
if noise_tensors is not None: | |
G.static_noise(noise_tensors=noise_tensors) | |
with torch.no_grad(): | |
images.append(G(latents, labels=labels)) | |
progress.step() | |
images = torch.cat(images, dim=0) | |
progress.write('Done!', step=False) | |
progress.close() | |
name_prefix = ['seed%04d-' % seed for seed in args.seeds] | |
project_images(G, images, name_prefix, args) | |
#---------------------------------------------------------------------------- | |
def project_real_images(G, args): | |
device = torch.device(args.gpu[0] if args.gpu else 'cpu') | |
print('Loading images from "%s"...' % args.data_dir) | |
dataset = utils.ImageFolder( | |
args.data_dir, pixel_min=args.pixel_min, pixel_max=args.pixel_max) | |
rnd = np.random.RandomState(args.seed) | |
indices = rnd.choice( | |
len(dataset), size=min(args.num_images, len(dataset)), replace=False) | |
images = [] | |
for i in indices: | |
data = dataset[i] | |
if isinstance(data, (tuple, list)): | |
data = data[0] | |
images.append(data) | |
images = torch.stack(images).to(device) | |
name_prefix = ['image%04d-' % i for i in indices] | |
print('Done!') | |
project_images(G, images, name_prefix, args) | |
#---------------------------------------------------------------------------- | |
def main(): | |
args = get_arg_parser().parse_args() | |
assert args.command, 'Missing subcommand.' | |
assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \ | |
'--output argument should specify a directory, not a file.' | |
if not os.path.exists(args.output): | |
os.makedirs(args.output) | |
G = stylegan2.models.load(args.network) | |
assert isinstance(G, stylegan2.models.Generator), 'Model type has to be ' + \ | |
'stylegan2.models.Generator. Found {}.'.format(type(G)) | |
if args.command == 'project_generated_images': | |
project_generated_images(G, args) | |
elif args.command == 'project_real_images': | |
project_real_images(G, args) | |
else: | |
raise TypeError('Unkown command {}'.format(args.command)) | |
if __name__ == '__main__': | |
main() | |