Spaces:
Runtime error
Runtime error
File size: 6,566 Bytes
40e4418 7bed91f 2db50b5 d34229c 2db50b5 46247a1 40e4418 7bed91f 40e4418 2db50b5 3145888 e43f483 7bed91f cc2a0eb cded25d 7bed91f 46247a1 701a70b 2db50b5 cc2a0eb 2db50b5 cc2a0eb 2db50b5 cded25d 2db50b5 cded25d 2db50b5 46247a1 2db50b5 cded25d cc2a0eb 7bed91f 754548a cded25d 7bed91f 06ff1ff 7bed91f 9b95f10 cded25d 8d6b3c3 2db50b5 40e4418 891f199 117c8b1 40e4418 937de7f 40e4418 937de7f 40e4418 937de7f 40e4418 117c8b1 2db50b5 40e4418 2db50b5 4351e0d 1fa2af5 4351e0d caad63d 2db50b5 caad63d cded25d caad63d 2db50b5 614af9a 6f16285 2db50b5 9b95f10 2db50b5 46247a1 2db50b5 9b95f10 2db50b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import gradio as gr
import os
import csv
import json
import uuid
import random
import pickle
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from googleapiclient.discovery import build
from google.oauth2 import service_account
USER_ID = uuid.uuid4()
SERVICE_ACCOUNT_JSON = os.environ.get('GOOGLE_SHEET_CREDENTIALS')
creds = service_account.Credentials.from_service_account_info(json.loads(SERVICE_ACCOUNT_JSON))
SPREADSHEET_ID = '1o0iKPxWYKYKEPjqB2YwrTgrLzvGyb9ULj9tnw_cfJb0'
service = build('sheets', 'v4', credentials=creds)
LEFT_MODEL = None
RIGHT_MODEL = None
PROMPT = None
with open("article_list.pkl","rb") as articles:
article_list = tuple(pickle.load(articles))
INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"]
MODELS = [
"biodatlab/MIReAD-Neuro-Large",
"biodatlab/MIReAD-Neuro-Contrastive",
"biodatlab/SciBERT-Neuro-Contrastive",
]
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
faiss_embedders = [HuggingFaceEmbeddings(
model_name=name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs) for name in MODELS]
vecdbs = [FAISS.load_local(index_name, faiss_embedder)
for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)]
def get_matchup():
global LEFT_MODEL, RIGHT_MODEL
choices = INDEXES
left, right = random.sample(choices,2)
LEFT_MODEL, RIGHT_MODEL = left, right
return left, right
def get_comp(prompt):
global PROMPT
left, right = get_matchup()
left_output = inference(PROMPT,left)
right_output = inference(PROMPT,right)
return left_output, right_output
def get_article():
return random.choice(article_list)
def send_result(l_output, r_output, prompt, pick):
global PROMPT
global LEFT_MODEL, RIGHT_MODEL
# with open('results.csv','a') as res_file:
# writer = csv.writer(res_file)
# writer.writerow(row)
if (pick=='left'):
pick = LEFT_MODEL
else:
pick = RIGHT_MODEL
row = [USER_ID,PROMPT,LEFT_MODEL,RIGHT_MODEL,pick]
row = [str(x) for x in row]
body = {'values': [row]}
result = service.spreadsheets().values().append(spreadsheetId=SPREADSHEET_ID, range='A1:E1', valueInputOption='RAW', body=body).execute()
print(f"Appended {result['updates']['updatedCells']} cells.")
new_prompt = get_article()
PROMPT = new_prompt
return new_prompt,gr.State.update(value=new_prompt)
def get_matches(query, db_name="miread_contrastive"):
"""
Wrapper to call the similarity search on the required index
"""
matches = vecdbs[INDEXES.index(
db_name)].similarity_search_with_score(query, k=30)
return matches
def inference(query, model="miread_contrastive"):
"""
This function processes information retrieved by the get_matches() function
Returns - Gradio update commands for the authors, abstracts and journals tablular output
"""
matches = get_matches(query, model)
auth_counts = {}
n_table = []
scores = [round(match[1].item(), 3) for match in matches]
min_score = min(scores)
max_score = max(scores)
def normaliser(x): return round(1 - (x-min_score)/max_score, 3)
i = 1
for match in matches:
doc = match[0]
score = round(normaliser(round(match[1].item(), 3)), 3)
title = doc.metadata['title']
author = doc.metadata['authors'][0].title()
date = doc.metadata.get('date', 'None')
link = doc.metadata.get('link', 'None')
# For authors
record = [score,
author,
title,
link,
date]
if auth_counts.get(author, 0) < 2:
n_table.append([i,]+record)
i += 1
if auth_counts.get(author, 0) == 0:
auth_counts[author] = 1
else:
auth_counts[author] += 1
n_output = gr.Dataframe.update(value=n_table[:10], visible=True)
return n_output
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# NBDT Recommendation Engine Arena")
gr.Markdown("NBDT Recommendation Engine Arena is a tool designed to compare neuroscience abstract recommendations by our models. \
We will use this data to compare the performance of models and their preference by various neuroscientists.\
Click on the 'Get Comparision' button to run two random models on the displayed prompt. Then use the correct 'Model X is Better' button to give your vote.\
All models were trained on data provided to us by the NBDT Journal.")
article = get_article()
models = gr.State(value=get_matchup())
prompt = gr.State(value=article)
PROMPT = article
abst = gr.Textbox(value = article, label="Abstract", lines=10)
action_btn = gr.Button(value="Get comparison")
with gr.Group():
with gr.Row().style(equal_height=True):
with gr.Column(scale=1):
l_output = gr.Dataframe(
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
datatype=['number', 'number', 'str', 'str', 'str', 'str'],
col_count=(6, "fixed"),
wrap=True,
visible=True,
label='Model A',
show_label = True,
overflow_row_behaviour='paginate',
scale=1
)
with gr.Column(scale=1):
r_output = gr.Dataframe(
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
datatype=['number', 'number', 'str', 'str', 'str', 'str'],
col_count=(6, "fixed"),
wrap=True,
visible=True,
label='Model B',
show_label = True,
overflow_row_behaviour='paginate',
scale=1
)
with gr.Row().style(equal_height=True):
l_btn = gr.Button(value="Model A is better",scale=1)
r_btn = gr.Button(value="Model B is better",scale=1)
action_btn.click(fn=get_comp,
inputs=[prompt,],
outputs=[l_output, r_output],
api_name="arena")
l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'),
inputs=[l_output,r_output,prompt],
outputs=[abst,],
api_name="feedleft")
r_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'),
inputs=[l_output,r_output,prompt],
outputs=[abst,prompt],
api_name="feedright")
demo.launch(debug=True) |