Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import numpy as np | |
TO_WARN_USER_ONCE = True | |
class AbstractPermuter(nn.Module): | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
def forward(self, x, reverse=False): | |
raise NotImplementedError | |
class Identity(AbstractPermuter): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, reverse=False): | |
return x | |
class ColumnMajor(AbstractPermuter): | |
'''Useful for spectrograms which are from left to right (features, time)''' | |
def __init__(self, H, W): | |
super().__init__() | |
self.H = H | |
self.W = W | |
idx = self.make_idx(H, W) | |
self.register_buffer('forward_shuffle_idx', idx) | |
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
def forward(self, x, reverse=False): | |
B, L = x.shape | |
L_idx = len(self.forward_shuffle_idx) | |
if L > L_idx: | |
# an ugly patch for "infinite" sampling because self.*_shuffle_idx are shorter | |
# otherwise even uglier patch in other places. 'if' is triggered only on sampling. | |
assert L % L_idx == 0 and L / L_idx == int(L / L_idx), f'L: {L}, L_idx: {L_idx}' | |
W_scale = L // L_idx | |
# print(f'Permuter is making a guess on the temp scale: {W_scale}. Ignore on "infinite" sampling') | |
idx = self.make_idx(self.H, self.W * W_scale) | |
if not reverse: | |
return x[:, idx] | |
else: | |
return x[:, torch.argsort(idx)] | |
else: | |
if not reverse: | |
return x[:, self.forward_shuffle_idx] | |
else: | |
return x[:, self.backward_shuffle_idx] | |
def make_idx(self, H, W): | |
idx = np.arange(H * W).reshape(H, W) | |
idx = idx.T | |
idx = torch.tensor(idx.ravel()) | |
return idx | |
class Subsample(AbstractPermuter): | |
def __init__(self, H, W): | |
super().__init__() | |
C = 1 | |
indices = np.arange(H*W).reshape(C,H,W) | |
while min(H, W) > 1: | |
indices = indices.reshape(C,H//2,2,W//2,2) | |
indices = indices.transpose(0,2,4,1,3) | |
indices = indices.reshape(C*4,H//2, W//2) | |
H = H//2 | |
W = W//2 | |
C = C*4 | |
assert H == W == 1 | |
idx = torch.tensor(indices.ravel()) | |
self.register_buffer('forward_shuffle_idx', | |
nn.Parameter(idx, requires_grad=False)) | |
self.register_buffer('backward_shuffle_idx', | |
nn.Parameter(torch.argsort(idx), requires_grad=False)) | |
def forward(self, x, reverse=False): | |
if not reverse: | |
return x[:, self.forward_shuffle_idx] | |
else: | |
return x[:, self.backward_shuffle_idx] | |
def mortonify(i, j): | |
"""(i,j) index to linear morton code""" | |
i = np.uint64(i) | |
j = np.uint64(j) | |
z = np.uint(0) | |
for pos in range(32): | |
z = (z | | |
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | | |
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) | |
) | |
return z | |
class ZCurve(AbstractPermuter): | |
def __init__(self, H, W): | |
super().__init__() | |
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] | |
idx = np.argsort(reverseidx) | |
idx = torch.tensor(idx) | |
reverseidx = torch.tensor(reverseidx) | |
self.register_buffer('forward_shuffle_idx', | |
idx) | |
self.register_buffer('backward_shuffle_idx', | |
reverseidx) | |
def forward(self, x, reverse=False): | |
if not reverse: | |
return x[:, self.forward_shuffle_idx] | |
else: | |
return x[:, self.backward_shuffle_idx] | |
class SpiralOut(AbstractPermuter): | |
def __init__(self, H, W): | |
super().__init__() | |
assert H == W | |
size = W | |
indices = np.arange(size*size).reshape(size,size) | |
i0 = size//2 | |
j0 = size//2-1 | |
i = i0 | |
j = j0 | |
idx = [indices[i0, j0]] | |
step_mult = 0 | |
for c in range(1, size//2+1): | |
step_mult += 1 | |
# steps left | |
for k in range(step_mult): | |
i = i - 1 | |
j = j | |
idx.append(indices[i, j]) | |
# step down | |
for k in range(step_mult): | |
i = i | |
j = j + 1 | |
idx.append(indices[i, j]) | |
step_mult += 1 | |
if c < size//2: | |
# step right | |
for k in range(step_mult): | |
i = i + 1 | |
j = j | |
idx.append(indices[i, j]) | |
# step up | |
for k in range(step_mult): | |
i = i | |
j = j - 1 | |
idx.append(indices[i, j]) | |
else: | |
# end reached | |
for k in range(step_mult-1): | |
i = i + 1 | |
idx.append(indices[i, j]) | |
assert len(idx) == size*size | |
idx = torch.tensor(idx) | |
self.register_buffer('forward_shuffle_idx', idx) | |
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
def forward(self, x, reverse=False): | |
if not reverse: | |
return x[:, self.forward_shuffle_idx] | |
else: | |
return x[:, self.backward_shuffle_idx] | |
class SpiralIn(AbstractPermuter): | |
def __init__(self, H, W): | |
super().__init__() | |
assert H == W | |
size = W | |
indices = np.arange(size*size).reshape(size,size) | |
i0 = size//2 | |
j0 = size//2-1 | |
i = i0 | |
j = j0 | |
idx = [indices[i0, j0]] | |
step_mult = 0 | |
for c in range(1, size//2+1): | |
step_mult += 1 | |
# steps left | |
for k in range(step_mult): | |
i = i - 1 | |
j = j | |
idx.append(indices[i, j]) | |
# step down | |
for k in range(step_mult): | |
i = i | |
j = j + 1 | |
idx.append(indices[i, j]) | |
step_mult += 1 | |
if c < size//2: | |
# step right | |
for k in range(step_mult): | |
i = i + 1 | |
j = j | |
idx.append(indices[i, j]) | |
# step up | |
for k in range(step_mult): | |
i = i | |
j = j - 1 | |
idx.append(indices[i, j]) | |
else: | |
# end reached | |
for k in range(step_mult-1): | |
i = i + 1 | |
idx.append(indices[i, j]) | |
assert len(idx) == size*size | |
idx = idx[::-1] | |
idx = torch.tensor(idx) | |
self.register_buffer('forward_shuffle_idx', idx) | |
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
def forward(self, x, reverse=False): | |
if not reverse: | |
return x[:, self.forward_shuffle_idx] | |
else: | |
return x[:, self.backward_shuffle_idx] | |
class Random(nn.Module): | |
def __init__(self, H, W): | |
super().__init__() | |
indices = np.random.RandomState(1).permutation(H*W) | |
idx = torch.tensor(indices.ravel()) | |
self.register_buffer('forward_shuffle_idx', idx) | |
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
def forward(self, x, reverse=False): | |
if not reverse: | |
return x[:, self.forward_shuffle_idx] | |
else: | |
return x[:, self.backward_shuffle_idx] | |
class AlternateParsing(AbstractPermuter): | |
def __init__(self, H, W): | |
super().__init__() | |
indices = np.arange(W*H).reshape(H,W) | |
for i in range(1, H, 2): | |
indices[i, :] = indices[i, ::-1] | |
idx = indices.flatten() | |
assert len(idx) == H*W | |
idx = torch.tensor(idx) | |
self.register_buffer('forward_shuffle_idx', idx) | |
self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
def forward(self, x, reverse=False): | |
if not reverse: | |
return x[:, self.forward_shuffle_idx] | |
else: | |
return x[:, self.backward_shuffle_idx] | |
if __name__ == "__main__": | |
p0 = AlternateParsing(16, 16) | |
print(p0.forward_shuffle_idx) | |
print(p0.backward_shuffle_idx) | |
x = torch.randint(0, 768, size=(11, 256)) | |
y = p0(x) | |
xre = p0(y, reverse=True) | |
assert torch.equal(x, xre) | |
p1 = SpiralOut(2, 2) | |
print(p1.forward_shuffle_idx) | |
print(p1.backward_shuffle_idx) | |
x = torch.randint(0, 768, size=(11, 2*2)) | |
y = p1(x) | |
xre = p1(y, reverse=True) | |
assert torch.equal(x, xre) | |
p2 = ColumnMajor(5, 53) | |
print(p2.forward_shuffle_idx) | |
print(p2.backward_shuffle_idx) | |
x = torch.randint(0, 768, size=(11, 5*53)) | |
xre = p2(p2(x), reverse=True) | |
assert torch.equal(x, xre) | |