import torch | |
from torch import nn | |
class ScaledSoftmaxCE(nn.Module): | |
def forward(self, x, label): | |
logits = x[..., :-10] | |
temp_scales = x[..., -10:] | |
logprobs = logits.softmax(-1) | |
import torch | |
from torch import nn | |
class ScaledSoftmaxCE(nn.Module): | |
def forward(self, x, label): | |
logits = x[..., :-10] | |
temp_scales = x[..., -10:] | |
logprobs = logits.softmax(-1) | |