ugmSorcero commited on
Commit
f456ef3
·
1 Parent(s): b15879f

Adds ranker and setting of parameters in UI

Browse files
app.py CHANGED
@@ -37,9 +37,7 @@ def run_demo():
37
  styles={
38
  "container": {"border": "2px solid #818494"},
39
  "icon": {"font-size": "22px"},
40
- # "nav-item": {},
41
  "nav-link": {"font-size": "20px", "text-align": "left"},
42
- # "nav-link-selected": {"background-color": "green"},
43
  },
44
  )
45
  component_select_pipeline(navigation)
 
37
  styles={
38
  "container": {"border": "2px solid #818494"},
39
  "icon": {"font-size": "22px"},
 
40
  "nav-link": {"font-size": "20px", "text-align": "left"},
 
41
  },
42
  )
43
  component_select_pipeline(navigation)
core/pipelines.py CHANGED
@@ -6,6 +6,7 @@ from haystack import Pipeline
6
  from haystack.document_stores import InMemoryDocumentStore
7
  from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
8
  from haystack.nodes.preprocessor import PreProcessor
 
9
 
10
 
11
  def keyword_search(
@@ -72,3 +73,20 @@ def dense_passage_retrieval(
72
  )
73
 
74
  return search_pipeline, index_pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from haystack.document_stores import InMemoryDocumentStore
7
  from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
8
  from haystack.nodes.preprocessor import PreProcessor
9
+ from haystack.nodes.ranker import SentenceTransformersRanker
10
 
11
 
12
  def keyword_search(
 
73
  )
74
 
75
  return search_pipeline, index_pipeline
76
+
77
+ def dense_passage_retrieval_ranker(
78
+ index="documents",
79
+ query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
80
+ passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
81
+ ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2"
82
+ ):
83
+ search_pipeline, index_pipeline = dense_passage_retrieval(
84
+ index=index,
85
+ query_embedding_model=query_embedding_model,
86
+ passage_embedding_model=passage_embedding_model,
87
+ )
88
+ ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
89
+
90
+ search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
91
+
92
+ return search_pipeline, index_pipeline
interface/components.py CHANGED
@@ -4,7 +4,7 @@ from interface.draw_pipelines import get_pipeline_graph
4
 
5
 
6
  def component_select_pipeline(container):
7
- pipeline_names, pipeline_funcs = get_pipelines()
8
  with container:
9
  selected_pipeline = st.selectbox(
10
  "Select pipeline",
@@ -13,14 +13,29 @@ def component_select_pipeline(container):
13
  if "Keyword Search" in pipeline_names
14
  else 0,
15
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  if (
17
  st.session_state["pipeline"] is None
18
  or st.session_state["pipeline"]["name"] != selected_pipeline
 
19
  ):
 
20
  (
21
  search_pipeline,
22
  index_pipeline,
23
- ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
24
  st.session_state["pipeline"] = {
25
  "name": selected_pipeline,
26
  "search_pipeline": search_pipeline,
 
4
 
5
 
6
  def component_select_pipeline(container):
7
+ pipeline_names, pipeline_funcs, pipeline_func_parameters = get_pipelines()
8
  with container:
9
  selected_pipeline = st.selectbox(
10
  "Select pipeline",
 
13
  if "Keyword Search" in pipeline_names
14
  else 0,
15
  )
16
+ index_pipe = pipeline_names.index(selected_pipeline)
17
+ st.write("---")
18
+ st.header("Pipeline Parameters")
19
+ for parameter, value in pipeline_func_parameters[index_pipe].items():
20
+ if isinstance(value, str):
21
+ value = st.text_input(parameter, value)
22
+ elif isinstance(value, bool):
23
+ value = st.checkbox(parameter, value)
24
+ elif isinstance(value, int):
25
+ value = int(st.number_input(parameter, value))
26
+ elif isinstance(value, float):
27
+ value = float(st.number_input(parameter, value))
28
+ pipeline_func_parameters[index_pipe][parameter] = value
29
  if (
30
  st.session_state["pipeline"] is None
31
  or st.session_state["pipeline"]["name"] != selected_pipeline
32
+ or list(st.session_state["pipeline_func_parameters"][index_pipe].values()) != list(pipeline_func_parameters[index_pipe].values())
33
  ):
34
+ st.session_state["pipeline_func_parameters"] = pipeline_func_parameters
35
  (
36
  search_pipeline,
37
  index_pipeline,
38
+ ) = pipeline_funcs[index_pipe](**pipeline_func_parameters[index_pipe])
39
  st.session_state["pipeline"] = {
40
  "name": selected_pipeline,
41
  "search_pipeline": search_pipeline,
interface/config.py CHANGED
@@ -1,7 +1,10 @@
1
  from interface.pages import page_landing_page, page_search, page_index
2
 
3
  # Define default Session Variables over the whole session.
4
- session_state_variables = {"pipeline": None}
 
 
 
5
 
6
  # Define Pages for the demo
7
  pages = {
 
1
  from interface.pages import page_landing_page, page_search, page_index
2
 
3
  # Define default Session Variables over the whole session.
4
+ session_state_variables = {
5
+ "pipeline": None,
6
+ "pipeline_func_parameters": []
7
+ }
8
 
9
  # Define Pages for the demo
10
  pages = {
interface/utils.py CHANGED
@@ -1,6 +1,6 @@
1
  from io import StringIO
2
  import core.pipelines as pipelines_functions
3
- from inspect import getmembers, isfunction
4
  from newspaper import Article
5
  from PyPDF2 import PdfFileReader
6
  import streamlit as st
@@ -16,7 +16,8 @@ def get_pipelines():
16
  pipeline_names = [
17
  " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
18
  ]
19
- return pipeline_names, pipeline_funcs
 
20
 
21
 
22
  @st.experimental_memo
 
1
  from io import StringIO
2
  import core.pipelines as pipelines_functions
3
+ from inspect import getmembers, isfunction, signature
4
  from newspaper import Article
5
  from PyPDF2 import PdfFileReader
6
  import streamlit as st
 
16
  pipeline_names = [
17
  " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
18
  ]
19
+ pipeline_func_parameters = [{key: value.default for key, value in signature(pipe_func).parameters.items()} for pipe_func in pipeline_funcs]
20
+ return pipeline_names, pipeline_funcs, pipeline_func_parameters
21
 
22
 
23
  @st.experimental_memo