NCTCMumbai
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -35,7 +35,7 @@ template = env.get_template('template.j2')
|
|
35 |
template_html = env.get_template('template_html.j2')
|
36 |
|
37 |
# crossEncoder
|
38 |
-
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
39 |
#cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
|
40 |
# Examples
|
41 |
examples = ['What is the 4 digit classification heading for Gold jewellery?',
|
@@ -49,7 +49,7 @@ def add_text(history, text):
|
|
49 |
return history, gr.Textbox(value="", interactive=False)
|
50 |
|
51 |
|
52 |
-
def bot(history,
|
53 |
top_rerank = 15
|
54 |
top_k_rank = 10
|
55 |
query = history[-1][0]
|
@@ -66,7 +66,7 @@ def bot(history, api_kind):
|
|
66 |
logger.warning(f'Finished query vec')
|
67 |
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
|
68 |
|
69 |
-
|
70 |
|
71 |
logger.warning(f'Finished search')
|
72 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
|
@@ -74,6 +74,10 @@ def bot(history, api_kind):
|
|
74 |
logger.warning(f'start cross encoder {len(documents)}')
|
75 |
# Retrieve documents relevant to query
|
76 |
query_doc_pair = [[query, doc] for doc in documents]
|
|
|
|
|
|
|
|
|
77 |
cross_scores = cross_encoder.predict(query_doc_pair)
|
78 |
sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
|
79 |
logger.warning(f'Finished cross encoder {len(documents)}')
|
@@ -88,16 +92,7 @@ def bot(history, api_kind):
|
|
88 |
prompt = template.render(documents=documents, query=query)
|
89 |
prompt_html = template_html.render(documents=documents, query=query)
|
90 |
|
91 |
-
|
92 |
-
generate_fn = generate_hf
|
93 |
-
elif api_kind == "OpenAI":
|
94 |
-
generate_fn = generate_openai
|
95 |
-
elif api_kind is None:
|
96 |
-
gr.Warning("API name was not provided")
|
97 |
-
raise ValueError("API name was not provided")
|
98 |
-
else:
|
99 |
-
gr.Warning(f"API {api_kind} is not supported")
|
100 |
-
raise ValueError(f"API {api_kind} is not supported")
|
101 |
|
102 |
history[-1][1] = ""
|
103 |
for character in generate_fn(prompt, history[:-1]):
|
@@ -125,19 +120,19 @@ with gr.Blocks() as demo:
|
|
125 |
)
|
126 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
127 |
|
128 |
-
|
129 |
|
130 |
prompt_html = gr.HTML()
|
131 |
# Turn off interactivity while generating if you click
|
132 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
133 |
-
bot, [chatbot,
|
134 |
|
135 |
# Turn it back on
|
136 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
137 |
|
138 |
# Turn off interactivity while generating if you hit enter
|
139 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
140 |
-
bot, [chatbot,
|
141 |
|
142 |
# Turn it back on
|
143 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
35 |
template_html = env.get_template('template_html.j2')
|
36 |
|
37 |
# crossEncoder
|
38 |
+
#cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
39 |
#cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
|
40 |
# Examples
|
41 |
examples = ['What is the 4 digit classification heading for Gold jewellery?',
|
|
|
49 |
return history, gr.Textbox(value="", interactive=False)
|
50 |
|
51 |
|
52 |
+
def bot(history, cross_encoder):
|
53 |
top_rerank = 15
|
54 |
top_k_rank = 10
|
55 |
query = history[-1][0]
|
|
|
66 |
logger.warning(f'Finished query vec')
|
67 |
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
|
68 |
|
69 |
+
|
70 |
|
71 |
logger.warning(f'Finished search')
|
72 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
|
|
|
74 |
logger.warning(f'start cross encoder {len(documents)}')
|
75 |
# Retrieve documents relevant to query
|
76 |
query_doc_pair = [[query, doc] for doc in documents]
|
77 |
+
if cross_encoder=='MiniLM-L6v2' :
|
78 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
79 |
+
else:
|
80 |
+
cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
|
81 |
cross_scores = cross_encoder.predict(query_doc_pair)
|
82 |
sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
|
83 |
logger.warning(f'Finished cross encoder {len(documents)}')
|
|
|
92 |
prompt = template.render(documents=documents, query=query)
|
93 |
prompt_html = template_html.render(documents=documents, query=query)
|
94 |
|
95 |
+
generate_fn = generate_hf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
history[-1][1] = ""
|
98 |
for character in generate_fn(prompt, history[:-1]):
|
|
|
120 |
)
|
121 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
122 |
|
123 |
+
cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker'], value='MiniLM-L6v2')
|
124 |
|
125 |
prompt_html = gr.HTML()
|
126 |
# Turn off interactivity while generating if you click
|
127 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
128 |
+
bot, [chatbot, cross_encoder], [chatbot, prompt_html])
|
129 |
|
130 |
# Turn it back on
|
131 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
132 |
|
133 |
# Turn off interactivity while generating if you hit enter
|
134 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
135 |
+
bot, [chatbot, cross_encoder], [chatbot, prompt_html])
|
136 |
|
137 |
# Turn it back on
|
138 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|