ArenaTester / app.py
atrytone's picture
Update app.py
754548a
raw
history blame
6.65 kB
import gradio as gr
import os
import csv
import json
import uuid
import random
import pickle
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from googleapiclient.discovery import build
from google.oauth2 import service_account
USER_ID = uuid.uuid4()
SERVICE_ACCOUNT_JSON = os.environ.get('GOOGLE_SHEET_CREDENTIALS')
creds = service_account.Credentials.from_service_account_info(json.loads(SERVICE_ACCOUNT_JSON))
SPREADSHEET_ID = '1o0iKPxWYKYKEPjqB2YwrTgrLzvGyb9ULj9tnw_cfJb0'
service = build('sheets', 'v4', credentials=creds)
LEFT_MODEL = None
RIGHT_MODEL = None
with open("article_list.pkl","rb") as articles:
article_list = tuple(pickle.load(articles))
INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"]
MODELS = [
"biodatlab/MIReAD-Neuro-Large",
"biodatlab/MIReAD-Neuro-Contrastive",
"biodatlab/SciBERT-Neuro-Contrastive",
]
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
faiss_embedders = [HuggingFaceEmbeddings(
model_name=name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs) for name in MODELS]
vecdbs = [FAISS.load_local(index_name, faiss_embedder)
for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)]
def get_matchup():
global LEFT_MODEL, RIGHT_MODEL
choices = INDEXES
left, right = random.sample(choices,2)
LEFT_MODEL, RIGHT_MODEL = left, right
return left, right
def get_comp(prompt):
left, right = get_matchup()
left_output = inference(prompt,left)
right_output = inference(prompt,right)
return left_output, right_output
def get_article():
return random.choice(article_list)
def send_result(l_output, r_output, prompt, pick):
global LEFT_MODEL, RIGHT_MODEL
# with open('results.csv','a') as res_file:
# writer = csv.writer(res_file)
# writer.writerow(row)
if (pick=='left'):
pick = LEFT_MODEL
else:
pick = RIGHT_MODEL
row = [USER_ID,prompt,LEFT_MODEL,RIGHT_MODEL,pick]
row = [str(x) for x in row]
body = {'values': [row]}
result = service.spreadsheets().values().append(spreadsheetId=SPREADSHEET_ID, range='A1:E1', valueInputOption='RAW', body=body).execute()
print(f"Appended {result['updates']['updatedCells']} cells.")
new_prompt = get_article()
return new_prompt,gr.State.update(value=new_prompt)
def get_matches(query, db_name="miread_contrastive"):
"""
Wrapper to call the similarity search on the required index
"""
matches = vecdbs[INDEXES.index(
db_name)].similarity_search_with_score(query, k=30)
return matches
def inference(query, model="miread_contrastive"):
"""
This function processes information retrieved by the get_matches() function
Returns - Gradio update commands for the authors, abstracts and journals tablular output
"""
matches = get_matches(query, model)
auth_counts = {}
n_table = []
scores = [round(match[1].item(), 3) for match in matches]
min_score = min(scores)
max_score = max(scores)
def normaliser(x): return round(1 - (x-min_score)/max_score, 3)
i = 1
for match in matches:
doc = match[0]
score = round(normaliser(round(match[1].item(), 3)), 3)
title = doc.metadata['title']
author = doc.metadata['authors'][0].title()
date = doc.metadata.get('date', 'None')
link = doc.metadata.get('link', 'None')
# For authors
record = [score,
author,
title,
link,
date]
if auth_counts.get(author, 0) < 2:
n_table.append([i,]+record)
i += 1
if auth_counts.get(author, 0) == 0:
auth_counts[author] = 1
else:
auth_counts[author] += 1
n_output = gr.Dataframe.update(value=n_table[:10], visible=True)
return n_output
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# NBDT Recommendation Engine Arena")
gr.Markdown("NBDT Recommendation Engine for Editors is a tool for neuroscience authors/abstracts/journalsrecommendation built for NBDT journal editors. \
It aims to help an editor to find similar reviewers, abstracts, and journals to a given submitted abstract.\
To find a recommendation, paste a `title[SEP]abstract` or `abstract` in the text box below and click on the appropriate \"Find Matches\" button.\
Then, you can hover to authors/abstracts/journals tab to find a suggested list.\
The data in our current demo includes authors associated with the NBDT Journal. We will update the data monthly for an up-to-date publications.")
article = get_article()
models = gr.State(value=get_matchup())
prompt = gr.State(value=article)
abst = gr.Textbox(value = article, label="Abstract", lines=10)
action_btn = gr.Button(value="Get comparison")
with gr.Group():
with gr.Row().style(equal_height=True):
with gr.Column(scale=1):
l_output = gr.Dataframe(
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
datatype=['number', 'number', 'str', 'str', 'str', 'str'],
col_count=(6, "fixed"),
wrap=True,
visible=True,
label='Model A',
show_label = True,
overflow_row_behaviour='paginate',
scale=1
)
with gr.Column(scale=1):
r_output = gr.Dataframe(
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
datatype=['number', 'number', 'str', 'str', 'str', 'str'],
col_count=(6, "fixed"),
wrap=True,
visible=True,
label='Model B',
show_label = True,
overflow_row_behaviour='paginate',
scale=1
)
with gr.Row().style(equal_height=True):
l_btn = gr.Button(value="Model A is better",scale=1)
r_btn = gr.Button(value="Model B is better",scale=1)
action_btn.click(fn=get_comp,
inputs=[prompt,],
outputs=[l_output, r_output],
api_name="arena")
l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'),
inputs=[l_output,r_output,prompt],
outputs=[abst,],
api_name="feedleft")
r_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'),
inputs=[l_output,r_output,prompt],
outputs=[abst,prompt],
api_name="feedright")
demo.launch(debug=True)