ugaray96 commited on
Commit
f026256
·
unverified ·
2 Parent(s): c9524e4 7786dc7

Merge pull request #10 from ugm2/feature/audio_output

Browse files
.gitignore CHANGED
@@ -128,4 +128,6 @@ dmypy.json
128
  # Pyre type checker
129
  .pyre/
130
 
131
- .vscode/
 
 
 
128
  # Pyre type checker
129
  .pyre/
130
 
131
+ .vscode/
132
+
133
+ data/audio/
core/pipelines.py CHANGED
@@ -2,14 +2,20 @@
2
  Haystack Pipelines
3
  """
4
 
 
5
  from haystack import Pipeline
6
  from haystack.document_stores import InMemoryDocumentStore
7
  from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
8
  from haystack.nodes.preprocessor import PreProcessor
9
  from haystack.nodes.ranker import SentenceTransformersRanker
 
 
10
 
 
 
11
 
12
- def keyword_search(index="documents", split_word_length=100):
 
13
  """
14
  **Keyword Search Pipeline**
15
 
@@ -19,8 +25,6 @@ def keyword_search(index="documents", split_word_length=100):
19
 
20
  - Documents that have more lexical overlap with the query are more likely to be relevant
21
  - Words that occur in fewer documents are more significant than words that occur in many documents
22
-
23
- :warning: **(HAYSTACK BUG) Keyword Search doesn't work if you reindex:** Please refresh page in order to reindex
24
  """
25
  document_store = InMemoryDocumentStore(index=index)
26
  keyword_retriever = TfidfRetriever(document_store=(document_store))
@@ -44,6 +48,15 @@ def keyword_search(index="documents", split_word_length=100):
44
  document_store, name="DocumentStore", inputs=["Preprocessor"]
45
  )
46
 
 
 
 
 
 
 
 
 
 
47
  return search_pipeline, index_pipeline
48
 
49
 
@@ -52,6 +65,7 @@ def dense_passage_retrieval(
52
  split_word_length=100,
53
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
54
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
 
55
  ):
56
  """
57
  **Dense Passage Retrieval Pipeline**
@@ -89,6 +103,15 @@ def dense_passage_retrieval(
89
  document_store, name="DocumentStore", inputs=["DPRRetriever"]
90
  )
91
 
 
 
 
 
 
 
 
 
 
92
  return search_pipeline, index_pipeline
93
 
94
 
@@ -98,6 +121,7 @@ def dense_passage_retrieval_ranker(
98
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
99
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
100
  ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
 
101
  ):
102
  """
103
  **Dense Passage Retrieval Ranker Pipeline**
@@ -118,4 +142,11 @@ def dense_passage_retrieval_ranker(
118
 
119
  search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
120
 
 
 
 
 
 
 
 
121
  return search_pipeline, index_pipeline
 
2
  Haystack Pipelines
3
  """
4
 
5
+ from pathlib import Path
6
  from haystack import Pipeline
7
  from haystack.document_stores import InMemoryDocumentStore
8
  from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
9
  from haystack.nodes.preprocessor import PreProcessor
10
  from haystack.nodes.ranker import SentenceTransformersRanker
11
+ from haystack.nodes.audio.document_to_speech import DocumentToSpeech
12
+ import os
13
 
14
+ data_path = "data/"
15
+ os.makedirs(data_path, exist_ok=True)
16
 
17
+
18
+ def keyword_search(index="documents", split_word_length=100, audio_output=False):
19
  """
20
  **Keyword Search Pipeline**
21
 
 
25
 
26
  - Documents that have more lexical overlap with the query are more likely to be relevant
27
  - Words that occur in fewer documents are more significant than words that occur in many documents
 
 
28
  """
29
  document_store = InMemoryDocumentStore(index=index)
30
  keyword_retriever = TfidfRetriever(document_store=(document_store))
 
48
  document_store, name="DocumentStore", inputs=["Preprocessor"]
49
  )
50
 
