Spaces:
Runtime error
Runtime error
File size: 5,299 Bytes
40e4418 2db50b5 40e4418 2db50b5 0bf35b2 1b9e9c5 2db50b5 40e4418 891f199 40e4418 2db50b5 40e4418 2db50b5 40e4418 caad63d 2db50b5 caad63d 2db50b5 40e4418 2db50b5 1130814 2db50b5 40e4418 2db50b5 1130814 2db50b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
import csv
import random
import uuid
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
USER_ID = uuid.uuid4()
INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"]
MODELS = [
"biodatlab/MIReAD-Neuro-Large",
"biodatlab/MIReAD-Neuro-Contrastive",
"biodatlab/SciBERT-Neuro-Contrastive",
]
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
faiss_embedders = [HuggingFaceEmbeddings(
model_name=name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs) for name in MODELS]
vecdbs = [FAISS.load_local(index_name, faiss_embedder)
for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)]
def get_matchup():
choices = INDEXES
left, right = random.sample(choices,2)
return left, right
def get_comp(prompt):
left, right = get_matchup()
left_output = inference(prompt,left)
right_output = inference(prompt,right)
return left_output, right_output
def get_article(db_name="miread_contrastive"):
db = vecdbs[INDEXES.index(db_name)]
return db.docstore._dict[db.index_to_docstore_id[0]]
def send_result(l_output, r_output, prompt, pick):
with csv.open('results.csv','a') as res_file:
writer = csv.writer(res_file)
row = [USER_ID,left,right,prompt,pick]
writer.writerow(row)
def get_matches(query, db_name="miread_contrastive"):
"""
Wrapper to call the similarity search on the required index
"""
matches = vecdbs[INDEXES.index(
db_name)].similarity_search_with_score(query, k=60)
return matches
def inference(query, model="miread_contrastive"):
"""
This function processes information retrieved by the get_matches() function
Returns - Gradio update commands for the authors, abstracts and journals tablular output
"""
matches = get_matches(query, model)
auth_counts = {}
n_table = []
scores = [round(match[1].item(), 3) for match in matches]
min_score = min(scores)
max_score = max(scores)
def normaliser(x): return round(1 - (x-min_score)/max_score, 3)
for i, match in enumerate(matches):
doc = match[0]
score = round(normaliser(round(match[1].item(), 3)), 3)
title = doc.metadata['title']
author = doc.metadata['authors'][0].title()
date = doc.metadata.get('date', 'None')
link = doc.metadata.get('link', 'None')
# For authors
record = [i+1,
score,
author,
title,
link,
date]
if auth_counts.get(author, 0) < 2:
n_table.append(record)
if auth_counts.get(author, 0) == 0:
auth_counts[author] = 1
else:
auth_counts[author] += 1
n_output = gr.Dataframe.update(value=n_table, visible=True)
return n_output
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# NBDT Recommendation Engine Arena")
gr.Markdown("NBDT Recommendation Engine for Editors is a tool for neuroscience authors/abstracts/journalsrecommendation built for NBDT journal editors. \
It aims to help an editor to find similar reviewers, abstracts, and journals to a given submitted abstract.\
To find a recommendation, paste a `title[SEP]abstract` or `abstract` in the text box below and click on the appropriate \"Find Matches\" button.\
Then, you can hover to authors/abstracts/journals tab to find a suggested list.\
The data in our current demo includes authors associated with the NBDT Journal. We will update the data monthly for an up-to-date publications.")
article = get_article()
models = gr.State(value=get_matchup())
prompt = gr.State(value=article)
abst = gr.Textbox(value = article, label="Abstract", lines=10)
action_btn = gr.Button(value="Get comparison")
with gr.Row().style(equal_height=True):
with gr.Column(scale=1):
l_output = gr.Dataframe(
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
datatype=['number', 'number', 'str', 'str', 'str', 'str'],
col_count=(6, "fixed"),
wrap=True,
visible=True,
label='Model A',
show_label = True,
scale=1
)
l_btn = gr.Button(value="Model A is better",scale=1)
with gr.Column(scale=1):
r_output = gr.Dataframe(
headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
datatype=['number', 'number', 'str', 'str', 'str', 'str'],
col_count=(6, "fixed"),
wrap=True,
visible=True,
label='Model B',
show_label = True,
scale=1
)
r_btn = gr.Button(value="Model B is better",scale=1)
action_btn.click(fn=get_comp,
inputs=[prompt,],
outputs=[l_output, r_output],
api_name="arena")
l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'),
inputs=[l_output,r_output,prompt],
api_name="feedleft")
l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'),
inputs=[l_output,r_output,prompt],
api_name="feedright")
demo.launch(debug=True) |