import warnings import argparse import os from PIL import Image import numpy as np import torch import stylegan2 from stylegan2 import utils from huggingface_hub import hf_hub_download import gradio as gr import re from types import SimpleNamespace #Edited from run_generator.py to return PIL images instead of saving them to the disk. def generate_images(G, args): latent_size, label_size = G.latent_size, G.label_size device = torch.device(args.gpu[0] if args.gpu else 'cpu') if device.index is not None: torch.cuda.set_device(device.index) G.to(device) if args.truncation_psi != 1: G.set_truncation(truncation_psi=args.truncation_psi) if len(args.gpu) > 1: warnings.warn( 'Noise can not be randomized based on the seed ' + 'when using more than 1 GPU device. Noise will ' + 'now be randomized from default random state.' ) G.random_noise() G = torch.nn.DataParallel(G, device_ids=args.gpu) else: noise_reference = G.static_noise() def get_batch(seeds): latents = [] labels = [] if len(args.gpu) <= 1: noise_tensors = [[] for _ in noise_reference] for seed in seeds: rnd = np.random.default_rng(seed) latents.append(torch.from_numpy(rnd.standard_normal(latent_size))) if len(args.gpu) <= 1: for i, ref in enumerate(noise_reference): noise_tensors[i].append( torch.from_numpy(rnd.standard_normal(tuple([*ref.size()[1:]]))) ) if label_size: labels.append(torch.tensor([rnd.integers(0, label_size)])) latents = torch.stack(latents, dim=0).to( device=device, dtype=torch.float32) if labels: labels = torch.cat(labels, dim=0).to( device=device, dtype=torch.int64) else: labels = None if len(args.gpu) <= 1: noise_tensors = [ torch.stack(noise, dim=0).to( device=device, dtype=torch.float32) for noise in noise_tensors ] else: noise_tensors = None return latents, labels, noise_tensors return_images = [] for i in range(0, len(args.seeds), args.batch_size): latents, labels, noise_tensors = get_batch( args.seeds[i: i + args.batch_size]) if noise_tensors is not None: G.static_noise(noise_tensors=noise_tensors) with torch.no_grad(): generated = G(latents, labels=labels) images = utils.tensor_to_PIL( generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max) return_images.extend(images) return return_images #---------------------------------------------------------------------------- interface_modelversion_labels = [ "TWDNEv3 iteration 24664 (best and current version on TWDNE)", "TWDNEv3 iteration 18528 (the most used version on the Internet)", "TWDNEv3 iteration 17325" ] def inference(seed, truncation_psi, modelversion_label): model_iteration = re.search("TWDNEv3 iteration (\d{5})", modelversion_label).group(1) G = stylegan2.models.load( hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", f"iteration-{model_iteration}/Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN']) ) G.eval() return generate_images( G, SimpleNamespace(**{ 'truncation_psi': truncation_psi, 'seeds': [seed], 'batch_size': 1, 'pixel_min': -1, 'pixel_max': 1, 'gpu': [] }) #Replace ArgumentParser at run_generator.py )[0] title = "TWDNEv3 CPU Generator" description = "Gradio Demo for TWDNEv3 CPU Generator (stylegan2_pytorch port)" article = "" gr.Interface( inference, [ gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"), gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'), gr.Radio( interface_modelversion_labels, value="TWDNEv3 iteration 24664 (best and current version on TWDNE)", type="value", label="Model versions" ) ], gr.outputs.Image(type="pil"), title=title,description=description,article=article,allow_flagging=False,allow_screenshot=False ).launch()