File size: 15,782 Bytes
2b26389
72b0e49
2b26389
 
 
 
 
181fc76
2b26389
aad9fe1
9b993cf
f3ed046
 
 
aad9fe1
d0dd902
 
 
 
 
 
 
aad9fe1
d0dd902
aad9fe1
 
 
 
d0dd902
 
07ac117
aad9fe1
 
d0dd902
 
0b30831
aad9fe1
 
d0dd902
 
07ac117
aad9fe1
 
 
 
 
 
eb615db
aad9fe1
 
 
c34048a
 
aad9fe1
 
3daa625
 
 
 
f3ed046
 
 
 
 
 
 
 
 
 
 
 
eb615db
 
 
 
 
 
 
 
 
 
 
 
1a0324b
72b0e49
eb615db
3daa625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507babf
3daa625
507babf
3daa625
507babf
c8e59d5
3daa625
 
 
507babf
3daa625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3705c34
77b966b
3705c34
77b966b
3705c34
cdf31f1
507babf
2b26389
2bc812b
 
 
 
2b26389
 
 
 
 
eb615db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b16660
aad9fe1
6b16660
aad9fe1
eb615db
e95deab
eb615db
 
507babf
ab4d8af
eb615db
507babf
ab4d8af
eb615db
507babf
ab4d8af
1167137
2b26389
 
8df133a
491d478
8df133a
 
491d478
48629ab
8df133a
d638ffc
 
 
 
3e43b5a
d638ffc
 
 
 
4ce5cb4
3e43b5a
d638ffc
 
 
4ce5cb4
3e43b5a
48629ab
2b26389
 
222ff4a
 
d87c732
2b26389
c3846ee
 
 
 
 
 
 
 
 
a8642e4
 
 
 
 
 
da19b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106f1cd
 
 
 
 
 
 
 
6789f7b
106f1cd
 
 
 
 
 
 
6789f7b
2592450
 
fb28f39
 
2592450
106f1cd
a8642e4
 
c3846ee
 
8e8d15c
c3846ee
 
 
 
892748f
da19b49
c3846ee
507babf
ab4d8af
2ad7e51
da19b49
 
 
106f1cd
da19b49
 
 
 
 
 
106f1cd
 
da19b49
7a5fd5c
2b1fc35
 
848fbff
aa8d8d9
 
 
 
 
c3846ee
 
bf58ca4
 
 
 
 
 
c3846ee
eb615db
507babf
c3846ee
 
 
 
507babf
8e8d15c
 
d638ffc
bece41e
48629ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
# from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
from data.evaluate_data.utils import Ontology
import difflib
import re
from transformers import MistralForCausalLM
from huggingface_hub import hf_hub_download
bp_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_bp1.pth")
mf_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_mf2.pth")
cc_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_cc2.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/mf2_bert.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/bp1_bert.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/cc2_bert.pth")

# bert_param = BertModel.from_pretrained("bert-base-uncased").state_dict()
# Load the trained model
def get_model(type='Molecule Function'):
    model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
    if type == 'Molecule Function':
        # model.load_checkpoint("model/checkpoint_mf2.pth")
        model.load_checkpoint(mf_param)
        model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
        model.to('cuda')
    elif type == 'Biological Process':
        # model.load_checkpoint("model/checkpoint_bp1.pth")
        model.load_checkpoint(bp_param)
        model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
        model.to('cuda')
    elif type == 'Cellar Component':
        # model.load_checkpoint("model/checkpoint_cc2.pth")
        model.load_checkpoint(cc_param)
        model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
        model.to('cuda')
    return model

models = {
    'Molecule Function': get_model('Molecule Function'),
    'Biological Process': get_model('Biological Process'),
    'Cellular Component': get_model('Cellar Component'),
    }

# Load the mistral model
mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
mistral_model.to('cuda')

# Load ESM2 model
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model_esm.to('cuda')
model_esm.eval()

godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
go_des.columns = ['id', 'text']
go_des = go_des.dropna()
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
go_obo_set = set(go_des['id'].tolist())
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
GO_dict = dict(zip(go_des['text'], go_des['id']))
Func_dict = dict(zip(go_des['id'], go_des['text']))

terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
choices_mf = {x.lower(): x for x in choices_mf}
terms_bp = pd.read_pickle('data/terms/bp_terms.pkl')
choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))]
choices_bp = {x.lower(): x for x in choices_bp}
terms_cc = pd.read_pickle('data/terms/cc_terms.pkl')
choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))]
choices_cc = {x.lower(): x for x in choices_cc}
choices = {
    'Molecule Function': choices_mf,
    'Biological Process': choices_bp,
    'Cellular Component': choices_cc,
    }

@spaces.GPU
def generate_caption(protein, prompt):
    # Process the image and the prompt
    # with open('/home/user/app/example.fasta', 'w') as f:
    #     f.write('>{}\n'.format("protein_name"))
    #     f.write('{}\n'.format(protein.strip()))
    # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
    # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
    #                    model=model_esm, alphabet=alphabet,
    #                    include='per_tok', repr_layers=[36], truncation_seq_length=1024)

    protein_name = 'protein_name'
    protein_seq = protein
    include = 'per_tok'
    repr_layers = [36]
    truncation_seq_length = 1024
    toks_per_batch = 4096
    # print("start")
    dataset = FastaBatchedDataset([protein_name], [protein_seq])
    # print("dataset prepared")
    batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
    # print("batches prepared")

    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
    )
    # print(f"Read sequences")
    return_contacts = "contacts" in include

    assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
    repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")
            for i, label in enumerate(labels):
                result = {"label": label}
                truncate_len = min(truncation_seq_length, len(strs[i]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in include:
                    result["representations"] = {
                        layer: t[i, 1: truncate_len + 1].clone()
                        for layer, t in representations.items()
                    }
                if "mean" in include:
                    result["mean_representations"] = {
                        layer: t[i, 1: truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                if "bos" in include:
                    result["bos_representations"] = {
                        layer: t[i, 0].clone() for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
            esm_emb = result['representations'][36]
    '''
    inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
    with torch.no_grad():
        outputs = model_esm(**inputs)
    esm_emb = outputs.last_hidden_state.detach()[0]
    '''
    # print("esm embedding generated")
    esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
    if prompt is None:
        prompt = 'none'
    else:
        prompt = prompt.lower()
    samples = {'name': ['protein_name'],
               'image': torch.unsqueeze(esm_emb, dim=0),
               'text_input': ['none'],
               'prompt': [prompt]}

    union_pred_terms = []
    for model_id in models.keys():
        model = models[model_id]
        # Generate the output
        prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
                                    repetition_penalty=1.0)
        x = prediction[0]
        x = [eval(i) for i in x.split('; ')]
        pred_terms = []
        temp = []
        for i in x:
            txt = i[0]
            prob = i[1]
            sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9)
            if len(sim_list) > 0:
                t_standard = sim_list[0]
                if t_standard not in temp:
                    pred_terms.append(t_standard+f'({prob})')
                    temp.append(t_standard)
        union_pred_terms.append(pred_terms)

    if prompt == 'none':
        res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
    else:
        res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
    if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0:
        return res_str
    res_str = ''
    if len(union_pred_terms[0]) != 0:
        temp = ['- '+i+'\n' for i in union_pred_terms[0]]
        res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of \n{''.join(temp)} \n"
    if len(union_pred_terms[1]) != 0:
        temp = ['- ' + i + '\n' for i in union_pred_terms[1]]
        res_str += f"It is likely involved in the following process: \n{''.join(temp)} \n"
    if len(union_pred_terms[2]) != 0:
        temp = ['- ' + i + '\n' for i in union_pred_terms[2]]
        res_str += f"It's subcellular localization is within the: \n{''.join(temp)}"
    return res_str


def save_feedback(inputs):
    print(inputs)
    with open('feedback.txt', 'a+') as f:
        f.write(inputs+'\n')
    return "Thanks your advice!"
        

feedback_data = []
def chatbot_respond(message, history=[]):
    response = "yes"
    return response, history + [(message, response)]

# Functions to handle like/dislike
def upvote(vote_id):
    feedback_data.append((vote_id, "upvote"))
    print(f"Current feedback data: {feedback_data}")
    return "You liked this prediction"

def downvote(vote_id):
    feedback_data.append((vote_id, "downvote"))
    print(f"Current feedback data: {feedback_data}")
    return "You disliked this prediction"

        
# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
Our paper is available at [BioRxiv](https://www.biorxiv.org/content/10.1101/2024.05.07.593067v1)
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main).
Thanks for the support from ProtonUnfold Tech.  Co., Ltd (https://www.protonunfold.com/)."""

# iface = gr.Interface(
#     fn=generate_caption,
#     inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
#     outputs=gr.Textbox(label="Generated description"),
#     description=description
# )
# # Launch the interface
# iface.launch()

css = """
    #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
    }
    .submit-btn {  
    display: flex;  
    width: 100%;  
    gap: 10px;  
    } 
    
    .vote-container {  
        display: flex;  
        width: 100%;  
        gap: 10px;  
    }  
    
    .vote-buttons {  
        display: flex;  
        width: 100%;  
        gap: 10px;  
    }  
    
    .vote-btn {  
        display: flex;  
        width: 100%; 
        gap: 10px;
    }  
    
    .vote-content {  
        flex: 3;  
    }
    /* Style for the upvote button */
    .upvote-button {
        width: 500px; /* Set button width */
        height: 50px; /* Set button height */
        font-size: 20px; /* Set font size */
        background-color: #d4edda; /* Set background color */
        border-radius: 5px; /* Rounded corners */
    }
    
    /* Style for the downvote button */
    .downvote-button {
        width: 50px; /* Set button width */
        height: 50px; /* Set button height */
        font-size: 20px; /* Set font size */
        background-color: #f8d7da; /* Set background color */
        border-radius: 5px; /* Rounded corners */
    }
    .feedback {
        width: 40px; /* Set button width */
        height: 40px; /* Set button height */
        font-size: 16px; /* Set font size */
        background-color: #f8d7da; /* Set background color */
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(description)
    # vote_id = gr.State(0)
    with gr.Tab(label="Protein caption"):
        with gr.Row():
            with gr.Column():
                input_protein = gr.Textbox(type="text", label="Upload sequence")
                prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
                submit_btn = gr.Button(value="Submit", elem_classes=["submit-btn"])
            with gr.Column():
                # output_text = gr.Textbox(label="Output Text")
                with gr.Accordion('Prediction:', open=True):
                    output_markdown = gr.Markdown(label="Output")
        # ๆŠ•็ฅจๆŒ‰้’ฎๅ’Œๅ†…ๅฎน็š„ๅฎนๅ™จ  
        with gr.Row(elem_classes=["vote-container"]):  
            # ๆŠ•็ฅจๆŒ‰้’ฎ็ป„ 
            with gr.Column():
                with gr.Row():
                    with gr.Column():
                        upvote_button = gr.Button("๐Ÿ‘", elem_classes=["vote-btn"])
                    with gr.Column():
                        downvote_button = gr.Button("๐Ÿ‘Ž", elem_classes=["vote-btn"])  
            
            with gr.Column():
                vote_markdown = gr.Markdown(label="Output")

        with gr.Row():
            inputs = gr.Textbox(type="text", label="Your feedback")
            feedback_markdown = gr.Markdown(label="Output")
        with gr.Row():
            with gr.Column():
                feedback_btn = gr.Button(value="Feedback")
                # feedback_temp1 = gr.Markdown(label="Output")
            with gr.Column():
                feedback_temp2 = gr.Markdown(label="Output")
        gr.Examples(
            examples=[
                ["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
                ["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
                ["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
                ['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
                ['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
                ['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
            ],
            inputs=[input_protein, prompt],
            outputs=[output_markdown],
            fn=generate_caption,
            cache_examples=True,
            label='Try examples'
        )
        submit_btn.click(generate_caption, [input_protein, prompt], [output_markdown])
        upvote_button.click(upvote, input_protein, vote_markdown)
        downvote_button.click(downvote, input_protein, vote_markdown)
        feedback_btn.click(save_feedback, [inputs], [feedback_markdown])

demo.launch(debug=True)