51
+ if audio_output:
52
+ doc2speech = DocumentToSpeech(
53
+ model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
54
+ generated_audio_dir=Path(data_path + "audio"),
55
+ )
56
+ search_pipeline.add_node(
57
+ doc2speech, name="DocumentToSpeech", inputs=["TfidfRetriever"]
58
+ )
59
+
60
  return search_pipeline, index_pipeline
61
 
62
 
 
65
  split_word_length=100,
66
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
67
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
68
+ audio_output=False,
69
  ):
70
  """
71
  **Dense Passage Retrieval Pipeline**
 
103
  document_store, name="DocumentStore", inputs=["DPRRetriever"]
104
  )
105
 
106
+ if audio_output:
107
+ doc2speech = DocumentToSpeech(
108
+ model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
109
+ generated_audio_dir=Path(data_path + "audio"),
110
+ )
111
+ search_pipeline.add_node(
112
+ doc2speech, name="DocumentToSpeech", inputs=["DPRRetriever"]
113
+ )
114
+
115
  return search_pipeline, index_pipeline
116
 
117
 
 
121
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
122
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
123
  ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
124
+ audio_output=False,
125
  ):
126
  """
127
  **Dense Passage Retrieval Ranker Pipeline**
 
142
 
143
  search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
144
 
145
+ if audio_output:
146
+ doc2speech = DocumentToSpeech(
147
+ model_name_or_path="espnet/kan-bayashi_ljspeech_vits",
148
+ generated_audio_dir=Path(data_path + "audio"),
149
+ )
150
+ search_pipeline.add_node(doc2speech, name="DocumentToSpeech", inputs=["Ranker"])
151
+
152
  return search_pipeline, index_pipeline
core/search_index.py CHANGED
@@ -37,15 +37,17 @@ def search(queries, pipeline):
37
  for res in matches:
38
  if not score_is_empty:
39
  score_is_empty = True if res.score is None else False
40
- query_results.append(
41
- {
42
- "text": res.content,
43
- "score": res.score,
44
- "id": res.meta["id"],
45
- "fragment_id": res.id,
46
- "meta": res.meta,
47
- }
48
- )
 
 
49
  if not score_is_empty:
50
  query_results = sorted(
51
  query_results, key=lambda x: x["score"], reverse=True
 
37
  for res in matches:
38
  if not score_is_empty:
39
  score_is_empty = True if res.score is None else False
40
+ match = {
41
+ "text": res.content,
42
+ "id": res.meta["id"],
43
+ "fragment_id": res.id,
44
+ "meta": res.meta,
45
+ }
46
+ if not score_is_empty:
47
+ match.update({"score": res.score})
48
+ if hasattr(res, "content_audio"):
49
+ match.update({"content_audio": res.content_audio})
50
+ query_results.append(match)
51
  if not score_is_empty:
52
  query_results = sorted(
53
  query_results, key=lambda x: x["score"], reverse=True
interface/components.py CHANGED
@@ -1,5 +1,10 @@
1
  import streamlit as st
2
- from interface.utils import get_pipelines, extract_text_from_url, extract_text_from_file
 
 
 
 
 
3
  from interface.draw_pipelines import get_pipeline_graph
4
 
5
 
@@ -42,7 +47,7 @@ def component_select_pipeline(container):
42
  "index_pipeline": index_pipeline,
43
  "doc": pipeline_funcs[index_pipe].__doc__,
44
  }
45
- st.session_state["doc_id"] = 0
46
 
47
 
48
  def component_show_pipeline(pipeline, pipeline_name):
@@ -65,8 +70,10 @@ def component_show_search_result(container, results):
65
  st.markdown(f"**Document**: {document['id']}")
66
  if "_split_id" in document["meta"]:
67
  st.markdown(f"**Document Chunk**: {document['meta']['_split_id']}")
68
- if document["score"] is not None:
69
  st.markdown(f"**Score**: {document['score']:.3f}")
 
 
70
  st.markdown("---")
71
 
72
 
 
1
  import streamlit as st
2
+ from interface.utils import (
3
+ get_pipelines,
4
+ extract_text_from_url,
5
+ extract_text_from_file,
6
+ reset_vars_data,
7
+ )
8
  from interface.draw_pipelines import get_pipeline_graph
9
 
10
 
 
47
  "index_pipeline": index_pipeline,
48
  "doc": pipeline_funcs[index_pipe].__doc__,
49
  }
