import gradio as gr import os import csv import json import uuid import random import pickle from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from googleapiclient.discovery import build from google.oauth2 import service_account USER_ID = uuid.uuid4() SERVICE_ACCOUNT_JSON = os.environ.get('GOOGLE_SHEET_CREDENTIALS') creds = service_account.Credentials.from_service_account_info(json.loads(SERVICE_ACCOUNT_JSON)) SPREADSHEET_ID = '1o0iKPxWYKYKEPjqB2YwrTgrLzvGyb9ULj9tnw_cfJb0' service = build('sheets', 'v4', credentials=creds) LEFT_MODEL = None RIGHT_MODEL = None PROMPT = None with open("article_list.pkl","rb") as articles: article_list = tuple(pickle.load(articles)) INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"] MODELS = [ "biodatlab/MIReAD-Neuro-Large", "biodatlab/MIReAD-Neuro-Contrastive", "biodatlab/SciBERT-Neuro-Contrastive", ] model_kwargs = {'device': 'cpu'} encode_kwargs = {'normalize_embeddings': False} faiss_embedders = [HuggingFaceEmbeddings( model_name=name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) for name in MODELS] vecdbs = [FAISS.load_local(index_name, faiss_embedder) for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)] def get_matchup(): global LEFT_MODEL, RIGHT_MODEL choices = INDEXES left, right = random.sample(choices,2) LEFT_MODEL, RIGHT_MODEL = left, right return left, right def get_comp(prompt): global PROMPT left, right = get_matchup() left_output = inference(PROMPT,left) right_output = inference(PROMPT,right) return left_output, right_output def get_article(): return random.choice(article_list) def send_result(l_output, r_output, prompt, pick): global PROMPT global LEFT_MODEL, RIGHT_MODEL # with open('results.csv','a') as res_file: # writer = csv.writer(res_file) # writer.writerow(row) if (pick=='left'): pick = LEFT_MODEL else: pick = RIGHT_MODEL row = [USER_ID,PROMPT,LEFT_MODEL,RIGHT_MODEL,pick] row = [str(x) for x in row] body = {'values': [row]} result = service.spreadsheets().values().append(spreadsheetId=SPREADSHEET_ID, range='A1:E1', valueInputOption='RAW', body=body).execute() print(f"Appended {result['updates']['updatedCells']} cells.") new_prompt = get_article() PROMPT = new_prompt return new_prompt,gr.State.update(value=new_prompt) def get_matches(query, db_name="miread_contrastive"): """ Wrapper to call the similarity search on the required index """ matches = vecdbs[INDEXES.index( db_name)].similarity_search_with_score(query, k=30) return matches def inference(query, model="miread_contrastive"): """ This function processes information retrieved by the get_matches() function Returns - Gradio update commands for the authors, abstracts and journals tablular output """ matches = get_matches(query, model) auth_counts = {} n_table = [] scores = [round(match[1].item(), 3) for match in matches] min_score = min(scores) max_score = max(scores) def normaliser(x): return round(1 - (x-min_score)/max_score, 3) i = 1 for match in matches: doc = match[0] score = round(normaliser(round(match[1].item(), 3)), 3) title = doc.metadata['title'] author = doc.metadata['authors'][0].title() date = doc.metadata.get('date', 'None') link = doc.metadata.get('link', 'None') # For authors record = [score, author, title, link, date] if auth_counts.get(author, 0) < 2: n_table.append([i,]+record) i += 1 if auth_counts.get(author, 0) == 0: auth_counts[author] = 1 else: auth_counts[author] += 1 n_output = gr.Dataframe.update(value=n_table[:10], visible=True) return n_output with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# NBDT Recommendation Engine Arena") gr.Markdown("NBDT Recommendation Engine Arena is a tool designed to compare neuroscience abstract recommendations by our models. \ We will use this data to compare the performance of models and their preference by various neuroscientists.\ Click on the 'Get Comparision' button to run two random models on the displayed prompt. Then use the correct 'Model X is Better' button to give your vote.\ All models were trained on data provided to us by the NBDT Journal.") article = get_article() models = gr.State(value=get_matchup()) prompt = gr.State(value=article) PROMPT = article abst = gr.Textbox(value = article, label="Abstract", lines=10) action_btn = gr.Button(value="Get comparison") with gr.Group(): with gr.Row().style(equal_height=True): with gr.Column(scale=1): l_output = gr.Dataframe( headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'], datatype=['number', 'number', 'str', 'str', 'str', 'str'], col_count=(6, "fixed"), wrap=True, visible=True, label='Model A', show_label = True, overflow_row_behaviour='paginate', scale=1 ) with gr.Column(scale=1): r_output = gr.Dataframe( headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'], datatype=['number', 'number', 'str', 'str', 'str', 'str'], col_count=(6, "fixed"), wrap=True, visible=True, label='Model B', show_label = True, overflow_row_behaviour='paginate', scale=1 ) with gr.Row().style(equal_height=True): l_btn = gr.Button(value="Model A is better",scale=1) r_btn = gr.Button(value="Model B is better",scale=1) action_btn.click(fn=get_comp, inputs=[prompt,], outputs=[l_output, r_output], api_name="arena") l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'), inputs=[l_output,r_output,prompt], outputs=[abst,], api_name="feedleft") r_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'), inputs=[l_output,r_output,prompt], outputs=[abst,prompt], api_name="feedright") demo.launch(debug=True)