File size: 5,299 Bytes
40e4418
2db50b5
 
 
40e4418
 
 
2db50b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bf35b2
1b9e9c5
2db50b5
 
 
 
 
 
 
 
40e4418
 
 
 
 
891f199
40e4418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2db50b5
40e4418
 
 
2db50b5
40e4418
 
 
 
 
caad63d
2db50b5
caad63d
 
2db50b5
 
 
 
40e4418
 
 
 
2db50b5
 
1130814
2db50b5
 
 
 
 
 
 
 
40e4418
2db50b5
 
1130814
2db50b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import csv
import random
import uuid
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings

USER_ID = uuid.uuid4()
INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"]
MODELS = [
    "biodatlab/MIReAD-Neuro-Large",
    "biodatlab/MIReAD-Neuro-Contrastive",
    "biodatlab/SciBERT-Neuro-Contrastive",
]
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
faiss_embedders = [HuggingFaceEmbeddings(
    model_name=name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs) for name in MODELS]

vecdbs = [FAISS.load_local(index_name, faiss_embedder)
          for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)]

def get_matchup():
    choices = INDEXES
    left, right = random.sample(choices,2)
    return left, right

def get_comp(prompt):
    left, right = get_matchup()
    left_output = inference(prompt,left)
    right_output = inference(prompt,right)
    return left_output, right_output

def get_article(db_name="miread_contrastive"):
    db = vecdbs[INDEXES.index(db_name)]
    return db.docstore._dict[db.index_to_docstore_id[0]]


def send_result(l_output, r_output, prompt, pick):
    with csv.open('results.csv','a') as res_file:
      writer = csv.writer(res_file)
      row = [USER_ID,left,right,prompt,pick]
      writer.writerow(row)


def get_matches(query, db_name="miread_contrastive"):
    """
    Wrapper to call the similarity search on the required index
    """
    matches = vecdbs[INDEXES.index(
        db_name)].similarity_search_with_score(query, k=60)
    return matches


def inference(query, model="miread_contrastive"):
    """
    This function processes information retrieved by the get_matches() function
    Returns - Gradio update commands for the authors, abstracts and journals tablular output
    """
    matches = get_matches(query, model)
    auth_counts = {}
    n_table = []
    scores = [round(match[1].item(), 3) for match in matches]
    min_score = min(scores)
    max_score = max(scores)
    def normaliser(x): return round(1 - (x-min_score)/max_score, 3)
    for i, match in enumerate(matches):
        doc = match[0]
        score = round(normaliser(round(match[1].item(), 3)), 3)
        title = doc.metadata['title']
        author = doc.metadata['authors'][0].title()
        date = doc.metadata.get('date', 'None')
        link = doc.metadata.get('link', 'None')

        # For authors
        record = [i+1,
                  score,
                  author,
                  title,
                  link,
                  date]
        if auth_counts.get(author, 0) < 2:
            n_table.append(record)
            if auth_counts.get(author, 0) == 0:
                auth_counts[author] = 1
            else:
                auth_counts[author] += 1
    n_output = gr.Dataframe.update(value=n_table, visible=True)
    return n_output


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# NBDT Recommendation Engine Arena")
    gr.Markdown("NBDT Recommendation Engine for Editors is a tool for neuroscience authors/abstracts/journalsrecommendation built for NBDT journal editors. \
    It aims to help an editor to find similar reviewers, abstracts, and journals to a given submitted abstract.\
    To find a recommendation, paste a `title[SEP]abstract` or `abstract` in the text box below and click on the appropriate \"Find Matches\" button.\
    Then, you can hover to authors/abstracts/journals tab to find a suggested list.\
    The data in our current demo includes authors associated with the NBDT Journal. We will update the data monthly for an up-to-date publications.")
    article = get_article()
    models = gr.State(value=get_matchup())
    prompt = gr.State(value=article)
    abst = gr.Textbox(value = article, label="Abstract", lines=10)
    action_btn = gr.Button(value="Get comparison")
    with gr.Row().style(equal_height=True):
      with gr.Column(scale=1):
        l_output = gr.Dataframe(
            headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
            datatype=['number', 'number', 'str', 'str', 'str', 'str'],
            col_count=(6, "fixed"),
            wrap=True,
            visible=True,
            label='Model A',
            show_label = True,
            scale=1
            )
        l_btn = gr.Button(value="Model A is better",scale=1)
      with gr.Column(scale=1):
        r_output = gr.Dataframe(
            headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
            datatype=['number', 'number', 'str', 'str', 'str', 'str'],
            col_count=(6, "fixed"),
            wrap=True,
            visible=True,
            label='Model B',
            show_label = True,
            scale=1
            )
        r_btn = gr.Button(value="Model B is better",scale=1)

    action_btn.click(fn=get_comp,
        inputs=[prompt,],
        outputs=[l_output, r_output],
        api_name="arena")
    l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'),
                inputs=[l_output,r_output,prompt],
                api_name="feedleft")
    l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'),
                inputs=[l_output,r_output,prompt],
                api_name="feedright")

demo.launch(debug=True)