Unai Garay Maestre commited on
Commit
3900908
·
unverified ·
2 Parent(s): 2ce16ea f65e26a

Merge pull request #1 from ugm2/feature/draw_pipeline

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor="#ffbf00"
3
+ backgroundColor="#0e1117"
4
+ secondaryBackgroundColor="#282929"
5
+ textColor = "#ffffff"
6
+ font="sans serif"
core/pipelines.py CHANGED
@@ -2,15 +2,12 @@
2
  Haystack Pipelines
3
  """
4
 
5
- import tokenizers
6
  from haystack import Pipeline
7
  from haystack.document_stores import InMemoryDocumentStore
8
  from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
9
  from haystack.nodes.preprocessor import PreProcessor
10
- import streamlit as st
11
 
12
 
13
- @st.cache(allow_output_mutation=True)
14
  def keyword_search(
15
  index="documents",
16
  ):
@@ -42,13 +39,6 @@ def keyword_search(
42
  return search_pipeline, index_pipeline
43
 
44
 
45
- @st.cache(
46
- hash_funcs={
47
- tokenizers.Tokenizer: lambda _: None,
48
- tokenizers.AddedToken: lambda _: None,
49
- },
50
- allow_output_mutation=True,
51
- )
52
  def dense_passage_retrieval(
53
  index="documents",
54
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
 
2
  Haystack Pipelines
3
  """
4
 
 
5
  from haystack import Pipeline
6
  from haystack.document_stores import InMemoryDocumentStore
7
  from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
8
  from haystack.nodes.preprocessor import PreProcessor
 
9
 
10
 
 
11
  def keyword_search(
12
  index="documents",
13
  ):
 
39
  return search_pipeline, index_pipeline
40
 
