File size: 3,165 Bytes
073b060
f1eb272
 
 
 
073b060
f1eb272
 
 
 
073b060
cce7334
073b060
f1eb272
cce7334
f1eb272
 
 
 
cce7334
f1eb272
 
 
 
 
 
 
 
 
 
 
 
 
073b060
 
 
 
 
 
f1eb272
 
 
073b060
 
 
f1eb272
 
 
 
 
 
 
 
 
073b060
f1eb272
 
 
 
 
 
 
 
 
 
 
 
073b060
 
 
 
 
 
 
 
f1eb272
 
 
 
 
 
 
 
073b060
 
f1eb272
073b060
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

import torch
import numpy as np
from openai import OpenAI
import os

client = OpenAI()

import streamlit as st
from PIL import Image
from diffusers import AutoPipelineForText2Image
import random
@st.cache_data(ttl=600)
def get_prompt_to_guess():
    random_prompt = ["tree", "cat", "dog", "consultant", "artificial intelligence", "beauty", "immeuble", "plage", "cyborg", "futuristic"]
    response = client.chat.completions.create(
    model="gpt-3.5-turbo",
    messages=[
        {"role": "system", "content": "You are a helpful assistant to generate one simple prompt in order to generate an image. Your given prompt won't go over 10 words. You only return the prompt. You will also answer in french."},
        {"role": "user", "content": f"Donne moi un prompt pour generer une image de {random.choice(random_prompt)}"},
    ]
    )
    return response.choices[0].message.content

@st.cache_resource
def get_model():
    pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float32, variant="fp16")
    return pipe

@st.cache_data
def generate_image(_pipe, prompt):
    return _pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0.0, seed=1).images[0]

if "ask_answer" not in st.session_state:
    st.session_state["ask_answer"] = False

if "testing" not in st.session_state:
    st.session_state["testing"] = False

if "submit_guess" not in st.session_state:
    st.session_state["submit_guess"] = False

if "real_ask_answer" not in st.session_state:
    st.session_state["real_ask_answer"] = False

def check_prompt(prompt, prompt_to_guess):
    return prompt.strip() == prompt_to_guess.strip()

pipe = get_model()
prompt = get_prompt_to_guess()
im_to_guess = generate_image(pipe, prompt)
h, w = im_to_guess.size

st.title("Guess the prompt by Ekimetrics")
st.text("Rules : guess the prompt (in French, with no fault) to generate the left image with the sdxl turbo model")
st.text("Hint : use right side to help you guess the prompt by testing some")
st.text("Disclosure : this runs on CPU so generation are quite slow (even with sdxl turbo)")
col_1, col_2 = st.columns([0.5, 0.5])
with col_1:
    st.header("GUESS THE PROMPT")
    guessed_prompt = st.text_area("Input your guess prompt")
    st.session_state["submit_guess"] = st.button("guess the prompt")
    if st.session_state["submit_guess"]:
        if check_prompt(guessed_prompt, prompt):
            st.text("Good prompt ! test again in 24h !")
        else:
            st.text("wrong prompt !")
    st.session_state["ask_answer"] = st.button("get the answer")
    if st.session_state["ask_answer"]:
        st.text(f"Cheater ! but here is the prompt : \n {prompt}")
    st.image(im_to_guess)
    
        
if "testing" not in st.session_state:
    st.session_state["testing"] = False

with col_2:
    st.header("TEST THE PROMPT")
    if st.session_state["testing"]:
        im = generate_image(pipe, testing_prompt)
        st.session_state["testing"] = False
    else:
        im = np.zeros([h,w,3])
    testing_prompt = st.text_area("Input your testing prompt")
    st.session_state["testing"] = st.button("test the prompt")
    st.image(im)