|
import os
|
|
import torch
|
|
from torch import nn
|
|
from transformers import AutoModel
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
token=os.getenv("HF_TOKEN")
|
|
repo_id = "Siyunb323/CreativityEvaluation"
|
|
model = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese")
|
|
|
|
class BERTregressor(nn.Module):
|
|
def __init__(self, bert, hidden_size=768, num_linear=1, dropout=0.1,
|
|
o_type='cls', t_type= 'C', use_sigmoid=False):
|
|
|
|
super(BERTregressor, self).__init__()
|
|
self.encoder = bert
|
|
self.o_type = o_type
|
|
self.t_type = t_type
|
|
self.sigmoid = use_sigmoid
|
|
|
|
if num_linear == 2:
|
|
layers = [nn.Linear(hidden_size, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(128, 1)]
|
|
elif num_linear == 1:
|
|
layers = [nn.Dropout(dropout),
|
|
nn.Linear(hidden_size, 1)]
|
|
|
|
if use_sigmoid:
|
|
layers.append(nn.Sigmoid())
|
|
|
|
self.output = nn.Sequential(*layers)
|
|
|
|
def forward(self, inputs, return_attention=False):
|
|
|
|
X = {'input_ids':inputs['input_ids'],
|
|
'token_type_ids':inputs['token_type_ids'],
|
|
'attention_mask':inputs['attention_mask'],
|
|
'output_attentions':return_attention}
|
|
encoded_X = self.encoder(**X)
|
|
if self.o_type == 'cls':
|
|
output = self.output(encoded_X.last_hidden_state[:, 0, :])
|
|
elif self.o_type == 'pooler':
|
|
output = self.output(encoded_X.pooler_output)
|
|
|
|
output = 4 * output.squeeze(-1) + 1 if self.sigmoid and self.t_type == 'C' else output.squeeze(-1)
|
|
|
|
return output if not return_attention else (output, encoded_X.attentions)
|
|
|
|
class Effectiveness(nn.Module):
|
|
def __init__(self, num_layers, hidden_size=768, use_sigmoid=True, dropout=0.2, **kwargs):
|
|
super(Effectiveness, self).__init__(**kwargs)
|
|
self.sigmoid = use_sigmoid
|
|
|
|
if num_layers == 2:
|
|
layers = [
|
|
nn.Linear(hidden_size, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(128, 1)
|
|
]
|
|
else:
|
|
layers = [
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(hidden_size, 1)
|
|
]
|
|
|
|
if use_sigmoid:
|
|
layers.append(nn.Sigmoid())
|
|
|
|
self.output = nn.Sequential(*layers)
|
|
|
|
def forward(self, X):
|
|
output = self.output(X)
|
|
|
|
|
|
if self.sigmoid:
|
|
return 4 * output.squeeze(-1) + 1
|
|
else:
|
|
return output.squeeze(-1)
|
|
|
|
class Creativity(nn.Module):
|
|
"""BERT的下一句预测任务"""
|
|
def __init__(self, num_layers, hidden_size=768, use_sigmoid=True, dropout=0.2, **kwargs):
|
|
super(Creativity, self).__init__(**kwargs)
|
|
self.sigmoid = use_sigmoid
|
|
|
|
if num_layers == 2:
|
|
layers = [
|
|
nn.Linear(hidden_size, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(128, 1)
|
|
]
|
|
else:
|
|
layers = [
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(hidden_size, 1)
|
|
]
|
|
|
|
if use_sigmoid:
|
|
layers.append(nn.Sigmoid())
|
|
|
|
self.output = nn.Sequential(*layers)
|
|
|
|
def forward(self, X):
|
|
output = self.output(X)
|
|
|
|
|
|
if self.sigmoid:
|
|
return 4 * output.squeeze(-1) + 1
|
|
else:
|
|
return output.squeeze(-1)
|
|
|
|
class BERT2Phase(nn.Module):
|
|
def __init__(self, bert, hidden_size=768, type='cls',
|
|
num_linear=1, dropout=0.1, use_sigmoid=False):
|
|
|
|
super(BERT2Phase, self).__init__()
|
|
self.encoder = bert
|
|
self.type = type
|
|
self.sigmoid = use_sigmoid
|
|
|
|
self.effectiveness = Effectiveness(num_linear, hidden_size, use_sigmoid, dropout)
|
|
self.creativity = Creativity(num_linear, hidden_size, use_sigmoid, dropout)
|
|
|
|
def forward(self, inputs, return_attention=False):
|
|
X = {'input_ids':inputs['input_ids'],
|
|
'token_type_ids':inputs['token_type_ids'],
|
|
'attention_mask':inputs['attention_mask'],
|
|
'output_attentions':return_attention}
|
|
encoded_X = self.encoder(**X)
|
|
|
|
if self.type == 'cls':
|
|
e_pred = self.effectiveness(encoded_X.last_hidden_state[:, 0, :])
|
|
c_pred = self.creativity(encoded_X.last_hidden_state[:, 0, :])
|
|
elif self.type == 'pooler':
|
|
e_pred = self.effectiveness(encoded_X.pooler_output)
|
|
c_pred = self.creativity(encoded_X.pooler_output)
|
|
|
|
return (c_pred, e_pred) if not return_attention else (c_pred, e_pred, encoded_X.attentions)
|
|
|
|
def load_model(model_name, pooling_method):
|
|
pooling = pooling_method if pooling_method == 'cls' else 'pooler'
|
|
if model_name == "One-phase Fine-tuned BERT":
|
|
loaded_net = BERTregressor(model, hidden_size=768, num_linear=1, dropout=0.1, o_type=pooling, t_type='C', use_sigmoid=True)
|
|
filename = 'model' + f"/OnePhase_BERT_{pooling_method}.pth"
|
|
elif model_name == "Two-phase Fine-tuned BERT":
|
|
loaded_net = BERT2Phase(model, hidden_size=768, num_linear=1, dropout=0.1, type=pooling, use_sigmoid=True)
|
|
filename = 'model' + f"/TwoPhase_BERT_{pooling_method}.pth"
|
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token)
|
|
loaded_net.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
|
loaded_net.eval()
|
|
return loaded_net
|
|
|