Spaces:
Runtime error
Runtime error
Commit
·
3b015d2
1
Parent(s):
58774f9
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,17 @@
|
|
1 |
import base64
|
|
|
|
|
|
|
2 |
import os
|
3 |
import shutil
|
|
|
4 |
import uuid
|
5 |
import zipfile
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
from glob import glob
|
8 |
from io import BytesIO
|
9 |
from itertools import cycle
|
@@ -14,6 +22,31 @@ import streamlit as st
|
|
14 |
from PIL import Image
|
15 |
from st_btn_select import st_btn_select
|
16 |
from streamlit_image_select import image_select
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
if "key" not in st.session_state:
|
19 |
st.session_state["key"] = uuid.uuid4().hex
|
@@ -43,22 +76,138 @@ if "captcha_response" not in st.session_state:
|
|
43 |
if "captcha" not in st.session_state:
|
44 |
st.session_state["captcha"] = {}
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def callback():
|
47 |
st.session_state["button_clicked"] = True
|
48 |
|
49 |
-
os.system('aws configure set default.s3.multipart_threshold 200MB')
|
50 |
|
51 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
if not os.path.exists(identifier):
|
53 |
os.makedirs(identifier)
|
|
|
|
|
54 |
for num, uploaded_file in enumerate(uploaded_files):
|
55 |
file_ = Image.open(uploaded_file).convert("RGB")
|
56 |
file_.save(f"{identifier}/{num}_test.png")
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate"
|
64 |
VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify"
|
@@ -71,18 +220,19 @@ def generate_captcha():
|
|
71 |
# If the request was successful, return the API response
|
72 |
if response.status_code == 200:
|
73 |
return response.json()
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
77 |
|
78 |
# Create a function to verify the captcha
|
79 |
def verify_captcha(captcha_id, captcha_response):
|
80 |
# Make a POST request to the API endpoint with the captcha ID and response
|
81 |
-
|
82 |
response = requests.post(
|
83 |
-
VERIFY_ENDPOINT, json=
|
84 |
)
|
85 |
-
|
86 |
|
87 |
# If the request was successful, return the API response
|
88 |
if response.status_code == 200:
|
@@ -91,130 +241,192 @@ def verify_captcha(captcha_id, captcha_response):
|
|
91 |
# Otherwise, return an error message
|
92 |
return {"error": "Failed to verify captcha"}
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
def train_model(model_inputs):
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
with
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
st.session_state["s3_face_file_path"] = zip_and_upload_images(
|
114 |
-
identifier, uploaded_files, "face"
|
115 |
-
)
|
116 |
-
st.success(f'Uploading {len(uploaded_files)} files done!')
|
117 |
-
|
118 |
-
preset_theme_images = st.empty()
|
119 |
-
with preset_theme_images.form("choose-preset-theme"):
|
120 |
-
img = image_select(
|
121 |
-
"Choose a Theme!",
|
122 |
-
images=[
|
123 |
-
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png",
|
124 |
-
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png",
|
125 |
-
"https://ichef.bbci.co.uk/images/ic/640x360/p09t1hg0.jpg",
|
126 |
-
],
|
127 |
-
captions=["Game of Thrones", "Iron Man", "Thor"],
|
128 |
-
return_value="index",
|
129 |
-
)
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
with col2:
|
152 |
-
submitted_4 = st.form_submit_button(
|
153 |
-
"If none of the themes interest you, click here!"
|
154 |
-
)
|
155 |
-
if submitted_4:
|
156 |
-
st.session_state["view"] = True
|
157 |
-
|
158 |
-
if st.session_state["view"]:
|
159 |
-
custom_theme_images = st.empty()
|
160 |
-
with custom_theme_images.form("input_custom_themes"):
|
161 |
-
st.markdown("If none of the themes interest you, please input your own!")
|
162 |
-
uploaded_files_2 = st.file_uploader(
|
163 |
-
"Choose image files",
|
164 |
-
accept_multiple_files=True,
|
165 |
-
type=["png", "jpg", "jpeg"],
|
166 |
)
|
167 |
-
|
168 |
-
|
169 |
-
if submitted_3:
|
170 |
with st.spinner('Uploading...'):
|
171 |
-
st.session_state["
|
172 |
-
identifier,
|
173 |
)
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
st.
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
if submitted:
|
209 |
-
|
210 |
-
|
211 |
-
st.session_state['captcha'] = {}
|
212 |
-
with st.spinner("Model Fine Tuning..."):
|
213 |
-
st.session_state["model_inputs"]["identifier"] = st.session_state["key"]
|
214 |
-
st.session_state["model_inputs"]["email"] = email
|
215 |
-
train_model(st.session_state["model_inputs"])
|
216 |
-
st.session_state["train_view"] = True
|
217 |
else:
|
218 |
-
st.
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
+
import boto3
|
3 |
+
from botocore.config import Config
|
4 |
+
from dotenv import load_dotenv
|
5 |
import os
|
6 |
import shutil
|
7 |
+
from typing import List, Tuple
|
8 |
import uuid
|
9 |
import zipfile
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
import sendgrid
|
13 |
+
from sendgrid.helpers.mail import Mail, Email, To, Content
|
14 |
+
|
15 |
from glob import glob
|
16 |
from io import BytesIO
|
17 |
from itertools import cycle
|
|
|
22 |
from PIL import Image
|
23 |
from st_btn_select import st_btn_select
|
24 |
from streamlit_image_select import image_select
|
25 |
+
import smart_open
|
26 |
+
|
27 |
+
logging.basicConfig()
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
logger.setLevel(logging.INFO)
|
30 |
+
|
31 |
+
# Looks for .env file in current directory to pull environment variables. Should
|
32 |
+
# not overwrite already set environment variables. Used for S3 credentials.
|
33 |
+
load_dotenv()
|
34 |
+
|
35 |
+
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
|
36 |
+
|
37 |
+
# Command-line arguments to control some stuff for easier local testing.
|
38 |
+
# Eventually may want to move everything into functions and have a
|
39 |
+
# if __name__ == "main" setup instead of everything inline.
|
40 |
+
parser = argparse.ArgumentParser()
|
41 |
+
parser.add_argument(
|
42 |
+
"--dry-run", action="store_true",
|
43 |
+
help="Skip sending train request to backend server.",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--train-endpoint-url", default=None,
|
47 |
+
help="URL of backend server to send train request to. If None, use hardcoded banana setup.",
|
48 |
+
)
|
49 |
+
cli_args = parser.parse_args()
|
50 |
|
51 |
if "key" not in st.session_state:
|
52 |
st.session_state["key"] = uuid.uuid4().hex
|
|
|
76 |
if "captcha" not in st.session_state:
|
77 |
st.session_state["captcha"] = {}
|
78 |
|
79 |
+
if "login" not in st.session_state:
|
80 |
+
st.session_state["login"] = None
|
81 |
+
|
82 |
+
if "user_auth_sess" not in st.session_state:
|
83 |
+
st.session_state["user_auth_sess"] = False
|
84 |
+
|
85 |
+
if "user_email" not in st.session_state:
|
86 |
+
st.session_state["email_provided"] = True
|
87 |
+
|
88 |
def callback():
|
89 |
st.session_state["button_clicked"] = True
|
90 |
|
|
|
91 |
|
92 |
+
def bucket_parts(s3_path: str) -> Tuple[str, str]:
|
93 |
+
"""Split an S3 path into bucket and key.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
s3_path: path starting with "s3:"
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Tuple of bucket and key for the path
|
100 |
+
"""
|
101 |
+
parts = s3_path.split("/")
|
102 |
+
bucket = parts[2]
|
103 |
+
key = "/".join(parts[3:])
|
104 |
+
return bucket, key
|
105 |
+
|
106 |
+
|
107 |
+
def generate_s3_get_url(s3_path: str, expiration_seconds: int) -> str:
|
108 |
+
"""Generate a presigned S3 url to read from an S3 path.
|
109 |
+
|
110 |
+
A presigned url allows anyone accessing that url to read the s3 path without
|
111 |
+
needing s3 credentials until the url expires.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
s3_path: path starting with "s3:"
|
115 |
+
expiration_seconds: how long the url will be valid (does not influence
|
116 |
+
lifetime of the underlying s3 object, only the presigned url)
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
The presigned url
|
120 |
+
"""
|
121 |
+
bucket, key = bucket_parts(s3_path)
|
122 |
+
|
123 |
+
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"}))
|
124 |
+
download_url = s3_client.generate_presigned_url(
|
125 |
+
"get_object",
|
126 |
+
Params={
|
127 |
+
"Bucket": bucket,
|
128 |
+
"Key": key
|
129 |
+
},
|
130 |
+
ExpiresIn=expiration_seconds
|
131 |
+
)
|
132 |
+
return download_url
|
133 |
+
|
134 |
+
|
135 |
+
def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str:
|
136 |
+
"""Generate a presigned S3 url to write to an S3 path.
|
137 |
+
|
138 |
+
A presigned url allows anyone accessing that url to write to the s3 path
|
139 |
+
without needing s3 credentials until the url expires.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
s3_path: path starting with "s3:"
|
143 |
+
expiration_seconds: how long the url will be valid (does not influence
|
144 |
+
lifetime of the underlying s3 object, only the presigned url)
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
The presigned url
|
148 |
+
"""
|
149 |
+
bucket, key = bucket_parts(s3_path)
|
150 |
+
|
151 |
+
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"}))
|
152 |
+
upload_url = s3_client.generate_presigned_url(
|
153 |
+
"put_object",
|
154 |
+
Params={
|
155 |
+
"Bucket": bucket,
|
156 |
+
"Key": key
|
157 |
+
},
|
158 |
+
ExpiresIn=expiration_seconds
|
159 |
+
)
|
160 |
+
return upload_url
|
161 |
+
|
162 |
+
|
163 |
+
def zip_and_upload_images(identifier: str, uploaded_files: List[str], image_type: str) -> str:
|
164 |
+
"""Save images as zip file to s3 for use in backend.
|
165 |
+
|
166 |
+
Blocks until images are processed, added to zip file, and uploaded to S3.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
identifier: unique identifier for the run, used in s3 link
|
170 |
+
uploaded_files: list of file names
|
171 |
+
image_type: string to identify different batches of images used in the
|
172 |
+
backend model/training. Currently used values: "face", "theme"
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
S3 location of zip file containing png images.
|
176 |
+
"""
|
177 |
if not os.path.exists(identifier):
|
178 |
os.makedirs(identifier)
|
179 |
+
|
180 |
+
logger.info("Processing uploaded images")
|
181 |
for num, uploaded_file in enumerate(uploaded_files):
|
182 |
file_ = Image.open(uploaded_file).convert("RGB")
|
183 |
file_.save(f"{identifier}/{num}_test.png")
|
184 |
+
local_zip_filestem = f"{identifier}_{image_type}_images"
|
185 |
+
logger.info("Making zip archive")
|
186 |
+
shutil.make_archive(local_zip_filestem, "zip", identifier)
|
187 |
+
local_zip_filename = f"{local_zip_filestem}.zip"
|
188 |
+
|
189 |
+
logger.info("Uploading zip file to s3")
|
190 |
+
# TODO: can we define expiration when making the s3 path?
|
191 |
+
# Probably if we use the boto3 library instead of smart open
|
192 |
+
s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type)
|
193 |
+
|
194 |
+
with open(local_zip_filename, "rb") as fin:
|
195 |
+
with smart_open.open(s3_path, "wb") as fout:
|
196 |
+
fout.write(fin.read())
|
197 |
+
logger.info(f"Completed upload to {s3_path}")
|
198 |
+
|
199 |
+
return s3_path
|
200 |
+
|
201 |
+
def send_email(to_email, user_code):
|
202 |
+
sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
|
203 |
+
from_email = Email("[email protected]")
|
204 |
+
to_email = To(to_email)
|
205 |
+
subject = "One Time Code"
|
206 |
+
content = Content("text/plain", f"Here is your one-time code: {user_code}")
|
207 |
+
mail = Mail(from_email, to_email, subject, content)
|
208 |
+
mail_json = mail.get()
|
209 |
+
response = sg.client.mail.send.post(request_body=mail_json)
|
210 |
+
|
211 |
|
212 |
CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate"
|
213 |
VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify"
|
|
|
220 |
# If the request was successful, return the API response
|
221 |
if response.status_code == 200:
|
222 |
return response.json()
|
223 |
+
else:
|
224 |
+
logger.warn(f"Error from generate captcha request: {response.json()}")
|
225 |
+
# Otherwise, return an error message
|
226 |
+
return {"error": "Failed to generate captcha"}
|
227 |
|
228 |
# Create a function to verify the captcha
|
229 |
def verify_captcha(captcha_id, captcha_response):
|
230 |
# Make a POST request to the API endpoint with the captcha ID and response
|
231 |
+
verify_json = {"uuid": captcha_id, "captcha": captcha_response}
|
232 |
response = requests.post(
|
233 |
+
VERIFY_ENDPOINT, json=verify_json,
|
234 |
)
|
235 |
+
logger.info(f"Response from captcha verify: {response}")
|
236 |
|
237 |
# If the request was successful, return the API response
|
238 |
if response.status_code == 200:
|
|
|
241 |
# Otherwise, return an error message
|
242 |
return {"error": "Failed to verify captcha"}
|
243 |
|
|
|
|
|
|
|
244 |
def train_model(model_inputs):
|
245 |
+
if cli_args.dry_run:
|
246 |
+
logger.info("Skipping model training since --dry-run is enabled.")
|
247 |
+
logger.info(f"model_inputs: {model_inputs}")
|
248 |
+
return
|
249 |
|
250 |
+
if cli_args.train_endpoint_url is None:
|
251 |
+
# Use banana backend
|
252 |
+
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
|
253 |
+
model_key = "1a3b4ce5-164f-4efb-9f4a-c2ad3a930d0b"
|
254 |
+
st.markdown(str(model_inputs))
|
255 |
+
_ = banana.run(api_key, model_key, model_inputs)
|
256 |
+
else:
|
257 |
+
# Send request directly to specified url
|
258 |
+
_ = requests.post(cli_args.train_endpoint_url, json=model_inputs)
|
259 |
|
260 |
+
if st.session_state["email_provided"]:
|
261 |
+
user_email_input = st.empty()
|
262 |
+
with user_email_input.form(key='user_auth'):
|
263 |
+
text_input = st.text_input(label='Please Enter Your Email')
|
264 |
+
submit_button = st.form_submit_button(label='Submit')
|
265 |
+
if submit_button:
|
266 |
+
st.session_state["user_auth_sess"] = True
|
267 |
+
st.session_state["email_provided"] = False
|
268 |
+
send_email(text_input, str(st.session_state["key"]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
+
if st.session_state["user_auth_sess"]:
|
271 |
+
user_auth = st.empty()
|
272 |
+
user_email_input.empty()
|
273 |
+
with user_auth.form("one-code"):
|
274 |
+
text_input = st.text_input(label='Please Input One Time Code')
|
275 |
+
submit_button = st.form_submit_button(label='Submit')
|
276 |
+
if submit_button:
|
277 |
+
if text_input == st.session_state["key"]:
|
278 |
+
st.session_state["login"] = True
|
279 |
+
else:
|
280 |
+
st.markdown("Please Enter Correct Code!")
|
281 |
+
|
282 |
+
if st.session_state["login"]:
|
283 |
+
identifier = st.session_state["key"]
|
284 |
+
user_auth.empty()
|
285 |
+
user_email_input.empty()
|
286 |
+
face_images = st.empty()
|
287 |
+
with face_images.form("my_form"):
|
288 |
+
uploaded_files = st.file_uploader(
|
289 |
+
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
)
|
291 |
+
submitted = st.form_submit_button(f"Upload")
|
292 |
+
if submitted:
|
|
|
293 |
with st.spinner('Uploading...'):
|
294 |
+
st.session_state["s3_face_file_path"] = zip_and_upload_images(
|
295 |
+
identifier, uploaded_files, "face"
|
296 |
)
|
297 |
+
st.success(f'Uploading {len(uploaded_files)} files done!')
|
298 |
+
|
299 |
+
preset_theme_images = st.empty()
|
300 |
+
with preset_theme_images.form("choose-preset-theme"):
|
301 |
+
img = image_select(
|
302 |
+
"Choose a Theme!",
|
303 |
+
images=[
|
304 |
+
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png",
|
305 |
+
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png",
|
306 |
+
"https://ichef.bbci.co.uk/images/ic/640x360/p09t1hg0.jpg",
|
307 |
+
],
|
308 |
+
captions=["Game of Thrones", "Iron Man", "Thor"],
|
309 |
+
return_value="index",
|
310 |
+
)
|
311 |
+
|
312 |
+
col1, col2 = st.columns([0.17, 1])
|
313 |
+
with col1:
|
314 |
+
submitted_3 = st.form_submit_button("Submit!")
|
315 |
+
if submitted_3:
|
316 |
+
with st.spinner():
|
317 |
+
dictionary = {
|
318 |
+
0: [
|
319 |
+
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/game-of-thrones.zip",
|
320 |
+
"game-of-thrones",
|
321 |
+
],
|
322 |
+
1: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/iron-man.zip", "iron-man"],
|
323 |
+
2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"],
|
324 |
+
}
|
325 |
+
st.session_state["model_inputs"] = {
|
326 |
+
"superhero_file_path": dictionary[img][0],
|
327 |
+
# Use presigned url since backend does not have credentials
|
328 |
+
"person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600),
|
329 |
+
"superhero_prompt": dictionary[img][1],
|
330 |
+
"num_images": 50,
|
331 |
+
}
|
332 |
+
st.success("Success!")
|
333 |
+
with col2:
|
334 |
+
submitted_4 = st.form_submit_button(
|
335 |
+
"If none of the themes interest you, click here!"
|
336 |
+
)
|
337 |
+
if submitted_4:
|
338 |
+
st.session_state["view"] = True
|
339 |
+
|
340 |
+
if st.session_state["view"]:
|
341 |
+
custom_theme_images = st.empty()
|
342 |
+
with custom_theme_images.form("input_custom_themes"):
|
343 |
+
st.markdown("If none of the themes interest you, please input your own!")
|
344 |
+
uploaded_files_2 = st.file_uploader(
|
345 |
+
"Choose image files",
|
346 |
+
accept_multiple_files=True,
|
347 |
+
type=["png", "jpg", "jpeg"],
|
348 |
+
)
|
349 |
+
title = st.text_input("Theme Name")
|
350 |
+
submitted_3 = st.form_submit_button("Submit!")
|
351 |
+
if submitted_3:
|
352 |
+
with st.spinner('Uploading...'):
|
353 |
+
st.session_state["s3_theme_file_path"] = zip_and_upload_images(
|
354 |
+
identifier, uploaded_files_2, "theme"
|
355 |
+
)
|
356 |
+
st.session_state["model_inputs"] = {
|
357 |
+
# Use presigned urls since backend does not have credentials
|
358 |
+
"superhero_file_path": generate_s3_get_url(st.session_state["s3_theme_file_path"], expiration_seconds=3600),
|
359 |
+
"person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600),
|
360 |
+
"superhero_prompt": title,
|
361 |
+
"num_images": 50,
|
362 |
+
}
|
363 |
+
st.success('Done!')
|
364 |
+
|
365 |
+
train = st.empty()
|
366 |
+
with train.form("training"):
|
367 |
+
col1, col3, col2 = st.columns(3)
|
368 |
+
with col1:
|
369 |
+
email = st.text_input("Enter Email")
|
370 |
+
with col2:
|
371 |
+
submitted = st.form_submit_button("Train Model!")
|
372 |
if submitted:
|
373 |
+
if not email:
|
374 |
+
st.markdown('Please input an email!')
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
else:
|
376 |
+
st.session_state["captcha_bool"] = True
|
377 |
+
|
378 |
+
if st.session_state["captcha_bool"]:
|
379 |
+
captcha_form = st.empty()
|
380 |
+
with captcha_form.form("captcha_form", clear_on_submit=True):
|
381 |
+
# Create container to create image/text input out of order from the
|
382 |
+
# format submit button. Needed since we need to know the status of the
|
383 |
+
# form submit to know what the captcha should do.
|
384 |
+
captcha_container = st.container()
|
385 |
+
display_captcha = True
|
386 |
+
# TODO: Submit button renders first, then drops down once the image is
|
387 |
+
# fetched leading to page reflow. Would be nice to not have reflow, but
|
388 |
+
# we need to know if the submit button was previously pressed and if the
|
389 |
+
# captcha was solved to generate and display a new captcha or not.
|
390 |
+
# Possible solution is use an on_click callback to set a session_state
|
391 |
+
# variable to access whether the button was pushed or not instead of the
|
392 |
+
# return value here.
|
393 |
+
submitted = st.form_submit_button("Submit Captcha!")
|
394 |
+
|
395 |
+
if submitted:
|
396 |
+
result = verify_captcha(st.session_state['captcha']['uuid'], st.session_state["captcha_response"])
|
397 |
+
del st.session_state["captcha_response"]
|
398 |
+
if 'message' in result and result['message'] == 'CAPTCHA_SOLVED':
|
399 |
+
st.session_state['captcha'] = {}
|
400 |
+
display_captcha = False
|
401 |
+
with st.spinner("Model Fine Tuning..."):
|
402 |
+
st.session_state["model_inputs"]["identifier"] = st.session_state["key"]
|
403 |
+
st.session_state["model_inputs"]["email"] = email
|
404 |
+
s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
|
405 |
+
# The backend does not have s3 credentials, so generate
|
406 |
+
# presigned urls for the backend to use to write and read
|
407 |
+
# the generated images.
|
408 |
+
st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url(
|
409 |
+
s3_output_path, expiration_seconds=60 * 60 * 24,
|
410 |
+
)
|
411 |
+
st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url(
|
412 |
+
s3_output_path, expiration_seconds=3600,
|
413 |
+
)
|
414 |
+
train_model(st.session_state["model_inputs"])
|
415 |
+
st.session_state["train_view"] = True
|
416 |
+
else:
|
417 |
+
st.error(result['error'])
|
418 |
+
|
419 |
+
if display_captcha:
|
420 |
+
# Generate new captcha and display. Occurs on first load with the
|
421 |
+
# captcha_bool=True, or after previously failed captcha attempts.
|
422 |
+
result = generate_captcha()
|
423 |
+
captcha_id = result['uuid']
|
424 |
+
captcha_image = result['captcha']
|
425 |
+
|
426 |
+
st.session_state['captcha']['uuid'] = captcha_id
|
427 |
+
st.session_state['captcha']['captcha'] = captcha_image
|
428 |
+
|
429 |
+
captcha_container.image(captcha_image, width=300)
|
430 |
+
|
431 |
+
captcha_container.text_input("Enter the captcha response", key="captcha_response")
|
432 |
+
# Submit button already setup previously.
|