ugaray96 commited on
Commit
acb72cc
·
1 Parent(s): 893d078

Adds loading message and removes results when loading new pipeline

Browse files
core/pipelines.py CHANGED
@@ -150,8 +150,7 @@ def dense_passage_retrieval_ranker(
150
  split_word_length=split_word_length,
151
  query_embedding_model=query_embedding_model,
152
  passage_embedding_model=passage_embedding_model,
153
- # top_k high to allow better recall, the ranker will handle the precision
154
- top_k=10000000,
155
  )
156
  ranker = SentenceTransformersRanker(model_name_or_path=ranker_model, top_k=top_k)
157
 
 
150
  split_word_length=split_word_length,
151
  query_embedding_model=query_embedding_model,
152
  passage_embedding_model=passage_embedding_model,
153
+ top_k=top_k,
 
154
  )
155
  ranker = SentenceTransformersRanker(model_name_or_path=ranker_model, top_k=top_k)
156
 
interface/components.py CHANGED
@@ -10,44 +10,47 @@ from interface.draw_pipelines import get_pipeline_graph
10
 
11
  def component_select_pipeline(container):
12
  pipeline_names, pipeline_funcs, pipeline_func_parameters = get_pipelines()
13
- with container:
14
- selected_pipeline = st.selectbox(
15
- "Select pipeline",
16
- pipeline_names,
17
- index=pipeline_names.index("Keyword Search")
18
- if "Keyword Search" in pipeline_names
19
- else 0,
20
- )
21
- index_pipe = pipeline_names.index(selected_pipeline)
22
- st.write("---")
23
- st.header("Pipeline Parameters")
24
- for parameter, value in pipeline_func_parameters[index_pipe].items():
25
- if isinstance(value, str):
26
- value = st.text_input(parameter, value)
27
- elif isinstance(value, bool):
28
- value = st.checkbox(parameter, value)
29
- elif isinstance(value, int):
30
- value = int(st.number_input(parameter, value=value))
31
- elif isinstance(value, float):
32
- value = float(st.number_input(parameter, value=value))
33
- pipeline_func_parameters[index_pipe][parameter] = value
34
- if (
35
- st.session_state["pipeline"] is None
36
- or st.session_state["pipeline"]["name"] != selected_pipeline
37
- or list(st.session_state["pipeline_func_parameters"][index_pipe].values())
38
- != list(pipeline_func_parameters[index_pipe].values())
39
- ):
40
- st.session_state["pipeline_func_parameters"] = pipeline_func_parameters
41
- (search_pipeline, index_pipeline,) = pipeline_funcs[
42
- index_pipe
43
- ](**pipeline_func_parameters[index_pipe])
44
- st.session_state["pipeline"] = {
45
- "name": selected_pipeline,
46
- "search_pipeline": search_pipeline,
47
- "index_pipeline": index_pipeline,
48
- "doc": pipeline_funcs[index_pipe].__doc__,
49
- }
50
- reset_vars_data()
 
 
 
51
 
52
 
53
  def component_show_pipeline(pipeline, pipeline_name):
 
10
 
11
  def component_select_pipeline(container):
12
  pipeline_names, pipeline_funcs, pipeline_func_parameters = get_pipelines()
