ADVERTISE / options /Banner_Model /Image2Image.py
raw
history blame
5.23 kB
import imageio
import numpy as np
from PIL import Image
import torch
# from .controlnet_flux import FluxControlNetModel
# from .transformer_flux import FluxTransformer2DModel
# from .pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
from typing import Tuple
from diffusers import FluxInpaintPipeline
import gradio as gr
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device for I2I: {DEVICE}")
# # Load the inpainting pipeline
# def resize_image(image, height, width):
# """Resize image tensor to the desired height and width."""
# return torch.nn.functional.interpolate(image, size=(height, width), mode='nearest')
# def dummy(img):
# """Save the composite image and generate a mask from the alpha channel."""
# imageio.imwrite("output_image.png", img["composite"])
# # Extract alpha channel from the first layer to create the mask
# alpha_channel = img["layers"][0][:, :, 3]
# mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8)
# return img["background"], mask
def resize_image_dimensions(
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = 1024
) -> Tuple[int, int]:
width, height = original_resolution_wh
# if width <= maximum_dimension and height <= maximum_dimension:
# width = width - (width % 32)
# height = height - (height % 32)
# return width, height
if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)
return new_width, new_height
# @spaces.GPU(duration=100)
def I2I(
input_image_editor: dict,
input_text: str,
seed_slicer: int=42,
randomize_seed_checkbox: bool=True,
strength_slider: float=0.85,
num_inference_steps_slider: int=20,
progress=gr.Progress(track_tqdm=True)):
pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
if not input_text:
gr.Info("Please enter a text prompt.")
return None, None
image = input_image_editor['background']
mask = input_image_editor['layers'][0]
if not image:
gr.Info("Please upload an image.")
return None, None
if not mask:
gr.Info("Please draw a mask on the image.")
return None, None
width, height = resize_image_dimensions(original_resolution_wh=image.size)
resized_image = image.resize((width, height), Image.LANCZOS)
resized_mask = mask.resize((width, height), Image.LANCZOS)
if randomize_seed_checkbox:
seed_slicer = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed_slicer)
result = pipe(
prompt=input_text,
image=resized_image,
mask_image=resized_mask,
width=width,
height=height,
strength=strength_slider,
generator=generator,
num_inference_steps=num_inference_steps_slider
).images[0]
print('INFERENCE DONE')
return result
def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
image = image.convert("RGBA")
data = image.getdata()
new_data = []
for item in data:
avg = sum(item[:3]) / 3
if avg < threshold:
new_data.append((0, 0, 0, 0))
else:
new_data.append(item)
image.putdata(new_data)
return image
# def I2I(prompt, image, width=1024, height=1024, guidance_scale=8.0, num_inference_steps=20, strength=0.99):
# controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
# transformer = FluxTransformer2DModel.from_pretrained(
# "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16
# )
# pipe = FluxControlNetInpaintingPipeline.from_pretrained(
# "black-forest-labs/FLUX.1-dev",
# controlnet=controlnet,
# transformer=transformer,
# torch_dtype=torch.bfloat16
# ).to(device)
# pipe.transformer.to(torch.bfloat16)
# pipe.controlnet.to(torch.bfloat16)
# pipe.set_attn_processor(FluxAttnProcessor2_0())
# img_url, mask = dummy(image)
# # Resize image and mask to the target dimensions (height x width)
# img_url = Image.fromarray(img_url, mode="RGB").resize((width, height))
# mask_url = Image.fromarray(mask,mode="L").resize((width, height))
# # Make sure both image and mask are converted into correct tensors
# generator = torch.Generator(device=device).manual_seed(0)
# # Generate the inpainted image
# result = pipe(
# prompt=prompt,
# height=size[1],
# width=size[0],
# control_image=image,
# control_mask=mask,
# num_inference_steps=28,
# generator=generator,
# controlnet_conditioning_scale=0.9,
# guidance_scale=3.5,
# negative_prompt="",
# true_guidance_scale=3.5
# ).images[0]
# return result