|
import pickle |
|
from pathlib import Path |
|
|
|
import torch |
|
from torchvision.utils import save_image |
|
|
|
|
|
def add_args(parser): |
|
""" |
|
Add arguments for sampling to a parser |
|
""" |
|
|
|
parser.add_argument("--name", required=True, type=str) |
|
parser.add_argument( |
|
"--save_dir", |
|
type=str, |
|
default="results", |
|
help="Location to samples and metadata", |
|
) |
|
parser.add_argument( |
|
"--prompts", |
|
required=True, |
|
type=str, |
|
nargs="+", |
|
help="Prompts to use, corresponding to each view.", |
|
) |
|
parser.add_argument( |
|
"--views", |
|
required=True, |
|
type=str, |
|
nargs="+", |
|
help="Name of views to use. See `get_views` in `views.py`.", |
|
) |
|
parser.add_argument( |
|
"--style", default="", type=str, help="Optional string to prepend prompt with" |
|
) |
|
parser.add_argument("--num_inference_steps", type=int, default=100) |
|
parser.add_argument("--num_samples", type=int, default=100) |
|
parser.add_argument("--reduction", type=str, default="mean") |
|
parser.add_argument("--seed", type=int, default=0) |
|
parser.add_argument("--guidance_scale", type=float, default=7.0) |
|
parser.add_argument( |
|
"--noise_level", type=int, default=50, help="Noise level for stage 2" |
|
) |
|
parser.add_argument("--device", type=str, default="cpu") |
|
parser.add_argument( |
|
"--save_metadata", |
|
action="store_true", |
|
help="If true, save metadata about the views. May use lots of disk space, particular for permutation views.", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def save_illusion(image, views, sample_dir): |
|
""" |
|
Saves the illusion (`image`), as well as all views of the illusion |
|
|
|
image (torch.tensor) : |
|
Tensor of shape (1,3,H,W) representing the image |
|
|
|
views (views.BaseView) : |
|
Represents the view, inherits from BaseView |
|
|
|
sample_dir (pathlib.Path) : |
|
pathlib Path object, representing the directory to save to |
|
""" |
|
|
|
size = image.shape[-1] |
|
|
|
|
|
save_image(image / 2.0 + 0.5, sample_dir / f"sample_{size}.png", padding=0) |
|
|
|
|
|
im_views = torch.stack([view.view(image[0]) for view in views]) |
|
save_image(im_views / 2.0 + 0.5, sample_dir / f"sample_{size}.views.png", padding=0) |
|
|
|
|
|
def save_metadata(views, args, save_dir): |
|
""" |
|
Saves the following the sample_dir |
|
1) pickled view object |
|
2) args for the illusion |
|
""" |
|
|
|
metadata = {"views": views, "args": args} |
|
with open(save_dir / "metadata.pkl", "wb") as f: |
|
pickle.dump(metadata, f) |
|
|
|
|
|
def get_courier_font_path(): |
|
font_path = Path(__file__).parent / "assets" / "CourierPrime-Regular.ttf" |
|
return str(font_path) |
|
|