Spaces:
Runtime error
Runtime error
import gradio as gr | |
import csv | |
import random | |
import uuid | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
USER_ID = uuid.uuid4() | |
INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"] | |
MODELS = [ | |
"biodatlab/MIReAD-Neuro-Large", | |
"biodatlab/MIReAD-Neuro-Contrastive", | |
"biodatlab/SciBERT-Neuro-Contrastive", | |
] | |
model_kwargs = {'device': 'cpu'} | |
encode_kwargs = {'normalize_embeddings': False} | |
faiss_embedders = [HuggingFaceEmbeddings( | |
model_name=name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs) for name in MODELS] | |
vecdbs = [FAISS.load_local(index_name, faiss_embedder) | |
for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)] | |
def get_matchup(): | |
choices = INDEXES | |
left, right = random.sample(choices,2) | |
return left, right | |
def get_comp(prompt): | |
left, right = get_matchup() | |
left_output = inference(prompt,left) | |
right_output = inference(prompt,right) | |
return left_output, right_output | |
def get_article(db_name="miread_contrastive"): | |
db = vecdbs[INDEXES.index(db_name)] | |
return db.docstore._dict[db.index_to_docstore_id[0]] | |
def send_result(l_output, r_output, prompt, pick): | |
with csv.open('results.csv','a') as res_file: | |
writer = csv.writer(res_file) | |
row = [USER_ID,left,right,prompt,pick] | |
writer.writerow(row) | |
def get_matches(query, db_name="miread_contrastive"): | |
""" | |
Wrapper to call the similarity search on the required index | |
""" | |
matches = vecdbs[INDEXES.index( | |
db_name)].similarity_search_with_score(query, k=60) | |
return matches | |
def inference(query, model="miread_contrastive"): | |
""" | |
This function processes information retrieved by the get_matches() function | |
Returns - Gradio update commands for the authors, abstracts and journals tablular output | |
""" | |
matches = get_matches(query, model) | |
auth_counts = {} | |
n_table = [] | |
scores = [round(match[1].item(), 3) for match in matches] | |
min_score = min(scores) | |
max_score = max(scores) | |
def normaliser(x): return round(1 - (x-min_score)/max_score, 3) | |
for i, match in enumerate(matches): | |
doc = match[0] | |
score = round(normaliser(round(match[1].item(), 3)), 3) | |
title = doc.metadata['title'] | |
author = doc.metadata['authors'][0].title() | |
date = doc.metadata.get('date', 'None') | |
link = doc.metadata.get('link', 'None') | |
# For authors | |
record = [i+1, | |
score, | |
author, | |
title, | |
link, | |
date] | |
if auth_counts.get(author, 0) < 2: | |
n_table.append(record) | |
if auth_counts.get(author, 0) == 0: | |
auth_counts[author] = 1 | |
else: | |
auth_counts[author] += 1 | |
n_output = gr.Dataframe.update(value=n_table, visible=True) | |
return n_output | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# NBDT Recommendation Engine Arena") | |
gr.Markdown("NBDT Recommendation Engine for Editors is a tool for neuroscience authors/abstracts/journalsrecommendation built for NBDT journal editors. \ | |
It aims to help an editor to find similar reviewers, abstracts, and journals to a given submitted abstract.\ | |
To find a recommendation, paste a `title[SEP]abstract` or `abstract` in the text box below and click on the appropriate \"Find Matches\" button.\ | |
Then, you can hover to authors/abstracts/journals tab to find a suggested list.\ | |
The data in our current demo includes authors associated with the NBDT Journal. We will update the data monthly for an up-to-date publications.") | |
article = get_article() | |
models = gr.State(value=get_matchup()) | |
prompt = gr.State(value=article) | |
abst = gr.Textbox(value = article, label="Abstract", lines=10) | |
action_btn = gr.Button(value="Get comparison") | |
with gr.Row().style(equal_height=True): | |
with gr.Column(scale=1): | |
l_output = gr.Dataframe( | |
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'], | |
datatype=['number', 'number', 'str', 'str', 'str', 'str'], | |
col_count=(6, "fixed"), | |
wrap=True, | |
visible=True, | |
label='Model A', | |
show_label = True, | |
scale=1 | |
) | |
l_btn = gr.Button(value="Model A is better",scale=1) | |
with gr.Column(scale=1): | |
r_output = gr.Dataframe( | |
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'], | |
datatype=['number', 'number', 'str', 'str', 'str', 'str'], | |
col_count=(6, "fixed"), | |
wrap=True, | |
visible=True, | |
label='Model B', | |
show_label = True, | |
scale=1 | |
) | |
r_btn = gr.Button(value="Model B is better",scale=1) | |
action_btn.click(fn=get_comp, | |
inputs=[prompt,], | |
outputs=[l_output, r_output], | |
api_name="arena") | |
l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'), | |
inputs=[l_output,r_output,prompt], | |
api_name="feedleft") | |
l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'), | |
inputs=[l_output,r_output,prompt], | |
api_name="feedright") | |
demo.launch(debug=True) |