lhallee commited on
Commit
8fd1c21
·
verified ·
1 Parent(s): 155f33d

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +118 -118
README.md CHANGED
@@ -1,119 +1,119 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
-
6
- # ESM++
7
- [ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package.
8
- The small version corresponds to the 300 million parameter version of ESMC.
9
-
10
-
11
- ## Use with 🤗 transformers
12
- ```python
13
- from transformers import AutoModelForMaskedLM
14
- model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
15
- tokenizer = model.tokenizer
16
-
17
- sequences = ['MPRTEIN', 'MSEQWENCE']
18
- tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
19
-
20
- # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
21
-
22
- output = model(**tokenized) # get all hidden states with output_hidden_states=True
23
- print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
24
- print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
25
- print(output.loss) # language modeling loss if you passed labels
26
- #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
27
- ```
28
-
29
- ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.
30
-
31
- ```python
32
- from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
33
-
34
- model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
35
- logits = model(**tokenized).logits
36
- print(logits.shape) # (batch_size, num_labels), (2, 2)
37
- ```
38
-
39
- ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this:
40
- ```python
41
- import torch
42
- model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True, torch_dtype=torch.float16) # or torch.bfloat16
43
- ```
44
-
45
- ## Embed entire datasets with no new code
46
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
47
- ```python
48
- embeddings = model.embed_dataset(
49
- sequences=sequences, # list of protein strings
50
- batch_size=16, # embedding batch size
51
- max_len=2048, # truncate to max_len
52
- full_embeddings=True, # return residue-wise embeddings
53
- full_precision=False, # store as float32
54
- pooling_type='mean', # use mean pooling if protein-wise embeddings
55
- num_workers=0, # data loading num workers
56
- sql=False, # return dictionary of sequences and embeddings
57
- )
58
-
59
- _ = model.embed_dataset(
60
- sequences=sequences, # list of protein strings
61
- batch_size=16, # embedding batch size
62
- max_len=2048, # truncate to max_len
63
- full_embeddings=True, # return residue-wise embeddings
64
- full_precision=False, # store as float32
65
- pooling_type='mean', # use mean pooling if protein-wise embeddings
66
- num_workers=0, # data loading num workers
67
- sql=True, # store sequences in local SQL database
68
- sql_db_path='embeddings.db', # path to .db file of choice
69
- )
70
- ```
71
-
72
- ## Returning attention maps
73
- Usually F.scaled_dot_product_attention is used for the attention calculations, which is much faster than native PyTorch. However, it cannot return attention maps.
74
- ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps.
75
-
76
- ```python
77
- output = model(**tokenized, output_attentions=True)
78
- att = output.attentions
79
- len(att) # 30, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
80
- ```
81
-
82
- ## Comparison across floating-point precision and implementations
83
- We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
84
- Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision.
85
-
86
- Average MSE FP32 vs. FP16: 0.00000003
87
-
88
- Average MSE FP32 vs. BF16: 0.00000140
89
-
90
- We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package.
91
-
92
- Average MSE of last hidden state: 7.74e-10
93
-
94
- You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_300m')
95
-
96
- ## Model probes
97
- We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.
98
-
99
- The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
100
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/2zyUZeHyOgCR_twvPF2Wy.png)
101
-
102
- ## Inference speeds
103
- We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences!
104
- The most gains will be seen with PyTorch > 2.5 on linux machines.
105
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/RfLRSchFivdsqJrWMh4bo.png)
106
-
107
- ### Citation
108
- If you use any of this implementation or work please cite it (as well as the ESMC preprint).
109
-
110
- ```
111
- @misc {ESMPlusPlus,
112
- author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
113
- title = { ESMPlusPlus },
114
- year = 2024,
115
- url = { https://huggingface.co/Synthyra/ESMplusplus_small },
116
- doi = { 10.57967/hf/3725 },
117
- publisher = { Hugging Face }
118
- }
119
  ```
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # ESM++
7
+ [ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package.
8
+ The small version corresponds to the 300 million parameter version of ESMC.
9
+
10
+
11
+ ## Use with 🤗 transformers
12
+ ```python
13
+ from transformers import AutoModelForMaskedLM
14
+ model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
15
+ tokenizer = model.tokenizer
16
+
17
+ sequences = ['MPRTEIN', 'MSEQWENCE']
18
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
19
+
20
+ # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
21
+
22
+ output = model(**tokenized) # get all hidden states with output_hidden_states=True
23
+ print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
24
+ print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
25
+ print(output.loss) # language modeling loss if you passed labels
26
+ #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
27
+ ```
28
+
29
+ ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.
30
+
31
+ ```python
32
+ from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
33
+
34
+ model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
35
+ logits = model(**tokenized).logits
36
+ print(logits.shape) # (batch_size, num_labels), (2, 2)
37
+ ```
38
+
39
+ ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this:
40
+ ```python
41
+ import torch
42
+ model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True, torch_dtype=torch.float16) # or torch.bfloat16
43
+ ```
44
+
45
+ ## Embed entire datasets with no new code
46
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
47
+ ```python
48
+ embeddings = model.embed_dataset(
49
+ sequences=sequences, # list of protein strings
50
+ batch_size=16, # embedding batch size
51
+ max_len=2048, # truncate to max_len
52
+ full_embeddings=True, # return residue-wise embeddings
53
+ full_precision=False, # store as float32
54
+ pooling_type='mean', # use mean pooling if protein-wise embeddings
55
+ num_workers=0, # data loading num workers
56
+ sql=False, # return dictionary of sequences and embeddings
57
+ )
58
+
59
+ _ = model.embed_dataset(
60
+ sequences=sequences, # list of protein strings
61
+ batch_size=16, # embedding batch size
62
+ max_len=2048, # truncate to max_len
63
+ full_embeddings=True, # return residue-wise embeddings
64
+ full_precision=False, # store as float32
65
+ pooling_type='mean', # use mean pooling if protein-wise embeddings
66
+ num_workers=0, # data loading num workers
67
+ sql=True, # store sequences in local SQL database
68
+ sql_db_path='embeddings.db', # path to .db file of choice
69
+ )
70
+ ```
71
+
72
+ ## Returning attention maps
73
+ Usually F.scaled_dot_product_attention is used for the attention calculations, which is much faster than native PyTorch. However, it cannot return attention maps.
74
+ ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps.
75
+
76
+ ```python
77
+ output = model(**tokenized, output_attentions=True)
78
+ att = output.attentions
79
+ len(att) # 30, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
80
+ ```
81
+
82
+ ## Comparison across floating-point precision and implementations
83
+ We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
84
+ Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision.
85
+
86
+ Average MSE FP32 vs. FP16: 0.00000003
87
+
88
+ Average MSE FP32 vs. BF16: 0.00000140
89
+
90
+ We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package.
91
+
92
+ Average MSE of last hidden state: 7.74e-10
93
+
94
+ You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_300m')
95
+
96
+ ## Model probes
97
+ We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.
98
+
99
+ The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
100
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/2zyUZeHyOgCR_twvPF2Wy.png)
101
+
102
+ ## Inference speeds
103
+ We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences!
104
+ The most gains will be seen with PyTorch > 2.5 on linux machines.
105
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/RfLRSchFivdsqJrWMh4bo.png)
106
+
107
+ ### Citation
108
+ If you use any of this implementation or work please cite it (as well as the ESMC preprint).
109
+
110
+ ```
111
+ @misc {ESMPlusPlus,
112
+ author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
113
+ title = { ESMPlusPlus },
114
+ year = 2024,
115
+ url = { https://huggingface.co/Synthyra/ESMplusplus_small },
116
+ doi = { 10.57967/hf/3725 },
117
+ publisher = { Hugging Face }
118
+ }
119
  ```