ugmSorcero commited on
Commit
6c3736e
·
1 Parent(s): e4aa90a

final touches to draw pipelines & manual cache

Browse files
core/pipelines.py CHANGED
@@ -2,15 +2,12 @@
2
  Haystack Pipelines
3
  """
4
 
5
- import tokenizers
6
  from haystack import Pipeline
7
  from haystack.document_stores import InMemoryDocumentStore
8
  from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
9
  from haystack.nodes.preprocessor import PreProcessor
10
- import streamlit as st
11
 
12
 
13
- @st.cache(allow_output_mutation=True)
14
  def keyword_search(
15
  index="documents",
16
  ):
@@ -42,13 +39,6 @@ def keyword_search(
42
  return search_pipeline, index_pipeline
43
 
44
 
45
- @st.cache(
46
- hash_funcs={
47
- tokenizers.Tokenizer: lambda _: None,
48
- tokenizers.AddedToken: lambda _: None,
49
- },
50
- allow_output_mutation=True,
51
- )
52
  def dense_passage_retrieval(
53
  index="documents",
54
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
 
2
  Haystack Pipelines
3
  """
4
 
 
5
  from haystack import Pipeline
6
  from haystack.document_stores import InMemoryDocumentStore
7
  from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
8
  from haystack.nodes.preprocessor import PreProcessor
 
9
 
10
 
 
11
  def keyword_search(
12
  index="documents",
13
  ):
 
39
  return search_pipeline, index_pipeline
40
 
41
 
 
 
 
 
 
 
 
42
  def dense_passage_retrieval(
43
  index="documents",
44
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
interface/components.py CHANGED
@@ -13,10 +13,16 @@ def component_select_pipeline(container):
13
  if "Keyword Search" in pipeline_names
14
  else 0,
15
  )
16
- (
17
- st.session_state["search_pipeline"],
18
- st.session_state["index_pipeline"],
19
- ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
 
 
 
 
 
 
20
 
21
 
22
  def component_show_pipeline(pipeline):
 
13
  if "Keyword Search" in pipeline_names
14
  else 0,
15
  )
