File size: 4,390 Bytes
2f22a68
 
 
a575d82
2f22a68
a575d82
2f22a68
 
 
 
a575d82
 
 
4a902da
2f22a68
 
 
a575d82
 
 
18905b7
a575d82
 
 
 
 
 
 
 
 
 
 
18905b7
a575d82
 
 
 
 
 
 
 
 
2f22a68
a575d82
 
2f22a68
 
62a24fb
08fbc70
 
 
 
 
 
0589fd4
 
 
 
 
 
 
 
 
 
08fbc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f22a68
 
0589fd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from main import main
from arguments import parse_args
import os

def generate_image(prompt, model, num_iterations, learning_rate, progress = gr.Progress(track_tqdm=True)):
    # Set up arguments
    args = parse_args()
    args.task = "single"
    args.prompt = prompt
    args.model = model
    args.n_iters = num_iterations
    args.lr = learning_rate
    args.cache_dir = "./HF_model_cache"
    args.save_dir = "./outputs"
    args.save_all_images = True
    
    try:
        # Run the main function
        main(args)

        settings = (
            f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}"
            f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}"
            f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}"
            f"_reg{args.reg_weight if args.enable_reg else '0'}"
            f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}"
            f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}"
            f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}"
            f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
            f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
        )

        save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
        
        # Return the path to the generated image
        image_path = f"{save_dir}/best_image.png"
        
        if os.path.exists(image_path):
            return image_path, f"Image generated successfully and saved at {image_path}"
        else:
            return None, "Image generation completed, but the file was not found."
    
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

# Create Gradio interface
title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown(title)
        gr.Markdown(description)
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href='https://github.com/ExplainableML/ReNO'>
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a> 
            <a href='https://arxiv.org/abs/2406.04312v1'>
                <img src='https://img.shields.io/badge/Paper-Arxiv-red'>
            </a>
        </div>
        """)

        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt")
                chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model")
                
                with gr.Row():
                    n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
                    learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")

                submit_btn = gr.Button("Submit")

                gr.Examples(
                    examples = [
                        "A minimalist logo design of a reindeer, fully rendered. The reindeer features distinct, complete shapes using bold and flat colors. The design emphasizes simplicity and clarity, suitable for logo use with a sharp outline and white background.",
                        "A blue scooter is parked near a curb in front of a green vintage car",
                        "A impressionistic oil painting: a lone figure walking on a misty beach, a weathered lighthouse on a cliff, seagulls above crashing waves",
                        "A bird with 8 legs",
                        "An orange chair to the right of a black airplane",
                        "A pink elephant and a grey cow",
                    ],
                    inputs = [prompt]     
                )
            
            with gr.Column():
                output_image = gr.Image(type="filepath", label="Generated Image")
                status = gr.Textbox(label="Status")

    submit_btn.click(
        fn = generate_image,
        inputs = [prompt, chosen_model, n_iter, learning_rate],
        outputs = [output_image, status]
    )

# Launch the app
demo.queue().launch(show_error=True, show_api=False)