ugmSorcero commited on
Commit
01b8e8e
Β·
1 Parent(s): a492fff

First app version

Browse files
app.py CHANGED
@@ -1,5 +1,43 @@
1
  import streamlit as st
2
 
3
- st.title("🧠 Neural Search πŸ”Ž")
 
 
 
 
 
4
 
5
- st.write("Coming soon...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
+ st.set_page_config(
4
+ page_title="Neural Search",
5
+ page_icon="πŸ”Ž",
6
+ layout="wide",
7
+ initial_sidebar_state="expanded",
8
+ )
9
 
10
+ from streamlit_option_menu import option_menu
11
+ from interface.config import session_state_variables, pages
12
+ from interface.components import component_select_pipeline
13
+
14
+ # Initialization of session state
15
+ for key, value in session_state_variables.items():
16
+ if key not in st.session_state:
17
+ st.session_state[key] = value
18
+
19
+
20
+ def run_demo():
21
+
22
+ main_page = st.container()
23
+
24
+ st.sidebar.title("🧠 Neural Search πŸ”Ž")
25
+
26
+ navigation = st.sidebar.container()
27
+
28
+ with navigation:
29
+
30
+ selected_page = option_menu(
31
+ "Navigation",
32
+ list(pages.keys()),
33
+ icons=[f[1] for f in pages.values()],
34
+ menu_icon="cast",
35
+ default_index=0,
36
+ )
37
+ component_select_pipeline(navigation)
38
+
39
+ # Draw the correct page
40
+ pages[selected_page][0](main_page)
41
+
42
+
43
+ run_demo()
core/pipelines.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Haystack Pipelines
3
+ """
4
+
5
+ import tokenizers
6
+ from haystack import Pipeline
7
+ from haystack.document_stores import InMemoryDocumentStore
8
+ from haystack.nodes.retriever import DensePassageRetriever
9
+ from haystack.nodes.preprocessor import PreProcessor
10
+ import streamlit as st
11
+
12
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None}, allow_output_mutation=True)
13
+ def dense_passage_retrieval(
14
+ index='documents',
15
+ query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
16
+ passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
17
+ ):
18
+ document_store = InMemoryDocumentStore(index=index)
19
+ dpr_retriever = DensePassageRetriever(
20
+ document_store=document_store,
21
+ query_embedding_model=query_embedding_model,
22
+ passage_embedding_model=passage_embedding_model,
23
+ )
24
+ processor = PreProcessor(
25
+ clean_empty_lines=True,
26
+ clean_whitespace=True,
27
+ clean_header_footer=True,
28
+ split_by="word",
29
+ split_length=100,
30
+ split_respect_sentence_boundary=True,
31
+ split_overlap=0,
32
+ )
33
+ # SEARCH PIPELINE
34
+ search_pipeline = Pipeline()
35
+ search_pipeline.add_node(dpr_retriever, name="DPRRetriever", inputs=["Query"])
36
+
37
+ # INDEXING PIPELINE
38
+ index_pipeline = Pipeline()
39
+ index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
40
+ index_pipeline.add_node(dpr_retriever, name="DPRRetriever", inputs=["Preprocessor"])
41
+ index_pipeline.add_node(
42
+ document_store, name="DocumentStore", inputs=["DPRRetriever"]
43
+ )
44
+
45
+ return search_pipeline, index_pipeline
core/search_index.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from haystack.schema import Document
2
+ import uuid
3
+
4
+
5
+ def format_docs(documents):
6
+ """Given a list of documents, format the documents and return the documents and doc ids."""
7
+ db_docs: list = []
8
+ for doc in documents:
9
+ doc_id = doc['id'] if doc['id'] is not None else str(uuid.uuid4())
10
+ db_doc = {
11
+ "content": doc['text'],
12
+ "content_type": "text",
13
+ "id": str(uuid.uuid4()),
14
+ "meta": {"id": doc_id},
15
+ }
16
+ db_docs.append(Document(**db_doc))
17
+ return db_docs, [doc.meta["id"] for doc in db_docs]
18
+
19
+ def index(documents, pipeline):
20
+ documents, doc_ids = format_docs(documents)
21
+ pipeline.run(documents=documents)
22
+ return doc_ids
23
+
24
+ def search(queries, pipeline):
25
+ results = []
26
+ matches_queries = pipeline.run_batch(queries=queries)
27
+ for matches in matches_queries["documents"]:
28
+ query_results = []
29
+ for res in matches:
30
+ metadata = res.meta
31
+ query_results.append(
32
+ {
33
+ "text": res.content,
34
+ "score": res.score,
35
+ "id": res.meta["id"],
36
+ "fragment_id": res.id
37
+ }
38
+ )
39
+ results.append(
40
+ sorted(query_results, key=lambda x: x["score"], reverse=True)
41
+ )
42
+ return results
interface/components.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import core.pipelines as pipelines_functions
3
+ from inspect import getmembers, isfunction
4
+
5
+ def component_select_pipeline(container):
6
+ pipeline_names, pipeline_funcs = list(zip(*getmembers(pipelines_functions, isfunction)))
7
+ pipeline_names = [' '.join([n.capitalize() for n in name.split('_')]) for name in pipeline_names]
8
+ with container:
9
+ selected_pipeline = st.selectbox(
10
+ 'Select pipeline',
11
+ pipeline_names
12
+ )
13
+ st.session_state['search_pipeline'], \
14
+ st.session_state['index_pipeline'] = \
15
+ pipeline_funcs[pipeline_names.index(selected_pipeline)]()
16
+
17
+ def component_show_pipeline(container, pipeline):
18
+ """Draw the pipeline"""
19
+ with container:
20
+ pass
21
+
22
+ def component_show_search_result(container, results):
23
+ with container:
24
+ for idx, document in enumerate(results):
25
+ st.markdown(f"### Match {idx+1}")
26
+ st.markdown(f"**Text**: {document['text']}")
27
+ st.markdown(f"**Document**: {document['id']}")
28
+ st.markdown(f"**Score**: {document['score']:.3f}")
29
+ st.markdown("---")
30
+
31
+ def component_text_input(container):
32
+ """Draw the Text Input widget"""
33
+ with container:
34
+ texts = []
35
+ doc_id = 1
36
+ with st.expander("Enter documents"):
37
+ while True:
38
+ text = st.text_input(f"Document {doc_id}", key=doc_id)
39
+ if text != "":
40
+ texts.append({"text": text})
41
+ doc_id += 1
42
+ st.markdown("---")
43
+ else:
44
+ break
45
+ corpus = [
46
+ {"text": doc["text"], "id": doc_id}
47
+ for doc_id, doc in enumerate(texts)
48
+ ]
49
+ return corpus
interface/config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from interface.pages import page_landing_page, page_search, page_index
2
+
3
+ # Define default Session Variables over the whole session.
4
+ session_state_variables = {}
5
+
6
+ # Define Pages for the demo
7
+ pages = {
8
+ "Introduction": (page_landing_page, "house-fill"),
9
+ "Search": (page_search, "search"),
10
+ "Index": (page_index, "files"),
11
+ }
interface/pages.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_option_menu import option_menu
3
+ from core.search_index import index, search
4
+ from interface.components import component_show_search_result, component_text_input
5
+
6
+ def page_landing_page(container):
7
+ with container:
8
+ st.header("🧠 Neural Search πŸ”Ž")
9
+
10
+ st.markdown(
11
+ "This is a tool to allow indexing & search content using neural capabilities"
12
+ )
13
+
14
+ def page_search(container):
15
+ with container:
16
+ st.title("Query me!")
17
+
18
+ ## SEARCH ##
19
+ query = st.text_input("Query")
20
+
21
+ if st.button("Search"):
22
+ st.session_state['search_results'] = search(
23
+ queries=[query],
24
+ pipeline=st.session_state['search_pipeline'],
25
+ )
26
+ if 'search_results' in st.session_state:
27
+ component_show_search_result(
28
+ container=container,
29
+ results=st.session_state['search_results'][0]
30
+ )
31
+
32
+ def page_index(container):
33
+ with container:
34
+ st.title("Index time!")
35
+
36
+ input_funcs = {
37
+ "Raw Text": (component_text_input, "card-text"),
38
+ }
39
+ selected_input = option_menu(
40
+ "Input Text",
41
+ list(input_funcs.keys()),
42
+ icons=[f[1] for f in input_funcs.values()],
43
+ menu_icon="list",
44
+ default_index=0,
45
+ orientation="horizontal",
46
+ )
47
+
48
+ corpus = input_funcs[selected_input][0](container)
49
+
50
+ if len(corpus) > 0:
51
+ index_results = None
52
+ if st.button("Index"):
53
+ index_results = index(
54
+ corpus,
55
+ st.session_state['index_pipeline'],
56
+ )
57
+ if index_results:
58
+ st.write(index_results)
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- streamlit
 
 
 
1
+ streamlit
2
+ streamlit_option_menu
3
+ farm-haystack