lhallee commited on
Commit
155f33d
·
verified ·
1 Parent(s): bcdaf7e

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +1076 -1076
modeling_esm_plusplus.py CHANGED
@@ -1,1076 +1,1076 @@
1
- """
2
- ESM++ model implementation.
3
-
4
- ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
- The ESM Python package is not required
6
-
7
- Modified from https://github.com/evolutionaryscale/esm
8
- License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
- """
10
-
11
- import math
12
- import os
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
- from dataclasses import dataclass
17
- from functools import cache, partial
18
- from pathlib import Path
19
- from typing import Optional, Tuple, Union
20
- from einops import rearrange, repeat
21
- from huggingface_hub import snapshot_download
22
- from tokenizers import Tokenizer
23
- from tokenizers.models import BPE
24
- from tokenizers.processors import TemplateProcessing
25
- from torch.utils.data import Dataset, DataLoader
26
- from tqdm.auto import tqdm
27
- from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
28
- from transformers.modeling_outputs import ModelOutput
29
-
30
-
31
- class ESMplusplusConfig(PretrainedConfig):
32
- """Configuration class for ESM++ model.
33
-
34
- Args:
35
- vocab_size: Size of the vocabulary
36
- hidden_size: Dimension of hidden layers
37
- num_attention_heads: Number of attention heads
38
- num_hidden_layers: Number of transformer layers
39
- num_labels: Number of output labels for classification
40
- problem_type: Type of problem - regression, single/multi label classification
41
- """
42
- model_type = "ESMplusplus"
43
- def __init__(
44
- self,
45
- vocab_size: int = 64,
46
- hidden_size: int = 960,
47
- num_attention_heads: int = 15,
48
- num_hidden_layers: int = 30,
49
- num_labels: int = 2,
50
- problem_type: str | None = None,
51
- dropout: float = 0.0,
52
- initializer_range: float = 0.02,
53
- **kwargs,
54
- ):
55
- super().__init__(**kwargs)
56
- self.vocab_size = vocab_size
57
- self.hidden_size = hidden_size
58
- self.num_attention_heads = num_attention_heads
59
- self.num_hidden_layers = num_hidden_layers
60
- self.num_labels = num_labels
61
- self.problem_type = problem_type
62
- self.dropout = dropout
63
- self.initializer_range = initializer_range
64
-
65
-
66
- ### Rotary Embeddings
67
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
68
- """Rotates half the hidden dims of the input."""
69
- if not interleaved:
70
- x1, x2 = x.chunk(2, dim=-1)
71
- return torch.cat((-x2, x1), dim=-1)
72
- else:
73
- x1, x2 = x[..., ::2], x[..., 1::2]
74
- return rearrange(
75
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
76
- )
77
-
78
-
79
- def apply_rotary_emb_torch(
80
- x: torch.Tensor,
81
- cos: torch.Tensor,
82
- sin: torch.Tensor,
83
- interleaved: bool = False,
84
- _inplace: bool = False,
85
- ) -> torch.Tensor:
86
- """Apply rotary embeddings to input based on cos and sin."""
87
- ro_dim = cos.shape[-1] * 2
88
- assert ro_dim <= x.shape[-1]
89
- seqlen = x.size(1)
90
- cos = cos[:seqlen]
91
- sin = sin[:seqlen]
92
- cos = repeat(cos, "s d -> s 1 (2 d)")
93
- sin = repeat(sin, "s d -> s 1 (2 d)")
94
- return torch.cat(
95
- [
96
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
97
- x[..., ro_dim:],
98
- ],
99
- dim=-1,
100
- )
101
-
102
-
103
- class RotaryEmbedding(torch.nn.Module):
104
- """Rotary position embeddings.
105
-
106
- Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
107
-
108
- Args:
109
- dim: Dimension of the embedding
110
- base: Base for computing angular frequencies
111
- interleaved: Whether to use interleaved rotations
112
- scale_base: Base for scaling
113
- scaling_factor: Factor for scaling positions
114
- pos_idx_in_fp32: Whether to compute position indices in fp32
115
- device: Computation device
116
- """
117
- def __init__(
118
- self,
119
- dim: int,
120
- base: float = 10000.0,
121
- interleaved: bool = False,
122
- scale_base: Optional[float] = None,
123
- scaling_factor: float = 1.0,
124
- pos_idx_in_fp32: bool = True,
125
- device: Optional[torch.device] = None,
126
- ):
127
- super().__init__()
128
- self.dim = dim
129
- self.base = float(base)
130
- self.pos_idx_in_fp32 = pos_idx_in_fp32
131
- self.interleaved = interleaved
132
- self.scale_base = scale_base
133
- self.scaling_factor = scaling_factor
134
- self.device = device
135
-
136
- self._seq_len_cached = 0
137
- self._cos_cached = None
138
- self._sin_cached = None
139
- self._cos_k_cached = None
140
- self._sin_k_cached = None
141
- self.reset_parameters()
142
-
143
- def reset_parameters(self):
144
- """Reset the parameters of the embedding."""
145
- inv_freq = self._compute_inv_freq(self.device)
146
- self.register_buffer("inv_freq", inv_freq, persistent=False)
147
- arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
148
- scale = (
149
- (arange + 0.4 * self.dim) / (1.4 * self.dim)
150
- if self.scale_base is not None
151
- else None
152
- )
153
- self.register_buffer("scale", scale)
154
-
155
- def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
156
- """Compute inverse frequency bands."""
157
- return 1 / (
158
- self.base
159
- ** (
160
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
161
- / self.dim
162
- )
163
- )
164
-
165
- def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
166
- """Update the cached cosine and sine values."""
167
- if (
168
- seqlen > self._seq_len_cached
169
- or self._cos_cached is None
170
- or self._cos_cached.device != device
171
- or self._cos_cached.dtype != dtype
172
- or (self.training and self._cos_cached.is_inference())
173
- ):
174
- self._seq_len_cached = seqlen
175
- if self.pos_idx_in_fp32:
176
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
177
- t /= self.scaling_factor
178
- if self.inv_freq.dtype != torch.float32:
179
- inv_freq = self.inv_freq.to(torch.float32)
180
- else:
181
- inv_freq = self.inv_freq
182
- else:
183
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
184
- t /= self.scaling_factor
185
- inv_freq = self.inv_freq
186
- freqs = torch.outer(t, inv_freq)
187
-
188
- if self.scale is None:
189
- self._cos_cached = torch.cos(freqs).to(dtype)
190
- self._sin_cached = torch.sin(freqs).to(dtype)
191
- else:
192
- power = (
193
- torch.arange(
194
- seqlen, dtype=self.scale.dtype, device=self.scale.device
195
- )
196
- - seqlen // 2
197
- ) / self.scale_base
198
- scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
199
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
200
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
201
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
202
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
203
-
204
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
205
- """Apply rotary embeddings to queries and keys.
206
-
207
- Args:
208
- q: Query tensor of shape (batch, seqlen, nheads, headdim)
209
- k: Key tensor of shape (batch, seqlen, nheads, headdim)
210
-
211
- Returns:
212
- Tuple of rotated query and key tensors
213
- """
214
- self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
215
- assert self._cos_cached is not None
216
- assert self._sin_cached is not None
217
- if self.scale is None:
218
- return (
219
- apply_rotary_emb_torch(
220
- q,
221
- self._cos_cached,
222
- self._sin_cached,
223
- self.interleaved,
224
- True, # inplace=True
225
- ),
226
- apply_rotary_emb_torch(
227
- k,
228
- self._cos_cached,
229
- self._sin_cached,
230
- self.interleaved,
231
- True, # inplace=True
232
- ),
233
- ) # type: ignore
234
- else:
235
- assert False
236
-
237
-
238
- ### Feedforward Network Components
239
- def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
240
- """Compute corrected dimension for SwiGLU."""
241
- return int(((expansion_ratio * d_model) + 255) // 256 * 256)
242
-
243
-
244
- class SwiGLU(nn.Module):
245
- """SwiGLU activation function."""
246
- def __init__(self):
247
- super(SwiGLU, self).__init__()
248
-
249
- def forward(self, x: torch.Tensor) -> torch.Tensor:
250
- x1, x2 = x.chunk(2, dim=-1)
251
- return F.silu(x1) * x2
252
-
253
-
254
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
255
- """Create SwiGLU feedforward network with layer normalization."""
256
- return nn.Sequential(
257
- nn.LayerNorm(d_model),
258
- nn.Linear(
259
- d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
260
- ),
261
- SwiGLU(),
262
- nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
263
- )
264
-
265
-
266
- ### Attention
267
- class MultiHeadAttention(nn.Module):
268
- """Multi-head attention with rotary embeddings.
269
-
270
- Args:
271
- d_model: Model dimension
272
- n_heads: Number of attention heads
273
- """
274
- def __init__(self, d_model: int, n_heads: int):
275
- super().__init__()
276
- self.d_model = d_model
277
- self.n_heads = n_heads
278
- self.d_head = self.d_model // self.n_heads
279
- self.layernorm_qkv = nn.Sequential(
280
- nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
281
- )
282
- self.out_proj = nn.Linear(d_model, d_model, bias=False)
283
- self.q_ln = nn.LayerNorm(d_model, bias=False)
284
- self.k_ln = nn.LayerNorm(d_model, bias=False)
285
- self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
286
- self.rotary = RotaryEmbedding(d_model // n_heads)
287
-
288
- def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
289
- """Apply rotary embeddings to query and key."""
290
- q = q.unflatten(-1, (self.n_heads, self.d_head))
291
- k = k.unflatten(-1, (self.n_heads, self.d_head))
292
- q, k = self.rotary(q, k)
293
- q = q.flatten(-2, -1)
294
- k = k.flatten(-2, -1)
295
- return q, k
296
-
297
- def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
298
- """
299
- Args:
300
- x: Input tensor
301
- attention_mask: Optional attention mask
302
- output_attentions: Whether to return attention weights
303
-
304
- Returns:
305
- Output tensor after self attention, and optionally attention weights
306
- """
307
- attn_weights = None
308
- qkv_BLD3 = self.layernorm_qkv(x)
309
- query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
310
- query_BLD, key_BLD = (
311
- self.q_ln(query_BLD).to(query_BLD.dtype),
312
- self.k_ln(key_BLD).to(query_BLD.dtype),
313
- )
314
- query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
315
- query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
316
-
317
- if output_attentions: # Manual attention computation
318
- L, S = query_BLD.size(-2), key_BLD.size(-2)
319
- scale = 1 / math.sqrt(query_BLD.size(-1))
320
- attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
321
- if attention_mask is not None:
322
- if attention_mask.dtype == torch.bool:
323
- attention_mask.masked_fill_(attention_mask.logical_not(), float('-inf'))
324
- else:
325
- attn_bias += attention_mask
326
-
327
- attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
328
- attn_weights += attn_bias
329
- attn_weights = F.softmax(attn_weights, dim=-1)
330
- context_BHLD = torch.matmul(attn_weights, value_BHLD)
331
- else:
332
- context_BHLD = F.scaled_dot_product_attention(
333
- query_BHLD, key_BHLD, value_BHLD, attention_mask
334
- )
335
-
336
- context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
337
- output = self.out_proj(context_BLD)
338
- return output, attn_weights
339
-
340
-
341
- ### Regression Head
342
- def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
343
- """Create a regression head with optional hidden dimension.
344
-
345
- Args:
346
- d_model: Input dimension
347
- output_dim: Output dimension
348
- hidden_dim: Optional hidden dimension (defaults to d_model)
349
- """
350
- hidden_dim = hidden_dim if hidden_dim is not None else d_model
351
- return nn.Sequential(
352
- nn.Linear(d_model, hidden_dim),
353
- nn.GELU(),
354
- nn.LayerNorm(hidden_dim),
355
- nn.Linear(hidden_dim, output_dim),
356
- )
357
-
358
-
359
- ### Transformer Block
360
- class UnifiedTransformerBlock(nn.Module):
361
- """Transformer block with attention and feedforward layers.
362
-
363
- Args:
364
- d_model: Model dimension
365
- n_heads: Number of attention heads
366
- residue_scaling_factor: Factor for scaling residual connections
367
- expansion_ratio: Expansion ratio for feedforward network
368
- """
369
- def __init__(
370
- self,
371
- d_model: int,
372
- n_heads: int,
373
- residue_scaling_factor: float = 1,
374
- expansion_ratio: float = 8 / 3,
375
- dropout: float = 0.0,
376
- ):
377
- super().__init__()
378
- self.attn = MultiHeadAttention(d_model, n_heads)
379
- self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
380
- self.scaling_factor = residue_scaling_factor
381
- self.dropout = nn.Dropout(dropout)
382
-
383
- def forward(
384
- self,
385
- x: torch.Tensor,
386
- attention_mask: Optional[torch.Tensor] = None,
387
- output_attentions: bool = False,
388
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
389
- """
390
- Args:
391
- x: Input tensor
392
- attention_mask: Optional attention mask
393
- output_attentions: Whether to return attention weights
394
-
395
- Returns:
396
- Output tensor after transformer block, and optionally attention weights
397
- """
398
- attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
399
- x = x + self.dropout(attn_output) / self.scaling_factor
400
- x = x + self.dropout(self.ffn(x)) / self.scaling_factor
401
- return x, attn_weights
402
-
403
-
404
- ### Model Outputs
405
- @dataclass
406
- class TransformerOutput(ModelOutput):
407
- """Output type for transformer encoder."""
408
- last_hidden_state: Optional[torch.Tensor] = None
409
- hidden_states: Optional[Tuple[torch.Tensor]] = None
410
- attentions: Optional[Tuple[torch.Tensor]] = None
411
-
412
-
413
- @dataclass
414
- class ESMplusplusOutput(ModelOutput):
415
- """Output type for ESM++ models."""
416
- loss: Optional[torch.Tensor] = None
417
- logits: Optional[torch.Tensor] = None
418
- last_hidden_state: Optional[torch.Tensor] = None
419
- hidden_states: Optional[Tuple[torch.Tensor]] = None
420
- attentions: Optional[Tuple[torch.Tensor]] = None
421
-
422
-
423
- ### Transformer Stack
424
- class TransformerStack(nn.Module):
425
- """Stack of transformer blocks.
426
-
427
- Args:
428
- d_model: Model dimension
429
- n_heads: Number of attention heads
430
- n_layers: Number of transformer layers
431
- dropout: Dropout rate
432
- """
433
- def __init__(
434
- self,
435
- d_model: int,
436
- n_heads: int,
437
- n_layers: int,
438
- dropout: float = 0.0,
439
- ):
440
- super().__init__()
441
- self.blocks = nn.ModuleList(
442
- [
443
- UnifiedTransformerBlock(
444
- d_model,
445
- n_heads,
446
- residue_scaling_factor=math.sqrt(n_layers / 36),
447
- dropout=dropout,
448
- )
449
- for i in range(n_layers)
450
- ]
451
- )
452
- self.norm = nn.LayerNorm(d_model, bias=False)
453
- self.gradient_checkpointing = False
454
-
455
- def forward(
456
- self,
457
- x: torch.Tensor,
458
- attention_mask: Optional[torch.Tensor] = None,
459
- output_hidden_states: bool = False,
460
- output_attentions: bool = False,
461
- ) -> TransformerOutput:
462
- """
463
- Args:
464
- x: Input tensor
465
- attention_mask: Optional attention mask
466
- output_hidden_states: Whether to return all hidden states
467
- output_attentions: Whether to return attention weights
468
-
469
- Returns:
470
- TransformerOutput containing last hidden state and optionally all hidden states and attention weights
471
- """
472
- batch_size, seq_len, _ = x.shape
473
- hidden_states = () if output_hidden_states else None
474
- attentions = () if output_attentions else None
475
-
476
- if attention_mask is not None:
477
- attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
478
-
479
- for block in self.blocks:
480
- if self.gradient_checkpointing and self.training:
481
- x, attn_weights = self._gradient_checkpointing_func(
482
- block.__call__,
483
- x,
484
- attention_mask,
485
- output_attentions,
486
- )
487
- else:
488
- x, attn_weights = block(x, attention_mask, output_attentions)
489
-
490
- if attentions is not None:
491
- attentions += (attn_weights,)
492
-
493
- if output_hidden_states:
494
- assert hidden_states is not None
495
- hidden_states += (x,)
496
-
497
- return TransformerOutput(
498
- last_hidden_state=self.norm(x),
499
- hidden_states=hidden_states,
500
- attentions=attentions
501
- )
502
-
503
-
504
- ### Dataset for Embedding
505
- class ProteinDataset(Dataset):
506
- """Simple dataset for protein sequences."""
507
- def __init__(self, sequences: list[str]):
508
- self.sequences = sequences
509
-
510
- def __len__(self) -> int:
511
- return len(self.sequences)
512
-
513
- def __getitem__(self, idx: int) -> str:
514
- return self.sequences[idx]
515
-
516
-
517
- class PreTrainedESMplusplusModel(PreTrainedModel):
518
- """
519
- init weights for ESM++ models
520
- """
521
- config_class = ESMplusplusConfig
522
- base_model_prefix = "esm++"
523
- supports_gradient_checkpointing = True
524
-
525
- def _init_weights(self, module):
526
- """Initialize the weights"""
527
- if isinstance(module, nn.Linear):
528
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
529
- if module.bias is not None:
530
- module.bias.data.zero_()
531
- elif isinstance(module, nn.Embedding):
532
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
533
- if module.padding_idx is not None:
534
- module.weight.data[module.padding_idx].zero_()
535
- elif isinstance(module, nn.LayerNorm):
536
- module.bias.data.zero_()
537
- module.weight.data.fill_(1.0)
538
-
539
- @classmethod
540
- def from_pretrained_esm(cls, model_name: str):
541
- """Load a pretrained ESM++ model."""
542
- if '300' in model_name:
543
- return ESMplusplus_300M()
544
- elif '600' in model_name:
545
- return ESMplusplus_600M()
546
- else:
547
- raise ValueError(f"Invalid model name: {model_name}")
548
-
549
- @property
550
- def device(self) -> torch.device:
551
- """Get the device of the model."""
552
- return next(self.parameters()).device
553
-
554
- def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
555
- """Apply mean pooling to sequence outputs."""
556
- if attention_mask is None:
557
- return x.mean(dim=1)
558
- else:
559
- attention_mask = attention_mask.unsqueeze(-1)
560
- return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
561
-
562
- def max_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
563
- """Apply max pooling to sequence outputs."""
564
- if attention_mask is None:
565
- return x.max(dim=1).values
566
- else:
567
- attention_mask = attention_mask.unsqueeze(-1)
568
- return (x * attention_mask).max(dim=1).values
569
-
570
- def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
571
- """Apply cls pooling to sequence outputs."""
572
- return x[:, 0, :]
573
-
574
- def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
575
- """Collate function for batching sequences."""
576
- return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
577
-
578
- def _read_sequences_from_db(self, db_path: str) -> set[str]:
579
- """Read sequences from SQLite database."""
580
- import sqlite3
581
- sequences = []
582
- with sqlite3.connect(db_path) as conn:
583
- c = conn.cursor()
584
- c.execute("SELECT sequence FROM embeddings")
585
- while True:
586
- row = c.fetchone()
587
- if row is None:
588
- break
589
- sequences.append(row[0])
590
- return set(sequences)
591
-
592
- def embed_dataset(
593
- self,
594
- sequences: list[str],
595
- batch_size: int = 2,
596
- max_len: int = 512,
597
- full_embeddings: bool = False,
598
- full_precision: bool = False,
599
- pooling_type: str = 'mean',
600
- num_workers: int = 0,
601
- sql: bool = False,
602
- sql_db_path: str = 'embeddings.db',
603
- ) -> Optional[dict[str, torch.Tensor]]:
604
- """Embed a dataset of protein sequences.
605
-
606
- Args:
607
- sequences: List of protein sequences
608
- batch_size: Batch size for processing
609
- max_len: Maximum sequence length
610
- full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
611
- full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
612
- pooling_type: Type of pooling ('mean' or 'cls')
613
- num_workers: Number of workers for data loading, 0 for the main process
614
- sql: Whether to store embeddings in SQLite database - will be stored in float32
615
- sql_db_path: Path to SQLite database
616
-
617
- Returns:
618
- Dictionary mapping sequences to embeddings, or None if sql=True
619
- """
620
- sequences = list(set([seq[:max_len] for seq in sequences]))
621
- sequences = sorted(sequences, key=len, reverse=True)
622
- dataset = ProteinDataset(sequences)
623
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
624
- device = self.device
625
-
626
- def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
627
- if full_embeddings:
628
- return residue_embeddings
629
- elif pooling_type == 'mean':
630
- return self.mean_pooling(residue_embeddings, attention_mask)
631
- elif pooling_type == 'max':
632
- return self.max_pooling(residue_embeddings, attention_mask)
633
- elif pooling_type == 'cls':
634
- return self.cls_pooling(residue_embeddings, attention_mask)
635
- else:
636
- raise ValueError(f"Invalid pooling type: {pooling_type}")
637
-
638
- if sql:
639
- import sqlite3
640
- conn = sqlite3.connect(sql_db_path)
641
- c = conn.cursor()
642
- c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
643
- already_embedded = self._read_sequences_from_db(sql_db_path)
644
- to_embed = [seq for seq in sequences if seq not in already_embedded]
645
- print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
646
- print(f"Embedding {len(to_embed)} new sequences")
647
- if len(to_embed) > 0:
648
- with torch.no_grad():
649
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
650
- seqs = to_embed[i * batch_size:(i + 1) * batch_size]
651
- input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
652
- x = self.embed(input_ids)
653
- residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
654
- embeddings = get_embeddings(residue_embeddings, attention_mask)
655
-
656
- for seq, emb, mask in zip(seqs, embeddings, attention_mask):
657
- if full_embeddings:
658
- emb = emb[mask.bool()]
659
- c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
660
- (seq, emb.cpu().numpy().tobytes()))
661
-
662
- if (i + 1) % 100 == 0:
663
- conn.commit()
664
-
665
- conn.commit()
666
- conn.close()
667
- return None
668
-
669
- embeddings_dict = {}
670
- with torch.no_grad():
671
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
672
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
673
- input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
674
- x = self.embed(input_ids)
675
- residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach()
676
- if full_precision:
677
- residue_embeddings = residue_embeddings.float()
678
- embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
679
- for seq, emb in zip(seqs, embeddings):
680
- embeddings_dict[seq] = emb
681
-
682
- return embeddings_dict
683
-
684
-
685
- ### ESM++ Models
686
- class ESMplusplusModel(PreTrainedESMplusplusModel):
687
- """
688
- ESM++ model. transformer model with no heads
689
- """
690
- config_class = ESMplusplusConfig
691
- def __init__(self, config: ESMplusplusConfig, **kwargs):
692
- super().__init__(config, **kwargs)
693
- self.config = config
694
- self.vocab_size = config.vocab_size
695
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
696
- self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
697
- self.tokenizer = EsmSequenceTokenizer()
698
- self.init_weights()
699
-
700
- def get_input_embeddings(self):
701
- return self.embed
702
-
703
- def set_input_embeddings(self, value):
704
- self.embed = value
705
-
706
- def forward(
707
- self,
708
- input_ids: Optional[torch.Tensor] = None,
709
- attention_mask: Optional[torch.Tensor] = None,
710
- inputs_embeds: Optional[torch.Tensor] = None,
711
- output_attentions: Optional[bool] = None,
712
- output_hidden_states: Optional[bool] = None,
713
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
714
- ) -> TransformerOutput:
715
- """Forward pass for masked language modeling.
716
-
717
- Args:
718
- input_ids: Input token IDs
719
- attention_mask: Attention mask
720
- inputs_embeds: Optional precomputed embeddings
721
- output_hidden_states: Whether to return all hidden states
722
- output_attentions: Whether to return attention weights
723
-
724
- Returns:
725
- TransformerOutput containing last hidden state and optionally all hidden states and attention weights
726
- """
727
- if inputs_embeds is None:
728
- x = self.embed(input_ids)
729
- else:
730
- x = inputs_embeds
731
- return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
732
-
733
-
734
- class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
735
- """
736
- ESM++ model for masked language modeling.
737
- Implements the base ESM++ architecture with a masked language modeling head.
738
- """
739
- config_class = ESMplusplusConfig
740
- def __init__(self, config: ESMplusplusConfig, **kwargs):
741
- super().__init__(config, **kwargs)
742
- self.config = config
743
- self.vocab_size = config.vocab_size
744
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
745
- self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
746
- self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
747
- self.ce_loss = nn.CrossEntropyLoss()
748
- self.tokenizer = EsmSequenceTokenizer()
749
- self.init_weights()
750
-
751
- def get_input_embeddings(self):
752
- return self.embed
753
-
754
- def set_input_embeddings(self, value):
755
- self.embed = value
756
-
757
- def get_output_embeddings(self):
758
- return self.sequence_head[-1]
759
-
760
- def set_output_embeddings(self, new_embeddings):
761
- self.sequence_head[-1] = new_embeddings
762
-
763
- def forward(
764
- self,
765
- input_ids: Optional[torch.Tensor] = None,
766
- attention_mask: Optional[torch.Tensor] = None,
767
- inputs_embeds: Optional[torch.Tensor] = None,
768
- labels: Optional[torch.Tensor] = None,
769
- output_attentions: Optional[bool] = None,
770
- output_hidden_states: Optional[bool] = None,
771
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
772
- ) -> ESMplusplusOutput:
773
- """Forward pass for masked language modeling.
774
-
775
- Args:
776
- input_ids: Input token IDs
777
- attention_mask: Attention mask
778
- inputs_embeds: Optional precomputed embeddings
779
- labels: Optional labels for masked tokens
780
- output_hidden_states: Whether to return all hidden states
781
- output_attentions: Whether to return attention weights
782
-
783
- Returns:
784
- ESMplusplusOutput containing loss, logits, hidden states and attention weights
785
- """
786
- if inputs_embeds is None:
787
- x = self.embed(input_ids)
788
- else:
789
- x = inputs_embeds
790
- output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
791
- x = output.last_hidden_state
792
- logits = self.sequence_head(x)
793
- loss = None
794
- if labels is not None:
795
- loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
796
- return ESMplusplusOutput(
797
- loss=loss,
798
- logits=logits,
799
- last_hidden_state=x,
800
- hidden_states=output.hidden_states,
801
- attentions=output.attentions,
802
- )
803
-
804
-
805
- class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
806
- """
807
- ESM++ model for sequence classification.
808
- Extends the base ESM++ model with a classification head.
809
- """
810
- def __init__(self, config: ESMplusplusConfig, **kwargs):
811
- super().__init__(config, **kwargs)
812
- self.config = config
813
- self.num_labels = config.num_labels
814
- self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
815
- # Large intermediate projections help with sequence classification tasks (*4)
816
- self.mse = nn.MSELoss()
817
- self.ce = nn.CrossEntropyLoss()
818
- self.bce = nn.BCEWithLogitsLoss()
819
- self.init_weights()
820
-
821
- def forward(
822
- self,
823
- input_ids: Optional[torch.Tensor] = None,
824
- attention_mask: Optional[torch.Tensor] = None,
825
- inputs_embeds: Optional[torch.Tensor] = None,
826
- labels: Optional[torch.Tensor] = None,
827
- output_attentions: Optional[bool] = None,
828
- output_hidden_states: Optional[bool] = None,
829
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
830
- ) -> ESMplusplusOutput:
831
- """Forward pass for sequence classification.
832
-
833
- Args:
834
- input_ids: Input token IDs
835
- attention_mask: Attention mask
836
- inputs_embeds: Optional precomputed embeddings
837
- labels: Optional labels for classification
838
- output_hidden_states: Whether to return all hidden states
839
- output_attentions: Whether to return attention weights
840
-
841
- Returns:
842
- ESMplusplusOutput containing loss, logits, and hidden states
843
- """
844
- output = super().forward(
845
- input_ids=input_ids,
846
- attention_mask=attention_mask,
847
- inputs_embeds=inputs_embeds,
848
- labels=None,
849
- output_attentions=output_attentions,
850
- output_hidden_states=output_hidden_states
851
- )
852
- x = output.last_hidden_state
853
- cls_features = x[:, 0, :]
854
- mean_features = self.mean_pooling(x, attention_mask)
855
- # we include mean pooling features to help with early convergence, the cost of this is basically zero
856
- features = torch.cat([cls_features, mean_features], dim=-1)
857
- logits = self.classifier(features)
858
- loss = None
859
- if labels is not None:
860
- labels = labels.to(logits.device)
861
- if self.config.problem_type is None:
862
- if self.num_labels == 1:
863
- self.config.problem_type = "regression"
864
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
865
- self.config.problem_type = "single_label_classification"
866
- else:
867
- self.config.problem_type = "multi_label_classification"
868
-
869
- if self.config.problem_type == "regression":
870
- if self.num_labels == 1:
871
- loss = self.mse(logits.flatten(), labels.flatten())
872
- else:
873
- loss = self.mse(logits, labels)
874
- elif self.config.problem_type == "single_label_classification":
875
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
876
- elif self.config.problem_type == "multi_label_classification":
877
- loss = self.bce(logits, labels)
878
- return ESMplusplusOutput(
879
- loss=loss,
880
- logits=logits,
881
- last_hidden_state=x,
882
- hidden_states=output.hidden_states,
883
- )
884
-
885
-
886
- class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
887
- """
888
- ESM++ model for token classification.
889
- Extends the base ESM++ model with a token classification head.
890
- """
891
- def __init__(self, config: ESMplusplusConfig):
892
- super().__init__(config)
893
- self.config = config
894
- self.num_labels = config.num_labels
895
- self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
896
- # Large intermediate projections help with sequence classification tasks (*4)
897
- self.loss_fct = nn.CrossEntropyLoss()
898
- self.init_weights()
899
-
900
- def forward(
901
- self,
902
- input_ids: Optional[torch.Tensor] = None,
903
- attention_mask: Optional[torch.Tensor] = None,
904
- inputs_embeds: Optional[torch.Tensor] = None,
905
- labels: Optional[torch.Tensor] = None,
906
- output_attentions: Optional[bool] = None,
907
- output_hidden_states: Optional[bool] = None,
908
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
909
- ) -> ESMplusplusOutput:
910
- """Forward pass for token classification.
911
-
912
- Args:
913
- input_ids: Input token IDs
914
- attention_mask: Attention mask
915
- inputs_embeds: Optional precomputed embeddings
916
- labels: Optional labels for token classification
917
- output_hidden_states: Whether to return all hidden states
918
- output_attentions: Whether to return attention weights
919
-
920
- Returns:
921
- ESMplusplusOutput containing loss, logits, and hidden states
922
- """
923
- output = super().forward(
924
- input_ids=input_ids,
925
- attention_mask=attention_mask,
926
- inputs_embeds=inputs_embeds,
927
- labels=None,
928
- output_attentions=output_attentions,
929
- output_hidden_states=output_hidden_states
930
- )
931
- x = output.last_hidden_state
932
- logits = self.classifier(x)
933
- loss = None
934
- if labels is not None:
935
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
936
- return ESMplusplusOutput(
937
- loss=loss,
938
- logits=logits,
939
- last_hidden_state=x,
940
- hidden_states=output.hidden_states,
941
- )
942
-
943
-
944
- ### Loading from EvolutionaryScale
945
- @staticmethod
946
- @cache
947
- def data_root(model: str):
948
- if "INFRA_PROVIDER" in os.environ:
949
- return Path("")
950
- # Try to download from hugginface if it doesn't exist
951
- if model.startswith("esmc-300"):
952
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
953
- elif model.startswith("esmc-600"):
954
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
955
- else:
956
- raise ValueError(f"{model=} is an invalid model name.")
957
- return path
958
-
959
-
960
- def ESMplusplus_300M(device: torch.device | str = "cpu"):
961
- with torch.device(device):
962
- config = ESMplusplusConfig(
963
- hidden_size=960,
964
- num_attention_heads=15,
965
- num_hidden_layers=30,
966
- )
967
- model = ESMplusplusForMaskedLM(config)
968
- state_dict = torch.load(
969
- data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
970
- map_location=device,
971
- )
972
- model.load_state_dict(state_dict)
973
- return model
974
-
975
-
976
- def ESMplusplus_600M(device: torch.device | str = "cpu"):
977
- with torch.device(device):
978
- config = ESMplusplusConfig(
979
- hidden_size=1152,
980
- num_attention_heads=18,
981
- num_hidden_layers=36,
982
- )
983
- model = ESMplusplusForMaskedLM(config)
984
- state_dict = torch.load(
985
- data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
986
- map_location=device,
987
- )
988
- model.load_state_dict(state_dict)
989
- return model
990
-
991
-
992
- ### Tokenization
993
- SEQUENCE_VOCAB = [
994
- "<cls>", "<pad>", "<eos>", "<unk>",
995
- "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
996
- "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
997
- "O", ".", "-", "|",
998
- "<mask>",
999
- ]
1000
-
1001
- class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1002
- model_input_names = ["input_ids", "attention_mask"]
1003
-
1004
- def __init__(
1005
- self,
1006
- unk_token="<unk>",
1007
- cls_token="<cls>",
1008
- pad_token="<pad>",
1009
- mask_token="<mask>",
1010
- eos_token="<eos>",
1011
- chain_break_token="|",
1012
- **kwargs,
1013
- ):
1014
- all_tokens = SEQUENCE_VOCAB
1015
- token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1016
-
1017
- # a character-level tokenizer is the same as BPE with no token merges
1018
- bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1019
- tokenizer = Tokenizer(bpe)
1020
- special_tokens = [
1021
- cls_token,
1022
- pad_token,
1023
- mask_token,
1024
- eos_token,
1025
- chain_break_token,
1026
- ]
1027
- self.cb_token = chain_break_token
1028
- additional_special_tokens = [chain_break_token]
1029
-
1030
- tokenizer.add_special_tokens(special_tokens)
1031
-
1032
- # This is where we configure the automatic addition of special tokens when we call
1033
- # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1034
- # sequences are merged if you want.
1035
- tokenizer.post_processor = TemplateProcessing( # type: ignore
1036
- single="<cls> $A <eos>",
1037
- special_tokens=[
1038
- ("<cls>", tokenizer.token_to_id("<cls>")),
1039
- ("<eos>", tokenizer.token_to_id("<eos>")),
1040
- ],
1041
- )
1042
- super().__init__(
1043
- tokenizer_object=tokenizer,
1044
- unk_token=unk_token,
1045
- cls_token=cls_token,
1046
- pad_token=pad_token,
1047
- mask_token=mask_token,
1048
- eos_token=eos_token,
1049
- additional_special_tokens=additional_special_tokens,
1050
- **kwargs,
1051
- )
1052
-
1053
- # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1054
- @property
1055
- def bos_token(self):
1056
- return self.cls_token
1057
-
1058
- @property
1059
- def bos_token_id(self):
1060
- return self.cls_token_id
1061
-
1062
- @property
1063
- def chain_break_token(self):
1064
- return self.cb_token
1065
-
1066
- @property
1067
- def chain_break_token_id(self):
1068
- return self.convert_tokens_to_ids(self.chain_break_token)
1069
-
1070
- @property
1071
- def all_token_ids(self):
1072
- return list(range(self.vocab_size))
1073
-
1074
- @property
1075
- def special_token_ids(self):
1076
- return self.all_special_ids
 
1
+ """
2
+ ESM++ model implementation.
3
+
4
+ ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
+ The ESM Python package is not required
6
+
7
+ Modified from https://github.com/evolutionaryscale/esm
8
+ License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from dataclasses import dataclass
17
+ from functools import cache, partial
18
+ from pathlib import Path
19
+ from typing import Optional, Tuple, Union
20
+ from einops import rearrange, repeat
21
+ from huggingface_hub import snapshot_download
22
+ from tokenizers import Tokenizer
23
+ from tokenizers.models import BPE
24
+ from tokenizers.processors import TemplateProcessing
25
+ from torch.utils.data import Dataset, DataLoader
26
+ from tqdm.auto import tqdm
27
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
28
+ from transformers.modeling_outputs import ModelOutput
29
+
30
+
31
+ class ESMplusplusConfig(PretrainedConfig):
32
+ """Configuration class for ESM++ model.
33
+
34
+ Args:
35
+ vocab_size: Size of the vocabulary
36
+ hidden_size: Dimension of hidden layers
37
+ num_attention_heads: Number of attention heads
38
+ num_hidden_layers: Number of transformer layers
39
+ num_labels: Number of output labels for classification
40
+ problem_type: Type of problem - regression, single/multi label classification
41
+ """
42
+ model_type = "ESMplusplus"
43
+ def __init__(
44
+ self,
45
+ vocab_size: int = 64,
46
+ hidden_size: int = 960,
47
+ num_attention_heads: int = 15,
48
+ num_hidden_layers: int = 30,
49
+ num_labels: int = 2,
50
+ problem_type: str | None = None,
51
+ dropout: float = 0.0,
52
+ initializer_range: float = 0.02,
53
+ **kwargs,
54
+ ):
55
+ super().__init__(**kwargs)
56
+ self.vocab_size = vocab_size
57
+ self.hidden_size = hidden_size
58
+ self.num_attention_heads = num_attention_heads
59
+ self.num_hidden_layers = num_hidden_layers
60
+ self.num_labels = num_labels
61
+ self.problem_type = problem_type
62
+ self.dropout = dropout
63
+ self.initializer_range = initializer_range
64
+
65
+
66
+ ### Rotary Embeddings
67
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
68
+ """Rotates half the hidden dims of the input."""
69
+ if not interleaved:
70
+ x1, x2 = x.chunk(2, dim=-1)
71
+ return torch.cat((-x2, x1), dim=-1)
72
+ else:
73
+ x1, x2 = x[..., ::2], x[..., 1::2]
74
+ return rearrange(
75
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
76
+ )
77
+
78
+
79
+ def apply_rotary_emb_torch(
80
+ x: torch.Tensor,
81
+ cos: torch.Tensor,
82
+ sin: torch.Tensor,
83
+ interleaved: bool = False,
84
+ _inplace: bool = False,
85
+ ) -> torch.Tensor:
86
+ """Apply rotary embeddings to input based on cos and sin."""
87
+ ro_dim = cos.shape[-1] * 2
88
+ assert ro_dim <= x.shape[-1]
89
+ seqlen = x.size(1)
90
+ cos = cos[:seqlen]
91
+ sin = sin[:seqlen]
92
+ cos = repeat(cos, "s d -> s 1 (2 d)")
93
+ sin = repeat(sin, "s d -> s 1 (2 d)")
94
+ return torch.cat(
95
+ [
96
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
97
+ x[..., ro_dim:],
98
+ ],
99
+ dim=-1,
100
+ )
101
+
102
+
103
+ class RotaryEmbedding(torch.nn.Module):
104
+ """Rotary position embeddings.
105
+
106
+ Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
107
+
108
+ Args:
109
+ dim: Dimension of the embedding
110
+ base: Base for computing angular frequencies
111
+ interleaved: Whether to use interleaved rotations
112
+ scale_base: Base for scaling
113
+ scaling_factor: Factor for scaling positions
114
+ pos_idx_in_fp32: Whether to compute position indices in fp32
115
+ device: Computation device
116
+ """
117
+ def __init__(
118
+ self,
119
+ dim: int,
120
+ base: float = 10000.0,
121
+ interleaved: bool = False,
122
+ scale_base: Optional[float] = None,
123
+ scaling_factor: float = 1.0,
124
+ pos_idx_in_fp32: bool = True,
125
+ device: Optional[torch.device] = None,
126
+ ):
127
+ super().__init__()
128
+ self.dim = dim
129
+ self.base = float(base)
130
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
131
+ self.interleaved = interleaved
132
+ self.scale_base = scale_base
133
+ self.scaling_factor = scaling_factor
134
+ self.device = device
135
+
136
+ self._seq_len_cached = 0
137
+ self._cos_cached = None
138
+ self._sin_cached = None
139
+ self._cos_k_cached = None
140
+ self._sin_k_cached = None
141
+ self.reset_parameters()
142
+
143
+ def reset_parameters(self):
144
+ """Reset the parameters of the embedding."""
145
+ inv_freq = self._compute_inv_freq(self.device)
146
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
147
+ arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
148
+ scale = (
149
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
150
+ if self.scale_base is not None
151
+ else None
152
+ )
153
+ self.register_buffer("scale", scale)
154
+
155
+ def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
156
+ """Compute inverse frequency bands."""
157
+ return 1 / (
158
+ self.base
159
+ ** (
160
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
161
+ / self.dim
162
+ )
163
+ )
164
+
165
+ def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
166
+ """Update the cached cosine and sine values."""
167
+ if (
168
+ seqlen > self._seq_len_cached
169
+ or self._cos_cached is None
170
+ or self._cos_cached.device != device
171
+ or self._cos_cached.dtype != dtype
172
+ or (self.training and self._cos_cached.is_inference())
173
+ ):
174
+ self._seq_len_cached = seqlen
175
+ if self.pos_idx_in_fp32:
176
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
177
+ t /= self.scaling_factor
178
+ if self.inv_freq.dtype != torch.float32:
179
+ inv_freq = self.inv_freq.to(torch.float32)
180
+ else:
181
+ inv_freq = self.inv_freq
182
+ else:
183
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
184
+ t /= self.scaling_factor
185
+ inv_freq = self.inv_freq
186
+ freqs = torch.outer(t, inv_freq)
187
+
188
+ if self.scale is None:
189
+ self._cos_cached = torch.cos(freqs).to(dtype)
190
+ self._sin_cached = torch.sin(freqs).to(dtype)
191
+ else:
192
+ power = (
193
+ torch.arange(
194
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
195
+ )
196
+ - seqlen // 2
197
+ ) / self.scale_base
198
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
199
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
200
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
201
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
202
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
203
+
204
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
205
+ """Apply rotary embeddings to queries and keys.
206
+
207
+ Args:
208
+ q: Query tensor of shape (batch, seqlen, nheads, headdim)
209
+ k: Key tensor of shape (batch, seqlen, nheads, headdim)
210
+
211
+ Returns:
212
+ Tuple of rotated query and key tensors
213
+ """
214
+ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
215
+ assert self._cos_cached is not None
216
+ assert self._sin_cached is not None
217
+ if self.scale is None:
218
+ return (
219
+ apply_rotary_emb_torch(
220
+ q,
221
+ self._cos_cached,
222
+ self._sin_cached,
223
+ self.interleaved,
224
+ True, # inplace=True
225
+ ),
226
+ apply_rotary_emb_torch(
227
+ k,
228
+ self._cos_cached,
229
+ self._sin_cached,
230
+ self.interleaved,
231
+ True, # inplace=True
232
+ ),
233
+ ) # type: ignore
234
+ else:
235
+ assert False
236
+
237
+
238
+ ### Feedforward Network Components
239
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
240
+ """Compute corrected dimension for SwiGLU."""
241
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
242
+
243
+
244
+ class SwiGLU(nn.Module):
245
+ """SwiGLU activation function."""
246
+ def __init__(self):
247
+ super(SwiGLU, self).__init__()
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ x1, x2 = x.chunk(2, dim=-1)
251
+ return F.silu(x1) * x2
252
+
253
+
254
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
255
+ """Create SwiGLU feedforward network with layer normalization."""
256
+ return nn.Sequential(
257
+ nn.LayerNorm(d_model),
258
+ nn.Linear(
259
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
260
+ ),
261
+ SwiGLU(),
262
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
263
+ )
264
+
265
+
266
+ ### Attention
267
+ class MultiHeadAttention(nn.Module):
268
+ """Multi-head attention with rotary embeddings.
269
+
270
+ Args:
271
+ d_model: Model dimension
272
+ n_heads: Number of attention heads
273
+ """
274
+ def __init__(self, d_model: int, n_heads: int):
275
+ super().__init__()
276
+ self.d_model = d_model
277
+ self.n_heads = n_heads
278
+ self.d_head = self.d_model // self.n_heads
279
+ self.layernorm_qkv = nn.Sequential(
280
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
281
+ )
282
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
283
+ self.q_ln = nn.LayerNorm(d_model, bias=False)
284
+ self.k_ln = nn.LayerNorm(d_model, bias=False)
285
+ self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
286
+ self.rotary = RotaryEmbedding(d_model // n_heads)
287
+
288
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
289
+ """Apply rotary embeddings to query and key."""
290
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
291
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
292
+ q, k = self.rotary(q, k)
293
+ q = q.flatten(-2, -1)
294
+ k = k.flatten(-2, -1)
295
+ return q, k
296
+
297
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
298
+ """
299
+ Args:
300
+ x: Input tensor
301
+ attention_mask: Optional attention mask
302
+ output_attentions: Whether to return attention weights
303
+
304
+ Returns:
305
+ Output tensor after self attention, and optionally attention weights
306
+ """
307
+ attn_weights = None
308
+ qkv_BLD3 = self.layernorm_qkv(x)
309
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
310
+ query_BLD, key_BLD = (
311
+ self.q_ln(query_BLD).to(query_BLD.dtype),
312
+ self.k_ln(key_BLD).to(query_BLD.dtype),
313
+ )
314
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
315
+ query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
316
+
317
+ if output_attentions: # Manual attention computation
318
+ L, S = query_BLD.size(-2), key_BLD.size(-2)
319
+ scale = 1 / math.sqrt(query_BLD.size(-1))
320
+ attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
321
+ if attention_mask is not None:
322
+ if attention_mask.dtype == torch.bool:
323
+ attention_mask.masked_fill_(attention_mask.logical_not(), float('-inf'))
324
+ else:
325
+ attn_bias += attention_mask
326
+
327
+ attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
328
+ attn_weights += attn_bias
329
+ attn_weights = F.softmax(attn_weights, dim=-1)
330
+ context_BHLD = torch.matmul(attn_weights, value_BHLD)
331
+ else:
332
+ context_BHLD = F.scaled_dot_product_attention(
333
+ query_BHLD, key_BHLD, value_BHLD, attention_mask
334
+ )
335
+
336
+ context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
337
+ output = self.out_proj(context_BLD)
338
+ return output, attn_weights
339
+
340
+
341
+ ### Regression Head
342
+ def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
343
+ """Create a regression head with optional hidden dimension.
344
+
345
+ Args:
346
+ d_model: Input dimension
347
+ output_dim: Output dimension
348
+ hidden_dim: Optional hidden dimension (defaults to d_model)
349
+ """
350
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
351
+ return nn.Sequential(
352
+ nn.Linear(d_model, hidden_dim),
353
+ nn.GELU(),
354
+ nn.LayerNorm(hidden_dim),
355
+ nn.Linear(hidden_dim, output_dim),
356
+ )
357
+
358
+
359
+ ### Transformer Block
360
+ class UnifiedTransformerBlock(nn.Module):
361
+ """Transformer block with attention and feedforward layers.
362
+
363
+ Args:
364
+ d_model: Model dimension
365
+ n_heads: Number of attention heads
366
+ residue_scaling_factor: Factor for scaling residual connections
367
+ expansion_ratio: Expansion ratio for feedforward network
368
+ """
369
+ def __init__(
370
+ self,
371
+ d_model: int,
372
+ n_heads: int,
373
+ residue_scaling_factor: float = 1,
374
+ expansion_ratio: float = 8 / 3,
375
+ dropout: float = 0.0,
376
+ ):
377
+ super().__init__()
378
+ self.attn = MultiHeadAttention(d_model, n_heads)
379
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
380
+ self.scaling_factor = residue_scaling_factor
381
+ self.dropout = nn.Dropout(dropout)
382
+
383
+ def forward(
384
+ self,
385
+ x: torch.Tensor,
386
+ attention_mask: Optional[torch.Tensor] = None,
387
+ output_attentions: bool = False,
388
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
389
+ """
390
+ Args:
391
+ x: Input tensor
392
+ attention_mask: Optional attention mask
393
+ output_attentions: Whether to return attention weights
394
+
395
+ Returns:
396
+ Output tensor after transformer block, and optionally attention weights
397
+ """
398
+ attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
399
+ x = x + self.dropout(attn_output) / self.scaling_factor
400
+ x = x + self.dropout(self.ffn(x)) / self.scaling_factor
401
+ return x, attn_weights
402
+
403
+
404
+ ### Model Outputs
405
+ @dataclass
406
+ class TransformerOutput(ModelOutput):
407
+ """Output type for transformer encoder."""
408
+ last_hidden_state: Optional[torch.Tensor] = None
409
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
410
+ attentions: Optional[Tuple[torch.Tensor]] = None
411
+
412
+
413
+ @dataclass
414
+ class ESMplusplusOutput(ModelOutput):
415
+ """Output type for ESM++ models."""
416
+ loss: Optional[torch.Tensor] = None
417
+ logits: Optional[torch.Tensor] = None
418
+ last_hidden_state: Optional[torch.Tensor] = None
419
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
420
+ attentions: Optional[Tuple[torch.Tensor]] = None
421
+
422
+
423
+ ### Transformer Stack
424
+ class TransformerStack(nn.Module):
425
+ """Stack of transformer blocks.
426
+
427
+ Args:
428
+ d_model: Model dimension
429
+ n_heads: Number of attention heads
430
+ n_layers: Number of transformer layers
431
+ dropout: Dropout rate
432
+ """
433
+ def __init__(
434
+ self,
435
+ d_model: int,
436
+ n_heads: int,
437
+ n_layers: int,
438
+ dropout: float = 0.0,
439
+ ):
440
+ super().__init__()
441
+ self.blocks = nn.ModuleList(
442
+ [
443
+ UnifiedTransformerBlock(
444
+ d_model,
445
+ n_heads,
446
+ residue_scaling_factor=math.sqrt(n_layers / 36),
447
+ dropout=dropout,
448
+ )
449
+ for i in range(n_layers)
450
+ ]
451
+ )
452
+ self.norm = nn.LayerNorm(d_model, bias=False)
453
+ self.gradient_checkpointing = False
454
+
455
+ def forward(
456
+ self,
457
+ x: torch.Tensor,
458
+ attention_mask: Optional[torch.Tensor] = None,
459
+ output_hidden_states: bool = False,
460
+ output_attentions: bool = False,
461
+ ) -> TransformerOutput:
462
+ """
463
+ Args:
464
+ x: Input tensor
465
+ attention_mask: Optional attention mask
466
+ output_hidden_states: Whether to return all hidden states
467
+ output_attentions: Whether to return attention weights
468
+
469
+ Returns:
470
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
471
+ """
472
+ batch_size, seq_len, _ = x.shape
473
+ hidden_states = () if output_hidden_states else None
474
+ attentions = () if output_attentions else None
475
+
476
+ if attention_mask is not None:
477
+ attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
478
+
479
+ for block in self.blocks:
480
+ if self.gradient_checkpointing and self.training:
481
+ x, attn_weights = self._gradient_checkpointing_func(
482
+ block.__call__,
483
+ x,
484
+ attention_mask,
485
+ output_attentions,
486
+ )
487
+ else:
488
+ x, attn_weights = block(x, attention_mask, output_attentions)
489
+
490
+ if attentions is not None:
491
+ attentions += (attn_weights,)
492
+
493
+ if output_hidden_states:
494
+ assert hidden_states is not None
495
+ hidden_states += (x,)
496
+
497
+ return TransformerOutput(
498
+ last_hidden_state=self.norm(x),
499
+ hidden_states=hidden_states,
500
+ attentions=attentions
501
+ )
502
+
503
+
504
+ ### Dataset for Embedding
505
+ class ProteinDataset(Dataset):
506
+ """Simple dataset for protein sequences."""
507
+ def __init__(self, sequences: list[str]):
508
+ self.sequences = sequences
509
+
510
+ def __len__(self) -> int:
511
+ return len(self.sequences)
512
+
513
+ def __getitem__(self, idx: int) -> str:
514
+ return self.sequences[idx]
515
+
516
+
517
+ class PreTrainedESMplusplusModel(PreTrainedModel):
518
+ """
519
+ init weights for ESM++ models
520
+ """
521
+ config_class = ESMplusplusConfig
522
+ base_model_prefix = "esm++"
523
+ supports_gradient_checkpointing = True
524
+
525
+ def _init_weights(self, module):
526
+ """Initialize the weights"""
527
+ if isinstance(module, nn.Linear):
528
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
529
+ if module.bias is not None:
530
+ module.bias.data.zero_()
531
+ elif isinstance(module, nn.Embedding):
532
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
533
+ if module.padding_idx is not None:
534
+ module.weight.data[module.padding_idx].zero_()
535
+ elif isinstance(module, nn.LayerNorm):
536
+ module.bias.data.zero_()
537
+ module.weight.data.fill_(1.0)
538
+
539
+ @classmethod
540
+ def from_pretrained_esm(cls, model_name: str):
541
+ """Load a pretrained ESM++ model."""
542
+ if '300' in model_name:
543
+ return ESMplusplus_300M()
544
+ elif '600' in model_name:
545
+ return ESMplusplus_600M()
546
+ else:
547
+ raise ValueError(f"Invalid model name: {model_name}")
548
+
549
+ @property
550
+ def device(self) -> torch.device:
551
+ """Get the device of the model."""
552
+ return next(self.parameters()).device
553
+
554
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
555
+ """Apply mean pooling to sequence outputs."""
556
+ if attention_mask is None:
557
+ return x.mean(dim=1)
558
+ else:
559
+ attention_mask = attention_mask.unsqueeze(-1)
560
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
561
+
562
+ def max_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
563
+ """Apply max pooling to sequence outputs."""
564
+ if attention_mask is None:
565
+ return x.max(dim=1).values
566
+ else:
567
+ attention_mask = attention_mask.unsqueeze(-1)
568
+ return (x * attention_mask).max(dim=1).values
569
+
570
+ def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
571
+ """Apply cls pooling to sequence outputs."""
572
+ return x[:, 0, :]
573
+
574
+ def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
575
+ """Collate function for batching sequences."""
576
+ return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
577
+
578
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
579
+ """Read sequences from SQLite database."""
580
+ import sqlite3
581
+ sequences = []
582
+ with sqlite3.connect(db_path) as conn:
583
+ c = conn.cursor()
584
+ c.execute("SELECT sequence FROM embeddings")
585
+ while True:
586
+ row = c.fetchone()
587
+ if row is None:
588
+ break
589
+ sequences.append(row[0])
590
+ return set(sequences)
591
+
592
+ def embed_dataset(
593
+ self,
594
+ sequences: list[str],
595
+ batch_size: int = 2,
596
+ max_len: int = 512,
597
+ full_embeddings: bool = False,
598
+ full_precision: bool = False,
599
+ pooling_type: str = 'mean',
600
+ num_workers: int = 0,
601
+ sql: bool = False,
602
+ sql_db_path: str = 'embeddings.db',
603
+ ) -> Optional[dict[str, torch.Tensor]]:
604
+ """Embed a dataset of protein sequences.
605
+
606
+ Args:
607
+ sequences: List of protein sequences
608
+ batch_size: Batch size for processing
609
+ max_len: Maximum sequence length
610
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
611
+ full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
612
+ pooling_type: Type of pooling ('mean' or 'cls')
613
+ num_workers: Number of workers for data loading, 0 for the main process
614
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
615
+ sql_db_path: Path to SQLite database
616
+
617
+ Returns:
618
+ Dictionary mapping sequences to embeddings, or None if sql=True
619
+ """
620
+ sequences = list(set([seq[:max_len] for seq in sequences]))
621
+ sequences = sorted(sequences, key=len, reverse=True)
622
+ dataset = ProteinDataset(sequences)
623
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
624
+ device = self.device
625
+
626
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
627
+ if full_embeddings:
628
+ return residue_embeddings
629
+ elif pooling_type == 'mean':
630
+ return self.mean_pooling(residue_embeddings, attention_mask)
631
+ elif pooling_type == 'max':
632
+ return self.max_pooling(residue_embeddings, attention_mask)
633
+ elif pooling_type == 'cls':
634
+ return self.cls_pooling(residue_embeddings, attention_mask)
635
+ else:
636
+ raise ValueError(f"Invalid pooling type: {pooling_type}")
637
+
638
+ if sql:
639
+ import sqlite3
640
+ conn = sqlite3.connect(sql_db_path)
641
+ c = conn.cursor()
642
+ c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
643
+ already_embedded = self._read_sequences_from_db(sql_db_path)
644
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
645
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
646
+ print(f"Embedding {len(to_embed)} new sequences")
647
+ if len(to_embed) > 0:
648
+ with torch.no_grad():
649
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
650
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
651
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
652
+ x = self.embed(input_ids)
653
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
654
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
655
+
656
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
657
+ if full_embeddings:
658
+ emb = emb[mask.bool()]
659
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
660
+ (seq, emb.cpu().numpy().tobytes()))
661
+
662
+ if (i + 1) % 100 == 0:
663
+ conn.commit()
664
+
665
+ conn.commit()
666
+ conn.close()
667
+ return None
668
+
669
+ embeddings_dict = {}
670
+ with torch.no_grad():
671
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
672
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
673
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
674
+ x = self.embed(input_ids)
675
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach()
676
+ if full_precision:
677
+ residue_embeddings = residue_embeddings.float()
678
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
679
+ for seq, emb in zip(seqs, embeddings):
680
+ embeddings_dict[seq] = emb
681
+
682
+ return embeddings_dict
683
+
684
+
685
+ ### ESM++ Models
686
+ class ESMplusplusModel(PreTrainedESMplusplusModel):
687
+ """
688
+ ESM++ model. transformer model with no heads
689
+ """
690
+ config_class = ESMplusplusConfig
691
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
692
+ super().__init__(config, **kwargs)
693
+ self.config = config
694
+ self.vocab_size = config.vocab_size
695
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
696
+ self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
697
+ self.tokenizer = EsmSequenceTokenizer()
698
+ self.init_weights()
699
+
700
+ def get_input_embeddings(self):
701
+ return self.embed
702
+
703
+ def set_input_embeddings(self, value):
704
+ self.embed = value
705
+
706
+ def forward(
707
+ self,
708
+ input_ids: Optional[torch.Tensor] = None,
709
+ attention_mask: Optional[torch.Tensor] = None,
710
+ inputs_embeds: Optional[torch.Tensor] = None,
711
+ output_attentions: Optional[bool] = None,
712
+ output_hidden_states: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
714
+ ) -> TransformerOutput:
715
+ """Forward pass for masked language modeling.
716
+
717
+ Args:
718
+ input_ids: Input token IDs
719
+ attention_mask: Attention mask
720
+ inputs_embeds: Optional precomputed embeddings
721
+ output_hidden_states: Whether to return all hidden states
722
+ output_attentions: Whether to return attention weights
723
+
724
+ Returns:
725
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
726
+ """
727
+ if inputs_embeds is None:
728
+ x = self.embed(input_ids)
729
+ else:
730
+ x = inputs_embeds
731
+ return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
732
+
733
+
734
+ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
735
+ """
736
+ ESM++ model for masked language modeling.
737
+ Implements the base ESM++ architecture with a masked language modeling head.
738
+ """
739
+ config_class = ESMplusplusConfig
740
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
741
+ super().__init__(config, **kwargs)
742
+ self.config = config
743
+ self.vocab_size = config.vocab_size
744
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
745
+ self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
746
+ self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
747
+ self.ce_loss = nn.CrossEntropyLoss()
748
+ self.tokenizer = EsmSequenceTokenizer()
749
+ self.init_weights()
750
+
751
+ def get_input_embeddings(self):
752
+ return self.embed
753
+
754
+ def set_input_embeddings(self, value):
755
+ self.embed = value
756
+
757
+ def get_output_embeddings(self):
758
+ return self.sequence_head[-1]
759
+
760
+ def set_output_embeddings(self, new_embeddings):
761
+ self.sequence_head[-1] = new_embeddings
762
+
763
+ def forward(
764
+ self,
765
+ input_ids: Optional[torch.Tensor] = None,
766
+ attention_mask: Optional[torch.Tensor] = None,
767
+ inputs_embeds: Optional[torch.Tensor] = None,
768
+ labels: Optional[torch.Tensor] = None,
769
+ output_attentions: Optional[bool] = None,
770
+ output_hidden_states: Optional[bool] = None,
771
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
772
+ ) -> ESMplusplusOutput:
773
+ """Forward pass for masked language modeling.
774
+
775
+ Args:
776
+ input_ids: Input token IDs
777
+ attention_mask: Attention mask
778
+ inputs_embeds: Optional precomputed embeddings
779
+ labels: Optional labels for masked tokens
780
+ output_hidden_states: Whether to return all hidden states
781
+ output_attentions: Whether to return attention weights
782
+
783
+ Returns:
784
+ ESMplusplusOutput containing loss, logits, hidden states and attention weights
785
+ """
786
+ if inputs_embeds is None:
787
+ x = self.embed(input_ids)
788
+ else:
789
+ x = inputs_embeds
790
+ output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
791
+ x = output.last_hidden_state
792
+ logits = self.sequence_head(x)
793
+ loss = None
794
+ if labels is not None:
795
+ loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
796
+ return ESMplusplusOutput(
797
+ loss=loss,
798
+ logits=logits,
799
+ last_hidden_state=x,
800
+ hidden_states=output.hidden_states,
801
+ attentions=output.attentions,
802
+ )
803
+
804
+
805
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
806
+ """
807
+ ESM++ model for sequence classification.
808
+ Extends the base ESM++ model with a classification head.
809
+ """
810
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
811
+ super().__init__(config, **kwargs)
812
+ self.config = config
813
+ self.num_labels = config.num_labels
814
+ self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
815
+ # Large intermediate projections help with sequence classification tasks (*4)
816
+ self.mse = nn.MSELoss()
817
+ self.ce = nn.CrossEntropyLoss()
818
+ self.bce = nn.BCEWithLogitsLoss()
819
+ self.init_weights()
820
+
821
+ def forward(
822
+ self,
823
+ input_ids: Optional[torch.Tensor] = None,
824
+ attention_mask: Optional[torch.Tensor] = None,
825
+ inputs_embeds: Optional[torch.Tensor] = None,
826
+ labels: Optional[torch.Tensor] = None,
827
+ output_attentions: Optional[bool] = None,
828
+ output_hidden_states: Optional[bool] = None,
829
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
830
+ ) -> ESMplusplusOutput:
831
+ """Forward pass for sequence classification.
832
+
833
+ Args:
834
+ input_ids: Input token IDs
835
+ attention_mask: Attention mask
836
+ inputs_embeds: Optional precomputed embeddings
837
+ labels: Optional labels for classification
838
+ output_hidden_states: Whether to return all hidden states
839
+ output_attentions: Whether to return attention weights
840
+
841
+ Returns:
842
+ ESMplusplusOutput containing loss, logits, and hidden states
843
+ """
844
+ output = super().forward(
845
+ input_ids=input_ids,
846
+ attention_mask=attention_mask,
847
+ inputs_embeds=inputs_embeds,
848
+ labels=None,
849
+ output_attentions=output_attentions,
850
+ output_hidden_states=output_hidden_states
851
+ )
852
+ x = output.last_hidden_state
853
+ cls_features = x[:, 0, :]
854
+ mean_features = self.mean_pooling(x, attention_mask)
855
+ # we include mean pooling features to help with early convergence, the cost of this is basically zero
856
+ features = torch.cat([cls_features, mean_features], dim=-1)
857
+ logits = self.classifier(features)
858
+ loss = None
859
+ if labels is not None:
860
+ labels = labels.to(logits.device)
861
+ if self.config.problem_type is None:
862
+ if self.num_labels == 1:
863
+ self.config.problem_type = "regression"
864
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
865
+ self.config.problem_type = "single_label_classification"
866
+ else:
867
+ self.config.problem_type = "multi_label_classification"
868
+
869
+ if self.config.problem_type == "regression":
870
+ if self.num_labels == 1:
871
+ loss = self.mse(logits.flatten(), labels.flatten())
872
+ else:
873
+ loss = self.mse(logits, labels)
874
+ elif self.config.problem_type == "single_label_classification":
875
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
876
+ elif self.config.problem_type == "multi_label_classification":
877
+ loss = self.bce(logits, labels)
878
+ return ESMplusplusOutput(
879
+ loss=loss,
880
+ logits=logits,
881
+ last_hidden_state=x,
882
+ hidden_states=output.hidden_states,
883
+ )
884
+
885
+
886
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
887
+ """
888
+ ESM++ model for token classification.
889
+ Extends the base ESM++ model with a token classification head.
890
+ """
891
+ def __init__(self, config: ESMplusplusConfig):
892
+ super().__init__(config)
893
+ self.config = config
894
+ self.num_labels = config.num_labels
895
+ self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
896
+ # Large intermediate projections help with sequence classification tasks (*4)
897
+ self.loss_fct = nn.CrossEntropyLoss()
898
+ self.init_weights()
899
+
900
+ def forward(
901
+ self,
902
+ input_ids: Optional[torch.Tensor] = None,
903
+ attention_mask: Optional[torch.Tensor] = None,
904
+ inputs_embeds: Optional[torch.Tensor] = None,
905
+ labels: Optional[torch.Tensor] = None,
906
+ output_attentions: Optional[bool] = None,
907
+ output_hidden_states: Optional[bool] = None,
908
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
909
+ ) -> ESMplusplusOutput:
910
+ """Forward pass for token classification.
911
+
912
+ Args:
913
+ input_ids: Input token IDs
914
+ attention_mask: Attention mask
915
+ inputs_embeds: Optional precomputed embeddings
916
+ labels: Optional labels for token classification
917
+ output_hidden_states: Whether to return all hidden states
918
+ output_attentions: Whether to return attention weights
919
+
920
+ Returns:
921
+ ESMplusplusOutput containing loss, logits, and hidden states
922
+ """
923
+ output = super().forward(
924
+ input_ids=input_ids,
925
+ attention_mask=attention_mask,
926
+ inputs_embeds=inputs_embeds,
927
+ labels=None,
928
+ output_attentions=output_attentions,
929
+ output_hidden_states=output_hidden_states
930
+ )
931
+ x = output.last_hidden_state
932
+ logits = self.classifier(x)
933
+ loss = None
934
+ if labels is not None:
935
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
936
+ return ESMplusplusOutput(
937
+ loss=loss,
938
+ logits=logits,
939
+ last_hidden_state=x,
940
+ hidden_states=output.hidden_states,
941
+ )
942
+
943
+
944
+ ### Loading from EvolutionaryScale
945
+ @staticmethod
946
+ @cache
947
+ def data_root(model: str):
948
+ if "INFRA_PROVIDER" in os.environ:
949
+ return Path("")
950
+ # Try to download from hugginface if it doesn't exist
951
+ if model.startswith("esmc-300"):
952
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
953
+ elif model.startswith("esmc-600"):
954
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
955
+ else:
956
+ raise ValueError(f"{model=} is an invalid model name.")
957
+ return path
958
+
959
+
960
+ def ESMplusplus_300M(device: torch.device | str = "cpu"):
961
+ with torch.device(device):
962
+ config = ESMplusplusConfig(
963
+ hidden_size=960,
964
+ num_attention_heads=15,
965
+ num_hidden_layers=30,
966
+ )
967
+ model = ESMplusplusForMaskedLM(config)
968
+ state_dict = torch.load(
969
+ data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
970
+ map_location=device,
971
+ )
972
+ model.load_state_dict(state_dict)
973
+ return model
974
+
975
+
976
+ def ESMplusplus_600M(device: torch.device | str = "cpu"):
977
+ with torch.device(device):
978
+ config = ESMplusplusConfig(
979
+ hidden_size=1152,
980
+ num_attention_heads=18,
981
+ num_hidden_layers=36,
982
+ )
983
+ model = ESMplusplusForMaskedLM(config)
984
+ state_dict = torch.load(
985
+ data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
986
+ map_location=device,
987
+ )
988
+ model.load_state_dict(state_dict)
989
+ return model
990
+
991
+
992
+ ### Tokenization
993
+ SEQUENCE_VOCAB = [
994
+ "<cls>", "<pad>", "<eos>", "<unk>",
995
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
996
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
997
+ "O", ".", "-", "|",
998
+ "<mask>",
999
+ ]
1000
+
1001
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1002
+ model_input_names = ["input_ids", "attention_mask"]
1003
+
1004
+ def __init__(
1005
+ self,
1006
+ unk_token="<unk>",
1007
+ cls_token="<cls>",
1008
+ pad_token="<pad>",
1009
+ mask_token="<mask>",
1010
+ eos_token="<eos>",
1011
+ chain_break_token="|",
1012
+ **kwargs,
1013
+ ):
1014
+ all_tokens = SEQUENCE_VOCAB
1015
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1016
+
1017
+ # a character-level tokenizer is the same as BPE with no token merges
1018
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1019
+ tokenizer = Tokenizer(bpe)
1020
+ special_tokens = [
1021
+ cls_token,
1022
+ pad_token,
1023
+ mask_token,
1024
+ eos_token,
1025
+ chain_break_token,
1026
+ ]
1027
+ self.cb_token = chain_break_token
1028
+ additional_special_tokens = [chain_break_token]
1029
+
1030
+ tokenizer.add_special_tokens(special_tokens)
1031
+
1032
+ # This is where we configure the automatic addition of special tokens when we call
1033
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1034
+ # sequences are merged if you want.
1035
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
1036
+ single="<cls> $A <eos>",
1037
+ special_tokens=[
1038
+ ("<cls>", tokenizer.token_to_id("<cls>")),
1039
+ ("<eos>", tokenizer.token_to_id("<eos>")),
1040
+ ],
1041
+ )
1042
+ super().__init__(
1043
+ tokenizer_object=tokenizer,
1044
+ unk_token=unk_token,
1045
+ cls_token=cls_token,
1046
+ pad_token=pad_token,
1047
+ mask_token=mask_token,
1048
+ eos_token=eos_token,
1049
+ additional_special_tokens=additional_special_tokens,
1050
+ **kwargs,
1051
+ )
1052
+
1053
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1054
+ @property
1055
+ def bos_token(self):
1056
+ return self.cls_token
1057
+
1058
+ @property
1059
+ def bos_token_id(self):
1060
+ return self.cls_token_id
1061
+
1062
+ @property
1063
+ def chain_break_token(self):
1064
+ return self.cb_token
1065
+
1066
+ @property
1067
+ def chain_break_token_id(self):
1068
+ return self.convert_tokens_to_ids(self.chain_break_token)
1069
+
1070
+ @property
1071
+ def all_token_ids(self):
1072
+ return list(range(self.vocab_size))
1073
+
1074
+ @property
1075
+ def special_token_ids(self):
1076
+ return self.all_special_ids