File size: 3,684 Bytes
f50f696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class _PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.d_model = d_model
self.device_test_tensor = nn.Parameter(torch.tensor(1.))
def forward(self, x):# T x B x num_features
assert self.d_model % x.shape[-1]*2 == 0
d_per_feature = self.d_model // x.shape[-1]
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
#position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
interval_size = 10
div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
#print(div_term/2/math.pi)
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
class EmbeddingEncoder(nn.Module):
def __init__(self, num_features, em_size, num_embs=100):
super().__init__()
self.num_embs = num_embs
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
self.init_weights(.1)
self.min_max = (-2,+2)
@property
def width(self):
return self.min_max[1] - self.min_max[0]
def init_weights(self, initrange):
self.embeddings.weight.data.uniform_(-initrange, initrange)
def discretize(self, x):
split_size = self.width / self.num_embs
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
def forward(self, x): # T x B x num_features
x_idxs = self.discretize(x)
x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
# print(x_idxs,self.embeddings.weight.shape)
return self.embeddings(x_idxs).mean(-2)
Linear = nn.Linear
MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
nn.ReLU(),
nn.Linear(emsize*2,emsize))
class Conv(nn.Module):
def __init__(self, input_size, emsize):
super().__init__()
self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
self.linear = nn.Linear(64,emsize)
def forward(self, x):
size = math.isqrt(x.shape[-1])
assert size*size == x.shape[-1]
x = x.reshape(*x.shape[:-1], 1, size, size)
for conv in self.convs:
if x.shape[-1] < 4:
break
x = conv(x)
x.relu_()
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
return self.linear(x)
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
class CanEmb(nn.Embedding):
def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
assert embedding_dim % num_features == 0
embedding_dim = embedding_dim // num_features
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
def forward(self, x):
x = super().forward(x)
return x.view(*x.shape[:-2], -1)
def get_Canonical(num_classes):
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
def get_Embedding(num_embs_per_feature=100):
return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
|