wenkai commited on
Commit
7409b0a
·
verified ·
1 Parent(s): f18092c

Upload 31 files

Browse files
esm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .version import version as __version__ # noqa
7
+
8
+ from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
9
+ from .model.esm1 import ProteinBertModel # noqa
10
+ from .model.esm2 import ESM2 # noqa
11
+ from .model.msa_transformer import MSATransformer #noqa
12
+ from . import pretrained # noqa
esm/axial_attention.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class RowSelfAttention(nn.Module):
12
+ """Compute self-attention over rows of a 2D input."""
13
+
14
+ def __init__(
15
+ self,
16
+ embed_dim,
17
+ num_heads,
18
+ dropout=0.0,
19
+ max_tokens_per_msa: int = 2 ** 16,
20
+ ):
21
+ super().__init__()
22
+ self.num_heads = num_heads
23
+ self.dropout = dropout
24
+ self.head_dim = embed_dim // num_heads
25
+ self.scaling = self.head_dim ** -0.5
26
+ self.max_tokens_per_msa = max_tokens_per_msa
27
+ self.attn_shape = "hnij"
28
+
29
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
30
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
31
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
32
+
33
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
34
+ self.dropout_module = nn.Dropout(dropout)
35
+
36
+ def align_scaling(self, q):
37
+ num_rows = q.size(0)
38
+ return self.scaling / math.sqrt(num_rows)
39
+
40
+ def _batched_forward(
41
+ self,
42
+ x,
43
+ self_attn_mask=None,
44
+ self_attn_padding_mask=None,
45
+ ):
46
+ num_rows, num_cols, batch_size, embed_dim = x.size()
47
+ max_rows = max(1, self.max_tokens_per_msa // num_cols)
48
+ attns = 0
49
+ scaling = self.align_scaling(x)
50
+ for start in range(0, num_rows, max_rows):
51
+ attn_weights = self.compute_attention_weights(
52
+ x[start : start + max_rows],
53
+ scaling,
54
+ self_attn_mask=self_attn_mask,
55
+ self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
56
+ if self_attn_padding_mask is not None
57
+ else None,
58
+ )
59
+ attns += attn_weights
60
+ attn_probs = attns.softmax(-1)
61
+ attn_probs = self.dropout_module(attn_probs)
62
+
63
+ outputs = []
64
+ for start in range(0, num_rows, max_rows):
65
+ output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
66
+ outputs.append(output)
67
+
68
+ output = torch.cat(outputs, 0)
69
+ return output, attn_probs
70
+
71
+ def compute_attention_weights(
72
+ self,
73
+ x,
74
+ scaling: float,
75
+ self_attn_mask=None,
76
+ self_attn_padding_mask=None,
77
+ ):
78
+ num_rows, num_cols, batch_size, embed_dim = x.size()
79
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
80
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
81
+ q *= scaling
82
+ if self_attn_padding_mask is not None:
83
+ # Zero out any padded aligned positions - this is important since
84
+ # we take a sum across the alignment axis.
85
+ q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
86
+
87
+ attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
88
+
89
+ if self_attn_mask is not None:
90
+ raise NotImplementedError
91
+ # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
92
+
93
+ if self_attn_padding_mask is not None:
94
+ attn_weights = attn_weights.masked_fill(
95
+ self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
96
+ -10000,
97
+ )
98
+
99
+ return attn_weights
100
+
101
+ def compute_attention_update(
102
+ self,
103
+ x,
104
+ attn_probs,
105
+ ):
106
+ num_rows, num_cols, batch_size, embed_dim = x.size()
107
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
108
+ context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
109
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
110
+ output = self.out_proj(context)
111
+ return output
112
+
113
+ def forward(
114
+ self,
115
+ x,
116
+ self_attn_mask=None,
117
+ self_attn_padding_mask=None,
118
+ ):
119
+ num_rows, num_cols, batch_size, embed_dim = x.size()
120
+ if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
121
+ return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
122
+ else:
123
+ scaling = self.align_scaling(x)
124
+ attn_weights = self.compute_attention_weights(
125
+ x, scaling, self_attn_mask, self_attn_padding_mask
126
+ )
127
+ attn_probs = attn_weights.softmax(-1)
128
+ attn_probs = self.dropout_module(attn_probs)
129
+ output = self.compute_attention_update(x, attn_probs)
130
+ return output, attn_probs
131
+
132
+
133
+ class ColumnSelfAttention(nn.Module):
134
+ """Compute self-attention over columns of a 2D input."""
135
+
136
+ def __init__(
137
+ self,
138
+ embed_dim,
139
+ num_heads,
140
+ dropout=0.0,
141
+ max_tokens_per_msa: int = 2 ** 16,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.num_heads = num_heads
146
+ self.dropout = dropout
147
+ self.head_dim = embed_dim // num_heads
148
+ self.scaling = self.head_dim ** -0.5
149
+ self.max_tokens_per_msa = max_tokens_per_msa
150
+
151
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
152
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
153
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
154
+
155
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
156
+ self.dropout_module = nn.Dropout(dropout)
157
+
158
+ def _batched_forward(
159
+ self,
160
+ x,
161
+ self_attn_mask=None,
162
+ self_attn_padding_mask=None,
163
+ ):
164
+ num_rows, num_cols, batch_size, embed_dim = x.size()
165
+ max_cols = max(1, self.max_tokens_per_msa // num_rows)
166
+ outputs = []
167
+ attns = []
168
+ for start in range(0, num_cols, max_cols):
169
+ output, attn = self(
170
+ x[:, start : start + max_cols],
171
+ self_attn_mask=self_attn_mask,
172
+ self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
173
+ if self_attn_padding_mask is not None
174
+ else None,
175
+ )
176
+ outputs.append(output)
177
+ attns.append(attn)
178
+ output = torch.cat(outputs, 1)
179
+ attns = torch.cat(attns, 1)
180
+ return output, attns
181
+
182
+ def compute_attention_update(
183
+ self,
184
+ x,
185
+ self_attn_mask=None,
186
+ self_attn_padding_mask=None,
187
+ ):
188
+ num_rows, num_cols, batch_size, embed_dim = x.size()
189
+ if num_rows == 1:
190
+ # if there is only 1 position, this is equivalent and doesn't break with padding
191
+ attn_probs = torch.ones(
192
+ self.num_heads,
193
+ num_cols,
194
+ batch_size,
195
+ num_rows,
196
+ num_rows,
197
+ device=x.device,
198
+ dtype=x.dtype,
199
+ )
200
+ output = self.out_proj(self.v_proj(x))
201
+ else:
202
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
203
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
204
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
205
+ q *= self.scaling
206
+
207
+ attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
208
+
209
+ if self_attn_mask is not None:
210
+ raise NotImplementedError
211
+ if self_attn_padding_mask is not None:
212
+ attn_weights = attn_weights.masked_fill(
213
+ self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
214
+ -10000,
215
+ )
216
+
217
+ attn_probs = attn_weights.softmax(-1)
218
+ attn_probs = self.dropout_module(attn_probs)
219
+ context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
220
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
221
+ output = self.out_proj(context)
222
+ return output, attn_probs
223
+
224
+ def forward(
225
+ self,
226
+ x,
227
+ self_attn_mask=None,
228
+ self_attn_padding_mask=None,
229
+ ):
230
+ num_rows, num_cols, batch_size, embed_dim = x.size()
231
+ # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
232
+ if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
233
+ return self._batched_forward(
234
+ x,
235
+ self_attn_mask,
236
+ self_attn_padding_mask,
237
+ )
238
+ else:
239
+ return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
esm/constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # fmt: off
7
+ proteinseq_toks = {
8
+ 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
9
+ }
10
+ # fmt: on
esm/data.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import os
8
+ from typing import Sequence, Tuple, List, Union
9
+ import pickle
10
+ import re
11
+ import shutil
12
+ import torch
13
+ from pathlib import Path
14
+ from esm.constants import proteinseq_toks
15
+
16
+ RawMSA = Sequence[Tuple[str, str]]
17
+
18
+
19
+ class FastaBatchedDataset(object):
20
+ def __init__(self, sequence_labels, sequence_strs):
21
+ self.sequence_labels = list(sequence_labels)
22
+ self.sequence_strs = list(sequence_strs)
23
+
24
+ @classmethod
25
+ def from_file(cls, fasta_file):
26
+ sequence_labels, sequence_strs = [], []
27
+ cur_seq_label = None
28
+ buf = []
29
+
30
+ def _flush_current_seq():
31
+ nonlocal cur_seq_label, buf
32
+ if cur_seq_label is None:
33
+ return
34
+ sequence_labels.append(cur_seq_label)
35
+ sequence_strs.append("".join(buf))
36
+ cur_seq_label = None
37
+ buf = []
38
+
39
+ with open(fasta_file, "r") as infile:
40
+ for line_idx, line in enumerate(infile):
41
+ if line.startswith(">"): # label line
42
+ _flush_current_seq()
43
+ line = line[1:].strip()
44
+ if len(line) > 0:
45
+ cur_seq_label = line
46
+ else:
47
+ cur_seq_label = f"seqnum{line_idx:09d}"
48
+ else: # sequence line
49
+ buf.append(line.strip())
50
+
51
+ _flush_current_seq()
52
+
53
+ assert len(set(sequence_labels)) == len(
54
+ sequence_labels
55
+ ), "Found duplicate sequence labels"
56
+
57
+ return cls(sequence_labels, sequence_strs)
58
+
59
+ def __len__(self):
60
+ return len(self.sequence_labels)
61
+
62
+ def __getitem__(self, idx):
63
+ return self.sequence_labels[idx], self.sequence_strs[idx]
64
+
65
+ def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
66
+ sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
67
+ sizes.sort()
68
+ batches = []
69
+ buf = []
70
+ max_len = 0
71
+
72
+ def _flush_current_buf():
73
+ nonlocal max_len, buf
74
+ if len(buf) == 0:
75
+ return
76
+ batches.append(buf)
77
+ buf = []
78
+ max_len = 0
79
+
80
+ for sz, i in sizes:
81
+ sz += extra_toks_per_seq
82
+ if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
83
+ _flush_current_buf()
84
+ max_len = max(max_len, sz)
85
+ buf.append(i)
86
+
87
+ _flush_current_buf()
88
+ return batches
89
+
90
+
91
+ class Alphabet(object):
92
+ def __init__(
93
+ self,
94
+ standard_toks: Sequence[str],
95
+ prepend_toks: Sequence[str] = ("<null_0>", "<pad>", "<eos>", "<unk>"),
96
+ append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"),
97
+ prepend_bos: bool = True,
98
+ append_eos: bool = False,
99
+ use_msa: bool = False,
100
+ ):
101
+ self.standard_toks = list(standard_toks)
102
+ self.prepend_toks = list(prepend_toks)
103
+ self.append_toks = list(append_toks)
104
+ self.prepend_bos = prepend_bos
105
+ self.append_eos = append_eos
106
+ self.use_msa = use_msa
107
+
108
+ self.all_toks = list(self.prepend_toks)
109
+ self.all_toks.extend(self.standard_toks)
110
+ for i in range((8 - (len(self.all_toks) % 8)) % 8):
111
+ self.all_toks.append(f"<null_{i + 1}>")
112
+ self.all_toks.extend(self.append_toks)
113
+
114
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
115
+
116
+ self.unk_idx = self.tok_to_idx["<unk>"]
117
+ self.padding_idx = self.get_idx("<pad>")
118
+ self.cls_idx = self.get_idx("<cls>")
119
+ self.mask_idx = self.get_idx("<mask>")
120
+ self.eos_idx = self.get_idx("<eos>")
121
+ self.all_special_tokens = ['<eos>', '<unk>', '<pad>', '<cls>', '<mask>']
122
+ self.unique_no_split_tokens = self.all_toks
123
+
124
+ def __len__(self):
125
+ return len(self.all_toks)
126
+
127
+ def get_idx(self, tok):
128
+ return self.tok_to_idx.get(tok, self.unk_idx)
129
+
130
+ def get_tok(self, ind):
131
+ return self.all_toks[ind]
132
+
133
+ def to_dict(self):
134
+ return self.tok_to_idx.copy()
135
+
136
+ def get_batch_converter(self, truncation_seq_length: int = None):
137
+ if self.use_msa:
138
+ return MSABatchConverter(self, truncation_seq_length)
139
+ else:
140
+ return BatchConverter(self, truncation_seq_length)
141
+
142
+ @classmethod
143
+ def from_architecture(cls, name: str) -> "Alphabet":
144
+ if name in ("ESM-1", "protein_bert_base"):
145
+ standard_toks = proteinseq_toks["toks"]
146
+ prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>")
147
+ append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>")
148
+ prepend_bos = True
149
+ append_eos = False
150
+ use_msa = False
151
+ elif name in ("ESM-1b", "roberta_large"):
152
+ standard_toks = proteinseq_toks["toks"]
153
+ prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
154
+ append_toks = ("<mask>",)
155
+ prepend_bos = True
156
+ append_eos = True
157
+ use_msa = False
158
+ elif name in ("MSA Transformer", "msa_transformer"):
159
+ standard_toks = proteinseq_toks["toks"]
160
+ prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
161
+ append_toks = ("<mask>",)
162
+ prepend_bos = True
163
+ append_eos = False
164
+ use_msa = True
165
+ elif "invariant_gvp" in name.lower():
166
+ standard_toks = proteinseq_toks["toks"]
167
+ prepend_toks = ("<null_0>", "<pad>", "<eos>", "<unk>")
168
+ append_toks = ("<mask>", "<cath>", "<af2>")
169
+ prepend_bos = True
170
+ append_eos = False
171
+ use_msa = False
172
+ else:
173
+ raise ValueError("Unknown architecture selected")
174
+ return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
175
+
176
+ def _tokenize(self, text) -> str:
177
+ return text.split()
178
+
179
+ def tokenize(self, text, **kwargs) -> List[str]:
180
+ """
181
+ Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
182
+ Converts a string in a sequence of tokens, using the tokenizer.
183
+
184
+ Args:
185
+ text (:obj:`str`):
186
+ The sequence to be encoded.
187
+
188
+ Returns:
189
+ :obj:`List[str]`: The list of tokens.
190
+ """
191
+
192
+ def split_on_token(tok, text):
193
+ result = []
194
+ split_text = text.split(tok)
195
+ for i, sub_text in enumerate(split_text):
196
+ # AddedToken can control whitespace stripping around them.
197
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
198
+ # Cf. https://github.com/huggingface/transformers/pull/2778
199
+ # and https://github.com/huggingface/transformers/issues/3788
200
+ # We strip left and right by default
201
+ if i < len(split_text) - 1:
202
+ sub_text = sub_text.rstrip()
203
+ if i > 0:
204
+ sub_text = sub_text.lstrip()
205
+
206
+ if i == 0 and not sub_text:
207
+ result.append(tok)
208
+ elif i == len(split_text) - 1:
209
+ if sub_text:
210
+ result.append(sub_text)
211
+ else:
212
+ pass
213
+ else:
214
+ if sub_text:
215
+ result.append(sub_text)
216
+ result.append(tok)
217
+ return result
218
+
219
+ def split_on_tokens(tok_list, text):
220
+ if not text.strip():
221
+ return []
222
+
223
+ tokenized_text = []
224
+ text_list = [text]
225
+ for tok in tok_list:
226
+ tokenized_text = []
227
+ for sub_text in text_list:
228
+ if sub_text not in self.unique_no_split_tokens:
229
+ tokenized_text.extend(split_on_token(tok, sub_text))
230
+ else:
231
+ tokenized_text.append(sub_text)
232
+ text_list = tokenized_text
233
+
234
+ return list(
235
+ itertools.chain.from_iterable(
236
+ (
237
+ self._tokenize(token)
238
+ if token not in self.unique_no_split_tokens
239
+ else [token]
240
+ for token in tokenized_text
241
+ )
242
+ )
243
+ )
244
+
245
+ no_split_token = self.unique_no_split_tokens
246
+ tokenized_text = split_on_tokens(no_split_token, text)
247
+ return tokenized_text
248
+
249
+ def encode(self, text):
250
+ return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
251
+
252
+
253
+ class BatchConverter(object):
254
+ """Callable to convert an unprocessed (labels + strings) batch to a
255
+ processed (labels + tensor) batch.
256
+ """
257
+
258
+ def __init__(self, alphabet, truncation_seq_length: int = None):
259
+ self.alphabet = alphabet
260
+ self.truncation_seq_length = truncation_seq_length
261
+
262
+ def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
263
+ # RoBERTa uses an eos token, while ESM-1 does not.
264
+ batch_size = len(raw_batch)
265
+ batch_labels, seq_str_list = zip(*raw_batch)
266
+ seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
267
+ if self.truncation_seq_length:
268
+ seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
269
+ max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
270
+ tokens = torch.empty(
271
+ (
272
+ batch_size,
273
+ max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
274
+ ),
275
+ dtype=torch.int64,
276
+ )
277
+ tokens.fill_(self.alphabet.padding_idx)
278
+ labels = []
279
+ strs = []
280
+
281
+ for i, (label, seq_str, seq_encoded) in enumerate(
282
+ zip(batch_labels, seq_str_list, seq_encoded_list)
283
+ ):
284
+ labels.append(label)
285
+ strs.append(seq_str)
286
+ if self.alphabet.prepend_bos:
287
+ tokens[i, 0] = self.alphabet.cls_idx
288
+ seq = torch.tensor(seq_encoded, dtype=torch.int64)
289
+ tokens[
290
+ i,
291
+ int(self.alphabet.prepend_bos) : len(seq_encoded)
292
+ + int(self.alphabet.prepend_bos),
293
+ ] = seq
294
+ if self.alphabet.append_eos:
295
+ tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
296
+
297
+ return labels, strs, tokens
298
+
299
+
300
+ class MSABatchConverter(BatchConverter):
301
+ def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
302
+ if isinstance(inputs[0][0], str):
303
+ # Input is a single MSA
304
+ raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
305
+ else:
306
+ raw_batch = inputs # type: ignore
307
+
308
+ batch_size = len(raw_batch)
309
+ max_alignments = max(len(msa) for msa in raw_batch)
310
+ max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
311
+
312
+ tokens = torch.empty(
313
+ (
314
+ batch_size,
315
+ max_alignments,
316
+ max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
317
+ ),
318
+ dtype=torch.int64,
319
+ )
320
+ tokens.fill_(self.alphabet.padding_idx)
321
+ labels = []
322
+ strs = []
323
+
324
+ for i, msa in enumerate(raw_batch):
325
+ msa_seqlens = set(len(seq) for _, seq in msa)
326
+ if not len(msa_seqlens) == 1:
327
+ raise RuntimeError(
328
+ "Received unaligned sequences for input to MSA, all sequence "
329
+ "lengths must be equal."
330
+ )
331
+ msa_labels, msa_strs, msa_tokens = super().__call__(msa)
332
+ labels.append(msa_labels)
333
+ strs.append(msa_strs)
334
+ tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
335
+
336
+ return labels, strs, tokens
337
+
338
+
339
+ def read_fasta(
340
+ path,
341
+ keep_gaps=True,
342
+ keep_insertions=True,
343
+ to_upper=False,
344
+ ):
345
+ with open(path, "r") as f:
346
+ for result in read_alignment_lines(
347
+ f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
348
+ ):
349
+ yield result
350
+
351
+
352
+ def read_alignment_lines(
353
+ lines,
354
+ keep_gaps=True,
355
+ keep_insertions=True,
356
+ to_upper=False,
357
+ ):
358
+ seq = desc = None
359
+
360
+ def parse(s):
361
+ if not keep_gaps:
362
+ s = re.sub("-", "", s)
363
+ if not keep_insertions:
364
+ s = re.sub("[a-z]", "", s)
365
+ return s.upper() if to_upper else s
366
+
367
+ for line in lines:
368
+ # Line may be empty if seq % file_line_width == 0
369
+ if len(line) > 0 and line[0] == ">":
370
+ if seq is not None:
371
+ yield desc, parse(seq)
372
+ desc = line.strip().lstrip(">")
373
+ seq = ""
374
+ else:
375
+ assert isinstance(seq, str)
376
+ seq += line.strip()
377
+ assert isinstance(seq, str) and isinstance(desc, str)
378
+ yield desc, parse(seq)
379
+
380
+
381
+ class ESMStructuralSplitDataset(torch.utils.data.Dataset):
382
+ """
383
+ Structural Split Dataset as described in section A.10 of the supplement of our paper.
384
+ https://doi.org/10.1101/622803
385
+
386
+ We use the full version of SCOPe 2.07, clustered at 90% sequence identity,
387
+ generated on January 23, 2020.
388
+
389
+ For each SCOPe domain:
390
+ - We extract the sequence from the corresponding PDB file
391
+ - We extract the 3D coordinates of the Carbon beta atoms, aligning them
392
+ to the sequence. We put NaN where Cb atoms are missing.
393
+ - From the 3D coordinates, we calculate a pairwise distance map, based
394
+ on L2 distance
395
+ - We use DSSP to generate secondary structure labels for the corresponding
396
+ PDB file. This is also aligned to the sequence. We put - where SSP
397
+ labels are missing.
398
+
399
+ For each SCOPe classification level of family/superfamily/fold (in order of difficulty),
400
+ we have split the data into 5 partitions for cross validation. These are provided
401
+ in a downloaded splits folder, in the format:
402
+ splits/{split_level}/{cv_partition}/{train|valid}.txt
403
+ where train is the partition and valid is the concatentation of the remaining 4.
404
+
405
+ For each SCOPe domain, we provide a pkl dump that contains:
406
+ - seq : The domain sequence, stored as an L-length string
407
+ - ssp : The secondary structure labels, stored as an L-length string
408
+ - dist : The distance map, stored as an LxL numpy array
409
+ - coords : The 3D coordinates, stored as an Lx3 numpy array
410
+
411
+ """
412
+
413
+ base_folder = "structural-data"
414
+ file_list = [
415
+ # url tar filename filename MD5 Hash
416
+ (
417
+ "https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz",
418
+ "splits.tar.gz",
419
+ "splits",
420
+ "456fe1c7f22c9d3d8dfe9735da52411d",
421
+ ),
422
+ (
423
+ "https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz",
424
+ "pkl.tar.gz",
425
+ "pkl",
426
+ "644ea91e56066c750cd50101d390f5db",
427
+ ),
428
+ ]
429
+
430
+ def __init__(
431
+ self,
432
+ split_level,
433
+ cv_partition,
434
+ split,
435
+ root_path=os.path.expanduser("~/.cache/torch/data/esm"),
436
+ download=False,
437
+ ):
438
+ super().__init__()
439
+ assert split in [
440
+ "train",
441
+ "valid",
442
+ ], "train_valid must be 'train' or 'valid'"
443
+ self.root_path = root_path
444
+ self.base_path = os.path.join(self.root_path, self.base_folder)
445
+
446
+ # check if root path has what you need or else download it
447
+ if download:
448
+ self.download()
449
+
450
+ self.split_file = os.path.join(
451
+ self.base_path, "splits", split_level, cv_partition, f"{split}.txt"
452
+ )
453
+ self.pkl_dir = os.path.join(self.base_path, "pkl")
454
+ self.names = []
455
+ with open(self.split_file) as f:
456
+ self.names = f.read().splitlines()
457
+
458
+ def __len__(self):
459
+ return len(self.names)
460
+
461
+ def _check_exists(self) -> bool:
462
+ for (_, _, filename, _) in self.file_list:
463
+ fpath = os.path.join(self.base_path, filename)
464
+ if not os.path.exists(fpath) or not os.path.isdir(fpath):
465
+ return False
466
+ return True
467
+
468
+ def download(self):
469
+
470
+ if self._check_exists():
471
+ print("Files already downloaded and verified")
472
+ return
473
+
474
+ from torchvision.datasets.utils import download_url
475
+
476
+ for url, tar_filename, filename, md5_hash in self.file_list:
477
+ download_path = os.path.join(self.base_path, tar_filename)
478
+ download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash)
479
+ shutil.unpack_archive(download_path, self.base_path)
480
+
481
+ def __getitem__(self, idx):
482
+ """
483
+ Returns a dict with the following entires
484
+ - seq : Str (domain sequence)
485
+ - ssp : Str (SSP labels)
486
+ - dist : np.array (distance map)
487
+ - coords : np.array (3D coordinates)
488
+ """
489
+ name = self.names[idx]
490
+ pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl")
491
+ with open(pkl_fname, "rb") as f:
492
+ obj = pickle.load(f)
493
+ return obj
esm/esmfold/v1/__init__.py ADDED
File without changes
esm/esmfold/v1/categorical_mixture.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+
7
+
8
+ class CategoricalMixture:
9
+ def __init__(self, param, bins=50, start=0, end=1):
10
+ # All tensors are of shape ..., bins.
11
+ self.logits = param
12
+ bins = torch.linspace(
13
+ start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype
14
+ )
15
+ self.v_bins = (bins[:-1] + bins[1:]) / 2
16
+
17
+ def log_prob(self, true):
18
+ # Shapes are:
19
+ # self.probs: ... x bins
20
+ # true : ...
21
+ true_index = (
22
+ (
23
+ true.unsqueeze(-1)
24
+ - self.v_bins[
25
+ [
26
+ None,
27
+ ]
28
+ * true.ndim
29
+ ]
30
+ )
31
+ .abs()
32
+ .argmin(-1)
33
+ )
34
+ nll = self.logits.log_softmax(-1)
35
+ return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
36
+
37
+ def mean(self):
38
+ return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
39
+
40
+
41
+ def categorical_lddt(logits, bins=50):
42
+ # Logits are ..., 37, bins.
43
+ return CategoricalMixture(logits, bins=bins).mean()
esm/esmfold/v1/esmfold.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import typing as T
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import nn
12
+ from torch.nn import LayerNorm
13
+
14
+ import esm
15
+ from esm import Alphabet
16
+ from esm.esmfold.v1.categorical_mixture import categorical_lddt
17
+ from esm.esmfold.v1.misc import (
18
+ batch_encode_sequences,
19
+ collate_dense_tensors,
20
+ output_to_pdb,
21
+ )
22
+ from esm.esmfold.v1.trunk import FoldingTrunk, FoldingTrunkConfig
23
+ from openfold.data.data_transforms import make_atom14_masks
24
+ from openfold.np import residue_constants
25
+ from openfold.utils.loss import compute_predicted_aligned_error, compute_tm
26
+
27
+
28
+ @dataclass
29
+ class ESMFoldConfig:
30
+ trunk: T.Any = FoldingTrunkConfig()
31
+ lddt_head_hid_dim: int = 128
32
+
33
+
34
+ load_fn = esm.pretrained.load_model_and_alphabet
35
+ esm_registry = {
36
+ "esm2_8M": partial(load_fn, "esm2_t6_8M_UR50D_500K"),
37
+ "esm2_8M_270K": esm.pretrained.esm2_t6_8M_UR50D,
38
+ "esm2_35M": partial(load_fn, "esm2_t12_35M_UR50D_500K"),
39
+ "esm2_35M_270K": esm.pretrained.esm2_t12_35M_UR50D,
40
+ "esm2_150M": partial(load_fn, "esm2_t30_150M_UR50D_500K"),
41
+ "esm2_150M_270K": partial(load_fn, "esm2_t30_150M_UR50D_270K"),
42
+ "esm2_650M": esm.pretrained.esm2_t33_650M_UR50D,
43
+ "esm2_650M_270K": partial(load_fn, "esm2_t33_650M_270K_UR50D"),
44
+ "esm2_3B": esm.pretrained.esm2_t36_3B_UR50D,
45
+ "esm2_3B_270K": partial(load_fn, "esm2_t36_3B_UR50D_500K"),
46
+ "esm2_15B": esm.pretrained.esm2_t48_15B_UR50D,
47
+ }
48
+
49
+
50
+ class ESMFold(nn.Module):
51
+ def __init__(self, esmfold_config=None, **kwargs):
52
+ super().__init__()
53
+
54
+ self.cfg = esmfold_config if esmfold_config else ESMFoldConfig(**kwargs)
55
+ cfg = self.cfg
56
+
57
+ self.distogram_bins = 64
58
+
59
+ self.esm, self.esm_dict = esm_registry.get(cfg.esm_type)()
60
+
61
+ self.esm.requires_grad_(False)
62
+ self.esm.half()
63
+
64
+ self.esm_feats = self.esm.embed_dim
65
+ self.esm_attns = self.esm.num_layers * self.esm.attention_heads
66
+ self.register_buffer("af2_to_esm", ESMFold._af2_to_esm(self.esm_dict))
67
+ self.esm_s_combine = nn.Parameter(torch.zeros(self.esm.num_layers + 1))
68
+
69
+ c_s = cfg.trunk.sequence_state_dim
70
+ c_z = cfg.trunk.pairwise_state_dim
71
+
72
+ self.esm_s_mlp = nn.Sequential(
73
+ LayerNorm(self.esm_feats),
74
+ nn.Linear(self.esm_feats, c_s),
75
+ nn.ReLU(),
76
+ nn.Linear(c_s, c_s),
77
+ )
78
+ if cfg.use_esm_attn_map:
79
+ self.esm_z_mlp = nn.Sequential(
80
+ LayerNorm(self.esm_attns),
81
+ nn.Linear(self.esm_attns, c_z),
82
+ nn.ReLU(),
83
+ nn.Linear(c_z, c_z),
84
+ )
85
+
86
+ # 0 is padding, N is unknown residues, N + 1 is mask.
87
+ self.n_tokens_embed = residue_constants.restype_num + 3
88
+ self.pad_idx = 0
89
+ self.unk_idx = self.n_tokens_embed - 2
90
+ self.mask_idx = self.n_tokens_embed - 1
91
+ self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
92
+
93
+ self.trunk = FoldingTrunk(**cfg.trunk)
94
+
95
+ self.distogram_head = nn.Linear(c_z, self.distogram_bins)
96
+ self.ptm_head = nn.Linear(c_z, self.distogram_bins)
97
+ self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
98
+ self.lddt_bins = 50
99
+ self.lddt_head = nn.Sequential(
100
+ nn.LayerNorm(cfg.trunk.structure_module.c_s),
101
+ nn.Linear(cfg.trunk.structure_module.c_s, cfg.lddt_head_hid_dim),
102
+ nn.Linear(cfg.lddt_head_hid_dim, cfg.lddt_head_hid_dim),
103
+ nn.Linear(cfg.lddt_head_hid_dim, 37 * self.lddt_bins),
104
+ )
105
+
106
+ @staticmethod
107
+ def _af2_to_esm(d: Alphabet):
108
+ # Remember that t is shifted from residue_constants by 1 (0 is padding).
109
+ esm_reorder = [d.padding_idx] + [
110
+ d.get_idx(v) for v in residue_constants.restypes_with_x
111
+ ]
112
+ return torch.tensor(esm_reorder)
113
+
114
+ def _af2_idx_to_esm_idx(self, aa, mask):
115
+ aa = (aa + 1).masked_fill(mask != 1, 0)
116
+ return self.af2_to_esm[aa]
117
+
118
+ def _compute_language_model_representations(
119
+ self, esmaa: torch.Tensor
120
+ ) -> torch.Tensor:
121
+ """Adds bos/eos tokens for the language model, since the structure module doesn't use these."""
122
+ batch_size = esmaa.size(0)
123
+
124
+ bosi, eosi = self.esm_dict.cls_idx, self.esm_dict.eos_idx
125
+ bos = esmaa.new_full((batch_size, 1), bosi)
126
+ eos = esmaa.new_full((batch_size, 1), self.esm_dict.padding_idx)
127
+ esmaa = torch.cat([bos, esmaa, eos], dim=1)
128
+ # Use the first padding index as eos during inference.
129
+ esmaa[range(batch_size), (esmaa != 1).sum(1)] = eosi
130
+
131
+ res = self.esm(
132
+ esmaa,
133
+ repr_layers=range(self.esm.num_layers + 1),
134
+ need_head_weights=self.cfg.use_esm_attn_map,
135
+ )
136
+ esm_s = torch.stack(
137
+ [v for _, v in sorted(res["representations"].items())], dim=2
138
+ )
139
+ esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
140
+ esm_z = (
141
+ res["attentions"].permute(0, 4, 3, 1, 2).flatten(3, 4)[:, 1:-1, 1:-1, :]
142
+ if self.cfg.use_esm_attn_map
143
+ else None
144
+ )
145
+ return esm_s, esm_z
146
+
147
+ def _mask_inputs_to_esm(self, esmaa, pattern):
148
+ new_esmaa = esmaa.clone()
149
+ new_esmaa[pattern == 1] = self.esm_dict.mask_idx
150
+ return new_esmaa
151
+
152
+ def forward(
153
+ self,
154
+ aa: torch.Tensor,
155
+ mask: T.Optional[torch.Tensor] = None,
156
+ residx: T.Optional[torch.Tensor] = None,
157
+ masking_pattern: T.Optional[torch.Tensor] = None,
158
+ num_recycles: T.Optional[int] = None,
159
+ ):
160
+ """Runs a forward pass given input tokens. Use `model.infer` to
161
+ run inference from a sequence.
162
+
163
+ Args:
164
+ aa (torch.Tensor): Tensor containing indices corresponding to amino acids. Indices match
165
+ openfold.np.residue_constants.restype_order_with_x.
166
+ mask (torch.Tensor): Binary tensor with 1 meaning position is unmasked and 0 meaning position is masked.
167
+ residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
168
+ masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
169
+ as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
170
+ different masks are provided.
171
+ num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
172
+ recycles, which is 3.
173
+ """
174
+
175
+ if mask is None:
176
+ mask = torch.ones_like(aa)
177
+
178
+ B = aa.shape[0]
179
+ L = aa.shape[1]
180
+ device = aa.device
181
+
182
+ if residx is None:
183
+ residx = torch.arange(L, device=device).expand_as(aa)
184
+
185
+ # === ESM ===
186
+ esmaa = self._af2_idx_to_esm_idx(aa, mask)
187
+
188
+ if masking_pattern is not None:
189
+ esmaa = self._mask_inputs_to_esm(esmaa, masking_pattern)
190
+
191
+ esm_s, esm_z = self._compute_language_model_representations(esmaa)
192
+
193
+ # Convert esm_s to the precision used by the trunk and
194
+ # the structure module. These tensors may be a lower precision if, for example,
195
+ # we're running the language model in fp16 precision.
196
+ esm_s = esm_s.to(self.esm_s_combine.dtype)
197
+ esm_s = esm_s.detach()
198
+
199
+ # === preprocessing ===
200
+ esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
201
+
202
+ s_s_0 = self.esm_s_mlp(esm_s)
203
+ if self.cfg.use_esm_attn_map:
204
+ esm_z = esm_z.to(self.esm_s_combine.dtype)
205
+ esm_z = esm_z.detach()
206
+ s_z_0 = self.esm_z_mlp(esm_z)
207
+ else:
208
+ s_z_0 = s_s_0.new_zeros(B, L, L, self.cfg.trunk.pairwise_state_dim)
209
+
210
+ s_s_0 += self.embedding(aa)
211
+
212
+ structure: dict = self.trunk(
213
+ s_s_0, s_z_0, aa, residx, mask, no_recycles=num_recycles
214
+ )
215
+ # Documenting what we expect:
216
+ structure = {
217
+ k: v
218
+ for k, v in structure.items()
219
+ if k
220
+ in [
221
+ "s_z",
222
+ "s_s",
223
+ "frames",
224
+ "sidechain_frames",
225
+ "unnormalized_angles",
226
+ "angles",
227
+ "positions",
228
+ "states",
229
+ ]
230
+ }
231
+
232
+ disto_logits = self.distogram_head(structure["s_z"])
233
+ disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
234
+ structure["distogram_logits"] = disto_logits
235
+
236
+ lm_logits = self.lm_head(structure["s_s"])
237
+ structure["lm_logits"] = lm_logits
238
+
239
+ structure["aatype"] = aa
240
+ make_atom14_masks(structure)
241
+
242
+ for k in [
243
+ "atom14_atom_exists",
244
+ "atom37_atom_exists",
245
+ ]:
246
+ structure[k] *= mask.unsqueeze(-1)
247
+ structure["residue_index"] = residx
248
+
249
+ lddt_head = self.lddt_head(structure["states"]).reshape(
250
+ structure["states"].shape[0], B, L, -1, self.lddt_bins
251
+ )
252
+ structure["lddt_head"] = lddt_head
253
+ plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
254
+ structure["plddt"] = (
255
+ 100 * plddt
256
+ ) # we predict plDDT between 0 and 1, scale to be between 0 and 100.
257
+
258
+ ptm_logits = self.ptm_head(structure["s_z"])
259
+
260
+ seqlen = mask.type(torch.int64).sum(1)
261
+ structure["ptm_logits"] = ptm_logits
262
+ structure["ptm"] = torch.stack(
263
+ [
264
+ compute_tm(
265
+ batch_ptm_logits[None, :sl, :sl],
266
+ max_bins=31,
267
+ no_bins=self.distogram_bins,
268
+ )
269
+ for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
270
+ ]
271
+ )
272
+ structure.update(
273
+ compute_predicted_aligned_error(
274
+ ptm_logits, max_bin=31, no_bins=self.distogram_bins
275
+ )
276
+ )
277
+
278
+ return structure
279
+
280
+ @torch.no_grad()
281
+ def infer(
282
+ self,
283
+ sequences: T.Union[str, T.List[str]],
284
+ residx=None,
285
+ masking_pattern: T.Optional[torch.Tensor] = None,
286
+ num_recycles: T.Optional[int] = None,
287
+ residue_index_offset: T.Optional[int] = 512,
288
+ chain_linker: T.Optional[str] = "G" * 25,
289
+ ):
290
+ """Runs a forward pass given input sequences.
291
+
292
+ Args:
293
+ sequences (Union[str, List[str]]): A list of sequences to make predictions for. Multimers can also be passed in,
294
+ each chain should be separated by a ':' token (e.g. "<chain1>:<chain2>:<chain3>").
295
+ residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
296
+ masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
297
+ as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
298
+ different masks are provided.
299
+ num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
300
+ recycles (cfg.trunk.max_recycles), which is 4.
301
+ residue_index_offset (int): Residue index separation between chains if predicting a multimer. Has no effect on
302
+ single chain predictions. Default: 512.
303
+ chain_linker (str): Linker to use between chains if predicting a multimer. Has no effect on single chain
304
+ predictions. Default: length-25 poly-G ("G" * 25).
305
+ """
306
+ if isinstance(sequences, str):
307
+ sequences = [sequences]
308
+
309
+ aatype, mask, _residx, linker_mask, chain_index = batch_encode_sequences(
310
+ sequences, residue_index_offset, chain_linker
311
+ )
312
+
313
+ if residx is None:
314
+ residx = _residx
315
+ elif not isinstance(residx, torch.Tensor):
316
+ residx = collate_dense_tensors(residx)
317
+
318
+ aatype, mask, residx, linker_mask = map(
319
+ lambda x: x.to(self.device), (aatype, mask, residx, linker_mask)
320
+ )
321
+
322
+ output = self.forward(
323
+ aatype,
324
+ mask=mask,
325
+ residx=residx,
326
+ masking_pattern=masking_pattern,
327
+ num_recycles=num_recycles,
328
+ )
329
+
330
+ output["atom37_atom_exists"] = output[
331
+ "atom37_atom_exists"
332
+ ] * linker_mask.unsqueeze(2)
333
+
334
+ output["mean_plddt"] = (output["plddt"] * output["atom37_atom_exists"]).sum(
335
+ dim=(1, 2)
336
+ ) / output["atom37_atom_exists"].sum(dim=(1, 2))
337
+ output["chain_index"] = chain_index
338
+
339
+ return output
340
+
341
+ def output_to_pdb(self, output: T.Dict) -> T.List[str]:
342
+ """Returns the pbd (file) string from the model given the model output."""
343
+ return output_to_pdb(output)
344
+
345
+ def infer_pdbs(self, seqs: T.List[str], *args, **kwargs) -> T.List[str]:
346
+ """Returns list of pdb (files) strings from the model given a list of input sequences."""
347
+ output = self.infer(seqs, *args, **kwargs)
348
+ return self.output_to_pdb(output)
349
+
350
+ def infer_pdb(self, sequence: str, *args, **kwargs) -> str:
351
+ """Returns the pdb (file) string from the model given an input sequence."""
352
+ return self.infer_pdbs([sequence], *args, **kwargs)[0]
353
+
354
+ def set_chunk_size(self, chunk_size: T.Optional[int]):
355
+ # This parameter means the axial attention will be computed
356
+ # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
357
+ # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
358
+ # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
359
+ # Setting the value to None will return to default behavior, disable chunking.
360
+ self.trunk.set_chunk_size(chunk_size)
361
+
362
+ @property
363
+ def device(self):
364
+ return self.esm_s_combine.device
esm/esmfold/v1/misc.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import typing as T
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from einops import rearrange, repeat
11
+ from torch import nn
12
+ from openfold.np import residue_constants
13
+ from openfold.np.protein import Protein as OFProtein
14
+ from openfold.np.protein import to_pdb
15
+ from openfold.utils.feats import atom14_to_atom37
16
+
17
+
18
+ def encode_sequence(
19
+ seq: str,
20
+ residue_index_offset: T.Optional[int] = 512,
21
+ chain_linker: T.Optional[str] = "G" * 25,
22
+ ) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
23
+ if chain_linker is None:
24
+ chain_linker = ""
25
+ if residue_index_offset is None:
26
+ residue_index_offset = 0
27
+
28
+ chains = seq.split(":")
29
+ seq = chain_linker.join(chains)
30
+
31
+ unk_idx = residue_constants.restype_order_with_x["X"]
32
+ encoded = torch.tensor(
33
+ [residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq]
34
+ )
35
+ residx = torch.arange(len(encoded))
36
+
37
+ if residue_index_offset > 0:
38
+ start = 0
39
+ for i, chain in enumerate(chains):
40
+ residx[start : start + len(chain) + len(chain_linker)] += (
41
+ i * residue_index_offset
42
+ )
43
+ start += len(chain) + len(chain_linker)
44
+
45
+ linker_mask = torch.ones_like(encoded, dtype=torch.float32)
46
+ chain_index = []
47
+ offset = 0
48
+ for i, chain in enumerate(chains):
49
+ if i > 0:
50
+ chain_index.extend([i - 1] * len(chain_linker))
51
+ chain_index.extend([i] * len(chain))
52
+ offset += len(chain)
53
+ linker_mask[offset : offset + len(chain_linker)] = 0
54
+ offset += len(chain_linker)
55
+
56
+ chain_index = torch.tensor(chain_index, dtype=torch.int64)
57
+
58
+ return encoded, residx, linker_mask, chain_index
59
+
60
+
61
+ def batch_encode_sequences(
62
+ sequences: T.Sequence[str],
63
+ residue_index_offset: T.Optional[int] = 512,
64
+ chain_linker: T.Optional[str] = "G" * 25,
65
+ ) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
66
+
67
+ aatype_list = []
68
+ residx_list = []
69
+ linker_mask_list = []
70
+ chain_index_list = []
71
+ for seq in sequences:
72
+ aatype_seq, residx_seq, linker_mask_seq, chain_index_seq = encode_sequence(
73
+ seq,
74
+ residue_index_offset=residue_index_offset,
75
+ chain_linker=chain_linker,
76
+ )
77
+ aatype_list.append(aatype_seq)
78
+ residx_list.append(residx_seq)
79
+ linker_mask_list.append(linker_mask_seq)
80
+ chain_index_list.append(chain_index_seq)
81
+
82
+ aatype = collate_dense_tensors(aatype_list)
83
+ mask = collate_dense_tensors(
84
+ [aatype.new_ones(len(aatype_seq)) for aatype_seq in aatype_list]
85
+ )
86
+ residx = collate_dense_tensors(residx_list)
87
+ linker_mask = collate_dense_tensors(linker_mask_list)
88
+ chain_index_list = collate_dense_tensors(chain_index_list, -1)
89
+
90
+ return aatype, mask, residx, linker_mask, chain_index_list
91
+
92
+
93
+ def output_to_pdb(output: T.Dict) -> T.List[str]:
94
+ """Returns the pbd (file) string from the model given the model output."""
95
+ # atom14_to_atom37 must be called first, as it fails on latest numpy if the
96
+ # input is a numpy array. It will work if the input is a torch tensor.
97
+ final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
98
+ output = {k: v.to("cpu").numpy() for k, v in output.items()}
99
+ final_atom_positions = final_atom_positions.cpu().numpy()
100
+ final_atom_mask = output["atom37_atom_exists"]
101
+ pdbs = []
102
+ for i in range(output["aatype"].shape[0]):
103
+ aa = output["aatype"][i]
104
+ pred_pos = final_atom_positions[i]
105
+ mask = final_atom_mask[i]
106
+ resid = output["residue_index"][i] + 1
107
+ pred = OFProtein(
108
+ aatype=aa,
109
+ atom_positions=pred_pos,
110
+ atom_mask=mask,
111
+ residue_index=resid,
112
+ b_factors=output["plddt"][i],
113
+ chain_index=output["chain_index"][i] if "chain_index" in output else None,
114
+ )
115
+ pdbs.append(to_pdb(pred))
116
+ return pdbs
117
+
118
+
119
+ def collate_dense_tensors(
120
+ samples: T.List[torch.Tensor], pad_v: float = 0
121
+ ) -> torch.Tensor:
122
+ """
123
+ Takes a list of tensors with the following dimensions:
124
+ [(d_11, ..., d_1K),
125
+ (d_21, ..., d_2K),
126
+ ...,
127
+ (d_N1, ..., d_NK)]
128
+ and stack + pads them into a single tensor of:
129
+ (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
130
+ """
131
+ if len(samples) == 0:
132
+ return torch.Tensor()
133
+ if len(set(x.dim() for x in samples)) != 1:
134
+ raise RuntimeError(
135
+ f"Samples has varying dimensions: {[x.dim() for x in samples]}"
136
+ )
137
+ (device,) = tuple(set(x.device for x in samples)) # assumes all on same device
138
+ max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
139
+ result = torch.empty(
140
+ len(samples), *max_shape, dtype=samples[0].dtype, device=device
141
+ )
142
+ result.fill_(pad_v)
143
+ for i in range(len(samples)):
144
+ result_i = result[i]
145
+ t = samples[i]
146
+ result_i[tuple(slice(0, k) for k in t.shape)] = t
147
+ return result
148
+
149
+
150
+ class Attention(nn.Module):
151
+ def __init__(self, embed_dim, num_heads, head_width, gated=False):
152
+ super().__init__()
153
+ assert embed_dim == num_heads * head_width
154
+
155
+ self.embed_dim = embed_dim
156
+ self.num_heads = num_heads
157
+ self.head_width = head_width
158
+
159
+ self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
160
+ self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
161
+ self.gated = gated
162
+ if gated:
163
+ self.g_proj = nn.Linear(embed_dim, embed_dim)
164
+ torch.nn.init.zeros_(self.g_proj.weight)
165
+ torch.nn.init.ones_(self.g_proj.bias)
166
+
167
+ self.rescale_factor = self.head_width**-0.5
168
+
169
+ torch.nn.init.zeros_(self.o_proj.bias)
170
+
171
+ def forward(self, x, mask=None, bias=None, indices=None):
172
+ """
173
+ Basic self attention with optional mask and external pairwise bias.
174
+ To handle sequences of different lengths, use mask.
175
+
176
+ Inputs:
177
+ x: batch of input sequneces (.. x L x C)
178
+ mask: batch of boolean masks where 1=valid, 0=padding position (.. x L_k). optional.
179
+ bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads). optional.
180
+
181
+ Outputs:
182
+ sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
183
+ """
184
+
185
+ t = rearrange(self.proj(x), "... l (h c) -> ... h l c", h=self.num_heads)
186
+ q, k, v = t.chunk(3, dim=-1)
187
+
188
+ q = self.rescale_factor * q
189
+ a = torch.einsum("...qc,...kc->...qk", q, k)
190
+
191
+ # Add external attention bias.
192
+ if bias is not None:
193
+ a = a + rearrange(bias, "... lq lk h -> ... h lq lk")
194
+
195
+ # Do not attend to padding tokens.
196
+ if mask is not None:
197
+ mask = repeat(
198
+ mask, "... lk -> ... h lq lk", h=self.num_heads, lq=q.shape[-2]
199
+ )
200
+ a = a.masked_fill(mask == False, -np.inf)
201
+
202
+ a = F.softmax(a, dim=-1)
203
+
204
+ y = torch.einsum("...hqk,...hkc->...qhc", a, v)
205
+ y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads)
206
+
207
+ if self.gated:
208
+ y = self.g_proj(x).sigmoid() * y
209
+ y = self.o_proj(y)
210
+
211
+ return y, rearrange(a, "... lq lk h -> ... h lq lk")
212
+
213
+
214
+ class Dropout(nn.Module):
215
+ """
216
+ Implementation of dropout with the ability to share the dropout mask
217
+ along a particular dimension.
218
+ """
219
+
220
+ def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]):
221
+ super(Dropout, self).__init__()
222
+
223
+ self.r = r
224
+ if type(batch_dim) == int:
225
+ batch_dim = [batch_dim]
226
+ self.batch_dim = batch_dim
227
+ self.dropout = nn.Dropout(self.r)
228
+
229
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
230
+ shape = list(x.shape)
231
+ if self.batch_dim is not None:
232
+ for bd in self.batch_dim:
233
+ shape[bd] = 1
234
+ return x * self.dropout(x.new_ones(shape))
235
+
236
+
237
+ class SequenceToPair(nn.Module):
238
+ def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
239
+ super().__init__()
240
+
241
+ self.layernorm = nn.LayerNorm(sequence_state_dim)
242
+ self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
243
+ self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
244
+
245
+ torch.nn.init.zeros_(self.proj.bias)
246
+ torch.nn.init.zeros_(self.o_proj.bias)
247
+
248
+ def forward(self, sequence_state):
249
+ """
250
+ Inputs:
251
+ sequence_state: B x L x sequence_state_dim
252
+
253
+ Output:
254
+ pairwise_state: B x L x L x pairwise_state_dim
255
+
256
+ Intermediate state:
257
+ B x L x L x 2*inner_dim
258
+ """
259
+
260
+ assert len(sequence_state.shape) == 3
261
+
262
+ s = self.layernorm(sequence_state)
263
+ s = self.proj(s)
264
+ q, k = s.chunk(2, dim=-1)
265
+
266
+ prod = q[:, None, :, :] * k[:, :, None, :]
267
+ diff = q[:, None, :, :] - k[:, :, None, :]
268
+
269
+ x = torch.cat([prod, diff], dim=-1)
270
+ x = self.o_proj(x)
271
+
272
+ return x
273
+
274
+
275
+ class PairToSequence(nn.Module):
276
+ def __init__(self, pairwise_state_dim, num_heads):
277
+ super().__init__()
278
+
279
+ self.layernorm = nn.LayerNorm(pairwise_state_dim)
280
+ self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
281
+
282
+ def forward(self, pairwise_state):
283
+ """
284
+ Inputs:
285
+ pairwise_state: B x L x L x pairwise_state_dim
286
+
287
+ Output:
288
+ pairwise_bias: B x L x L x num_heads
289
+ """
290
+ assert len(pairwise_state.shape) == 4
291
+ z = self.layernorm(pairwise_state)
292
+ pairwise_bias = self.linear(z)
293
+ return pairwise_bias
294
+
295
+
296
+ class ResidueMLP(nn.Module):
297
+ def __init__(self, embed_dim, inner_dim, norm=nn.LayerNorm, dropout=0):
298
+ super().__init__()
299
+
300
+ self.mlp = nn.Sequential(
301
+ norm(embed_dim),
302
+ nn.Linear(embed_dim, inner_dim),
303
+ nn.ReLU(),
304
+ nn.Linear(inner_dim, embed_dim),
305
+ nn.Dropout(dropout),
306
+ )
307
+
308
+ def forward(self, x):
309
+ return x + self.mlp(x)
esm/esmfold/v1/pretrained.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from esm.esmfold.v1.esmfold import ESMFold
11
+
12
+
13
+ def _load_model(model_name):
14
+ if model_name.endswith(".pt"): # local, treat as filepath
15
+ model_path = Path(model_name)
16
+ model_data = torch.load(str(model_path), map_location="cpu")
17
+ else: # load from hub
18
+ url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
19
+ model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
20
+
21
+ cfg = model_data["cfg"]["model"]
22
+ model_state = model_data["model"]
23
+ model = ESMFold(esmfold_config=cfg)
24
+
25
+ expected_keys = set(model.state_dict().keys())
26
+ found_keys = set(model_state.keys())
27
+
28
+ missing_essential_keys = []
29
+ for missing_key in expected_keys - found_keys:
30
+ if not missing_key.startswith("esm."):
31
+ missing_essential_keys.append(missing_key)
32
+
33
+ if missing_essential_keys:
34
+ raise RuntimeError(f"Keys '{', '.join(missing_essential_keys)}' are missing.")
35
+
36
+ model.load_state_dict(model_state, strict=False)
37
+
38
+ return model
39
+
40
+
41
+ def esmfold_v0():
42
+ """
43
+ ESMFold v0 model with 3B ESM-2, 48 folding blocks.
44
+ This version was used for the paper (Lin et al, 2022). It was trained
45
+ on all PDB chains until 2020-05, to ensure temporal holdout with CASP14
46
+ and the CAMEO validation and test set reported there.
47
+ """
48
+ return _load_model("esmfold_3B_v0")
49
+
50
+
51
+ def esmfold_v1():
52
+ """
53
+ ESMFold v1 model using 3B ESM-2, 48 folding blocks.
54
+ ESMFold provides fast high accuracy atomic level structure prediction
55
+ directly from the individual sequence of a protein. ESMFold uses the ESM2
56
+ protein language model to extract meaningful representations from the
57
+ protein sequence.
58
+ """
59
+ return _load_model("esmfold_3B_v1")
60
+
61
+
62
+ def esmfold_structure_module_only_8M():
63
+ """
64
+ ESMFold baseline model using 8M ESM-2, 0 folding blocks.
65
+ ESM-2 here is trained out to 500K updates.
66
+ This is a model designed to test the capabilities of the language model
67
+ when ablated for number of parameters in the language model.
68
+ See table S1 in (Lin et al, 2022).
69
+ """
70
+ return _load_model("esmfold_structure_module_only_8M")
71
+
72
+
73
+ def esmfold_structure_module_only_8M_270K():
74
+ """
75
+ ESMFold baseline model using 8M ESM-2, 0 folding blocks.
76
+ ESM-2 here is trained out to 270K updates.
77
+ This is a model designed to test the capabilities of the language model
78
+ when ablated for number of parameters in the language model.
79
+ See table S1 in (Lin et al, 2022).
80
+ """
81
+ return _load_model("esmfold_structure_module_only_8M_270K")
82
+
83
+
84
+ def esmfold_structure_module_only_35M():
85
+ """
86
+ ESMFold baseline model using 35M ESM-2, 0 folding blocks.
87
+ ESM-2 here is trained out to 500K updates.
88
+ This is a model designed to test the capabilities of the language model
89
+ when ablated for number of parameters in the language model.
90
+ See table S1 in (Lin et al, 2022).
91
+ """
92
+ return _load_model("esmfold_structure_module_only_35M")
93
+
94
+
95
+ def esmfold_structure_module_only_35M_270K():
96
+ """
97
+ ESMFold baseline model using 35M ESM-2, 0 folding blocks.
98
+ ESM-2 here is trained out to 270K updates.
99
+ This is a model designed to test the capabilities of the language model
100
+ when ablated for number of parameters in the language model.
101
+ See table S1 in (Lin et al, 2022).
102
+ """
103
+ return _load_model("esmfold_structure_module_only_35M_270K")
104
+
105
+
106
+ def esmfold_structure_module_only_150M():
107
+ """
108
+ ESMFold baseline model using 150M ESM-2, 0 folding blocks.
109
+ ESM-2 here is trained out to 500K updates.
110
+ This is a model designed to test the capabilities of the language model
111
+ when ablated for number of parameters in the language model.
112
+ See table S1 in (Lin et al, 2022).
113
+ """
114
+ return _load_model("esmfold_structure_module_only_150M")
115
+
116
+
117
+ def esmfold_structure_module_only_150M_270K():
118
+ """
119
+ ESMFold baseline model using 150M ESM-2, 0 folding blocks.
120
+ ESM-2 here is trained out to 270K updates.
121
+ This is a model designed to test the capabilities of the language model
122
+ when ablated for number of parameters in the language model.
123
+ See table S1 in (Lin et al, 2022).
124
+ """
125
+ return _load_model("esmfold_structure_module_only_150M_270K")
126
+
127
+
128
+ def esmfold_structure_module_only_650M():
129
+ """
130
+ ESMFold baseline model using 650M ESM-2, 0 folding blocks.
131
+ ESM-2 here is trained out to 500K updates.
132
+ This is a model designed to test the capabilities of the language model
133
+ when ablated for number of parameters in the language model.
134
+ See table S1 in (Lin et al, 2022).
135
+ """
136
+ return _load_model("esmfold_structure_module_only_650M")
137
+
138
+
139
+ def esmfold_structure_module_only_650M_270K():
140
+ """
141
+ ESMFold baseline model using 650M ESM-2, 0 folding blocks.
142
+ ESM-2 here is trained out to 270K updates.
143
+ This is a model designed to test the capabilities of the language model
144
+ when ablated for number of parameters in the language model.
145
+ See table S1 in (Lin et al, 2022).
146
+ """
147
+ return _load_model("esmfold_structure_module_only_650M_270K")
148
+
149
+
150
+ def esmfold_structure_module_only_3B():
151
+ """
152
+ ESMFold baseline model using 3B ESM-2, 0 folding blocks.
153
+ ESM-2 here is trained out to 500K updates.
154
+ This is a model designed to test the capabilities of the language model
155
+ when ablated for number of parameters in the language model.
156
+ See table S1 in (Lin et al, 2022).
157
+ """
158
+ return _load_model("esmfold_structure_module_only_3B")
159
+
160
+
161
+ def esmfold_structure_module_only_3B_270K():
162
+ """
163
+ ESMFold baseline model using 3B ESM-2, 0 folding blocks.
164
+ ESM-2 here is trained out to 270K updates.
165
+ This is a model designed to test the capabilities of the language model
166
+ when ablated for number of parameters in the language model.
167
+ See table S1 in (Lin et al, 2022).
168
+ """
169
+ return _load_model("esmfold_structure_module_only_3B_270K")
170
+
171
+
172
+ def esmfold_structure_module_only_15B():
173
+ """
174
+ ESMFold baseline model using 15B ESM-2, 0 folding blocks.
175
+ ESM-2 here is trained out to 270K updates.
176
+ The 15B parameter ESM-2 was not trained out to 500K updates
177
+ This is a model designed to test the capabilities of the language model
178
+ when ablated for number of parameters in the language model.
179
+ See table S1 in (Lin et al, 2022).
180
+ """
181
+ return _load_model("esmfold_structure_module_only_15B")
esm/esmfold/v1/tri_self_attn_block.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+ from openfold.model.triangular_attention import (
7
+ TriangleAttentionEndingNode,
8
+ TriangleAttentionStartingNode,
9
+ )
10
+ from openfold.model.triangular_multiplicative_update import (
11
+ TriangleMultiplicationIncoming,
12
+ TriangleMultiplicationOutgoing,
13
+ )
14
+ from torch import nn
15
+
16
+ from esm.esmfold.v1.misc import (
17
+ Attention,
18
+ Dropout,
19
+ PairToSequence,
20
+ ResidueMLP,
21
+ SequenceToPair,
22
+ )
23
+
24
+
25
+ class TriangularSelfAttentionBlock(nn.Module):
26
+ def __init__(
27
+ self,
28
+ sequence_state_dim,
29
+ pairwise_state_dim,
30
+ sequence_head_width,
31
+ pairwise_head_width,
32
+ dropout=0,
33
+ **__kwargs,
34
+ ):
35
+ super().__init__()
36
+
37
+ assert sequence_state_dim % sequence_head_width == 0
38
+ assert pairwise_state_dim % pairwise_head_width == 0
39
+ sequence_num_heads = sequence_state_dim // sequence_head_width
40
+ pairwise_num_heads = pairwise_state_dim // pairwise_head_width
41
+ assert sequence_state_dim == sequence_num_heads * sequence_head_width
42
+ assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width
43
+ assert pairwise_state_dim % 2 == 0
44
+
45
+ self.sequence_state_dim = sequence_state_dim
46
+ self.pairwise_state_dim = pairwise_state_dim
47
+
48
+ self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
49
+
50
+ self.sequence_to_pair = SequenceToPair(
51
+ sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim
52
+ )
53
+ self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads)
54
+
55
+ self.seq_attention = Attention(
56
+ sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True
57
+ )
58
+ self.tri_mul_out = TriangleMultiplicationOutgoing(
59
+ pairwise_state_dim,
60
+ pairwise_state_dim,
61
+ )
62
+ self.tri_mul_in = TriangleMultiplicationIncoming(
63
+ pairwise_state_dim,
64
+ pairwise_state_dim,
65
+ )
66
+ self.tri_att_start = TriangleAttentionStartingNode(
67
+ pairwise_state_dim,
68
+ pairwise_head_width,
69
+ pairwise_num_heads,
70
+ inf=1e9,
71
+ ) # type: ignore
72
+ self.tri_att_end = TriangleAttentionEndingNode(
73
+ pairwise_state_dim,
74
+ pairwise_head_width,
75
+ pairwise_num_heads,
76
+ inf=1e9,
77
+ ) # type: ignore
78
+
79
+ self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout)
80
+ self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout)
81
+
82
+ assert dropout < 0.4
83
+ self.drop = nn.Dropout(dropout)
84
+ self.row_drop = Dropout(dropout * 2, 2)
85
+ self.col_drop = Dropout(dropout * 2, 1)
86
+
87
+ torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight)
88
+ torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias)
89
+ torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight)
90
+ torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias)
91
+ torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight)
92
+ torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias)
93
+ torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight)
94
+ torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias)
95
+
96
+ torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight)
97
+ torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias)
98
+ torch.nn.init.zeros_(self.pair_to_sequence.linear.weight)
99
+ torch.nn.init.zeros_(self.seq_attention.o_proj.weight)
100
+ torch.nn.init.zeros_(self.seq_attention.o_proj.bias)
101
+ torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight)
102
+ torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias)
103
+ torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight)
104
+ torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias)
105
+
106
+ def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
107
+ """
108
+ Inputs:
109
+ sequence_state: B x L x sequence_state_dim
110
+ pairwise_state: B x L x L x pairwise_state_dim
111
+ mask: B x L boolean tensor of valid positions
112
+
113
+ Output:
114
+ sequence_state: B x L x sequence_state_dim
115
+ pairwise_state: B x L x L x pairwise_state_dim
116
+ """
117
+ assert len(sequence_state.shape) == 3
118
+ assert len(pairwise_state.shape) == 4
119
+ if mask is not None:
120
+ assert len(mask.shape) == 2
121
+
122
+ batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
123
+ pairwise_state_dim = pairwise_state.shape[3]
124
+ assert sequence_state_dim == self.sequence_state_dim
125
+ assert pairwise_state_dim == self.pairwise_state_dim
126
+ assert batch_dim == pairwise_state.shape[0]
127
+ assert seq_dim == pairwise_state.shape[1]
128
+ assert seq_dim == pairwise_state.shape[2]
129
+
130
+ # Update sequence state
131
+ bias = self.pair_to_sequence(pairwise_state)
132
+
133
+ # Self attention with bias + mlp.
134
+ y = self.layernorm_1(sequence_state)
135
+ y, _ = self.seq_attention(y, mask=mask, bias=bias)
136
+ sequence_state = sequence_state + self.drop(y)
137
+ sequence_state = self.mlp_seq(sequence_state)
138
+
139
+ # Update pairwise state
140
+ pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
141
+
142
+ # Axial attention with triangular bias.
143
+ tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
144
+ pairwise_state = pairwise_state + self.row_drop(
145
+ self.tri_mul_out(pairwise_state, mask=tri_mask)
146
+ )
147
+ pairwise_state = pairwise_state + self.col_drop(
148
+ self.tri_mul_in(pairwise_state, mask=tri_mask)
149
+ )
150
+ pairwise_state = pairwise_state + self.row_drop(
151
+ self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
152
+ )
153
+ pairwise_state = pairwise_state + self.col_drop(
154
+ self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
155
+ )
156
+
157
+ # MLP over pairs.
158
+ pairwise_state = self.mlp_pair(pairwise_state)
159
+
160
+ return sequence_state, pairwise_state
esm/esmfold/v1/trunk.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import typing as T
6
+ from contextlib import ExitStack
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from openfold.model.structure_module import StructureModule
12
+
13
+ from esm.esmfold.v1.tri_self_attn_block import TriangularSelfAttentionBlock
14
+
15
+
16
+ @dataclass
17
+ class StructureModuleConfig:
18
+ c_s: int = 384
19
+ c_z: int = 128
20
+ c_ipa: int = 16
21
+ c_resnet: int = 128
22
+ no_heads_ipa: int = 12
23
+ no_qk_points: int = 4
24
+ no_v_points: int = 8
25
+ dropout_rate: float = 0.1
26
+ no_blocks: int = 8
27
+ no_transition_layers: int = 1
28
+ no_resnet_blocks: int = 2
29
+ no_angles: int = 7
30
+ trans_scale_factor: int = 10
31
+ epsilon: float = 1e-8
32
+ inf: float = 1e5
33
+
34
+
35
+ @dataclass
36
+ class FoldingTrunkConfig:
37
+ _name: str = "FoldingTrunkConfig"
38
+ num_blocks: int = 48
39
+ sequence_state_dim: int = 1024
40
+ pairwise_state_dim: int = 128
41
+ sequence_head_width: int = 32
42
+ pairwise_head_width: int = 32
43
+ position_bins: int = 32
44
+ dropout: float = 0
45
+ layer_drop: float = 0
46
+ cpu_grad_checkpoint: bool = False
47
+
48
+ max_recycles: int = 4
49
+ chunk_size: T.Optional[int] = None
50
+
51
+ structure_module: StructureModuleConfig = StructureModuleConfig()
52
+
53
+
54
+ def get_axial_mask(mask):
55
+ """
56
+ Helper to convert B x L mask of valid positions to axial mask used
57
+ in row column attentions.
58
+
59
+ Input:
60
+ mask: B x L tensor of booleans
61
+
62
+ Output:
63
+ mask: B x L x L tensor of booleans
64
+ """
65
+
66
+ if mask is None:
67
+ return None
68
+ assert len(mask.shape) == 2
69
+ batch_dim, seq_dim = mask.shape
70
+ m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
71
+ m = m.reshape(batch_dim * seq_dim, seq_dim)
72
+ return m
73
+
74
+
75
+ class RelativePosition(nn.Module):
76
+ def __init__(self, bins, pairwise_state_dim):
77
+ super().__init__()
78
+ self.bins = bins
79
+
80
+ # Note an additional offset is used so that the 0th position
81
+ # is reserved for masked pairs.
82
+ self.embedding = torch.nn.Embedding(2 * bins + 2, pairwise_state_dim)
83
+
84
+ def forward(self, residue_index, mask=None):
85
+ """
86
+ Input:
87
+ residue_index: B x L tensor of indices (dytpe=torch.long)
88
+ mask: B x L tensor of booleans
89
+
90
+ Output:
91
+ pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
92
+ """
93
+
94
+ assert residue_index.dtype == torch.long
95
+ if mask is not None:
96
+ assert residue_index.shape == mask.shape
97
+
98
+ diff = residue_index[:, None, :] - residue_index[:, :, None]
99
+ diff = diff.clamp(-self.bins, self.bins)
100
+ diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
101
+
102
+ if mask is not None:
103
+ mask = mask[:, None, :] * mask[:, :, None]
104
+ diff[mask == False] = 0
105
+
106
+ output = self.embedding(diff)
107
+ return output
108
+
109
+
110
+ class FoldingTrunk(nn.Module):
111
+ def __init__(self, **kwargs):
112
+ super().__init__()
113
+ self.cfg = FoldingTrunkConfig(**kwargs)
114
+ assert self.cfg.max_recycles > 0
115
+
116
+ c_s = self.cfg.sequence_state_dim
117
+ c_z = self.cfg.pairwise_state_dim
118
+
119
+ assert c_s % self.cfg.sequence_head_width == 0
120
+ assert c_z % self.cfg.pairwise_head_width == 0
121
+ block = TriangularSelfAttentionBlock
122
+
123
+ self.pairwise_positional_embedding = RelativePosition(self.cfg.position_bins, c_z)
124
+
125
+ self.blocks = nn.ModuleList(
126
+ [
127
+ block(
128
+ sequence_state_dim=c_s,
129
+ pairwise_state_dim=c_z,
130
+ sequence_head_width=self.cfg.sequence_head_width,
131
+ pairwise_head_width=self.cfg.pairwise_head_width,
132
+ dropout=self.cfg.dropout,
133
+ )
134
+ for i in range(self.cfg.num_blocks)
135
+ ]
136
+ )
137
+
138
+ self.recycle_bins = 15
139
+ self.recycle_s_norm = nn.LayerNorm(c_s)
140
+ self.recycle_z_norm = nn.LayerNorm(c_z)
141
+ self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
142
+ self.recycle_disto.weight[0].detach().zero_()
143
+
144
+ self.structure_module = StructureModule(**self.cfg.structure_module) # type: ignore
145
+ self.trunk2sm_s = nn.Linear(c_s, self.structure_module.c_s)
146
+ self.trunk2sm_z = nn.Linear(c_z, self.structure_module.c_z)
147
+
148
+ self.chunk_size = self.cfg.chunk_size
149
+
150
+ def set_chunk_size(self, chunk_size):
151
+ # This parameter means the axial attention will be computed
152
+ # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
153
+ # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
154
+ # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
155
+ self.chunk_size = chunk_size
156
+
157
+ def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles: T.Optional[int] = None):
158
+ """
159
+ Inputs:
160
+ seq_feats: B x L x C tensor of sequence features
161
+ pair_feats: B x L x L x C tensor of pair features
162
+ residx: B x L long tensor giving the position in the sequence
163
+ mask: B x L boolean tensor indicating valid residues
164
+
165
+ Output:
166
+ predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
167
+ """
168
+
169
+ device = seq_feats.device
170
+ s_s_0 = seq_feats
171
+ s_z_0 = pair_feats
172
+
173
+ if no_recycles is None:
174
+ no_recycles = self.cfg.max_recycles
175
+ else:
176
+ assert no_recycles >= 0, "Number of recycles must not be negative."
177
+ no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
178
+
179
+ def trunk_iter(s, z, residx, mask):
180
+ z = z + self.pairwise_positional_embedding(residx, mask=mask)
181
+
182
+ for block in self.blocks:
183
+ s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
184
+ return s, z
185
+
186
+ s_s = s_s_0
187
+ s_z = s_z_0
188
+ recycle_s = torch.zeros_like(s_s)
189
+ recycle_z = torch.zeros_like(s_z)
190
+ recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
191
+
192
+ assert no_recycles > 0
193
+ for recycle_idx in range(no_recycles):
194
+ with ExitStack() if recycle_idx == no_recycles - 1 else torch.no_grad():
195
+ # === Recycling ===
196
+ recycle_s = self.recycle_s_norm(recycle_s.detach())
197
+ recycle_z = self.recycle_z_norm(recycle_z.detach())
198
+ recycle_z += self.recycle_disto(recycle_bins.detach())
199
+
200
+ s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
201
+
202
+ # === Structure module ===
203
+ structure = self.structure_module(
204
+ {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
205
+ true_aa,
206
+ mask.float(),
207
+ )
208
+
209
+ recycle_s = s_s
210
+ recycle_z = s_z
211
+ # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
212
+ recycle_bins = FoldingTrunk.distogram(
213
+ structure["positions"][-1][:, :, :3],
214
+ 3.375,
215
+ 21.375,
216
+ self.recycle_bins,
217
+ )
218
+
219
+ assert isinstance(structure, dict) # type: ignore
220
+ structure["s_s"] = s_s
221
+ structure["s_z"] = s_z
222
+
223
+ return structure
224
+
225
+ @staticmethod
226
+ def distogram(coords, min_bin, max_bin, num_bins):
227
+ # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
228
+ boundaries = torch.linspace(
229
+ min_bin,
230
+ max_bin,
231
+ num_bins - 1,
232
+ device=coords.device,
233
+ )
234
+ boundaries = boundaries**2
235
+ N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
236
+ # Infer CB coordinates.
237
+ b = CA - N
238
+ c = C - CA
239
+ a = b.cross(c, dim=-1)
240
+ CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
241
+ dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
242
+ bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
243
+ return bins
esm/inverse_folding/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import gvp_transformer
7
+ from . import util
8
+ from . import multichain_util
esm/inverse_folding/features.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Portions of this file were adapted from the open source code for the following
7
+ # two papers:
8
+ #
9
+ # Ingraham, J., Garg, V., Barzilay, R., & Jaakkola, T. (2019). Generative
10
+ # models for graph-based protein design. Advances in Neural Information
11
+ # Processing Systems, 32.
12
+ #
13
+ # Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020).
14
+ # Learning from Protein Structure with Geometric Vector Perceptrons. In
15
+ # International Conference on Learning Representations.
16
+ #
17
+ # MIT License
18
+ #
19
+ # Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+ #
39
+ # ================================================================
40
+ # The below license applies to the portions of the code (parts of
41
+ # src/datasets.py and src/models.py) adapted from Ingraham, et al.
42
+ # ================================================================
43
+ #
44
+ # MIT License
45
+ #
46
+ # Copyright (c) 2019 John Ingraham, Vikas Garg, Regina Barzilay, Tommi Jaakkola
47
+ #
48
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
49
+ # of this software and associated documentation files (the "Software"), to deal
50
+ # in the Software without restriction, including without limitation the rights
51
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
52
+ # copies of the Software, and to permit persons to whom the Software is
53
+ # furnished to do so, subject to the following conditions:
54
+ #
55
+ # The above copyright notice and this permission notice shall be included in all
56
+ # copies or substantial portions of the Software.
57
+ #
58
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
59
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
60
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
61
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
62
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
63
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
64
+ # SOFTWARE.
65
+
66
+ import math
67
+ import numpy as np
68
+ import torch
69
+ import torch.nn as nn
70
+ import torch.nn.functional as F
71
+
72
+ from .gvp_utils import flatten_graph
73
+ from .gvp_modules import GVP, LayerNorm
74
+ from .util import normalize, norm, nan_to_num, rbf
75
+
76
+
77
+ class GVPInputFeaturizer(nn.Module):
78
+
79
+ @staticmethod
80
+ def get_node_features(coords, coord_mask, with_coord_mask=True):
81
+ # scalar features
82
+ node_scalar_features = GVPInputFeaturizer._dihedrals(coords)
83
+ if with_coord_mask:
84
+ node_scalar_features = torch.cat([
85
+ node_scalar_features,
86
+ coord_mask.float().unsqueeze(-1)
87
+ ], dim=-1)
88
+ # vector features
89
+ X_ca = coords[:, :, 1]
90
+ orientations = GVPInputFeaturizer._orientations(X_ca)
91
+ sidechains = GVPInputFeaturizer._sidechains(coords)
92
+ node_vector_features = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
93
+ return node_scalar_features, node_vector_features
94
+
95
+ @staticmethod
96
+ def _orientations(X):
97
+ forward = normalize(X[:, 1:] - X[:, :-1])
98
+ backward = normalize(X[:, :-1] - X[:, 1:])
99
+ forward = F.pad(forward, [0, 0, 0, 1])
100
+ backward = F.pad(backward, [0, 0, 1, 0])
101
+ return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
102
+
103
+ @staticmethod
104
+ def _sidechains(X):
105
+ n, origin, c = X[:, :, 0], X[:, :, 1], X[:, :, 2]
106
+ c, n = normalize(c - origin), normalize(n - origin)
107
+ bisector = normalize(c + n)
108
+ perp = normalize(torch.cross(c, n, dim=-1))
109
+ vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
110
+ return vec
111
+
112
+ @staticmethod
113
+ def _dihedrals(X, eps=1e-7):
114
+ X = torch.flatten(X[:, :, :3], 1, 2)
115
+ bsz = X.shape[0]
116
+ dX = X[:, 1:] - X[:, :-1]
117
+ U = normalize(dX, dim=-1)
118
+ u_2 = U[:, :-2]
119
+ u_1 = U[:, 1:-1]
120
+ u_0 = U[:, 2:]
121
+
122
+ # Backbone normals
123
+ n_2 = normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
124
+ n_1 = normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
125
+
126
+ # Angle between normals
127
+ cosD = torch.sum(n_2 * n_1, -1)
128
+ cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
129
+ D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
130
+
131
+ # This scheme will remove phi[0], psi[-1], omega[-1]
132
+ D = F.pad(D, [1, 2])
133
+ D = torch.reshape(D, [bsz, -1, 3])
134
+ # Lift angle representations to the circle
135
+ D_features = torch.cat([torch.cos(D), torch.sin(D)], -1)
136
+ return D_features
137
+
138
+ @staticmethod
139
+ def _positional_embeddings(edge_index,
140
+ num_embeddings=None,
141
+ num_positional_embeddings=16,
142
+ period_range=[2, 1000]):
143
+ # From https://github.com/jingraham/neurips19-graph-protein-design
144
+ num_embeddings = num_embeddings or num_positional_embeddings
145
+ d = edge_index[0] - edge_index[1]
146
+
147
+ frequency = torch.exp(
148
+ torch.arange(0, num_embeddings, 2, dtype=torch.float32,
149
+ device=edge_index.device)
150
+ * -(np.log(10000.0) / num_embeddings)
151
+ )
152
+ angles = d.unsqueeze(-1) * frequency
153
+ E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
154
+ return E
155
+
156
+ @staticmethod
157
+ def _dist(X, coord_mask, padding_mask, top_k_neighbors, eps=1e-8):
158
+ """ Pairwise euclidean distances """
159
+ bsz, maxlen = X.size(0), X.size(1)
160
+ coord_mask_2D = torch.unsqueeze(coord_mask,1) * torch.unsqueeze(coord_mask,2)
161
+ residue_mask = ~padding_mask
162
+ residue_mask_2D = torch.unsqueeze(residue_mask,1) * torch.unsqueeze(residue_mask,2)
163
+ dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
164
+ D = coord_mask_2D * norm(dX, dim=-1)
165
+
166
+ # sorting preference: first those with coords, then among the residues that
167
+ # exist but are masked use distance in sequence as tie breaker, and then the
168
+ # residues that came from padding are last
169
+ seqpos = torch.arange(maxlen, device=X.device)
170
+ Dseq = torch.abs(seqpos.unsqueeze(1) - seqpos.unsqueeze(0)).repeat(bsz, 1, 1)
171
+ D_adjust = nan_to_num(D) + (~coord_mask_2D) * (1e8 + Dseq*1e6) + (
172
+ ~residue_mask_2D) * (1e10)
173
+
174
+ if top_k_neighbors == -1:
175
+ D_neighbors = D_adjust
176
+ E_idx = seqpos.repeat(
177
+ *D_neighbors.shape[:-1], 1)
178
+ else:
179
+ # Identify k nearest neighbors (including self)
180
+ k = min(top_k_neighbors, X.size(1))
181
+ D_neighbors, E_idx = torch.topk(D_adjust, k, dim=-1, largest=False)
182
+
183
+ coord_mask_neighbors = (D_neighbors < 5e7)
184
+ residue_mask_neighbors = (D_neighbors < 5e9)
185
+ return D_neighbors, E_idx, coord_mask_neighbors, residue_mask_neighbors
186
+
187
+
188
+ class Normalize(nn.Module):
189
+ def __init__(self, features, epsilon=1e-6):
190
+ super(Normalize, self).__init__()
191
+ self.gain = nn.Parameter(torch.ones(features))
192
+ self.bias = nn.Parameter(torch.zeros(features))
193
+ self.epsilon = epsilon
194
+
195
+ def forward(self, x, dim=-1):
196
+ mu = x.mean(dim, keepdim=True)
197
+ sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon)
198
+ gain = self.gain
199
+ bias = self.bias
200
+ # Reshape
201
+ if dim != -1:
202
+ shape = [1] * len(mu.size())
203
+ shape[dim] = self.gain.size()[0]
204
+ gain = gain.view(shape)
205
+ bias = bias.view(shape)
206
+ return gain * (x - mu) / (sigma + self.epsilon) + bias
207
+
208
+
209
+ class DihedralFeatures(nn.Module):
210
+ def __init__(self, node_embed_dim):
211
+ """ Embed dihedral angle features. """
212
+ super(DihedralFeatures, self).__init__()
213
+ # 3 dihedral angles; sin and cos of each angle
214
+ node_in = 6
215
+ # Normalization and embedding
216
+ self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True)
217
+ self.norm_nodes = Normalize(node_embed_dim)
218
+
219
+ def forward(self, X):
220
+ """ Featurize coordinates as an attributed graph """
221
+ V = self._dihedrals(X)
222
+ V = self.node_embedding(V)
223
+ V = self.norm_nodes(V)
224
+ return V
225
+
226
+ @staticmethod
227
+ def _dihedrals(X, eps=1e-7, return_angles=False):
228
+ # First 3 coordinates are N, CA, C
229
+ X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3)
230
+
231
+ # Shifted slices of unit vectors
232
+ dX = X[:,1:,:] - X[:,:-1,:]
233
+ U = F.normalize(dX, dim=-1)
234
+ u_2 = U[:,:-2,:]
235
+ u_1 = U[:,1:-1,:]
236
+ u_0 = U[:,2:,:]
237
+ # Backbone normals
238
+ n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
239
+ n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
240
+
241
+ # Angle between normals
242
+ cosD = (n_2 * n_1).sum(-1)
243
+ cosD = torch.clamp(cosD, -1+eps, 1-eps)
244
+ D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
245
+
246
+ # This scheme will remove phi[0], psi[-1], omega[-1]
247
+ D = F.pad(D, (1,2), 'constant', 0)
248
+ D = D.view((D.size(0), int(D.size(1)/3), 3))
249
+ phi, psi, omega = torch.unbind(D,-1)
250
+
251
+ if return_angles:
252
+ return phi, psi, omega
253
+
254
+ # Lift angle representations to the circle
255
+ D_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
256
+ return D_features
257
+
258
+
259
+ class GVPGraphEmbedding(GVPInputFeaturizer):
260
+
261
+ def __init__(self, args):
262
+ super().__init__()
263
+ self.top_k_neighbors = args.top_k_neighbors
264
+ self.num_positional_embeddings = 16
265
+ self.remove_edges_without_coords = True
266
+ node_input_dim = (7, 3)
267
+ edge_input_dim = (34, 1)
268
+ node_hidden_dim = (args.node_hidden_dim_scalar,
269
+ args.node_hidden_dim_vector)
270
+ edge_hidden_dim = (args.edge_hidden_dim_scalar,
271
+ args.edge_hidden_dim_vector)
272
+ self.embed_node = nn.Sequential(
273
+ GVP(node_input_dim, node_hidden_dim, activations=(None, None)),
274
+ LayerNorm(node_hidden_dim, eps=1e-4)
275
+ )
276
+ self.embed_edge = nn.Sequential(
277
+ GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)),
278
+ LayerNorm(edge_hidden_dim, eps=1e-4)
279
+ )
280
+ self.embed_confidence = nn.Linear(16, args.node_hidden_dim_scalar)
281
+
282
+ def forward(self, coords, coord_mask, padding_mask, confidence):
283
+ with torch.no_grad():
284
+ node_features = self.get_node_features(coords, coord_mask)
285
+ edge_features, edge_index = self.get_edge_features(
286
+ coords, coord_mask, padding_mask)
287
+ node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features)
288
+ edge_embeddings = self.embed_edge(edge_features)
289
+
290
+ rbf_rep = rbf(confidence, 0., 1.)
291
+ node_embeddings = (
292
+ node_embeddings_scalar + self.embed_confidence(rbf_rep),
293
+ node_embeddings_vector
294
+ )
295
+
296
+ node_embeddings, edge_embeddings, edge_index = flatten_graph(
297
+ node_embeddings, edge_embeddings, edge_index)
298
+ return node_embeddings, edge_embeddings, edge_index
299
+
300
+ def get_edge_features(self, coords, coord_mask, padding_mask):
301
+ X_ca = coords[:, :, 1]
302
+ # Get distances to the top k neighbors
303
+ E_dist, E_idx, E_coord_mask, E_residue_mask = GVPInputFeaturizer._dist(
304
+ X_ca, coord_mask, padding_mask, self.top_k_neighbors)
305
+ # Flatten the graph to be batch size 1 for torch_geometric package
306
+ dest = E_idx
307
+ B, L, k = E_idx.shape[:3]
308
+ src = torch.arange(L, device=E_idx.device).view([1, L, 1]).expand(B, L, k)
309
+ # After flattening, [2, B, E]
310
+ edge_index = torch.stack([src, dest], dim=0).flatten(2, 3)
311
+ # After flattening, [B, E]
312
+ E_dist = E_dist.flatten(1, 2)
313
+ E_coord_mask = E_coord_mask.flatten(1, 2).unsqueeze(-1)
314
+ E_residue_mask = E_residue_mask.flatten(1, 2)
315
+ # Calculate relative positional embeddings and distance RBF
316
+ pos_embeddings = GVPInputFeaturizer._positional_embeddings(
317
+ edge_index,
318
+ num_positional_embeddings=self.num_positional_embeddings,
319
+ )
320
+ D_rbf = rbf(E_dist, 0., 20.)
321
+ # Calculate relative orientation
322
+ X_src = X_ca.unsqueeze(2).expand(-1, -1, k, -1).flatten(1, 2)
323
+ X_dest = torch.gather(
324
+ X_ca,
325
+ 1,
326
+ edge_index[1, :, :].unsqueeze(-1).expand([B, L*k, 3])
327
+ )
328
+ coord_mask_src = coord_mask.unsqueeze(2).expand(-1, -1, k).flatten(1, 2)
329
+ coord_mask_dest = torch.gather(
330
+ coord_mask,
331
+ 1,
332
+ edge_index[1, :, :].expand([B, L*k])
333
+ )
334
+ E_vectors = X_src - X_dest
335
+ # For the ones without coordinates, substitute in the average vector
336
+ E_vector_mean = torch.sum(E_vectors * E_coord_mask, dim=1,
337
+ keepdims=True) / torch.sum(E_coord_mask, dim=1, keepdims=True)
338
+ E_vectors = E_vectors * E_coord_mask + E_vector_mean * ~(E_coord_mask)
339
+ # Normalize and remove nans
340
+ edge_s = torch.cat([D_rbf, pos_embeddings], dim=-1)
341
+ edge_v = normalize(E_vectors).unsqueeze(-2)
342
+ edge_s, edge_v = map(nan_to_num, (edge_s, edge_v))
343
+ # Also add indications of whether the coordinates are present
344
+ edge_s = torch.cat([
345
+ edge_s,
346
+ (~coord_mask_src).float().unsqueeze(-1),
347
+ (~coord_mask_dest).float().unsqueeze(-1),
348
+ ], dim=-1)
349
+ edge_index[:, ~E_residue_mask] = -1
350
+ if self.remove_edges_without_coords:
351
+ edge_index[:, ~E_coord_mask.squeeze(-1)] = -1
352
+ return (edge_s, edge_v), edge_index.transpose(0, 1)
esm/inverse_folding/gvp_encoder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from argparse import Namespace
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from .features import GVPGraphEmbedding
13
+ from .gvp_modules import GVPConvLayer, LayerNorm
14
+ from .gvp_utils import unflatten_graph
15
+
16
+
17
+
18
+ class GVPEncoder(nn.Module):
19
+
20
+ def __init__(self, args):
21
+ super().__init__()
22
+ self.args = args
23
+ self.embed_graph = GVPGraphEmbedding(args)
24
+
25
+ node_hidden_dim = (args.node_hidden_dim_scalar,
26
+ args.node_hidden_dim_vector)
27
+ edge_hidden_dim = (args.edge_hidden_dim_scalar,
28
+ args.edge_hidden_dim_vector)
29
+
30
+ conv_activations = (F.relu, torch.sigmoid)
31
+ self.encoder_layers = nn.ModuleList(
32
+ GVPConvLayer(
33
+ node_hidden_dim,
34
+ edge_hidden_dim,
35
+ drop_rate=args.dropout,
36
+ vector_gate=True,
37
+ attention_heads=0,
38
+ n_message=3,
39
+ conv_activations=conv_activations,
40
+ n_edge_gvps=0,
41
+ eps=1e-4,
42
+ layernorm=True,
43
+ )
44
+ for i in range(args.num_encoder_layers)
45
+ )
46
+
47
+ def forward(self, coords, coord_mask, padding_mask, confidence):
48
+ node_embeddings, edge_embeddings, edge_index = self.embed_graph(
49
+ coords, coord_mask, padding_mask, confidence)
50
+
51
+ for i, layer in enumerate(self.encoder_layers):
52
+ node_embeddings, edge_embeddings = layer(node_embeddings,
53
+ edge_index, edge_embeddings)
54
+
55
+ node_embeddings = unflatten_graph(node_embeddings, coords.shape[0])
56
+ return node_embeddings
esm/inverse_folding/gvp_modules.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contents of this file are from the open source code for
2
+ #
3
+ # Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020).
4
+ # Learning from Protein Structure with Geometric Vector Perceptrons. In
5
+ # International Conference on Learning Representations.
6
+ #
7
+ # MIT License
8
+ #
9
+ # Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror
10
+ #
11
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ # of this software and associated documentation files (the "Software"), to deal
13
+ # in the Software without restriction, including without limitation the rights
14
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ # copies of the Software, and to permit persons to whom the Software is
16
+ # furnished to do so, subject to the following conditions:
17
+ #
18
+ # The above copyright notice and this permission notice shall be included in all
19
+ # copies or substantial portions of the Software.
20
+ #
21
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ # SOFTWARE.
28
+
29
+ import typing as T
30
+ import torch
31
+ from torch import nn
32
+ import torch.nn.functional as F
33
+ from torch_geometric.nn import MessagePassing
34
+
35
+ def tuple_size(tp):
36
+ return tuple([0 if a is None else a.size() for a in tp])
37
+
38
+ def tuple_sum(tp1, tp2):
39
+ s1, v1 = tp1
40
+ s2, v2 = tp2
41
+ if v2 is None and v2 is None:
42
+ return (s1 + s2, None)
43
+ return (s1 + s2, v1 + v2)
44
+
45
+ def tuple_cat(*args, dim=-1):
46
+ '''
47
+ Concatenates any number of tuples (s, V) elementwise.
48
+
49
+ :param dim: dimension along which to concatenate when viewed
50
+ as the `dim` index for the scalar-channel tensors.
51
+ This means that `dim=-1` will be applied as
52
+ `dim=-2` for the vector-channel tensors.
53
+ '''
54
+ dim %= len(args[0][0].shape)
55
+ s_args, v_args = list(zip(*args))
56
+ return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
57
+
58
+ def tuple_index(x, idx):
59
+ '''
60
+ Indexes into a tuple (s, V) along the first dimension.
61
+
62
+ :param idx: any object which can be used to index into a `torch.Tensor`
63
+ '''
64
+ return x[0][idx], x[1][idx]
65
+
66
+ def randn(n, dims, device="cpu"):
67
+ '''
68
+ Returns random tuples (s, V) drawn elementwise from a normal distribution.
69
+
70
+ :param n: number of data points
71
+ :param dims: tuple of dimensions (n_scalar, n_vector)
72
+
73
+ :return: (s, V) with s.shape = (n, n_scalar) and
74
+ V.shape = (n, n_vector, 3)
75
+ '''
76
+ return torch.randn(n, dims[0], device=device), \
77
+ torch.randn(n, dims[1], 3, device=device)
78
+
79
+ def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
80
+ '''
81
+ L2 norm of tensor clamped above a minimum value `eps`.
82
+
83
+ :param sqrt: if `False`, returns the square of the L2 norm
84
+ '''
85
+ # clamp is slow
86
+ # out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
87
+ out = torch.sum(torch.square(x), axis, keepdims) + eps
88
+ return torch.sqrt(out) if sqrt else out
89
+
90
+ def _split(x, nv):
91
+ '''
92
+ Splits a merged representation of (s, V) back into a tuple.
93
+ Should be used only with `_merge(s, V)` and only if the tuple
94
+ representation cannot be used.
95
+
96
+ :param x: the `torch.Tensor` returned from `_merge`
97
+ :param nv: the number of vector channels in the input to `_merge`
98
+ '''
99
+ v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
100
+ s = x[..., :-3*nv]
101
+ return s, v
102
+
103
+ def _merge(s, v):
104
+ '''
105
+ Merges a tuple (s, V) into a single `torch.Tensor`, where the
106
+ vector channels are flattened and appended to the scalar channels.
107
+ Should be used only if the tuple representation cannot be used.
108
+ Use `_split(x, nv)` to reverse.
109
+ '''
110
+ v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
111
+ return torch.cat([s, v], -1)
112
+
113
+ class GVP(nn.Module):
114
+ '''
115
+ Geometric Vector Perceptron. See manuscript and README.md
116
+ for more details.
117
+
118
+ :param in_dims: tuple (n_scalar, n_vector)
119
+ :param out_dims: tuple (n_scalar, n_vector)
120
+ :param h_dim: intermediate number of vector channels, optional
121
+ :param activations: tuple of functions (scalar_act, vector_act)
122
+ :param tuple_io: whether to keep accepting tuple inputs and outputs when vi
123
+ or vo = 0
124
+ '''
125
+ def __init__(self, in_dims, out_dims, h_dim=None, vector_gate=False,
126
+ activations=(F.relu, torch.sigmoid), tuple_io=True,
127
+ eps=1e-8):
128
+ super(GVP, self).__init__()
129
+ self.si, self.vi = in_dims
130
+ self.so, self.vo = out_dims
131
+ self.tuple_io = tuple_io
132
+ if self.vi:
133
+ self.h_dim = h_dim or max(self.vi, self.vo)
134
+ self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
135
+ self.ws = nn.Linear(self.h_dim + self.si, self.so)
136
+ if self.vo:
137
+ self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
138
+ if vector_gate:
139
+ self.wg = nn.Linear(self.so, self.vo)
140
+ else:
141
+ self.ws = nn.Linear(self.si, self.so)
142
+
143
+ self.vector_gate = vector_gate
144
+ self.scalar_act, self.vector_act = activations
145
+ self.eps = eps
146
+
147
+ def forward(self, x):
148
+ '''
149
+ :param x: tuple (s, V) of `torch.Tensor`,
150
+ or (if vectors_in is 0), a single `torch.Tensor`
151
+ :return: tuple (s, V) of `torch.Tensor`,
152
+ or (if vectors_out is 0), a single `torch.Tensor`
153
+ '''
154
+ if self.vi:
155
+ s, v = x
156
+ v = torch.transpose(v, -1, -2)
157
+ vh = self.wh(v)
158
+ vn = _norm_no_nan(vh, axis=-2, eps=self.eps)
159
+ s = self.ws(torch.cat([s, vn], -1))
160
+ if self.scalar_act:
161
+ s = self.scalar_act(s)
162
+ if self.vo:
163
+ v = self.wv(vh)
164
+ v = torch.transpose(v, -1, -2)
165
+ if self.vector_gate:
166
+ g = self.wg(s).unsqueeze(-1)
167
+ else:
168
+ g = _norm_no_nan(v, axis=-1, keepdims=True, eps=self.eps)
169
+ if self.vector_act:
170
+ g = self.vector_act(g)
171
+ v = v * g
172
+ else:
173
+ if self.tuple_io:
174
+ assert x[1] is None
175
+ x = x[0]
176
+ s = self.ws(x)
177
+ if self.scalar_act:
178
+ s = self.scalar_act(s)
179
+ if self.vo:
180
+ v = torch.zeros(list(s.shape)[:-1] + [self.vo, 3],
181
+ device=s.device)
182
+
183
+ if self.vo:
184
+ return (s, v)
185
+ elif self.tuple_io:
186
+ return (s, None)
187
+ else:
188
+ return s
189
+
190
+
191
+ class _VDropout(nn.Module):
192
+ '''
193
+ Vector channel dropout where the elements of each
194
+ vector channel are dropped together.
195
+ '''
196
+ def __init__(self, drop_rate):
197
+ super(_VDropout, self).__init__()
198
+ self.drop_rate = drop_rate
199
+
200
+ def forward(self, x):
201
+ '''
202
+ :param x: `torch.Tensor` corresponding to vector channels
203
+ '''
204
+ if x is None:
205
+ return None
206
+ device = x.device
207
+ if not self.training:
208
+ return x
209
+ mask = torch.bernoulli(
210
+ (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
211
+ ).unsqueeze(-1)
212
+ x = mask * x / (1 - self.drop_rate)
213
+ return x
214
+
215
+ class Dropout(nn.Module):
216
+ '''
217
+ Combined dropout for tuples (s, V).
218
+ Takes tuples (s, V) as input and as output.
219
+ '''
220
+ def __init__(self, drop_rate):
221
+ super(Dropout, self).__init__()
222
+ self.sdropout = nn.Dropout(drop_rate)
223
+ self.vdropout = _VDropout(drop_rate)
224
+
225
+ def forward(self, x):
226
+ '''
227
+ :param x: tuple (s, V) of `torch.Tensor`,
228
+ or single `torch.Tensor`
229
+ (will be assumed to be scalar channels)
230
+ '''
231
+ if type(x) is torch.Tensor:
232
+ return self.sdropout(x)
233
+ s, v = x
234
+ return self.sdropout(s), self.vdropout(v)
235
+
236
+ class LayerNorm(nn.Module):
237
+ '''
238
+ Combined LayerNorm for tuples (s, V).
239
+ Takes tuples (s, V) as input and as output.
240
+ '''
241
+ def __init__(self, dims, tuple_io=True, eps=1e-8):
242
+ super(LayerNorm, self).__init__()
243
+ self.tuple_io = tuple_io
244
+ self.s, self.v = dims
245
+ self.scalar_norm = nn.LayerNorm(self.s)
246
+ self.eps = eps
247
+
248
+ def forward(self, x):
249
+ '''
250
+ :param x: tuple (s, V) of `torch.Tensor`,
251
+ or single `torch.Tensor`
252
+ (will be assumed to be scalar channels)
253
+ '''
254
+ if not self.v:
255
+ if self.tuple_io:
256
+ return self.scalar_norm(x[0]), None
257
+ return self.scalar_norm(x)
258
+ s, v = x
259
+ vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False, eps=self.eps)
260
+ nonzero_mask = (vn > 2 * self.eps)
261
+ vn = torch.sum(vn * nonzero_mask, dim=-2, keepdim=True
262
+ ) / (self.eps + torch.sum(nonzero_mask, dim=-2, keepdim=True))
263
+ vn = torch.sqrt(vn + self.eps)
264
+ v = nonzero_mask * (v / vn)
265
+ return self.scalar_norm(s), v
266
+
267
+ class GVPConv(MessagePassing):
268
+ '''
269
+ Graph convolution / message passing with Geometric Vector Perceptrons.
270
+ Takes in a graph with node and edge embeddings,
271
+ and returns new node embeddings.
272
+
273
+ This does NOT do residual updates and pointwise feedforward layers
274
+ ---see `GVPConvLayer`.
275
+
276
+ :param in_dims: input node embedding dimensions (n_scalar, n_vector)
277
+ :param out_dims: output node embedding dimensions (n_scalar, n_vector)
278
+ :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
279
+ :param n_layers: number of GVPs in the message function
280
+ :param module_list: preconstructed message function, overrides n_layers
281
+ :param aggr: should be "add" if some incoming edges are masked, as in
282
+ a masked autoregressive decoder architecture
283
+ '''
284
+ def __init__(self, in_dims, out_dims, edge_dims, n_layers=3,
285
+ vector_gate=False, module_list=None, aggr="mean", eps=1e-8,
286
+ activations=(F.relu, torch.sigmoid)):
287
+ super(GVPConv, self).__init__(aggr=aggr)
288
+ self.eps = eps
289
+ self.si, self.vi = in_dims
290
+ self.so, self.vo = out_dims
291
+ self.se, self.ve = edge_dims
292
+
293
+ module_list = module_list or []
294
+ if not module_list:
295
+ if n_layers == 1:
296
+ module_list.append(
297
+ GVP((2*self.si + self.se, 2*self.vi + self.ve),
298
+ (self.so, self.vo), activations=(None, None)))
299
+ else:
300
+ module_list.append(
301
+ GVP((2*self.si + self.se, 2*self.vi + self.ve), out_dims,
302
+ vector_gate=vector_gate, activations=activations)
303
+ )
304
+ for i in range(n_layers - 2):
305
+ module_list.append(GVP(out_dims, out_dims,
306
+ vector_gate=vector_gate))
307
+ module_list.append(GVP(out_dims, out_dims,
308
+ activations=(None, None)))
309
+ self.message_func = nn.Sequential(*module_list)
310
+
311
+ def forward(self, x, edge_index, edge_attr):
312
+ '''
313
+ :param x: tuple (s, V) of `torch.Tensor`
314
+ :param edge_index: array of shape [2, n_edges]
315
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
316
+ '''
317
+ x_s, x_v = x
318
+ message = self.propagate(edge_index,
319
+ s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
320
+ edge_attr=edge_attr)
321
+ return _split(message, self.vo)
322
+
323
+ def message(self, s_i, v_i, s_j, v_j, edge_attr):
324
+ v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
325
+ v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
326
+ message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
327
+ message = self.message_func(message)
328
+ return _merge(*message)
329
+
330
+
331
+ class GVPConvLayer(nn.Module):
332
+ '''
333
+ Full graph convolution / message passing layer with
334
+ Geometric Vector Perceptrons. Residually updates node embeddings with
335
+ aggregated incoming messages, applies a pointwise feedforward
336
+ network to node embeddings, and returns updated node embeddings.
337
+
338
+ To only compute the aggregated messages, see `GVPConv`.
339
+
340
+ :param node_dims: node embedding dimensions (n_scalar, n_vector)
341
+ :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
342
+ :param n_message: number of GVPs to use in message function
343
+ :param n_feedforward: number of GVPs to use in feedforward function
344
+ :param drop_rate: drop probability in all dropout layers
345
+ :param autoregressive: if `True`, this `GVPConvLayer` will be used
346
+ with a different set of input node embeddings for messages
347
+ where src >= dst
348
+ '''
349
+ def __init__(self, node_dims, edge_dims, vector_gate=False,
350
+ n_message=3, n_feedforward=2, drop_rate=.1,
351
+ autoregressive=False, attention_heads=0,
352
+ conv_activations=(F.relu, torch.sigmoid),
353
+ n_edge_gvps=0, layernorm=True, eps=1e-8):
354
+
355
+ super(GVPConvLayer, self).__init__()
356
+ if attention_heads == 0:
357
+ self.conv = GVPConv(
358
+ node_dims, node_dims, edge_dims, n_layers=n_message,
359
+ vector_gate=vector_gate,
360
+ aggr="add" if autoregressive else "mean",
361
+ activations=conv_activations,
362
+ eps=eps,
363
+ )
364
+ else:
365
+ raise NotImplementedError
366
+ if layernorm:
367
+ self.norm = nn.ModuleList([LayerNorm(node_dims, eps=eps) for _ in range(2)])
368
+ else:
369
+ self.norm = nn.ModuleList([nn.Identity() for _ in range(2)])
370
+ self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
371
+
372
+ ff_func = []
373
+ if n_feedforward == 1:
374
+ ff_func.append(GVP(node_dims, node_dims, activations=(None, None)))
375
+ else:
376
+ hid_dims = 4*node_dims[0], 2*node_dims[1]
377
+ ff_func.append(GVP(node_dims, hid_dims, vector_gate=vector_gate))
378
+ for i in range(n_feedforward-2):
379
+ ff_func.append(GVP(hid_dims, hid_dims, vector_gate=vector_gate))
380
+ ff_func.append(GVP(hid_dims, node_dims, activations=(None, None)))
381
+ self.ff_func = nn.Sequential(*ff_func)
382
+
383
+ self.edge_message_func = None
384
+ if n_edge_gvps > 0:
385
+ si, vi = node_dims
386
+ se, ve = edge_dims
387
+ module_list = [
388
+ GVP((2*si + se, 2*vi + ve), edge_dims, vector_gate=vector_gate)
389
+ ]
390
+ for i in range(n_edge_gvps - 2):
391
+ module_list.append(GVP(edge_dims, edge_dims,
392
+ vector_gate=vector_gate))
393
+ if n_edge_gvps > 1:
394
+ module_list.append(GVP(edge_dims, edge_dims,
395
+ activations=(None, None)))
396
+ self.edge_message_func = nn.Sequential(*module_list)
397
+ if layernorm:
398
+ self.edge_norm = LayerNorm(edge_dims, eps=eps)
399
+ else:
400
+ self.edge_norm = nn.Identity()
401
+ self.edge_dropout = Dropout(drop_rate)
402
+
403
+ def forward(self, x, edge_index, edge_attr,
404
+ autoregressive_x=None, node_mask=None):
405
+ '''
406
+ :param x: tuple (s, V) of `torch.Tensor`
407
+ :param edge_index: array of shape [2, n_edges]
408
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
409
+ :param autoregressive_x: tuple (s, V) of `torch.Tensor`.
410
+ If not `None`, will be used as srcqq node embeddings
411
+ for forming messages where src >= dst. The corrent node
412
+ embeddings `x` will still be the base of the update and the
413
+ pointwise feedforward.
414
+ :param node_mask: array of type `bool` to index into the first
415
+ dim of node embeddings (s, V). If not `None`, only
416
+ these nodes will be updated.
417
+ '''
418
+ if self.edge_message_func:
419
+ src, dst = edge_index
420
+ if autoregressive_x is None:
421
+ x_src = x[0][src], x[1][src]
422
+ else:
423
+ mask = (src < dst).unsqueeze(-1)
424
+ x_src = (
425
+ torch.where(mask, x[0][src], autoregressive_x[0][src]),
426
+ torch.where(mask.unsqueeze(-1), x[1][src],
427
+ autoregressive_x[1][src])
428
+ )
429
+ x_dst = x[0][dst], x[1][dst]
430
+ x_edge = (
431
+ torch.cat([x_src[0], edge_attr[0], x_dst[0]], dim=-1),
432
+ torch.cat([x_src[1], edge_attr[1], x_dst[1]], dim=-2)
433
+ )
434
+ edge_attr_dh = self.edge_message_func(x_edge)
435
+ edge_attr = self.edge_norm(tuple_sum(edge_attr,
436
+ self.edge_dropout(edge_attr_dh)))
437
+
438
+ if autoregressive_x is not None:
439
+ # Guarding this import here to remove the dependency on torch_scatter, since this isn't used
440
+ # in ESM-IF1
441
+ from torch_scatter import scatter_add
442
+ src, dst = edge_index
443
+ mask = src < dst
444
+ edge_index_forward = edge_index[:, mask]
445
+ edge_index_backward = edge_index[:, ~mask]
446
+ edge_attr_forward = tuple_index(edge_attr, mask)
447
+ edge_attr_backward = tuple_index(edge_attr, ~mask)
448
+
449
+ dh = tuple_sum(
450
+ self.conv(x, edge_index_forward, edge_attr_forward),
451
+ self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
452
+ )
453
+
454
+ count = scatter_add(torch.ones_like(dst), dst,
455
+ dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
456
+
457
+ dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
458
+
459
+ else:
460
+ dh = self.conv(x, edge_index, edge_attr)
461
+
462
+ if node_mask is not None:
463
+ x_ = x
464
+ x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
465
+
466
+ x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
467
+
468
+ dh = self.ff_func(x)
469
+ x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
470
+
471
+ if node_mask is not None:
472
+ x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
473
+ x = x_
474
+
475
+ return x, edge_attr
esm/inverse_folding/gvp_transformer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ from typing import Any, Dict, List, Optional, Tuple, NamedTuple
8
+ import torch
9
+ from torch import nn
10
+ from torch import Tensor
11
+ import torch.nn.functional as F
12
+ from scipy.spatial import transform
13
+
14
+ from esm.data import Alphabet
15
+
16
+ from .features import DihedralFeatures
17
+ from .gvp_encoder import GVPEncoder
18
+ from .gvp_utils import unflatten_graph
19
+ from .gvp_transformer_encoder import GVPTransformerEncoder
20
+ from .transformer_decoder import TransformerDecoder
21
+ from .util import rotate, CoordBatchConverter
22
+
23
+
24
+ class GVPTransformerModel(nn.Module):
25
+ """
26
+ GVP-Transformer inverse folding model.
27
+
28
+ Architecture: Geometric GVP-GNN as initial layers, followed by
29
+ sequence-to-sequence Transformer encoder and decoder.
30
+ """
31
+
32
+ def __init__(self, args, alphabet):
33
+ super().__init__()
34
+ encoder_embed_tokens = self.build_embedding(
35
+ args, alphabet, args.encoder_embed_dim,
36
+ )
37
+ decoder_embed_tokens = self.build_embedding(
38
+ args, alphabet, args.decoder_embed_dim,
39
+ )
40
+ encoder = self.build_encoder(args, alphabet, encoder_embed_tokens)
41
+ decoder = self.build_decoder(args, alphabet, decoder_embed_tokens)
42
+ self.args = args
43
+ self.encoder = encoder
44
+ self.decoder = decoder
45
+
46
+ @classmethod
47
+ def build_encoder(cls, args, src_dict, embed_tokens):
48
+ encoder = GVPTransformerEncoder(args, src_dict, embed_tokens)
49
+ return encoder
50
+
51
+ @classmethod
52
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
53
+ decoder = TransformerDecoder(
54
+ args,
55
+ tgt_dict,
56
+ embed_tokens,
57
+ )
58
+ return decoder
59
+
60
+ @classmethod
61
+ def build_embedding(cls, args, dictionary, embed_dim):
62
+ num_embeddings = len(dictionary)
63
+ padding_idx = dictionary.padding_idx
64
+ emb = nn.Embedding(num_embeddings, embed_dim, padding_idx)
65
+ nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5)
66
+ nn.init.constant_(emb.weight[padding_idx], 0)
67
+ return emb
68
+
69
+ def forward(
70
+ self,
71
+ coords,
72
+ padding_mask,
73
+ confidence,
74
+ prev_output_tokens,
75
+ return_all_hiddens: bool = False,
76
+ features_only: bool = False,
77
+ ):
78
+ encoder_out = self.encoder(coords, padding_mask, confidence,
79
+ return_all_hiddens=return_all_hiddens)
80
+ logits, extra = self.decoder(
81
+ prev_output_tokens,
82
+ encoder_out=encoder_out,
83
+ features_only=features_only,
84
+ return_all_hiddens=return_all_hiddens,
85
+ )
86
+ return logits, extra
87
+
88
+ def sample(self, coords, partial_seq=None, temperature=1.0, confidence=None, device=None):
89
+ """
90
+ Samples sequences based on multinomial sampling (no beam search).
91
+
92
+ Args:
93
+ coords: L x 3 x 3 list representing one backbone
94
+ partial_seq: Optional, partial sequence with mask tokens if part of
95
+ the sequence is known
96
+ temperature: sampling temperature, use low temperature for higher
97
+ sequence recovery and high temperature for higher diversity
98
+ confidence: optional length L list of confidence scores for coordinates
99
+ """
100
+ L = len(coords)
101
+ # Convert to batch format
102
+ batch_converter = CoordBatchConverter(self.decoder.dictionary)
103
+ batch_coords, confidence, _, _, padding_mask = (
104
+ batch_converter([(coords, confidence, None)], device=device)
105
+ )
106
+
107
+ # Start with prepend token
108
+ mask_idx = self.decoder.dictionary.get_idx('<mask>')
109
+ sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int)
110
+ sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('<cath>')
111
+ if partial_seq is not None:
112
+ for i, c in enumerate(partial_seq):
113
+ sampled_tokens[0, i+1] = self.decoder.dictionary.get_idx(c)
114
+
115
+ # Save incremental states for faster sampling
116
+ incremental_state = dict()
117
+
118
+ # Run encoder only once
119
+ encoder_out = self.encoder(batch_coords, padding_mask, confidence)
120
+
121
+ # Make sure all tensors are on the same device if a GPU is present
122
+ if device:
123
+ sampled_tokens = sampled_tokens.to(device)
124
+
125
+ # Decode one token at a time
126
+ for i in range(1, L+1):
127
+ logits, _ = self.decoder(
128
+ sampled_tokens[:, :i],
129
+ encoder_out,
130
+ incremental_state=incremental_state,
131
+ )
132
+ logits = logits[0].transpose(0, 1)
133
+ logits /= temperature
134
+ probs = F.softmax(logits, dim=-1)
135
+ if sampled_tokens[0, i] == mask_idx:
136
+ sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1)
137
+ sampled_seq = sampled_tokens[0, 1:]
138
+
139
+ # Convert back to string via lookup
140
+ return ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq])
esm/inverse_folding/gvp_transformer_encoder.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Contents of this file were adapted from the open source fairseq repository.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import argparse
9
+ import math
10
+ from typing import Dict, List, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch import Tensor
15
+
16
+ from esm.modules import SinusoidalPositionalEmbedding
17
+ from .features import GVPInputFeaturizer, DihedralFeatures
18
+ from .gvp_encoder import GVPEncoder
19
+ from .transformer_layer import TransformerEncoderLayer
20
+ from .util import nan_to_num, get_rotation_frames, rotate, rbf
21
+
22
+
23
+ class GVPTransformerEncoder(nn.Module):
24
+ """
25
+ Transformer encoder consisting of *args.encoder.layers* layers. Each layer
26
+ is a :class:`TransformerEncoderLayer`.
27
+
28
+ Args:
29
+ args (argparse.Namespace): parsed command-line arguments
30
+ dictionary (~fairseq.data.Dictionary): encoding dictionary
31
+ embed_tokens (torch.nn.Embedding): input embedding
32
+ """
33
+
34
+ def __init__(self, args, dictionary, embed_tokens):
35
+ super().__init__()
36
+ self.args = args
37
+ self.dictionary = dictionary
38
+
39
+ self.dropout_module = nn.Dropout(args.dropout)
40
+
41
+ embed_dim = embed_tokens.embedding_dim
42
+ self.padding_idx = embed_tokens.padding_idx
43
+
44
+ self.embed_tokens = embed_tokens
45
+ self.embed_scale = math.sqrt(embed_dim)
46
+ self.embed_positions = SinusoidalPositionalEmbedding(
47
+ embed_dim,
48
+ self.padding_idx,
49
+ )
50
+ self.embed_gvp_input_features = nn.Linear(15, embed_dim)
51
+ self.embed_confidence = nn.Linear(16, embed_dim)
52
+ self.embed_dihedrals = DihedralFeatures(embed_dim)
53
+
54
+ gvp_args = argparse.Namespace()
55
+ for k, v in vars(args).items():
56
+ if k.startswith("gvp_"):
57
+ setattr(gvp_args, k[4:], v)
58
+ self.gvp_encoder = GVPEncoder(gvp_args)
59
+ gvp_out_dim = gvp_args.node_hidden_dim_scalar + (3 *
60
+ gvp_args.node_hidden_dim_vector)
61
+ self.embed_gvp_output = nn.Linear(gvp_out_dim, embed_dim)
62
+
63
+ self.layers = nn.ModuleList([])
64
+ self.layers.extend(
65
+ [self.build_encoder_layer(args) for i in range(args.encoder_layers)]
66
+ )
67
+ self.num_layers = len(self.layers)
68
+ self.layer_norm = nn.LayerNorm(embed_dim)
69
+
70
+ def build_encoder_layer(self, args):
71
+ return TransformerEncoderLayer(args)
72
+
73
+ def forward_embedding(self, coords, padding_mask, confidence):
74
+ """
75
+ Args:
76
+ coords: N, CA, C backbone coordinates in shape length x 3 (atoms) x 3
77
+ padding_mask: boolean Tensor (true for padding) of shape length
78
+ confidence: confidence scores between 0 and 1 of shape length
79
+ """
80
+ components = dict()
81
+ coord_mask = torch.all(torch.all(torch.isfinite(coords), dim=-1), dim=-1)
82
+ coords = nan_to_num(coords)
83
+ mask_tokens = (
84
+ padding_mask * self.dictionary.padding_idx +
85
+ ~padding_mask * self.dictionary.get_idx("<mask>")
86
+ )
87
+ components["tokens"] = self.embed_tokens(mask_tokens) * self.embed_scale
88
+ components["diherals"] = self.embed_dihedrals(coords)
89
+
90
+ # GVP encoder
91
+ gvp_out_scalars, gvp_out_vectors = self.gvp_encoder(coords,
92
+ coord_mask, padding_mask, confidence)
93
+ R = get_rotation_frames(coords)
94
+ # Rotate to local rotation frame for rotation-invariance
95
+ gvp_out_features = torch.cat([
96
+ gvp_out_scalars,
97
+ rotate(gvp_out_vectors, R.transpose(-2, -1)).flatten(-2, -1),
98
+ ], dim=-1)
99
+ components["gvp_out"] = self.embed_gvp_output(gvp_out_features)
100
+
101
+ components["confidence"] = self.embed_confidence(
102
+ rbf(confidence, 0., 1.))
103
+
104
+ # In addition to GVP encoder outputs, also directly embed GVP input node
105
+ # features to the Transformer
106
+ scalar_features, vector_features = GVPInputFeaturizer.get_node_features(
107
+ coords, coord_mask, with_coord_mask=False)
108
+ features = torch.cat([
109
+ scalar_features,
110
+ rotate(vector_features, R.transpose(-2, -1)).flatten(-2, -1),
111
+ ], dim=-1)
112
+ components["gvp_input_features"] = self.embed_gvp_input_features(features)
113
+
114
+ embed = sum(components.values())
115
+ # for k, v in components.items():
116
+ # print(k, torch.mean(v, dim=(0,1)), torch.std(v, dim=(0,1)))
117
+
118
+ x = embed
119
+ x = x + self.embed_positions(mask_tokens)
120
+ x = self.dropout_module(x)
121
+ return x, components
122
+
123
+ def forward(
124
+ self,
125
+ coords,
126
+ encoder_padding_mask,
127
+ confidence,
128
+ return_all_hiddens: bool = False,
129
+ ):
130
+ """
131
+ Args:
132
+ coords (Tensor): backbone coordinates
133
+ shape batch_size x num_residues x num_atoms (3 for N, CA, C) x 3
134
+ encoder_padding_mask (ByteTensor): the positions of
135
+ padding elements of shape `(batch_size x num_residues)`
136
+ confidence (Tensor): the confidence score of shape (batch_size x
137
+ num_residues). The value is between 0. and 1. for each residue
138
+ coordinate, or -1. if no coordinate is given
139
+ return_all_hiddens (bool, optional): also return all of the
140
+ intermediate hidden states (default: False).
141
+
142
+ Returns:
143
+ dict:
144
+ - **encoder_out** (Tensor): the last encoder layer's output of
145
+ shape `(num_residues, batch_size, embed_dim)`
146
+ - **encoder_padding_mask** (ByteTensor): the positions of
147
+ padding elements of shape `(batch_size, num_residues)`
148
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
149
+ of shape `(batch_size, num_residues, embed_dim)`
150
+ - **encoder_states** (List[Tensor]): all intermediate
151
+ hidden states of shape `(num_residues, batch_size, embed_dim)`.
152
+ Only populated if *return_all_hiddens* is True.
153
+ """
154
+ x, encoder_embedding = self.forward_embedding(coords,
155
+ encoder_padding_mask, confidence)
156
+ # account for padding while computing the representation
157
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
158
+
159
+ # B x T x C -> T x B x C
160
+ x = x.transpose(0, 1)
161
+
162
+ encoder_states = []
163
+
164
+ if return_all_hiddens:
165
+ encoder_states.append(x)
166
+
167
+ # encoder layers
168
+ for layer in self.layers:
169
+ x = layer(
170
+ x, encoder_padding_mask=encoder_padding_mask
171
+ )
172
+ if return_all_hiddens:
173
+ assert encoder_states is not None
174
+ encoder_states.append(x)
175
+
176
+ if self.layer_norm is not None:
177
+ x = self.layer_norm(x)
178
+
179
+ return {
180
+ "encoder_out": [x], # T x B x C
181
+ "encoder_padding_mask": [encoder_padding_mask], # B x T
182
+ "encoder_embedding": [encoder_embedding], # dictionary
183
+ "encoder_states": encoder_states, # List[T x B x C]
184
+ }
esm/inverse_folding/gvp_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ def flatten_graph(node_embeddings, edge_embeddings, edge_index):
10
+ """
11
+ Flattens the graph into a batch size one (with disconnected subgraphs for
12
+ each example) to be compatible with pytorch-geometric package.
13
+ Args:
14
+ node_embeddings: node embeddings in tuple form (scalar, vector)
15
+ - scalar: shape batch size x nodes x node_embed_dim
16
+ - vector: shape batch size x nodes x node_embed_dim x 3
17
+ edge_embeddings: edge embeddings of in tuple form (scalar, vector)
18
+ - scalar: shape batch size x edges x edge_embed_dim
19
+ - vector: shape batch size x edges x edge_embed_dim x 3
20
+ edge_index: shape batch_size x 2 (source node and target node) x edges
21
+ Returns:
22
+ node_embeddings: node embeddings in tuple form (scalar, vector)
23
+ - scalar: shape batch total_nodes x node_embed_dim
24
+ - vector: shape batch total_nodes x node_embed_dim x 3
25
+ edge_embeddings: edge embeddings of in tuple form (scalar, vector)
26
+ - scalar: shape batch total_edges x edge_embed_dim
27
+ - vector: shape batch total_edges x edge_embed_dim x 3
28
+ edge_index: shape 2 x total_edges
29
+ """
30
+ x_s, x_v = node_embeddings
31
+ e_s, e_v = edge_embeddings
32
+ batch_size, N = x_s.shape[0], x_s.shape[1]
33
+ node_embeddings = (torch.flatten(x_s, 0, 1), torch.flatten(x_v, 0, 1))
34
+ edge_embeddings = (torch.flatten(e_s, 0, 1), torch.flatten(e_v, 0, 1))
35
+
36
+ edge_mask = torch.any(edge_index != -1, dim=1)
37
+ # Re-number the nodes by adding batch_idx * N to each batch
38
+ edge_index = edge_index + (torch.arange(batch_size, device=edge_index.device) *
39
+ N).unsqueeze(-1).unsqueeze(-1)
40
+ edge_index = edge_index.permute(1, 0, 2).flatten(1, 2)
41
+ edge_mask = edge_mask.flatten()
42
+ edge_index = edge_index[:, edge_mask]
43
+ edge_embeddings = (
44
+ edge_embeddings[0][edge_mask, :],
45
+ edge_embeddings[1][edge_mask, :]
46
+ )
47
+ return node_embeddings, edge_embeddings, edge_index
48
+
49
+
50
+ def unflatten_graph(node_embeddings, batch_size):
51
+ """
52
+ Unflattens node embeddings.
53
+ Args:
54
+ node_embeddings: node embeddings in tuple form (scalar, vector)
55
+ - scalar: shape batch total_nodes x node_embed_dim
56
+ - vector: shape batch total_nodes x node_embed_dim x 3
57
+ batch_size: int
58
+ Returns:
59
+ node_embeddings: node embeddings in tuple form (scalar, vector)
60
+ - scalar: shape batch size x nodes x node_embed_dim
61
+ - vector: shape batch size x nodes x node_embed_dim x 3
62
+ """
63
+ x_s, x_v = node_embeddings
64
+ x_s = x_s.reshape(batch_size, -1, x_s.shape[1])
65
+ x_v = x_v.reshape(batch_size, -1, x_v.shape[1], x_v.shape[2])
66
+ return (x_s, x_v)
67
+
68
+
esm/inverse_folding/multichain_util.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import biotite.structure
7
+ import numpy as np
8
+ import torch
9
+ from typing import Sequence, Tuple, List
10
+
11
+ from esm.inverse_folding.util import (
12
+ load_structure,
13
+ extract_coords_from_structure,
14
+ load_coords,
15
+ get_sequence_loss,
16
+ get_encoder_output,
17
+ )
18
+
19
+
20
+ def extract_coords_from_complex(structure: biotite.structure.AtomArray):
21
+ """
22
+ Args:
23
+ structure: biotite AtomArray
24
+ Returns:
25
+ Tuple (coords_list, seq_list)
26
+ - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
27
+ coordinates representing the backbone of each chain
28
+ - seqs: Dictionary mapping chain ids to native sequences of each chain
29
+ """
30
+ coords = {}
31
+ seqs = {}
32
+ all_chains = biotite.structure.get_chains(structure)
33
+ for chain_id in all_chains:
34
+ chain = structure[structure.chain_id == chain_id]
35
+ coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain)
36
+ return coords, seqs
37
+
38
+
39
+ def load_complex_coords(fpath, chains):
40
+ """
41
+ Args:
42
+ fpath: filepath to either pdb or cif file
43
+ chains: the chain ids (the order matters for autoregressive model)
44
+ Returns:
45
+ Tuple (coords_list, seq_list)
46
+ - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
47
+ coordinates representing the backbone of each chain
48
+ - seqs: Dictionary mapping chain ids to native sequences of each chain
49
+ """
50
+ structure = load_structure(fpath, chains)
51
+ return extract_coords_from_complex(structure)
52
+
53
+
54
+ def _concatenate_coords(coords, target_chain_id, padding_length=10):
55
+ """
56
+ Args:
57
+ coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
58
+ coordinates representing the backbone of each chain
59
+ target_chain_id: The chain id to sample sequences for
60
+ padding_length: Length of padding between concatenated chains
61
+ Returns:
62
+ Tuple (coords, seq)
63
+ - coords is an L x 3 x 3 array for N, CA, C coordinates, a
64
+ concatenation of the chains with padding in between
65
+ - seq is the extracted sequence, with padding tokens inserted
66
+ between the concatenated chains
67
+ """
68
+ pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
69
+ # For best performance, put the target chain first in concatenation.
70
+ coords_list = [coords[target_chain_id]]
71
+ for chain_id in coords:
72
+ if chain_id == target_chain_id:
73
+ continue
74
+ coords_list.append(pad_coords)
75
+ coords_list.append(coords[chain_id])
76
+ coords_concatenated = np.concatenate(coords_list, axis=0)
77
+ return coords_concatenated
78
+
79
+
80
+ def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1.,
81
+ padding_length=10):
82
+ """
83
+ Samples sequence for one chain in a complex.
84
+ Args:
85
+ model: An instance of the GVPTransformer model
86
+ coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
87
+ coordinates representing the backbone of each chain
88
+ target_chain_id: The chain id to sample sequences for
89
+ padding_length: padding length in between chains
90
+ Returns:
91
+ Sampled sequence for the target chain
92
+ """
93
+ target_chain_len = coords[target_chain_id].shape[0]
94
+ all_coords = _concatenate_coords(coords, target_chain_id)
95
+ device = next(model.parameters()).device
96
+
97
+ # Supply padding tokens for other chains to avoid unused sampling for speed
98
+ padding_pattern = ['<pad>'] * all_coords.shape[0]
99
+ for i in range(target_chain_len):
100
+ padding_pattern[i] = '<mask>'
101
+ sampled = model.sample(all_coords, partial_seq=padding_pattern,
102
+ temperature=temperature, device=device)
103
+ sampled = sampled[:target_chain_len]
104
+ return sampled
105
+
106
+
107
+ def score_sequence_in_complex(model, alphabet, coords, target_chain_id,
108
+ target_seq, padding_length=10):
109
+ """
110
+ Scores sequence for one chain in a complex.
111
+ Args:
112
+ model: An instance of the GVPTransformer model
113
+ alphabet: Alphabet for the model
114
+ coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
115
+ coordinates representing the backbone of each chain
116
+ target_chain_id: The chain id to sample sequences for
117
+ target_seq: Target sequence for the target chain for scoring.
118
+ padding_length: padding length in between chains
119
+ Returns:
120
+ Tuple (ll_fullseq, ll_withcoord)
121
+ - ll_fullseq: Average log-likelihood over the full target chain
122
+ - ll_withcoord: Average log-likelihood in target chain excluding those
123
+ residues without coordinates
124
+ """
125
+ all_coords = _concatenate_coords(coords, target_chain_id)
126
+
127
+ loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords,
128
+ target_seq)
129
+ ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(
130
+ ~target_padding_mask)
131
+
132
+ # Also calculate average when excluding masked portions
133
+ coord_mask = np.all(np.isfinite(coords[target_chain_id]), axis=(-1, -2))
134
+ ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
135
+ return ll_fullseq, ll_withcoord
136
+
137
+
138
+ def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id):
139
+ """
140
+ Args:
141
+ model: An instance of the GVPTransformer model
142
+ alphabet: Alphabet for the model
143
+ coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
144
+ coordinates representing the backbone of each chain
145
+ target_chain_id: The chain id to sample sequences for
146
+ Returns:
147
+ Dictionary mapping chain id to encoder output for each chain
148
+ """
149
+ all_coords = _concatenate_coords(coords, target_chain_id)
150
+ all_rep = get_encoder_output(model, alphabet, all_coords)
151
+ target_chain_len = coords[target_chain_id].shape[0]
152
+ return all_rep[:target_chain_len]
esm/inverse_folding/transformer_decoder.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Contents of this file were adapted from the open source fairseq repository.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+
15
+ from esm.modules import SinusoidalPositionalEmbedding
16
+ from .transformer_layer import TransformerDecoderLayer
17
+
18
+
19
+ def fill_with_neg_inf(t):
20
+ """FP16-compatible function that fills a tensor with -inf."""
21
+ return t.float().fill_(float("-inf")).type_as(t)
22
+
23
+
24
+ class TransformerDecoder(nn.Module):
25
+ """
26
+ Transformer decoder consisting of *args.decoder.layers* layers. Each layer
27
+ is a :class:`TransformerDecoderLayer`.
28
+
29
+ Args:
30
+ args (argparse.Namespace): parsed command-line arguments
31
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
32
+ embed_tokens (torch.nn.Embedding): output embedding
33
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
34
+ (default: False).
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ args,
40
+ dictionary,
41
+ embed_tokens,
42
+ ):
43
+ super().__init__()
44
+ self.args = args
45
+ self.dictionary = dictionary
46
+ self._future_mask = torch.empty(0)
47
+
48
+ self.dropout_module = nn.Dropout(args.dropout)
49
+
50
+ input_embed_dim = embed_tokens.embedding_dim
51
+ embed_dim = args.decoder_embed_dim
52
+ self.embed_dim = embed_dim
53
+
54
+ self.padding_idx = embed_tokens.padding_idx
55
+
56
+ self.embed_tokens = embed_tokens
57
+ self.embed_scale = math.sqrt(embed_dim)
58
+
59
+ self.project_in_dim = (
60
+ nn.Linear(input_embed_dim, embed_dim, bias=False)
61
+ if embed_dim != input_embed_dim
62
+ else None
63
+ )
64
+ self.embed_positions = SinusoidalPositionalEmbedding(
65
+ embed_dim,
66
+ self.padding_idx,
67
+ )
68
+
69
+ self.layers = nn.ModuleList([])
70
+ self.layers.extend(
71
+ [
72
+ self.build_decoder_layer(args)
73
+ for _ in range(args.decoder_layers)
74
+ ]
75
+ )
76
+ self.num_layers = len(self.layers)
77
+ self.layer_norm = nn.LayerNorm(embed_dim)
78
+
79
+ self.build_output_projection(args, dictionary)
80
+
81
+ def build_output_projection(self, args, dictionary):
82
+ self.output_projection = nn.Linear(
83
+ args.decoder_embed_dim, len(dictionary), bias=False
84
+ )
85
+ nn.init.normal_(
86
+ self.output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
87
+ )
88
+
89
+ def build_decoder_layer(self, args):
90
+ return TransformerDecoderLayer(args)
91
+
92
+ def forward(
93
+ self,
94
+ prev_output_tokens,
95
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
96
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
97
+ features_only: bool = False,
98
+ return_all_hiddens: bool = False,
99
+ ):
100
+ """
101
+ Args:
102
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
103
+ `(batch, tgt_len)`, for teacher forcing
104
+ encoder_out (optional): output from the encoder, used for
105
+ encoder-side attention, should be of size T x B x C
106
+ incremental_state (dict): dictionary used for storing state during
107
+ :ref:`Incremental decoding`
108
+ features_only (bool, optional): only return features without
109
+ applying output layer (default: False).
110
+
111
+ Returns:
112
+ tuple:
113
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
114
+ - a dictionary with any model-specific outputs
115
+ """
116
+
117
+ x, extra = self.extract_features(
118
+ prev_output_tokens,
119
+ encoder_out=encoder_out,
120
+ incremental_state=incremental_state,
121
+ )
122
+
123
+ if not features_only:
124
+ x = self.output_layer(x)
125
+ x = x.transpose(1, 2) # B x T x C -> B x C x T
126
+ return x, extra
127
+
128
+ def extract_features(
129
+ self,
130
+ prev_output_tokens,
131
+ encoder_out: Optional[Dict[str, List[Tensor]]],
132
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
133
+ ):
134
+ """
135
+ Similar to *forward* but only return features.
136
+
137
+ Includes several features from "Jointly Learning to Align and
138
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
139
+
140
+ Returns:
141
+ tuple:
142
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
143
+ - a dictionary with any model-specific outputs
144
+ """
145
+ bs, slen = prev_output_tokens.size()
146
+
147
+ enc: Optional[Tensor] = None
148
+ padding_mask: Optional[Tensor] = None
149
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
150
+ enc = encoder_out["encoder_out"][0]
151
+ assert (
152
+ enc.size()[1] == bs
153
+ ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
154
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
155
+ padding_mask = encoder_out["encoder_padding_mask"][0]
156
+
157
+ # embed positions
158
+ positions = self.embed_positions(
159
+ prev_output_tokens
160
+ )
161
+
162
+ if incremental_state is not None:
163
+ prev_output_tokens = prev_output_tokens[:, -1:]
164
+ positions = positions[:, -1:]
165
+
166
+ # embed tokens and positions
167
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
168
+
169
+ if self.project_in_dim is not None:
170
+ x = self.project_in_dim(x)
171
+
172
+ x += positions
173
+
174
+ x = self.dropout_module(x)
175
+
176
+ # B x T x C -> T x B x C
177
+ x = x.transpose(0, 1)
178
+
179
+ self_attn_padding_mask: Optional[Tensor] = None
180
+ if prev_output_tokens.eq(self.padding_idx).any():
181
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
182
+
183
+ # decoder layers
184
+ attn: Optional[Tensor] = None
185
+ inner_states: List[Optional[Tensor]] = [x]
186
+ for idx, layer in enumerate(self.layers):
187
+ if incremental_state is None:
188
+ self_attn_mask = self.buffered_future_mask(x)
189
+ else:
190
+ self_attn_mask = None
191
+
192
+ x, layer_attn, _ = layer(
193
+ x,
194
+ enc,
195
+ padding_mask,
196
+ incremental_state,
197
+ self_attn_mask=self_attn_mask,
198
+ self_attn_padding_mask=self_attn_padding_mask,
199
+ need_attn=False,
200
+ need_head_weights=False,
201
+ )
202
+ inner_states.append(x)
203
+
204
+ if self.layer_norm is not None:
205
+ x = self.layer_norm(x)
206
+
207
+ # T x B x C -> B x C x T
208
+ x = x.transpose(0, 1)
209
+
210
+ return x, {"inner_states": inner_states}
211
+
212
+ def output_layer(self, features):
213
+ """Project features to the vocabulary size."""
214
+ return self.output_projection(features)
215
+
216
+ def buffered_future_mask(self, tensor):
217
+ dim = tensor.size(0)
218
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
219
+ if (
220
+ self._future_mask.size(0) == 0
221
+ or (not self._future_mask.device == tensor.device)
222
+ or self._future_mask.size(0) < dim
223
+ ):
224
+ self._future_mask = torch.triu(
225
+ fill_with_neg_inf(torch.zeros([dim, dim])), 1
226
+ )
227
+ self._future_mask = self._future_mask.to(tensor)
228
+ return self._future_mask[:dim, :dim]
esm/inverse_folding/transformer_layer.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Contents of this file were adapted from the open source fairseq repository.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from esm.multihead_attention import MultiheadAttention
14
+ from torch import Tensor
15
+
16
+
17
+ class TransformerEncoderLayer(nn.Module):
18
+ """Encoder layer block.
19
+ `layernorm -> dropout -> add residual`
20
+
21
+ Args:
22
+ args (argparse.Namespace): parsed command-line arguments
23
+ """
24
+
25
+ def __init__(self, args):
26
+ super().__init__()
27
+ self.args = args
28
+ self.embed_dim = args.encoder_embed_dim
29
+ self.self_attn = self.build_self_attention(self.embed_dim, args)
30
+ self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim)
31
+ self.dropout_module = nn.Dropout(args.dropout)
32
+ self.activation_fn = F.relu
33
+ self.fc1 = self.build_fc1(
34
+ self.embed_dim,
35
+ args.encoder_ffn_embed_dim,
36
+ )
37
+ self.fc2 = self.build_fc2(
38
+ args.encoder_ffn_embed_dim,
39
+ self.embed_dim,
40
+ )
41
+
42
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
43
+
44
+ def build_fc1(self, input_dim, output_dim):
45
+ return nn.Linear(input_dim, output_dim)
46
+
47
+ def build_fc2(self, input_dim, output_dim):
48
+ return nn.Linear(input_dim, output_dim)
49
+
50
+ def build_self_attention(self, embed_dim, args):
51
+ return MultiheadAttention(
52
+ embed_dim,
53
+ args.encoder_attention_heads,
54
+ dropout=args.attention_dropout,
55
+ self_attention=True,
56
+ )
57
+
58
+ def residual_connection(self, x, residual):
59
+ return residual + x
60
+
61
+ def forward(
62
+ self,
63
+ x,
64
+ encoder_padding_mask: Optional[Tensor],
65
+ attn_mask: Optional[Tensor] = None,
66
+ ):
67
+ """
68
+ Args:
69
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
70
+ encoder_padding_mask (ByteTensor): binary ByteTensor of shape
71
+ `(batch, seq_len)` where padding elements are indicated by ``1``.
72
+ attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
73
+ where `tgt_len` is the length of output and `src_len` is the
74
+ length of input, though here both are equal to `seq_len`.
75
+ `attn_mask[tgt_i, src_j] = 1` means that when calculating the
76
+ embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
77
+ useful for strided self-attention.
78
+
79
+ Returns:
80
+ encoded output of shape `(seq_len, batch, embed_dim)`
81
+ """
82
+ # anything in original attn_mask = 1, becomes -1e8
83
+ # anything in original attn_mask = 0, becomes 0
84
+ # Note that we cannot use -inf here, because at some edge cases,
85
+ # the attention weight (before softmax) for some padded element in query
86
+ # will become -inf, which results in NaN in model parameters
87
+ if attn_mask is not None:
88
+ attn_mask = attn_mask.masked_fill(
89
+ attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
90
+ )
91
+
92
+ residual = x
93
+ x = self.self_attn_layer_norm(x)
94
+ x, _ = self.self_attn(
95
+ query=x,
96
+ key=x,
97
+ value=x,
98
+ key_padding_mask=encoder_padding_mask,
99
+ need_weights=False,
100
+ attn_mask=attn_mask,
101
+ )
102
+ x = self.dropout_module(x)
103
+ x = self.residual_connection(x, residual)
104
+
105
+ residual = x
106
+ x = self.final_layer_norm(x)
107
+ x = self.activation_fn(self.fc1(x))
108
+ x = self.fc2(x)
109
+ x = self.dropout_module(x)
110
+ x = self.residual_connection(x, residual)
111
+ return x
112
+
113
+
114
+ class TransformerDecoderLayer(nn.Module):
115
+ """Decoder layer block.
116
+ `layernorm -> dropout -> add residual`
117
+
118
+ Args:
119
+ args (argparse.Namespace): parsed command-line arguments
120
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
121
+ (default: False).
122
+ """
123
+
124
+ def __init__(
125
+ self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
126
+ ):
127
+ super().__init__()
128
+ self.embed_dim = args.decoder_embed_dim
129
+ self.dropout_module = nn.Dropout(args.dropout)
130
+
131
+ self.self_attn = self.build_self_attention(
132
+ self.embed_dim,
133
+ args,
134
+ add_bias_kv=add_bias_kv,
135
+ add_zero_attn=add_zero_attn,
136
+ )
137
+ self.nh = self.self_attn.num_heads
138
+ self.head_dim = self.self_attn.head_dim
139
+
140
+ self.activation_fn = F.relu
141
+
142
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
143
+
144
+ if no_encoder_attn:
145
+ self.encoder_attn = None
146
+ self.encoder_attn_layer_norm = None
147
+ else:
148
+ self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
149
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
150
+
151
+ self.ffn_layernorm = (
152
+ LayerNorm(args.decoder_ffn_embed_dim)
153
+ if getattr(args, "scale_fc", False)
154
+ else None
155
+ )
156
+ self.w_resid = (
157
+ nn.Parameter(
158
+ torch.ones(
159
+ self.embed_dim,
160
+ ),
161
+ requires_grad=True,
162
+ )
163
+ if getattr(args, "scale_resids", False)
164
+ else None
165
+ )
166
+
167
+ self.fc1 = self.build_fc1(
168
+ self.embed_dim,
169
+ args.decoder_ffn_embed_dim,
170
+ )
171
+ self.fc2 = self.build_fc2(
172
+ args.decoder_ffn_embed_dim,
173
+ self.embed_dim,
174
+ )
175
+
176
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
177
+ self.need_attn = True
178
+
179
+ def build_fc1(self, input_dim, output_dim):
180
+ return nn.Linear(input_dim, output_dim)
181
+
182
+ def build_fc2(self, input_dim, output_dim):
183
+ return nn.Linear(input_dim, output_dim)
184
+
185
+ def build_self_attention(
186
+ self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
187
+ ):
188
+ return MultiheadAttention(
189
+ embed_dim,
190
+ args.decoder_attention_heads,
191
+ dropout=args.attention_dropout,
192
+ add_bias_kv=add_bias_kv,
193
+ add_zero_attn=add_zero_attn,
194
+ self_attention=True,
195
+ )
196
+
197
+ def build_encoder_attention(self, embed_dim, args):
198
+ return MultiheadAttention(
199
+ embed_dim,
200
+ args.decoder_attention_heads,
201
+ kdim=args.encoder_embed_dim,
202
+ vdim=args.encoder_embed_dim,
203
+ dropout=args.attention_dropout,
204
+ encoder_decoder_attention=True,
205
+ )
206
+
207
+ def residual_connection(self, x, residual):
208
+ return residual + x
209
+
210
+ def forward(
211
+ self,
212
+ x,
213
+ encoder_out: Optional[torch.Tensor] = None,
214
+ encoder_padding_mask: Optional[torch.Tensor] = None,
215
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
216
+ prev_self_attn_state: Optional[List[torch.Tensor]] = None,
217
+ prev_attn_state: Optional[List[torch.Tensor]] = None,
218
+ self_attn_mask: Optional[torch.Tensor] = None,
219
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
220
+ need_attn: bool = False,
221
+ need_head_weights: bool = False,
222
+ ):
223
+ """
224
+ Args:
225
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
226
+ encoder_padding_mask (ByteTensor, optional): binary
227
+ ByteTensor of shape `(batch, src_len)` where padding
228
+ elements are indicated by ``1``.
229
+ need_attn (bool, optional): return attention weights
230
+ need_head_weights (bool, optional): return attention weights
231
+ for each head (default: return average over heads).
232
+
233
+ Returns:
234
+ encoded output of shape `(seq_len, batch, embed_dim)`
235
+ """
236
+ if need_head_weights:
237
+ need_attn = True
238
+
239
+ residual = x
240
+ x = self.self_attn_layer_norm(x)
241
+ if prev_self_attn_state is not None:
242
+ prev_key, prev_value = prev_self_attn_state[:2]
243
+ saved_state: Dict[str, Optional[Tensor]] = {
244
+ "prev_key": prev_key,
245
+ "prev_value": prev_value,
246
+ }
247
+ if len(prev_self_attn_state) >= 3:
248
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
249
+ assert incremental_state is not None
250
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
251
+ _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
252
+ y = x
253
+
254
+ x, attn = self.self_attn(
255
+ query=x,
256
+ key=y,
257
+ value=y,
258
+ key_padding_mask=self_attn_padding_mask,
259
+ incremental_state=incremental_state,
260
+ need_weights=False,
261
+ attn_mask=self_attn_mask,
262
+ )
263
+ x = self.dropout_module(x)
264
+ x = self.residual_connection(x, residual)
265
+
266
+ if self.encoder_attn is not None and encoder_out is not None:
267
+ residual = x
268
+ x = self.encoder_attn_layer_norm(x)
269
+ if prev_attn_state is not None:
270
+ prev_key, prev_value = prev_attn_state[:2]
271
+ saved_state: Dict[str, Optional[Tensor]] = {
272
+ "prev_key": prev_key,
273
+ "prev_value": prev_value,
274
+ }
275
+ if len(prev_attn_state) >= 3:
276
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
277
+ assert incremental_state is not None
278
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
279
+
280
+ x, attn = self.encoder_attn(
281
+ query=x,
282
+ key=encoder_out,
283
+ value=encoder_out,
284
+ key_padding_mask=encoder_padding_mask,
285
+ incremental_state=incremental_state,
286
+ static_kv=True,
287
+ need_weights=need_attn or (not self.training and self.need_attn),
288
+ need_head_weights=need_head_weights,
289
+ )
290
+ x = self.dropout_module(x)
291
+ x = self.residual_connection(x, residual)
292
+
293
+ residual = x
294
+ x = self.final_layer_norm(x)
295
+
296
+ x = self.activation_fn(self.fc1(x))
297
+ if self.ffn_layernorm is not None:
298
+ x = self.ffn_layernorm(x)
299
+ x = self.fc2(x)
300
+ x = self.dropout_module(x)
301
+ if self.w_resid is not None:
302
+ residual = torch.mul(self.w_resid, residual)
303
+ x = self.residual_connection(x, residual)
304
+ return x, attn, None
esm/inverse_folding/util.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import math
8
+
9
+ import biotite.structure
10
+ from biotite.structure.io import pdbx, pdb
11
+ from biotite.structure.residues import get_residues
12
+ from biotite.structure import filter_backbone
13
+ from biotite.structure import get_chains
14
+ from biotite.sequence import ProteinSequence
15
+ import numpy as np
16
+ from scipy.spatial import transform
17
+ from scipy.stats import special_ortho_group
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.data as data
22
+ from typing import Sequence, Tuple, List
23
+
24
+ from esm.data import BatchConverter
25
+
26
+
27
+ def load_structure(fpath, chain=None):
28
+ """
29
+ Args:
30
+ fpath: filepath to either pdb or cif file
31
+ chain: the chain id or list of chain ids to load
32
+ Returns:
33
+ biotite.structure.AtomArray
34
+ """
35
+ if fpath.endswith('cif'):
36
+ with open(fpath) as fin:
37
+ pdbxf = pdbx.PDBxFile.read(fin)
38
+ structure = pdbx.get_structure(pdbxf, model=1)
39
+ elif fpath.endswith('pdb'):
40
+ with open(fpath) as fin:
41
+ pdbf = pdb.PDBFile.read(fin)
42
+ structure = pdb.get_structure(pdbf, model=1)
43
+ bbmask = filter_backbone(structure)
44
+ structure = structure[bbmask]
45
+ all_chains = get_chains(structure)
46
+ if len(all_chains) == 0:
47
+ raise ValueError('No chains found in the input file.')
48
+ if chain is None:
49
+ chain_ids = all_chains
50
+ elif isinstance(chain, list):
51
+ chain_ids = chain
52
+ else:
53
+ chain_ids = [chain]
54
+ for chain in chain_ids:
55
+ if chain not in all_chains:
56
+ raise ValueError(f'Chain {chain} not found in input file')
57
+ chain_filter = [a.chain_id in chain_ids for a in structure]
58
+ structure = structure[chain_filter]
59
+ return structure
60
+
61
+
62
+ def extract_coords_from_structure(structure: biotite.structure.AtomArray):
63
+ """
64
+ Args:
65
+ structure: An instance of biotite AtomArray
66
+ Returns:
67
+ Tuple (coords, seq)
68
+ - coords is an L x 3 x 3 array for N, CA, C coordinates
69
+ - seq is the extracted sequence
70
+ """
71
+ coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
72
+ residue_identities = get_residues(structure)[1]
73
+ seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
74
+ return coords, seq
75
+
76
+
77
+ def load_coords(fpath, chain):
78
+ """
79
+ Args:
80
+ fpath: filepath to either pdb or cif file
81
+ chain: the chain id
82
+ Returns:
83
+ Tuple (coords, seq)
84
+ - coords is an L x 3 x 3 array for N, CA, C coordinates
85
+ - seq is the extracted sequence
86
+ """
87
+ structure = load_structure(fpath, chain)
88
+ return extract_coords_from_structure(structure)
89
+
90
+
91
+ def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
92
+ """
93
+ Example for atoms argument: ["N", "CA", "C"]
94
+ """
95
+ def filterfn(s, axis=None):
96
+ filters = np.stack([s.atom_name == name for name in atoms], axis=1)
97
+ sum = filters.sum(0)
98
+ if not np.all(sum <= np.ones(filters.shape[1])):
99
+ raise RuntimeError("structure has multiple atoms with same name")
100
+ index = filters.argmax(0)
101
+ coords = s[index].coord
102
+ coords[sum == 0] = float("nan")
103
+ return coords
104
+
105
+ return biotite.structure.apply_residue_wise(struct, struct, filterfn)
106
+
107
+
108
+ def get_sequence_loss(model, alphabet, coords, seq):
109
+ device = next(model.parameters()).device
110
+ batch_converter = CoordBatchConverter(alphabet)
111
+ batch = [(coords, None, seq)]
112
+ coords, confidence, strs, tokens, padding_mask = batch_converter(
113
+ batch, device=device)
114
+
115
+ prev_output_tokens = tokens[:, :-1].to(device)
116
+ target = tokens[:, 1:]
117
+ target_padding_mask = (target == alphabet.padding_idx)
118
+ logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens)
119
+ loss = F.cross_entropy(logits, target, reduction='none')
120
+ loss = loss[0].cpu().detach().numpy()
121
+ target_padding_mask = target_padding_mask[0].cpu().numpy()
122
+ return loss, target_padding_mask
123
+
124
+
125
+ def score_sequence(model, alphabet, coords, seq):
126
+ loss, target_padding_mask = get_sequence_loss(model, alphabet, coords, seq)
127
+ ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask)
128
+ # Also calculate average when excluding masked portions
129
+ coord_mask = np.all(np.isfinite(coords), axis=(-1, -2))
130
+ ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
131
+ return ll_fullseq, ll_withcoord
132
+
133
+
134
+ def get_encoder_output(model, alphabet, coords):
135
+ device = next(model.parameters()).device
136
+ batch_converter = CoordBatchConverter(alphabet)
137
+ batch = [(coords, None, None)]
138
+ coords, confidence, strs, tokens, padding_mask = batch_converter(
139
+ batch, device=device)
140
+ encoder_out = model.encoder.forward(coords, padding_mask, confidence,
141
+ return_all_hiddens=False)
142
+ # remove beginning and end (bos and eos tokens)
143
+ return encoder_out['encoder_out'][0][1:-1, 0]
144
+
145
+
146
+ def rotate(v, R):
147
+ """
148
+ Rotates a vector by a rotation matrix.
149
+
150
+ Args:
151
+ v: 3D vector, tensor of shape (length x batch_size x channels x 3)
152
+ R: rotation matrix, tensor of shape (length x batch_size x 3 x 3)
153
+
154
+ Returns:
155
+ Rotated version of v by rotation matrix R.
156
+ """
157
+ R = R.unsqueeze(-3)
158
+ v = v.unsqueeze(-1)
159
+ return torch.sum(v * R, dim=-2)
160
+
161
+
162
+ def get_rotation_frames(coords):
163
+ """
164
+ Returns a local rotation frame defined by N, CA, C positions.
165
+
166
+ Args:
167
+ coords: coordinates, tensor of shape (batch_size x length x 3 x 3)
168
+ where the third dimension is in order of N, CA, C
169
+
170
+ Returns:
171
+ Local relative rotation frames in shape (batch_size x length x 3 x 3)
172
+ """
173
+ v1 = coords[:, :, 2] - coords[:, :, 1]
174
+ v2 = coords[:, :, 0] - coords[:, :, 1]
175
+ e1 = normalize(v1, dim=-1)
176
+ u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True)
177
+ e2 = normalize(u2, dim=-1)
178
+ e3 = torch.cross(e1, e2, dim=-1)
179
+ R = torch.stack([e1, e2, e3], dim=-2)
180
+ return R
181
+
182
+
183
+ def nan_to_num(ts, val=0.0):
184
+ """
185
+ Replaces nans in tensor with a fixed value.
186
+ """
187
+ val = torch.tensor(val, dtype=ts.dtype, device=ts.device)
188
+ return torch.where(~torch.isfinite(ts), val, ts)
189
+
190
+
191
+ def rbf(values, v_min, v_max, n_bins=16):
192
+ """
193
+ Returns RBF encodings in a new dimension at the end.
194
+ """
195
+ rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device)
196
+ rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
197
+ rbf_std = (v_max - v_min) / n_bins
198
+ v_expand = torch.unsqueeze(values, -1)
199
+ z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
200
+ return torch.exp(-z ** 2)
201
+
202
+
203
+ def norm(tensor, dim, eps=1e-8, keepdim=False):
204
+ """
205
+ Returns L2 norm along a dimension.
206
+ """
207
+ return torch.sqrt(
208
+ torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps)
209
+
210
+
211
+ def normalize(tensor, dim=-1):
212
+ """
213
+ Normalizes a tensor along a dimension after removing nans.
214
+ """
215
+ return nan_to_num(
216
+ torch.div(tensor, norm(tensor, dim=dim, keepdim=True))
217
+ )
218
+
219
+
220
+ class CoordBatchConverter(BatchConverter):
221
+ def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None):
222
+ """
223
+ Args:
224
+ raw_batch: List of tuples (coords, confidence, seq)
225
+ In each tuple,
226
+ coords: list of floats, shape L x 3 x 3
227
+ confidence: list of floats, shape L; or scalar float; or None
228
+ seq: string of length L
229
+ Returns:
230
+ coords: Tensor of shape batch_size x L x 3 x 3
231
+ confidence: Tensor of shape batch_size x L
232
+ strs: list of strings
233
+ tokens: LongTensor of shape batch_size x L
234
+ padding_mask: ByteTensor of shape batch_size x L
235
+ """
236
+ self.alphabet.cls_idx = self.alphabet.get_idx("<cath>")
237
+ batch = []
238
+ for coords, confidence, seq in raw_batch:
239
+ if confidence is None:
240
+ confidence = 1.
241
+ if isinstance(confidence, float) or isinstance(confidence, int):
242
+ confidence = [float(confidence)] * len(coords)
243
+ if seq is None:
244
+ seq = 'X' * len(coords)
245
+ batch.append(((coords, confidence), seq))
246
+
247
+ coords_and_confidence, strs, tokens = super().__call__(batch)
248
+
249
+ # pad beginning and end of each protein due to legacy reasons
250
+ coords = [
251
+ F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)
252
+ for cd, _ in coords_and_confidence
253
+ ]
254
+ confidence = [
255
+ F.pad(torch.tensor(cf), (1, 1), value=-1.)
256
+ for _, cf in coords_and_confidence
257
+ ]
258
+ coords = self.collate_dense_tensors(coords, pad_v=np.nan)
259
+ confidence = self.collate_dense_tensors(confidence, pad_v=-1.)
260
+ if device is not None:
261
+ coords = coords.to(device)
262
+ confidence = confidence.to(device)
263
+ tokens = tokens.to(device)
264
+ padding_mask = torch.isnan(coords[:,:,0,0])
265
+ coord_mask = torch.isfinite(coords.sum(-2).sum(-1))
266
+ confidence = confidence * coord_mask + (-1.) * padding_mask
267
+ return coords, confidence, strs, tokens, padding_mask
268
+
269
+ def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None):
270
+ """
271
+ Args:
272
+ coords_list: list of length batch_size, each item is a list of
273
+ floats in shape L x 3 x 3 to describe a backbone
274
+ confidence_list: one of
275
+ - None, default to highest confidence
276
+ - list of length batch_size, each item is a scalar
277
+ - list of length batch_size, each item is a list of floats of
278
+ length L to describe the confidence scores for the backbone
279
+ with values between 0. and 1.
280
+ seq_list: either None or a list of strings
281
+ Returns:
282
+ coords: Tensor of shape batch_size x L x 3 x 3
283
+ confidence: Tensor of shape batch_size x L
284
+ strs: list of strings
285
+ tokens: LongTensor of shape batch_size x L
286
+ padding_mask: ByteTensor of shape batch_size x L
287
+ """
288
+ batch_size = len(coords_list)
289
+ if confidence_list is None:
290
+ confidence_list = [None] * batch_size
291
+ if seq_list is None:
292
+ seq_list = [None] * batch_size
293
+ raw_batch = zip(coords_list, confidence_list, seq_list)
294
+ return self.__call__(raw_batch, device)
295
+
296
+ @staticmethod
297
+ def collate_dense_tensors(samples, pad_v):
298
+ """
299
+ Takes a list of tensors with the following dimensions:
300
+ [(d_11, ..., d_1K),
301
+ (d_21, ..., d_2K),
302
+ ...,
303
+ (d_N1, ..., d_NK)]
304
+ and stack + pads them into a single tensor of:
305
+ (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
306
+ """
307
+ if len(samples) == 0:
308
+ return torch.Tensor()
309
+ if len(set(x.dim() for x in samples)) != 1:
310
+ raise RuntimeError(
311
+ f"Samples has varying dimensions: {[x.dim() for x in samples]}"
312
+ )
313
+ (device,) = tuple(set(x.device for x in samples)) # assumes all on same device
314
+ max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
315
+ result = torch.empty(
316
+ len(samples), *max_shape, dtype=samples[0].dtype, device=device
317
+ )
318
+ result.fill_(pad_v)
319
+ for i in range(len(samples)):
320
+ result_i = result[i]
321
+ t = samples[i]
322
+ result_i[tuple(slice(0, k) for k in t.shape)] = t
323
+ return result
esm/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
esm/model/esm1.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from ..modules import (
13
+ TransformerLayer,
14
+ LearnedPositionalEmbedding,
15
+ SinusoidalPositionalEmbedding,
16
+ RobertaLMHead,
17
+ ESM1bLayerNorm,
18
+ ContactPredictionHead,
19
+ )
20
+
21
+
22
+ class ProteinBertModel(nn.Module):
23
+ @classmethod
24
+ def add_args(cls, parser):
25
+ parser.add_argument(
26
+ "--num_layers", default=36, type=int, metavar="N", help="number of layers"
27
+ )
28
+ parser.add_argument(
29
+ "--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension"
30
+ )
31
+ parser.add_argument(
32
+ "--logit_bias", action="store_true", help="whether to apply bias to logits"
33
+ )
34
+ parser.add_argument(
35
+ "--ffn_embed_dim",
36
+ default=5120,
37
+ type=int,
38
+ metavar="N",
39
+ help="embedding dimension for FFN",
40
+ )
41
+ parser.add_argument(
42
+ "--attention_heads",
43
+ default=20,
44
+ type=int,
45
+ metavar="N",
46
+ help="number of attention heads",
47
+ )
48
+
49
+ def __init__(self, args, alphabet):
50
+ super().__init__()
51
+ self.args = args
52
+ self.alphabet_size = len(alphabet)
53
+ self.padding_idx = alphabet.padding_idx
54
+ self.mask_idx = alphabet.mask_idx
55
+ self.cls_idx = alphabet.cls_idx
56
+ self.eos_idx = alphabet.eos_idx
57
+ self.prepend_bos = alphabet.prepend_bos
58
+ self.append_eos = alphabet.append_eos
59
+ self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False)
60
+ if self.args.arch == "roberta_large":
61
+ self.model_version = "ESM-1b"
62
+ self._init_submodules_esm1b()
63
+ else:
64
+ self.model_version = "ESM-1"
65
+ self._init_submodules_esm1()
66
+
67
+ def _init_submodules_common(self):
68
+ self.embed_tokens = nn.Embedding(
69
+ self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
70
+ )
71
+ self.layers = nn.ModuleList(
72
+ [
73
+ TransformerLayer(
74
+ self.args.embed_dim,
75
+ self.args.ffn_embed_dim,
76
+ self.args.attention_heads,
77
+ add_bias_kv=(self.model_version != "ESM-1b"),
78
+ use_esm1b_layer_norm=(self.model_version == "ESM-1b"),
79
+ )
80
+ for _ in range(self.args.layers)
81
+ ]
82
+ )
83
+
84
+ self.contact_head = ContactPredictionHead(
85
+ self.args.layers * self.args.attention_heads,
86
+ self.prepend_bos,
87
+ self.append_eos,
88
+ eos_idx=self.eos_idx,
89
+ )
90
+
91
+ def _init_submodules_esm1b(self):
92
+ self._init_submodules_common()
93
+ self.embed_scale = 1
94
+ self.embed_positions = LearnedPositionalEmbedding(
95
+ self.args.max_positions, self.args.embed_dim, self.padding_idx
96
+ )
97
+ self.emb_layer_norm_before = (
98
+ ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None
99
+ )
100
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
101
+ self.lm_head = RobertaLMHead(
102
+ embed_dim=self.args.embed_dim,
103
+ output_dim=self.alphabet_size,
104
+ weight=self.embed_tokens.weight,
105
+ )
106
+
107
+ def _init_submodules_esm1(self):
108
+ self._init_submodules_common()
109
+ self.embed_scale = math.sqrt(self.args.embed_dim)
110
+ self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx)
111
+ self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim)))
112
+ self.embed_out_bias = None
113
+ if self.args.final_bias:
114
+ self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size))
115
+
116
+ def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
117
+ if return_contacts:
118
+ need_head_weights = True
119
+
120
+ assert tokens.ndim == 2
121
+ padding_mask = tokens.eq(self.padding_idx) # B, T
122
+
123
+ x = self.embed_scale * self.embed_tokens(tokens)
124
+
125
+ if getattr(self.args, "token_dropout", False):
126
+ x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
127
+ # x: B x T x C
128
+ mask_ratio_train = 0.15 * 0.8
129
+ src_lengths = (~padding_mask).sum(-1)
130
+ mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths
131
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
132
+
133
+ x = x + self.embed_positions(tokens)
134
+
135
+ if self.model_version == "ESM-1b":
136
+ if self.emb_layer_norm_before:
137
+ x = self.emb_layer_norm_before(x)
138
+ if padding_mask is not None:
139
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
140
+
141
+ repr_layers = set(repr_layers)
142
+ hidden_representations = {}
143
+ if 0 in repr_layers:
144
+ hidden_representations[0] = x
145
+
146
+ if need_head_weights:
147
+ attn_weights = []
148
+
149
+ # (B, T, E) => (T, B, E)
150
+ x = x.transpose(0, 1)
151
+
152
+ if not padding_mask.any():
153
+ padding_mask = None
154
+
155
+ for layer_idx, layer in enumerate(self.layers):
156
+ x, attn = layer(
157
+ x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
158
+ )
159
+ if (layer_idx + 1) in repr_layers:
160
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
161
+ if need_head_weights:
162
+ # (H, B, T, T) => (B, H, T, T)
163
+ attn_weights.append(attn.transpose(1, 0))
164
+
165
+ if self.model_version == "ESM-1b":
166
+ x = self.emb_layer_norm_after(x)
167
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
168
+
169
+ # last hidden representation should have layer norm applied
170
+ if (layer_idx + 1) in repr_layers:
171
+ hidden_representations[layer_idx + 1] = x
172
+ x = self.lm_head(x)
173
+ else:
174
+ x = F.linear(x, self.embed_out, bias=self.embed_out_bias)
175
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
176
+
177
+ result = {"logits": x, "representations": hidden_representations}
178
+ if need_head_weights:
179
+ # attentions: B x L x H x T x T
180
+ attentions = torch.stack(attn_weights, 1)
181
+ if self.model_version == "ESM-1":
182
+ # ESM-1 models have an additional null-token for attention, which we remove
183
+ attentions = attentions[..., :-1]
184
+ if padding_mask is not None:
185
+ attention_mask = 1 - padding_mask.type_as(attentions)
186
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
187
+ attentions = attentions * attention_mask[:, None, None, :, :]
188
+ result["attentions"] = attentions
189
+ if return_contacts:
190
+ contacts = self.contact_head(tokens, attentions)
191
+ result["contacts"] = contacts
192
+
193
+ return result
194
+
195
+ def predict_contacts(self, tokens):
196
+ return self(tokens, return_contacts=True)["contacts"]
197
+
198
+ @property
199
+ def num_layers(self):
200
+ return self.args.layers
esm/model/esm2.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import esm
11
+ from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
12
+
13
+
14
+ class ESM2(nn.Module):
15
+ def __init__(
16
+ self,
17
+ num_layers: int = 33,
18
+ embed_dim: int = 1280,
19
+ attention_heads: int = 20,
20
+ alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
21
+ token_dropout: bool = True,
22
+ ):
23
+ super().__init__()
24
+ self.num_layers = num_layers
25
+ self.embed_dim = embed_dim
26
+ self.attention_heads = attention_heads
27
+ if not isinstance(alphabet, esm.data.Alphabet):
28
+ alphabet = esm.data.Alphabet.from_architecture(alphabet)
29
+ self.alphabet = alphabet
30
+ self.alphabet_size = len(alphabet)
31
+ self.padding_idx = alphabet.padding_idx
32
+ self.mask_idx = alphabet.mask_idx
33
+ self.cls_idx = alphabet.cls_idx
34
+ self.eos_idx = alphabet.eos_idx
35
+ self.prepend_bos = alphabet.prepend_bos
36
+ self.append_eos = alphabet.append_eos
37
+ self.token_dropout = token_dropout
38
+
39
+ self._init_submodules()
40
+
41
+ def _init_submodules(self):
42
+ self.embed_scale = 1
43
+ self.embed_tokens = nn.Embedding(
44
+ self.alphabet_size,
45
+ self.embed_dim,
46
+ padding_idx=self.padding_idx,
47
+ )
48
+
49
+ self.layers = nn.ModuleList(
50
+ [
51
+ TransformerLayer(
52
+ self.embed_dim,
53
+ 4 * self.embed_dim,
54
+ self.attention_heads,
55
+ add_bias_kv=False,
56
+ use_esm1b_layer_norm=True,
57
+ use_rotary_embeddings=True,
58
+ )
59
+ for _ in range(self.num_layers)
60
+ ]
61
+ )
62
+
63
+ self.contact_head = ContactPredictionHead(
64
+ self.num_layers * self.attention_heads,
65
+ self.prepend_bos,
66
+ self.append_eos,
67
+ eos_idx=self.eos_idx,
68
+ )
69
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
70
+
71
+ self.lm_head = RobertaLMHead(
72
+ embed_dim=self.embed_dim,
73
+ output_dim=self.alphabet_size,
74
+ weight=self.embed_tokens.weight,
75
+ )
76
+
77
+ def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
78
+ if return_contacts:
79
+ need_head_weights = True
80
+
81
+ assert tokens.ndim == 2
82
+ padding_mask = tokens.eq(self.padding_idx) # B, T
83
+
84
+ x = self.embed_scale * self.embed_tokens(tokens)
85
+
86
+ if self.token_dropout:
87
+ x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
88
+ # x: B x T x C
89
+ mask_ratio_train = 0.15 * 0.8
90
+ src_lengths = (~padding_mask).sum(-1)
91
+ mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
92
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
93
+
94
+ if padding_mask is not None:
95
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
96
+
97
+ repr_layers = set(repr_layers)
98
+ hidden_representations = {}
99
+ if 0 in repr_layers:
100
+ hidden_representations[0] = x
101
+
102
+ if need_head_weights:
103
+ attn_weights = []
104
+
105
+ # (B, T, E) => (T, B, E)
106
+ x = x.transpose(0, 1)
107
+
108
+ if not padding_mask.any():
109
+ padding_mask = None
110
+
111
+ for layer_idx, layer in enumerate(self.layers):
112
+ x, attn = layer(
113
+ x,
114
+ self_attn_padding_mask=padding_mask,
115
+ need_head_weights=need_head_weights,
116
+ )
117
+ if (layer_idx + 1) in repr_layers:
118
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
119
+ if need_head_weights:
120
+ # (H, B, T, T) => (B, H, T, T)
121
+ attn_weights.append(attn.transpose(1, 0))
122
+
123
+ x = self.emb_layer_norm_after(x)
124
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
125
+
126
+ # last hidden representation should have layer norm applied
127
+ if (layer_idx + 1) in repr_layers:
128
+ hidden_representations[layer_idx + 1] = x
129
+ x = self.lm_head(x)
130
+
131
+ result = {"logits": x, "representations": hidden_representations}
132
+ if need_head_weights:
133
+ # attentions: B x L x H x T x T
134
+ attentions = torch.stack(attn_weights, 1)
135
+ if padding_mask is not None:
136
+ attention_mask = 1 - padding_mask.type_as(attentions)
137
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
138
+ attentions = attentions * attention_mask[:, None, None, :, :]
139
+ result["attentions"] = attentions
140
+ if return_contacts:
141
+ contacts = self.contact_head(tokens, attentions)
142
+ result["contacts"] = contacts
143
+
144
+ return result
145
+
146
+ def predict_contacts(self, tokens):
147
+ return self(tokens, return_contacts=True)["contacts"]
esm/model/msa_transformer.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ..modules import (
10
+ AxialTransformerLayer,
11
+ LearnedPositionalEmbedding,
12
+ RobertaLMHead,
13
+ ESM1bLayerNorm,
14
+ ContactPredictionHead,
15
+ )
16
+
17
+ from ..axial_attention import RowSelfAttention, ColumnSelfAttention
18
+
19
+
20
+
21
+ class MSATransformer(nn.Module):
22
+ @classmethod
23
+ def add_args(cls, parser):
24
+ # fmt: off
25
+ parser.add_argument(
26
+ "--num_layers",
27
+ default=12,
28
+ type=int,
29
+ metavar="N",
30
+ help="number of layers"
31
+ )
32
+ parser.add_argument(
33
+ "--embed_dim",
34
+ default=768,
35
+ type=int,
36
+ metavar="N",
37
+ help="embedding dimension"
38
+ )
39
+ parser.add_argument(
40
+ "--logit_bias",
41
+ action="store_true",
42
+ help="whether to apply bias to logits"
43
+ )
44
+ parser.add_argument(
45
+ "--ffn_embed_dim",
46
+ default=3072,
47
+ type=int,
48
+ metavar="N",
49
+ help="embedding dimension for FFN",
50
+ )
51
+ parser.add_argument(
52
+ "--attention_heads",
53
+ default=12,
54
+ type=int,
55
+ metavar="N",
56
+ help="number of attention heads",
57
+ )
58
+ parser.add_argument(
59
+ "--dropout",
60
+ default=0.1,
61
+ type=float,
62
+ help="Dropout to apply."
63
+ )
64
+ parser.add_argument(
65
+ "--attention_dropout",
66
+ default=0.1,
67
+ type=float,
68
+ help="Dropout to apply."
69
+ )
70
+ parser.add_argument(
71
+ "--activation_dropout",
72
+ default=0.1,
73
+ type=float,
74
+ help="Dropout to apply."
75
+ )
76
+ parser.add_argument(
77
+ "--max_tokens_per_msa",
78
+ default=2 ** 14,
79
+ type=int,
80
+ help=(
81
+ "Used during inference to batch attention computations in a single "
82
+ "forward pass. This allows increased input sizes with less memory."
83
+ ),
84
+ )
85
+ # fmt: on
86
+
87
+ def __init__(self, args, alphabet):
88
+ super().__init__()
89
+ self.args = args
90
+ self.alphabet_size = len(alphabet)
91
+ self.padding_idx = alphabet.padding_idx
92
+ self.mask_idx = alphabet.mask_idx
93
+ self.cls_idx = alphabet.cls_idx
94
+ self.eos_idx = alphabet.eos_idx
95
+ self.prepend_bos = alphabet.prepend_bos
96
+ self.append_eos = alphabet.append_eos
97
+
98
+ self.embed_tokens = nn.Embedding(
99
+ self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
100
+ )
101
+
102
+ if getattr(self.args, "embed_positions_msa", False):
103
+ emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim)
104
+ self.msa_position_embedding = nn.Parameter(
105
+ 0.01 * torch.randn(1, 1024, 1, emb_dim),
106
+ requires_grad=True,
107
+ )
108
+ else:
109
+ self.register_parameter("msa_position_embedding", None)
110
+
111
+ self.dropout_module = nn.Dropout(self.args.dropout)
112
+ self.layers = nn.ModuleList(
113
+ [
114
+ AxialTransformerLayer(
115
+ self.args.embed_dim,
116
+ self.args.ffn_embed_dim,
117
+ self.args.attention_heads,
118
+ self.args.dropout,
119
+ self.args.attention_dropout,
120
+ self.args.activation_dropout,
121
+ getattr(self.args, "max_tokens_per_msa", self.args.max_tokens),
122
+ )
123
+ for _ in range(self.args.layers)
124
+ ]
125
+ )
126
+
127
+ self.contact_head = ContactPredictionHead(
128
+ self.args.layers * self.args.attention_heads,
129
+ self.prepend_bos,
130
+ self.append_eos,
131
+ eos_idx=self.eos_idx,
132
+ )
133
+ self.embed_positions = LearnedPositionalEmbedding(
134
+ self.args.max_positions,
135
+ self.args.embed_dim,
136
+ self.padding_idx,
137
+ )
138
+ self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim)
139
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
140
+ self.lm_head = RobertaLMHead(
141
+ embed_dim=self.args.embed_dim,
142
+ output_dim=self.alphabet_size,
143
+ weight=self.embed_tokens.weight,
144
+ )
145
+
146
+ def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
147
+ if return_contacts:
148
+ need_head_weights = True
149
+
150
+ assert tokens.ndim == 3
151
+ batch_size, num_alignments, seqlen = tokens.size()
152
+ padding_mask = tokens.eq(self.padding_idx) # B, R, C
153
+ if not padding_mask.any():
154
+ padding_mask = None
155
+
156
+ x = self.embed_tokens(tokens)
157
+ x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size())
158
+ if self.msa_position_embedding is not None:
159
+ if x.size(1) > 1024:
160
+ raise RuntimeError(
161
+ "Using model with MSA position embedding trained on maximum MSA "
162
+ f"depth of 1024, but received {x.size(1)} alignments."
163
+ )
164
+ x += self.msa_position_embedding[:, :num_alignments]
165
+
166
+ x = self.emb_layer_norm_before(x)
167
+
168
+ x = self.dropout_module(x)
169
+
170
+ if padding_mask is not None:
171
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
172
+
173
+ repr_layers = set(repr_layers)
174
+ hidden_representations = {}
175
+ if 0 in repr_layers:
176
+ hidden_representations[0] = x
177
+
178
+ if need_head_weights:
179
+ row_attn_weights = []
180
+ col_attn_weights = []
181
+
182
+ # B x R x C x D -> R x C x B x D
183
+ x = x.permute(1, 2, 0, 3)
184
+
185
+ for layer_idx, layer in enumerate(self.layers):
186
+ x = layer(
187
+ x,
188
+ self_attn_padding_mask=padding_mask,
189
+ need_head_weights=need_head_weights,
190
+ )
191
+ if need_head_weights:
192
+ x, col_attn, row_attn = x
193
+ # H x C x B x R x R -> B x H x C x R x R
194
+ col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4))
195
+ # H x B x C x C -> B x H x C x C
196
+ row_attn_weights.append(row_attn.permute(1, 0, 2, 3))
197
+ if (layer_idx + 1) in repr_layers:
198
+ hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3)
199
+
200
+ x = self.emb_layer_norm_after(x)
201
+ x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D
202
+
203
+ # last hidden representation should have layer norm applied
204
+ if (layer_idx + 1) in repr_layers:
205
+ hidden_representations[layer_idx + 1] = x
206
+ x = self.lm_head(x)
207
+
208
+ result = {"logits": x, "representations": hidden_representations}
209
+ if need_head_weights:
210
+ # col_attentions: B x L x H x C x R x R
211
+ col_attentions = torch.stack(col_attn_weights, 1)
212
+ # row_attentions: B x L x H x C x C
213
+ row_attentions = torch.stack(row_attn_weights, 1)
214
+ result["col_attentions"] = col_attentions
215
+ result["row_attentions"] = row_attentions
216
+ if return_contacts:
217
+ contacts = self.contact_head(tokens, row_attentions)
218
+ result["contacts"] = contacts
219
+
220
+ return result
221
+
222
+ def predict_contacts(self, tokens):
223
+ return self(tokens, return_contacts=True)["contacts"]
224
+
225
+ @property
226
+ def num_layers(self):
227
+ return self.args.layers
228
+
229
+ def max_tokens_per_msa_(self, value: int) -> None:
230
+ """The MSA Transformer automatically batches attention computations when
231
+ gradients are disabled to allow you to pass in larger MSAs at test time than
232
+ you can fit in GPU memory. By default this occurs when more than 2^14 tokens
233
+ are passed in the input MSA. You can set this value to infinity to disable
234
+ this behavior.
235
+ """
236
+ for module in self.modules():
237
+ if isinstance(module, (RowSelfAttention, ColumnSelfAttention)):
238
+ module.max_tokens_per_msa = value
esm/modules.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .multihead_attention import MultiheadAttention # noqa
14
+ from .axial_attention import ColumnSelfAttention, RowSelfAttention
15
+
16
+
17
+ def gelu(x):
18
+ """Implementation of the gelu activation function.
19
+
20
+ For information: OpenAI GPT's gelu is slightly different
21
+ (and gives slightly different results):
22
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
23
+ """
24
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
25
+
26
+
27
+ def symmetrize(x):
28
+ "Make layer symmetric in final two dimensions, used for contact prediction."
29
+ return x + x.transpose(-1, -2)
30
+
31
+
32
+ def apc(x):
33
+ "Perform average product correct, used for contact prediction."
34
+ a1 = x.sum(-1, keepdims=True)
35
+ a2 = x.sum(-2, keepdims=True)
36
+ a12 = x.sum((-1, -2), keepdims=True)
37
+
38
+ avg = a1 * a2
39
+ avg.div_(a12) # in-place to reduce memory
40
+ normalized = x - avg
41
+ return normalized
42
+
43
+
44
+ class ESM1LayerNorm(nn.Module):
45
+ def __init__(self, hidden_size, eps=1e-12, affine=True):
46
+ """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
47
+ super().__init__()
48
+ self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
49
+ self.eps = eps
50
+ self.affine = bool(affine)
51
+ if self.affine:
52
+ self.weight = nn.Parameter(torch.ones(hidden_size))
53
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
54
+ else:
55
+ self.weight, self.bias = None, None
56
+
57
+ def forward(self, x):
58
+ dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
59
+ means = x.mean(dims, keepdim=True)
60
+ x_zeromean = x - means
61
+ variances = x_zeromean.pow(2).mean(dims, keepdim=True)
62
+ x = x_zeromean / torch.sqrt(variances + self.eps)
63
+ if self.affine:
64
+ x = (self.weight * x) + self.bias
65
+ return x
66
+
67
+
68
+ try:
69
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
70
+
71
+ class ESM1bLayerNorm(_FusedLayerNorm):
72
+ @torch.jit.unused
73
+ def forward(self, x):
74
+ if not x.is_cuda:
75
+ return super().forward(x)
76
+ else:
77
+ with torch.cuda.device(x.device):
78
+ return super().forward(x)
79
+
80
+ except ImportError:
81
+ from torch.nn import LayerNorm as ESM1bLayerNorm
82
+
83
+
84
+ class TransformerLayer(nn.Module):
85
+ """Transformer layer block."""
86
+
87
+ def __init__(
88
+ self,
89
+ embed_dim,
90
+ ffn_embed_dim,
91
+ attention_heads,
92
+ add_bias_kv=True,
93
+ use_esm1b_layer_norm=False,
94
+ use_rotary_embeddings: bool = False,
95
+ ):
96
+ super().__init__()
97
+ self.embed_dim = embed_dim
98
+ self.ffn_embed_dim = ffn_embed_dim
99
+ self.attention_heads = attention_heads
100
+ self.use_rotary_embeddings = use_rotary_embeddings
101
+ self._init_submodules(add_bias_kv, use_esm1b_layer_norm)
102
+
103
+ def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
104
+ BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm
105
+
106
+ self.self_attn = MultiheadAttention(
107
+ self.embed_dim,
108
+ self.attention_heads,
109
+ add_bias_kv=add_bias_kv,
110
+ add_zero_attn=False,
111
+ use_rotary_embeddings=self.use_rotary_embeddings,
112
+ )
113
+ self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
114
+
115
+ self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
116
+ self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
117
+
118
+ self.final_layer_norm = BertLayerNorm(self.embed_dim)
119
+
120
+ def forward(
121
+ self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
122
+ ):
123
+ residual = x
124
+ x = self.self_attn_layer_norm(x)
125
+ x, attn = self.self_attn(
126
+ query=x,
127
+ key=x,
128
+ value=x,
129
+ key_padding_mask=self_attn_padding_mask,
130
+ need_weights=True,
131
+ need_head_weights=need_head_weights,
132
+ attn_mask=self_attn_mask,
133
+ )
134
+ x = residual + x
135
+
136
+ residual = x
137
+ x = self.final_layer_norm(x)
138
+ x = gelu(self.fc1(x))
139
+ x = self.fc2(x)
140
+ x = residual + x
141
+
142
+ return x, attn
143
+
144
+
145
+ class AxialTransformerLayer(nn.Module):
146
+ """Implements an Axial MSA Transformer block."""
147
+
148
+ def __init__(
149
+ self,
150
+ embedding_dim: int = 768,
151
+ ffn_embedding_dim: int = 3072,
152
+ num_attention_heads: int = 8,
153
+ dropout: float = 0.1,
154
+ attention_dropout: float = 0.1,
155
+ activation_dropout: float = 0.1,
156
+ max_tokens_per_msa: int = 2**14,
157
+ ) -> None:
158
+ super().__init__()
159
+
160
+ # Initialize parameters
161
+ self.embedding_dim = embedding_dim
162
+ self.dropout_prob = dropout
163
+
164
+ row_self_attention = RowSelfAttention(
165
+ embedding_dim,
166
+ num_attention_heads,
167
+ dropout=dropout,
168
+ max_tokens_per_msa=max_tokens_per_msa,
169
+ )
170
+
171
+ column_self_attention = ColumnSelfAttention(
172
+ embedding_dim,
173
+ num_attention_heads,
174
+ dropout=dropout,
175
+ max_tokens_per_msa=max_tokens_per_msa,
176
+ )
177
+
178
+ feed_forward_layer = FeedForwardNetwork(
179
+ embedding_dim,
180
+ ffn_embedding_dim,
181
+ activation_dropout=activation_dropout,
182
+ max_tokens_per_msa=max_tokens_per_msa,
183
+ )
184
+
185
+ self.row_self_attention = self.build_residual(row_self_attention)
186
+ self.column_self_attention = self.build_residual(column_self_attention)
187
+ self.feed_forward_layer = self.build_residual(feed_forward_layer)
188
+
189
+ def build_residual(self, layer: nn.Module):
190
+ return NormalizedResidualBlock(
191
+ layer,
192
+ self.embedding_dim,
193
+ self.dropout_prob,
194
+ )
195
+
196
+ def forward(
197
+ self,
198
+ x: torch.Tensor,
199
+ self_attn_mask: Optional[torch.Tensor] = None,
200
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
201
+ need_head_weights: bool = False,
202
+ ):
203
+ """
204
+ LayerNorm is applied either before or after the self-attention/ffn
205
+ modules similar to the original Transformer implementation.
206
+ """
207
+ x, row_attn = self.row_self_attention(
208
+ x,
209
+ self_attn_mask=self_attn_mask,
210
+ self_attn_padding_mask=self_attn_padding_mask,
211
+ )
212
+ x, column_attn = self.column_self_attention(
213
+ x,
214
+ self_attn_mask=self_attn_mask,
215
+ self_attn_padding_mask=self_attn_padding_mask,
216
+ )
217
+ x = self.feed_forward_layer(x)
218
+ if need_head_weights:
219
+ return x, column_attn, row_attn
220
+ else:
221
+ return x
222
+
223
+
224
+ class LearnedPositionalEmbedding(nn.Embedding):
225
+ """
226
+ This module learns positional embeddings up to a fixed maximum size.
227
+ Padding ids are ignored by either offsetting based on padding_idx
228
+ or by setting padding_idx to None and ensuring that the appropriate
229
+ position ids are passed to the forward function.
230
+ """
231
+
232
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
233
+ if padding_idx is not None:
234
+ num_embeddings_ = num_embeddings + padding_idx + 1
235
+ else:
236
+ num_embeddings_ = num_embeddings
237
+ super().__init__(num_embeddings_, embedding_dim, padding_idx)
238
+ self.max_positions = num_embeddings
239
+
240
+ def forward(self, input: torch.Tensor):
241
+ """Input is expected to be of size [bsz x seqlen]."""
242
+ if input.size(1) > self.max_positions:
243
+ raise ValueError(
244
+ f"Sequence length {input.size(1)} above maximum "
245
+ f" sequence length of {self.max_positions}"
246
+ )
247
+ mask = input.ne(self.padding_idx).int()
248
+ positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
249
+ return F.embedding(
250
+ positions,
251
+ self.weight,
252
+ self.padding_idx,
253
+ self.max_norm,
254
+ self.norm_type,
255
+ self.scale_grad_by_freq,
256
+ self.sparse,
257
+ )
258
+
259
+
260
+ class SinusoidalPositionalEmbedding(nn.Module):
261
+ def __init__(self, embed_dim, padding_idx, learned=False):
262
+ super().__init__()
263
+ self.embed_dim = embed_dim
264
+ self.padding_idx = padding_idx
265
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
266
+ self.weights = None
267
+
268
+ def forward(self, x):
269
+ bsz, seq_len = x.shape
270
+ max_pos = self.padding_idx + 1 + seq_len
271
+ if self.weights is None or max_pos > self.weights.size(0):
272
+ self.weights = self.get_embedding(max_pos)
273
+ self.weights = self.weights.type_as(self._float_tensor)
274
+
275
+ positions = self.make_positions(x)
276
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
277
+
278
+ def make_positions(self, x):
279
+ mask = x.ne(self.padding_idx)
280
+ range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
281
+ positions = range_buf.expand_as(x)
282
+ return positions * mask.long() + self.padding_idx * (1 - mask.long())
283
+
284
+ def get_embedding(self, num_embeddings):
285
+ half_dim = self.embed_dim // 2
286
+ emb = math.log(10000) / (half_dim - 1)
287
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
288
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
289
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
290
+ if self.embed_dim % 2 == 1:
291
+ # zero pad
292
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
293
+ if self.padding_idx is not None:
294
+ emb[self.padding_idx, :] = 0
295
+ return emb
296
+
297
+
298
+ class RobertaLMHead(nn.Module):
299
+ """Head for masked language modeling."""
300
+
301
+ def __init__(self, embed_dim, output_dim, weight):
302
+ super().__init__()
303
+ self.dense = nn.Linear(embed_dim, embed_dim)
304
+ self.layer_norm = ESM1bLayerNorm(embed_dim)
305
+ self.weight = weight
306
+ self.bias = nn.Parameter(torch.zeros(output_dim))
307
+
308
+ def forward(self, features):
309
+ x = self.dense(features)
310
+ x = gelu(x)
311
+ x = self.layer_norm(x)
312
+ # project back to size of vocabulary with bias
313
+ x = F.linear(x, self.weight) + self.bias
314
+ return x
315
+
316
+
317
+ class ContactPredictionHead(nn.Module):
318
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
319
+
320
+ def __init__(
321
+ self,
322
+ in_features: int,
323
+ prepend_bos: bool,
324
+ append_eos: bool,
325
+ bias=True,
326
+ eos_idx: Optional[int] = None,
327
+ ):
328
+ super().__init__()
329
+ self.in_features = in_features
330
+ self.prepend_bos = prepend_bos
331
+ self.append_eos = append_eos
332
+ if append_eos and eos_idx is None:
333
+ raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
334
+ self.eos_idx = eos_idx
335
+ self.regression = nn.Linear(in_features, 1, bias)
336
+ self.activation = nn.Sigmoid()
337
+
338
+ def forward(self, tokens, attentions):
339
+ # remove eos token attentions
340
+ if self.append_eos:
341
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
342
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
343
+ attentions = attentions * eos_mask[:, None, None, :, :]
344
+ attentions = attentions[..., :-1, :-1]
345
+ # remove cls token attentions
346
+ if self.prepend_bos:
347
+ attentions = attentions[..., 1:, 1:]
348
+ batch_size, layers, heads, seqlen, _ = attentions.size()
349
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
350
+
351
+ # features: B x C x T x T
352
+ attentions = attentions.to(
353
+ self.regression.weight.device
354
+ ) # attentions always float32, may need to convert to float16
355
+ attentions = apc(symmetrize(attentions))
356
+ attentions = attentions.permute(0, 2, 3, 1)
357
+ return self.activation(self.regression(attentions).squeeze(3))
358
+
359
+
360
+ class NormalizedResidualBlock(nn.Module):
361
+ def __init__(
362
+ self,
363
+ layer: nn.Module,
364
+ embedding_dim: int,
365
+ dropout: float = 0.1,
366
+ ):
367
+ super().__init__()
368
+ self.embedding_dim = embedding_dim
369
+
370
+ self.layer = layer
371
+ self.dropout_module = nn.Dropout(
372
+ dropout,
373
+ )
374
+ self.layer_norm = ESM1bLayerNorm(self.embedding_dim)
375
+
376
+ def forward(self, x, *args, **kwargs):
377
+ residual = x
378
+ x = self.layer_norm(x)
379
+ outputs = self.layer(x, *args, **kwargs)
380
+ if isinstance(outputs, tuple):
381
+ x, *out = outputs
382
+ else:
383
+ x = outputs
384
+ out = None
385
+
386
+ x = self.dropout_module(x)
387
+ x = residual + x
388
+
389
+ if out is not None:
390
+ return (x,) + tuple(out)
391
+ else:
392
+ return x
393
+
394
+
395
+ class FeedForwardNetwork(nn.Module):
396
+ def __init__(
397
+ self,
398
+ embedding_dim: int,
399
+ ffn_embedding_dim: int,
400
+ activation_dropout: float = 0.1,
401
+ max_tokens_per_msa: int = 2**14,
402
+ ):
403
+ super().__init__()
404
+ self.embedding_dim = embedding_dim
405
+ self.ffn_embedding_dim = ffn_embedding_dim
406
+ self.max_tokens_per_msa = max_tokens_per_msa
407
+ self.activation_fn = nn.GELU()
408
+ self.activation_dropout_module = nn.Dropout(
409
+ activation_dropout,
410
+ )
411
+ self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
412
+ self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
413
+
414
+ def forward(self, x):
415
+ x = self.activation_fn(self.fc1(x))
416
+ x = self.activation_dropout_module(x)
417
+ x = self.fc2(x)
418
+ return x
esm/multihead_attention.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import Tensor, nn
12
+ from torch.nn import Parameter
13
+ from esm.rotary_embedding import RotaryEmbedding
14
+
15
+ import uuid
16
+
17
+
18
+ def utils_softmax(x, dim: int, onnx_trace: bool = False):
19
+ if onnx_trace:
20
+ return F.softmax(x.float(), dim=dim)
21
+ else:
22
+ return F.softmax(x, dim=dim, dtype=torch.float32)
23
+
24
+
25
+ class FairseqIncrementalState(object):
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+ self.init_incremental_state()
29
+
30
+ def init_incremental_state(self):
31
+ self._incremental_state_id = str(uuid.uuid4())
32
+
33
+ def _get_full_incremental_state_key(self, key: str) -> str:
34
+ return "{}.{}".format(self._incremental_state_id, key)
35
+
36
+ def get_incremental_state(
37
+ self,
38
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
39
+ key: str,
40
+ ) -> Optional[Dict[str, Optional[Tensor]]]:
41
+ """Helper for getting incremental state for an nn.Module."""
42
+ full_key = self._get_full_incremental_state_key(key)
43
+ if incremental_state is None or full_key not in incremental_state:
44
+ return None
45
+ return incremental_state[full_key]
46
+
47
+ def set_incremental_state(
48
+ self,
49
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
50
+ key: str,
51
+ value: Dict[str, Optional[Tensor]],
52
+ ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
53
+ """Helper for setting incremental state for an nn.Module."""
54
+ if incremental_state is not None:
55
+ full_key = self._get_full_incremental_state_key(key)
56
+ incremental_state[full_key] = value
57
+ return incremental_state
58
+
59
+
60
+ def with_incremental_state(cls):
61
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(
62
+ b for b in cls.__bases__ if b != FairseqIncrementalState
63
+ )
64
+ return cls
65
+
66
+
67
+ @with_incremental_state
68
+ class MultiheadAttention(nn.Module):
69
+ """Multi-headed attention.
70
+
71
+ See "Attention Is All You Need" for more details.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ embed_dim,
77
+ num_heads,
78
+ kdim=None,
79
+ vdim=None,
80
+ dropout=0.0,
81
+ bias=True,
82
+ add_bias_kv: bool = False,
83
+ add_zero_attn: bool = False,
84
+ self_attention: bool = False,
85
+ encoder_decoder_attention: bool = False,
86
+ use_rotary_embeddings: bool = False,
87
+ ):
88
+ super().__init__()
89
+ self.embed_dim = embed_dim
90
+ self.kdim = kdim if kdim is not None else embed_dim
91
+ self.vdim = vdim if vdim is not None else embed_dim
92
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
93
+
94
+ self.num_heads = num_heads
95
+ self.dropout = dropout
96
+ self.head_dim = embed_dim // num_heads
97
+ assert (
98
+ self.head_dim * num_heads == self.embed_dim
99
+ ), "embed_dim must be divisible by num_heads"
100
+ self.scaling = self.head_dim**-0.5
101
+
102
+ self.self_attention = self_attention
103
+ self.encoder_decoder_attention = encoder_decoder_attention
104
+
105
+ assert not self.self_attention or self.qkv_same_dim, (
106
+ "Self-attention requires query, key and " "value to be of the same size"
107
+ )
108
+
109
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
110
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
111
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
112
+
113
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
114
+
115
+ if add_bias_kv:
116
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
117
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
118
+ else:
119
+ self.bias_k = self.bias_v = None
120
+
121
+ self.add_zero_attn = add_zero_attn
122
+
123
+ self.reset_parameters()
124
+
125
+ self.onnx_trace = False
126
+ self.rot_emb = None
127
+ if use_rotary_embeddings:
128
+ self.rot_emb = RotaryEmbedding(dim=self.head_dim)
129
+
130
+ self.enable_torch_version = False
131
+ if hasattr(F, "multi_head_attention_forward"):
132
+ self.enable_torch_version = True
133
+ else:
134
+ self.enable_torch_version = False
135
+
136
+ def prepare_for_onnx_export_(self):
137
+ self.onnx_trace = True
138
+
139
+ def reset_parameters(self):
140
+ if self.qkv_same_dim:
141
+ # Empirically observed the convergence to be much better with
142
+ # the scaled initialization
143
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
144
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
145
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
146
+ else:
147
+ nn.init.xavier_uniform_(self.k_proj.weight)
148
+ nn.init.xavier_uniform_(self.v_proj.weight)
149
+ nn.init.xavier_uniform_(self.q_proj.weight)
150
+
151
+ nn.init.xavier_uniform_(self.out_proj.weight)
152
+ if self.out_proj.bias is not None:
153
+ nn.init.constant_(self.out_proj.bias, 0.0)
154
+ if self.bias_k is not None:
155
+ nn.init.xavier_normal_(self.bias_k)
156
+ if self.bias_v is not None:
157
+ nn.init.xavier_normal_(self.bias_v)
158
+
159
+ def forward(
160
+ self,
161
+ query,
162
+ key: Optional[Tensor],
163
+ value: Optional[Tensor],
164
+ key_padding_mask: Optional[Tensor] = None,
165
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
166
+ need_weights: bool = True,
167
+ static_kv: bool = False,
168
+ attn_mask: Optional[Tensor] = None,
169
+ before_softmax: bool = False,
170
+ need_head_weights: bool = False,
171
+ ) -> Tuple[Tensor, Optional[Tensor]]:
172
+ """Input shape: Time x Batch x Channel
173
+
174
+ Args:
175
+ key_padding_mask (ByteTensor, optional): mask to exclude
176
+ keys that are pads, of shape `(batch, src_len)`, where
177
+ padding elements are indicated by 1s.
178
+ need_weights (bool, optional): return the attention weights,
179
+ averaged over heads (default: False).
180
+ attn_mask (ByteTensor, optional): typically used to
181
+ implement causal attention, where the mask prevents the
182
+ attention from looking forward in time (default: None).
183
+ before_softmax (bool, optional): return the raw attention
184
+ weights and values before the attention softmax.
185
+ need_head_weights (bool, optional): return the attention
186
+ weights for each head. Implies *need_weights*. Default:
187
+ return the average attention weights over all heads.
188
+ """
189
+ if need_head_weights:
190
+ need_weights = True
191
+
192
+ tgt_len, bsz, embed_dim = query.size()
193
+ assert embed_dim == self.embed_dim
194
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
195
+
196
+ if (
197
+ not self.rot_emb
198
+ and self.enable_torch_version
199
+ and not self.onnx_trace
200
+ and incremental_state is None
201
+ and not static_kv
202
+ # A workaround for quantization to work. Otherwise JIT compilation
203
+ # treats bias in linear module as method.
204
+ and not torch.jit.is_scripting()
205
+ and not need_head_weights
206
+ ):
207
+ assert key is not None and value is not None
208
+ return F.multi_head_attention_forward(
209
+ query,
210
+ key,
211
+ value,
212
+ self.embed_dim,
213
+ self.num_heads,
214
+ torch.empty([0]),
215
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
216
+ self.bias_k,
217
+ self.bias_v,
218
+ self.add_zero_attn,
219
+ self.dropout,
220
+ self.out_proj.weight,
221
+ self.out_proj.bias,
222
+ self.training,
223
+ key_padding_mask,
224
+ need_weights,
225
+ attn_mask,
226
+ use_separate_proj_weight=True,
227
+ q_proj_weight=self.q_proj.weight,
228
+ k_proj_weight=self.k_proj.weight,
229
+ v_proj_weight=self.v_proj.weight,
230
+ )
231
+ if incremental_state is not None:
232
+ saved_state = self._get_input_buffer(incremental_state)
233
+ if saved_state is not None and "prev_key" in saved_state:
234
+ # previous time steps are cached - no need to recompute
235
+ # key and value if they are static
236
+ if static_kv:
237
+ assert self.encoder_decoder_attention and not self.self_attention
238
+ key = value = None
239
+ else:
240
+ saved_state = None
241
+
242
+ if self.self_attention:
243
+ q = self.q_proj(query)
244
+ k = self.k_proj(query)
245
+ v = self.v_proj(query)
246
+ elif self.encoder_decoder_attention:
247
+ # encoder-decoder attention
248
+ q = self.q_proj(query)
249
+ if key is None:
250
+ assert value is None
251
+ k = v = None
252
+ else:
253
+ k = self.k_proj(key)
254
+ v = self.v_proj(key)
255
+
256
+ else:
257
+ assert key is not None and value is not None
258
+ q = self.q_proj(query)
259
+ k = self.k_proj(key)
260
+ v = self.v_proj(value)
261
+ q *= self.scaling
262
+
263
+ if self.bias_k is not None:
264
+ assert self.bias_v is not None
265
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
266
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
267
+ if attn_mask is not None:
268
+ attn_mask = torch.cat(
269
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
270
+ )
271
+ if key_padding_mask is not None:
272
+ key_padding_mask = torch.cat(
273
+ [
274
+ key_padding_mask,
275
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
276
+ ],
277
+ dim=1,
278
+ )
279
+
280
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
281
+ if k is not None:
282
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
283
+ if v is not None:
284
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
285
+
286
+ if saved_state is not None:
287
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
288
+ if "prev_key" in saved_state:
289
+ _prev_key = saved_state["prev_key"]
290
+ assert _prev_key is not None
291
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
292
+ if static_kv:
293
+ k = prev_key
294
+ else:
295
+ assert k is not None
296
+ k = torch.cat([prev_key, k], dim=1)
297
+ if "prev_value" in saved_state:
298
+ _prev_value = saved_state["prev_value"]
299
+ assert _prev_value is not None
300
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
301
+ if static_kv:
302
+ v = prev_value
303
+ else:
304
+ assert v is not None
305
+ v = torch.cat([prev_value, v], dim=1)
306
+ prev_key_padding_mask: Optional[Tensor] = None
307
+ if "prev_key_padding_mask" in saved_state:
308
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
309
+ assert k is not None and v is not None
310
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
311
+ key_padding_mask=key_padding_mask,
312
+ prev_key_padding_mask=prev_key_padding_mask,
313
+ batch_size=bsz,
314
+ src_len=k.size(1),
315
+ static_kv=static_kv,
316
+ )
317
+
318
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
319
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
320
+ saved_state["prev_key_padding_mask"] = key_padding_mask
321
+ # In this branch incremental_state is never None
322
+ assert incremental_state is not None
323
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
324
+ assert k is not None
325
+ src_len = k.size(1)
326
+
327
+ # This is part of a workaround to get around fork/join parallelism
328
+ # not supporting Optional types.
329
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
330
+ key_padding_mask = None
331
+
332
+ if key_padding_mask is not None:
333
+ assert key_padding_mask.size(0) == bsz
334
+ assert key_padding_mask.size(1) == src_len
335
+
336
+ if self.add_zero_attn:
337
+ assert v is not None
338
+ src_len += 1
339
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
340
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
341
+ if attn_mask is not None:
342
+ attn_mask = torch.cat(
343
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
344
+ )
345
+ if key_padding_mask is not None:
346
+ key_padding_mask = torch.cat(
347
+ [
348
+ key_padding_mask,
349
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
350
+ ],
351
+ dim=1,
352
+ )
353
+
354
+ if self.rot_emb:
355
+ q, k = self.rot_emb(q, k)
356
+
357
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
358
+ attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
359
+
360
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
361
+
362
+ if attn_mask is not None:
363
+ attn_mask = attn_mask.unsqueeze(0)
364
+ if self.onnx_trace:
365
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
366
+ attn_weights += attn_mask
367
+
368
+ if key_padding_mask is not None:
369
+ # don't attend to padding symbols
370
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
371
+ attn_weights = attn_weights.masked_fill(
372
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
373
+ )
374
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
375
+
376
+ if before_softmax:
377
+ return attn_weights, v
378
+
379
+ attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
380
+ attn_weights = attn_weights_float.type_as(attn_weights)
381
+ attn_probs = F.dropout(
382
+ attn_weights_float.type_as(attn_weights),
383
+ p=self.dropout,
384
+ training=self.training,
385
+ )
386
+ assert v is not None
387
+ attn = torch.bmm(attn_probs, v)
388
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
389
+ if self.onnx_trace and attn.size(1) == 1:
390
+ # when ONNX tracing a single decoder step (sequence length == 1)
391
+ # the transpose is a no-op copy before view, thus unnecessary
392
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
393
+ else:
394
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
395
+ attn = self.out_proj(attn)
396
+ attn_weights: Optional[Tensor] = None
397
+ if need_weights:
398
+ attn_weights = attn_weights_float.view(
399
+ bsz, self.num_heads, tgt_len, src_len
400
+ ).type_as(attn).transpose(1, 0)
401
+ if not need_head_weights:
402
+ # average attention weights over heads
403
+ attn_weights = attn_weights.mean(dim=0)
404
+
405
+ return attn, attn_weights
406
+
407
+ @staticmethod
408
+ def _append_prev_key_padding_mask(
409
+ key_padding_mask: Optional[Tensor],
410
+ prev_key_padding_mask: Optional[Tensor],
411
+ batch_size: int,
412
+ src_len: int,
413
+ static_kv: bool,
414
+ ) -> Optional[Tensor]:
415
+ # saved key padding masks have shape (bsz, seq_len)
416
+ if prev_key_padding_mask is not None and static_kv:
417
+ new_key_padding_mask = prev_key_padding_mask
418
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
419
+ new_key_padding_mask = torch.cat(
420
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
421
+ )
422
+ # During incremental decoding, as the padding token enters and
423
+ # leaves the frame, there will be a time when prev or current
424
+ # is None
425
+ elif prev_key_padding_mask is not None:
426
+ filler = torch.zeros(
427
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
428
+ device=prev_key_padding_mask.device,
429
+ )
430
+ new_key_padding_mask = torch.cat(
431
+ [prev_key_padding_mask.float(), filler.float()], dim=1
432
+ )
433
+ elif key_padding_mask is not None:
434
+ filler = torch.zeros(
435
+ (batch_size, src_len - key_padding_mask.size(1)),
436
+ device=key_padding_mask.device,
437
+ )
438
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
439
+ else:
440
+ new_key_padding_mask = prev_key_padding_mask
441
+ return new_key_padding_mask
442
+
443
+ @torch.jit.export
444
+ def reorder_incremental_state(
445
+ self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
446
+ ):
447
+ """Reorder buffered internal state (for incremental generation)."""
448
+ input_buffer = self._get_input_buffer(incremental_state)
449
+ if input_buffer is not None:
450
+ for k in input_buffer.keys():
451
+ input_buffer_k = input_buffer[k]
452
+ if input_buffer_k is not None:
453
+ if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
454
+ 0
455
+ ):
456
+ break
457
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
458
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
459
+ return incremental_state
460
+
461
+ def _get_input_buffer(
462
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
463
+ ) -> Dict[str, Optional[Tensor]]:
464
+ result = self.get_incremental_state(incremental_state, "attn_state")
465
+ if result is not None:
466
+ return result
467
+ else:
468
+ empty_result: Dict[str, Optional[Tensor]] = {}
469
+ return empty_result
470
+
471
+ def _set_input_buffer(
472
+ self,
473
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
474
+ buffer: Dict[str, Optional[Tensor]],
475
+ ):
476
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
477
+
478
+ def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
479
+ return attn_weights
480
+
481
+ def upgrade_state_dict_named(self, state_dict, name):
482
+ prefix = name + "." if name != "" else ""
483
+ items_to_add = {}
484
+ keys_to_remove = []
485
+ for k in state_dict.keys():
486
+ if k.endswith(prefix + "in_proj_weight"):
487
+ # in_proj_weight used to be q + k + v with same dimensions
488
+ dim = int(state_dict[k].shape[0] / 3)
489
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
490
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
491
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
492
+
493
+ keys_to_remove.append(k)
494
+
495
+ k_bias = prefix + "in_proj_bias"
496
+ if k_bias in state_dict.keys():
497
+ dim = int(state_dict[k].shape[0] / 3)
498
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
499
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
500
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
501
+
502
+ keys_to_remove.append(prefix + "in_proj_bias")
503
+
504
+ for k in keys_to_remove:
505
+ del state_dict[k]
506
+
507
+ for key, value in items_to_add.items():
508
+ state_dict[key] = value
esm/pretrained.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import urllib
8
+ import warnings
9
+ from argparse import Namespace
10
+ from pathlib import Path
11
+
12
+ import torch
13
+
14
+ import esm
15
+ from esm.model.esm2 import ESM2
16
+
17
+
18
+ def _has_regression_weights(model_name):
19
+ """Return whether we expect / require regression weights;
20
+ Right now that is all models except ESM-1v, ESM-IF, and partially trained ESM2 models"""
21
+ return not ("esm1v" in model_name or "esm_if" in model_name or "270K" in model_name or "500K" in model_name)
22
+
23
+
24
+ def load_model_and_alphabet(model_name):
25
+ if model_name.endswith(".pt"): # treat as filepath
26
+ return load_model_and_alphabet_local(model_name)
27
+ else:
28
+ return load_model_and_alphabet_hub(model_name)
29
+
30
+
31
+ def load_hub_workaround(url):
32
+ try:
33
+ data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
34
+ except RuntimeError:
35
+ # Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106
36
+ fn = Path(url).name
37
+ data = torch.load(
38
+ f"{torch.hub.get_dir()}/checkpoints/{fn}",
39
+ map_location="cpu",
40
+ )
41
+ except urllib.error.HTTPError as e:
42
+ raise Exception(f"Could not load {url}, check if you specified a correct model name?")
43
+ return data
44
+
45
+
46
+ def load_regression_hub(model_name):
47
+ url = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt"
48
+ regression_data = load_hub_workaround(url)
49
+ return regression_data
50
+
51
+
52
+ def _download_model_and_regression_data(model_name):
53
+ url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
54
+ model_data = load_hub_workaround(url)
55
+ if _has_regression_weights(model_name):
56
+ regression_data = load_regression_hub(model_name)
57
+ else:
58
+ regression_data = None
59
+ return model_data, regression_data
60
+
61
+
62
+ def load_model_and_alphabet_hub(model_name):
63
+ model_data, regression_data = _download_model_and_regression_data(model_name)
64
+ return load_model_and_alphabet_core(model_name, model_data, regression_data)
65
+
66
+
67
+ def load_model_and_alphabet_local(model_location):
68
+ """Load from local path. The regression weights need to be co-located"""
69
+ model_location = Path(model_location)
70
+ model_data = torch.load(str(model_location), map_location="cpu")
71
+ model_name = model_location.stem
72
+ if _has_regression_weights(model_name):
73
+ regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt"
74
+ regression_data = torch.load(regression_location, map_location="cpu")
75
+ else:
76
+ regression_data = None
77
+ return load_model_and_alphabet_core(model_name, model_data, regression_data)
78
+
79
+
80
+ def has_emb_layer_norm_before(model_state):
81
+ """Determine whether layer norm needs to be applied before the encoder"""
82
+ return any(k.startswith("emb_layer_norm_before") for k, param in model_state.items())
83
+
84
+
85
+ def _load_model_and_alphabet_core_v1(model_data):
86
+ import esm # since esm.inverse_folding is imported below, you actually have to re-import esm here
87
+
88
+ alphabet = esm.Alphabet.from_architecture(model_data["args"].arch)
89
+
90
+ if model_data["args"].arch == "roberta_large":
91
+ # upgrade state dict
92
+ pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
93
+ prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
94
+ prs2 = lambda s: "".join(
95
+ s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
96
+ )
97
+ model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
98
+ model_state = {prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items()}
99
+ model_state["embed_tokens.weight"][alphabet.mask_idx].zero_() # For token drop
100
+ model_args["emb_layer_norm_before"] = has_emb_layer_norm_before(model_state)
101
+ model_type = esm.ProteinBertModel
102
+
103
+ elif model_data["args"].arch == "protein_bert_base":
104
+
105
+ # upgrade state dict
106
+ pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s)
107
+ prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s)
108
+ model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
109
+ model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}
110
+ model_type = esm.ProteinBertModel
111
+ elif model_data["args"].arch == "msa_transformer":
112
+
113
+ # upgrade state dict
114
+ pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
115
+ prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
116
+ prs2 = lambda s: "".join(
117
+ s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
118
+ )
119
+ prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row")
120
+ model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
121
+ model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()}
122
+ if model_args.get("embed_positions_msa", False):
123
+ emb_dim = model_state["msa_position_embedding"].size(-1)
124
+ model_args["embed_positions_msa_dim"] = emb_dim # initial release, bug: emb_dim==1
125
+
126
+ model_type = esm.MSATransformer
127
+
128
+ elif "invariant_gvp" in model_data["args"].arch:
129
+ import esm.inverse_folding
130
+
131
+ model_type = esm.inverse_folding.gvp_transformer.GVPTransformerModel
132
+ model_args = vars(model_data["args"]) # convert Namespace -> dict
133
+
134
+ def update_name(s):
135
+ # Map the module names in checkpoints trained with internal code to
136
+ # the updated module names in open source code
137
+ s = s.replace("W_v", "embed_graph.embed_node")
138
+ s = s.replace("W_e", "embed_graph.embed_edge")
139
+ s = s.replace("embed_scores.0", "embed_confidence")
140
+ s = s.replace("embed_score.", "embed_graph.embed_confidence.")
141
+ s = s.replace("seq_logits_projection.", "")
142
+ s = s.replace("embed_ingraham_features", "embed_dihedrals")
143
+ s = s.replace("embed_gvp_in_local_frame.0", "embed_gvp_output")
144
+ s = s.replace("embed_features_in_local_frame.0", "embed_gvp_input_features")
145
+ return s
146
+
147
+ model_state = {
148
+ update_name(sname): svalue
149
+ for sname, svalue in model_data["model"].items()
150
+ if "version" not in sname
151
+ }
152
+
153
+ else:
154
+ raise ValueError("Unknown architecture selected")
155
+
156
+ model = model_type(
157
+ Namespace(**model_args),
158
+ alphabet,
159
+ )
160
+
161
+ return model, alphabet, model_state
162
+
163
+
164
+ def _load_model_and_alphabet_core_v2(model_data):
165
+ def upgrade_state_dict(state_dict):
166
+ """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
167
+ prefixes = ["encoder.sentence_encoder.", "encoder."]
168
+ pattern = re.compile("^" + "|".join(prefixes))
169
+ state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
170
+ return state_dict
171
+
172
+ cfg = model_data["cfg"]["model"]
173
+ state_dict = model_data["model"]
174
+ state_dict = upgrade_state_dict(state_dict)
175
+ alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
176
+ model = ESM2(
177
+ num_layers=cfg.encoder_layers,
178
+ embed_dim=cfg.encoder_embed_dim,
179
+ attention_heads=cfg.encoder_attention_heads,
180
+ alphabet=alphabet,
181
+ token_dropout=cfg.token_dropout,
182
+ )
183
+ return model, alphabet, state_dict
184
+
185
+
186
+ def load_model_and_alphabet_core(model_name, model_data, regression_data=None):
187
+ if regression_data is not None:
188
+ model_data["model"].update(regression_data["model"])
189
+
190
+ if model_name.startswith("esm2"):
191
+ model, alphabet, model_state = _load_model_and_alphabet_core_v2(model_data)
192
+ else:
193
+ model, alphabet, model_state = _load_model_and_alphabet_core_v1(model_data)
194
+
195
+ expected_keys = set(model.state_dict().keys())
196
+ found_keys = set(model_state.keys())
197
+
198
+ if regression_data is None:
199
+ expected_missing = {"contact_head.regression.weight", "contact_head.regression.bias"}
200
+ error_msgs = []
201
+ missing = (expected_keys - found_keys) - expected_missing
202
+ if missing:
203
+ error_msgs.append(f"Missing key(s) in state_dict: {missing}.")
204
+ unexpected = found_keys - expected_keys
205
+ if unexpected:
206
+ error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected}.")
207
+
208
+ if error_msgs:
209
+ raise RuntimeError(
210
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
211
+ model.__class__.__name__, "\n\t".join(error_msgs)
212
+ )
213
+ )
214
+ if expected_missing - found_keys:
215
+ warnings.warn(
216
+ "Regression weights not found, predicting contacts will not produce correct results."
217
+ )
218
+
219
+ model.load_state_dict(model_state, strict=regression_data is not None)
220
+
221
+ return model, alphabet
222
+
223
+
224
+ def esm1_t34_670M_UR50S():
225
+ """34 layer transformer model with 670M params, trained on Uniref50 Sparse.
226
+
227
+ Returns a tuple of (Model, Alphabet).
228
+ """
229
+ return load_model_and_alphabet_hub("esm1_t34_670M_UR50S")
230
+
231
+
232
+ def esm1_t34_670M_UR50D():
233
+ """34 layer transformer model with 670M params, trained on Uniref50 Dense.
234
+
235
+ Returns a tuple of (Model, Alphabet).
236
+ """
237
+ return load_model_and_alphabet_hub("esm1_t34_670M_UR50D")
238
+
239
+
240
+ def esm1_t34_670M_UR100():
241
+ """34 layer transformer model with 670M params, trained on Uniref100.
242
+
243
+ Returns a tuple of (Model, Alphabet).
244
+ """
245
+ return load_model_and_alphabet_hub("esm1_t34_670M_UR100")
246
+
247
+
248
+ def esm1_t12_85M_UR50S():
249
+ """12 layer transformer model with 85M params, trained on Uniref50 Sparse.
250
+
251
+ Returns a tuple of (Model, Alphabet).
252
+ """
253
+ return load_model_and_alphabet_hub("esm1_t12_85M_UR50S")
254
+
255
+
256
+ def esm1_t6_43M_UR50S():
257
+ """6 layer transformer model with 43M params, trained on Uniref50 Sparse.
258
+
259
+ Returns a tuple of (Model, Alphabet).
260
+ """
261
+ return load_model_and_alphabet_hub("esm1_t6_43M_UR50S")
262
+
263
+
264
+ def esm1b_t33_650M_UR50S():
265
+ """33 layer transformer model with 650M params, trained on Uniref50 Sparse.
266
+ This is our best performing model, which will be described in a future publication.
267
+
268
+ Returns a tuple of (Model, Alphabet).
269
+ """
270
+ return load_model_and_alphabet_hub("esm1b_t33_650M_UR50S")
271
+
272
+
273
+ def esm_msa1_t12_100M_UR50S():
274
+ warnings.warn(
275
+ "This model had a minor bug in the positional embeddings, "
276
+ "please use ESM-MSA-1b: esm.pretrained.esm_msa1b_t12_100M_UR50S()",
277
+ )
278
+ return load_model_and_alphabet_hub("esm_msa1_t12_100M_UR50S")
279
+
280
+
281
+ def esm_msa1b_t12_100M_UR50S():
282
+ return load_model_and_alphabet_hub("esm_msa1b_t12_100M_UR50S")
283
+
284
+
285
+ def esm1v_t33_650M_UR90S():
286
+ """33 layer transformer model with 650M params, trained on Uniref90.
287
+ This is model 1 of a 5 model ensemble.
288
+
289
+ Returns a tuple of (Model, Alphabet).
290
+ """
291
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
292
+
293
+
294
+ def esm1v_t33_650M_UR90S_1():
295
+ """33 layer transformer model with 650M params, trained on Uniref90.
296
+ This is model 1 of a 5 model ensemble.
297
+
298
+ Returns a tuple of (Model, Alphabet).
299
+ """
300
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
301
+
302
+
303
+ def esm1v_t33_650M_UR90S_2():
304
+ """33 layer transformer model with 650M params, trained on Uniref90.
305
+ This is model 2 of a 5 model ensemble.
306
+
307
+ Returns a tuple of (Model, Alphabet).
308
+ """
309
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_2")
310
+
311
+
312
+ def esm1v_t33_650M_UR90S_3():
313
+ """33 layer transformer model with 650M params, trained on Uniref90.
314
+ This is model 3 of a 5 model ensemble.
315
+
316
+ Returns a tuple of (Model, Alphabet).
317
+ """
318
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_3")
319
+
320
+
321
+ def esm1v_t33_650M_UR90S_4():
322
+ """33 layer transformer model with 650M params, trained on Uniref90.
323
+ This is model 4 of a 5 model ensemble.
324
+
325
+ Returns a tuple of (Model, Alphabet).
326
+ """
327
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_4")
328
+
329
+
330
+ def esm1v_t33_650M_UR90S_5():
331
+ """33 layer transformer model with 650M params, trained on Uniref90.
332
+ This is model 5 of a 5 model ensemble.
333
+
334
+ Returns a tuple of (Model, Alphabet).
335
+ """
336
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_5")
337
+
338
+
339
+ def esm_if1_gvp4_t16_142M_UR50():
340
+ """Inverse folding model with 142M params, with 4 GVP-GNN layers, 8
341
+ Transformer encoder layers, and 8 Transformer decoder layers, trained on
342
+ CATH structures and 12 million alphafold2 predicted structures from UniRef50
343
+ sequences.
344
+
345
+ Returns a tuple of (Model, Alphabet).
346
+ """
347
+ return load_model_and_alphabet_hub("esm_if1_gvp4_t16_142M_UR50")
348
+
349
+
350
+ def esm2_t6_8M_UR50D():
351
+ """6 layer ESM-2 model with 8M params, trained on UniRef50.
352
+
353
+ Returns a tuple of (Model, Alphabet).
354
+ """
355
+ return load_model_and_alphabet_hub("esm2_t6_8M_UR50D")
356
+
357
+
358
+ def esm2_t12_35M_UR50D():
359
+ """12 layer ESM-2 model with 35M params, trained on UniRef50.
360
+
361
+ Returns a tuple of (Model, Alphabet).
362
+ """
363
+ return load_model_and_alphabet_hub("esm2_t12_35M_UR50D")
364
+
365
+
366
+ def esm2_t30_150M_UR50D():
367
+ """30 layer ESM-2 model with 150M params, trained on UniRef50.
368
+
369
+ Returns a tuple of (Model, Alphabet).
370
+ """
371
+ return load_model_and_alphabet_hub("esm2_t30_150M_UR50D")
372
+
373
+
374
+ def esm2_t33_650M_UR50D():
375
+ """33 layer ESM-2 model with 650M params, trained on UniRef50.
376
+
377
+ Returns a tuple of (Model, Alphabet).
378
+ """
379
+ return load_model_and_alphabet_hub("esm2_t33_650M_UR50D")
380
+
381
+
382
+ def esm2_t36_3B_UR50D():
383
+ """36 layer ESM-2 model with 3B params, trained on UniRef50.
384
+
385
+ Returns a tuple of (Model, Alphabet).
386
+ """
387
+ return load_model_and_alphabet_hub("esm2_t36_3B_UR50D")
388
+
389
+
390
+ def esm2_t48_15B_UR50D():
391
+ """48 layer ESM-2 model with 15B params, trained on UniRef50.
392
+ If you have OOM while loading this model, please refer to README
393
+ on how to employ FSDP and ZeRO CPU offloading
394
+
395
+ Returns a tuple of (Model, Alphabet).
396
+ """
397
+ return load_model_and_alphabet_hub("esm2_t48_15B_UR50D")
398
+
399
+
400
+ def esmfold_v0():
401
+ """
402
+ ESMFold v0 model with 3B ESM-2, 48 folding blocks.
403
+ This version was used for the paper (Lin et al, 2022). It was trained
404
+ on all PDB chains until 2020-05, to ensure temporal holdout with CASP14
405
+ and the CAMEO validation and test set reported there.
406
+ """
407
+ import esm.esmfold.v1.pretrained
408
+ return esm.esmfold.v1.pretrained.esmfold_v0()
409
+
410
+
411
+ def esmfold_v1():
412
+ """
413
+ ESMFold v1 model using 3B ESM-2, 48 folding blocks.
414
+ ESMFold provides fast high accuracy atomic level structure prediction
415
+ directly from the individual sequence of a protein. ESMFold uses the ESM2
416
+ protein language model to extract meaningful representations from the
417
+ protein sequence.
418
+ """
419
+ import esm.esmfold.v1.pretrained
420
+ return esm.esmfold.v1.pretrained.esmfold_v1()
421
+
422
+ def esmfold_structure_module_only_8M():
423
+ """
424
+ ESMFold baseline model using 8M ESM-2, 0 folding blocks.
425
+ ESM-2 here is trained out to 500K updates.
426
+ This is a model designed to test the capabilities of the language model
427
+ when ablated for number of parameters in the language model.
428
+ See table S1 in (Lin et al, 2022).
429
+ """
430
+ import esm.esmfold.v1.pretrained
431
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_8M()
432
+
433
+
434
+ def esmfold_structure_module_only_8M_270K():
435
+ """
436
+ ESMFold baseline model using 8M ESM-2, 0 folding blocks.
437
+ ESM-2 here is trained out to 270K updates.
438
+ This is a model designed to test the capabilities of the language model
439
+ when ablated for number of parameters in the language model.
440
+ See table S1 in (Lin et al, 2022).
441
+ """
442
+ import esm.esmfold.v1.pretrained
443
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_8M_270K()
444
+
445
+
446
+ def esmfold_structure_module_only_35M():
447
+ """
448
+ ESMFold baseline model using 35M ESM-2, 0 folding blocks.
449
+ ESM-2 here is trained out to 500K updates.
450
+ This is a model designed to test the capabilities of the language model
451
+ when ablated for number of parameters in the language model.
452
+ See table S1 in (Lin et al, 2022).
453
+ """
454
+ import esm.esmfold.v1.pretrained
455
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_35M()
456
+
457
+
458
+ def esmfold_structure_module_only_35M_270K():
459
+ """
460
+ ESMFold baseline model using 35M ESM-2, 0 folding blocks.
461
+ ESM-2 here is trained out to 270K updates.
462
+ This is a model designed to test the capabilities of the language model
463
+ when ablated for number of parameters in the language model.
464
+ See table S1 in (Lin et al, 2022).
465
+ """
466
+ import esm.esmfold.v1.pretrained
467
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_35M_270K()
468
+
469
+
470
+ def esmfold_structure_module_only_150M():
471
+ """
472
+ ESMFold baseline model using 150M ESM-2, 0 folding blocks.
473
+ ESM-2 here is trained out to 500K updates.
474
+ This is a model designed to test the capabilities of the language model
475
+ when ablated for number of parameters in the language model.
476
+ See table S1 in (Lin et al, 2022).
477
+ """
478
+ import esm.esmfold.v1.pretrained
479
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_150M()
480
+
481
+
482
+ def esmfold_structure_module_only_150M_270K():
483
+ """
484
+ ESMFold baseline model using 150M ESM-2, 0 folding blocks.
485
+ ESM-2 here is trained out to 270K updates.
486
+ This is a model designed to test the capabilities of the language model
487
+ when ablated for number of parameters in the language model.
488
+ See table S1 in (Lin et al, 2022).
489
+ """
490
+ import esm.esmfold.v1.pretrained
491
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_150M_270K()
492
+
493
+
494
+ def esmfold_structure_module_only_650M():
495
+ """
496
+ ESMFold baseline model using 650M ESM-2, 0 folding blocks.
497
+ ESM-2 here is trained out to 500K updates.
498
+ This is a model designed to test the capabilities of the language model
499
+ when ablated for number of parameters in the language model.
500
+ See table S1 in (Lin et al, 2022).
501
+ """
502
+ import esm.esmfold.v1.pretrained
503
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_650M()
504
+
505
+
506
+ def esmfold_structure_module_only_650M_270K():
507
+ """
508
+ ESMFold baseline model using 650M ESM-2, 0 folding blocks.
509
+ ESM-2 here is trained out to 270K updates.
510
+ This is a model designed to test the capabilities of the language model
511
+ when ablated for number of parameters in the language model.
512
+ See table S1 in (Lin et al, 2022).
513
+ """
514
+ import esm.esmfold.v1.pretrained
515
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_650M_270K()
516
+
517
+
518
+ def esmfold_structure_module_only_3B():
519
+ """
520
+ ESMFold baseline model using 3B ESM-2, 0 folding blocks.
521
+ ESM-2 here is trained out to 500K updates.
522
+ This is a model designed to test the capabilities of the language model
523
+ when ablated for number of parameters in the language model.
524
+ See table S1 in (Lin et al, 2022).
525
+ """
526
+ import esm.esmfold.v1.pretrained
527
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_3B()
528
+
529
+
530
+ def esmfold_structure_module_only_3B_270K():
531
+ """
532
+ ESMFold baseline model using 3B ESM-2, 0 folding blocks.
533
+ ESM-2 here is trained out to 270K updates.
534
+ This is a model designed to test the capabilities of the language model
535
+ when ablated for number of parameters in the language model.
536
+ See table S1 in (Lin et al, 2022).
537
+ """
538
+ import esm.esmfold.v1.pretrained
539
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_3B_270K()
540
+
541
+
542
+ def esmfold_structure_module_only_15B():
543
+ """
544
+ ESMFold baseline model using 15B ESM-2, 0 folding blocks.
545
+ ESM-2 here is trained out to 270K updates.
546
+ The 15B parameter ESM-2 was not trained out to 500K updates
547
+ This is a model designed to test the capabilities of the language model
548
+ when ablated for number of parameters in the language model.
549
+ See table S1 in (Lin et al, 2022).
550
+ """
551
+ import esm.esmfold.v1.pretrained
552
+ return esm.esmfold.v1.pretrained.esmfold_structure_module_only_15B()
esm/rotary_embedding.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Tuple
7
+
8
+ import torch
9
+
10
+
11
+ def rotate_half(x):
12
+ x1, x2 = x.chunk(2, dim=-1)
13
+ return torch.cat((-x2, x1), dim=-1)
14
+
15
+
16
+ def apply_rotary_pos_emb(x, cos, sin):
17
+ cos = cos[:, : x.shape[-2], :]
18
+ sin = sin[:, : x.shape[-2], :]
19
+
20
+ return (x * cos) + (rotate_half(x) * sin)
21
+
22
+
23
+ class RotaryEmbedding(torch.nn.Module):
24
+ """
25
+ The rotary position embeddings from RoFormer_ (Su et. al).
26
+ A crucial insight from the method is that the query and keys are
27
+ transformed by rotation matrices which depend on the relative positions.
28
+ Other implementations are available in the Rotary Transformer repo_ and in
29
+ GPT-NeoX_, GPT-NeoX was an inspiration
30
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
31
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
32
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
33
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
34
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
35
+ """
36
+
37
+ def __init__(self, dim: int, *_, **__):
38
+ super().__init__()
39
+ # Generate and save the inverse frequency buffer (non trainable)
40
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
41
+ self.register_buffer("inv_freq", inv_freq)
42
+
43
+ self._seq_len_cached = None
44
+ self._cos_cached = None
45
+ self._sin_cached = None
46
+
47
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
48
+ seq_len = x.shape[seq_dimension]
49
+
50
+ # Reset the tables if the sequence length has changed,
51
+ # or if we're on a new device (possibly due to tracing for instance)
52
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
53
+ self._seq_len_cached = seq_len
54
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
55
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
56
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
57
+
58
+ self._cos_cached = emb.cos()[None, :, :]
59
+ self._sin_cached = emb.sin()[None, :, :]
60
+
61
+ return self._cos_cached, self._sin_cached
62
+
63
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
64
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
65
+
66
+ return (
67
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
68
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
69
+ )
esm/version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ version = "2.0.1"