|
import os |
|
import os.path as osp |
|
import pdb |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
import json |
|
from packaging import version |
|
import torch.distributed as dist |
|
|
|
|
|
class OD_model(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
self.order_num = args.order_num |
|
if args.od_type == 'linear_cat': |
|
|
|
|
|
self.order_dense_1 = nn.Linear(args.hidden_size * self.order_num, args.hidden_size) |
|
if self.args.num_od_layer > 0: |
|
self.layer = nn.ModuleList([OD_Layer_linear(args) for _ in range(args.num_od_layer)]) |
|
|
|
self.order_dense_2 = nn.Linear(args.hidden_size, 1) |
|
|
|
self.actication = nn.LeakyReLU() |
|
self.bn = torch.nn.BatchNorm1d(args.hidden_size) |
|
self.dp = nn.Dropout(p=args.hidden_dropout_prob) |
|
self.loss_func = nn.BCEWithLogitsLoss() |
|
|
|
|
|
def forward(self, input, labels): |
|
|
|
|
|
loss_dic = {} |
|
pre = self.predict(input) |
|
|
|
loss = self.loss_func(pre, labels.unsqueeze(1)) |
|
loss_dic['order_loss'] = loss.item() |
|
return loss, loss_dic |
|
|
|
def encode(self, input): |
|
if self.args.num_od_layer > 0: |
|
for layer_module in self.layer: |
|
input = layer_module(input) |
|
inputs = torch.chunk(input, 2, dim=0) |
|
emb = torch.concat(inputs, dim=1) |
|
return self.actication(self.order_dense_1(self.dp(emb))) |
|
|
|
def predict(self, input): |
|
return self.order_dense_2(self.bn(self.encode(input))) |
|
|
|
def right_caculate(self, input, labels, threshold=0.5): |
|
input = input.squeeze(1).tolist() |
|
labels = labels.tolist() |
|
right = 0 |
|
for i in range(len(input)): |
|
if (input[i] >= threshold and labels[i] >= 0.5) or (input[i] < threshold and labels[i] < 0.5): |
|
right += 1 |
|
return right |
|
|
|
|
|
class OD_Layer_linear(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
self.dense = nn.Linear(args.hidden_size, args.hidden_size) |
|
self.actication = nn.LeakyReLU() |
|
self.bn = torch.nn.BatchNorm1d(args.hidden_size) |
|
self.dropout = nn.Dropout(p=args.hidden_dropout_prob) |
|
|
|
def forward(self, input): |
|
return self.actication(self.bn(self.dense(self.dropout(input)))) |
|
|