16
+ if st.session_state["pipeline"] is None or st.session_state["pipeline"]["name"] != selected_pipeline:
17
+ (
18
+ search_pipeline,
19
+ index_pipeline,
20
+ ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
21
+ st.session_state["pipeline"] = {
22
+ 'name': selected_pipeline,
23
+ 'search_pipeline': search_pipeline,
24
+ 'index_pipeline': index_pipeline,
25
+ }
26
 
27
 
28
  def component_show_pipeline(pipeline):
interface/config.py CHANGED
@@ -1,7 +1,9 @@
1
  from interface.pages import page_landing_page, page_search, page_index
2
 
3
  # Define default Session Variables over the whole session.
4
- session_state_variables = {}
 
 
5
 
6
  # Define Pages for the demo
7
  pages = {
 
1
  from interface.pages import page_landing_page, page_search, page_index
2
 
3
  # Define default Session Variables over the whole session.
4
+ session_state_variables = {
5
+ "pipeline": None
6
+ }
7
 
8
  # Define Pages for the demo
9
  pages = {
interface/draw_pipelines.py CHANGED
@@ -3,11 +3,9 @@ from typing import List
3
  from itertools import chain
4
  import networkx as nx
5
  import plotly.graph_objs as go
6
- import streamlit as st
7
  import numpy as np
8
 
9
 
10
- @st.cache(allow_output_mutation=True)
11
  def get_pipeline_graph(pipeline):
12
  # Controls for how the graph is drawn
13
  nodeColor = "#ffbf00"
@@ -16,13 +14,37 @@ def get_pipeline_graph(pipeline):
16
  lineColor = "#ffffff"
17
 
18
  G = pipeline.graph
19
- initial_coordinate = (0, len(G.nodes))
20
- fixed_pos = {
21
- node: np.array([initial_coordinate[0], initial_coordinate[1] - float(idx)])
22
- for idx, node in enumerate(G.nodes)
23
- }
24
- pos = nx.spring_layout(G, pos=fixed_pos, seed=42)
25
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  for node in G.nodes:
27
  G.nodes[node]["pos"] = list(pos[node])
28
 
 
3
  from itertools import chain
4
  import networkx as nx
5
  import plotly.graph_objs as go
 
6
  import numpy as np
7
 
8
 
 
9
  def get_pipeline_graph(pipeline):
10
  # Controls for how the graph is drawn
11
  nodeColor = "#ffbf00"
 
14
  lineColor = "#ffffff"
15
 
16
  G = pipeline.graph
17
+ current_coordinate = (0, len(set([edge[0] for edge in G.edges()])) + 1)
18
+ # Transform G.edges into {node : all_connected_nodes} format
19
+ node_connections = {}
20
+ for in_node, out_node in G.edges():
21
+ if in_node in node_connections:
22
+ node_connections[in_node].append(out_node)
23
+ else:
24
+ node_connections[in_node] = [out_node]
25
+ # Get node coordinates/pos
26
+ fixed_pos_nodes = {}
27
+ for idx, (in_node, out_nodes) in enumerate(node_connections.items()):
28
+ if in_node not in fixed_pos_nodes:
29
+ fixed_pos_nodes[in_node] = np.array([current_coordinate[0], current_coordinate[1]])
30
+ current_coordinate = (current_coordinate[0], current_coordinate[1] - 1)
31
+ # If more than 1 out node, then branch out in X coordinate
32
+ if len(out_nodes) > 1:
33
+ # if length is odd
34
+ if (len(out_nodes) % 2) != 0:
35
+ middle_node = out_nodes[round(len(out_nodes)/2, 0) - 1]
36
+ fixed_pos_nodes[middle_node] = np.array([current_coordinate[0], current_coordinate[1]])
37
+ out_nodes = [n for n in out_nodes if n != middle_node]
38
+ correction_coordinate = - len(out_nodes) / 2
39
+ for out_node in out_nodes:
40
+ fixed_pos_nodes[out_node] = np.array([int(current_coordinate[0] + correction_coordinate), int(current_coordinate[1])])
41
+ if correction_coordinate == -1:
42
+ correction_coordinate += 1
43
+ correction_coordinate += 1
44
+ current_coordinate = (current_coordinate[0], current_coordinate[1] - 1)
45
+ elif len(node_connections) - 1 == idx:
46
+ fixed_pos_nodes[out_nodes[0]] = np.array([current_coordinate[0], current_coordinate[1]])
47
+ pos = nx.spring_layout(G, pos=fixed_pos_nodes, fixed=G.nodes(), seed=42)
48
  for node in G.nodes:
49
  G.nodes[node]["pos"] = list(pos[node])
50
 
interface/pages.py CHANGED
@@ -36,12 +36,12 @@ def page_search(container):
36
  ## SEARCH ##
37
  query = st.text_input("Query")
38
 
39
- component_show_pipeline(st.session_state["search_pipeline"])
40
 
41
  if st.button("Search"):
42
  st.session_state["search_results"] = search(
43
  queries=[query],
44
- pipeline=st.session_state["search_pipeline"],
45
  )
46
  if "search_results" in st.session_state:
47
  component_show_search_result(
@@ -53,7 +53,7 @@ def page_index(container):
53
  with container:
54
  st.title("Index time!")
55
 
56
- component_show_pipeline(st.session_state["index_pipeline"])
57
 
58
  input_funcs = {
59
  "Raw Text": (component_text_input, "card-text"),
@@ -74,7 +74,7 @@ def page_index(container):
74
  if st.button("Index"):
75
  index_results = index(
76
  corpus,
77
- st.session_state["index_pipeline"],
78
  )
79
  if index_results:
80
  st.write(index_results)
 
36
  ## SEARCH ##
37
  query = st.text_input("Query")
38
 
39
+ component_show_pipeline(st.session_state["pipeline"]["search_pipeline"])
40
 
41
  if st.button("Search"):
42
  st.session_state["search_results"] = search(
43
  queries=[query],
44
+ pipeline=st.session_state["pipeline"]["search_pipeline"],
45
  )
46
  if "search_results" in st.session_state:
47
  component_show_search_result(
 
53
  with container:
54
  st.title("Index time!")
55
 
56
+ component_show_pipeline(st.session_state["pipeline"]["index_pipeline"])
57
 
58
  input_funcs = {
59
  "Raw Text": (component_text_input, "card-text"),
 
74
  if st.button("Index"):
75
  index_results = index(
76
  corpus,
77
+ st.session_state["pipeline"]["index_pipeline"],
78
  )
79
  if index_results:
80
  st.write(index_results)