File size: 209 Bytes
f50f696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch
from torch import nn
class ScaledSoftmaxCE(nn.Module):
def forward(self, x, label):
logits = x[..., :-10]
temp_scales = x[..., -10:]
logprobs = logits.softmax(-1)
|