SVM / app.py
abondrn's picture
Added progres
64a6606
raw
history blame
7.96 kB
# credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py
import os
import sys
from urllib import request
import gradio as gr
import requests
from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel
import torch
import progres as pg
tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
model_nt.eval()
tokenizer_aa = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
model_aa = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
model_aa.eval()
tokenizer_se = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model_se.eval()
def nt_embed(sequence: str):
tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"]
attention_mask = tokens_ids != tokenizer_nt.pad_token_id
with torch.no_grad():
torch_outs = model_nt(
tokens_ids,#.to('cuda'),
attention_mask=attention_mask,#.to('cuda'),
output_hidden_states=True
)
last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0]
return last_layer_CLS
def aa_embed(sequence: str):
tokens = tokenizer_aa([sequence], return_tensors="pt")
with torch.no_grad():
torch_outs = model_aa(**tokens)
return torch_outs[0]
def se_embed(sentence: str):
encoded_input = tokenizer_se([sentence], return_tensors='pt')
with torch.no_grad():
model_output = model_se(**encoded_input)
return model_output[0]
def download_data_if_required():
url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
fps = [pg.trained_model_fp]
urls = [f"{url_base}/trained_model.pt"]
#for targetdb in pre_embedded_dbs:
# fps.append(os.path.join(database_dir, targetdb + ".pt"))
# urls.append(f"{url_base}/{targetdb}.pt")
if not os.path.isdir(pg.trained_model_dir):
os.makedirs(pg.trained_model_dir)
#if not os.path.isdir(database_dir):
# os.makedirs(database_dir)
printed = False
for fp, url in zip(fps, urls):
if not os.path.isfile(fp):
if not printed:
print("Downloading data as first time setup (~340 MB) to ", pg.progres_dir,
", internet connection required, this can take a few minutes",
sep="", file=sys.stderr)
printed = True
try:
request.urlretrieve(url, fp)
d = torch.load(fp, map_location="cpu")
if fp == pg.trained_model_fp:
assert "model" in d
else:
assert "embeddings" in d
except:
if os.path.isfile(fp):
os.remove(fp)
print("Failed to download from", url, "and save to", fp, file=sys.stderr)
print("Exiting", file=sys.stderr)
sys.exit(1)
if printed:
print("Data downloaded successfully", file=sys.stderr)
def get_pdb(pdb_code="", filepath=""):
if pdb_code is None or pdb_code == "":
try:
with open(filepath.name) as f:
return f.read()
except AttributeError as e:
return None
else:
return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode()
def molecule(pdb):
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
let pdb = `"""
+ pdb
+ """`
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "black" };
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } });
viewer.addSurface("MS", { opacity: .5, color: "white" });
viewer.zoomTo();
viewer.render();
viewer.zoom(0.8, 2000);
})
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def str2coords(s):
coords = []
for line in s.split('\n'):
if (line.startswith("ATOM ") or line.startswith("HETATM")) and line[12:16].strip() == "CA":
coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
elif line.startswith("ENDMDL"):
break
return coords
def update_st(inp, file):
pdb = get_pdb(inp, file)
return (molecule(pdb), pg.embed_coords(str2coords(pdb)))
def update_nt(inp):
return str(nt_embed(inp or ''))
def update_aa(inp):
return str(aa_embed(inp))
def update_se(inp):
return str(se_embed(inp))
demo = gr.Blocks()
with demo:
with gr.Tabs():
with gr.TabItem("PDB Structural Embeddings"):
with gr.Row():
with gr.Box():
inp = gr.Textbox(
placeholder="PDB Code or upload file below", label="Input structure"
)
file = gr.File(file_count="single")
gr.Examples(["2CBA", "6VXX"], inp)
btn = gr.Button("View structure")
gr.Markdown("# PDB viewer using 3Dmol.js")
mol = gr.HTML()
emb = gr.Textbox(interactive=False)
btn.click(fn=update_st, inputs=[inp, file], outputs=[mol, emb])
with gr.TabItem("Nucleotide Sequence Embeddings"):
with gr.Box():
inp = gr.Textbox(
placeholder="ATCGCTGCCCGTAGATAATAAGAGACACTGAGGCC", label="Input Nucleotide Sequence"
)
btn = gr.Button("View embeddings")
emb = gr.Textbox(interactive=False)
btn.click(fn=update_nt, inputs=[inp], outputs=emb)
with gr.TabItem("Amino Acid Sequence Embeddings"):
with gr.Box():
inp = gr.Textbox(
placeholder="AAGQCYRGRCSGGLCCSKYGYCGSGPAYCG", label="Input Amino Acid Sequence"
)
btn = gr.Button("View embeddings")
emb = gr.Textbox(interactive=False)
btn.click(fn=update_aa, inputs=[inp], outputs=emb)
with gr.TabItem("Sentence Embeddings"):
with gr.Box():
inp = gr.Textbox(
placeholder="Your text here", label="Input Sentence"
)
btn = gr.Button("View embeddings")
emb = gr.Textbox(interactive=False)
btn.click(fn=update_se, inputs=[inp], outputs=emb)
if __name__ == "__main__":
download_data_if_required()
demo.launch()