# credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py import os import sys from urllib import request import gradio as gr import requests from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel import torch import progres as pg tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g") model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g") model_nt.eval() tokenizer_aa = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") model_aa = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D") model_aa.eval() tokenizer_se = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2') model_se.eval() def nt_embed(sequence: str): tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"] attention_mask = tokens_ids != tokenizer_nt.pad_token_id with torch.no_grad(): torch_outs = model_nt( tokens_ids,#.to('cuda'), attention_mask=attention_mask,#.to('cuda'), output_hidden_states=True ) last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0] return last_layer_CLS def aa_embed(sequence: str): tokens = tokenizer_aa([sequence], return_tensors="pt") with torch.no_grad(): torch_outs = model_aa(**tokens) return torch_outs[0] def se_embed(sentence: str): encoded_input = tokenizer_se([sentence], return_tensors='pt') with torch.no_grad(): model_output = model_se(**encoded_input) return model_output[0] def download_data_if_required(): url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files" fps = [pg.trained_model_fp] urls = [f"{url_base}/trained_model.pt"] #for targetdb in pre_embedded_dbs: # fps.append(os.path.join(database_dir, targetdb + ".pt")) # urls.append(f"{url_base}/{targetdb}.pt") if not os.path.isdir(pg.trained_model_dir): os.makedirs(pg.trained_model_dir) #if not os.path.isdir(database_dir): # os.makedirs(database_dir) printed = False for fp, url in zip(fps, urls): if not os.path.isfile(fp): if not printed: print("Downloading data as first time setup (~340 MB) to ", pg.progres_dir, ", internet connection required, this can take a few minutes", sep="", file=sys.stderr) printed = True try: request.urlretrieve(url, fp) d = torch.load(fp, map_location="cpu") if fp == pg.trained_model_fp: assert "model" in d else: assert "embeddings" in d except: if os.path.isfile(fp): os.remove(fp) print("Failed to download from", url, "and save to", fp, file=sys.stderr) print("Exiting", file=sys.stderr) sys.exit(1) if printed: print("Data downloaded successfully", file=sys.stderr) def get_pdb(pdb_code="", filepath=""): if pdb_code is None or pdb_code == "": try: with open(filepath.name) as f: return f.read() except AttributeError as e: return None else: return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode() def molecule(pdb): x = ( """
""" ) return f"""""" def str2coords(s): coords = [] for line in s.split('\n'): if (line.startswith("ATOM ") or line.startswith("HETATM")) and line[12:16].strip() == "CA": coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])]) elif line.startswith("ENDMDL"): break return coords def update_st(inp, file): pdb = get_pdb(inp, file) return (molecule(pdb), pg.embed_coords(str2coords(pdb))) def update_nt(inp): return str(nt_embed(inp or '')) def update_aa(inp): return str(aa_embed(inp)) def update_se(inp): return str(se_embed(inp)) demo = gr.Blocks() with demo: with gr.Tabs(): with gr.TabItem("PDB Structural Embeddings"): with gr.Row(): with gr.Box(): inp = gr.Textbox( placeholder="PDB Code or upload file below", label="Input structure" ) file = gr.File(file_count="single") gr.Examples(["2CBA", "6VXX"], inp) btn = gr.Button("View structure") gr.Markdown("# PDB viewer using 3Dmol.js") mol = gr.HTML() emb = gr.Textbox(interactive=False) btn.click(fn=update_st, inputs=[inp, file], outputs=[mol, emb]) with gr.TabItem("Nucleotide Sequence Embeddings"): with gr.Box(): inp = gr.Textbox( placeholder="ATCGCTGCCCGTAGATAATAAGAGACACTGAGGCC", label="Input Nucleotide Sequence" ) btn = gr.Button("View embeddings") emb = gr.Textbox(interactive=False) btn.click(fn=update_nt, inputs=[inp], outputs=emb) with gr.TabItem("Amino Acid Sequence Embeddings"): with gr.Box(): inp = gr.Textbox( placeholder="AAGQCYRGRCSGGLCCSKYGYCGSGPAYCG", label="Input Amino Acid Sequence" ) btn = gr.Button("View embeddings") emb = gr.Textbox(interactive=False) btn.click(fn=update_aa, inputs=[inp], outputs=emb) with gr.TabItem("Sentence Embeddings"): with gr.Box(): inp = gr.Textbox( placeholder="Your text here", label="Input Sentence" ) btn = gr.Button("View embeddings") emb = gr.Textbox(interactive=False) btn.click(fn=update_se, inputs=[inp], outputs=emb) if __name__ == "__main__": download_data_if_required() demo.launch()