Pedro Cuenca commited on
Commit
3f58819
·
1 Parent(s): bbc3a60

streamlit session_state hack for v0.79

Browse files
Files changed (1) hide show
  1. app/app.py +21 -7
app/app.py CHANGED
@@ -6,12 +6,26 @@ from dalle_mini.backend import ServiceError, get_images_from_backend
6
 
7
  import streamlit as st
8
 
9
- # st.sidebar.title("DALL·E mini")
 
10
 
11
- # sc = st.sidebar.beta_columns(2)
12
- # st.sidebar.image('../img/logo.png', width=150)
13
- # sc[1].write(" ")
14
- # st.sidebar.markdown("Generate images from text")
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  st.sidebar.markdown("""
17
  <style>
@@ -41,7 +55,7 @@ prompt = st.text_input("What do you want to see?")
41
  #TODO: I think there's an issue where we can't run twice the same inference (not due to caching) - may need to use st.form
42
 
43
  DEBUG = False
44
- if prompt != "" or (st.session_state.get("again", False) and prompt != ""):
45
  container = st.empty()
46
  container.markdown(f"Generating predictions for: **{prompt}**")
47
 
@@ -56,7 +70,7 @@ if prompt != "" or (st.session_state.get("again", False) and prompt != ""):
56
 
57
  container.markdown(f"**{prompt}**")
58
 
59
- st.session_state["again"] = st.button('Again!', key='again_button')
60
 
61
  except ServiceError as error:
62
  container.text(f"Service unavailable, status: {error.status_code}")
 
6
 
7
  import streamlit as st
8
 
9
+ # streamlit.session_state is not available in Huggingface spaces.
10
+ # Session state hack https://huggingface.slack.com/archives/C025LJDP962/p1626527367443200?thread_ts=1626525999.440500&cid=C025LJDP962
11
 
12
+ from streamlit.report_thread import get_report_ctx
13
+ def query_cache(q_emb=None):
14
+ ctx = get_report_ctx()
15
+ session_id = ctx.session_id
16
+ session = st.server.server.Server.get_current()._get_session_info(session_id).session
17
+ if not hasattr(session, "_query_state"):
18
+ setattr(session, "_query_state", q_emb)
19
+ if q_emb:
20
+ session._query_state = q_emb
21
+ return session._query_state
22
+
23
+ def set_run_again(state):
24
+ query_cache(state)
25
+
26
+ def should_run_again():
27
+ state = query_cache()
28
+ return state if state is not None else False
29
 
30
  st.sidebar.markdown("""
31
  <style>
 
55
  #TODO: I think there's an issue where we can't run twice the same inference (not due to caching) - may need to use st.form
56
 
57
  DEBUG = False
58
+ if prompt != "" or (should_run_again and prompt != ""):
59
  container = st.empty()
60
  container.markdown(f"Generating predictions for: **{prompt}**")
61
 
 
70
 
71
  container.markdown(f"**{prompt}**")
72
 
73
+ set_run_again(st.button('Again!', key='again_button'))
74
 
75
  except ServiceError as error:
76
  container.text(f"Service unavailable, status: {error.status_code}")