Jiqing commited on
Commit
ea455b4
·
1 Parent(s): b67f175

Update modeling_protst.py

Browse files
Files changed (1) hide show
  1. modeling_protst.py +5 -1
modeling_protst.py CHANGED
@@ -55,6 +55,8 @@ class BertForPubMed(BertPreTrainedModel):
55
  self.text_mlp = ProtSTHead(config)
56
  self.word_mlp = ProtSTHead(config)
57
 
 
 
58
  def forward(
59
  self,
60
  input_ids: Optional[torch.Tensor] = None,
@@ -111,7 +113,7 @@ class EsmForProteinRepresentation(EsmPreTrainedModel):
111
  self.protein_mlp = ProtSTHead(config)
112
  self.residue_mlp = ProtSTHead(config)
113
 
114
- self.init_weights()
115
 
116
  def forward(
117
  self,
@@ -163,6 +165,8 @@ class EsmForProteinPropertyPrediction(EsmPreTrainedModel):
163
  self.model = EsmForProteinRepresentation(config)
164
  self.classifier = ProtSTHead(config, out_dim=config.num_labels)
165
 
 
 
166
  def forward(
167
  self,
168
  input_ids: Optional[torch.LongTensor] = None,
 
55
  self.text_mlp = ProtSTHead(config)
56
  self.word_mlp = ProtSTHead(config)
57
 
58
+ self.post_init() # NOTE
59
+
60
  def forward(
61
  self,
62
  input_ids: Optional[torch.Tensor] = None,
 
113
  self.protein_mlp = ProtSTHead(config)
114
  self.residue_mlp = ProtSTHead(config)
115
 
116
+ self.post_init() # NOTE
117
 
118
  def forward(
119
  self,
 
165
  self.model = EsmForProteinRepresentation(config)
166
  self.classifier = ProtSTHead(config, out_dim=config.num_labels)
167
 
168
+ self.post_init() # NOTE
169
+
170
  def forward(
171
  self,
172
  input_ids: Optional[torch.LongTensor] = None,