Jiqing commited on
Commit
3907684
·
1 Parent(s): 5d57a51

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_protst.py +53 -0
  2. modeling_protst.py +278 -0
configuration_protst.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers.utils import logging
3
+ from transformers.models.esm import EsmConfig
4
+ from transformers.models.bert import BertConfig
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class ProtSTConfig(PretrainedConfig):
10
+ r"""
11
+ This is the configuration class to store the configuration of a [`ProtSTModel`].
12
+
13
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
14
+ documentation from [`PretrainedConfig`] for more information.
15
+
16
+ Args:
17
+ protein_config (`dict`, *optional*):
18
+ Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
19
+ text_config (`dict`, *optional*):
20
+ Dictionary of configuration options used to initialize [`BertForPubMed`].
21
+ ```"""
22
+
23
+ # model_type = "protst"
24
+
25
+ def __init__(
26
+ self,
27
+ protein_config=None,
28
+ text_config=None,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+
33
+ if protein_config is None:
34
+ protein_config = {}
35
+ logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.")
36
+
37
+ if text_config is None:
38
+ text_config = {}
39
+ logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.")
40
+
41
+ self.protein_config = EsmConfig(**protein_config)
42
+ self.text_config = BertConfig(**text_config)
43
+
44
+ @classmethod
45
+ def from_protein_text_configs(
46
+ cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs
47
+ ):
48
+ r"""
49
+ Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
50
+ [`ProtSTConfig`]: An instance of a configuration object
51
+ """
52
+
53
+ return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs)
modeling_protst.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional, Tuple, Union
5
+ from dataclasses import dataclass
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import ModelOutput
8
+ from transformers.models.esm import EsmPreTrainedModel, EsmModel
9
+ from transformers.models.bert import BertPreTrainedModel, BertModel
10
+ from configuration_protst import ProtSTConfig
11
+
12
+
13
+ @dataclass
14
+ class EsmProteinRepresentationOutput(ModelOutput):
15
+
16
+ protein_feature: torch.FloatTensor = None
17
+ residue_feature: torch.FloatTensor = None
18
+
19
+
20
+ @dataclass
21
+ class BertTextRepresentationOutput(ModelOutput):
22
+
23
+ text_feature: torch.FloatTensor = None
24
+ word_feature: torch.FloatTensor = None
25
+
26
+
27
+ @dataclass
28
+ class EsmProteinClassificationOutput(ModelOutput):
29
+
30
+ loss: Optional[torch.FloatTensor] = None
31
+ logits: torch.FloatTensor = None
32
+
33
+ class ProtSTHead(nn.Module):
34
+ def __init__(self, config, out_dim=512):
35
+ super().__init__()
36
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
37
+ self.out_proj = nn.Linear(config.hidden_size, out_dim)
38
+
39
+ def forward(self, x):
40
+ x = self.dense(x)
41
+ x = nn.functional.relu(x)
42
+ x = self.out_proj(x)
43
+ return x
44
+
45
+
46
+ class BertForPubMed(BertPreTrainedModel):
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+
50
+ self.pad_token_id = config.pad_token_id
51
+ self.cls_token_id = config.cls_token_id
52
+ self.sep_token_id = config.sep_token_id
53
+
54
+ self.bert = BertModel(config, add_pooling_layer=False)
55
+ self.text_mlp = ProtSTHead(config)
56
+ self.word_mlp = ProtSTHead(config)
57
+
58
+ def forward(
59
+ self,
60
+ input_ids: Optional[torch.Tensor] = None,
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ token_type_ids: Optional[torch.Tensor] = None,
63
+ position_ids: Optional[torch.Tensor] = None,
64
+ head_mask: Optional[torch.Tensor] = None,
65
+ inputs_embeds: Optional[torch.Tensor] = None,
66
+ encoder_hidden_states: Optional[torch.Tensor] = None,
67
+ encoder_attention_mask: Optional[torch.Tensor] = None,
68
+ output_attentions: Optional[bool] = None,
69
+ output_hidden_states: Optional[bool] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple[torch.Tensor], ModelOutput]:
72
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
73
+
74
+ outputs = self.bert(
75
+ input_ids,
76
+ attention_mask=attention_mask,
77
+ token_type_ids=token_type_ids,
78
+ position_ids=position_ids,
79
+ head_mask=head_mask,
80
+ inputs_embeds=inputs_embeds,
81
+ encoder_hidden_states=encoder_hidden_states,
82
+ encoder_attention_mask=encoder_attention_mask,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict,
86
+ )
87
+ word_feature = outputs.last_hidden_state
88
+ is_special = (input_ids == self.cls_token_id) | (input_ids == self.sep_token_id) | (input_ids == self.pad_token_id)
89
+ special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
90
+ pooled_feature = ((word_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(word_feature.dtype)
91
+ pooled_feature = self.text_mlp(pooled_feature)
92
+ word_feature = self.word_mlp(word_feature)
93
+
94
+ if not return_dict:
95
+ return (pooled_feature, word_feature)
96
+
97
+ return BertTextRepresentationOutput(text_feature=pooled_feature, word_feature=word_feature)
98
+
99
+
100
+
101
+
102
+ class EsmForProteinRepresentation(EsmPreTrainedModel):
103
+ def __init__(self, config):
104
+ super().__init__(config)
105
+
106
+ self.cls_token_id = config.cls_token_id
107
+ self.pad_token_id = config.pad_token_id
108
+ self.eos_token_id = config.eos_token_id
109
+
110
+ self.esm = EsmModel(config, add_pooling_layer=False)
111
+ self.protein_mlp = ProtSTHead(config)
112
+ self.residue_mlp = ProtSTHead(config)
113
+
114
+ self.init_weights()
115
+
116
+ def forward(
117
+ self,
118
+ input_ids: Optional[torch.LongTensor] = None,
119
+ attention_mask: Optional[torch.Tensor] = None,
120
+ position_ids: Optional[torch.LongTensor] = None,
121
+ head_mask: Optional[torch.Tensor] = None,
122
+ inputs_embeds: Optional[torch.FloatTensor] = None,
123
+ output_attentions: Optional[bool] = None,
124
+ output_hidden_states: Optional[bool] = None,
125
+ return_dict: Optional[bool] = None,
126
+ ) -> Union[Tuple, EsmProteinClassificationOutput]:
127
+
128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
129
+
130
+ outputs = self.esm(
131
+ input_ids,
132
+ attention_mask=attention_mask,
133
+ position_ids=position_ids,
134
+ head_mask=head_mask,
135
+ inputs_embeds=inputs_embeds,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+
141
+ residue_feature = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
142
+
143
+ # mean readout
144
+ is_special = (
145
+ (input_ids == self.cls_token_id) | (input_ids == self.eos_token_id) | (input_ids == self.pad_token_id)
146
+ )
147
+ special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
148
+ protein_feature = ((residue_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(residue_feature.dtype)
149
+
150
+ # For ProtST pretrain and zero-shot
151
+ protein_feature = self.protein_mlp(protein_feature)
152
+ residue_feature = self.residue_mlp(residue_feature)
153
+
154
+
155
+ return EsmProteinRepresentationOutput(
156
+ protein_feature=protein_feature, residue_feature=residue_feature
157
+ )
158
+
159
+
160
+ class EsmForProteinPropertyPrediction(EsmPreTrainedModel):
161
+ def __init__(self, config):
162
+ super().__init__(config)
163
+ self.model = EsmForProteinRepresentation(config)
164
+ self.classifier = ProtSTHead(config, out_dim=config.num_labels)
165
+
166
+ def forward(
167
+ self,
168
+ input_ids: Optional[torch.LongTensor] = None,
169
+ attention_mask: Optional[torch.Tensor] = None,
170
+ position_ids: Optional[torch.LongTensor] = None,
171
+ head_mask: Optional[torch.Tensor] = None,
172
+ inputs_embeds: Optional[torch.FloatTensor] = None,
173
+ labels: Optional[torch.LongTensor] = None,
174
+ output_attentions: Optional[bool] = None,
175
+ output_hidden_states: Optional[bool] = None,
176
+ return_dict: Optional[bool] = None,
177
+ ) -> Union[Tuple, EsmProteinClassificationOutput]:
178
+ r"""
179
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
180
+ Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
181
+ Returns:
182
+ Examples:
183
+ """
184
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
185
+
186
+ outputs = self.model(
187
+ input_ids,
188
+ attention_mask=attention_mask,
189
+ position_ids=position_ids,
190
+ head_mask=head_mask,
191
+ inputs_embeds=inputs_embeds,
192
+ output_attentions=output_attentions,
193
+ output_hidden_states=output_hidden_states,
194
+ return_dict=return_dict,
195
+ )
196
+
197
+ logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels]
198
+
199
+ loss = None
200
+ if labels is not None:
201
+ loss_fct = nn.CrossEntropyLoss()
202
+
203
+ labels = labels.to(logits.device)
204
+ loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
205
+
206
+ if not return_dict:
207
+ output = (logits,)
208
+ return ((loss,) + output) if loss is not None else output
209
+
210
+ return EsmProteinClassificationOutput(loss=loss, logits=logits)
211
+
212
+
213
+ class ProtSTPreTrainedModel(PreTrainedModel):
214
+ config_class = ProtSTConfig
215
+
216
+ def _compute_protein_feature(self,
217
+ protein_input_ids, protein_attention_mask, protein_position_ids,
218
+ output_attentions, output_hidden_states
219
+ ):
220
+
221
+ protein_outputs = self.protein_model(
222
+ protein_input_ids,
223
+ attention_mask=protein_attention_mask,
224
+ position_ids=protein_position_ids,
225
+ head_mask=None,
226
+ inputs_embeds=None,
227
+ encoder_hidden_states=None,
228
+ encoder_attention_mask=None,
229
+ output_attentions=output_attentions,
230
+ output_hidden_states=output_hidden_states,
231
+ return_dict=None,
232
+ )
233
+
234
+ return protein_outputs
235
+
236
+ def _compute_text_feature(self,
237
+ text_input_ids, text_attention_mask, text_position_ids,
238
+ output_attentions, output_hidden_states
239
+ ):
240
+ text_outputs = self.text_model(
241
+ text_input_ids,
242
+ attention_mask=text_attention_mask,
243
+ position_ids=text_position_ids,
244
+ head_mask=None,
245
+ inputs_embeds=None,
246
+ encoder_hidden_states=None,
247
+ encoder_attention_mask=None,
248
+ output_attentions=output_attentions,
249
+ output_hidden_states=output_hidden_states,
250
+ return_dict=None,
251
+ )
252
+
253
+ return text_outputs
254
+
255
+
256
+ class ProtSTModel(ProtSTPreTrainedModel):
257
+ def __init__(self, config):
258
+ super().__init__(config)
259
+
260
+ self.config = config
261
+ self.protein_model = EsmForProteinRepresentation(config.protein_config)
262
+ self.text_model = BertForPubMed(config.text_config)
263
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
264
+
265
+ self.post_init() # NOTE
266
+
267
+ def forward(self,
268
+ protein_input_ids: Optional[torch.LongTensor] = None,
269
+ text_input_ids: Optional[torch.LongTensor] = None,
270
+ protein_attention_mask: Optional[torch.Tensor] = None,
271
+ text_attention_mask: Optional[torch.Tensor] = None,
272
+ protein_position_ids: Optional[torch.LongTensor] = None,
273
+ text_position_ids: Optional[torch.LongTensor] = None,
274
+ output_attentions: Optional[bool] = None,
275
+ output_hidden_states: Optional[bool] = None,
276
+ ):
277
+ # Not implement yet
278
+ return None