callum-canavan
commited on
Commit
·
954caab
1
Parent(s):
b4209f3
Add helpers, change to hot dog example
Browse files- .gitignore +3 -0
- app.py +20 -4
- diffuse.py +34 -0
- generate.py +79 -0
- requirements.txt +165 -0
- visual_anagrams/__init__.py +0 -0
- visual_anagrams/samplers.py +232 -0
- visual_anagrams/utils.py +93 -0
- visual_anagrams/views/__init__.py +46 -0
- visual_anagrams/views/jigsaw_helpers.py +35 -0
- visual_anagrams/views/permutations.py +242 -0
- visual_anagrams/views/view_base.py +49 -0
- visual_anagrams/views/view_flip.py +30 -0
- visual_anagrams/views/view_identity.py +11 -0
- visual_anagrams/views/view_inner_circle.py +56 -0
- visual_anagrams/views/view_jigsaw.py +222 -0
- visual_anagrams/views/view_negate.py +41 -0
- visual_anagrams/views/view_patch_permute.py +154 -0
- visual_anagrams/views/view_permute.py +91 -0
- visual_anagrams/views/view_rotate.py +87 -0
- visual_anagrams/views/view_skew.py +55 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
env/
|
2 |
+
__pycache__/
|
3 |
+
assets/
|
app.py
CHANGED
@@ -1,9 +1,25 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
|
|
|
3 |
|
4 |
-
def greet(name):
|
5 |
-
return "Hello " + name + "!!"
|
6 |
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
|
4 |
+
pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
|
5 |
|
|
|
|
|
6 |
|
7 |
+
def predict(input_img):
|
8 |
+
predictions = pipeline(input_img)
|
9 |
+
return input_img, {p["label"]: p["score"] for p in predictions}
|
10 |
|
11 |
+
|
12 |
+
gradio_app = gr.Interface(
|
13 |
+
predict,
|
14 |
+
inputs=gr.Image(
|
15 |
+
label="Select hot dog candidate", sources=["upload", "webcam"], type="pil"
|
16 |
+
),
|
17 |
+
outputs=[
|
18 |
+
gr.Image(label="Processed Image"),
|
19 |
+
gr.Label(label="Result", num_top_classes=2),
|
20 |
+
],
|
21 |
+
title="Hot Dog? Or Not?",
|
22 |
+
)
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
gradio_app.launch()
|
diffuse.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline
|
2 |
+
from diffusers.utils import pt_to_pil
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# stage 1
|
6 |
+
stage_1 = DiffusionPipeline.from_pretrained(
|
7 |
+
"DeepFloyd/IF-I-M-v1.0", variant="fp16", torch_dtype=torch.float16
|
8 |
+
)
|
9 |
+
stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
|
10 |
+
stage_1.enable_model_cpu_offload()
|
11 |
+
|
12 |
+
# stage 2
|
13 |
+
stage_2 = DiffusionPipeline.from_pretrained(
|
14 |
+
"DeepFloyd/IF-II-M-v1.0",
|
15 |
+
text_encoder=None,
|
16 |
+
variant="fp16",
|
17 |
+
torch_dtype=torch.float16,
|
18 |
+
)
|
19 |
+
stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
|
20 |
+
stage_2.enable_model_cpu_offload()
|
21 |
+
|
22 |
+
# stage 3
|
23 |
+
safety_modules = {
|
24 |
+
"feature_extractor": stage_1.feature_extractor,
|
25 |
+
"safety_checker": stage_1.safety_checker,
|
26 |
+
"watermarker": stage_1.watermarker,
|
27 |
+
}
|
28 |
+
stage_3 = DiffusionPipeline.from_pretrained(
|
29 |
+
"stabilityai/stable-diffusion-x4-upscaler",
|
30 |
+
**safety_modules,
|
31 |
+
torch_dtype=torch.float16
|
32 |
+
)
|
33 |
+
stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
|
34 |
+
stage_3.enable_model_cpu_offload()
|
generate.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import DiffusionPipeline
|
6 |
+
|
7 |
+
from visual_anagrams.views import get_views
|
8 |
+
from visual_anagrams.samplers import sample_stage_1, sample_stage_2
|
9 |
+
from visual_anagrams.utils import add_args, save_illusion, save_metadata
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
# Parse args
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser = add_args(parser)
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
# Do admin stuff
|
19 |
+
save_dir = Path(args.save_dir) / args.name
|
20 |
+
save_dir.mkdir(exist_ok=True, parents=True)
|
21 |
+
|
22 |
+
# Make models
|
23 |
+
stage_1 = DiffusionPipeline.from_pretrained(
|
24 |
+
"DeepFloyd/IF-I-M-v1.0",
|
25 |
+
variant="fp16",
|
26 |
+
torch_dtype=torch.float16)
|
27 |
+
stage_2 = DiffusionPipeline.from_pretrained(
|
28 |
+
"DeepFloyd/IF-II-M-v1.0",
|
29 |
+
text_encoder=None,
|
30 |
+
variant="fp16",
|
31 |
+
torch_dtype=torch.float16,
|
32 |
+
)
|
33 |
+
stage_1.enable_model_cpu_offload()
|
34 |
+
stage_2.enable_model_cpu_offload()
|
35 |
+
stage_1 = stage_1.to(args.device)
|
36 |
+
stage_2 = stage_2.to(args.device)
|
37 |
+
|
38 |
+
# Get prompt embeddings
|
39 |
+
prompt_embeds = [stage_1.encode_prompt(f'{args.style} {p}'.strip()) for p in args.prompts]
|
40 |
+
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
|
41 |
+
prompt_embeds = torch.cat(prompt_embeds)
|
42 |
+
negative_prompt_embeds = torch.cat(negative_prompt_embeds) # These are just null embeds
|
43 |
+
|
44 |
+
# Get views
|
45 |
+
views = get_views(args.views)
|
46 |
+
|
47 |
+
# Save metadata
|
48 |
+
save_metadata(views, args, save_dir)
|
49 |
+
|
50 |
+
# Sample illusions
|
51 |
+
for i in range(args.num_samples):
|
52 |
+
# Admin stuff
|
53 |
+
generator = torch.manual_seed(args.seed + i)
|
54 |
+
sample_dir = save_dir / f'{i:04}'
|
55 |
+
sample_dir.mkdir(exist_ok=True, parents=True)
|
56 |
+
|
57 |
+
# Sample 64x64 image
|
58 |
+
image = sample_stage_1(stage_1,
|
59 |
+
prompt_embeds,
|
60 |
+
negative_prompt_embeds,
|
61 |
+
views,
|
62 |
+
num_inference_steps=args.num_inference_steps,
|
63 |
+
guidance_scale=args.guidance_scale,
|
64 |
+
reduction=args.reduction,
|
65 |
+
generator=generator)
|
66 |
+
save_illusion(image, views, sample_dir)
|
67 |
+
|
68 |
+
# Sample 256x256 image, by upsampling 64x64 image
|
69 |
+
image = sample_stage_2(stage_2,
|
70 |
+
image,
|
71 |
+
prompt_embeds,
|
72 |
+
negative_prompt_embeds,
|
73 |
+
views,
|
74 |
+
num_inference_steps=args.num_inference_steps,
|
75 |
+
guidance_scale=args.guidance_scale,
|
76 |
+
reduction=args.reduction,
|
77 |
+
noise_level=args.noise_level,
|
78 |
+
generator=generator)
|
79 |
+
save_illusion(image, views, sample_dir)
|
requirements.txt
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.4.0
|
2 |
+
aiohttp==3.8.5
|
3 |
+
aiosignal==1.3.1
|
4 |
+
annotated-types==0.5.0
|
5 |
+
anyio==3.7.1
|
6 |
+
argcomplete @ file:///private/tmp/python-argcomplete-20231112-5493-8o8e4p/argcomplete-3.1.6
|
7 |
+
arrow==1.2.3
|
8 |
+
astroid==2.15.6
|
9 |
+
astunparse==1.6.3
|
10 |
+
async-timeout==4.0.3
|
11 |
+
attrs==23.1.0
|
12 |
+
aws-cdk-lib==2.104.0
|
13 |
+
aws-cdk.asset-awscli-v1==2.2.201
|
14 |
+
aws-cdk.asset-kubectl-v20==2.1.2
|
15 |
+
aws-cdk.asset-node-proxy-agent-v6==2.0.1
|
16 |
+
backoff==2.2.1
|
17 |
+
beautifulsoup4==4.12.2
|
18 |
+
black==23.9.1
|
19 |
+
blessed==1.20.0
|
20 |
+
cachetools==5.3.1
|
21 |
+
cattrs==23.1.2
|
22 |
+
certifi==2022.12.7
|
23 |
+
charset-normalizer==3.1.0
|
24 |
+
click==8.1.7
|
25 |
+
constructs==10.3.0
|
26 |
+
contourpy==1.0.7
|
27 |
+
croniter==1.4.1
|
28 |
+
cycler==0.11.0
|
29 |
+
dataclasses-json==0.6.1
|
30 |
+
dateutils==0.6.12
|
31 |
+
deepdiff==6.5.0
|
32 |
+
diffusers==0.24.0
|
33 |
+
dill==0.3.7
|
34 |
+
distlib==0.3.6
|
35 |
+
easydict==1.10
|
36 |
+
fastapi==0.103.1
|
37 |
+
filelock==3.9.0
|
38 |
+
flatbuffers==23.5.26
|
39 |
+
fonttools==4.39.3
|
40 |
+
frozenlist==1.4.0
|
41 |
+
fsspec==2023.9.0
|
42 |
+
gast==0.4.0
|
43 |
+
gitdb==4.0.10
|
44 |
+
GitPython==3.1.36
|
45 |
+
google-auth==2.22.0
|
46 |
+
google-auth-oauthlib==1.0.0
|
47 |
+
google-pasta==0.2.0
|
48 |
+
grpcio==1.57.0
|
49 |
+
h11==0.14.0
|
50 |
+
h5py==3.9.0
|
51 |
+
huggingface-hub==0.19.4
|
52 |
+
idna==3.4
|
53 |
+
importlib-metadata==6.9.0
|
54 |
+
importlib-resources==6.1.0
|
55 |
+
iniconfig==2.0.0
|
56 |
+
inquirer==3.1.3
|
57 |
+
isort==5.12.0
|
58 |
+
itsdangerous==2.1.2
|
59 |
+
Jinja2==3.1.2
|
60 |
+
joblib==1.3.2
|
61 |
+
jsii==1.91.0
|
62 |
+
jsonpatch==1.33
|
63 |
+
jsonpointer==2.4
|
64 |
+
keras==2.13.1
|
65 |
+
kiwisolver==1.4.4
|
66 |
+
langchain==0.0.330
|
67 |
+
langsmith==0.0.57
|
68 |
+
lazy-object-proxy==1.9.0
|
69 |
+
libclang==16.0.6
|
70 |
+
lightning==2.0.8
|
71 |
+
lightning-cloud==0.5.38
|
72 |
+
lightning-utilities==0.9.0
|
73 |
+
Markdown==3.4.4
|
74 |
+
markdown-it-py==3.0.0
|
75 |
+
MarkupSafe==2.1.2
|
76 |
+
marshmallow==3.20.1
|
77 |
+
matplotlib==3.7.2
|
78 |
+
mccabe==0.7.0
|
79 |
+
mdurl==0.1.2
|
80 |
+
mpmath==1.3.0
|
81 |
+
multidict==6.0.4
|
82 |
+
mypy-extensions==1.0.0
|
83 |
+
networkx==3.1
|
84 |
+
numpy==1.24.2
|
85 |
+
oauthlib==3.2.2
|
86 |
+
opencv-python==4.7.0.72
|
87 |
+
opt-einsum==3.3.0
|
88 |
+
ordered-set==4.1.0
|
89 |
+
packaging==23.1
|
90 |
+
pandas==2.0.3
|
91 |
+
pathspec==0.11.2
|
92 |
+
Pillow==9.5.0
|
93 |
+
platformdirs==3.1.0
|
94 |
+
pluggy==1.3.0
|
95 |
+
protobuf==4.24.0
|
96 |
+
psutil==5.9.5
|
97 |
+
publication==0.0.3
|
98 |
+
py-cpuinfo==9.0.0
|
99 |
+
pyasn1==0.5.0
|
100 |
+
pyasn1-modules==0.3.0
|
101 |
+
pybind11==2.11.1
|
102 |
+
pydantic==2.1.1
|
103 |
+
pydantic_core==2.4.0
|
104 |
+
Pygments==2.16.1
|
105 |
+
PyJWT==2.8.0
|
106 |
+
pylint==2.17.5
|
107 |
+
pyparsing==3.0.9
|
108 |
+
pytest==7.4.2
|
109 |
+
python-dateutil==2.8.2
|
110 |
+
python-dotenv==1.0.0
|
111 |
+
python-editor==1.0.4
|
112 |
+
python-multipart==0.0.6
|
113 |
+
pytorch-lightning==2.0.8
|
114 |
+
pytz==2023.3
|
115 |
+
PyYAML==6.0.1
|
116 |
+
readchar==4.0.5
|
117 |
+
regex==2023.10.3
|
118 |
+
requests==2.28.2
|
119 |
+
requests-oauthlib==1.3.1
|
120 |
+
rich==13.5.2
|
121 |
+
rsa==4.9
|
122 |
+
safetensors==0.4.1
|
123 |
+
scikit-learn==1.3.0
|
124 |
+
seaborn==0.12.2
|
125 |
+
six==1.16.0
|
126 |
+
smmap==5.0.0
|
127 |
+
sniffio==1.3.0
|
128 |
+
soupsieve==2.5
|
129 |
+
SQLAlchemy==2.0.23
|
130 |
+
starlette==0.27.0
|
131 |
+
starsessions==1.3.0
|
132 |
+
sympy==1.11.1
|
133 |
+
tenacity==8.2.3
|
134 |
+
tensorboard==2.13.0
|
135 |
+
tensorboard-data-server==0.7.1
|
136 |
+
tensorflow==2.13.0
|
137 |
+
tensorflow-estimator==2.13.0
|
138 |
+
termcolor==2.3.0
|
139 |
+
threadpoolctl==3.2.0
|
140 |
+
tokenizers==0.15.0
|
141 |
+
tomlkit==0.12.1
|
142 |
+
torch==2.0.1
|
143 |
+
torchaudio==2.0.2
|
144 |
+
torchmetrics==1.1.2
|
145 |
+
torchvision==0.15.2
|
146 |
+
tqdm==4.65.0
|
147 |
+
traitlets==5.10.0
|
148 |
+
transformers==4.35.2
|
149 |
+
typeguard==2.13.3
|
150 |
+
typing-inspect==0.9.0
|
151 |
+
typing_extensions==4.6.1
|
152 |
+
tzdata==2023.3
|
153 |
+
ultralytics==8.0.178
|
154 |
+
urllib3==1.26.15
|
155 |
+
uvicorn==0.23.2
|
156 |
+
virtualenv==20.20.0
|
157 |
+
wcwidth==0.2.6
|
158 |
+
websocket-client==1.6.3
|
159 |
+
websockets==11.0.3
|
160 |
+
Werkzeug==2.3.7
|
161 |
+
wrapt==1.15.0
|
162 |
+
yacs==0.1.8
|
163 |
+
yarl==1.9.2
|
164 |
+
yolov4==2.0.3
|
165 |
+
zipp==3.17.0
|
visual_anagrams/__init__.py
ADDED
File without changes
|
visual_anagrams/samplers.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from diffusers.utils.torch_utils import randn_tensor
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def sample_stage_1(model,
|
10 |
+
prompt_embeds,
|
11 |
+
negative_prompt_embeds,
|
12 |
+
views,
|
13 |
+
num_inference_steps=100,
|
14 |
+
guidance_scale=7.0,
|
15 |
+
reduction='mean',
|
16 |
+
generator=None):
|
17 |
+
|
18 |
+
# Params
|
19 |
+
num_images_per_prompt = 1
|
20 |
+
device = model.device
|
21 |
+
height = model.unet.config.sample_size
|
22 |
+
width = model.unet.config.sample_size
|
23 |
+
batch_size = 1 # TODO: Support larger batch sizes, maybe
|
24 |
+
num_prompts = prompt_embeds.shape[0]
|
25 |
+
assert num_prompts == len(views), \
|
26 |
+
"Number of prompts must match number of views!"
|
27 |
+
|
28 |
+
# For CFG
|
29 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
30 |
+
|
31 |
+
# Setup timesteps
|
32 |
+
model.scheduler.set_timesteps(num_inference_steps, device=device)
|
33 |
+
timesteps = model.scheduler.timesteps
|
34 |
+
|
35 |
+
# Make intermediate_images
|
36 |
+
noisy_images = model.prepare_intermediate_images(
|
37 |
+
batch_size * num_images_per_prompt,
|
38 |
+
model.unet.config.in_channels,
|
39 |
+
height,
|
40 |
+
width,
|
41 |
+
prompt_embeds.dtype,
|
42 |
+
device,
|
43 |
+
generator,
|
44 |
+
)
|
45 |
+
|
46 |
+
for i, t in enumerate(tqdm(timesteps)):
|
47 |
+
# Apply views to noisy_image
|
48 |
+
viewed_noisy_images = []
|
49 |
+
for view_fn in views:
|
50 |
+
viewed_noisy_images.append(view_fn.view(noisy_images[0]))
|
51 |
+
viewed_noisy_images = torch.stack(viewed_noisy_images)
|
52 |
+
|
53 |
+
# Duplicate inputs for CFG
|
54 |
+
# Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
|
55 |
+
model_input = torch.cat([viewed_noisy_images] * 2)
|
56 |
+
model_input = model.scheduler.scale_model_input(model_input, t)
|
57 |
+
|
58 |
+
# Predict noise estimate
|
59 |
+
noise_pred = model.unet(
|
60 |
+
model_input,
|
61 |
+
t,
|
62 |
+
encoder_hidden_states=prompt_embeds,
|
63 |
+
cross_attention_kwargs=None,
|
64 |
+
return_dict=False,
|
65 |
+
)[0]
|
66 |
+
|
67 |
+
# Extract uncond (neg) and cond noise estimates
|
68 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
69 |
+
|
70 |
+
# Invert the unconditional (negative) estimates
|
71 |
+
inverted_preds = []
|
72 |
+
for pred, view in zip(noise_pred_uncond, views):
|
73 |
+
inverted_pred = view.inverse_view(pred)
|
74 |
+
inverted_preds.append(inverted_pred)
|
75 |
+
noise_pred_uncond = torch.stack(inverted_preds)
|
76 |
+
|
77 |
+
# Invert the conditional estimates
|
78 |
+
inverted_preds = []
|
79 |
+
for pred, view in zip(noise_pred_text, views):
|
80 |
+
inverted_pred = view.inverse_view(pred)
|
81 |
+
inverted_preds.append(inverted_pred)
|
82 |
+
noise_pred_text = torch.stack(inverted_preds)
|
83 |
+
|
84 |
+
# Split into noise estimate and variance estimates
|
85 |
+
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
86 |
+
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
87 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
88 |
+
|
89 |
+
# Reduce predicted noise and variances
|
90 |
+
noise_pred = noise_pred.view(-1,num_prompts,3,64,64)
|
91 |
+
predicted_variance = predicted_variance.view(-1,num_prompts,3,64,64)
|
92 |
+
if reduction == 'mean':
|
93 |
+
noise_pred = noise_pred.mean(1)
|
94 |
+
predicted_variance = predicted_variance.mean(1)
|
95 |
+
elif reduction == 'alternate':
|
96 |
+
noise_pred = noise_pred[:,i%num_prompts]
|
97 |
+
predicted_variance = predicted_variance[:,i%num_prompts]
|
98 |
+
else:
|
99 |
+
raise ValueError('Reduction must be either `mean` or `alternate`')
|
100 |
+
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
101 |
+
|
102 |
+
# compute the previous noisy sample x_t -> x_t-1
|
103 |
+
noisy_images = model.scheduler.step(
|
104 |
+
noise_pred, t, noisy_images, generator=generator, return_dict=False
|
105 |
+
)[0]
|
106 |
+
|
107 |
+
# Return denoised images
|
108 |
+
return noisy_images
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def sample_stage_2(model,
|
118 |
+
image,
|
119 |
+
prompt_embeds,
|
120 |
+
negative_prompt_embeds,
|
121 |
+
views,
|
122 |
+
num_inference_steps=100,
|
123 |
+
guidance_scale=7.0,
|
124 |
+
reduction='mean',
|
125 |
+
noise_level=50,
|
126 |
+
generator=None):
|
127 |
+
|
128 |
+
# Params
|
129 |
+
batch_size = 1 # TODO: Support larger batch sizes, maybe
|
130 |
+
num_prompts = prompt_embeds.shape[0]
|
131 |
+
height = model.unet.config.sample_size
|
132 |
+
width = model.unet.config.sample_size
|
133 |
+
device = model.device
|
134 |
+
num_images_per_prompt = 1
|
135 |
+
|
136 |
+
# For CFG
|
137 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
138 |
+
|
139 |
+
# Get timesteps
|
140 |
+
model.scheduler.set_timesteps(num_inference_steps, device=device)
|
141 |
+
timesteps = model.scheduler.timesteps
|
142 |
+
|
143 |
+
num_channels = model.unet.config.in_channels // 2
|
144 |
+
noisy_images = model.prepare_intermediate_images(
|
145 |
+
batch_size * num_images_per_prompt,
|
146 |
+
num_channels,
|
147 |
+
height,
|
148 |
+
width,
|
149 |
+
prompt_embeds.dtype,
|
150 |
+
device,
|
151 |
+
generator,
|
152 |
+
)
|
153 |
+
|
154 |
+
# Prepare upscaled image and noise level
|
155 |
+
image = model.preprocess_image(image, num_images_per_prompt, device)
|
156 |
+
upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
|
157 |
+
|
158 |
+
noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
|
159 |
+
noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
|
160 |
+
upscaled = model.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
|
161 |
+
|
162 |
+
# Condition on noise level, for each model input
|
163 |
+
noise_level = torch.cat([noise_level] * num_prompts * 2)
|
164 |
+
|
165 |
+
# Denoising Loop
|
166 |
+
for i, t in enumerate(tqdm(timesteps)):
|
167 |
+
# Cat noisy image with upscaled conditioning image
|
168 |
+
model_input = torch.cat([noisy_images, upscaled], dim=1)
|
169 |
+
|
170 |
+
# Apply views to noisy_image
|
171 |
+
viewed_inputs = []
|
172 |
+
for view_fn in views:
|
173 |
+
viewed_inputs.append(view_fn.view(model_input[0]))
|
174 |
+
viewed_inputs = torch.stack(viewed_inputs)
|
175 |
+
|
176 |
+
# Duplicate inputs for CFG
|
177 |
+
# Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
|
178 |
+
model_input = torch.cat([viewed_inputs] * 2)
|
179 |
+
model_input = model.scheduler.scale_model_input(model_input, t)
|
180 |
+
|
181 |
+
# predict the noise residual
|
182 |
+
noise_pred = model.unet(
|
183 |
+
model_input,
|
184 |
+
t,
|
185 |
+
encoder_hidden_states=prompt_embeds,
|
186 |
+
class_labels=noise_level,
|
187 |
+
cross_attention_kwargs=None,
|
188 |
+
return_dict=False,
|
189 |
+
)[0]
|
190 |
+
|
191 |
+
# Extract uncond (neg) and cond noise estimates
|
192 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
193 |
+
|
194 |
+
# Invert the unconditional (negative) estimates
|
195 |
+
# TODO: pretty sure you can combine these into one loop
|
196 |
+
inverted_preds = []
|
197 |
+
for pred, view in zip(noise_pred_uncond, views):
|
198 |
+
inverted_pred = view.inverse_view(pred)
|
199 |
+
inverted_preds.append(inverted_pred)
|
200 |
+
noise_pred_uncond = torch.stack(inverted_preds)
|
201 |
+
|
202 |
+
# Invert the conditional estimates
|
203 |
+
inverted_preds = []
|
204 |
+
for pred, view in zip(noise_pred_text, views):
|
205 |
+
inverted_pred = view.inverse_view(pred)
|
206 |
+
inverted_preds.append(inverted_pred)
|
207 |
+
noise_pred_text = torch.stack(inverted_preds)
|
208 |
+
|
209 |
+
# Split predicted noise and predicted variances
|
210 |
+
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
|
211 |
+
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
|
212 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
213 |
+
|
214 |
+
# Combine noise estimates (and variance estimates)
|
215 |
+
noise_pred = noise_pred.view(-1,num_prompts,3,256,256)
|
216 |
+
predicted_variance = predicted_variance.view(-1,num_prompts,3,256,256)
|
217 |
+
if reduction == 'mean':
|
218 |
+
noise_pred = noise_pred.mean(1)
|
219 |
+
predicted_variance = predicted_variance.mean(1)
|
220 |
+
elif reduction == 'alternate':
|
221 |
+
noise_pred = noise_pred[:,i%num_prompts]
|
222 |
+
predicted_variance = predicted_variance[:,i%num_prompts]
|
223 |
+
|
224 |
+
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
225 |
+
|
226 |
+
# compute the previous noisy sample x_t -> x_t-1
|
227 |
+
noisy_images = model.scheduler.step(
|
228 |
+
noise_pred, t, noisy_images, generator=generator, return_dict=False
|
229 |
+
)[0]
|
230 |
+
|
231 |
+
# Return denoised images
|
232 |
+
return noisy_images
|
visual_anagrams/utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torchvision.utils import save_image
|
6 |
+
|
7 |
+
|
8 |
+
def add_args(parser):
|
9 |
+
"""
|
10 |
+
Add arguments for sampling to a parser
|
11 |
+
"""
|
12 |
+
|
13 |
+
parser.add_argument("--name", required=True, type=str)
|
14 |
+
parser.add_argument(
|
15 |
+
"--save_dir",
|
16 |
+
type=str,
|
17 |
+
default="results",
|
18 |
+
help="Location to samples and metadata",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--prompts",
|
22 |
+
required=True,
|
23 |
+
type=str,
|
24 |
+
nargs="+",
|
25 |
+
help="Prompts to use, corresponding to each view.",
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--views",
|
29 |
+
required=True,
|
30 |
+
type=str,
|
31 |
+
nargs="+",
|
32 |
+
help="Name of views to use. See `get_views` in `views.py`.",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--style", default="", type=str, help="Optional string to prepend prompt with"
|
36 |
+
)
|
37 |
+
parser.add_argument("--num_inference_steps", type=int, default=100)
|
38 |
+
parser.add_argument("--num_samples", type=int, default=100)
|
39 |
+
parser.add_argument("--reduction", type=str, default="mean")
|
40 |
+
parser.add_argument("--seed", type=int, default=0)
|
41 |
+
parser.add_argument("--guidance_scale", type=float, default=7.0)
|
42 |
+
parser.add_argument(
|
43 |
+
"--noise_level", type=int, default=50, help="Noise level for stage 2"
|
44 |
+
)
|
45 |
+
parser.add_argument("--device", type=str, default="cpu")
|
46 |
+
parser.add_argument(
|
47 |
+
"--save_metadata",
|
48 |
+
action="store_true",
|
49 |
+
help="If true, save metadata about the views. May use lots of disk space, particular for permutation views.",
|
50 |
+
)
|
51 |
+
|
52 |
+
return parser
|
53 |
+
|
54 |
+
|
55 |
+
def save_illusion(image, views, sample_dir):
|
56 |
+
"""
|
57 |
+
Saves the illusion (`image`), as well as all views of the illusion
|
58 |
+
|
59 |
+
image (torch.tensor) :
|
60 |
+
Tensor of shape (1,3,H,W) representing the image
|
61 |
+
|
62 |
+
views (views.BaseView) :
|
63 |
+
Represents the view, inherits from BaseView
|
64 |
+
|
65 |
+
sample_dir (pathlib.Path) :
|
66 |
+
pathlib Path object, representing the directory to save to
|
67 |
+
"""
|
68 |
+
|
69 |
+
size = image.shape[-1]
|
70 |
+
|
71 |
+
# Save illusion
|
72 |
+
save_image(image / 2.0 + 0.5, sample_dir / f"sample_{size}.png", padding=0)
|
73 |
+
|
74 |
+
# Save views of the illusion
|
75 |
+
im_views = torch.stack([view.view(image[0]) for view in views])
|
76 |
+
save_image(im_views / 2.0 + 0.5, sample_dir / f"sample_{size}.views.png", padding=0)
|
77 |
+
|
78 |
+
|
79 |
+
def save_metadata(views, args, save_dir):
|
80 |
+
"""
|
81 |
+
Saves the following the sample_dir
|
82 |
+
1) pickled view object
|
83 |
+
2) args for the illusion
|
84 |
+
"""
|
85 |
+
|
86 |
+
metadata = {"views": views, "args": args}
|
87 |
+
with open(save_dir / "metadata.pkl", "wb") as f:
|
88 |
+
pickle.dump(metadata, f)
|
89 |
+
|
90 |
+
|
91 |
+
def get_courier_font_path():
|
92 |
+
font_path = Path(__file__).parent / "assets" / "CourierPrime-Regular.ttf"
|
93 |
+
return str(font_path)
|
visual_anagrams/views/__init__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from .view_identity import IdentityView
|
6 |
+
from .view_flip import FlipView
|
7 |
+
from .view_rotate import Rotate180View, Rotate90CCWView, Rotate90CWView
|
8 |
+
from .view_negate import NegateView
|
9 |
+
from .view_skew import SkewView
|
10 |
+
from .view_patch_permute import PatchPermuteView
|
11 |
+
from .view_jigsaw import JigsawView
|
12 |
+
from .view_inner_circle import InnerCircleView
|
13 |
+
|
14 |
+
VIEW_MAP = {
|
15 |
+
'identity': IdentityView,
|
16 |
+
'flip': FlipView,
|
17 |
+
'rotate_cw': Rotate90CWView,
|
18 |
+
'rotate_ccw': Rotate90CCWView,
|
19 |
+
'rotate_180': Rotate180View,
|
20 |
+
'negate': NegateView,
|
21 |
+
'skew': SkewView,
|
22 |
+
'patch_permute': PatchPermuteView,
|
23 |
+
'pixel_permute': PatchPermuteView,
|
24 |
+
'jigsaw': JigsawView,
|
25 |
+
'inner_circle': InnerCircleView,
|
26 |
+
}
|
27 |
+
|
28 |
+
def get_views(view_names):
|
29 |
+
'''
|
30 |
+
Bespoke function to get views (just to make command line usage easier)
|
31 |
+
'''
|
32 |
+
views = []
|
33 |
+
for view_name in view_names:
|
34 |
+
if view_name == 'patch_permute':
|
35 |
+
args = [8]
|
36 |
+
elif view_name == 'pixel_permute':
|
37 |
+
args = [64]
|
38 |
+
elif view_name == 'skew':
|
39 |
+
args = [1.5]
|
40 |
+
else:
|
41 |
+
args = []
|
42 |
+
|
43 |
+
view = VIEW_MAP[view_name](*args)
|
44 |
+
views.append(view)
|
45 |
+
|
46 |
+
return views
|
visual_anagrams/views/jigsaw_helpers.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def get_jigsaw_pieces(size):
|
6 |
+
'''
|
7 |
+
Load all pieces of the 4x4 jigsaw puzzle.
|
8 |
+
|
9 |
+
size (int) :
|
10 |
+
Should be 64 or 256, indicating side length of jigsaw puzzle
|
11 |
+
'''
|
12 |
+
|
13 |
+
# Location of pieces
|
14 |
+
piece_dir = Path(__file__).parent / 'assets'
|
15 |
+
|
16 |
+
# Helper function to load pieces as np arrays
|
17 |
+
def load_pieces(path):
|
18 |
+
'''
|
19 |
+
Load a piece, from the given path, as a binary numpy array.
|
20 |
+
Return a list of the "base" piece, and all four of its rotations.
|
21 |
+
'''
|
22 |
+
piece = Image.open(path)
|
23 |
+
piece = np.array(piece)[:,:,0] // 255
|
24 |
+
pieces = np.stack([np.rot90(piece, k=-i) for i in range(4)])
|
25 |
+
return pieces
|
26 |
+
|
27 |
+
# Load pieces and rotate to get 16 pieces, and cat
|
28 |
+
pieces_corner = load_pieces(piece_dir / f'4x4/4x4_corner_{size}.png')
|
29 |
+
pieces_inner = load_pieces(piece_dir / f'4x4/4x4_inner_{size}.png')
|
30 |
+
pieces_edge1 = load_pieces(piece_dir / f'4x4/4x4_edge1_{size}.png')
|
31 |
+
pieces_edge2 = load_pieces(piece_dir / f'4x4/4x4_edge2_{size}.png')
|
32 |
+
pieces = np.concatenate([pieces_corner, pieces_inner, pieces_edge1, pieces_edge2])
|
33 |
+
|
34 |
+
return pieces
|
35 |
+
|
visual_anagrams/views/permutations.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchvision.transforms.functional as TF
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
from .jigsaw_helpers import get_jigsaw_pieces
|
8 |
+
|
9 |
+
def get_inv_perm(perm):
|
10 |
+
'''
|
11 |
+
Get the inverse permutation of a permutation. That is, the array such that
|
12 |
+
perm[perm_inv] = perm_inv[perm] = arange(len(perm))
|
13 |
+
|
14 |
+
perm (torch.tensor) :
|
15 |
+
A 1-dimensional integer array, representing a permutation. Indicates
|
16 |
+
that element i should move to index perm[i]
|
17 |
+
'''
|
18 |
+
perm_inv = torch.empty_like(perm)
|
19 |
+
perm_inv[perm] = torch.arange(len(perm))
|
20 |
+
return perm_inv
|
21 |
+
|
22 |
+
|
23 |
+
def make_inner_circle_perm(im_size=64, r=24):
|
24 |
+
'''
|
25 |
+
Makes permutations for "inner circle" view. Given size of image, and
|
26 |
+
`r`, the radius of the circle. We do this by iterating through every
|
27 |
+
pixel and figuring out where it should go.
|
28 |
+
'''
|
29 |
+
perm = [] # Permutation array
|
30 |
+
|
31 |
+
# Iterate through all positions, in order
|
32 |
+
for iy in range(im_size):
|
33 |
+
for ix in range(im_size):
|
34 |
+
# Get coordinates, with origin at (0, 0)
|
35 |
+
x = ix - im_size // 2 + 0.5
|
36 |
+
y = iy - im_size // 2 + 0.5
|
37 |
+
|
38 |
+
# Do 180 deg rotation if in circle
|
39 |
+
if x**2 + y**2 < r**2:
|
40 |
+
x = -x
|
41 |
+
y = -y
|
42 |
+
|
43 |
+
# Convert back to integer coordinates
|
44 |
+
x = int(x + im_size // 2 - 0.5)
|
45 |
+
y = int(y + im_size // 2 - 0.5)
|
46 |
+
|
47 |
+
# Append destination pixel index to permutation
|
48 |
+
perm.append(x + y * im_size)
|
49 |
+
perm = torch.tensor(perm)
|
50 |
+
|
51 |
+
return perm
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
def make_jigsaw_perm(size, seed=0):
|
57 |
+
'''
|
58 |
+
Returns a permutation of pixels that is a jigsaw permutation
|
59 |
+
|
60 |
+
There are 3 types of pieces: corner, edge, and inner pieces. These were
|
61 |
+
created in MS Paint. They are all identical and laid out like:
|
62 |
+
|
63 |
+
c0 e0 f0 c1
|
64 |
+
f3 i0 i1 e1
|
65 |
+
e3 i3 i2 f1
|
66 |
+
c3 f2 e2 c2
|
67 |
+
|
68 |
+
where c is "corner," i is "inner," and "e" and "f" are "edges."
|
69 |
+
"e" and "f" pieces are identical, but labeled differently such that
|
70 |
+
to move any piece to the next index you can apply a 90 deg rotation.
|
71 |
+
|
72 |
+
Pieces c0, e0, f0, and i0 are defined by pngs, and will be loaded in. All
|
73 |
+
other pieces are obtained by 90 deg rotations of these "base" pieces.
|
74 |
+
|
75 |
+
Permutations are defined by:
|
76 |
+
1. permutation of corner (c) pieces (length 4 perm list)
|
77 |
+
2. permutation of inner (i) pieces (length 4 perm list)
|
78 |
+
3. permutation of edge (e) pieces (length 4 perm list)
|
79 |
+
4. permutation of edge (f) pieces (length 4 perm list)
|
80 |
+
5. list of four swaps, indicating swaps between e and f
|
81 |
+
edge pieces along the same edge (length 4 bit list)
|
82 |
+
|
83 |
+
Note these perm indexes will just be a "rotation index" indicating
|
84 |
+
how many 90 deg rotations to apply to the base pieces. The swaps
|
85 |
+
ensure that any edge piece can go to any edge piece, and are indexed
|
86 |
+
by the indexes of the "e" and "f" pieces on the edge.
|
87 |
+
|
88 |
+
Also note, order of indexes in permutation array is raster scan order. So,
|
89 |
+
go along x's first, then y's. This means y * size + x gives us the
|
90 |
+
1-D location in the permutation array. And image arrays are in
|
91 |
+
(y,x) order.
|
92 |
+
|
93 |
+
Plan of attack for making a pixel permutation array that represents
|
94 |
+
a jigsaw permutation:
|
95 |
+
|
96 |
+
1. Iterate through all pixels (in raster scan order)
|
97 |
+
2. Figure out which puzzle piece it is in initially
|
98 |
+
3. Look at the permutations, and see where it should go
|
99 |
+
4. Additionally, see if it's an edge piece, and needs to be swapped
|
100 |
+
5. Add the new (1-D) index to the permutation array
|
101 |
+
|
102 |
+
'''
|
103 |
+
np.random.seed(seed)
|
104 |
+
|
105 |
+
# Get location of puzzle pieces
|
106 |
+
piece_dir = Path(__file__).parent / 'assets'
|
107 |
+
|
108 |
+
# Get random permutations of groups of 4, and cat
|
109 |
+
identity = np.arange(4)
|
110 |
+
perm_corner = np.random.permutation(identity)
|
111 |
+
perm_inner = np.random.permutation(identity)
|
112 |
+
perm_edge1 = np.random.permutation(identity)
|
113 |
+
perm_edge2 = np.random.permutation(identity)
|
114 |
+
edge_swaps = np.random.randint(2, size=4)
|
115 |
+
piece_perms = np.concatenate([perm_corner, perm_inner, perm_edge1, perm_edge2])
|
116 |
+
|
117 |
+
# Get all 16 jigsaw pieces (in the order above)
|
118 |
+
pieces = get_jigsaw_pieces(size)
|
119 |
+
|
120 |
+
# Make permutation array to fill
|
121 |
+
perm = []
|
122 |
+
|
123 |
+
# For each pixel, figure out where it should go
|
124 |
+
for y in range(size):
|
125 |
+
for x in range(size):
|
126 |
+
# Figure out which piece (x,y) is in:
|
127 |
+
piece_idx = pieces[:,y,x].argmax()
|
128 |
+
|
129 |
+
# Figure out how many 90 deg rotations are on the piece
|
130 |
+
rot_idx = piece_idx % 4
|
131 |
+
|
132 |
+
# The perms tells us how many 90 deg rotations to apply to
|
133 |
+
# arrive at new pixel location
|
134 |
+
dest_rot_idx = piece_perms[piece_idx]
|
135 |
+
angle = (dest_rot_idx - rot_idx) * 90 / 180 * np.pi
|
136 |
+
|
137 |
+
# Center coordinates on origin
|
138 |
+
cx = x - (size - 1) / 2.
|
139 |
+
cy = y - (size - 1) / 2.
|
140 |
+
|
141 |
+
# Perform rotation
|
142 |
+
nx = np.cos(angle) * cx - np.sin(angle) * cy
|
143 |
+
ny = np.sin(angle) * cx + np.cos(angle) * cy
|
144 |
+
|
145 |
+
# Translate back and round coordinates to _nearest_ integer
|
146 |
+
nx = nx + (size - 1) / 2.
|
147 |
+
ny = ny + (size - 1) / 2.
|
148 |
+
nx = int(np.rint(nx))
|
149 |
+
ny = int(np.rint(ny))
|
150 |
+
|
151 |
+
# Perform swap if piece is an edge, and swap == 1 at NEW location
|
152 |
+
new_piece_idx = pieces[:,ny,nx].argmax()
|
153 |
+
edge_idx = new_piece_idx % 4
|
154 |
+
if new_piece_idx >= 8 and edge_swaps[edge_idx] == 1:
|
155 |
+
is_f_edge = (new_piece_idx - 8) // 4 # 1 if f, 0 if e edge
|
156 |
+
edge_type_parity = 1 - 2 * is_f_edge
|
157 |
+
rotation_parity = 1 - 2 * (edge_idx // 2)
|
158 |
+
swap_dist = size // 4
|
159 |
+
|
160 |
+
# if edge_idx is even, swap in x direction, else y
|
161 |
+
if edge_idx % 2 == 0:
|
162 |
+
nx = nx + swap_dist * edge_type_parity * rotation_parity
|
163 |
+
else:
|
164 |
+
ny = ny + swap_dist * edge_type_parity * rotation_parity
|
165 |
+
|
166 |
+
# append new index to permutation array
|
167 |
+
new_idx = int(ny * size + nx)
|
168 |
+
perm.append(new_idx)
|
169 |
+
|
170 |
+
# sanity check
|
171 |
+
#import matplotlib.pyplot as plt
|
172 |
+
#missing = sorted(set(range(size*size)).difference(set(perm)))
|
173 |
+
#asdf = np.zeros(size*size)
|
174 |
+
#asdf[missing] = 1
|
175 |
+
#plt.imshow(asdf.reshape(size,size))
|
176 |
+
#plt.savefig('tmp.png')
|
177 |
+
#plt.show()
|
178 |
+
#print(np.sum(asdf))
|
179 |
+
|
180 |
+
#viz = np.zeros((64,64))
|
181 |
+
#for idx in perm:
|
182 |
+
# y, x = idx // 64, idx % 64
|
183 |
+
# viz[y,x] = 1
|
184 |
+
#plt.imshow(viz)
|
185 |
+
#plt.savefig('tmp.png')
|
186 |
+
#Image.fromarray(viz * 255).convert('RGB').save('tmp.png')
|
187 |
+
#Image.fromarray(pieces_edge1[0] * 255).convert('RGB').save('tmp.png')
|
188 |
+
|
189 |
+
# sanity check on test image
|
190 |
+
#im = Image.open('results/flip.campfire.man/0000/sample_64.png')
|
191 |
+
#im = Image.open('results/flip.campfire.man/0000/sample_256.png')
|
192 |
+
#im = np.array(im)
|
193 |
+
#Image.fromarray(im.reshape(-1, 3)[perm].reshape(size,size,3)).save('test.png')
|
194 |
+
|
195 |
+
return torch.tensor(perm), (piece_perms, edge_swaps)
|
196 |
+
|
197 |
+
#for i in range(100):
|
198 |
+
#make_jigsaw_perm(64, seed=i)
|
199 |
+
#make_jigsaw_perm(256, seed=11)
|
200 |
+
|
201 |
+
|
202 |
+
def recover_patch_permute(im_0, im_1, patch_size):
|
203 |
+
'''
|
204 |
+
Given two views of a patch permutation illusion, recover the patch
|
205 |
+
permutation used.
|
206 |
+
|
207 |
+
im_0 (PIL.Image) :
|
208 |
+
Identity view of the illusion
|
209 |
+
|
210 |
+
im_1 (PIL.Image) :
|
211 |
+
Patch permuted view of the illusion
|
212 |
+
|
213 |
+
patch_size (int) :
|
214 |
+
Size of the patches in the image
|
215 |
+
'''
|
216 |
+
|
217 |
+
# Convert to tensors
|
218 |
+
im_0 = TF.to_tensor(im_0)
|
219 |
+
im_1 = TF.to_tensor(im_1)
|
220 |
+
|
221 |
+
# Extract patches
|
222 |
+
patches_0 = rearrange(im_0,
|
223 |
+
'c (h p1) (w p2) -> (h w) c p1 p2',
|
224 |
+
p1=patch_size,
|
225 |
+
p2=patch_size)
|
226 |
+
patches_1 = rearrange(im_1,
|
227 |
+
'c (h p1) (w p2) -> (h w) c p1 p2',
|
228 |
+
p1=patch_size,
|
229 |
+
p2=patch_size)
|
230 |
+
|
231 |
+
# Repeat patches_1 for each patch in patches_0
|
232 |
+
patches_1_repeated = repeat(patches_1,
|
233 |
+
'np c p1 p2 -> np1 np c p1 p2',
|
234 |
+
np=patches_1.shape[0],
|
235 |
+
np1=patches_1.shape[0],
|
236 |
+
p1=patch_size,
|
237 |
+
p2=patch_size)
|
238 |
+
|
239 |
+
# Find closest patch in other image by L1 dist, and return indexes
|
240 |
+
perm = (patches_1_repeated - patches_0[:,None]).abs().sum((2,3,4)).argmin(1)
|
241 |
+
|
242 |
+
return perm
|
visual_anagrams/views/view_base.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseView:
|
2 |
+
'''
|
3 |
+
BaseView class, from which all views inherit. Implements the
|
4 |
+
following functions:
|
5 |
+
'''
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
pass
|
9 |
+
|
10 |
+
def view(self, im):
|
11 |
+
'''
|
12 |
+
Apply transform to an image.
|
13 |
+
|
14 |
+
im (`torch.tensor`):
|
15 |
+
For stage 1: Tensor of shape (3, H, W) representing a noisy image
|
16 |
+
OR
|
17 |
+
For stage 2: Tensor of shape (6, H, W) representing a noisy image
|
18 |
+
concatenated with an upsampled conditioning image from stage 1
|
19 |
+
'''
|
20 |
+
raise NotImplementedError()
|
21 |
+
|
22 |
+
def inverse_view(self, noise):
|
23 |
+
'''
|
24 |
+
Apply inverse transform to noise estimates.
|
25 |
+
Because DeepFloyd estimates the variance in addition to
|
26 |
+
the noise, this function must apply the inverse to the
|
27 |
+
variance as well.
|
28 |
+
|
29 |
+
im (`torch.tensor`):
|
30 |
+
Tensor of shape (6, H, W) representing the noise estimate
|
31 |
+
(first three channel dims) and variacne estimates (last
|
32 |
+
three channel dims)
|
33 |
+
'''
|
34 |
+
raise NotImplementedError()
|
35 |
+
|
36 |
+
def make_frame(self, im, t):
|
37 |
+
'''
|
38 |
+
Make a frame, transitioning linearly from the identity view (t=0)
|
39 |
+
to this view (t=1)
|
40 |
+
|
41 |
+
im (`PIL.Image`) :
|
42 |
+
A PIL Image of the illusion
|
43 |
+
|
44 |
+
t (float) :
|
45 |
+
A float in [0,1] indicating time in the animation. Should start
|
46 |
+
at the identity view at t=0, and continuously transition to the
|
47 |
+
view at t=1.
|
48 |
+
'''
|
49 |
+
raise NotImplementedError()
|
visual_anagrams/views/view_flip.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .view_base import BaseView
|
6 |
+
|
7 |
+
class FlipView(BaseView):
|
8 |
+
def __init__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def view(self, im):
|
12 |
+
return torch.flip(im, [1])
|
13 |
+
|
14 |
+
def inverse_view(self, noise):
|
15 |
+
return torch.flip(noise, [1])
|
16 |
+
|
17 |
+
def make_frame(self, im, t):
|
18 |
+
im_size = im.size[0]
|
19 |
+
frame_size = int(im_size * 1.5)
|
20 |
+
theta = t * 180
|
21 |
+
|
22 |
+
# TODO: Technically not a flip, change this to a homography later
|
23 |
+
frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
|
24 |
+
frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
|
25 |
+
frame = frame.rotate(theta,
|
26 |
+
resample=Image.Resampling.BILINEAR,
|
27 |
+
expand=False,
|
28 |
+
fillcolor=(255,255,255))
|
29 |
+
|
30 |
+
return frame
|
visual_anagrams/views/view_identity.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .view_base import BaseView
|
2 |
+
|
3 |
+
class IdentityView(BaseView):
|
4 |
+
def __init__(self):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def view(self, im):
|
8 |
+
return im
|
9 |
+
|
10 |
+
def inverse_view(self, noise):
|
11 |
+
return noise
|
visual_anagrams/views/view_inner_circle.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
|
7 |
+
from .permutations import make_inner_circle_perm
|
8 |
+
from .view_permute import PermuteView
|
9 |
+
|
10 |
+
class InnerCircleView(PermuteView):
|
11 |
+
'''
|
12 |
+
Implements an "inner circle" view, where a circle inside the image spins
|
13 |
+
but the border stays still. Inherits from `PermuteView`, which implements
|
14 |
+
the `view` and `inverse_view` functions as permutations. We just make
|
15 |
+
the correct permutation here, and implement the `make_frame` method
|
16 |
+
for animation
|
17 |
+
'''
|
18 |
+
def __init__(self):
|
19 |
+
'''
|
20 |
+
Make the correct "inner circle" permutations and pass it to the
|
21 |
+
parent class constructor.
|
22 |
+
'''
|
23 |
+
self.perm_64 = make_inner_circle_perm(im_size=64, r=24)
|
24 |
+
self.perm_256 = make_inner_circle_perm(im_size=256, r=96)
|
25 |
+
|
26 |
+
super().__init__(self.perm_64, self.perm_256)
|
27 |
+
|
28 |
+
def make_frame(self, im, t):
|
29 |
+
im_size = im.size[0]
|
30 |
+
frame_size = int(im_size * 1.5)
|
31 |
+
theta = -t * 180
|
32 |
+
|
33 |
+
# Convert to tensor
|
34 |
+
im = torch.tensor(np.array(im) / 255.).permute(2,0,1)
|
35 |
+
|
36 |
+
# Get mask of circle (TODO: assuming size 256)
|
37 |
+
coords = torch.arange(0, 256) - 127.5
|
38 |
+
xx, yy = torch.meshgrid(coords, coords)
|
39 |
+
mask = xx**2 + yy**2 < (24*4)**2
|
40 |
+
mask = torch.stack([mask]*3).float()
|
41 |
+
|
42 |
+
# Get rotate image
|
43 |
+
im_rotated = TF.rotate(im, theta)
|
44 |
+
|
45 |
+
# Composite rotated circle + border together
|
46 |
+
im = im * (1 - mask) + im_rotated * mask
|
47 |
+
|
48 |
+
# Convert back to PIL
|
49 |
+
im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8))
|
50 |
+
|
51 |
+
# Paste on to canvas
|
52 |
+
frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
|
53 |
+
frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
|
54 |
+
|
55 |
+
return frame
|
56 |
+
|
visual_anagrams/views/view_jigsaw.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
from einops import einsum, rearrange
|
5 |
+
|
6 |
+
from .permutations import make_jigsaw_perm, get_inv_perm
|
7 |
+
from .view_permute import PermuteView
|
8 |
+
from .jigsaw_helpers import get_jigsaw_pieces
|
9 |
+
|
10 |
+
class JigsawView(PermuteView):
|
11 |
+
'''
|
12 |
+
Implements a 4x4 jigsaw puzzle view...
|
13 |
+
'''
|
14 |
+
def __init__(self, seed=11):
|
15 |
+
'''
|
16 |
+
'''
|
17 |
+
# Get pixel permutations, corresponding to jigsaw permutations
|
18 |
+
self.perm_64, _ = make_jigsaw_perm(64, seed=seed)
|
19 |
+
self.perm_256, (jigsaw_perm) = make_jigsaw_perm(256, seed=seed)
|
20 |
+
|
21 |
+
# keep track of jigsaw permutation as well
|
22 |
+
self.piece_perms, self.edge_swaps = jigsaw_perm
|
23 |
+
|
24 |
+
# Init parent PermuteView, with above pixel perms
|
25 |
+
super().__init__(self.perm_64, self.perm_256)
|
26 |
+
|
27 |
+
def extract_pieces(self, im):
|
28 |
+
'''
|
29 |
+
Given an image, extract jigsaw puzzle pieces from it
|
30 |
+
|
31 |
+
im (PIL.Image) :
|
32 |
+
PIL Image of the jigsaw illusion
|
33 |
+
'''
|
34 |
+
im = np.array(im)
|
35 |
+
size = im.shape[0]
|
36 |
+
pieces = []
|
37 |
+
|
38 |
+
# Get jigsaw pieces
|
39 |
+
piece_masks = get_jigsaw_pieces(size)
|
40 |
+
|
41 |
+
# Save pieces
|
42 |
+
for piece_mask in piece_masks:
|
43 |
+
# Add mask as alpha mask to image
|
44 |
+
im_piece = np.concatenate([im, piece_mask[:,:,None] * 255], axis=2)
|
45 |
+
|
46 |
+
# Get extents of piece, and crop
|
47 |
+
x_min = np.nonzero(im_piece[:,:,-1].sum(0))[0].min()
|
48 |
+
x_max = np.nonzero(im_piece[:,:,-1].sum(0))[0].max()
|
49 |
+
y_min = np.nonzero(im_piece[:,:,-1].sum(1))[0].min()
|
50 |
+
y_max = np.nonzero(im_piece[:,:,-1].sum(1))[0].max()
|
51 |
+
im_piece = im_piece[y_min:y_max+1, x_min:x_max+1]
|
52 |
+
|
53 |
+
pieces.append(Image.fromarray(im_piece))
|
54 |
+
|
55 |
+
return pieces
|
56 |
+
|
57 |
+
|
58 |
+
def paste_piece(self, piece, x, y, theta, xc, yc, canvas_size=384):
|
59 |
+
'''
|
60 |
+
Given a PIL Image of a piece, place it so that it's center is at
|
61 |
+
(x,y) and it's rotate about that center at theta degrees
|
62 |
+
|
63 |
+
x (float) : x coordinate to place piece at
|
64 |
+
y (float) : y coordinate to place piece at
|
65 |
+
theta (float) : degrees to rotate piece about center
|
66 |
+
xc (float) : x coordinate of center of piece
|
67 |
+
yc (float) : y coordinate of center of piece
|
68 |
+
'''
|
69 |
+
|
70 |
+
# Make canvas
|
71 |
+
canvas = Image.new("RGBA",
|
72 |
+
(canvas_size, canvas_size),
|
73 |
+
(255, 255, 255, 0))
|
74 |
+
|
75 |
+
# Past piece so center is at (x, y)
|
76 |
+
canvas.paste(piece, (x-xc,y-yc), piece)
|
77 |
+
|
78 |
+
# Rotate about (x, y)
|
79 |
+
canvas = canvas.rotate(theta, resample=Image.BILINEAR, center=(x, y))
|
80 |
+
return canvas
|
81 |
+
|
82 |
+
|
83 |
+
def make_frame(self, im, t, canvas_size=384, knot_seed=0):
|
84 |
+
'''
|
85 |
+
This function returns a PIL image of a frame animating a jigsaw
|
86 |
+
permutation. Pieces move and rotate from the identity view
|
87 |
+
(t = 0) to the rearranged view (t = 1) along splines.
|
88 |
+
|
89 |
+
The approach is as follows:
|
90 |
+
|
91 |
+
1. Extract all 16 pieces
|
92 |
+
2. Figure out start locations for each of these pieces (t=0)
|
93 |
+
3. Figure out how these pieces permute
|
94 |
+
4. Using these permutations, figure out end locations (t=1)
|
95 |
+
5. Make knots for splines, randomly offset normally from the
|
96 |
+
midpoint of the start and end locations
|
97 |
+
6. Paste pieces into correct locations, determined by
|
98 |
+
spline interpolation
|
99 |
+
|
100 |
+
im (PIL.Image) :
|
101 |
+
PIL image representing the jigsaw illusion
|
102 |
+
|
103 |
+
t (float) :
|
104 |
+
Interpolation parameter in [0,1] indicating what frame of the
|
105 |
+
animation to generate
|
106 |
+
|
107 |
+
canvas_size (int) :
|
108 |
+
Side length of the frame
|
109 |
+
|
110 |
+
knot_seed (int) :
|
111 |
+
Seed for random offsets for the knots
|
112 |
+
'''
|
113 |
+
im_size = im.size[0]
|
114 |
+
|
115 |
+
# Extract 16 jigsaw pieces
|
116 |
+
pieces = self.extract_pieces(im)
|
117 |
+
|
118 |
+
# Rotate all pieces to "base" piece orientation
|
119 |
+
pieces = [p.rotate(90 * (i % 4),
|
120 |
+
resample=Image.BILINEAR,
|
121 |
+
expand=1) for i, p in enumerate(pieces)]
|
122 |
+
|
123 |
+
# Get (hardcoded) start locations for each base piece, on a
|
124 |
+
# 4x4 grid centered on the origin.
|
125 |
+
corner_start_loc = np.array([-1.5, -1.5])
|
126 |
+
inner_start_loc = np.array([-0.5, -0.5])
|
127 |
+
edge_e_start_loc = np.array([-1.5, -0.5])
|
128 |
+
edge_f_start_loc = np.array([-1.5, 0.5])
|
129 |
+
base_start_locs = np.stack([corner_start_loc,
|
130 |
+
inner_start_loc,
|
131 |
+
edge_e_start_loc,
|
132 |
+
edge_f_start_loc])
|
133 |
+
|
134 |
+
# Construct all start locations by rotating around (0,0)
|
135 |
+
# by 90 degrees, 4 times, and concatenating the results
|
136 |
+
rot_mats = []
|
137 |
+
for theta in -np.arange(4) * 90 / 180 * np.pi:
|
138 |
+
rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
|
139 |
+
[np.sin(theta), np.cos(theta)]])
|
140 |
+
rot_mats.append(rot_mat)
|
141 |
+
rot_mats = np.stack(rot_mats)
|
142 |
+
start_locs = einsum(base_start_locs, rot_mats,
|
143 |
+
'start i, rot j i -> start rot j')
|
144 |
+
start_locs = rearrange(start_locs,
|
145 |
+
'start rot j -> (start rot) j')
|
146 |
+
|
147 |
+
# Add rotation information to start locations
|
148 |
+
thetas = np.tile(np.arange(4) * -90, 4)[:, None]
|
149 |
+
start_locs = np.concatenate([start_locs, thetas], axis=1)
|
150 |
+
|
151 |
+
# Get explicit permutation of pieces from permutation metadata
|
152 |
+
perm = self.piece_perms + np.repeat(np.arange(4), 4) * 4
|
153 |
+
for edge_idx, to_swap in enumerate(self.edge_swaps):
|
154 |
+
if to_swap:
|
155 |
+
# Make swap permutation array
|
156 |
+
swap_perm = np.arange(16)
|
157 |
+
swap_perm[8 + edge_idx], swap_perm[12 + edge_idx] = \
|
158 |
+
swap_perm[12 + edge_idx], swap_perm[8 + edge_idx]
|
159 |
+
|
160 |
+
# Apply swap permutation after perm
|
161 |
+
perm = np.array([swap_perm[perm[i]] for i in range(16)])
|
162 |
+
|
163 |
+
# Get inverse perm (the actual permutation needed)...
|
164 |
+
perm_inv = get_inv_perm(torch.tensor(perm))
|
165 |
+
|
166 |
+
# ...and use it to get the final locations of pieces
|
167 |
+
end_locs = start_locs[perm_inv]
|
168 |
+
|
169 |
+
# Convert start and end locations to pixel coordinate system
|
170 |
+
start_locs[:,:2] = (start_locs[:,:2] + 2) * 64
|
171 |
+
end_locs[:,:2] = (end_locs[:,:2] + 2) * 64
|
172 |
+
|
173 |
+
# Add offset so pieces are centered on canvas
|
174 |
+
start_locs[:,:2] = start_locs[:,:2] + (canvas_size - im_size) // 2
|
175 |
+
end_locs[:,:2] = end_locs[:,:2] + (canvas_size - im_size) // 2
|
176 |
+
|
177 |
+
# Get random offsets from middle for spline knot (so path is pretty)
|
178 |
+
# Wrapped in a set seed
|
179 |
+
original_state = np.random.get_state()
|
180 |
+
np.random.seed(knot_seed)
|
181 |
+
rand_offsets = np.random.rand(16, 1) * 2 - 1
|
182 |
+
rand_offsets = rand_offsets * 2
|
183 |
+
eps = np.random.randn(16, 2) # Add epsilon for divide by zero
|
184 |
+
np.random.set_state(original_state)
|
185 |
+
|
186 |
+
# Make spline knots by taking average of start and end,
|
187 |
+
# and offsetting by some amount normal from the line
|
188 |
+
avg_locs = (start_locs[:, :2] + end_locs[:, :2]) / 2.
|
189 |
+
norm = (end_locs[:, :2] - start_locs[:, :2])
|
190 |
+
norm = norm + eps
|
191 |
+
norm = norm / np.linalg.norm(norm, axis=1, keepdims=True)
|
192 |
+
rot_mat = np.array([[0,1], [-1,0]])
|
193 |
+
norm = norm @ rot_mat
|
194 |
+
rand_offsets = rand_offsets * (im_size / 4)
|
195 |
+
knot_locs = avg_locs + norm * rand_offsets
|
196 |
+
|
197 |
+
# Paste pieces on to a canvas
|
198 |
+
canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255))
|
199 |
+
for i in range(16):
|
200 |
+
# Get start and end coords
|
201 |
+
y_0, x_0, theta_0 = start_locs[i]
|
202 |
+
y_1, x_1, theta_1 = end_locs[i]
|
203 |
+
y_k, x_k = knot_locs[i]
|
204 |
+
|
205 |
+
# Take spline interpolation for x and y
|
206 |
+
x_int_0 = x_0 * (1-t) + x_k * t
|
207 |
+
y_int_0 = y_0 * (1-t) + y_k * t
|
208 |
+
x_int_1 = x_k * (1-t) + x_1 * t
|
209 |
+
y_int_1 = y_k * (1-t) + y_1 * t
|
210 |
+
x = int(np.round(x_int_0 * (1-t) + x_int_1 * t))
|
211 |
+
y = int(np.round(y_int_0 * (1-t) + y_int_1 * t))
|
212 |
+
|
213 |
+
# Just take normal interpolation for theta
|
214 |
+
theta = int(np.round(theta_0 * (1-t) + theta_1 * t))
|
215 |
+
|
216 |
+
# Get piece in location and rotation
|
217 |
+
xc = yc = im_size // 4 // 2
|
218 |
+
pasted_piece = self.paste_piece(pieces[i], x, y, theta, xc, yc)
|
219 |
+
|
220 |
+
canvas.paste(pasted_piece, (0,0), pasted_piece)
|
221 |
+
|
222 |
+
return canvas
|
visual_anagrams/views/view_negate.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .view_base import BaseView
|
7 |
+
|
8 |
+
class NegateView(BaseView):
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def view(self, im):
|
13 |
+
return -im
|
14 |
+
|
15 |
+
def inverse_view(self, noise):
|
16 |
+
'''
|
17 |
+
Negating the variance estimate is "weird" so just don't do it.
|
18 |
+
This hack seems to work just fine
|
19 |
+
'''
|
20 |
+
invert_mask = torch.ones_like(noise)
|
21 |
+
invert_mask[:3] = -1
|
22 |
+
return noise * invert_mask
|
23 |
+
|
24 |
+
def make_frame(self, im, t):
|
25 |
+
im_size = im.size[0]
|
26 |
+
frame_size = int(im_size * 1.5)
|
27 |
+
|
28 |
+
# map t from [0, 1] -> [1, -1]
|
29 |
+
t = 1 - t
|
30 |
+
t = t * 2 - 1
|
31 |
+
|
32 |
+
# Interpolate from pixels from [0, 1] to [1, 0]
|
33 |
+
im = np.array(im) / 255.
|
34 |
+
im = ((2 * im - 1) * t + 1) / 2.
|
35 |
+
im = Image.fromarray((im * 255.).astype(np.uint8))
|
36 |
+
|
37 |
+
# Paste on to canvas
|
38 |
+
frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
|
39 |
+
frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
|
40 |
+
|
41 |
+
return frame
|
visual_anagrams/views/view_patch_permute.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from .permutations import get_inv_perm
|
9 |
+
from .view_base import BaseView
|
10 |
+
|
11 |
+
|
12 |
+
class PatchPermuteView(BaseView):
|
13 |
+
def __init__(self, num_patches=8):
|
14 |
+
'''
|
15 |
+
Implements random patch permutations, with `num_patches`
|
16 |
+
patches per side
|
17 |
+
|
18 |
+
num_patches (int) :
|
19 |
+
Number of patches in one dimension. Total number
|
20 |
+
of patches will be num_patches**2. Should be a power of 2.
|
21 |
+
'''
|
22 |
+
|
23 |
+
assert 64 % num_patches == 0 and 256 % num_patches == 0, \
|
24 |
+
"`num_patches` must divide image side lengths of 64 and 256"
|
25 |
+
|
26 |
+
self.num_patches = num_patches
|
27 |
+
|
28 |
+
# Get random permutation and inverse permutation
|
29 |
+
self.perm = torch.randperm(self.num_patches**2)
|
30 |
+
self.perm_inv = get_inv_perm(self.perm)
|
31 |
+
|
32 |
+
def view(self, im):
|
33 |
+
im_size = im.shape[-1]
|
34 |
+
|
35 |
+
# Get number of pixels on one side of a patch
|
36 |
+
patch_size = int(im_size / self.num_patches)
|
37 |
+
|
38 |
+
# Reshape into patches of size (c, patch_size, patch_size)
|
39 |
+
patches = rearrange(im,
|
40 |
+
'c (h p1) (w p2) -> (h w) c p1 p2',
|
41 |
+
p1=patch_size,
|
42 |
+
p2=patch_size)
|
43 |
+
|
44 |
+
# Permute
|
45 |
+
patches = patches[self.perm]
|
46 |
+
|
47 |
+
# Reshape back into image
|
48 |
+
im_rearr = rearrange(patches,
|
49 |
+
'(h w) c p1 p2 -> c (h p1) (w p2)',
|
50 |
+
h=self.num_patches,
|
51 |
+
w=self.num_patches,
|
52 |
+
p1=patch_size,
|
53 |
+
p2=patch_size)
|
54 |
+
return im_rearr
|
55 |
+
|
56 |
+
def inverse_view(self, noise):
|
57 |
+
im_size = noise.shape[-1]
|
58 |
+
|
59 |
+
# Get number of pixels on one side of a patch
|
60 |
+
patch_size = int(im_size / self.num_patches)
|
61 |
+
|
62 |
+
# Reshape into patches of size (c, patch_size, patch_size)
|
63 |
+
patches = rearrange(noise,
|
64 |
+
'c (h p1) (w p2) -> (h w) c p1 p2',
|
65 |
+
p1=patch_size,
|
66 |
+
p2=patch_size)
|
67 |
+
|
68 |
+
# Apply inverse permutation
|
69 |
+
patches = patches[self.perm_inv]
|
70 |
+
|
71 |
+
# Reshape back into image
|
72 |
+
im_rearr = rearrange(patches,
|
73 |
+
'(h w) c p1 p2 -> c (h p1) (w p2)',
|
74 |
+
h=self.num_patches,
|
75 |
+
w=self.num_patches,
|
76 |
+
p1=patch_size,
|
77 |
+
p2=patch_size)
|
78 |
+
return im_rearr
|
79 |
+
|
80 |
+
def make_frame(self, im, t, canvas_size=384, scale=4, knot_seed=0):
|
81 |
+
'''
|
82 |
+
Scale is a hack, because PIL for some reason doesn't support pasting
|
83 |
+
at floating point coordinates. So just render at larger scale
|
84 |
+
and resize by 1/scale
|
85 |
+
'''
|
86 |
+
# Get useful info
|
87 |
+
im_size = im.size[0]
|
88 |
+
offset = (canvas_size - im_size) // 2 # offset to center animation
|
89 |
+
|
90 |
+
canvas_size = canvas_size * scale
|
91 |
+
offset = offset * scale
|
92 |
+
|
93 |
+
im = TF.to_tensor(im)
|
94 |
+
|
95 |
+
# Get number of pixels on one side of a patch
|
96 |
+
im_size = im.shape[-1]
|
97 |
+
patch_size = int(im_size / self.num_patches)
|
98 |
+
|
99 |
+
# Extract patches
|
100 |
+
patches = rearrange(im,
|
101 |
+
'c (h p1) (w p2) -> (h w) c p1 p2',
|
102 |
+
p1=patch_size,
|
103 |
+
p2=patch_size)
|
104 |
+
|
105 |
+
# Get start locations (top left corner of patch)
|
106 |
+
yy, xx = torch.meshgrid(
|
107 |
+
torch.arange(self.num_patches),
|
108 |
+
torch.arange(self.num_patches)
|
109 |
+
)
|
110 |
+
xx = xx.flatten()
|
111 |
+
yy = yy.flatten()
|
112 |
+
start_locs = torch.stack([xx, yy], dim=1) * patch_size * scale
|
113 |
+
start_locs = start_locs + offset
|
114 |
+
|
115 |
+
# Get end locations by permuting
|
116 |
+
end_locs = start_locs[self.perm]
|
117 |
+
|
118 |
+
# Get random anchor locations
|
119 |
+
original_state = np.random.get_state()
|
120 |
+
np.random.seed(knot_seed)
|
121 |
+
rand_offsets = np.random.rand(self.num_patches**2, 1) * 2 - 1
|
122 |
+
rand_offsets = rand_offsets * 2 * scale
|
123 |
+
eps = np.random.randn(*start_locs.shape) # Add epsilon for divide by zero
|
124 |
+
np.random.set_state(original_state)
|
125 |
+
|
126 |
+
# Make spline knots by taking average of start and end,
|
127 |
+
# and offsetting by some amount normal from the line
|
128 |
+
avg_locs = (start_locs + end_locs) / 2.
|
129 |
+
norm = (end_locs - start_locs)
|
130 |
+
norm = norm + eps
|
131 |
+
norm = norm / np.linalg.norm(norm, axis=1, keepdims=True)
|
132 |
+
rot_mat = np.array([[0,1], [-1,0]])
|
133 |
+
norm = norm @ rot_mat
|
134 |
+
rand_offsets = rand_offsets * (im_size / 4)
|
135 |
+
knot_locs = avg_locs + norm * rand_offsets
|
136 |
+
|
137 |
+
# Get paste locations
|
138 |
+
spline_0 = start_locs * (1 - t) + knot_locs * t
|
139 |
+
spline_1 = knot_locs * (1 - t) + end_locs * t
|
140 |
+
paste_locs = spline_0 * (1 - t) + spline_1 * t
|
141 |
+
paste_locs = paste_locs.to(int)
|
142 |
+
|
143 |
+
# Paste patches onto canvas
|
144 |
+
canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255))
|
145 |
+
for patch, paste_loc in zip(patches, paste_locs):
|
146 |
+
patch = TF.to_pil_image(patch).convert('RGBA')
|
147 |
+
patch = patch.resize((patch_size * scale, patch_size * scale))
|
148 |
+
paste_loc = (paste_loc[0].item(), paste_loc[1].item())
|
149 |
+
canvas.paste(patch, paste_loc, patch)
|
150 |
+
|
151 |
+
if scale != 1.0:
|
152 |
+
canvas = canvas.resize((canvas_size // scale, canvas_size // scale))
|
153 |
+
|
154 |
+
return canvas
|
visual_anagrams/views/view_permute.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
from .permutations import get_inv_perm
|
5 |
+
from .view_base import BaseView
|
6 |
+
|
7 |
+
class PermuteView(BaseView):
|
8 |
+
def __init__(self, perm_64, perm_256):
|
9 |
+
'''
|
10 |
+
Implements arbitrary pixel permutations, for a given permutation.
|
11 |
+
We need two permutations. One of size 64x64 for stage 1, and
|
12 |
+
one of size 256x256 for stage 2.
|
13 |
+
|
14 |
+
perm_64 (torch.tensor) :
|
15 |
+
Tensor of integer indexes, defining a permutation, of size 64*64
|
16 |
+
|
17 |
+
perm_256 (torch.tensor) :
|
18 |
+
Tensor of integer indexes, defining a permutation, of size 256*256
|
19 |
+
'''
|
20 |
+
|
21 |
+
assert perm_64.shape == torch.Size([64*64]), \
|
22 |
+
"`perm_64` must be a permutation tensor of size 64*64"
|
23 |
+
|
24 |
+
assert perm_256.shape == torch.Size([256*256]), \
|
25 |
+
"`perm_256` must be a permutation tensor of size 256*256"
|
26 |
+
|
27 |
+
# Get random permutation and inverse permutation for stage 1
|
28 |
+
self.perm_64 = perm_64
|
29 |
+
self.perm_64_inv = get_inv_perm(self.perm_64)
|
30 |
+
|
31 |
+
# Get random permutation and inverse permutation for stage 2
|
32 |
+
self.perm_256 = perm_256
|
33 |
+
self.perm_256_inv = get_inv_perm(self.perm_256)
|
34 |
+
|
35 |
+
def view(self, im):
|
36 |
+
im_size = im.shape[-1]
|
37 |
+
perm = self.perm_64 if im_size == 64 else self.perm_256
|
38 |
+
num_patches = im_size
|
39 |
+
|
40 |
+
# Permute every pixel in the image
|
41 |
+
patch_size = 1
|
42 |
+
|
43 |
+
# Reshape into patches of size (c, patch_size, patch_size)
|
44 |
+
patches = rearrange(im,
|
45 |
+
'c (h p1) (w p2) -> (h w) c p1 p2',
|
46 |
+
p1=patch_size,
|
47 |
+
p2=patch_size)
|
48 |
+
|
49 |
+
# Permute
|
50 |
+
patches = patches[perm]
|
51 |
+
|
52 |
+
# Reshape back into image
|
53 |
+
im_rearr = rearrange(patches,
|
54 |
+
'(h w) c p1 p2 -> c (h p1) (w p2)',
|
55 |
+
h=num_patches,
|
56 |
+
w=num_patches,
|
57 |
+
p1=patch_size,
|
58 |
+
p2=patch_size)
|
59 |
+
return im_rearr
|
60 |
+
|
61 |
+
def inverse_view(self, noise):
|
62 |
+
im_size = noise.shape[-1]
|
63 |
+
perm_inv = self.perm_64_inv if im_size == 64 else self.perm_256_inv
|
64 |
+
num_patches = im_size
|
65 |
+
|
66 |
+
# Permute every pixel in the image
|
67 |
+
patch_size = 1
|
68 |
+
|
69 |
+
# Reshape into patches of size (c, patch_size, patch_size)
|
70 |
+
patches = rearrange(noise,
|
71 |
+
'c (h p1) (w p2) -> (h w) c p1 p2',
|
72 |
+
p1=patch_size,
|
73 |
+
p2=patch_size)
|
74 |
+
|
75 |
+
# Apply inverse permutation
|
76 |
+
patches = patches[perm_inv]
|
77 |
+
|
78 |
+
# Reshape back into image
|
79 |
+
im_rearr = rearrange(patches,
|
80 |
+
'(h w) c p1 p2 -> c (h p1) (w p2)',
|
81 |
+
h=num_patches,
|
82 |
+
w=num_patches,
|
83 |
+
p1=patch_size,
|
84 |
+
p2=patch_size)
|
85 |
+
return im_rearr
|
86 |
+
|
87 |
+
def make_frame(self, im, t):
|
88 |
+
# TODO: Implement this, as just moving pixels around
|
89 |
+
raise NotImplementedError()
|
90 |
+
|
91 |
+
|
visual_anagrams/views/view_rotate.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import torchvision.transforms.functional as TF
|
4 |
+
from torchvision.transforms import InterpolationMode
|
5 |
+
|
6 |
+
from .view_base import BaseView
|
7 |
+
|
8 |
+
|
9 |
+
class Rotate90CWView(BaseView):
|
10 |
+
def __init__(self):
|
11 |
+
pass
|
12 |
+
|
13 |
+
def view(self, im):
|
14 |
+
# TODO: Is nearest-exact better?
|
15 |
+
return TF.rotate(im, -90, interpolation=InterpolationMode.NEAREST)
|
16 |
+
|
17 |
+
def inverse_view(self, noise):
|
18 |
+
return TF.rotate(noise, 90, interpolation=InterpolationMode.NEAREST)
|
19 |
+
|
20 |
+
def make_frame(self, im, t):
|
21 |
+
im_size = im.size[0]
|
22 |
+
frame_size = int(im_size * 1.5)
|
23 |
+
theta = t * -90
|
24 |
+
|
25 |
+
frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
|
26 |
+
centered_loc = (frame_size - im_size) // 2
|
27 |
+
frame.paste(im, (centered_loc, centered_loc))
|
28 |
+
frame = frame.rotate(theta,
|
29 |
+
resample=Image.Resampling.BILINEAR,
|
30 |
+
expand=False,
|
31 |
+
fillcolor=(255,255,255))
|
32 |
+
|
33 |
+
return frame
|
34 |
+
|
35 |
+
|
36 |
+
class Rotate90CCWView(BaseView):
|
37 |
+
def __init__(self):
|
38 |
+
pass
|
39 |
+
|
40 |
+
def view(self, im):
|
41 |
+
# TODO: Is nearest-exact better?
|
42 |
+
return TF.rotate(im, 90, interpolation=InterpolationMode.NEAREST)
|
43 |
+
|
44 |
+
def inverse_view(self, noise):
|
45 |
+
return TF.rotate(noise, -90, interpolation=InterpolationMode.NEAREST)
|
46 |
+
|
47 |
+
def make_frame(self, im, t):
|
48 |
+
im_size = im.size[0]
|
49 |
+
frame_size = int(im_size * 1.5)
|
50 |
+
theta = t * 90
|
51 |
+
|
52 |
+
frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
|
53 |
+
centered_loc = (frame_size - im_size) // 2
|
54 |
+
frame.paste(im, (centered_loc, centered_loc))
|
55 |
+
frame = frame.rotate(theta,
|
56 |
+
resample=Image.Resampling.BILINEAR,
|
57 |
+
expand=False,
|
58 |
+
fillcolor=(255,255,255))
|
59 |
+
|
60 |
+
return frame
|
61 |
+
|
62 |
+
|
63 |
+
class Rotate180View(BaseView):
|
64 |
+
def __init__(self):
|
65 |
+
pass
|
66 |
+
|
67 |
+
def view(self, im):
|
68 |
+
# TODO: Is nearest-exact better?
|
69 |
+
return TF.rotate(im, 180, interpolation=InterpolationMode.NEAREST)
|
70 |
+
|
71 |
+
def inverse_view(self, noise):
|
72 |
+
return TF.rotate(noise, -180, interpolation=InterpolationMode.NEAREST)
|
73 |
+
|
74 |
+
def make_frame(self, im, t):
|
75 |
+
im_size = im.size[0]
|
76 |
+
frame_size = int(im_size * 1.5)
|
77 |
+
theta = t * 180
|
78 |
+
|
79 |
+
frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
|
80 |
+
centered_loc = (frame_size - im_size) // 2
|
81 |
+
frame.paste(im, (centered_loc, centered_loc))
|
82 |
+
frame = frame.rotate(theta,
|
83 |
+
resample=Image.Resampling.BILINEAR,
|
84 |
+
expand=False,
|
85 |
+
fillcolor=(255,255,255))
|
86 |
+
|
87 |
+
return frame
|
visual_anagrams/views/view_skew.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .view_base import BaseView
|
7 |
+
|
8 |
+
|
9 |
+
class SkewView(BaseView):
|
10 |
+
def __init__(self, skew_factor=1.5):
|
11 |
+
self.skew_factor = skew_factor
|
12 |
+
|
13 |
+
def skew_image(self, im, skew_factor):
|
14 |
+
'''
|
15 |
+
Roll each column of the image by increasing displacements.
|
16 |
+
This is a permutation of pixels
|
17 |
+
'''
|
18 |
+
|
19 |
+
# Params
|
20 |
+
c,h,w = im.shape
|
21 |
+
h_center = h//2
|
22 |
+
|
23 |
+
# Roll columns
|
24 |
+
cols = []
|
25 |
+
for i in range(w):
|
26 |
+
d = int(skew_factor * (i - h_center)) # Displacement
|
27 |
+
col = im[:,:,i]
|
28 |
+
cols.append(col.roll(d, dims=1))
|
29 |
+
|
30 |
+
# Stack rolled columns
|
31 |
+
skewed = torch.stack(cols, dim=2)
|
32 |
+
return skewed
|
33 |
+
|
34 |
+
def view(self, im):
|
35 |
+
return self.skew_image(im, self.skew_factor)
|
36 |
+
|
37 |
+
def inverse_view(self, noise):
|
38 |
+
return self.skew_image(noise, -self.skew_factor)
|
39 |
+
|
40 |
+
def make_frame(self, im, t):
|
41 |
+
im_size = im.size[0]
|
42 |
+
frame_size = int(im_size * 1.5)
|
43 |
+
skew_factor = t * self.skew_factor
|
44 |
+
|
45 |
+
# Convert to tensor, skew, then convert back to PIL
|
46 |
+
im = torch.tensor(np.array(im) / 255.).permute(2,0,1)
|
47 |
+
im = self.skew_image(im, skew_factor)
|
48 |
+
im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8))
|
49 |
+
|
50 |
+
# Paste on to canvas
|
51 |
+
frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
|
52 |
+
frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
|
53 |
+
|
54 |
+
return frame
|
55 |
+
|