--- widget: - text: "MEPLDDLDLLLLEEDSGAEAVPRMEILQKKADAFFAETVLSRGVDNRYLVLAVETKLNERGAEEKHLLITVSQEGEQEVLCILRNGWSSVPVEPGDIIHIEGDCTSEPWIVDDDFGYFILSPDMLISGTSVASSIRCLRRAVLSETFRVSDTATRQMLIGTILHEVFQKAISESFAPEKLQELALQTLREVRHLKEMYRLNLSQDEVRCEVEEYLPSFSKWADEFMHKGTKAEFPQMHLSLPSDSSDRSSPCNIEVVKSLDIEESIWSPRFGLKGKIDVTVGVKIHRDCKTKYKIMPLELKTGKESNSIEHRGQVILYTLLSQERREDPEAGWLLYLKTGQMYPVPANHLDKRELLKLRNQLAFSLLHRVSRAAAGEEARLLALPQIIEEEKTCKYCSQMGNCALYSRAVEQVHDTSIPEGMRSKIQEGTQHLTRAHLKYFSLWCLMLTLESQSKDTKKSHQSIWLTPASKLEESGNCIGSLVRTEPVKRVCDGHYLHNFQRKNGPMPATNLMAGDRIILSGEERKLFALSKGYVKRIDTAAVTCLLDRNLSTLPETTLFRLDREEKHGDINTPLGNLSKLMENTDSSKRLRELIIDFKEPQFIAYLSSVLPHDAKDTVANILKGLNKPQRQAMKKVLLSKDYTLIVGMPGTGKTTTICALVRILSACGFSVLLTSYTHSAVDNILLKLAKFKIGFLRLGQSHKVHPDIQKFTEEEMCRLRSIASLAHLEELYNSHPVVATTCMGISHPMFSRKTFDFCIVDEASQISQPICLGPLFFSRRFVLVGDHKQLPPLVLNREARALGMSESLFKRLERNESAVVQLTIQYRMNRKIMSLSNKLTYEGKLECGSDRVANAVITLPNLKDVRLEFYADYSDNPWLAGVFEPDNPVCFLNTDKVPAPEQIENGGVSNVTEARLIVFLTSTFIKAGCSPSDIGIIAPYRQQLRTITDLLARSSVGMVEVNTVDKYQGRDKSLILVSFVRSNEDGTLGELLKDWRRLNVAITRAKHKLILLGSVSSLKRF" example_title: "Protein Sequence 1" - text: "MNSVTVSHAPYYIVYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPPKFFIQLKQMLRNKRVCVCGILPYPIDGTGVPFESPNFTKKSIKEIASSISRLTGVIDYKGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPAARDRQFEKDRSFEIINELLELDNKVPINWAQGFIY" example_title: "Protein Sequence 2" - text: "MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY" example_title: "Protein Sequence 3" license: mit datasets: - AmelieSchreiber/general_binding_sites language: - en metrics: - precision - recall - f1 library_name: transformers tags: - biology - esm - esm2 - ESM-2 - protein language model --- # ESM-2 for General Protein Binding Site Prediction This model is trained to predict general binding sites of proteins using only the sequence. This is a finetuned version of `esm2_t6_8M_UR50D` ([see here](https://huggingface.co/facebook/esm2_t6_8M_UR50D) and [also here](https://huggingface.co/docs/transformers/model_doc/esm) for more info on the base model), trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/general_binding_sites). The data is not filtered by family, and thus the model may be overfit to some degree. In the Hugging Face Inference API widget to the right there are three protein sequence examples. The first is a DNA binding protein truncated to the first 1022 amino acid residues ([see UniProt entry here](https://www.uniprot.org/uniprotkb/D3ZG52/entry)). The second and third were obtained using [EvoProtGrad](https://github.com/Amelie-Schreiber/sampling_protein_language_models/blob/main/EvoProtGrad_copy.ipynb) a Markov Chain Monte Carlo method of (*in silico*) directed evolution of proteins based on a form of Gibbs sampling. The mutatant-type protein sequences in theory should have similar binding sites to the wild-type protein sequence, but perhaps with higher binding affinity. Testing this out on the model, we see the two proteins indeed have the same binding sites, which validates to some degree that the model has learned to predict binding sites well (and that EvoProtGrad works as intended). ## Training This model was trained on approximately 70,000 proteins with binding site and active site annotations in UniProt. The training split was a random 85/15 split for this version, and does not consider anything in the way of family or sequence similarity. New iterations of the model have been trained on larger datasets (over 200,000 proteins), with the split such that there are no overlapping families, however they seem to overfit much earlier and have significantly worse performance in terms of the training metrics (precision, recall, and F1). To address this we plan to implement LoRA (and hopefully QLoRA). Training Metrics for the Model in the form of the `trainer_state.json` can be [found here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2/blob/main/trainer_state.json). ``` epoch 3: Training Loss Validation Loss Precision Recall F1 Auc 0.031100 0.074720 0.684798 0.966856 0.801743 0.980853 ``` The hyperparameters are: ``` wandb: lr: 0.0004977045729600779 wandb: lr_scheduler_type: cosine wandb: max_grad_norm: 0.5 wandb: num_train_epochs: 3 wandb: per_device_train_batch_size: 8 wandb: weight_decay: 0.025 ``` ## Using the Model To use the model, try running: ```python import torch from transformers import AutoModelForTokenClassification, AutoTokenizer def predict_binding_sites(model_path, protein_sequences): """ Predict binding sites for a collection of protein sequences. Parameters: - model_path (str): Path to the saved model. - protein_sequences (List[str]): List of protein sequences. Returns: - List[List[str]]: Predicted labels for each sequence. """ # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForTokenClassification.from_pretrained(model_path) # Ensure model is in evaluation mode model.eval() # Tokenize sequences inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True) # Move to the same device as model and obtain logits with torch.no_grad(): logits = model(**inputs).logits # Obtain predicted labels predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy() # Convert label IDs to human-readable labels id2label = model.config.id2label human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels] return human_readable_labels # Usage: model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2" # Replace with your model's path unseen_proteins = [ "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD" ] # Replace with your protein sequences predictions = predict_binding_sites(model_path, unseen_proteins) predictions ```