|
--- |
|
license: mit |
|
language: |
|
- en |
|
library_name: transformers |
|
tags: |
|
- esm |
|
- esm2 |
|
- protein language model |
|
- biology |
|
--- |
|
|
|
# ESM-2 (`esm2_t6_8M_UR50D`) |
|
|
|
This is a fine-tuned version of [ESM-2](https://huggingface.co/facebook/esm2_t6_8M_UR50D) for sequence classification |
|
that categorizes protein sequences into two classes, either "cystolic" or "membrane". |
|
|
|
## Training and Accuracy |
|
|
|
The model is trained using [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_language_modeling.ipynb) |
|
and achieved an eval accuracy of 94.83163664839468 %. |
|
|
|
## Using the Model |
|
To use try running: |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
# Initialize the tokenizer and model |
|
model_path_directory = "AmelieSchreiber/esm2_t6_8M_UR50D-finetuned-localization" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path_directory) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_path_directory) |
|
|
|
# Define a function to predict the category of a protein sequence |
|
def predict_category(sequence): |
|
# Tokenize the sequence and convert it to tensor format |
|
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=512, padding="max_length") |
|
|
|
# Make prediction |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
# Determine the category with the highest score |
|
predicted_class = torch.argmax(logits, dim=1).item() |
|
|
|
# Return the category: 0 for cytosolic, 1 for membrane |
|
return "cytosolic" if predicted_class == 0 else "membrane" |
|
|
|
# Example sequence |
|
new_protein_sequence = "MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYEEKYKTFNKYWCRQPCLPIWHEMVETGGSEGVVRSDQVIITDHPGDLTFTVTLENLTADDAGKYRCGIATILQEDGLSGFLPDPFFQVQVLVSSASSTENSVKTPASPTRPSQCQGSLPSSTCFLLLPLLKVPLLLSILGAILWVNRPWRTPWTES" |
|
|
|
# Predict the category |
|
category = predict_category(new_protein_sequence) |
|
print(f"The predicted category for the sequence is: {category}") |
|
``` |