File size: 6,566 Bytes
40e4418
7bed91f
2db50b5
d34229c
 
2db50b5
46247a1
40e4418
 
7bed91f
 
40e4418
2db50b5
3145888
e43f483
7bed91f
 
cc2a0eb
 
cded25d
7bed91f
46247a1
701a70b
2db50b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2a0eb
2db50b5
 
cc2a0eb
2db50b5
 
 
cded25d
2db50b5
cded25d
 
2db50b5
 
46247a1
 
2db50b5
 
 
cded25d
cc2a0eb
7bed91f
 
 
754548a
 
 
 
cded25d
7bed91f
 
06ff1ff
7bed91f
9b95f10
cded25d
8d6b3c3
2db50b5
40e4418
 
 
 
 
891f199
117c8b1
40e4418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937de7f
 
40e4418
 
 
 
 
 
 
 
937de7f
40e4418
 
 
 
 
937de7f
 
40e4418
 
 
 
117c8b1
2db50b5
40e4418
 
 
2db50b5
4351e0d
 
1fa2af5
4351e0d
caad63d
2db50b5
caad63d
cded25d
caad63d
2db50b5
614af9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f16285
 
2db50b5
 
 
 
 
 
 
 
9b95f10
2db50b5
46247a1
2db50b5
9b95f10
2db50b5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import os
import csv
import json
import uuid
import random
import pickle
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from googleapiclient.discovery import build
from google.oauth2 import service_account

USER_ID = uuid.uuid4()
SERVICE_ACCOUNT_JSON = os.environ.get('GOOGLE_SHEET_CREDENTIALS')
creds = service_account.Credentials.from_service_account_info(json.loads(SERVICE_ACCOUNT_JSON))
SPREADSHEET_ID = '1o0iKPxWYKYKEPjqB2YwrTgrLzvGyb9ULj9tnw_cfJb0'
service = build('sheets', 'v4', credentials=creds)
LEFT_MODEL = None
RIGHT_MODEL = None
PROMPT = None

with open("article_list.pkl","rb") as articles:
    article_list = tuple(pickle.load(articles))
INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"]
MODELS = [
    "biodatlab/MIReAD-Neuro-Large",
    "biodatlab/MIReAD-Neuro-Contrastive",
    "biodatlab/SciBERT-Neuro-Contrastive",
]
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
faiss_embedders = [HuggingFaceEmbeddings(
    model_name=name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs) for name in MODELS]

vecdbs = [FAISS.load_local(index_name, faiss_embedder)
          for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)]

def get_matchup():
    global LEFT_MODEL, RIGHT_MODEL
    choices = INDEXES
    left, right = random.sample(choices,2)
    LEFT_MODEL, RIGHT_MODEL = left, right
    return left, right

def get_comp(prompt):
    global PROMPT
    left, right = get_matchup()
    left_output = inference(PROMPT,left)
    right_output = inference(PROMPT,right)
    return left_output, right_output

def get_article():
    return random.choice(article_list)


def send_result(l_output, r_output, prompt, pick):
    global PROMPT
    global LEFT_MODEL, RIGHT_MODEL
    # with open('results.csv','a') as res_file:
    #   writer = csv.writer(res_file)
    #   writer.writerow(row)
    if (pick=='left'):
        pick = LEFT_MODEL
    else:
        pick = RIGHT_MODEL
    row = [USER_ID,PROMPT,LEFT_MODEL,RIGHT_MODEL,pick]
    row = [str(x) for x in row]
    body = {'values': [row]}
    result = service.spreadsheets().values().append(spreadsheetId=SPREADSHEET_ID, range='A1:E1', valueInputOption='RAW', body=body).execute()
    print(f"Appended {result['updates']['updatedCells']} cells.")
    new_prompt = get_article()
    PROMPT = new_prompt
    return new_prompt,gr.State.update(value=new_prompt)


def get_matches(query, db_name="miread_contrastive"):
    """
    Wrapper to call the similarity search on the required index
    """
    matches = vecdbs[INDEXES.index(
        db_name)].similarity_search_with_score(query, k=30)
    return matches


def inference(query, model="miread_contrastive"):
    """
    This function processes information retrieved by the get_matches() function
    Returns - Gradio update commands for the authors, abstracts and journals tablular output
    """
    matches = get_matches(query, model)
    auth_counts = {}
    n_table = []
    scores = [round(match[1].item(), 3) for match in matches]
    min_score = min(scores)
    max_score = max(scores)
    def normaliser(x): return round(1 - (x-min_score)/max_score, 3)
    i = 1
    for match in matches:
        doc = match[0]
        score = round(normaliser(round(match[1].item(), 3)), 3)
        title = doc.metadata['title']
        author = doc.metadata['authors'][0].title()
        date = doc.metadata.get('date', 'None')
        link = doc.metadata.get('link', 'None')

        # For authors
        record = [score,
                  author,
                  title,
                  link,
                  date]
        if auth_counts.get(author, 0) < 2:
            n_table.append([i,]+record)
            i += 1
            if auth_counts.get(author, 0) == 0:
                auth_counts[author] = 1
            else:
                auth_counts[author] += 1
    n_output = gr.Dataframe.update(value=n_table[:10], visible=True)
    return n_output


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# NBDT Recommendation Engine Arena")
    gr.Markdown("NBDT Recommendation Engine Arena is a tool designed to compare neuroscience abstract recommendations by our models. \
    We will use this data to compare the performance of models and their preference by various neuroscientists.\
    Click on the 'Get Comparision' button to run two random models on the displayed prompt. Then use the correct 'Model X is Better' button to give your vote.\
    All models were trained on data provided to us by the NBDT Journal.")
    article = get_article()
    models = gr.State(value=get_matchup())
    prompt = gr.State(value=article)
    PROMPT = article
    abst = gr.Textbox(value = article, label="Abstract", lines=10)
    action_btn = gr.Button(value="Get comparison")
    with gr.Group():
        with gr.Row().style(equal_height=True):
          with gr.Column(scale=1):
            l_output = gr.Dataframe(
                headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
                datatype=['number', 'number', 'str', 'str', 'str', 'str'],
                col_count=(6, "fixed"),
                wrap=True,
                visible=True,
                label='Model A',
                show_label = True,
                overflow_row_behaviour='paginate',
                scale=1
                )
          with gr.Column(scale=1):
            r_output = gr.Dataframe(
                headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
                datatype=['number', 'number', 'str', 'str', 'str', 'str'],
                col_count=(6, "fixed"),
                wrap=True,
                visible=True,
                label='Model B',
                show_label = True,
                overflow_row_behaviour='paginate',
                scale=1
                )
    with gr.Row().style(equal_height=True):
        l_btn = gr.Button(value="Model A is better",scale=1)
        r_btn = gr.Button(value="Model B is better",scale=1)

    action_btn.click(fn=get_comp,
        inputs=[prompt,],
        outputs=[l_output, r_output],
        api_name="arena")
    l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'),
                inputs=[l_output,r_output,prompt],
                outputs=[abst,],
                api_name="feedleft")
    r_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'),
                inputs=[l_output,r_output,prompt],
                outputs=[abst,prompt],
                api_name="feedright")

demo.launch(debug=True)