Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from diffusers import DiffusionPipeline, AutoencoderTiny | |
import random | |
import spaces | |
import hashlib | |
from PIL import Image | |
# Initialize the base model | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
base_model = "black-forest-labs/FLUX.1-dev" | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device) | |
MAX_SEED = 2**32-1 | |
def generate_image(prompt, width, height, lora_path, trigger_word, steps, custom_hash): | |
# Load LoRA weights | |
pipe.load_lora_weights(lora_path) | |
# Combine prompt with trigger word | |
full_prompt = f"{trigger_word} {prompt}" | |
# Set up generation parameters | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
# Generate image | |
image = pipe( | |
prompt=full_prompt, | |
num_inference_steps=steps, | |
guidance_scale=3.5, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
# Generate or use provided hash for the image | |
if not custom_hash: | |
# Generate a hash if custom_hash is not provided | |
image_bytes = image.tobytes() | |
hash_object = hashlib.sha256(image_bytes) | |
image_hash = hash_object.hexdigest() | |
else: | |
image_hash = custom_hash | |
# Save the image with the hash as filename | |
image_path = f"{image_hash}.png" | |
image.save(image_path) | |
return image, image_hash | |
def run_lora(prompt, width, height, lora_path, trigger_word, steps, custom_hash): | |
return generate_image(prompt, width, height, lora_path, trigger_word, steps, custom_hash) | |
# Gradio interface | |
with gr.Blocks() as app: | |
gr.Markdown("# LoRA Image Generator") | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here") | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512) | |
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512) | |
with gr.Row(): | |
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2") | |
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK") | |
with gr.Row(): | |
steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28) | |
with gr.Row(): | |
custom_hash = gr.Textbox(label="Custom Hash (optional)", placeholder="Leave blank to auto-generate hash") | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
output_hash = gr.Textbox(label="Image Hash", interactive=False) | |
generate_button.click( | |
fn=run_lora, | |
inputs=[prompt, width, height, lora_path, trigger_word, steps, custom_hash], | |
outputs=[output_image, output_hash] | |
) | |
if __name__ == "__main__": | |
app.queue().launch(share=False) | |