File size: 209 Bytes
f50f696
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch import nn


class ScaledSoftmaxCE(nn.Module):
    def forward(self, x, label):
        logits = x[..., :-10]
        temp_scales = x[..., -10:]



        logprobs = logits.softmax(-1)