Flux-Redux / app.py
MohamedRashad's picture
Update app.py
c7c2e8e verified
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image
import gradio as gr
import spaces
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained("black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev" ,
text_encoder=None,
text_encoder_2=None,
torch_dtype=torch.bfloat16
).to("cuda")
# pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
# pipe.enable_sequential_cpu_offload()
@spaces.GPU(duration=120)
def enhance_image(image_path, keep_aspect_ratio=False):
print(image_path)
image = load_image(image_path)
print(image.size)
width, height = image.size if keep_aspect_ratio else (None, None)
pipe_prior_output = pipe_prior_redux(image)
# for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
# guidance_scale=2.5,
# num_inference_steps=50,
# width=width,
# height=height,
# generator=torch.Generator("cpu").manual_seed(0),
# output_type="pil",
# **pipe_prior_output
# ):
# yield img
images = pipe(
height=height,
width=width,
guidance_scale=2.5,
num_inference_steps=25,
generator=torch.Generator("cpu").manual_seed(0),
**pipe_prior_output,
).images
return images[0]
with gr.Blocks(title="Flux.1 Dev Redux") as demo:
gr.HTML("<center><h1>Flux.1 Dev Redux</h1></center>")
gr.Markdown("[FLUX.1 Redux](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) is an adapter for all FLUX.1 base models for image variation generation.")
with gr.Row():
with gr.Column():
image_path = gr.Image(label="Image", type="filepath")
keep_aspect_ratio = gr.Checkbox(label="Keep Aspect Ratio", value=False)
submit_btn = gr.Button(value="Submit", variant="primary")
enhanced_image = gr.Image(label="Enhanced Image", type="pil")
submit_btn.click(enhance_image, inputs=[image_path, keep_aspect_ratio], outputs=enhanced_image)
demo.queue().launch(share=False)