File size: 3,076 Bytes
2caf84c
0e0ee20
 
4989e93
607d766
4989e93
2c7b642
d06075d
0e0ee20
 
c724573
 
463aefd
c724573
 
f3e96f9
c59400c
f645c51
8789230
f645c51
 
 
4989e93
f645c51
4989e93
 
 
5b82e60
fd8e800
4989e93
 
 
7a8479d
4989e93
 
 
 
 
 
8789230
 
 
 
 
 
 
 
d06075d
2c7b642
 
d06075d
 
2c7b642
2caf84c
8789230
 
0b93385
4e9861b
2c7b642
4989e93
 
0e0ee20
4989e93
 
0e0ee20
2c7b642
 
4989e93
f645c51
 
 
2c7b642
 
 
f645c51
8789230
 
 
4989e93
 
 
2c7b642
4989e93
 
0e0ee20
8789230
2c7b642
0e0ee20
 
4989e93
4e9861b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces
import hashlib
from PIL import Image

# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1

@spaces.GPU()
def generate_image(prompt, width, height, lora_path, trigger_word, steps, custom_hash):
    # Load LoRA weights
    pipe.load_lora_weights(lora_path)
    
    # Combine prompt with trigger word
    full_prompt = f"{trigger_word} {prompt}"
    
    # Set up generation parameters
    seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    # Generate image
    image = pipe(
        prompt=full_prompt,
        num_inference_steps=steps,
        guidance_scale=3.5,
        width=width,
        height=height,
        generator=generator,
    ).images[0]
    
    # Generate or use provided hash for the image
    if not custom_hash:
        # Generate a hash if custom_hash is not provided
        image_bytes = image.tobytes()
        hash_object = hashlib.sha256(image_bytes)
        image_hash = hash_object.hexdigest()
    else:
        image_hash = custom_hash
    
    # Save the image with the hash as filename
    image_path = f"{image_hash}.png"
    image.save(image_path)
    
    return image, image_hash

def run_lora(prompt, width, height, lora_path, trigger_word, steps, custom_hash):
    return generate_image(prompt, width, height, lora_path, trigger_word, steps, custom_hash)

# Gradio interface
with gr.Blocks() as app:
    gr.Markdown("# LoRA Image Generator")
    
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
    
    with gr.Row():
        width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
        height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
    
    with gr.Row():
        lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
        trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
    
    with gr.Row():
        steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28)
    
    with gr.Row():
        custom_hash = gr.Textbox(label="Custom Hash (optional)", placeholder="Leave blank to auto-generate hash")
    
    generate_button = gr.Button("Generate Image")
    
    output_image = gr.Image(label="Generated Image")
    output_hash = gr.Textbox(label="Image Hash", interactive=False)
    
    generate_button.click(
        fn=run_lora,
        inputs=[prompt, width, height, lora_path, trigger_word, steps, custom_hash],
        outputs=[output_image, output_hash]
    )

if __name__ == "__main__":
    app.queue().launch(share=False)