File size: 10,332 Bytes
f34e8aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

import gradio as gr

from einops import rearrange, repeat

import torchvision.transforms.functional as ttf
from timm.models.convmixer import ConvMixer
import functorch 

def img_to_patches(im, patch_h, patch_w):
    "B, C, H, W -> B, C, D, h_patch, w_patch"
    bs, c, h, w = im.shape
    im = im.unfold(-1, patch_h, patch_w).unfold(2, patch_h, patch_w)
    im = im.permute(0, 1, 2, 3, 5, 4)
    im = im.contiguous().view(bs, c, -1, patch_h, patch_w)
    return im

def patches_to_img(patches, num_patch_h, num_patch_w):
    "B, C, D, h_patch, w_patch -> B, C, H, W"
    bs, c, d, h, w = patches.shape
    patches = patches.view(bs, c, num_patch_h, num_patch_w, h, w)
    # fold patches
    patches = torch.cat([patches[..., k, :, :] for k in range(num_patch_w)], dim=-1)
    x = torch.cat([patches[..., k, :, :] for k in range(num_patch_h)], dim=-2)    
    return x

def vmapped_rotate(x, angle, in_dims=1):
    "B, C, D, H, W -> B, C, D, H, W"
    rotate_ = functorch.vmap(ttf.rotate, in_dims=in_dims, out_dims=in_dims)
    return rotate_(x, angle=angle)

class CollageOperator2d(nn.Module):

    def __init__(self, res, rh, rw, dh=None, dw=None, use_augmentations=False):
        """Collage Operator for two-dimensional data. Given a fractal code, it outputs the corresponding fixed-point.

        Args:
            res (int): Spatial resolutions of input (and output) data.
            rh (int): Height of range (target) square patches.
            rw (int): Width of range (target) square patches.
            dh (int, optional): Height of range domain (source) patches. Defaults to `res`.
            dw (int, optional): Width of range domain (source) patches. Defaults to `res`.
            use_augmentations (bool, optional): Use augmentations of domain square patches at each decoding iteration. Defaults to `False`.
        """
        super().__init__()
        self.dh, self.dw = dh, dw
        if self.dh is None: self.dh = res
        if self.dw is None: self.dw = res

        # 5 refers to the 5 copies of domain patches generated with the current choice of augmentations:
        # 3 rotations (90, 180, 270), horizontal flips and vertical flips.
        self.n_aug_transforms = 9 if use_augmentations else 0

        # precompute useful quantities related to the partitioning scheme into patches, given
        # the desired `dh`, `dw`, `rh`, `rw`. 
        partition_info = self.collage_partition_info(res, self.dh, self.dw, rh, rw)
        self.n_dh, self.n_dw, self.n_rh, self.n_rw, self.h_factors, self.w_factors, self.n_domains, self.n_ranges = partition_info
        
        # At each step of the collage, all (source) domain patches are pooled down to the size of range (target) patches.
        # Notices how the pooling factors do not change if one decodes at higher resolutions, since both domain and range 
        # patch sizes are multiplied by the same integer.
        self.pool = nn.AvgPool3d(kernel_size=(1, self.h_factors, self.w_factors), stride=(1, self.h_factors, self.w_factors))

    def collage_operator(self, z, collage_weight, collage_bias):
        """Collage Operator (decoding). Performs the steps described in  Def. 3.1, Figure 2."""

        # Given the current iterate `z`, we split it into domain patches according to the partitioning scheme.
        domains = img_to_patches(z)

        # Pool domains (pre augmentation) to range patch sizes.
        pooled_domains = self.pool(domains) 

        # If needed, produce additional candidate domain patches as augmentations of existing domains.
        # Auxiliary learned feature maps / patches are also introduced here.
        if self.n_aug_transforms > 1:
            pooled_domains = self.generate_candidates(pooled_domains)

        pooled_domains = repeat(pooled_domains, 'b c d h w -> b c d r h w', r=self.num_ranges)

        # Apply the affine maps to domain patches
        range_domains = torch.einsum('bcdrhw, bcdr -> bcrhw', pooled_domains, collage_weight)
        range_domains = range_domains + collage_bias[..., None, None]

        # Reconstruct data by "composing" the output patches back together (collage!).
        z = patches_to_img(range_domains)

        return z

    def decode_step(self, z, weight, bias, superres_factor, return_patches=False):
        """Single Collage Operator step. Performs the steps described in:
        https://arxiv.org/pdf/2204.07673.pdf (Def. 3.1, Figure 2).
        """

        # Given the current iterate `z`, we split it into `n_domains` domain patches.
        domains = img_to_patches(z, patch_h=self.dh * superres_factor, patch_w=self.dw * superres_factor)

        # Pool domains (pre augmentation) for compatibility with range patches.
        pooled_domains = self.pool(domains) 

        # If needed, produce additional candidate domain patches as augmentations of existing domains.
        if self.n_aug_transforms > 1:
            pooled_domains = self.generate_candidates(pooled_domains)

        pooled_domains = repeat(pooled_domains, 'b c d h w -> b c d r h w', r=self.n_ranges)

        # Apply the affine maps to domain patches
        range_domains = torch.einsum('bcdrhw, bcdr -> bcrhw', pooled_domains, weight)
        range_domains = range_domains + bias[:, :, :, None, None]

        # Reconstruct data by "composing" the output patches back together (collage!).
        z = patches_to_img(range_domains, self.n_rh, self.n_rw)
        if return_patches: return z, (domains, pooled_domains, range_domains)
        return z

    def generate_candidates(self, domains):
        domains = domains.permute(0,2,1,3,4)
        rotations = [vmapped_rotate(domains, angle=angle) for angle in (90, 180, 270)]
        hflips = ttf.hflip(domains)
        vflips = ttf.vflip(domains)
        br_shift = ttf.adjust_brightness(domains, 0.5)
        cr_shift = ttf.adjust_contrast(domains, 0.5)
        hue_shift = ttf.adjust_hue(domains, 0.5)
        sat_shift = ttf.adjust_saturation(domains, 0.5)
        domains = torch.cat([domains, *rotations, hflips, vflips, br_shift, cr_shift, hue_shift, sat_shift], dim=1)
        return domains.permute(0,2,1,3,4)

    def forward(self, x, co_w, co_bias, decode_steps=20, superres_factor=1):
        B, C, H, W = x.shape
        # It does not matter which initial condition is chosen, so long as the dimensions match.
        # The fixed-point of a Collage Operator is uniquely determined* by the fractal code
        # *: and auxiliary learned patches, if any.
        z = torch.randn(B, C, H * superres_factor, W * superres_factor).to(x.device)
        for _ in range(decode_steps):
            z = self.decode_step(z, co_w, co_bias, superres_factor)
        return z

    def collage_partition_info(self, input_res, dh, dw, rh, rw):
        """
        Computes auxiliary information for the collage (number of source and target domains, and relative size factors)
        """
        height = width = input_res
        n_dh, n_dw = height // dh, width // dw
        n_domains = n_dh * n_dw

        # Adjust number of domain patches to include augmentations
        n_domains = n_domains + n_domains * self.n_aug_transforms # (3 rotations, hflip, vlip)

        h_factors, w_factors = dh // rh, dw // rw
        n_rh, n_rw = input_res // rh, input_res // rw    
        n_ranges = n_rh * n_rw
        return n_dh, n_dw, n_rh, n_rw, h_factors, w_factors, n_domains, n_ranges


