Spaces:
Runtime error
Runtime error
Adds doc store global and top k parameter
Browse files- core/pipelines.py +20 -5
- interface/components.py +2 -2
- interface/pages.py +3 -1
core/pipelines.py
CHANGED
@@ -14,8 +14,14 @@ import os
|
|
14 |
data_path = "data/"
|
15 |
os.makedirs(data_path, exist_ok=True)
|
16 |
|
|
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
"""
|
20 |
**Keyword Search Pipeline**
|
21 |
|
@@ -26,8 +32,10 @@ def keyword_search(index="documents", split_word_length=100, audio_output=False)
|
|
26 |
- Documents that have more lexical overlap with the query are more likely to be relevant
|
27 |
- Words that occur in fewer documents are more significant than words that occur in many documents
|
28 |
"""
|
29 |
-
document_store
|
30 |
-
|
|
|
|
|
31 |
processor = PreProcessor(
|
32 |
clean_empty_lines=True,
|
33 |
clean_whitespace=True,
|
@@ -65,6 +73,7 @@ def dense_passage_retrieval(
|
|
65 |
split_word_length=100,
|
66 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
67 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
68 |
audio_output=False,
|
69 |
):
|
70 |
"""
|
@@ -76,11 +85,14 @@ def dense_passage_retrieval(
|
|
76 |
- One BERT base model to encode queries
|
77 |
- Ranking of documents done by dot product similarity between query and document embeddings
|
78 |
"""
|
79 |
-
document_store
|
|
|
|
|
80 |
dpr_retriever = DensePassageRetriever(
|
81 |
document_store=document_store,
|
82 |
query_embedding_model=query_embedding_model,
|
83 |
passage_embedding_model=passage_embedding_model,
|
|
|
84 |
)
|
85 |
processor = PreProcessor(
|
86 |
clean_empty_lines=True,
|
@@ -121,6 +133,7 @@ def dense_passage_retrieval_ranker(
|
|
121 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
122 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
123 |
ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
|
|
|
124 |
audio_output=False,
|
125 |
):
|
126 |
"""
|
@@ -137,8 +150,10 @@ def dense_passage_retrieval_ranker(
|
|
137 |
split_word_length=split_word_length,
|
138 |
query_embedding_model=query_embedding_model,
|
139 |
passage_embedding_model=passage_embedding_model,
|
|
|
|
|
140 |
)
|
141 |
-
ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
|
142 |
|
143 |
search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
|
144 |
|
|
|
14 |
data_path = "data/"
|
15 |
os.makedirs(data_path, exist_ok=True)
|
16 |
|
17 |
+
index = "documents"
|
18 |
|
19 |
+
document_store = InMemoryDocumentStore(index=index)
|
20 |
+
|
21 |
+
|
22 |
+
def keyword_search(
|
23 |
+
index="documents", split_word_length=100, top_k=10, audio_output=False
|
24 |
+
):
|
25 |
"""
|
26 |
**Keyword Search Pipeline**
|
27 |
|
|
|
32 |
- Documents that have more lexical overlap with the query are more likely to be relevant
|
33 |
- Words that occur in fewer documents are more significant than words that occur in many documents
|
34 |
"""
|
35 |
+
global document_store
|
36 |
+
if index != document_store.index:
|
37 |
+
document_store = InMemoryDocumentStore(index=index)
|
38 |
+
keyword_retriever = TfidfRetriever(document_store=(document_store), top_k=top_k)
|
39 |
processor = PreProcessor(
|
40 |
clean_empty_lines=True,
|
41 |
clean_whitespace=True,
|
|
|
73 |
split_word_length=100,
|
74 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
75 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
76 |
+
top_k=10,
|
77 |
audio_output=False,
|
78 |
):
|
79 |
"""
|
|
|
85 |
- One BERT base model to encode queries
|
86 |
- Ranking of documents done by dot product similarity between query and document embeddings
|
87 |
"""
|
88 |
+
global document_store
|
89 |
+
if index != document_store.index:
|
90 |
+
document_store = InMemoryDocumentStore(index=index)
|
91 |
dpr_retriever = DensePassageRetriever(
|
92 |
document_store=document_store,
|
93 |
query_embedding_model=query_embedding_model,
|
94 |
passage_embedding_model=passage_embedding_model,
|
95 |
+
top_k=top_k,
|
96 |
)
|
97 |
processor = PreProcessor(
|
98 |
clean_empty_lines=True,
|
|
|
133 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
134 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
135 |
ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
|
136 |
+
top_k=10,
|
137 |
audio_output=False,
|
138 |
):
|
139 |
"""
|
|
|
150 |
split_word_length=split_word_length,
|
151 |
query_embedding_model=query_embedding_model,
|
152 |
passage_embedding_model=passage_embedding_model,
|
153 |
+
# top_k high to allow better recall, the ranker will handle the precision
|
154 |
+
top_k=10000000,
|
155 |
)
|
156 |
+
ranker = SentenceTransformersRanker(model_name_or_path=ranker_model, top_k=top_k)
|
157 |
|
158 |
search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
|
159 |
|
interface/components.py
CHANGED
@@ -27,9 +27,9 @@ def component_select_pipeline(container):
|
|
27 |
elif isinstance(value, bool):
|
28 |
value = st.checkbox(parameter, value)
|
29 |
elif isinstance(value, int):
|
30 |
-
value = int(st.number_input(parameter, value))
|
31 |
elif isinstance(value, float):
|
32 |
-
value = float(st.number_input(parameter, value))
|
33 |
pipeline_func_parameters[index_pipe][parameter] = value
|
34 |
if (
|
35 |
st.session_state["pipeline"] is None
|
|
|
27 |
elif isinstance(value, bool):
|
28 |
value = st.checkbox(parameter, value)
|
29 |
elif isinstance(value, int):
|
30 |
+
value = int(st.number_input(parameter, value=value))
|
31 |
elif isinstance(value, float):
|
32 |
+
value = float(st.number_input(parameter, value=value))
|
33 |
pipeline_func_parameters[index_pipe][parameter] = value
|
34 |
if (
|
35 |
st.session_state["pipeline"] is None
|
interface/pages.py
CHANGED
@@ -88,7 +88,9 @@ def page_index(container):
|
|
88 |
index_results = None
|
89 |
if st.button("Index"):
|
90 |
index_results = index(
|
91 |
-
corpus,
|
|
|
|
|
92 |
)
|
93 |
st.session_state["doc_id"] = doc_id
|
94 |
if index_results:
|
|
|
88 |
index_results = None
|
89 |
if st.button("Index"):
|
90 |
index_results = index(
|
91 |
+
documents=corpus,
|
92 |
+
pipeline=st.session_state["pipeline"]["index_pipeline"],
|
93 |
+
clear_index=clear_index,
|
94 |
)
|
95 |
st.session_state["doc_id"] = doc_id
|
96 |
if index_results:
|