Sifal commited on
Commit
718e1b0
·
1 Parent(s): 26655b6

add base model option

Browse files
Files changed (1) hide show
  1. model.py +7 -3
model.py CHANGED
@@ -4,10 +4,14 @@ from transformers.modeling_outputs import TokenClassifierOutput
4
 
5
 
6
  class BertClassifier(nn.Module):
7
- def __init__(self, num_labels=2, dropout=0.1):
8
  super().__init__()
9
- config = BertConfig(vocab_size=34688, max_position_embeddings=512)
10
- self.bert = BertModel(config=config)
 
 
 
 
11
  self.num_labels = num_labels
12
  self.classifier = nn.Sequential(
13
  nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),
 
4
 
5
 
6
  class BertClassifier(nn.Module):
7
+ def __init__(self, num_labels=2, dropout=0.1,bert_model=None):
8
  super().__init__()
9
+ if bert_model:
10
+ self.bert = BertModel.from_pretrained(bert_model)
11
+ else:
12
+ config = BertConfig(vocab_size=34688, max_position_embeddings=512)
13
+ self.bert = BertModel(config=config)
14
+
15
  self.num_labels = num_labels
16
  self.classifier = nn.Sequential(
17
  nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),