|
import torch |
|
from torch import nn |
|
import random |
|
|
|
|
|
class ScaledDecoder(nn.Module): |
|
def __init__(self, ninp, nhid, nout): |
|
super().__init__() |
|
self.linear = nn.Linear(ninp, nhid) |
|
self.linear1 = nn.Linear(nhid, nout) |
|
self.linear2 = nn.Linear(nhid, 10) |
|
|
|
def forward(self, x): |
|
|
|
x = self.linear(x) |
|
x = nn.GELU()(x) |
|
temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device) |
|
if random.random() > .99: |
|
print(temps.shape,temps[:,:2]) |
|
return self.linear1(x) / temps.unsqueeze(-1) |
|
|
|
class FixedScaledDecoder(nn.Module): |
|
def __init__(self, ninp, nhid, nout): |
|
super().__init__() |
|
self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)) |
|
self.T = nn.Parameter(torch.ones(10000)/10000) |
|
|
|
def forward(self, x): |
|
return self.mapper(x)/self.T.sum() |
|
|
|
|