|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], "..")) |
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from argparse import ArgumentParser |
|
import torch |
|
|
|
import data_utils.utils as data_utils |
|
import inference.utils as inference_utils |
|
import BigGAN_PyTorch.utils as biggan_utils |
|
from data_utils.datasets_common import pil_loader |
|
import torchvision.transforms as transforms |
|
import time |
|
|
|
|
|
def get_data(root_path, model, resolution, which_dataset, visualize_instance_images): |
|
data_path = os.path.join(root_path, "stored_instances") |
|
if model == "cc_icgan": |
|
feature_extractor = "classification" |
|
else: |
|
feature_extractor = "selfsupervised" |
|
filename = "%s_res%i_rn50_%s_kmeans_k1000_instance_features.npy" % ( |
|
which_dataset, |
|
resolution, |
|
feature_extractor, |
|
) |
|
|
|
data = np.load(os.path.join(data_path, filename), allow_pickle=True).item() |
|
|
|
transform_list = None |
|
if visualize_instance_images: |
|
|
|
transform_list = transforms.Compose( |
|
[data_utils.CenterCropLongEdge(), transforms.Resize(resolution)] |
|
) |
|
return data, transform_list |
|
|
|
|
|
def get_model(exp_name, root_path, backbone, device="cuda"): |
|
parser = biggan_utils.prepare_parser() |
|
parser = biggan_utils.add_sample_parser(parser) |
|
parser = inference_utils.add_backbone_parser(parser) |
|
|
|
args = ["--experiment_name", exp_name] |
|
args += ["--base_root", root_path] |
|
args += ["--model_backbone", backbone] |
|
|
|
config = vars(parser.parse_args(args=args)) |
|
|
|
|
|
config = biggan_utils.update_config_roots(config, change_weight_folder=False) |
|
generator, config = inference_utils.load_model_inference(config, device=device) |
|
biggan_utils.count_parameters(generator) |
|
generator.eval() |
|
|
|
return generator |
|
|
|
|
|
def get_conditionings(test_config, generator, data): |
|
|
|
z = torch.empty( |
|
test_config["num_imgs_gen"] * test_config["num_conditionings_gen"], |
|
generator.z_dim if config["model_backbone"] == "stylegan2" else generator.dim_z, |
|
).normal_(mean=0, std=test_config["z_var"]) |
|
|
|
|
|
if test_config["num_conditionings_gen"] > 1: |
|
total_idxs = np.random.choice( |
|
range(1000), test_config["num_conditionings_gen"], replace=False |
|
) |
|
|
|
|
|
all_feats, all_img_paths, all_labels = [], [], [] |
|
for counter in range(test_config["num_conditionings_gen"]): |
|
|
|
if test_config["index"] is not None: |
|
idx = test_config["index"] |
|
else: |
|
idx = total_idxs[counter] |
|
|
|
if test_config["visualize_instance_images"]: |
|
all_img_paths.append(data["image_path"][idx]) |
|
|
|
all_feats.append( |
|
torch.FloatTensor(data["instance_features"][idx : idx + 1]).repeat( |
|
test_config["num_imgs_gen"], 1 |
|
) |
|
) |
|
|
|
if test_config["swap_target"] is not None: |
|
|
|
label_int = test_config["swap_target"] |
|
else: |
|
|
|
label_int = int(data["labels"][idx]) |
|
|
|
labels = None |
|
if test_config["model_backbone"] == "stylegan2": |
|
dim_labels = 1000 |
|
labels = torch.eye(dim_labels)[torch.LongTensor([label_int])].repeat( |
|
test_config["num_imgs_gen"], 1 |
|
) |
|
else: |
|
if test_config["model"] == "cc_icgan": |
|
labels = torch.LongTensor([label_int]).repeat( |
|
test_config["num_imgs_gen"] |
|
) |
|
all_labels.append(labels) |
|
|
|
all_feats = torch.cat(all_feats) |
|
if all_labels[0] is not None: |
|
all_labels = torch.cat(all_labels) |
|
else: |
|
all_labels = None |
|
return z, all_feats, all_labels, all_img_paths |
|
|
|
|
|
def main(test_config): |
|
suffix = ( |
|
"_nofeataug" |
|
if test_config["resolution"] == 256 |
|
and test_config["trained_dataset"] == "imagenet" |
|
else "" |
|
) |
|
exp_name = "%s_%s_%s_res%i%s" % ( |
|
test_config["model"], |
|
test_config["model_backbone"], |
|
test_config["trained_dataset"], |
|
test_config["resolution"], |
|
suffix, |
|
) |
|
device = "cuda" |
|
|
|
data, transform_list = get_data( |
|
test_config["root_path"], |
|
test_config["model"], |
|
test_config["resolution"], |
|
test_config["which_dataset"], |
|
test_config["visualize_instance_images"], |
|
) |
|
|
|
|
|
generator = get_model( |
|
exp_name, test_config["root_path"], test_config["model_backbone"], device=device |
|
) |
|
|
|
|
|
|
|
|
|
z, all_feats, all_labels, all_img_paths = get_conditionings( |
|
test_config, generator, data |
|
) |
|
|
|
|
|
all_generated_images = [] |
|
with torch.no_grad(): |
|
num_batches = 1 + (z.shape[0]) // test_config["batch_size"] |
|
for i in range(num_batches): |
|
start = test_config["batch_size"] * i |
|
end = min( |
|
test_config["batch_size"] * i + test_config["batch_size"], z.shape[0] |
|
) |
|
if all_labels is not None: |
|
labels_ = all_labels[start:end].to(device) |
|
else: |
|
labels_ = None |
|
gen_img = generator( |
|
z[start:end].to(device), labels_, all_feats[start:end].to(device) |
|
) |
|
if test_config["model_backbone"] == "biggan": |
|
gen_img = ((gen_img * 0.5 + 0.5) * 255).int() |
|
elif test_config["model_backbone"] == "stylegan2": |
|
gen_img = torch.clamp((gen_img * 127.5 + 128), 0, 255).int() |
|
all_generated_images.append(gen_img.cpu()) |
|
all_generated_images = torch.cat(all_generated_images) |
|
all_generated_images = all_generated_images.permute(0, 2, 3, 1).numpy() |
|
|
|
big_plot = [] |
|
for i in range(0, test_config["num_conditionings_gen"]): |
|
row = [] |
|
for j in range(0, test_config["num_imgs_gen"]): |
|
subplot_idx = (i * test_config["num_imgs_gen"]) + j |
|
row.append(all_generated_images[subplot_idx]) |
|
row = np.concatenate(row, axis=1) |
|
big_plot.append(row) |
|
big_plot = np.concatenate(big_plot, axis=0) |
|
|
|
|
|
if test_config["visualize_instance_images"]: |
|
all_gt_imgs = [] |
|
for i in range(0, len(all_img_paths)): |
|
all_gt_imgs.append( |
|
np.array( |
|
transform_list( |
|
pil_loader( |
|
os.path.join(test_config["dataset_path"], all_img_paths[i]) |
|
) |
|
) |
|
).astype(np.uint8) |
|
) |
|
all_gt_imgs = np.concatenate(all_gt_imgs, axis=0) |
|
white_space = ( |
|
np.ones((all_gt_imgs.shape[0], 20, all_gt_imgs.shape[2])) * 255 |
|
).astype(np.uint8) |
|
big_plot = np.concatenate([all_gt_imgs, white_space, big_plot], axis=1) |
|
|
|
plt.figure( |
|
figsize=( |
|
5 * test_config["num_imgs_gen"], |
|
5 * test_config["num_conditionings_gen"], |
|
) |
|
) |
|
plt.imshow(big_plot) |
|
plt.axis("off") |
|
|
|
fig_path = "%s_Generations_with_InstanceDataset_%s%s%s_zvar%0.2f.png" % ( |
|
exp_name, |
|
test_config["which_dataset"], |
|
"_index" + str(test_config["index"]) |
|
if test_config["index"] is not None |
|
else "", |
|
"_class_idx" + str(test_config["swap_target"]) |
|
if test_config["swap_target"] is not None |
|
else "", |
|
test_config["z_var"], |
|
) |
|
plt.savefig(fig_path, dpi=600, bbox_inches="tight", pad_inches=0) |
|
|
|
print("Done! Figure saved as %s" % (fig_path)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser( |
|
description="Generate and save images using pre-trained models" |
|
) |
|
|
|
parser.add_argument( |
|
"--root_path", |
|
type=str, |
|
required=True, |
|
help="Path where pretrained models + instance features have been downloaded.", |
|
) |
|
parser.add_argument( |
|
"--which_dataset", |
|
type=str, |
|
default="imagenet", |
|
choices=["imagenet", "coco"], |
|
help="Dataset to sample instances from.", |
|
) |
|
parser.add_argument( |
|
"--trained_dataset", |
|
type=str, |
|
default="imagenet", |
|
choices=["imagenet", "coco"], |
|
help="Dataset in which the model has been trained on.", |
|
) |
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default="icgan", |
|
choices=["icgan", "cc_icgan"], |
|
help="Model type.", |
|
) |
|
parser.add_argument( |
|
"--model_backbone", |
|
type=str, |
|
default="biggan", |
|
choices=["biggan", "stylegan2"], |
|
help="Model backbone type.", |
|
) |
|
parser.add_argument( |
|
"--resolution", |
|
type=int, |
|
default=256, |
|
help="Resolution to generate images with " "(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--z_var", type=float, default=1.0, help="Noise variance: %(default)s)" |
|
) |
|
parser.add_argument("--batch_size", type=int, default=16, help="Batch size.") |
|
parser.add_argument( |
|
"--num_imgs_gen", |
|
type=int, |
|
default=5, |
|
help="Number of images to generate with different noise vectors, " |
|
"given an input conditioning.", |
|
) |
|
parser.add_argument( |
|
"--num_conditionings_gen", |
|
type=int, |
|
default=5, |
|
help="Number of conditionings to generate with." |
|
" Use `num_imgs_gen` to control the number of generated samples per conditioning", |
|
) |
|
parser.add_argument( |
|
"--index", |
|
type=int, |
|
default=None, |
|
help="Index of the stored instance to use as conditioning [0,1000)." |
|
" Mutually exclusive with `num_conditionings_gen!=1`", |
|
) |
|
parser.add_argument( |
|
"--swap_target", |
|
type=int, |
|
default=None, |
|
help="For class-conditional IC-GAN, we can choose to swap the target for a different one." |
|
" If swap_target=None, the original label from the instance is used. " |
|
"If swap_target is in [0,1000), a specific ImageNet class is used instead.", |
|
) |
|
|
|
parser.add_argument( |
|
"--visualize_instance_images", |
|
action="store_true", |
|
default=False, |
|
help="Also visualize the ground-truth image corresponding to the instance conditioning " |
|
"(requires a path to the ImageNet dataset)", |
|
) |
|
parser.add_argument( |
|
"--dataset_path", |
|
type=str, |
|
default="", |
|
help="Only needed if visualize_instance_images=True." |
|
" Folder where to find the dataset ground-truth images.", |
|
) |
|
|
|
config = vars(parser.parse_args()) |
|
|
|
if config["index"] is not None and config["num_conditionings_gen"] != 1: |
|
raise ValueError( |
|
"If a specific feature vector (specificed by --index) " |
|
"wants to be used to sample images from, num_conditionings_gen" |
|
" needs to be set to 1" |
|
) |
|
if config["swap_target"] is not None and config["model"] == "icgan": |
|
raise ValueError( |
|
'Cannot specify a class label for IC-GAN! Only use "swap_target" with --model=cc_igan. ' |
|
) |
|
main(config) |
|
|