kpriyanshu256's picture
Added app files
b1aad3c
raw
history blame
679 Bytes
import torch
import torch.nn as nn
import transformers
class BertAD(nn.Module):
def __init__(self):
super(BertAD, self).__init__()
self.bert = transformers.BertModel.from_pretrained('model', output_hidden_states=True)
self.layer = nn.Linear(768, 2)
def forward(self, ids, mask, token_type):
output = self.bert(input_ids = ids,
attention_mask = mask,
token_type_ids = token_type)
logits = self.layer(output[0])
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
return start_logits, end_logits