24labsimages / app.py
erikbeltran's picture
Update app.py
3992596 verified
raw
history blame
3.08 kB
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces
import hashlib
from PIL import Image
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1
@spaces.GPU()
def generate_image(prompt, width, height, lora_path, trigger_word, steps, custom_hash):
# Load LoRA weights
pipe.load_lora_weights(lora_path)
# Combine prompt with trigger word
full_prompt = f"{trigger_word} {prompt}"
# Set up generation parameters
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate image
image = pipe(
prompt=full_prompt,
num_inference_steps=steps,
guidance_scale=3.5,
width=width,
height=height,
generator=generator,
).images[0]
# Generate or use provided hash for the image
if not custom_hash:
# Generate a hash if custom_hash is not provided
image_bytes = image.tobytes()
hash_object = hashlib.sha256(image_bytes)
image_hash = hash_object.hexdigest()
else:
image_hash = custom_hash
# Save the image with the hash as filename
image_path = f"{image_hash}.png"
image.save(image_path)
return image, image_hash
def run_lora(prompt, width, height, lora_path, trigger_word, steps, custom_hash):
return generate_image(prompt, width, height, lora_path, trigger_word, steps, custom_hash)
# Gradio interface
with gr.Blocks() as app:
gr.Markdown("# LoRA Image Generator")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
with gr.Row():
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
with gr.Row():
steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28)
with gr.Row():
custom_hash = gr.Textbox(label="Custom Hash (optional)", placeholder="Leave blank to auto-generate hash")
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
output_hash = gr.Textbox(label="Image Hash", interactive=False)
generate_button.click(
fn=run_lora,
inputs=[prompt, width, height, lora_path, trigger_word, steps, custom_hash],
outputs=[output_image, output_hash]
)
if __name__ == "__main__":
app.queue().launch(share=False)