ugaray96 commited on
Commit
893d078
·
unverified ·
1 Parent(s): f026256

Adds doc store global and top k parameter

Browse files
core/pipelines.py CHANGED
@@ -14,8 +14,14 @@ import os
14
  data_path = "data/"
15
  os.makedirs(data_path, exist_ok=True)
16
 
 
17
 
18
- def keyword_search(index="documents", split_word_length=100, audio_output=False):
 
 
 
 
 
19
  """
20
  **Keyword Search Pipeline**
21
 
@@ -26,8 +32,10 @@ def keyword_search(index="documents", split_word_length=100, audio_output=False)
26
  - Documents that have more lexical overlap with the query are more likely to be relevant
27
  - Words that occur in fewer documents are more significant than words that occur in many documents
28
  """
29
- document_store = InMemoryDocumentStore(index=index)
30
- keyword_retriever = TfidfRetriever(document_store=(document_store))
 
 
31
  processor = PreProcessor(
32
  clean_empty_lines=True,
33
  clean_whitespace=True,
@@ -65,6 +73,7 @@ def dense_passage_retrieval(
65
  split_word_length=100,
66
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
67
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
 
68
  audio_output=False,
69
  ):
70
  """
@@ -76,11 +85,14 @@ def dense_passage_retrieval(
76
  - One BERT base model to encode queries
77
  - Ranking of documents done by dot product similarity between query and document embeddings
78
  """
79
- document_store = InMemoryDocumentStore(index=index)
 
 
80
  dpr_retriever = DensePassageRetriever(
81
  document_store=document_store,
82
  query_embedding_model=query_embedding_model,
83
  passage_embedding_model=passage_embedding_model,
 
84
  )
85
  processor = PreProcessor(
86
  clean_empty_lines=True,
@@ -121,6 +133,7 @@ def dense_passage_retrieval_ranker(
121
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
122
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
123
  ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
 
124
  audio_output=False,
125
  ):
126
  """
@@ -137,8 +150,10 @@ def dense_passage_retrieval_ranker(
137
  split_word_length=split_word_length,
138
  query_embedding_model=query_embedding_model,
139
  passage_embedding_model=passage_embedding_model,
 
 
140
  )
141
- ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
142
 
143
  search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
144
 
 
14
  data_path = "data/"
15
  os.makedirs(data_path, exist_ok=True)
16
 
17
+ index = "documents"
18
 
19
+ document_store = InMemoryDocumentStore(index=index)
20
+
21
+
22
+ def keyword_search(
23
+ index="documents", split_word_length=100, top_k=10, audio_output=False
24
+ ):
25
  """
26
  **Keyword Search Pipeline**
27
 
 
32
  - Documents that have more lexical overlap with the query are more likely to be relevant
33
  - Words that occur in fewer documents are more significant than words that occur in many documents
34
  """
35
+ global document_store
36
+ if index != document_store.index:
37
+ document_store = InMemoryDocumentStore(index=index)
38
+ keyword_retriever = TfidfRetriever(document_store=(document_store), top_k=top_k)
39
  processor = PreProcessor(
40
  clean_empty_lines=True,
41
  clean_whitespace=True,
 
73
  split_word_length=100,
74
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
75
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
76
+ top_k=10,
77
  audio_output=False,
78
  ):
79
  """
 
85
  - One BERT base model to encode queries
86
  - Ranking of documents done by dot product similarity between query and document embeddings
87
  """
88
+ global document_store
89
+ if index != document_store.index:
90
+ document_store = InMemoryDocumentStore(index=index)
91
  dpr_retriever = DensePassageRetriever(
92
  document_store=document_store,
93
  query_embedding_model=query_embedding_model,
94
  passage_embedding_model=passage_embedding_model,
95
+ top_k=top_k,
96
  )
97
  processor = PreProcessor(
98
  clean_empty_lines=True,
 
133
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
134
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
135
  ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
136
+ top_k=10,
137
  audio_output=False,
138
  ):
139
  """
 
150
  split_word_length=split_word_length,
151
  query_embedding_model=query_embedding_model,
152
  passage_embedding_model=passage_embedding_model,
153
+ # top_k high to allow better recall, the ranker will handle the precision
154
+ top_k=10000000,
155
  )
156
+ ranker = SentenceTransformersRanker(model_name_or_path=ranker_model, top_k=top_k)
157
 
158
  search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
159
 
interface/components.py CHANGED
@@ -27,9 +27,9 @@ def component_select_pipeline(container):
27
  elif isinstance(value, bool):
28
  value = st.checkbox(parameter, value)
29
  elif isinstance(value, int):
30
- value = int(st.number_input(parameter, value))
31
  elif isinstance(value, float):
32
- value = float(st.number_input(parameter, value))
33
  pipeline_func_parameters[index_pipe][parameter] = value
34
  if (
35
  st.session_state["pipeline"] is None
 
27
  elif isinstance(value, bool):
28
  value = st.checkbox(parameter, value)
29
  elif isinstance(value, int):
30
+ value = int(st.number_input(parameter, value=value))
31
  elif isinstance(value, float):
32
+ value = float(st.number_input(parameter, value=value))
33
  pipeline_func_parameters[index_pipe][parameter] = value
34
  if (
35
  st.session_state["pipeline"] is None
interface/pages.py CHANGED
@@ -88,7 +88,9 @@ def page_index(container):
88
  index_results = None
89
  if st.button("Index"):
90
  index_results = index(
91
- corpus, st.session_state["pipeline"]["index_pipeline"], clear_index
 
 
92
  )
93
  st.session_state["doc_id"] = doc_id
94
  if index_results:
 
88
  index_results = None
89
  if st.button("Index"):
90
  index_results = index(
91
+ documents=corpus,
92
+ pipeline=st.session_state["pipeline"]["index_pipeline"],
93
+ clear_index=clear_index,
94
  )
95
  st.session_state["doc_id"] = doc_id
96
  if index_results: