|
|
|
|
|
import os |
|
import sys |
|
from urllib import request |
|
|
|
import gradio as gr |
|
import requests |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel |
|
import torch |
|
import progres as pg |
|
|
|
|
|
tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g") |
|
model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g") |
|
model_nt.eval() |
|
|
|
tokenizer_aa = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") |
|
model_aa = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D") |
|
model_aa.eval() |
|
|
|
tokenizer_se = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') |
|
model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2') |
|
model_se.eval() |
|
|
|
|
|
def nt_embed(sequence: str): |
|
tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"] |
|
attention_mask = tokens_ids != tokenizer_nt.pad_token_id |
|
with torch.no_grad(): |
|
torch_outs = model_nt( |
|
tokens_ids, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True |
|
) |
|
last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0] |
|
return last_layer_CLS |
|
|
|
|
|
def aa_embed(sequence: str): |
|
tokens = tokenizer_aa([sequence], return_tensors="pt") |
|
with torch.no_grad(): |
|
torch_outs = model_aa(**tokens) |
|
return torch_outs[0] |
|
|
|
|
|
def se_embed(sentence: str): |
|
encoded_input = tokenizer_se([sentence], return_tensors='pt') |
|
with torch.no_grad(): |
|
model_output = model_se(**encoded_input) |
|
return model_output[0] |
|
|
|
|
|
def download_data_if_required(): |
|
url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files" |
|
fps = [pg.trained_model_fp] |
|
urls = [f"{url_base}/trained_model.pt"] |
|
|
|
|
|
|
|
|
|
if not os.path.isdir(pg.trained_model_dir): |
|
os.makedirs(pg.trained_model_dir) |
|
|
|
|
|
|
|
printed = False |
|
for fp, url in zip(fps, urls): |
|
if not os.path.isfile(fp): |
|
if not printed: |
|
print("Downloading data as first time setup (~340 MB) to ", pg.progres_dir, |
|
", internet connection required, this can take a few minutes", |
|
sep="", file=sys.stderr) |
|
printed = True |
|
try: |
|
request.urlretrieve(url, fp) |
|
d = torch.load(fp, map_location="cpu") |
|
if fp == pg.trained_model_fp: |
|
assert "model" in d |
|
else: |
|
assert "embeddings" in d |
|
except: |
|
if os.path.isfile(fp): |
|
os.remove(fp) |
|
print("Failed to download from", url, "and save to", fp, file=sys.stderr) |
|
print("Exiting", file=sys.stderr) |
|
sys.exit(1) |
|
|
|
if printed: |
|
print("Data downloaded successfully", file=sys.stderr) |
|
|
|
|
|
def get_pdb(pdb_code="", filepath=""): |
|
if pdb_code is None or pdb_code == "": |
|
try: |
|
with open(filepath.name) as f: |
|
return f.read() |
|
except AttributeError as e: |
|
return None |
|
else: |
|
return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode() |
|
|
|
|
|
def molecule(pdb): |
|
|
|
x = ( |
|
"""<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> |
|
<style> |
|
body{ |
|
font-family:sans-serif |
|
} |
|
.mol-container { |
|
width: 100%; |
|
height: 600px; |
|
position: relative; |
|
} |
|
.mol-container select{ |
|
background-image:None; |
|
} |
|
</style> |
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> |
|
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> |
|
</head> |
|
<body> |
|
<div id="container" class="mol-container"></div> |
|
|
|
<script> |
|
let pdb = `""" |
|
+ pdb |
|
+ """` |
|
|
|
$(document).ready(function () { |
|
let element = $("#container"); |
|
let config = { backgroundColor: "black" }; |
|
let viewer = $3Dmol.createViewer(element, config); |
|
viewer.addModel(pdb, "pdb"); |
|
viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } }); |
|
viewer.addSurface("MS", { opacity: .5, color: "white" }); |
|
viewer.zoomTo(); |
|
viewer.render(); |
|
viewer.zoom(0.8, 2000); |
|
}) |
|
</script> |
|
</body></html>""" |
|
) |
|
|
|
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
|
|
def str2coords(s): |
|
coords = [] |
|
for line in s.split('\n'): |
|
if (line.startswith("ATOM ") or line.startswith("HETATM")) and line[12:16].strip() == "CA": |
|
coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])]) |
|
elif line.startswith("ENDMDL"): |
|
break |
|
return coords |
|
|
|
|
|
def update_st(inp, file): |
|
pdb = get_pdb(inp, file) |
|
return (molecule(pdb), pg.embed_coords(str2coords(pdb))) |
|
|
|
|
|
def update_nt(inp): |
|
return str(nt_embed(inp or '')) |
|
|
|
|
|
def update_aa(inp): |
|
return str(aa_embed(inp)) |
|
|
|
|
|
def update_se(inp): |
|
return str(se_embed(inp)) |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
with gr.Tabs(): |
|
with gr.TabItem("PDB Structural Embeddings"): |
|
with gr.Row(): |
|
with gr.Box(): |
|
inp = gr.Textbox( |
|
placeholder="PDB Code or upload file below", label="Input structure" |
|
) |
|
file = gr.File(file_count="single") |
|
gr.Examples(["2CBA", "6VXX"], inp) |
|
btn = gr.Button("View structure") |
|
gr.Markdown("# PDB viewer using 3Dmol.js") |
|
mol = gr.HTML() |
|
emb = gr.Textbox(interactive=False) |
|
btn.click(fn=update_st, inputs=[inp, file], outputs=[mol, emb]) |
|
with gr.TabItem("Nucleotide Sequence Embeddings"): |
|
with gr.Box(): |
|
inp = gr.Textbox( |
|
placeholder="ATCGCTGCCCGTAGATAATAAGAGACACTGAGGCC", label="Input Nucleotide Sequence" |
|
) |
|
btn = gr.Button("View embeddings") |
|
emb = gr.Textbox(interactive=False) |
|
btn.click(fn=update_nt, inputs=[inp], outputs=emb) |
|
with gr.TabItem("Amino Acid Sequence Embeddings"): |
|
with gr.Box(): |
|
inp = gr.Textbox( |
|
placeholder="AAGQCYRGRCSGGLCCSKYGYCGSGPAYCG", label="Input Amino Acid Sequence" |
|
) |
|
btn = gr.Button("View embeddings") |
|
emb = gr.Textbox(interactive=False) |
|
btn.click(fn=update_aa, inputs=[inp], outputs=emb) |
|
with gr.TabItem("Sentence Embeddings"): |
|
with gr.Box(): |
|
inp = gr.Textbox( |
|
placeholder="Your text here", label="Input Sentence" |
|
) |
|
btn = gr.Button("View embeddings") |
|
emb = gr.Textbox(interactive=False) |
|
btn.click(fn=update_se, inputs=[inp], outputs=emb) |
|
|
|
if __name__ == "__main__": |
|
download_data_if_required() |
|
demo.launch() |