Spaces:
Running
on
L40S
Running
on
L40S
import gradio as gr | |
from main import main | |
from arguments import parse_args | |
import os | |
def generate_image(prompt, model, num_iterations, learning_rate, progress = gr.Progress(track_tqdm=True)): | |
# Set up arguments | |
args = parse_args() | |
args.task = "single" | |
args.prompt = prompt | |
args.model = model | |
args.n_iters = num_iterations | |
args.lr = learning_rate | |
args.cache_dir = "./HF_model_cache" | |
args.save_dir = "./outputs" | |
args.save_all_images = True | |
try: | |
# Run the main function | |
main(args) | |
settings = ( | |
f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}" | |
f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}" | |
f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}" | |
f"_reg{args.reg_weight if args.enable_reg else '0'}" | |
f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}" | |
f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}" | |
f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}" | |
f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}" | |
f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}" | |
) | |
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}" | |
# Return the path to the generated image | |
image_path = f"{save_dir}/best_image.png" | |
if os.path.exists(image_path): | |
return image_path, f"Image generated successfully and saved at {image_path}" | |
else: | |
return None, "Image generation completed, but the file was not found." | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization" | |
description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed." | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href='https://github.com/ExplainableML/ReNO'> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
<a href='https://arxiv.org/abs/2406.04312v1'> | |
<img src='https://img.shields.io/badge/Paper-Arxiv-red'> | |
</a> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model") | |
with gr.Row(): | |
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations") | |
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate") | |
submit_btn = gr.Button("Submit") | |
gr.Examples( | |
examples = [ | |
"A minimalist logo design of a reindeer, fully rendered. The reindeer features distinct, complete shapes using bold and flat colors. The design emphasizes simplicity and clarity, suitable for logo use with a sharp outline and white background.", | |
"A blue scooter is parked near a curb in front of a green vintage car", | |
"A impressionistic oil painting: a lone figure walking on a misty beach, a weathered lighthouse on a cliff, seagulls above crashing waves", | |
"A bird with 8 legs", | |
"An orange chair to the right of a black airplane", | |
"A pink elephant and a grey cow", | |
], | |
inputs = [prompt] | |
) | |
with gr.Column(): | |
output_image = gr.Image(type="filepath", label="Generated Image") | |
status = gr.Textbox(label="Status") | |
submit_btn.click( | |
fn = generate_image, | |
inputs = [prompt, chosen_model, n_iter, learning_rate], | |
outputs = [output_image, status] | |
) | |
# Launch the app | |
demo.queue().launch(show_error=True, show_api=False) |