Spaces:
Sleeping
Sleeping
import pandas as pd | |
from IPython.display import clear_output | |
import torch | |
from transformers import EsmForSequenceClassification, AdamW, AutoTokenizer | |
from torch.utils.data import DataLoader, TensorDataset, random_split | |
from sklearn.preprocessing import LabelEncoder | |
from tqdm import tqdm | |
import numpy as np | |
import seaborn as sns | |
from sklearn.model_selection import train_test_split | |
import matplotlib | |
matplotlib.use('Agg') # Use the non-interactive Agg backend | |
import matplotlib.pyplot as plt | |
import pickle | |
import torch.nn.functional as F | |
import gradio as gr | |
import io | |
from PIL import Image | |
import Bio | |
from Bio import SeqIO | |
from Bio.Blast import NCBIXML | |
import subprocess | |
import zipfile | |
import os | |
GTA_fam_dict = { | |
0: "GT116", | |
1: "GT12", | |
2: "GT13", | |
3: "GT14", | |
4: "GT15", | |
5: "GT16", | |
6: "GT17", | |
7: "GT2-clade1", | |
8: "GT2-clade2", | |
9: "GT2-clade3", | |
10: "GT2-clade4", | |
11: "GT2-clade5", | |
12: "GT2-related", | |
13: "GT21", | |
14: "GT24", | |
15: "GT25", | |
16: "GT27", | |
17: "GT31", | |
18: "GT32", | |
19: "GT34", | |
20: "GT40", | |
21: "GT43", | |
22: "GT45", | |
23: "GT49", | |
24: "GT54", | |
25: "GT55", | |
26: "GT6", | |
27: "GT60", | |
28: "GT62", | |
29: "GT64", | |
30: "GT67", | |
31: "GT7", | |
32: "GT75", | |
33: "GT77", | |
34: "GT78", | |
35: "GT8", | |
36: "GT81", | |
37: "GT82", | |
38: "GT84", | |
39: "GT88", | |
40: "GT92" | |
} | |
GTA_don_dict = { | |
0: "N-Acetyl Galactosamine", | |
1: "N-Acetyl Glucosamine", | |
2: "Arabinose", | |
3: "Galactose", | |
4: "Galacturonic Acid", | |
5: "Glucose", | |
6: "Glucuronic Acid", | |
7: "Mannose", | |
8: "Rhamnose", | |
9: "Xylose" | |
} | |
GTB_fam_dict = { | |
0: "GT1", | |
1: "GT10", | |
2: "GT104", | |
3: "GT11", | |
4: "GT18", | |
5: "GT19", | |
6: "GT20", | |
7: "GT23", | |
8: "GT28", | |
9: "GT3", | |
10: "GT30", | |
11: "GT35", | |
12: "GT37", | |
13: "GT38", | |
14: "GT4", | |
15: "GT41", | |
16: "GT5", | |
17: "GT52", | |
18: "GT63", | |
19: "GT65", | |
20: "GT68", | |
21: "GT70", | |
22: "GT72", | |
23: "GT80", | |
24: "GT9", | |
25: "GT90", | |
26: "GT99" | |
} | |
GTB_don_dict = { | |
0: "Fucose", | |
1: "Galactose", | |
2: "N-Acetyl Galactosamine", | |
3: "Glucuronic Acid", | |
4: "N-Acetyl Glucosamine", | |
5: "Glucose", | |
6: "Mannose", | |
7: "Other", | |
8: "Xylose" | |
} | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D | |
glycosyltransferase_db = { | |
"GT40" : {'CAZy Name': 'GT40', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT40.html'}, | |
"GT16" : {'CAZy Name': 'GT16', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT16.html'}, | |
"GT27" : {'CAZy Name': 'GT27', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT27.html'}, | |
"GT55" : {'CAZy Name': 'GT55', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT55.html'}, | |
"GT25" : {'CAZy Name': 'GT25', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT25.html'}, | |
"GT2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT84" : {'CAZy Name': 'GT84', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT84.html'}, | |
"GT13" : {'CAZy Name': 'GT13', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT13.html'}, | |
"GT67" : {'CAZy Name': 'GT67', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT67.html'}, | |
"GT82" : {'CAZy Name': 'GT82', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT82.html'}, | |
"GT24" : {'CAZy Name': 'GT24', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT24.html'}, | |
"GT81" : {'CAZy Name': 'GT81', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT81.html'}, | |
"GT49" : {'CAZy Name': 'GT49', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT49.html'}, | |
"GT34" : {'CAZy Name': 'GT34', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT34.html'}, | |
"GT45" : {'CAZy Name': 'GT45', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT45.html'}, | |
"GT32" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'}, | |
"GT88" : {'CAZy Name': 'GT88', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT88.html'}, | |
"GT21" : {'CAZy Name': 'GT21', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT21.html'}, | |
"GT54" : {'CAZy Name': 'GT54', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT54.html'}, | |
"GT6" : {'CAZy Name': 'GT6 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT6.html' }, | |
"GT7" : {'CAZy Name': 'GT7 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT7.html' }, | |
"GT64" : {'CAZy Name': 'GT64', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT64.html'}, | |
"GT78" : {'CAZy Name': 'GT78', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT78.html'}, | |
"GT12" : {'CAZy Name': 'GT12', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT12.html'}, | |
"GT31" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'}, | |
"GT62" : {'CAZy Name': 'GT62', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT62.html'}, | |
"GT8" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' }, | |
"GT15" : {'CAZy Name': 'GT15', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT15.html'}, | |
"GT43" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'}, | |
"GT60" : {'CAZy Name': 'GT60', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT60.html'}, | |
"GT14" : {'CAZy Name': 'GT14', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT14.html'}, | |
"GT17" : {'CAZy Name': 'GT17', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT17.html'}, | |
"GT77" : {'CAZy Name': 'GT77', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT77.html'}, | |
"GT75" : {'CAZy Name': 'GT75', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT75.html'}, | |
} | |
def parse_blast_output_for_best_evalue(output_file): | |
with open(output_file) as result_handle: | |
blast_record = NCBIXML.read(result_handle) | |
if len(blast_record.alignments) == 0: | |
# Handle the case where no alignments are found | |
# You might return a high e-value or None to indicate no match | |
return None | |
best_hit = blast_record.alignments[0] | |
best_evalue = best_hit.hsps[0].expect | |
print(best_evalue) | |
return best_evalue | |
def run_local_blast(sequence, database): | |
# Temporarily save the query sequence to a file | |
query_file = "temp_query.fasta" | |
with open(query_file, "w") as file: | |
file.write(">Query\n" + sequence) | |
# Specify the output file for BLAST results | |
output_file = "blast_results.xml" | |
# Construct the BLAST command | |
blast_cmd = [ | |
"blastp", | |
"-query", query_file, | |
"-db", database, | |
"-out", output_file, | |
"-outfmt", "5", # Output format 5 is XML | |
"-evalue", "1e-2" # Set your desired E-value threshold here | |
] | |
# Execute the BLAST search | |
subprocess.run(blast_cmd, check=True) | |
# Parse the BLAST output to find the best E-value | |
best_evalue = parse_blast_output_for_best_evalue(output_file) | |
# Clean up temporary files | |
os.remove(query_file) | |
os.remove(output_file) | |
return best_evalue | |
def get_family_info(family_name): | |
family_info = glycosyltransferase_db.get(family_name, {}) | |
output = "" | |
for key, value in family_info.items(): | |
if key == "more_info": | |
output += "**{}:**".format(key.title().replace("_", " ")) + "\n" | |
for link in value: | |
output += "[{}]({}) ".format(link, link) | |
else: | |
output += "**{}:** {} ".format(key.title().replace("_", " "), value) | |
return output | |
def fig_to_img(fig): | |
"""Converts a matplotlib figure to a PIL Image and returns it""" | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def preprocess_protein_sequence(protein_fasta): | |
lines = protein_fasta.split('\n') | |
headers = [line for line in lines if line.startswith('>')] | |
if len(headers) > 1: | |
return None, None, None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence." | |
protein_sequence = ''.join(line for line in lines if not line.startswith('>')) | |
valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") | |
# Check if every character in the sequence is in the set of valid characters. | |
if any(char.upper() not in valid_characters for char in protein_sequence): | |
return None, None, None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids." | |
print("Running Blast.") | |
gta_db_path = "blast_data/GTA/GTA.db" | |
gtb_db_path = "blast_data/GTB/GTB.db" | |
evalue_gta = run_local_blast(protein_sequence, gta_db_path) | |
evalue_gta = evalue_gta if evalue_gta is not None else 1e+100 | |
evalue_gtb = run_local_blast(protein_sequence, gtb_db_path) | |
evalue_gtb = evalue_gtb if evalue_gtb is not None else 1e+100 | |
print("E-value GT-A:", evalue_gta, "E-value GT-B:", evalue_gtb) | |
print("Blast finished running. Checking sequence against known data.") | |
# Determine which models to use based on the best E-value | |
model_fam = "GTA_fam.pth" if evalue_gta < evalue_gtb else "GTB_fam.pth" | |
model_don = "GTA_don.pth" if evalue_gta < evalue_gtb else "GTB_don.pth" | |
print("Selected model for family:", model_fam, "and donor:", model_don) | |
# Adjust your existing condition to check if both E-values exceed the threshold | |
if evalue_gta > 1e-2 and evalue_gtb > 1e-2: | |
# If both E-values are above the threshold, it suggests the sequence does not match well with either database | |
return None, None, None, "**Warning:** The sequence does not appear to be a GT-A or GT-B. Please ensure you are submitting a sequence from these families." | |
return protein_sequence, model_fam, model_don, None | |
def process_family_sequence(protein_sequence, modelfam, label_dict): | |
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt") | |
input_idsfam = encoded_input["input_ids"] | |
attention_maskfam = encoded_input["attention_mask"] | |
with torch.no_grad(): | |
outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam) | |
logitsfam = outputfam.logits | |
probabilitiesfam = F.softmax(logitsfam, dim=1) | |
_, predicted_labelsfam = torch.max(logitsfam, dim=1) | |
predicted_label_index_fam = predicted_labelsfam.item() # Assuming single sample prediction | |
decoded_label_fam = label_dict.get(predicted_label_index_fam, "Unknown Label") # Decoding label using the dictionary | |
family_info = get_family_info(decoded_label_fam) | |
figfam = plt.figure(figsize=(10, 5)) | |
# probabilitiesfam_flat = probabilitiesfam.squeeze().tolist() # Flatten probabilities | |
# Extract and sort top 5 label probabilities | |
top5_probs, top5_labels = torch.topk(probabilitiesfam, 5) | |
top5_labels = top5_labels.squeeze().tolist() | |
top5_decoded_labels = [label_dict.get(label, "Unknown") for label in top5_labels] | |
# For debugging | |
print("Top 5 labels:", top5_labels) | |
print("Available keys in label_dict:", label_dict.keys()) | |
y_posfam = np.arange(len(top5_decoded_labels)) | |
plt.barh(y_posfam, [prob * 100 for prob in top5_probs.squeeze().tolist()], align='center', alpha=0.5) | |
plt.yticks(y_posfam, top5_decoded_labels) | |
plt.xlabel('Probability (%)') | |
plt.title('Top 5 Family Class Probabilities') | |
plt.xlim(0, 100) | |
plt.close(figfam) | |
img = fig_to_img(figfam) | |
if len(protein_sequence) < 100: | |
return decoded_label_fam, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}" | |
return decoded_label_fam, img, None, family_info | |
def process_donor_sequence(protein_sequence, modeldon, label_dict): | |
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt") | |
input_idsdon = encoded_input["input_ids"] | |
attention_maskdon = encoded_input["attention_mask"] | |
with torch.no_grad(): | |
outputdon = modeldon(input_idsdon, attention_mask=attention_maskdon) | |
logitsdon = outputdon.logits | |
probabilitiesdon = F.softmax(logitsdon, dim=1) | |
_, predicted_labelsdon = torch.max(logitsdon, dim=1) | |
predicted_label_index_don = predicted_labelsdon.item() # Assuming single sample prediction | |
decoded_label_don = label_dict.get(predicted_label_index_don, "Unknown Label") # Decoding label using the dictionary | |
figdon = plt.figure(figsize=(10, 5)) | |
probabilitiesdon_flat = probabilitiesdon.squeeze().tolist() # Flatten probabilities | |
# Extract and sort top 5 label probabilities | |
top3_probs, top3_labels = torch.topk(probabilitiesdon, 3) | |
top3_labels = top3_labels.squeeze().tolist() | |
top3_decoded_labels = [label_dict.get(label, "Unknown") for label in top3_labels] | |
y_posdon = np.arange(len(top3_decoded_labels)) | |
plt.barh(y_posdon, [prob * 100 for prob in top3_probs.squeeze().tolist()], align='center', alpha=0.5) | |
plt.yticks(y_posdon, top3_decoded_labels) | |
plt.xlabel('Probability (%)') | |
plt.title('Top 3 Donor Class Probabilities') | |
plt.xlim(0, 100) | |
plt.close(figdon) | |
img = fig_to_img(figdon) | |
if len(protein_sequence) < 100: | |
return decoded_label_don, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}" | |
return decoded_label_don, img, None | |
def main_function_single(sequence): | |
# Initial preprocessing including BLAST-based model selection | |
protein_sequence, model_fam_path, model_don_path, error_msg = preprocess_protein_sequence(sequence) | |
if error_msg: | |
print(error_msg) | |
return None, None, error_msg, None, None | |
model_config = { | |
"GTA_fam.pth": {"num_labels": 41, "label_dict": GTA_fam_dict}, | |
"GTB_fam.pth": {"num_labels": 27, "label_dict": GTB_fam_dict}, | |
"GTA_don.pth": {"num_labels": 10, "label_dict": GTA_don_dict}, | |
"GTB_don.pth": {"num_labels": 9, "label_dict": GTB_don_dict}, | |
} | |
# Load the model for family classification | |
config_fam = model_config[model_fam_path] | |
model_fam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_fam["num_labels"]) | |
model_fam.load_state_dict(torch.load(model_fam_path, map_location=torch.device('cpu')), strict=False) | |
model_fam.eval() | |
model_fam.to('cpu') | |
# Load the model for donor classification | |
config_don = model_config[model_don_path] | |
model_don = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_don["num_labels"]) | |
model_don.load_state_dict(torch.load(model_don_path, map_location=torch.device('cpu')), strict=False) | |
model_don.eval() | |
model_don.to('cpu') | |
print(config_fam["label_dict"]) | |
# Pass the label dictionary along with the model to the processing functions | |
family_label, family_img, _, family_info = process_family_sequence(protein_sequence, model_fam, config_fam["label_dict"]) | |
donor_label, donor_img, _ = process_donor_sequence(protein_sequence, model_don, config_don["label_dict"]) | |
return family_label, family_img, family_info, donor_label, donor_img | |
prediction_imagefam = gr.outputs.Image(type='pil', label="Family prediction graph") | |
prediction_imagedonor = gr.outputs.Image(type='pil', label="Donor prediction graph") | |
with gr.Blocks() as app: | |
gr.Markdown("# Glydentify (alpha v0.5)") | |
with gr.Tab("Single Sequence Prediction"): | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
sequence = gr.inputs.Textbox(lines=16, placeholder='Enter Protein Sequence Here...', label="Protein Sequence") | |
# explanation_checkbox = gr.inputs.Checkbox(label="Show Explanation", default=False) | |
with gr.Column(): | |
with gr.Accordion("Example:"): | |
gr.Markdown(""" | |
\>sp|Q9Y5Z6|B3GT1_HUMAN Beta-1,3-galactosyltransferase 1 OS=Homo sapiens OX=9606 GN=B3GALT1 PE=1 SV=1 | |
MASKVSCLYVLTVVCWASALWYLSITRPTSSYTGSKPFSHLTVARKNFTFGNIRTRPINPHSFEFLINEPNKCEKNIPFLVILIST | |
THKEFDARQAIRETWGDENNFKGIKIATLFLLGKNADPVLNQMVEQESQIFHDIIVEDFIDSYHNLTLKTLMGMRWVATFCSK | |
AKYVMKTDSDIFVNMDNLIYKLLKPSTKPRRRYFTGYVINGGPIRDVRSKWYMPRDLYPDSNYPPFCSGTGYIFSADVAELIYK | |
TSLHTRLLHLEDVYVGLCLRKLGIHPFQNSGFNHWKMAYSLCRYRRVITVHQISPEEMHRIWNDMSSKKHLRC | |
""") | |
family_prediction = gr.outputs.Textbox(label="Predicted family") | |
donor_prediction = gr.outputs.Textbox(label="Predicted donor") | |
info_markdown = gr.Markdown() | |
# Predict and Clear buttons | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
predict_button = gr.Button("Predict") | |
predict_button.click(main_function_single, inputs=[sequence], | |
outputs=[family_prediction, prediction_imagefam, info_markdown, | |
donor_prediction, prediction_imagedonor]) | |
# Family & Donor Section | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
with gr.Accordion("Family Prediction:"): | |
prediction_imagefam.render() # = gr.outputs.Image(type='pil', label="Family prediction graph") | |
with gr.Column(): | |
with gr.Accordion("Donor Prediction:"): | |
prediction_imagedonor.render() # = gr.outputs.Image(type='pil', label="Donor prediction graph") | |
app.launch(show_error=True) | |