import torch import torch.nn as nn import transformers class BertAD(nn.Module): def __init__(self): super(BertAD, self).__init__() self.bert = transformers.BertModel.from_pretrained('model', output_hidden_states=True) self.layer = nn.Linear(768, 2) def forward(self, ids, mask, token_type): output = self.bert(input_ids = ids, attention_mask = mask, token_type_ids = token_type) logits = self.layer(output[0]) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) return start_logits, end_logits