File size: 4,012 Bytes
cb2ac60
 
 
ffed138
8e37fb8
cb2ac60
 
8e37fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f58819
8e37fb8
 
c286f15
 
 
 
 
 
 
 
 
 
 
c1cfda4
 
2049545
4325576
2049545
c1cfda4
32f68a8
 
f95204e
851cea0
32f68a8
 
0e8338d
4325576
9f85e8c
482963e
0e8338d
 
a7f2bba
9f85e8c
8e37fb8
9f85e8c
a7f2bba
 
 
bb7c400
a7f2bba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813a400
bb7c400
cb2ac60
ffed138
 
 
 
482963e
 
 
 
9f85e8c
 
8e37fb8
 
0b4b9a3
ffed138
9f85e8c
ffed138
9f85e8c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python
# coding: utf-8

from dalle_mini.backend import ServiceError, get_images_from_backend

import streamlit as st

# streamlit.session_state is not available in Huggingface spaces.
# Session state hack https://huggingface.slack.com/archives/C025LJDP962/p1626527367443200?thread_ts=1626525999.440500&cid=C025LJDP962

from streamlit.report_thread import get_report_ctx
def query_cache(q_emb=None):
    ctx = get_report_ctx()
    session_id = ctx.session_id
    session = st.server.server.Server.get_current()._get_session_info(session_id).session
    if not hasattr(session, "_query_state"):
        setattr(session, "_query_state", q_emb)
    if q_emb:
        session._query_state = q_emb
    return session._query_state

def set_run_again(state):
    query_cache(state)

def should_run_again():
    state = query_cache()
    return state if state is not None else False

st.sidebar.markdown("""
<style>
.aligncenter {
    text-align: center;
}
</style>
<p class="aligncenter">
    <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
</p>
""", unsafe_allow_html=True)
st.sidebar.markdown("""
___
<p style='text-align: center'>
DALL·E mini is an AI model that generates images from any prompt you give!
</p>

<p style='text-align: center'>
Created by Boris Dayma et al. 2021
<br/>
<a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
</p>
        """, unsafe_allow_html=True)

st.header('DALL·E mini')
st.subheader('Generate images from text')

prompt = st.text_input("What do you want to see?")

test = st.empty()
DEBUG = False
if prompt != "" or (should_run_again and prompt != ""):
    container = st.empty()
    # The following mimics `streamlit.info()`.
    # I tried to get the secondary background color using `components.streamlit.config.get_options_for_section("theme")["secondaryBackgroundColor"]`
    # but it returns None.
    container.markdown(f"""
        <style> p {{ margin:0 }} div {{ margin:0 }} </style>
        <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
        <div class="stAlert">
        <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
        <div class="st-b7">
        <div class="css-whx05o e13vu3m50">
        <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
                <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/img/loading.gif" width="30"/>
                Generating predictions for: <b>{prompt}</b>
        </div>
        </div>
        </div>
        </div>
        </div>
        </div>
        <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
    """, unsafe_allow_html=True)

    try:
        backend_url = st.secrets["BACKEND_SERVER"]
        print(f"Getting selections: {prompt}")
        selected = get_images_from_backend(prompt, backend_url)

        cols = st.beta_columns(4)
        for i, img in enumerate(selected):
            cols[i%4].image(img)

        container.markdown(f"**{prompt}**")
        
        set_run_again(st.button('Again!', key='again_button'))
    
    except ServiceError as error:
        container.text(f"Service unavailable, status: {error.status_code}")
    except KeyError:
        if DEBUG:
            container.markdown("""
            **Error: BACKEND_SERVER unset**

            Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
            ```
            BACKEND_SERVER="<server url>"
            ```
            """)
        else:
            container.markdown('Error -5, please try again or [report it](mailto:[email protected]).')