lhallee commited on
Commit
76bc2e4
·
verified ·
1 Parent(s): 4000b5d

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +635 -634
modeling_esm_plusplus.py CHANGED
@@ -1,634 +1,635 @@
1
- ### Modified from https://github.com/evolutionaryscale/esm
2
- ### License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- import math
7
- from dataclasses import dataclass
8
- from transformers import PreTrainedModel, PretrainedConfig
9
- from einops import rearrange, repeat
10
- from functools import partial
11
- from typing import Optional, Tuple
12
- from transformers.modeling_outputs import ModelOutput
13
-
14
-
15
- class ESMplusplusConfig(PretrainedConfig):
16
- model_type = "ESMplusplus"
17
- def __init__(
18
- self,
19
- vocab_size: int = 64,
20
- hidden_size: int = 960,
21
- num_attention_heads: int = 15,
22
- num_hidden_layers: int = 30,
23
- num_labels: int = 2,
24
- problem_type: str | None = None,
25
- **kwargs,
26
- ):
27
- super().__init__(**kwargs)
28
- self.vocab_size = vocab_size
29
- self.hidden_size = hidden_size
30
- self.num_attention_heads = num_attention_heads
31
- self.num_hidden_layers = num_hidden_layers
32
- self.num_labels = num_labels
33
- self.problem_type = problem_type
34
-
35
-
36
- ### Rotary
37
- # https://github.com/evolutionaryscale/esm/blob/main/esm/layers/rotary.py
38
- # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
39
- # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
40
- def rotate_half(x, interleaved=False):
41
- if not interleaved:
42
- x1, x2 = x.chunk(2, dim=-1)
43
- return torch.cat((-x2, x1), dim=-1)
44
- else:
45
- x1, x2 = x[..., ::2], x[..., 1::2]
46
- return rearrange(
47
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
48
- )
49
-
50
-
51
- def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
52
- """
53
- x: (batch_size, seqlen, nheads, headdim)
54
- cos, sin: (seqlen, rotary_dim / 2)
55
- """
56
- ro_dim = cos.shape[-1] * 2
57
- assert ro_dim <= x.shape[-1]
58
- seqlen = x.size(1)
59
- cos = cos[:seqlen]
60
- sin = sin[:seqlen]
61
- cos = repeat(cos, "s d -> s 1 (2 d)")
62
- sin = repeat(sin, "s d -> s 1 (2 d)")
63
- return torch.cat(
64
- [
65
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
66
- x[..., ro_dim:],
67
- ],
68
- dim=-1,
69
- )
70
-
71
-
72
- class RotaryEmbedding(torch.nn.Module):
73
- def __init__(
74
- self,
75
- dim: int,
76
- base=10000.0,
77
- interleaved=False,
78
- scale_base=None,
79
- scaling_factor=1.0,
80
- pos_idx_in_fp32=True,
81
- device=None,
82
- ):
83
- super().__init__()
84
- self.dim = dim
85
- self.base = float(base)
86
- self.pos_idx_in_fp32 = pos_idx_in_fp32
87
- # Generate and save the inverse frequency buffer (non trainable)
88
- self.interleaved = interleaved
89
- self.scale_base = scale_base
90
- self.scaling_factor = scaling_factor
91
- self.device = device
92
-
93
- self._seq_len_cached = 0
94
- self._cos_cached = None
95
- self._sin_cached = None
96
- self._cos_k_cached = None
97
- self._sin_k_cached = None
98
- self.reset_parameters()
99
-
100
- def reset_parameters(self):
101
- inv_freq = self._compute_inv_freq(self.device)
102
- self.register_buffer("inv_freq", inv_freq, persistent=False)
103
- arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
104
- scale = (
105
- (arange + 0.4 * self.dim) / (1.4 * self.dim)
106
- if self.scale_base is not None
107
- else None
108
- )
109
- self.register_buffer("scale", scale)
110
-
111
- def _compute_inv_freq(self, device=None):
112
- return 1 / (
113
- self.base
114
- ** (
115
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
116
- / self.dim
117
- )
118
- )
119
-
120
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
121
- if (
122
- seqlen > self._seq_len_cached
123
- or self._cos_cached is None
124
- or self._cos_cached.device != device
125
- or self._cos_cached.dtype != dtype
126
- or (self.training and self._cos_cached.is_inference())
127
- ):
128
- self._seq_len_cached = seqlen
129
- if self.pos_idx_in_fp32:
130
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
131
- t /= self.scaling_factor
132
- if self.inv_freq.dtype != torch.float32:
133
- inv_freq = self.inv_freq.to(torch.float32)
134
- else:
135
- inv_freq = self.inv_freq
136
- else:
137
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
138
- t /= self.scaling_factor
139
- inv_freq = self.inv_freq
140
- freqs = torch.outer(t, inv_freq)
141
-
142
- if self.scale is None:
143
- self._cos_cached = torch.cos(freqs).to(dtype)
144
- self._sin_cached = torch.sin(freqs).to(dtype)
145
- else:
146
- power = (
147
- torch.arange(
148
- seqlen, dtype=self.scale.dtype, device=self.scale.device
149
- )
150
- - seqlen // 2
151
- ) / self.scale_base
152
- scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
153
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
154
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
155
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
156
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
157
-
158
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
159
- """
160
- q: (batch, seqlen, nheads, headdim)
161
- k: (batch, seqlen, nheads, headdim)
162
- """
163
- self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
164
- assert self._cos_cached is not None
165
- assert self._sin_cached is not None
166
- if self.scale is None:
167
- return (
168
- apply_rotary_emb_torch(
169
- q,
170
- self._cos_cached,
171
- self._sin_cached,
172
- self.interleaved,
173
- True, # inplace=True
174
- ),
175
- apply_rotary_emb_torch(
176
- k,
177
- self._cos_cached,
178
- self._sin_cached,
179
- self.interleaved,
180
- True, # inplace=True
181
- ),
182
- ) # type: ignore
183
- else:
184
- assert False
185
-
186
-
187
- ### Feedforward
188
- def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
189
- return int(((expansion_ratio * d_model) + 255) // 256 * 256)
190
-
191
-
192
- class SwiGLU(nn.Module):
193
- def __init__(self):
194
- super(SwiGLU, self).__init__()
195
-
196
- def forward(self, x: torch.Tensor) -> torch.Tensor:
197
- x1, x2 = x.chunk(2, dim=-1)
198
- return F.silu(x1) * x2
199
-
200
-
201
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float):
202
- return nn.Sequential(
203
- nn.LayerNorm(d_model),
204
- nn.Linear(
205
- d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
206
- ),
207
- SwiGLU(),
208
- nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
209
- )
210
-
211
-
212
- ### Attention
213
- class MultiHeadAttention(nn.Module):
214
- def __init__(self, d_model: int, n_heads: int):
215
- super().__init__()
216
- self.d_model = d_model
217
- self.n_heads = n_heads
218
- self.d_head = self.d_model // self.n_heads
219
- self.layernorm_qkv = nn.Sequential(
220
- nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
221
- )
222
- self.out_proj = nn.Linear(d_model, d_model, bias=False)
223
- self.q_ln = nn.LayerNorm(d_model, bias=False)
224
- self.k_ln = nn.LayerNorm(d_model, bias=False)
225
- self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
226
- self.rotary = RotaryEmbedding(d_model // n_heads)
227
-
228
- def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
229
- q = q.unflatten(-1, (self.n_heads, self.d_head))
230
- k = k.unflatten(-1, (self.n_heads, self.d_head))
231
- q, k = self.rotary(q, k)
232
- q = q.flatten(-2, -1)
233
- k = k.flatten(-2, -1)
234
- return q, k
235
-
236
- def forward(self, x, attention_mask=None):
237
- qkv_BLD3 = self.layernorm_qkv(x)
238
- query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
239
- query_BLD, key_BLD = (
240
- self.q_ln(query_BLD).to(query_BLD.dtype),
241
- self.k_ln(key_BLD).to(query_BLD.dtype),
242
- )
243
- query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
244
- query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
245
- context_BHLD = F.scaled_dot_product_attention(
246
- query_BHLD, key_BHLD, value_BHLD, attention_mask
247
- )
248
- context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
249
- return self.out_proj(context_BLD)
250
-
251
-
252
- ### LM Head
253
- def RegressionHead(
254
- d_model: int, output_dim: int, hidden_dim: int | None = None
255
- ) -> nn.Module:
256
- hidden_dim = hidden_dim if hidden_dim is not None else d_model
257
- return nn.Sequential(
258
- nn.Linear(d_model, hidden_dim),
259
- nn.GELU(),
260
- nn.LayerNorm(hidden_dim),
261
- nn.Linear(hidden_dim, output_dim),
262
- )
263
-
264
-
265
- ### Transformer Block
266
- class UnifiedTransformerBlock(nn.Module):
267
- def __init__(
268
- self,
269
- d_model: int,
270
- n_heads: int,
271
- residue_scaling_factor: float = 1,
272
- expansion_ratio: float = 8 / 3,
273
- ):
274
- super().__init__()
275
- self.attn = MultiHeadAttention(d_model, n_heads)
276
- self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
277
- self.scaling_factor = residue_scaling_factor
278
-
279
- def forward(
280
- self,
281
- x: torch.Tensor,
282
- attention_mask: Optional[torch.Tensor] = None,
283
- ) -> torch.Tensor:
284
- r1 = self.attn(x, attention_mask)
285
- x = x + r1 / self.scaling_factor
286
- r3 = self.ffn(x) / self.scaling_factor
287
- x = x + r3
288
- return x
289
-
290
-
291
- ### Outputs
292
- @dataclass
293
- class TransformerOutput(ModelOutput):
294
- last_hidden_state: torch.Tensor | None = None
295
- hidden_states: tuple[torch.Tensor] | None = None
296
-
297
-
298
- @dataclass
299
- class ESMplusplusOutput(ModelOutput):
300
- loss: torch.Tensor | None = None
301
- logits: torch.Tensor | None = None
302
- last_hidden_state: torch.Tensor | None = None
303
- hidden_states: tuple[torch.Tensor] | None = None
304
-
305
-
306
- ### Transformer
307
- class TransformerStack(nn.Module):
308
- def __init__(
309
- self,
310
- d_model: int,
311
- n_heads: int,
312
- n_layers: int,
313
- ):
314
- super().__init__()
315
- self.blocks = nn.ModuleList(
316
- [
317
- UnifiedTransformerBlock(
318
- d_model,
319
- n_heads,
320
- residue_scaling_factor=math.sqrt(n_layers / 36),
321
- )
322
- for i in range(n_layers)
323
- ]
324
- )
325
- self.norm = nn.LayerNorm(d_model, bias=False)
326
-
327
- def forward(
328
- self,
329
- x: torch.Tensor,
330
- attention_mask: Optional[torch.Tensor] = None,
331
- output_hidden_states: bool = False,
332
- ) -> TransformerOutput:
333
- batch_size, seq_len, _ = x.shape
334
- hidden_states = ()
335
- if attention_mask is not None:
336
- attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
337
- for block in self.blocks:
338
- x = block(x, attention_mask)
339
- if output_hidden_states:
340
- hidden_states += (x,)
341
- return TransformerOutput(last_hidden_state=self.norm(x), hidden_states=hidden_states)
342
-
343
-
344
- ### Full model
345
- class ESMplusplusForMaskedLM(PreTrainedModel):
346
- """
347
- ESM++ for masked language modeling.
348
- """
349
- def __init__(self, config: ESMplusplusConfig):
350
- super().__init__(config)
351
- self.config = config
352
- self.vocab_size = config.vocab_size
353
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
354
- self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers)
355
- self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
356
- self.ce_loss = nn.CrossEntropyLoss()
357
- self.tokenizer = EsmSequenceTokenizer()
358
-
359
- @classmethod
360
- def from_pretrained_esm(cls, model_name: str):
361
- if '300' in model_name:
362
- return ESMplusplus_300M()
363
- elif '600' in model_name:
364
- return ESMplusplus_600M()
365
- else:
366
- raise ValueError(f"Invalid model name: {model_name}")
367
-
368
- @property
369
- def device(self):
370
- return next(self.parameters()).device
371
-
372
- def forward(
373
- self,
374
- input_ids: torch.Tensor | None = None,
375
- attention_mask: Optional[torch.Tensor] = None,
376
- labels: Optional[torch.Tensor] = None,
377
- output_hidden_states: bool = False,
378
- ) -> ESMplusplusOutput:
379
- x = self.embed(input_ids)
380
- output = self.transformer(x, attention_mask, output_hidden_states)
381
- x = output.last_hidden_state
382
- logits = self.sequence_head(x)
383
- loss = None
384
- if labels is not None:
385
- loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
386
- return ESMplusplusOutput(
387
- loss=loss,
388
- logits=logits,
389
- last_hidden_state=x,
390
- hidden_states=output.hidden_states,
391
- )
392
-
393
-
394
- class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
395
- """
396
- ESM++ for sequence classification.
397
- """
398
- def __init__(self, config: ESMplusplusConfig):
399
- super().__init__(config)
400
- self.config = config
401
- self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
402
- # we find that large intermediate projections help with sequence classification tasks (*4)
403
- self.mse = nn.MSELoss()
404
- self.ce = nn.CrossEntropyLoss()
405
- self.bce = nn.BCEWithLogitsLoss()
406
-
407
- def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
408
- # x: (batch_size, seq_len, hidden_size)
409
- # attention_mask: (batch_size, seq_len)
410
- if attention_mask is None:
411
- return x.mean(dim=1)
412
- else:
413
- return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
414
-
415
- def forward(
416
- self,
417
- input_ids: torch.Tensor | None = None,
418
- attention_mask: Optional[torch.Tensor] = None,
419
- labels: Optional[torch.Tensor] = None,
420
- output_hidden_states: bool = False,
421
- ) -> ESMplusplusOutput:
422
- output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
423
- x = output.last_hidden_state
424
- cls_features = x[:, 0, :]
425
- mean_features = self.mean_pooling(x, attention_mask)
426
- # we include mean pooling features to help with early convergence, the cost of this is basically zero
427
- features = torch.cat([cls_features, mean_features], dim=-1)
428
- logits = self.classifier(features)
429
- loss = None
430
- if labels is not None:
431
- labels = labels.to(logits.device)
432
- if self.config.problem_type is None:
433
- if self.num_labels == 1:
434
- self.config.problem_type = "regression"
435
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
436
- self.config.problem_type = "single_label_classification"
437
- else:
438
- self.config.problem_type = "multi_label_classification"
439
-
440
- if self.config.problem_type == "regression":
441
- if self.num_labels == 1:
442
- loss = self.mse(logits.squeeze(), labels.squeeze())
443
- else:
444
- loss = self.mse(logits, labels)
445
- elif self.config.problem_type == "single_label_classification":
446
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
447
- elif self.config.problem_type == "multi_label_classification":
448
- loss = self.bce(logits, labels)
449
- return ESMplusplusOutput(
450
- loss=loss,
451
- logits=logits,
452
- last_hidden_state=x,
453
- hidden_states=output.hidden_states,
454
- )
455
-
456
-
457
- class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
458
- """
459
- ESM++ for token classification.
460
- """
461
- def __init__(self, config: ESMplusplusConfig):
462
- super().__init__(config)
463
- self.config = config
464
- self.num_labels = config.num_labels
465
- self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
466
- # we find that large intermediate projections help with sequence classification tasks (*4)
467
- self.loss_fct = nn.CrossEntropyLoss()
468
-
469
- def forward(
470
- self,
471
- input_ids: torch.Tensor | None = None,
472
- attention_mask: Optional[torch.Tensor] = None,
473
- labels: Optional[torch.Tensor] = None,
474
- output_hidden_states: bool = False,
475
- ) -> ESMplusplusOutput:
476
- output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
477
- x = output.last_hidden_state
478
- logits = self.classifier(x)
479
- loss = None
480
- if labels is not None:
481
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
482
- return ESMplusplusOutput(
483
- loss=loss,
484
- logits=logits,
485
- last_hidden_state=x,
486
- hidden_states=output.hidden_states,
487
- )
488
-
489
-
490
- ### Loading
491
- import os
492
- from functools import cache
493
- from pathlib import Path
494
- from huggingface_hub import snapshot_download
495
-
496
-
497
- @staticmethod
498
- @cache
499
- def data_root(model: str):
500
- if "INFRA_PROVIDER" in os.environ:
501
- return Path("")
502
- # Try to download from hugginface if it doesn't exist
503
- if model.startswith("esmc-300"):
504
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
505
- elif model.startswith("esmc-600"):
506
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
507
- else:
508
- raise ValueError(f"{model=} is an invalid model name.")
509
- return path
510
-
511
-
512
- def ESMplusplus_300M(device: torch.device | str = "cpu"):
513
- with torch.device(device):
514
- config = ESMplusplusConfig(
515
- hidden_size=960,
516
- num_attention_heads=15,
517
- num_hidden_layers=30,
518
- )
519
- model = ESMplusplusForMaskedLM(config)
520
- state_dict = torch.load(
521
- data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
522
- map_location=device,
523
- )
524
- model.load_state_dict(state_dict)
525
- return model
526
-
527
-
528
- def ESMplusplus_600M(device: torch.device | str = "cpu"):
529
- with torch.device(device):
530
- config = ESMplusplusConfig(
531
- hidden_size=1152,
532
- num_attention_heads=18,
533
- num_hidden_layers=36,
534
- )
535
- model = ESMplusplusForMaskedLM(config)
536
- state_dict = torch.load(
537
- data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
538
- map_location=device,
539
- )
540
- model.load_state_dict(state_dict)
541
- return model
542
-
543
-
544
- ### Tokenization
545
- from tokenizers import Tokenizer
546
- from tokenizers.models import BPE
547
- from tokenizers.processors import TemplateProcessing
548
- from transformers import PreTrainedTokenizerFast
549
-
550
-
551
- SEQUENCE_VOCAB = [
552
- "<cls>", "<pad>", "<eos>", "<unk>",
553
- "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
554
- "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
555
- "O", ".", "-", "|",
556
- "<mask>",
557
- ]
558
-
559
- class EsmSequenceTokenizer(PreTrainedTokenizerFast):
560
- model_input_names = ["input_ids", "attention_mask"]
561
-
562
- def __init__(
563
- self,
564
- unk_token="<unk>",
565
- cls_token="<cls>",
566
- pad_token="<pad>",
567
- mask_token="<mask>",
568
- eos_token="<eos>",
569
- chain_break_token="|",
570
- **kwargs,
571
- ):
572
- all_tokens = SEQUENCE_VOCAB
573
- token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
574
-
575
- # a character-level tokenizer is the same as BPE with no token merges
576
- bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
577
- tokenizer = Tokenizer(bpe)
578
- special_tokens = [
579
- cls_token,
580
- pad_token,
581
- mask_token,
582
- eos_token,
583
- chain_break_token,
584
- ]
585
- self.cb_token = chain_break_token
586
- additional_special_tokens = [chain_break_token]
587
-
588
- tokenizer.add_special_tokens(special_tokens)
589
-
590
- # This is where we configure the automatic addition of special tokens when we call
591
- # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
592
- # sequences are merged if you want.
593
- tokenizer.post_processor = TemplateProcessing( # type: ignore
594
- single="<cls> $A <eos>",
595
- special_tokens=[
596
- ("<cls>", tokenizer.token_to_id("<cls>")),
597
- ("<eos>", tokenizer.token_to_id("<eos>")),
598
- ],
599
- )
600
- super().__init__(
601
- tokenizer_object=tokenizer,
602
- unk_token=unk_token,
603
- cls_token=cls_token,
604
- pad_token=pad_token,
605
- mask_token=mask_token,
606
- eos_token=eos_token,
607
- additional_special_tokens=additional_special_tokens,
608
- **kwargs,
609
- )
610
-
611
- # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
612
- @property
613
- def bos_token(self):
614
- return self.cls_token
615
-
616
- @property
617
- def bos_token_id(self):
618
- return self.cls_token_id
619
-
620
- @property
621
- def chain_break_token(self):
622
- return self.cb_token
623
-
624
- @property
625
- def chain_break_token_id(self):
626
- return self.convert_tokens_to_ids(self.chain_break_token)
627
-
628
- @property
629
- def all_token_ids(self):
630
- return list(range(self.vocab_size))
631
-
632
- @property
633
- def special_token_ids(self):
634
- return self.all_special_ids
 
 
1
+ ### Modified from https://github.com/evolutionaryscale/esm
2
+ ### License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from dataclasses import dataclass
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+ from einops import rearrange, repeat
10
+ from functools import partial
11
+ from typing import Optional, Tuple
12
+ from transformers.modeling_outputs import ModelOutput
13
+
14
+
15
+ class ESMplusplusConfig(PretrainedConfig):
16
+ model_type = "ESMplusplus"
17
+ def __init__(
18
+ self,
19
+ vocab_size: int = 64,
20
+ hidden_size: int = 960,
21
+ num_attention_heads: int = 15,
22
+ num_hidden_layers: int = 30,
23
+ num_labels: int = 2,
24
+ problem_type: str | None = None,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.vocab_size = vocab_size
29
+ self.hidden_size = hidden_size
30
+ self.num_attention_heads = num_attention_heads
31
+ self.num_hidden_layers = num_hidden_layers
32
+ self.num_labels = num_labels
33
+ self.problem_type = problem_type
34
+
35
+
36
+ ### Rotary
37
+ # https://github.com/evolutionaryscale/esm/blob/main/esm/layers/rotary.py
38
+ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
39
+ # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
40
+ def rotate_half(x, interleaved=False):
41
+ if not interleaved:
42
+ x1, x2 = x.chunk(2, dim=-1)
43
+ return torch.cat((-x2, x1), dim=-1)
44
+ else:
45
+ x1, x2 = x[..., ::2], x[..., 1::2]
46
+ return rearrange(
47
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
48
+ )
49
+
50
+
51
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
52
+ """
53
+ x: (batch_size, seqlen, nheads, headdim)
54
+ cos, sin: (seqlen, rotary_dim / 2)
55
+ """
56
+ ro_dim = cos.shape[-1] * 2
57
+ assert ro_dim <= x.shape[-1]
58
+ seqlen = x.size(1)
59
+ cos = cos[:seqlen]
60
+ sin = sin[:seqlen]
61
+ cos = repeat(cos, "s d -> s 1 (2 d)")
62
+ sin = repeat(sin, "s d -> s 1 (2 d)")
63
+ return torch.cat(
64
+ [
65
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
66
+ x[..., ro_dim:],
67
+ ],
68
+ dim=-1,
69
+ )
70
+
71
+
72
+ class RotaryEmbedding(torch.nn.Module):
73
+ def __init__(
74
+ self,
75
+ dim: int,
76
+ base=10000.0,
77
+ interleaved=False,
78
+ scale_base=None,
79
+ scaling_factor=1.0,
80
+ pos_idx_in_fp32=True,
81
+ device=None,
82
+ ):
83
+ super().__init__()
84
+ self.dim = dim
85
+ self.base = float(base)
86
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
87
+ # Generate and save the inverse frequency buffer (non trainable)
88
+ self.interleaved = interleaved
89
+ self.scale_base = scale_base
90
+ self.scaling_factor = scaling_factor
91
+ self.device = device
92
+
93
+ self._seq_len_cached = 0
94
+ self._cos_cached = None
95
+ self._sin_cached = None
96
+ self._cos_k_cached = None
97
+ self._sin_k_cached = None
98
+ self.reset_parameters()
99
+
100
+ def reset_parameters(self):
101
+ inv_freq = self._compute_inv_freq(self.device)
102
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
103
+ arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
104
+ scale = (
105
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
106
+ if self.scale_base is not None
107
+ else None
108
+ )
109
+ self.register_buffer("scale", scale)
110
+
111
+ def _compute_inv_freq(self, device=None):
112
+ return 1 / (
113
+ self.base
114
+ ** (
115
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
116
+ / self.dim
117
+ )
118
+ )
119
+
120
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
121
+ if (
122
+ seqlen > self._seq_len_cached
123
+ or self._cos_cached is None
124
+ or self._cos_cached.device != device
125
+ or self._cos_cached.dtype != dtype
126
+ or (self.training and self._cos_cached.is_inference())
127
+ ):
128
+ self._seq_len_cached = seqlen
129
+ if self.pos_idx_in_fp32:
130
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
131
+ t /= self.scaling_factor
132
+ if self.inv_freq.dtype != torch.float32:
133
+ inv_freq = self.inv_freq.to(torch.float32)
134
+ else:
135
+ inv_freq = self.inv_freq
136
+ else:
137
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
138
+ t /= self.scaling_factor
139
+ inv_freq = self.inv_freq
140
+ freqs = torch.outer(t, inv_freq)
141
+
142
+ if self.scale is None:
143
+ self._cos_cached = torch.cos(freqs).to(dtype)
144
+ self._sin_cached = torch.sin(freqs).to(dtype)
145
+ else:
146
+ power = (
147
+ torch.arange(
148
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
149
+ )
150
+ - seqlen // 2
151
+ ) / self.scale_base
152
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
153
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
154
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
155
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
156
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
157
+
158
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
159
+ """
160
+ q: (batch, seqlen, nheads, headdim)
161
+ k: (batch, seqlen, nheads, headdim)
162
+ """
163
+ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
164
+ assert self._cos_cached is not None
165
+ assert self._sin_cached is not None
166
+ if self.scale is None:
167
+ return (
168
+ apply_rotary_emb_torch(
169
+ q,
170
+ self._cos_cached,
171
+ self._sin_cached,
172
+ self.interleaved,
173
+ True, # inplace=True
174
+ ),
175
+ apply_rotary_emb_torch(
176
+ k,
177
+ self._cos_cached,
178
+ self._sin_cached,
179
+ self.interleaved,
180
+ True, # inplace=True
181
+ ),
182
+ ) # type: ignore
183
+ else:
184
+ assert False
185
+
186
+
187
+ ### Feedforward
188
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
189
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
190
+
191
+
192
+ class SwiGLU(nn.Module):
193
+ def __init__(self):
194
+ super(SwiGLU, self).__init__()
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ x1, x2 = x.chunk(2, dim=-1)
198
+ return F.silu(x1) * x2
199
+
200
+
201
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float):
202
+ return nn.Sequential(
203
+ nn.LayerNorm(d_model),
204
+ nn.Linear(
205
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
206
+ ),
207
+ SwiGLU(),
208
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
209
+ )
210
+
211
+
212
+ ### Attention
213
+ class MultiHeadAttention(nn.Module):
214
+ def __init__(self, d_model: int, n_heads: int):
215
+ super().__init__()
216
+ self.d_model = d_model
217
+ self.n_heads = n_heads
218
+ self.d_head = self.d_model // self.n_heads
219
+ self.layernorm_qkv = nn.Sequential(
220
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
221
+ )
222
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
223
+ self.q_ln = nn.LayerNorm(d_model, bias=False)
224
+ self.k_ln = nn.LayerNorm(d_model, bias=False)
225
+ self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
226
+ self.rotary = RotaryEmbedding(d_model // n_heads)
227
+
228
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
229
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
230
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
231
+ q, k = self.rotary(q, k)
232
+ q = q.flatten(-2, -1)
233
+ k = k.flatten(-2, -1)
234
+ return q, k
235
+
236
+ def forward(self, x, attention_mask=None):
237
+ qkv_BLD3 = self.layernorm_qkv(x)
238
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
239
+ query_BLD, key_BLD = (
240
+ self.q_ln(query_BLD).to(query_BLD.dtype),
241
+ self.k_ln(key_BLD).to(query_BLD.dtype),
242
+ )
243
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
244
+ query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
245
+ context_BHLD = F.scaled_dot_product_attention(
246
+ query_BHLD, key_BHLD, value_BHLD, attention_mask
247
+ )
248
+ context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
249
+ return self.out_proj(context_BLD)
250
+
251
+
252
+ ### LM Head
253
+ def RegressionHead(
254
+ d_model: int, output_dim: int, hidden_dim: int | None = None
255
+ ) -> nn.Module:
256
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
257
+ return nn.Sequential(
258
+ nn.Linear(d_model, hidden_dim),
259
+ nn.GELU(),
260
+ nn.LayerNorm(hidden_dim),
261
+ nn.Linear(hidden_dim, output_dim),
262
+ )
263
+
264
+
265
+ ### Transformer Block
266
+ class UnifiedTransformerBlock(nn.Module):
267
+ def __init__(
268
+ self,
269
+ d_model: int,
270
+ n_heads: int,
271
+ residue_scaling_factor: float = 1,
272
+ expansion_ratio: float = 8 / 3,
273
+ ):
274
+ super().__init__()
275
+ self.attn = MultiHeadAttention(d_model, n_heads)
276
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
277
+ self.scaling_factor = residue_scaling_factor
278
+
279
+ def forward(
280
+ self,
281
+ x: torch.Tensor,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ ) -> torch.Tensor:
284
+ r1 = self.attn(x, attention_mask)
285
+ x = x + r1 / self.scaling_factor
286
+ r3 = self.ffn(x) / self.scaling_factor
287
+ x = x + r3
288
+ return x
289
+
290
+
291
+ ### Outputs
292
+ @dataclass
293
+ class TransformerOutput(ModelOutput):
294
+ last_hidden_state: torch.Tensor | None = None
295
+ hidden_states: tuple[torch.Tensor] | None = None
296
+
297
+
298
+ @dataclass
299
+ class ESMplusplusOutput(ModelOutput):
300
+ loss: torch.Tensor | None = None
301
+ logits: torch.Tensor | None = None
302
+ last_hidden_state: torch.Tensor | None = None
303
+ hidden_states: tuple[torch.Tensor] | None = None
304
+
305
+
306
+ ### Transformer
307
+ class TransformerStack(nn.Module):
308
+ def __init__(
309
+ self,
310
+ d_model: int,
311
+ n_heads: int,
312
+ n_layers: int,
313
+ ):
314
+ super().__init__()
315
+ self.blocks = nn.ModuleList(
316
+ [
317
+ UnifiedTransformerBlock(
318
+ d_model,
319
+ n_heads,
320
+ residue_scaling_factor=math.sqrt(n_layers / 36),
321
+ )
322
+ for i in range(n_layers)
323
+ ]
324
+ )
325
+ self.norm = nn.LayerNorm(d_model, bias=False)
326
+
327
+ def forward(
328
+ self,
329
+ x: torch.Tensor,
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ output_hidden_states: bool = False,
332
+ ) -> TransformerOutput:
333
+ batch_size, seq_len, _ = x.shape
334
+ hidden_states = ()
335
+ if attention_mask is not None:
336
+ attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
337
+ for block in self.blocks:
338
+ x = block(x, attention_mask)
339
+ if output_hidden_states:
340
+ hidden_states += (x,)
341
+ return TransformerOutput(last_hidden_state=self.norm(x), hidden_states=hidden_states)
342
+
343
+
344
+ ### Full model
345
+ class ESMplusplusForMaskedLM(PreTrainedModel):
346
+ """
347
+ ESM++ for masked language modeling.
348
+ """
349
+ config_class = ESMplusplusConfig
350
+ def __init__(self, config: ESMplusplusConfig):
351
+ super().__init__(config)
352
+ self.config = config
353
+ self.vocab_size = config.vocab_size
354
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
355
+ self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers)
356
+ self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
357
+ self.ce_loss = nn.CrossEntropyLoss()
358
+ self.tokenizer = EsmSequenceTokenizer()
359
+
360
+ @classmethod
361
+ def from_pretrained_esm(cls, model_name: str):
362
+ if '300' in model_name:
363
+ return ESMplusplus_300M()
364
+ elif '600' in model_name:
365
+ return ESMplusplus_600M()
366
+ else:
367
+ raise ValueError(f"Invalid model name: {model_name}")
368
+
369
+ @property
370
+ def device(self):
371
+ return next(self.parameters()).device
372
+
373
+ def forward(
374
+ self,
375
+ input_ids: torch.Tensor | None = None,
376
+ attention_mask: Optional[torch.Tensor] = None,
377
+ labels: Optional[torch.Tensor] = None,
378
+ output_hidden_states: bool = False,
379
+ ) -> ESMplusplusOutput:
380
+ x = self.embed(input_ids)
381
+ output = self.transformer(x, attention_mask, output_hidden_states)
382
+ x = output.last_hidden_state
383
+ logits = self.sequence_head(x)
384
+ loss = None
385
+ if labels is not None:
386
+ loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
387
+ return ESMplusplusOutput(
388
+ loss=loss,
389
+ logits=logits,
390
+ last_hidden_state=x,
391
+ hidden_states=output.hidden_states,
392
+ )
393
+
394
+
395
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
396
+ """
397
+ ESM++ for sequence classification.
398
+ """
399
+ def __init__(self, config: ESMplusplusConfig):
400
+ super().__init__(config)
401
+ self.config = config
402
+ self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
403
+ # we find that large intermediate projections help with sequence classification tasks (*4)
404
+ self.mse = nn.MSELoss()
405
+ self.ce = nn.CrossEntropyLoss()
406
+ self.bce = nn.BCEWithLogitsLoss()
407
+
408
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
409
+ # x: (batch_size, seq_len, hidden_size)
410
+ # attention_mask: (batch_size, seq_len)
411
+ if attention_mask is None:
412
+ return x.mean(dim=1)
413
+ else:
414
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
415
+
416
+ def forward(
417
+ self,
418
+ input_ids: torch.Tensor | None = None,
419
+ attention_mask: Optional[torch.Tensor] = None,
420
+ labels: Optional[torch.Tensor] = None,
421
+ output_hidden_states: bool = False,
422
+ ) -> ESMplusplusOutput:
423
+ output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
424
+ x = output.last_hidden_state
425
+ cls_features = x[:, 0, :]
426
+ mean_features = self.mean_pooling(x, attention_mask)
427
+ # we include mean pooling features to help with early convergence, the cost of this is basically zero
428
+ features = torch.cat([cls_features, mean_features], dim=-1)
429
+ logits = self.classifier(features)
430
+ loss = None
431
+ if labels is not None:
432
+ labels = labels.to(logits.device)
433
+ if self.config.problem_type is None:
434
+ if self.num_labels == 1:
435
+ self.config.problem_type = "regression"
436
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
437
+ self.config.problem_type = "single_label_classification"
438
+ else:
439
+ self.config.problem_type = "multi_label_classification"
440
+
441
+ if self.config.problem_type == "regression":
442
+ if self.num_labels == 1:
443
+ loss = self.mse(logits.squeeze(), labels.squeeze())
444
+ else:
445
+ loss = self.mse(logits, labels)
446
+ elif self.config.problem_type == "single_label_classification":
447
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
448
+ elif self.config.problem_type == "multi_label_classification":
449
+ loss = self.bce(logits, labels)
450
+ return ESMplusplusOutput(
451
+ loss=loss,
452
+ logits=logits,
453
+ last_hidden_state=x,
454
+ hidden_states=output.hidden_states,
455
+ )
456
+
457
+
458
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
459
+ """
460
+ ESM++ for token classification.
461
+ """
462
+ def __init__(self, config: ESMplusplusConfig):
463
+ super().__init__(config)
464
+ self.config = config
465
+ self.num_labels = config.num_labels
466
+ self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
467
+ # we find that large intermediate projections help with sequence classification tasks (*4)
468
+ self.loss_fct = nn.CrossEntropyLoss()
469
+
470
+ def forward(
471
+ self,
472
+ input_ids: torch.Tensor | None = None,
473
+ attention_mask: Optional[torch.Tensor] = None,
474
+ labels: Optional[torch.Tensor] = None,
475
+ output_hidden_states: bool = False,
476
+ ) -> ESMplusplusOutput:
477
+ output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
478
+ x = output.last_hidden_state
479
+ logits = self.classifier(x)
480
+ loss = None
481
+ if labels is not None:
482
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
483
+ return ESMplusplusOutput(
484
+ loss=loss,
485
+ logits=logits,
486
+ last_hidden_state=x,
487
+ hidden_states=output.hidden_states,
488
+ )
489
+
490
+
491
+ ### Loading
492
+ import os
493
+ from functools import cache
494
+ from pathlib import Path
495
+ from huggingface_hub import snapshot_download
496
+
497
+
498
+ @staticmethod
499
+ @cache
500
+ def data_root(model: str):
501
+ if "INFRA_PROVIDER" in os.environ:
502
+ return Path("")
503
+ # Try to download from hugginface if it doesn't exist
504
+ if model.startswith("esmc-300"):
505
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
506
+ elif model.startswith("esmc-600"):
507
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
508
+ else:
509
+ raise ValueError(f"{model=} is an invalid model name.")
510
+ return path
511
+
512
+
513
+ def ESMplusplus_300M(device: torch.device | str = "cpu"):
514
+ with torch.device(device):
515
+ config = ESMplusplusConfig(
516
+ hidden_size=960,
517
+ num_attention_heads=15,
518
+ num_hidden_layers=30,
519
+ )
520
+ model = ESMplusplusForMaskedLM(config)
521
+ state_dict = torch.load(
522
+ data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
523
+ map_location=device,
524
+ )
525
+ model.load_state_dict(state_dict)
526
+ return model
527
+
528
+
529
+ def ESMplusplus_600M(device: torch.device | str = "cpu"):
530
+ with torch.device(device):
531
+ config = ESMplusplusConfig(
532
+ hidden_size=1152,
533
+ num_attention_heads=18,
534
+ num_hidden_layers=36,
535
+ )
536
+ model = ESMplusplusForMaskedLM(config)
537
+ state_dict = torch.load(
538
+ data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
539
+ map_location=device,
540
+ )
541
+ model.load_state_dict(state_dict)
542
+ return model
543
+
544
+
545
+ ### Tokenization
546
+ from tokenizers import Tokenizer
547
+ from tokenizers.models import BPE
548
+ from tokenizers.processors import TemplateProcessing
549
+ from transformers import PreTrainedTokenizerFast
550
+
551
+
552
+ SEQUENCE_VOCAB = [
553
+ "<cls>", "<pad>", "<eos>", "<unk>",
554
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
555
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
556
+ "O", ".", "-", "|",
557
+ "<mask>",
558
+ ]
559
+
560
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
561
+ model_input_names = ["input_ids", "attention_mask"]
562
+
563
+ def __init__(
564
+ self,
565
+ unk_token="<unk>",
566
+ cls_token="<cls>",
567
+ pad_token="<pad>",
568
+ mask_token="<mask>",
569
+ eos_token="<eos>",
570
+ chain_break_token="|",
571
+ **kwargs,
572
+ ):
573
+ all_tokens = SEQUENCE_VOCAB
574
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
575
+
576
+ # a character-level tokenizer is the same as BPE with no token merges
577
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
578
+ tokenizer = Tokenizer(bpe)
579
+ special_tokens = [
580
+ cls_token,
581
+ pad_token,
582
+ mask_token,
583
+ eos_token,
584
+ chain_break_token,
585
+ ]
586
+ self.cb_token = chain_break_token
587
+ additional_special_tokens = [chain_break_token]
588
+
589
+ tokenizer.add_special_tokens(special_tokens)
590
+
591
+ # This is where we configure the automatic addition of special tokens when we call
592
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
593
+ # sequences are merged if you want.
594
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
595
+ single="<cls> $A <eos>",
596
+ special_tokens=[
597
+ ("<cls>", tokenizer.token_to_id("<cls>")),
598
+ ("<eos>", tokenizer.token_to_id("<eos>")),
599
+ ],
600
+ )
601
+ super().__init__(
602
+ tokenizer_object=tokenizer,
603
+ unk_token=unk_token,
604
+ cls_token=cls_token,
605
+ pad_token=pad_token,
606
+ mask_token=mask_token,
607
+ eos_token=eos_token,
608
+ additional_special_tokens=additional_special_tokens,
609
+ **kwargs,
610
+ )
611
+
612
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
613
+ @property
614
+ def bos_token(self):
615
+ return self.cls_token
616
+
617
+ @property
618
+ def bos_token_id(self):
619
+ return self.cls_token_id
620
+
621
+ @property
622
+ def chain_break_token(self):
623
+ return self.cb_token
624
+
625
+ @property
626
+ def chain_break_token_id(self):
627
+ return self.convert_tokens_to_ids(self.chain_break_token)
628
+
629
+ @property
630
+ def all_token_ids(self):
631
+ return list(range(self.vocab_size))
632
+
633
+ @property
634
+ def special_token_ids(self):
635
+ return self.all_special_ids