File size: 5,608 Bytes
954caab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange

from .permutations import get_inv_perm
from .view_base import BaseView


class PatchPermuteView(BaseView):
    def __init__(self, num_patches=8):
        '''
        Implements random patch permutations, with `num_patches`
            patches per side

        num_patches (int) :
            Number of patches in one dimension. Total number
            of patches will be num_patches**2. Should be a power of 2.
        '''

        assert 64 % num_patches == 0 and 256 % num_patches == 0, \
            "`num_patches` must divide image side lengths of 64 and 256"

        self.num_patches = num_patches

        # Get random permutation and inverse permutation
        self.perm = torch.randperm(self.num_patches**2)
        self.perm_inv = get_inv_perm(self.perm)

    def view(self, im):
        im_size = im.shape[-1]

        # Get number of pixels on one side of a patch
        patch_size = int(im_size / self.num_patches)

        # Reshape into patches of size (c, patch_size, patch_size)
        patches = rearrange(im, 
                            'c (h p1) (w p2) -> (h w) c p1 p2', 
                            p1=patch_size, 
                            p2=patch_size)

        # Permute
        patches = patches[self.perm]

        # Reshape back into image
        im_rearr = rearrange(patches, 
                             '(h w) c p1 p2 -> c (h p1) (w p2)', 
                             h=self.num_patches, 
                             w=self.num_patches, 
                             p1=patch_size, 
                             p2=patch_size)
        return im_rearr

    def inverse_view(self, noise):
        im_size = noise.shape[-1]

        # Get number of pixels on one side of a patch
        patch_size = int(im_size / self.num_patches)

        # Reshape into patches of size (c, patch_size, patch_size)
        patches = rearrange(noise, 
                            'c (h p1) (w p2) -> (h w) c p1 p2', 
                            p1=patch_size, 
                            p2=patch_size)

        # Apply inverse permutation
        patches = patches[self.perm_inv]

        # Reshape back into image
        im_rearr = rearrange(patches, 
                             '(h w) c p1 p2 -> c (h p1) (w p2)', 
                             h=self.num_patches, 
                             w=self.num_patches, 
                             p1=patch_size, 
                             p2=patch_size)
        return im_rearr

    def make_frame(self, im, t, canvas_size=384, scale=4, knot_seed=0):
        '''
        Scale is a hack, because PIL for some reason doesn't support pasting
            at floating point coordinates. So just render at larger scale
            and resize by 1/scale
        '''
        # Get useful info
        im_size = im.size[0]
        offset = (canvas_size - im_size) // 2  # offset to center animation

        canvas_size = canvas_size * scale
        offset = offset * scale

        im = TF.to_tensor(im)

        # Get number of pixels on one side of a patch
        im_size = im.shape[-1]
        patch_size = int(im_size / self.num_patches)

        # Extract patches
        patches = rearrange(im, 
                            'c (h p1) (w p2) -> (h w) c p1 p2', 
                            p1=patch_size, 
                            p2=patch_size)

        # Get start locations (top left corner of patch)
        yy, xx = torch.meshgrid(
                        torch.arange(self.num_patches), 
                        torch.arange(self.num_patches)
                    )
        xx = xx.flatten()
        yy = yy.flatten()
        start_locs = torch.stack([xx, yy], dim=1) * patch_size * scale
        start_locs = start_locs + offset

        # Get end locations by permuting
        end_locs = start_locs[self.perm]

        # Get random anchor locations
        original_state = np.random.get_state()
        np.random.seed(knot_seed)
        rand_offsets = np.random.rand(self.num_patches**2, 1) * 2 - 1
        rand_offsets = rand_offsets * 2 * scale
        eps = np.random.randn(*start_locs.shape)    # Add epsilon for divide by zero
        np.random.set_state(original_state)

        # Make spline knots by taking average of start and end, 
        # and offsetting by some amount normal from the line
        avg_locs = (start_locs + end_locs) / 2.
        norm = (end_locs - start_locs)
        norm = norm + eps
        norm = norm / np.linalg.norm(norm, axis=1, keepdims=True)
        rot_mat = np.array([[0,1], [-1,0]])
        norm = norm @ rot_mat
        rand_offsets = rand_offsets * (im_size / 4)
        knot_locs = avg_locs + norm * rand_offsets

        # Get paste locations
        spline_0 = start_locs * (1 - t) + knot_locs * t
        spline_1 = knot_locs * (1 - t) + end_locs * t
        paste_locs = spline_0 * (1 - t) + spline_1 * t
        paste_locs = paste_locs.to(int)

        # Paste patches onto canvas
        canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255))
        for patch, paste_loc in zip(patches, paste_locs):
            patch = TF.to_pil_image(patch).convert('RGBA')
            patch = patch.resize((patch_size * scale, patch_size * scale))
            paste_loc = (paste_loc[0].item(), paste_loc[1].item())
            canvas.paste(patch, paste_loc, patch)

        if scale != 1.0:
            canvas = canvas.resize((canvas_size // scale, canvas_size // scale))

        return canvas