File size: 4,399 Bytes
2d3bbc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from collections import namedtuple
from typing import NamedTuple, Optional, Tuple
import torch
from torch import nn


def _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor):
    if x.ndim <= 3:
        x = x - mean
        x = x @ tx.T
    elif x.ndim == 4:
        x = x - mean.reshape(1, -1, 1, 1)
        kernel = tx.reshape(*tx.shape, 1, 1)
        x = torch.nn.functional.conv2d(x, weight=kernel, bias=None, stride=1, padding=0)
    else:
        raise ValueError(f'Unsupported input dimension: {x.ndim}, shape: {x.shape}')
    return x


class FeatureNormalizer(nn.Module):
    def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32):
        super().__init__()

        self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype))
        self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = _run_kernel(x, self.mean, self.tx)
        return x


class InterFeatState(NamedTuple):
    y: torch.Tensor
    alpha: torch.Tensor


class IntermediateFeatureNormalizerBase(nn.Module):
    def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
        raise NotImplementedError()


class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
    def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype))

        rot = torch.eye(embed_dim, dtype=dtype)
        if rot_per_layer:
            rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1)

        self.register_buffer('rotation', rot.contiguous())
        self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype))

    def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
        if rot_index is None:
            rot_index = index

        if skip:
            assert x.ndim == 3, f'Cannot use the `skip` parameter when the `x` tensor isn\'t 3-dimensional.'
            prefix, x = x[:, :skip], x[:, skip:]

        rotation = self._get_rotation(rot_index)
        y = _run_kernel(x, self.means[index], rotation)

        alpha = self.alphas[index]
        if skip:
            alpha = torch.cat([
                torch.ones(skip, dtype=alpha.dtype, device=alpha.device),
                alpha[None].expand(y.shape[1]),
            ]).reshape(1, -1, 1)
            y = torch.cat([prefix, y], dim=1)
        else:
            if x.ndim == 3:
                alpha = alpha.reshape(1, 1, 1).expand(1, y.shape[1], 1)
            elif x.ndim == 4:
                alpha = alpha.reshape(1, 1, 1, 1).expand(1, 1, *y.shape[2:])
            else:
                raise ValueError(f'Unsupported input dimension: {x.ndim}')

        return InterFeatState(y, alpha)

    def _get_rotation(self, rot_index: int) -> torch.Tensor:
        if self.rotation.ndim == 2:
            return self.rotation
        return self.rotation[rot_index]


class NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
    instances = dict()

    def __init__(self, dtype: torch.dtype, device: torch.device):
        super().__init__()
        self.register_buffer('alpha', torch.tensor(1, dtype=dtype, device=device))

    @staticmethod
    def get_instance(dtype: torch.dtype, device: torch.device):
        instance = NullIntermediateFeatureNormalizer.instances.get((dtype, device), None)
        if instance is None:
            instance = NullIntermediateFeatureNormalizer(dtype, device)
            NullIntermediateFeatureNormalizer.instances[(dtype, device)] = instance
        return instance

    def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
        return InterFeatState(x, self.alpha)