File size: 5,608 Bytes
954caab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
from .permutations import get_inv_perm
from .view_base import BaseView
class PatchPermuteView(BaseView):
def __init__(self, num_patches=8):
'''
Implements random patch permutations, with `num_patches`
patches per side
num_patches (int) :
Number of patches in one dimension. Total number
of patches will be num_patches**2. Should be a power of 2.
'''
assert 64 % num_patches == 0 and 256 % num_patches == 0, \
"`num_patches` must divide image side lengths of 64 and 256"
self.num_patches = num_patches
# Get random permutation and inverse permutation
self.perm = torch.randperm(self.num_patches**2)
self.perm_inv = get_inv_perm(self.perm)
def view(self, im):
im_size = im.shape[-1]
# Get number of pixels on one side of a patch
patch_size = int(im_size / self.num_patches)
# Reshape into patches of size (c, patch_size, patch_size)
patches = rearrange(im,
'c (h p1) (w p2) -> (h w) c p1 p2',
p1=patch_size,
p2=patch_size)
# Permute
patches = patches[self.perm]
# Reshape back into image
im_rearr = rearrange(patches,
'(h w) c p1 p2 -> c (h p1) (w p2)',
h=self.num_patches,
w=self.num_patches,
p1=patch_size,
p2=patch_size)
return im_rearr
def inverse_view(self, noise):
im_size = noise.shape[-1]
# Get number of pixels on one side of a patch
patch_size = int(im_size / self.num_patches)
# Reshape into patches of size (c, patch_size, patch_size)
patches = rearrange(noise,
'c (h p1) (w p2) -> (h w) c p1 p2',
p1=patch_size,
p2=patch_size)
# Apply inverse permutation
patches = patches[self.perm_inv]
# Reshape back into image
im_rearr = rearrange(patches,
'(h w) c p1 p2 -> c (h p1) (w p2)',
h=self.num_patches,
w=self.num_patches,
p1=patch_size,
p2=patch_size)
return im_rearr
def make_frame(self, im, t, canvas_size=384, scale=4, knot_seed=0):
'''
Scale is a hack, because PIL for some reason doesn't support pasting
at floating point coordinates. So just render at larger scale
and resize by 1/scale
'''
# Get useful info
im_size = im.size[0]
offset = (canvas_size - im_size) // 2 # offset to center animation
canvas_size = canvas_size * scale
offset = offset * scale
im = TF.to_tensor(im)
# Get number of pixels on one side of a patch
im_size = im.shape[-1]
patch_size = int(im_size / self.num_patches)
# Extract patches
patches = rearrange(im,
'c (h p1) (w p2) -> (h w) c p1 p2',
p1=patch_size,
p2=patch_size)
# Get start locations (top left corner of patch)
yy, xx = torch.meshgrid(
torch.arange(self.num_patches),
torch.arange(self.num_patches)
)
xx = xx.flatten()
yy = yy.flatten()
start_locs = torch.stack([xx, yy], dim=1) * patch_size * scale
start_locs = start_locs + offset
# Get end locations by permuting
end_locs = start_locs[self.perm]
# Get random anchor locations
original_state = np.random.get_state()
np.random.seed(knot_seed)
rand_offsets = np.random.rand(self.num_patches**2, 1) * 2 - 1
rand_offsets = rand_offsets * 2 * scale
eps = np.random.randn(*start_locs.shape) # Add epsilon for divide by zero
np.random.set_state(original_state)
# Make spline knots by taking average of start and end,
# and offsetting by some amount normal from the line
avg_locs = (start_locs + end_locs) / 2.
norm = (end_locs - start_locs)
norm = norm + eps
norm = norm / np.linalg.norm(norm, axis=1, keepdims=True)
rot_mat = np.array([[0,1], [-1,0]])
norm = norm @ rot_mat
rand_offsets = rand_offsets * (im_size / 4)
knot_locs = avg_locs + norm * rand_offsets
# Get paste locations
spline_0 = start_locs * (1 - t) + knot_locs * t
spline_1 = knot_locs * (1 - t) + end_locs * t
paste_locs = spline_0 * (1 - t) + spline_1 * t
paste_locs = paste_locs.to(int)
# Paste patches onto canvas
canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255))
for patch, paste_loc in zip(patches, paste_locs):
patch = TF.to_pil_image(patch).convert('RGBA')
patch = patch.resize((patch_size * scale, patch_size * scale))
paste_loc = (paste_loc[0].item(), paste_loc[1].item())
canvas.paste(patch, paste_loc, patch)
if scale != 1.0:
canvas = canvas.resize((canvas_size // scale, canvas_size // scale))
return canvas
|