class NeuralCollageOperator2d(nn.Module):
    def __init__(self, out_res, out_channels, rh, rw, dh=None, dw=None, net=None, use_augmentations=False):
        super().__init__()
        self.co = CollageOperator2d(out_res, rh, rw, dh, dw, use_augmentations)
        # In a Collage Operator, the affine map requires a single scalar weight 
        # for each pair of domain and range patches, and a single scalar bias for each range.
        # `net` learns to output these weights based on the objective.
        self.co_w_dim = self.co.n_domains * self.co.n_ranges * out_channels
        self.co_bias_dim = self.co.n_ranges * out_channels
        tot_out_dim = self.co_w_dim + self.co_bias_dim

        # Does not need to be a ConvMixer: for deep generative Neural Collages `net` can be e.g, a VDVAE.
        if net is None:
            net = ConvMixer(dim=32, depth=8, kernel_size=9, patch_size=7, num_classes=tot_out_dim)
        self.net = net

        self.softmax = nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()

    def forward(self, x, decode_steps=10, superres_factor=1, return_co_code=False):
        B, C, H, W = x.shape
        co_code = self.net(x) # B, C, co_w_dim + co_mix_dim + co_bias_dim
        co_w, co_bias = torch.split(co_code, [self.co_w_dim, self.co_bias_dim], dim=-1)

        co_w = co_w.view(B, C, self.co.n_domains, self.co.n_ranges)
        # No restrictions on co_w, thus no guarantee of contractiveness.
        # In the full jax version of Neural Collages we enforce the constraint |co_w| < 1 (elementwise).
        co_bias = co_bias.view(B, C, self.co.n_ranges)
        co_bias = self.tanh(co_bias)

        z = self.co(x, co_w, co_bias, decode_steps=decode_steps, superres_factor=superres_factor)
        
        if return_co_code: return z, co_w, co_bias
        else: return z



def fractalize(img, superresolution_factor=1):
    superresolution_factor = int(superresolution_factor)
    
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    im = np.asarray(img)

    im = torch.from_numpy(im).permute(2,0,1).to(device)
    co = NeuralCollageOperator2d(out_res=100, out_channels=3, rh=2, rw=2, dh=100, dw=100).to(device)

    opt = torch.optim.Adam(co.parameters(), lr=1e-2)
    objective = nn.MSELoss()
    norm_im = im.float().unsqueeze(0) / 255

    for _ in range(200):
        recon = co(norm_im, decode_steps=10, return_co_code=False)

        loss = objective(recon, norm_im)
        loss.backward()
        opt.step()
        opt.zero_grad()

    fractal_img = co(norm_im, decode_steps=10, superres_factor=superresolution_factor)[0].permute(1,2,0).clamp(-1, 1)
    return fractal_img.cpu().detach().numpy()

demo = gr.Interface(
    fn=fractalize,
    inputs=[gr.Image(shape=(100, 100), image_mode='RGB'), gr.Slider(1, 40, step=1)],
    outputs="image"
)
demo.launch()