File size: 2,724 Bytes
954caab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pickle
from pathlib import Path

import torch
from torchvision.utils import save_image


def add_args(parser):
    """
    Add arguments for sampling to a parser
    """

    parser.add_argument("--name", required=True, type=str)
    parser.add_argument(
        "--save_dir",
        type=str,
        default="results",
        help="Location to samples and metadata",
    )
    parser.add_argument(
        "--prompts",
        required=True,
        type=str,
        nargs="+",
        help="Prompts to use, corresponding to each view.",
    )
    parser.add_argument(
        "--views",
        required=True,
        type=str,
        nargs="+",
        help="Name of views to use. See `get_views` in `views.py`.",
    )
    parser.add_argument(
        "--style", default="", type=str, help="Optional string to prepend prompt with"
    )
    parser.add_argument("--num_inference_steps", type=int, default=100)
    parser.add_argument("--num_samples", type=int, default=100)
    parser.add_argument("--reduction", type=str, default="mean")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--guidance_scale", type=float, default=7.0)
    parser.add_argument(
        "--noise_level", type=int, default=50, help="Noise level for stage 2"
    )
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument(
        "--save_metadata",
        action="store_true",
        help="If true, save metadata about the views. May use lots of disk space, particular for permutation views.",
    )

    return parser


def save_illusion(image, views, sample_dir):
    """
    Saves the illusion (`image`), as well as all views of the illusion

    image (torch.tensor) :
        Tensor of shape (1,3,H,W) representing the image

    views (views.BaseView) :
        Represents the view, inherits from BaseView

    sample_dir (pathlib.Path) :
        pathlib Path object, representing the directory to save to
    """

    size = image.shape[-1]

    # Save illusion
    save_image(image / 2.0 + 0.5, sample_dir / f"sample_{size}.png", padding=0)

    # Save views of the illusion
    im_views = torch.stack([view.view(image[0]) for view in views])
    save_image(im_views / 2.0 + 0.5, sample_dir / f"sample_{size}.views.png", padding=0)


def save_metadata(views, args, save_dir):
    """
    Saves the following the sample_dir
        1) pickled view object
        2) args for the illusion
    """

    metadata = {"views": views, "args": args}
    with open(save_dir / "metadata.pkl", "wb") as f:
        pickle.dump(metadata, f)


def get_courier_font_path():
    font_path = Path(__file__).parent / "assets" / "CourierPrime-Regular.ttf"
    return str(font_path)