Spaces:
Runtime error
Runtime error
ugmSorcero
commited on
Commit
·
f456ef3
1
Parent(s):
b15879f
Adds ranker and setting of parameters in UI
Browse files- app.py +0 -2
- core/pipelines.py +18 -0
- interface/components.py +17 -2
- interface/config.py +4 -1
- interface/utils.py +3 -2
app.py
CHANGED
@@ -37,9 +37,7 @@ def run_demo():
|
|
37 |
styles={
|
38 |
"container": {"border": "2px solid #818494"},
|
39 |
"icon": {"font-size": "22px"},
|
40 |
-
# "nav-item": {},
|
41 |
"nav-link": {"font-size": "20px", "text-align": "left"},
|
42 |
-
# "nav-link-selected": {"background-color": "green"},
|
43 |
},
|
44 |
)
|
45 |
component_select_pipeline(navigation)
|
|
|
37 |
styles={
|
38 |
"container": {"border": "2px solid #818494"},
|
39 |
"icon": {"font-size": "22px"},
|
|
|
40 |
"nav-link": {"font-size": "20px", "text-align": "left"},
|
|
|
41 |
},
|
42 |
)
|
43 |
component_select_pipeline(navigation)
|
core/pipelines.py
CHANGED
@@ -6,6 +6,7 @@ from haystack import Pipeline
|
|
6 |
from haystack.document_stores import InMemoryDocumentStore
|
7 |
from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
8 |
from haystack.nodes.preprocessor import PreProcessor
|
|
|
9 |
|
10 |
|
11 |
def keyword_search(
|
@@ -72,3 +73,20 @@ def dense_passage_retrieval(
|
|
72 |
)
|
73 |
|
74 |
return search_pipeline, index_pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from haystack.document_stores import InMemoryDocumentStore
|
7 |
from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
8 |
from haystack.nodes.preprocessor import PreProcessor
|
9 |
+
from haystack.nodes.ranker import SentenceTransformersRanker
|
10 |
|
11 |
|
12 |
def keyword_search(
|
|
|
73 |
)
|
74 |
|
75 |
return search_pipeline, index_pipeline
|
76 |
+
|
77 |
+
def dense_passage_retrieval_ranker(
|
78 |
+
index="documents",
|
79 |
+
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
80 |
+
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
81 |
+
ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2"
|
82 |
+
):
|
83 |
+
search_pipeline, index_pipeline = dense_passage_retrieval(
|
84 |
+
index=index,
|
85 |
+
query_embedding_model=query_embedding_model,
|
86 |
+
passage_embedding_model=passage_embedding_model,
|
87 |
+
)
|
88 |
+
ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
|
89 |
+
|
90 |
+
search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
|
91 |
+
|
92 |
+
return search_pipeline, index_pipeline
|
interface/components.py
CHANGED
@@ -4,7 +4,7 @@ from interface.draw_pipelines import get_pipeline_graph
|
|
4 |
|
5 |
|
6 |
def component_select_pipeline(container):
|
7 |
-
pipeline_names, pipeline_funcs = get_pipelines()
|
8 |
with container:
|
9 |
selected_pipeline = st.selectbox(
|
10 |
"Select pipeline",
|
@@ -13,14 +13,29 @@ def component_select_pipeline(container):
|
|
13 |
if "Keyword Search" in pipeline_names
|
14 |
else 0,
|
15 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
if (
|
17 |
st.session_state["pipeline"] is None
|
18 |
or st.session_state["pipeline"]["name"] != selected_pipeline
|
|
|
19 |
):
|
|
|
20 |
(
|
21 |
search_pipeline,
|
22 |
index_pipeline,
|
23 |
-
) = pipeline_funcs[
|
24 |
st.session_state["pipeline"] = {
|
25 |
"name": selected_pipeline,
|
26 |
"search_pipeline": search_pipeline,
|
|
|
4 |
|
5 |
|
6 |
def component_select_pipeline(container):
|
7 |
+
pipeline_names, pipeline_funcs, pipeline_func_parameters = get_pipelines()
|
8 |
with container:
|
9 |
selected_pipeline = st.selectbox(
|
10 |
"Select pipeline",
|
|
|
13 |
if "Keyword Search" in pipeline_names
|
14 |
else 0,
|
15 |
)
|
16 |
+
index_pipe = pipeline_names.index(selected_pipeline)
|
17 |
+
st.write("---")
|
18 |
+
st.header("Pipeline Parameters")
|
19 |
+
for parameter, value in pipeline_func_parameters[index_pipe].items():
|
20 |
+
if isinstance(value, str):
|
21 |
+
value = st.text_input(parameter, value)
|
22 |
+
elif isinstance(value, bool):
|
23 |
+
value = st.checkbox(parameter, value)
|
24 |
+
elif isinstance(value, int):
|
25 |
+
value = int(st.number_input(parameter, value))
|
26 |
+
elif isinstance(value, float):
|
27 |
+
value = float(st.number_input(parameter, value))
|
28 |
+
pipeline_func_parameters[index_pipe][parameter] = value
|
29 |
if (
|
30 |
st.session_state["pipeline"] is None
|
31 |
or st.session_state["pipeline"]["name"] != selected_pipeline
|
32 |
+
or list(st.session_state["pipeline_func_parameters"][index_pipe].values()) != list(pipeline_func_parameters[index_pipe].values())
|
33 |
):
|
34 |
+
st.session_state["pipeline_func_parameters"] = pipeline_func_parameters
|
35 |
(
|
36 |
search_pipeline,
|
37 |
index_pipeline,
|
38 |
+
) = pipeline_funcs[index_pipe](**pipeline_func_parameters[index_pipe])
|
39 |
st.session_state["pipeline"] = {
|
40 |
"name": selected_pipeline,
|
41 |
"search_pipeline": search_pipeline,
|
interface/config.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
from interface.pages import page_landing_page, page_search, page_index
|
2 |
|
3 |
# Define default Session Variables over the whole session.
|
4 |
-
session_state_variables = {
|
|
|
|
|
|
|
5 |
|
6 |
# Define Pages for the demo
|
7 |
pages = {
|
|
|
1 |
from interface.pages import page_landing_page, page_search, page_index
|
2 |
|
3 |
# Define default Session Variables over the whole session.
|
4 |
+
session_state_variables = {
|
5 |
+
"pipeline": None,
|
6 |
+
"pipeline_func_parameters": []
|
7 |
+
}
|
8 |
|
9 |
# Define Pages for the demo
|
10 |
pages = {
|
interface/utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from io import StringIO
|
2 |
import core.pipelines as pipelines_functions
|
3 |
-
from inspect import getmembers, isfunction
|
4 |
from newspaper import Article
|
5 |
from PyPDF2 import PdfFileReader
|
6 |
import streamlit as st
|
@@ -16,7 +16,8 @@ def get_pipelines():
|
|
16 |
pipeline_names = [
|
17 |
" ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
|
18 |
]
|
19 |
-
|
|
|
20 |
|
21 |
|
22 |
@st.experimental_memo
|
|
|
1 |
from io import StringIO
|
2 |
import core.pipelines as pipelines_functions
|
3 |
+
from inspect import getmembers, isfunction, signature
|
4 |
from newspaper import Article
|
5 |
from PyPDF2 import PdfFileReader
|
6 |
import streamlit as st
|
|
|
16 |
pipeline_names = [
|
17 |
" ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
|
18 |
]
|
19 |
+
pipeline_func_parameters = [{key: value.default for key, value in signature(pipe_func).parameters.items()} for pipe_func in pipeline_funcs]
|
20 |
+
return pipeline_names, pipeline_funcs, pipeline_func_parameters
|
21 |
|
22 |
|
23 |
@st.experimental_memo
|