callum-canavan's picture
Add helpers, change to hot dog example
954caab
raw
history blame
5.61 kB
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
from .permutations import get_inv_perm
from .view_base import BaseView
class PatchPermuteView(BaseView):
def __init__(self, num_patches=8):
'''
Implements random patch permutations, with `num_patches`
patches per side
num_patches (int) :
Number of patches in one dimension. Total number
of patches will be num_patches**2. Should be a power of 2.
'''
assert 64 % num_patches == 0 and 256 % num_patches == 0, \
"`num_patches` must divide image side lengths of 64 and 256"
self.num_patches = num_patches
# Get random permutation and inverse permutation
self.perm = torch.randperm(self.num_patches**2)
self.perm_inv = get_inv_perm(self.perm)
def view(self, im):
im_size = im.shape[-1]
# Get number of pixels on one side of a patch
patch_size = int(im_size / self.num_patches)
# Reshape into patches of size (c, patch_size, patch_size)
patches = rearrange(im,
'c (h p1) (w p2) -> (h w) c p1 p2',
p1=patch_size,
p2=patch_size)
# Permute
patches = patches[self.perm]
# Reshape back into image
im_rearr = rearrange(patches,
'(h w) c p1 p2 -> c (h p1) (w p2)',
h=self.num_patches,
w=self.num_patches,
p1=patch_size,
p2=patch_size)
return im_rearr
def inverse_view(self, noise):
im_size = noise.shape[-1]
# Get number of pixels on one side of a patch
patch_size = int(im_size / self.num_patches)
# Reshape into patches of size (c, patch_size, patch_size)
patches = rearrange(noise,
'c (h p1) (w p2) -> (h w) c p1 p2',
p1=patch_size,
p2=patch_size)
# Apply inverse permutation
patches = patches[self.perm_inv]
# Reshape back into image
im_rearr = rearrange(patches,
'(h w) c p1 p2 -> c (h p1) (w p2)',
h=self.num_patches,
w=self.num_patches,
p1=patch_size,
p2=patch_size)
return im_rearr
def make_frame(self, im, t, canvas_size=384, scale=4, knot_seed=0):
'''
Scale is a hack, because PIL for some reason doesn't support pasting
at floating point coordinates. So just render at larger scale
and resize by 1/scale
'''
# Get useful info
im_size = im.size[0]
offset = (canvas_size - im_size) // 2 # offset to center animation
canvas_size = canvas_size * scale
offset = offset * scale
im = TF.to_tensor(im)
# Get number of pixels on one side of a patch
im_size = im.shape[-1]
patch_size = int(im_size / self.num_patches)
# Extract patches
patches = rearrange(im,
'c (h p1) (w p2) -> (h w) c p1 p2',
p1=patch_size,
p2=patch_size)
# Get start locations (top left corner of patch)
yy, xx = torch.meshgrid(
torch.arange(self.num_patches),
torch.arange(self.num_patches)
)
xx = xx.flatten()
yy = yy.flatten()
start_locs = torch.stack([xx, yy], dim=1) * patch_size * scale
start_locs = start_locs + offset
# Get end locations by permuting
end_locs = start_locs[self.perm]
# Get random anchor locations
original_state = np.random.get_state()
np.random.seed(knot_seed)
rand_offsets = np.random.rand(self.num_patches**2, 1) * 2 - 1
rand_offsets = rand_offsets * 2 * scale
eps = np.random.randn(*start_locs.shape) # Add epsilon for divide by zero
np.random.set_state(original_state)
# Make spline knots by taking average of start and end,
# and offsetting by some amount normal from the line
avg_locs = (start_locs + end_locs) / 2.
norm = (end_locs - start_locs)
norm = norm + eps
norm = norm / np.linalg.norm(norm, axis=1, keepdims=True)
rot_mat = np.array([[0,1], [-1,0]])
norm = norm @ rot_mat
rand_offsets = rand_offsets * (im_size / 4)
knot_locs = avg_locs + norm * rand_offsets
# Get paste locations
spline_0 = start_locs * (1 - t) + knot_locs * t
spline_1 = knot_locs * (1 - t) + end_locs * t
paste_locs = spline_0 * (1 - t) + spline_1 * t
paste_locs = paste_locs.to(int)
# Paste patches onto canvas
canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255))
for patch, paste_loc in zip(patches, paste_locs):
patch = TF.to_pil_image(patch).convert('RGBA')
patch = patch.resize((patch_size * scale, patch_size * scale))
paste_loc = (paste_loc[0].item(), paste_loc[1].item())
canvas.paste(patch, paste_loc, patch)
if scale != 1.0:
canvas = canvas.resize((canvas_size // scale, canvas_size // scale))
return canvas