13
+ with st.spinner("Loading Pipeline..."):
14
+ with container:
15
+ selected_pipeline = st.selectbox(
16
+ "Select pipeline",
17
+ pipeline_names,
18
+ index=pipeline_names.index("Keyword Search")
19
+ if "Keyword Search" in pipeline_names
20
+ else 0,
21
+ )
22
+ index_pipe = pipeline_names.index(selected_pipeline)
23
+ st.write("---")
24
+ st.header("Pipeline Parameters")
25
+ for parameter, value in pipeline_func_parameters[index_pipe].items():
26
+ if isinstance(value, str):
27
+ value = st.text_input(parameter, value)
28
+ elif isinstance(value, bool):
29
+ value = st.checkbox(parameter, value)
30
+ elif isinstance(value, int):
31
+ value = int(st.number_input(parameter, value=value))
32
+ elif isinstance(value, float):
33
+ value = float(st.number_input(parameter, value=value))
34
+ pipeline_func_parameters[index_pipe][parameter] = value
35
+ if (
36
+ st.session_state["pipeline"] is None
37
+ or st.session_state["pipeline"]["name"] != selected_pipeline
38
+ or list(
39
+ st.session_state["pipeline_func_parameters"][index_pipe].values()
40
+ )
41
+ != list(pipeline_func_parameters[index_pipe].values())
42
+ ):
43
+ st.session_state["pipeline_func_parameters"] = pipeline_func_parameters
44
+ (search_pipeline, index_pipeline,) = pipeline_funcs[
45
+ index_pipe
46
+ ](**pipeline_func_parameters[index_pipe])
47
+ st.session_state["pipeline"] = {
48
+ "name": selected_pipeline,
49
+ "search_pipeline": search_pipeline,
50
+ "index_pipeline": index_pipeline,
51
+ "doc": pipeline_funcs[index_pipe].__doc__,
52
+ }
53
+ reset_vars_data()
54
 
55
 
56
  def component_show_pipeline(pipeline, pipeline_name):
interface/config.py CHANGED
@@ -4,6 +4,7 @@ from interface.pages import page_landing_page, page_search, page_index
4
  session_state_variables = {
5
  "pipeline": None,
6
  "pipeline_func_parameters": [],
 
7
  "doc_id": 0,
8
  }
9
 
 
4
  session_state_variables = {
5
  "pipeline": None,
6
  "pipeline_func_parameters": [],
7
+ "search_results": None,
8
  "doc_id": 0,
9
  }
10
 
interface/pages.py CHANGED
@@ -49,11 +49,12 @@ def page_search(container):
49
  component_show_pipeline(st.session_state["pipeline"], "search_pipeline")
50
 
51
  if st.button("Search"):
52
- st.session_state["search_results"] = search(
53
- queries=[query],
54
- pipeline=st.session_state["pipeline"]["search_pipeline"],
55
- )
56
- if "search_results" in st.session_state:
 
57
  component_show_search_result(
58
  container=container, results=st.session_state["search_results"][0]
59
  )
@@ -87,11 +88,12 @@ def page_index(container):
87
  if len(corpus) > 0:
88
  index_results = None
89
  if st.button("Index"):
90
- index_results = index(
91
- documents=corpus,
92
- pipeline=st.session_state["pipeline"]["index_pipeline"],
93
- clear_index=clear_index,
94
- )
95
- st.session_state["doc_id"] = doc_id
 
96
  if index_results:
97
  st.write(index_results)
 
49
  component_show_pipeline(st.session_state["pipeline"], "search_pipeline")
50
 
51
  if st.button("Search"):
52
+ with st.spinner("Searching..."):
53
+ st.session_state["search_results"] = search(
54
+ queries=[query],
55
+ pipeline=st.session_state["pipeline"]["search_pipeline"],
56
+ )
57
+ if st.session_state["search_results"] is not None:
58
  component_show_search_result(
59
  container=container, results=st.session_state["search_results"][0]
60
  )
 
88
  if len(corpus) > 0:
89
  index_results = None
90
  if st.button("Index"):
91
+ with st.spinner("Indexing..."):
92
+ index_results = index(
93
+ documents=corpus,
94
+ pipeline=st.session_state["pipeline"]["index_pipeline"],
95
+ clear_index=clear_index,
96
+ )
97
+ st.session_state["doc_id"] = doc_id
98
  if index_results:
99
  st.write(index_results)
interface/utils.py CHANGED
@@ -28,6 +28,7 @@ def get_pipelines():
28
 
29
  def reset_vars_data():
30
  st.session_state["doc_id"] = 0
 
31
  # Delete data files
32
  shutil.rmtree(data_path)
33
  os.makedirs(data_path, exist_ok=True)
 
28
 
29
  def reset_vars_data():
30
  st.session_state["doc_id"] = 0
31
+ st.session_state["search_results"] = None
32
  # Delete data files
33
  shutil.rmtree(data_path)
34
  os.makedirs(data_path, exist_ok=True)