File size: 679 Bytes
b1aad3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
import transformers

class BertAD(nn.Module):
  def __init__(self):
    super(BertAD, self).__init__()
    self.bert = transformers.BertModel.from_pretrained('model', output_hidden_states=True)
    self.layer = nn.Linear(768, 2)
    

  def forward(self, ids, mask, token_type):
    output = self.bert(input_ids = ids,
                       attention_mask = mask,
                       token_type_ids = token_type)
    
    logits = self.layer(output[0]) 
    start_logits, end_logits = logits.split(1, dim=-1)
    
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)

    

    return start_logits, end_logits