|
from datasets import load_dataset |
|
from collections import Counter, defaultdict |
|
from random import sample, shuffle |
|
from collections import Counter |
|
import datasets |
|
from pandas import DataFrame |
|
from huggingface_hub import list_datasets |
|
import os |
|
import gradio as gr |
|
|
|
import secrets |
|
|
|
|
|
parti_prompt_results = [] |
|
ORG = "diffusers-parti-prompts" |
|
SUBMISSIONS = { |
|
"kand2": load_dataset(os.path.join(ORG, "kandinsky-2-2"))["train"], |
|
"sdxl": load_dataset(os.path.join(ORG, "sdxl-1.0-refiner"))["train"], |
|
"wuerst": load_dataset(os.path.join(ORG, "wuerstchen"))["train"], |
|
"karlo": load_dataset(os.path.join(ORG, "karlo-v1"))["train"], |
|
} |
|
|
|
LINKS = { |
|
"kand2": "https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder", |
|
"sdxl": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0", |
|
"wuerst": "https://huggingface.co/warp-ai/wuerstchen", |
|
"karlo": "https://huggingface.co/kakaobrain/karlo-v1-alpha", |
|
} |
|
KANDINSKY = """ |
|
"## The creative one π¨! |
|
![img](https://aeiljuispo.cloudimg.io/v7/https://cdn-uploads.huggingface.co/production/uploads/5dfcb1aada6d0311fd3d5448/rETvCyoUD5Mr9wm6OxUhe.png?w=200&h=200&f=face) |
|
\n You mostly resonate with **Kandinsky 2.2** released by AI Forever. |
|
\n Kandinsky 2.2 has a similar architecture to DALLE-2 and works extremely well for artistic, colorful generations. |
|
\n Check out your soulmate [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). |
|
""" |
|
SDXL_RESULT = """ |
|
## The powerful one β‘! |
|
![img](https://huggingface.co/datasets/OpenGenAI/logos/resolve/main/7vmYr2XwVcPtkLzac_jxQ.png) |
|
\n You mostly resonate with **Stable Diffusion XL** released by Stability AI. |
|
\n Stable Diffusion XL consists of a two diffusion models that are chained together, a base model and a refiner model. Together, the system contains roughly 5 billion parameters. |
|
\n It's the latest open-source release of Stable Diffusion and allows to render stunning images of much larger sizes than Stable Diffusion v1. |
|
Try it out [here](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). |
|
""" |
|
WUERSTCHEN = """ |
|
## The innovative one βοΈ ! |
|
![img](https://www.gravatar.com/avatar/3219846609129e84790fb83793998d61?d=retro&size=100) |
|
\n You mostly resonate with **Wuerstchen** released by the WARP team. |
|
\n Wuerstchen is a three stage diffusion model that proposed a very novel, innovative model architecture. |
|
\n Wuerstchen is able to generate very large images (up to 1024x2048) in just a few seconds. |
|
\n The model has an amazing image quality vs. speed trade-off. |
|
\n Check out your new best friend [here](https://huggingface.co/warp-ai/wuerstchen). |
|
""" |
|
KARLO = """ |
|
## The precise one π―! |
|
![img](https://huggingface.co/datasets/OpenGenAI/logos/resolve/main/1670220967262-615ed619c807b26d117a49bd.png) |
|
\n You mostly resonate with **Karlo** released by KakaoBrain. |
|
\n Karlo is based on the same architecture as DALLE-2 and has been trained on the [well curated COYO dataset](https://huggingface.co/datasets/kakaobrain/coyo-700m). |
|
\n Play around with it [here]("https://huggingface.co/kakaobrain/karlo-v1-alpha"). |
|
""" |
|
|
|
RESULT = { |
|
"kand2": KANDINSKY, |
|
"wuerst": WUERSTCHEN, |
|
"sdxl": SDXL_RESULT, |
|
"karlo": KARLO, |
|
} |
|
NUM_QUESTIONS = 10 |
|
MODEL_KEYS = "-".join(SUBMISSIONS.keys()) |
|
SUBMISSION_ORG = f"result-{MODEL_KEYS}" |
|
PROMPT_FORMAT = " Select the image that best matches the prompt and click on 'Submit'. Remember that if multiple images match the prompt equally well, select them all. If no image matches the prompt, no image shall be selected." |
|
|
|
submission_names = list(SUBMISSIONS.keys()) |
|
num_images = len(SUBMISSIONS[submission_names[0]]) |
|
|
|
|
|
def load_submissions(): |
|
all_datasets = list_datasets(author=SUBMISSION_ORG) |
|
relevant_ids = [d.id for d in all_datasets] |
|
|
|
submitted_ids = [] |
|
for _id in relevant_ids: |
|
ds = load_dataset(_id)["train"] |
|
submitted_ids += ds["id"] |
|
|
|
submitted_ids = Counter(submitted_ids) |
|
return submitted_ids |
|
|
|
|
|
SUBMITTED_IDS = load_submissions() |
|
|
|
|
|
def generate_random_hash(length=8): |
|
""" |
|
Generates a random hash of specified length. |
|
|
|
Args: |
|
length (int): The length of the hash to generate. |
|
|
|
Returns: |
|
str: A random hash of specified length. |
|
""" |
|
if length % 2 != 0: |
|
raise ValueError("Length should be an even number.") |
|
|
|
num_bytes = length // 2 |
|
random_bytes = secrets.token_bytes(num_bytes) |
|
random_hash = secrets.token_hex(num_bytes) |
|
|
|
return random_hash |
|
|
|
|
|
def refresh(row_number, dataframe): |
|
if row_number == NUM_QUESTIONS: |
|
submitted_ids = load_submissions() |
|
return start(submitted_ids) |
|
else: |
|
return dataframe |
|
|
|
def start(): |
|
ids = {id: 0 for id in range(num_images)} |
|
ids = {**ids, **SUBMITTED_IDS} |
|
|
|
|
|
ids = sorted(ids.items(), key=lambda x: x[1]) |
|
freq_ids = defaultdict(list) |
|
for k, v in ids: |
|
freq_ids[v].append(k) |
|
|
|
|
|
for k, v_list in freq_ids.items(): |
|
shuffle(v_list) |
|
freq_ids[k] = v_list |
|
|
|
shuffled_ids = sum(list(freq_ids.values()), []) |
|
|
|
|
|
id_candidates = shuffled_ids[: (10 * NUM_QUESTIONS)] |
|
|
|
|
|
image_ids = sample(id_candidates, k=NUM_QUESTIONS) |
|
images = {} |
|
|
|
for i in range(NUM_QUESTIONS): |
|
order = list(range(len(SUBMISSIONS))) |
|
shuffle(order) |
|
|
|
id = image_ids[i] |
|
row = SUBMISSIONS[submission_names[0]][id] |
|
images[i] = { |
|
"prompt": row["Prompt"], |
|
"result": "", |
|
"id": id, |
|
"Challenge": row["Challenge"], |
|
"Category": row["Category"], |
|
"Note": row["Note"], |
|
} |
|
for n, m in enumerate(order): |
|
images[i][f"choice_{n}"] = m |
|
|
|
images_frame = DataFrame.from_dict(images, orient="index") |
|
return images_frame |
|
|
|
|
|
def process(dataframe, row_number=0): |
|
if row_number == NUM_QUESTIONS: |
|
nones = len(RESULT) * [None] |
|
falses = len(RESULT) * [False] |
|
return *nones, *falses, "", "" |
|
|
|
image_id = dataframe.iloc[row_number]["id"] |
|
choices = [ |
|
submission_names[dataframe.iloc[row_number][f"choice_{i}"]] |
|
for i in range(len(SUBMISSIONS)) |
|
] |
|
images = [SUBMISSIONS[c][int(image_id)]["images"] for c in choices] |
|
|
|
prompt = SUBMISSIONS[choices[0]][int(image_id)]["Prompt"] |
|
prompt = f'# "{prompt}"' |
|
counter = f"***{row_number + 1}/{NUM_QUESTIONS} {PROMPT_FORMAT}***" |
|
image_buttons = len(images) * [False] |
|
|
|
return *images, *image_buttons, prompt, counter |
|
|
|
|
|
def write_result(user_choice, row_number, dataframe): |
|
if row_number == NUM_QUESTIONS: |
|
return row_number, dataframe |
|
|
|
user_choices = [] |
|
for i, b in enumerate(str(user_choice)): |
|
if bool(int(b)): |
|
user_choices.append(i) |
|
|
|
chosen_models = [] |
|
for user_choice in user_choices: |
|
chosen_models.append(submission_names[dataframe.iloc[row_number][f"choice_{user_choice}"]]) |
|
|
|
print(chosen_models) |
|
dataframe.loc[row_number, "result"] = ",".join(chosen_models) |
|
return row_number + 1, dataframe |
|
|
|
|
|
def get_index(evt: gr.SelectData) -> int: |
|
return evt.index |
|
|
|
|
|
def change_view(row_number, dataframe): |
|
if row_number == NUM_QUESTIONS: |
|
|
|
results = sum([x.split(",") for x in dataframe["result"].values], []) |
|
results = [r for r in results if len(r) > 0] |
|
favorite_model = Counter(results).most_common(1)[0][0] |
|
|
|
dataset = datasets.Dataset.from_pandas(dataframe) |
|
dataset = dataset.remove_columns(set(dataset.column_names) - set(["id", "result"])) |
|
hash = generate_random_hash() |
|
repo_id = os.path.join(SUBMISSION_ORG, hash) |
|
|
|
dataset.push_to_hub(repo_id, token=os.getenv("HF_TOKEN")) |
|
return { |
|
intro_view: gr.update(visible=False), |
|
result_view: gr.update(visible=True), |
|
gallery_view: gr.update(visible=False), |
|
start_view: gr.update(visible=True), |
|
result: RESULT[favorite_model], |
|
} |
|
else: |
|
return { |
|
intro_view: gr.update(visible=False), |
|
result_view: gr.update(visible=False), |
|
gallery_view: gr.update(visible=True), |
|
start_view: gr.update(visible=False), |
|
result: "", |
|
} |
|
|
|
|
|
TITLE = "# What AI model is best for you? π©ββοΈ" |
|
|
|
DESCRIPTION = """ |
|
***How it works*** π \n\n |
|
- Upon clicking start, you are shown image descriptions alongside four AI generated images. |
|
\n- Select the image that best matches the prompt. If multiple images match the prompt equally well, select all images. If no image matches the prompt, leave all images unchecked. |
|
\n- Answer **10** questions to find out what AI generator most resonates with you. |
|
\n- Your submissions contribute to [**Open Parti Prompts Leaderboard**](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard) β€οΈ. |
|
\n\n |
|
""" |
|
|
|
NOTE = """\n\n\n\n |
|
The prompts you are shown originate from the [Parti Prompts](https://huggingface.co/datasets/nateraw/parti-prompts) dataset. |
|
Parti Prompts is designed to test text-to-image AI models on 1600+ prompts of varying difficulty and categories. |
|
The images you are shown have been pre-generated with 4 state-of-the-art open-sourced text-to-image models. |
|
You answers will be used to contribute to the official [**Open Parti Prompts Leaderboard**](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard). |
|
Every couple months, the generated images will be updated with possibly improved models. The current models and code that was used to generate the images can be verified here:\n |
|
- [kandinsky-2-2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) \n |
|
- [wuerstchen](https://huggingface.co/warp-ai/wuerstchen) \n |
|
- [sdxl-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) \n |
|
- [karlo](https://huggingface.co/datasets/diffusers-parti-prompts/karlo-v1) \n |
|
""" |
|
|
|
GALLERY_COLUMN_NUM = len(SUBMISSIONS) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(TITLE) |
|
with gr.Column(visible=True) as intro_view: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
headers = ["prompt", "result", "id", "Challenge", "Category", "Note"] + [ |
|
f"choice_{i}" for i in range(len(SUBMISSIONS)) |
|
] |
|
datatype = ["str", "str", "number", "str", "str", "str"] + len(SUBMISSIONS) * [ |
|
"number" |
|
] |
|
|
|
with gr.Column(visible=False): |
|
row_number = gr.Number( |
|
label="Current row selection index", |
|
value=0, |
|
precision=0, |
|
interactive=False, |
|
) |
|
|
|
|
|
with gr.Column(visible=False) as result_view: |
|
result = gr.Markdown("") |
|
dataframe = gr.Dataframe( |
|
headers=headers, |
|
datatype=datatype, |
|
row_count=NUM_QUESTIONS, |
|
col_count=(6 + len(SUBMISSIONS), "fixed"), |
|
interactive=False, |
|
) |
|
gr.Markdown("Click on start to play again!") |
|
|
|
with gr.Column(visible=True) as start_view: |
|
start_button = gr.Button("Start").style(full_width=True) |
|
gr.Markdown(NOTE) |
|
|
|
with gr.Column(visible=False): |
|
selected_image = gr.Textbox(label="Selected indexes") |
|
|
|
with gr.Column(visible=False) as gallery_view: |
|
with gr.Row(): |
|
counter = gr.Markdown(f"***1/{NUM_QUESTIONS} {PROMPT_FORMAT}***") |
|
with gr.Row(): |
|
prompt = gr.Markdown("") |
|
with gr.Blocks(): |
|
with gr.Row(): |
|
with gr.Column(min_width=200) as c1: |
|
image_1 = gr.Image(interactive=False) |
|
image_1_button = gr.Checkbox(False, label="Image 1").style(full_width=True) |
|
with gr.Column(min_width=200) as c2: |
|
image_2 = gr.Image(interactive=False) |
|
image_2_button = gr.Checkbox(False, label="Image 2").style(full_width=True) |
|
with gr.Column(min_width=200) as c3: |
|
image_3 = gr.Image(interactive=False) |
|
image_3_button = gr.Checkbox(False, label="Image 3").style(full_width=True) |
|
with gr.Column(min_width=200) as c4: |
|
image_4 = gr.Image(interactive=False) |
|
image_4_button = gr.Checkbox(False, label="Image 4").style(full_width=True) |
|
with gr.Row(): |
|
submit_button = gr.Button("Submit").style(full_width=True) |
|
|
|
start_button.click( |
|
fn=start, |
|
inputs=[], |
|
outputs=dataframe, |
|
show_progress=True |
|
).then( |
|
fn=lambda x: 0 if x == NUM_QUESTIONS else x, |
|
inputs=[row_number], |
|
outputs=[row_number], |
|
).then( |
|
fn=change_view, |
|
inputs=[row_number, dataframe], |
|
outputs=[intro_view, result_view, gallery_view, start_view, result], |
|
).then( |
|
fn=process, |
|
inputs=[dataframe], |
|
outputs=[image_1, image_2, image_3, image_4, image_1_button, image_2_button, image_3_button, image_4_button, prompt, counter] |
|
) |
|
|
|
def integerize(x1, x2, x3, x4): |
|
number = f"{int(x1)}{int(x2)}{int(x3)}{int(x4)}" |
|
return number |
|
|
|
submit_button.click( |
|
fn=integerize, |
|
inputs=[image_1_button, image_2_button, image_3_button, image_4_button], |
|
outputs=[selected_image], |
|
).then( |
|
fn=write_result, |
|
inputs=[selected_image, row_number, dataframe], |
|
outputs=[row_number, dataframe], |
|
).then( |
|
fn=change_view, |
|
inputs=[row_number, dataframe], |
|
outputs=[intro_view, result_view, gallery_view, start_view, result] |
|
).then( |
|
fn=process, |
|
inputs=[dataframe, row_number], |
|
outputs=[image_1, image_2, image_3, image_4, image_1_button, image_2_button, image_3_button, image_4_button, prompt, counter], |
|
).then( |
|
fn=lambda x: 0 if x == NUM_QUESTIONS else x, |
|
inputs=[row_number], |
|
outputs=[row_number], |
|
).then( |
|
fn=refresh, |
|
inputs=[row_number, dataframe], |
|
outputs=[dataframe], |
|
) |
|
|
|
demo.launch() |
|
|