Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import transformers | |
class BertAD(nn.Module): | |
def __init__(self): | |
super(BertAD, self).__init__() | |
self.bert = transformers.BertModel.from_pretrained('model', output_hidden_states=True) | |
self.layer = nn.Linear(768, 2) | |
def forward(self, ids, mask, token_type): | |
output = self.bert(input_ids = ids, | |
attention_mask = mask, | |
token_type_ids = token_type) | |
logits = self.layer(output[0]) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1) | |
end_logits = end_logits.squeeze(-1) | |
return start_logits, end_logits |