50
+ reset_vars_data()
51
 
52
 
53
  def component_show_pipeline(pipeline, pipeline_name):
 
70
  st.markdown(f"**Document**: {document['id']}")
71
  if "_split_id" in document["meta"]:
72
  st.markdown(f"**Document Chunk**: {document['meta']['_split_id']}")
73
+ if "score" in document:
74
  st.markdown(f"**Score**: {document['score']:.3f}")
75
+ if "content_audio" in document:
76
+ st.audio(str(document["content_audio"]))
77
  st.markdown("---")
78
 
79
 
interface/pages.py CHANGED
@@ -25,12 +25,12 @@ def page_landing_page(container):
25
  "\n - Index raw text, URLs, CSVs, PDFs and Images"
26
  "\n - Use Dense Passage Retrieval, Keyword Search pipeline and DPR Ranker pipelines"
27
  "\n - Search the indexed documents"
 
28
  )
29
  st.markdown(
30
  "TODO list:"
31
  "\n - File type classification and converter nodes"
32
  "\n - Audio to text support for indexing"
33
- "\n - Include text to audio to read responses"
34
  "\n - Build other pipelines"
35
  )
36
  st.markdown(
 
25
  "\n - Index raw text, URLs, CSVs, PDFs and Images"
26
  "\n - Use Dense Passage Retrieval, Keyword Search pipeline and DPR Ranker pipelines"
27
  "\n - Search the indexed documents"
28
+ "\n - Read your responses out loud using the `audio_output` option!"
29
  )
30
  st.markdown(
31
  "TODO list:"
32
  "\n - File type classification and converter nodes"
33
  "\n - Audio to text support for indexing"
 
34
  "\n - Build other pipelines"
35
  )
36
  st.markdown(
interface/utils.py CHANGED
@@ -1,5 +1,8 @@
1
  from io import StringIO
 
 
2
  import core.pipelines as pipelines_functions
 
3
  from inspect import getmembers, isfunction, signature
4
  from newspaper import Article
5
  from PyPDF2 import PdfFileReader
@@ -23,6 +26,13 @@ def get_pipelines():
23
  return pipeline_names, pipeline_funcs, pipeline_func_parameters
24
 
25
 
 
 
 
 
 
 
 
26
  @st.experimental_memo
27
  def extract_text_from_url(url: str):
28
  article = Article(url)
 
1
  from io import StringIO
2
+ import os
3
+ import shutil
4
  import core.pipelines as pipelines_functions
5
+ from core.pipelines import data_path
6
  from inspect import getmembers, isfunction, signature
7
  from newspaper import Article
8
  from PyPDF2 import PdfFileReader
 
26
  return pipeline_names, pipeline_funcs, pipeline_func_parameters
27
 
28
 
29
+ def reset_vars_data():
30
+ st.session_state["doc_id"] = 0
31
+ # Delete data files
32
+ shutil.rmtree(data_path)
33
+ os.makedirs(data_path, exist_ok=True)
34
+
35
+
36
  @st.experimental_memo
37
  def extract_text_from_url(url: str):
38
  article = Article(url)
requirements.txt CHANGED
@@ -5,4 +5,8 @@ black==22.8.0
5
  plotly==5.10.0
6
  newspaper3k==0.2.8
7
  PyPDF2==2.10.7
8
- pytesseract==0.3.10
 
 
 
 
 
5
  plotly==5.10.0
6
  newspaper3k==0.2.8
7
  PyPDF2==2.10.7
8
+ pytesseract==0.3.10
9
+ soundfile==0.10.3.post1
10
+ espnet
11
+ pydub==0.25.1
12
+ espnet_model_zoo==0.1.7