41
 
 
 
 
 
 
 
 
42
  def dense_passage_retrieval(
43
  index="documents",
44
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
interface/components.py CHANGED
@@ -1,16 +1,10 @@
1
  import streamlit as st
2
- import core.pipelines as pipelines_functions
3
- from inspect import getmembers, isfunction
4
- from networkx.drawing.nx_agraph import to_agraph
5
 
6
 
7
  def component_select_pipeline(container):
8
- pipeline_names, pipeline_funcs = list(
9
- zip(*getmembers(pipelines_functions, isfunction))
10
- )
11
- pipeline_names = [
12
- " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
13
- ]
14
  with container:
15
  selected_pipeline = st.selectbox(
16
  "Select pipeline",
@@ -19,18 +13,26 @@ def component_select_pipeline(container):
19
  if "Keyword Search" in pipeline_names
20
  else 0,
21
  )
22
- (
23
- st.session_state["search_pipeline"],
24
- st.session_state["index_pipeline"],
25
- ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
 
 
 
 
 
 
 
 
 
26
 
27
 
28
- def component_show_pipeline(container, pipeline):
29
  """Draw the pipeline"""
30
  with st.expander("Show pipeline"):
31
- graphviz = to_agraph(pipeline.graph)
32
- graphviz.layout("dot")
33
- st.graphviz_chart(graphviz.string())
34
 
35
 
36
  def component_show_search_result(container, results):
 
1
  import streamlit as st
2
+ from interface.utils import get_pipelines
3
+ from interface.draw_pipelines import get_pipeline_graph
 
4
 
5
 
6
  def component_select_pipeline(container):
7
+ pipeline_names, pipeline_funcs = get_pipelines()
 
 
 
 
 
8
  with container:
9
  selected_pipeline = st.selectbox(
10
  "Select pipeline",
 
13
  if "Keyword Search" in pipeline_names
14
  else 0,
15
  )
16
+ if (
17
+ st.session_state["pipeline"] is None
18
+ or st.session_state["pipeline"]["name"] != selected_pipeline
19
+ ):
20
+ (
21
+ search_pipeline,
22
+ index_pipeline,
23
+ ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
24
+ st.session_state["pipeline"] = {
25
+ "name": selected_pipeline,
26
+ "search_pipeline": search_pipeline,
27
+ "index_pipeline": index_pipeline,
28
+ }
29
 
30
 
31
+ def component_show_pipeline(pipeline):
32
  """Draw the pipeline"""
33
  with st.expander("Show pipeline"):
34
+ fig = get_pipeline_graph(pipeline)
35
+ st.plotly_chart(fig, use_container_width=True)
 
36
 
37
 
38
  def component_show_search_result(container, results):
interface/config.py CHANGED
@@ -1,7 +1,7 @@
1
  from interface.pages import page_landing_page, page_search, page_index
2
 
3
  # Define default Session Variables over the whole session.
4
- session_state_variables = {}
5
 
6
  # Define Pages for the demo
7
  pages = {
 
1
  from interface.pages import page_landing_page, page_search, page_index
2
 
3
  # Define default Session Variables over the whole session.
4
+ session_state_variables = {"pipeline": None}
5
 
6
  # Define Pages for the demo
7
  pages = {
interface/draw_pipelines.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from itertools import chain
4
+ import networkx as nx
5
+ import plotly.graph_objs as go
6
+ import numpy as np
7
+
8
+
9
+ def get_pipeline_graph(pipeline):
10
+ # Controls for how the graph is drawn
11
+ nodeColor = "#ffbf00"
12
+ nodeSize = 40
13
+ lineWidth = 2
14
+ lineColor = "#ffffff"
15
+
16
+ G = pipeline.graph
17
+ current_coordinate = (0, len(set([edge[0] for edge in G.edges()])) + 1)
18
+ # Transform G.edges into {node : all_connected_nodes} format
19
+ node_connections = {}
20
+ for in_node, out_node in G.edges():
21
+ if in_node in node_connections:
22
+ node_connections[in_node].append(out_node)
23
+ else:
24
+ node_connections[in_node] = [out_node]
25
+ # Get node coordinates/pos
26
+ fixed_pos_nodes = {}
27
+ for idx, (in_node, out_nodes) in enumerate(node_connections.items()):
28
+ if in_node not in fixed_pos_nodes:
29
+ fixed_pos_nodes[in_node] = np.array(
30
+ [current_coordinate[0], current_coordinate[1]]
31
+ )
32
+ current_coordinate = (current_coordinate[0], current_coordinate[1] - 1)
33
+ # If more than 1 out node, then branch out in X coordinate
34
+ if len(out_nodes) > 1:
35
+ # if length is odd
36
+ if (len(out_nodes) % 2) != 0:
37
+ middle_node = out_nodes[round(len(out_nodes) / 2, 0) - 1]
38
+ fixed_pos_nodes[middle_node] = np.array(
39
+ [current_coordinate[0], current_coordinate[1]]
40
+ )
41
+ out_nodes = [n for n in out_nodes if n != middle_node]
42
+ correction_coordinate = -len(out_nodes) / 2
43
+ for out_node in out_nodes:
44
+ fixed_pos_nodes[out_node] = np.array(
45
+ [
46
+ int(current_coordinate[0] + correction_coordinate),
47
+ int(current_coordinate[1]),
48
+ ]
49
+ )
50
+ if correction_coordinate == -1:
51
+ correction_coordinate += 1
52
+ correction_coordinate += 1
53
+ current_coordinate = (current_coordinate[0], current_coordinate[1] - 1)
54
+ elif len(node_connections) - 1 == idx:
55
+ fixed_pos_nodes[out_nodes[0]] = np.array(
56
+ [current_coordinate[0], current_coordinate[1]]
57
+ )
58
+ pos = nx.spring_layout(G, pos=fixed_pos_nodes, fixed=G.nodes(), seed=42)
59
+ for node in G.nodes:
60
+ G.nodes[node]["pos"] = list(pos[node])
61
+
62
+ # Make list of nodes for plotly
63
+ node_x = []
64
+ node_y = []
65
+ node_name = []
66
+ for node in G.nodes():
67
+ node_name.append(G.nodes[node]["component"].name)
68
+ x, y = G.nodes[node]["pos"]
69
+ node_x.append(x)
70
+ node_y.append(y)
71
+
72
+ # Make a list of edges for plotly, including line segments that result in arrowheads
73
+ edge_x = []
74
+ edge_y = []
75
+ for edge in G.edges():
76
+ start = G.nodes[edge[0]]["pos"]
77
+ end = G.nodes[edge[1]]["pos"]
78
+ # addEdge(start, end, edge_x, edge_y, lengthFrac=1, arrowPos = None, arrowLength=0.025, arrowAngle = 30, dotSize=20)
79
+ edge_x, edge_y = addEdge(
80
+ start,
81
+ end,
82
+ edge_x,
83
+ edge_y,
84
+ lengthFrac=0.5,
85
+ arrowPos="end",
86
+ arrowLength=0.04,
87
+ arrowAngle=40,
88
+ dotSize=nodeSize,
89
+ )
90
+
91
+ edge_trace = go.Scatter(
92
+ x=edge_x,
93
+ y=edge_y,
94
+ line=dict(width=lineWidth, color=lineColor),
95
+ hoverinfo="none",
96
+ mode="lines",
97
+ )
98
+
99
+ node_trace = go.Scatter(
100
+ x=node_x,
101
+ y=node_y,
102
+ mode="markers+text",
103
+ textposition="middle right",
104
+ hoverinfo="none",
105
+ text=node_name,
106
+ marker=dict(showscale=False, color=nodeColor, size=nodeSize),
107
+ textfont=dict(size=18),
108
+ )
109
+
110
+ fig = go.Figure(
111
+ data=[edge_trace, node_trace],
112
+ layout=go.Layout(
113
+ showlegend=False,
114
+ hovermode="closest",
115
+ margin=dict(b=20, l=5, r=5, t=40),
116
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
117
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
118
+ ),
119
+ )
120
+
121
+ fig.update_layout(
122
+ yaxis=dict(scaleanchor="x", scaleratio=1), plot_bgcolor="rgb(14,17,23)"
123
+ )
124
+
125
+ return fig
126
+
127
+
128
+ def addEdge(
129
+ start,
130
+ end,
131
+ edge_x,
132
+ edge_y,
133
+ lengthFrac=1,
134
+ arrowPos=None,
135
+ arrowLength=0.025,
136
+ arrowAngle=30,
137
+ dotSize=20,
138
+ ):
139
+
140
+ # Get start and end cartesian coordinates
141
+ x0, y0 = start
142
+ x1, y1 = end
143
+
144
+ # Incorporate the fraction of this segment covered by a dot into total reduction
145
+ length = math.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2)
146
+ dotSizeConversion = 0.0565 / 20 # length units per dot size
147
+ convertedDotDiameter = dotSize * dotSizeConversion
148
+ lengthFracReduction = convertedDotDiameter / length
149
+ lengthFrac = lengthFrac - lengthFracReduction
150
+
151
+ # If the line segment should not cover the entire distance, get actual start and end coords
152
+ skipX = (x1 - x0) * (1 - lengthFrac)
153
+ skipY = (y1 - y0) * (1 - lengthFrac)
154
+ x0 = x0 + skipX / 2
155
+ x1 = x1 - skipX / 2
156
+ y0 = y0 + skipY / 2
157
+ y1 = y1 - skipY / 2
158
+
159
+ # Append line corresponding to the edge
160
+ edge_x.append(x0)
161
+ edge_x.append(x1)
162
+ edge_x.append(
163
+ None
164
+ ) # Prevents a line being drawn from end of this edge to start of next edge
165
+ edge_y.append(y0)
166
+ edge_y.append(y1)
167
+ edge_y.append(None)
168
+
169
+ # Draw arrow
170
+ if not arrowPos == None:
171
+
172
+ # Find the point of the arrow; assume is at end unless told middle
173
+ pointx = x1
174
+ pointy = y1
175
+
176
+ eta = math.degrees(math.atan((x1 - x0) / (y1 - y0))) if y1 != y0 else 90.0
177
+
178
+ if arrowPos == "middle" or arrowPos == "mid":
179
+ pointx = x0 + (x1 - x0) / 2
180
+ pointy = y0 + (y1 - y0) / 2
181
+
182
+ # Find the directions the arrows are pointing
183
+ signx = (x1 - x0) / abs(x1 - x0) if x1 != x0 else +1 # verify this once
184
+ signy = (y1 - y0) / abs(y1 - y0) if y1 != y0 else +1 # verified
185
+
186
+ # Append first arrowhead
187
+ dx = arrowLength * math.sin(math.radians(eta + arrowAngle))
188
+ dy = arrowLength * math.cos(math.radians(eta + arrowAngle))
189
+ edge_x.append(pointx)
190
+ edge_x.append(pointx - signx**2 * signy * dx)
191
+ edge_x.append(None)
192
+ edge_y.append(pointy)
193
+ edge_y.append(pointy - signx**2 * signy * dy)
194
+ edge_y.append(None)
195
+
196
+ # And second arrowhead
197
+ dx = arrowLength * math.sin(math.radians(eta - arrowAngle))
198
+ dy = arrowLength * math.cos(math.radians(eta - arrowAngle))
199
+ edge_x.append(pointx)
200
+ edge_x.append(pointx - signx**2 * signy * dx)
201
+ edge_x.append(None)
202
+ edge_y.append(pointy)
203
+ edge_y.append(pointy - signx**2 * signy * dy)
204
+ edge_y.append(None)
205
+
206
+ return edge_x, edge_y
207
+
208
+
209
+ def add_arrows(
210
+ source_x: List[float],
211
+ target_x: List[float],
212
+ source_y: List[float],
213
+ target_y: List[float],
214
+ arrowLength=0.025,
215
+ arrowAngle=30,
216
+ ):
217
+ pointx = list(map(lambda x: x[0] + (x[1] - x[0]) / 2, zip(source_x, target_x)))
218
+ pointy = list(map(lambda x: x[0] + (x[1] - x[0]) / 2, zip(source_y, target_y)))
219
+ etas = list(
220
+ map(
221
+ lambda x: math.degrees(math.atan((x[1] - x[0]) / (x[3] - x[2]))),
222
+ zip(source_x, target_x, source_y, target_y),
223
+ )
224
+ )
225
+
226
+ signx = list(
227
+ map(lambda x: (x[1] - x[0]) / abs(x[1] - x[0]), zip(source_x, target_x))
228
+ )
229
+ signy = list(
230
+ map(lambda x: (x[1] - x[0]) / abs(x[1] - x[0]), zip(source_y, target_y))
231
+ )
232
+
233
+ dx = list(map(lambda x: arrowLength * math.sin(math.radians(x + arrowAngle)), etas))
234
+ dy = list(map(lambda x: arrowLength * math.cos(math.radians(x + arrowAngle)), etas))
235
+ none_spacer = [None for _ in range(len(pointx))]
236
+ arrow_line_x = list(
237
+ map(lambda x: x[0] - x[1] ** 2 * x[2] * x[3], zip(pointx, signx, signy, dx))
238
+ )
239
+ arrow_line_y = list(
240
+ map(lambda x: x[0] - x[1] ** 2 * x[2] * x[3], zip(pointy, signx, signy, dy))
241
+ )
242
+
243
+ arrow_line_1x_coords = list(chain(*zip(pointx, arrow_line_x, none_spacer)))
244
+ arrow_line_1y_coords = list(chain(*zip(pointy, arrow_line_y, none_spacer)))
245
+
246
+ dx = list(map(lambda x: arrowLength * math.sin(math.radians(x - arrowAngle)), etas))
247
+ dy = list(map(lambda x: arrowLength * math.cos(math.radians(x - arrowAngle)), etas))
248
+ none_spacer = [None for _ in range(len(pointx))]
249
+ arrow_line_x = list(
250
+ map(lambda x: x[0] - x[1] ** 2 * x[2] * x[3], zip(pointx, signx, signy, dx))
251
+ )
252
+ arrow_line_y = list(
253
+ map(lambda x: x[0] - x[1] ** 2 * x[2] * x[3], zip(pointy, signx, signy, dy))
254
+ )
255
+
256
+ arrow_line_2x_coords = list(chain(*zip(pointx, arrow_line_x, none_spacer)))
257
+ arrow_line_2y_coords = list(chain(*zip(pointy, arrow_line_y, none_spacer)))
258
+
259
+ x_arrows = arrow_line_1x_coords + arrow_line_2x_coords
260
+ y_arrows = arrow_line_1y_coords + arrow_line_2y_coords
261
+
262
+ return x_arrows, y_arrows
interface/pages.py CHANGED
@@ -36,12 +36,12 @@ def page_search(container):
36
  ## SEARCH ##
37
  query = st.text_input("Query")
38
 
39
- # component_show_pipeline(container, st.session_state["search_pipeline"])
40
 
41
  if st.button("Search"):
42
  st.session_state["search_results"] = search(
43
  queries=[query],
44
- pipeline=st.session_state["search_pipeline"],
45
  )
46
  if "search_results" in st.session_state:
47
  component_show_search_result(
@@ -53,7 +53,7 @@ def page_index(container):
53
  with container:
54
  st.title("Index time!")
55
 
56
- # component_show_pipeline(container, st.session_state["index_pipeline"])
57
 
58
  input_funcs = {
59
  "Raw Text": (component_text_input, "card-text"),
@@ -74,7 +74,7 @@ def page_index(container):
74
  if st.button("Index"):
75
  index_results = index(
76
  corpus,
77
- st.session_state["index_pipeline"],
78
  )
79
  if index_results:
80
  st.write(index_results)
 
36
  ## SEARCH ##
37
  query = st.text_input("Query")
38
 
39
+ component_show_pipeline(st.session_state["pipeline"]["search_pipeline"])
40
 
41
  if st.button("Search"):
42
  st.session_state["search_results"] = search(
43
  queries=[query],
44
+ pipeline=st.session_state["pipeline"]["search_pipeline"],
45
  )
46
  if "search_results" in st.session_state:
47
  component_show_search_result(
 
53
  with container:
54
  st.title("Index time!")
55
 
56
+ component_show_pipeline(st.session_state["pipeline"]["index_pipeline"])
57
 
58
  input_funcs = {
59
  "Raw Text": (component_text_input, "card-text"),
 
74
  if st.button("Index"):
75
  index_results = index(
76
  corpus,
77
+ st.session_state["pipeline"]["index_pipeline"],
78
  )
79
  if index_results:
80
  st.write(index_results)
interface/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import core.pipelines as pipelines_functions
2
+ from inspect import getmembers, isfunction
3
+
4
+
5
+ def get_pipelines():
6
+ pipeline_names, pipeline_funcs = list(
7
+ zip(*getmembers(pipelines_functions, isfunction))
8
+ )
9
+ pipeline_names = [
10
+ " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
11
+ ]
12
+ return pipeline_names, pipeline_funcs
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  streamlit
2
  streamlit_option_menu
3
  farm-haystack
4
- black
 
 
1
  streamlit
2
  streamlit_option_menu
3
  farm-haystack
4
+ black
5
+ plotly