ProductPlacement / gradio_demo_bg_modif.py
Ashoka74's picture
Upload folder using huggingface_hub
5ea4356 verified
import os
import math
import gradio as gr
import numpy as np
import torch
import safetensors.torch as sf
import db_examples
import datetime
from pathlib import Path
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPTextModel, CLIPTokenizer
from briarmbg import BriaRMBG
from enum import Enum
from torch.hub import download_url_to_file
try:
import xformers
import xformers.ops
XFORMERS_AVAILABLE = True
print("xformers is available - Using memory efficient attention")
except ImportError:
XFORMERS_AVAILABLE = False
print("xformers not available - Using default attention")
# 'stablediffusionapi/realistic-vision-v51'
# 'runwayml/stable-diffusion-v1-5'
sd15_name = 'stablediffusionapi/realistic-vision-v51'
tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
# Change UNet
with torch.no_grad():
new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
new_conv_in.bias = unet.conv_in.bias
unet.conv_in = new_conv_in
unet_original_forward = unet.forward
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
new_sample = torch.cat([sample, c_concat], dim=1)
kwargs['cross_attention_kwargs'] = {}
return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
unet.forward = hooked_unet_forward
# Load
model_path = './models/iclight_sd15_fbc.safetensors'
if not os.path.exists(model_path):
download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors', dst=model_path)
# Device and dtype setup
device = torch.device('cuda')
dtype = torch.float16 # RTX 2070 works well with float16
# Memory optimizations for RTX 2070
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set a smaller attention slice size for RTX 2070
torch.backends.cuda.max_split_size_mb = 512
# Move models to device with consistent dtype
text_encoder = text_encoder.to(device=device, dtype=dtype)
vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
unet = unet.to(device=device, dtype=dtype)
rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
# Update the state dict merging to use correct dtype
sd_offset = sf.load_file(model_path)
sd_origin = unet.state_dict()
sd_merged = {k: sd_origin[k] + sd_offset[k].to(device=device, dtype=dtype) for k in sd_origin.keys()}
unet.load_state_dict(sd_merged, strict=True)
del sd_offset, sd_origin, sd_merged
def enable_efficient_attention():
if XFORMERS_AVAILABLE:
try:
# RTX 2070 specific settings
unet.set_use_memory_efficient_attention_xformers(True)
vae.set_use_memory_efficient_attention_xformers(True)
print("Enabled xformers memory efficient attention")
except Exception as e:
print(f"Xformers error: {e}")
print("Falling back to sliced attention")
# Use sliced attention for RTX 2070
unet.set_attention_slice_size(4)
vae.set_attention_slice_size(4)
unet.set_attn_processor(AttnProcessor2_0())
vae.set_attn_processor(AttnProcessor2_0())
else:
# Fallback for when xformers is not available
print("Using sliced attention")
unet.set_attention_slice_size(4)
vae.set_attention_slice_size(4)
unet.set_attn_processor(AttnProcessor2_0())
vae.set_attn_processor(AttnProcessor2_0())
# Add memory clearing function
def clear_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Enable efficient attention
enable_efficient_attention()
# Samplers
ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
euler_a_scheduler = EulerAncestralDiscreteScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
steps_offset=1
)
dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
steps_offset=1
)
# Pipelines
t2i_pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=dpmpp_2m_sde_karras_scheduler,
safety_checker=None,
requires_safety_checker=False,
feature_extractor=None,
image_encoder=None
)
i2i_pipe = StableDiffusionImg2ImgPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=dpmpp_2m_sde_karras_scheduler,
safety_checker=None,
requires_safety_checker=False,
feature_extractor=None,
image_encoder=None
)
@torch.inference_mode()
def encode_prompt_inner(txt: str):
max_length = tokenizer.model_max_length
chunk_length = tokenizer.model_max_length - 2
id_start = tokenizer.bos_token_id
id_end = tokenizer.eos_token_id
id_pad = id_end
def pad(x, p, i):
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
chunks = [pad(ck, id_pad, max_length) for ck in chunks]
token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
conds = text_encoder(token_ids).last_hidden_state
return conds
@torch.inference_mode()
def encode_prompt_pair(positive_prompt, negative_prompt):
c = encode_prompt_inner(positive_prompt)
uc = encode_prompt_inner(negative_prompt)
c_len = float(len(c))
uc_len = float(len(uc))
max_count = max(c_len, uc_len)
c_repeat = int(math.ceil(max_count / c_len))
uc_repeat = int(math.ceil(max_count / uc_len))
max_chunk = max(len(c), len(uc))
c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
c = torch.cat([p[None, ...] for p in c], dim=1)
uc = torch.cat([p[None, ...] for p in uc], dim=1)
return c, uc
@torch.inference_mode()
def pytorch2numpy(imgs, quant=True):
results = []
for x in imgs:
y = x.movedim(0, -1)
if quant:
y = y * 127.5 + 127.5
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
else:
y = y * 0.5 + 0.5
y = y.detach().float().cpu().numpy().clip(0, 1)
results.append(y)
return results
@torch.inference_mode()
def numpy2pytorch(imgs):
h = torch.from_numpy(np.stack(imgs, axis=0)).float() * 2.0 - 1.0
h = h.movedim(-1, 1)
return h
def resize_and_center_crop(image, target_width, target_height):
pil_image = Image.fromarray(image)
original_width, original_height = pil_image.size
scale_factor = max(target_width / original_width, target_height / original_height)
resized_width = int(round(original_width * scale_factor))
resized_height = int(round(original_height * scale_factor))
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
left = (resized_width - target_width) / 2
top = (resized_height - target_height) / 2
right = (resized_width + target_width) / 2
bottom = (resized_height + target_height) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return np.array(cropped_image)
def resize_without_crop(image, target_width, target_height):
pil_image = Image.fromarray(image)
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
return np.array(resized_image)
@torch.inference_mode()
def run_rmbg(img, sigma=0.0):
H, W, C = img.shape
assert C == 3
k = (256.0 / float(H * W)) ** 0.5
feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
alpha = rmbg(feed)[0][0]
alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
alpha = alpha.movedim(1, -1)[0]
alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
return result.clip(0, 255).astype(np.uint8), alpha
def resize_to_match(image, target_width, target_height):
pil_image = Image.fromarray(image)
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
return np.array(resized_image)
@torch.inference_mode()
def process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
clear_memory()
bg_source = BGSource(bg_source)
# Get background image dimensions
image_height, image_width, _ = input_bg.shape
# Adjust dimensions to the nearest multiple of 64
image_width = (image_width // 64) * 64
image_height = (image_height // 64) * 64
# Resize images without cropping
fg = resize_to_match(input_fg, image_width, image_height)
bg = resize_to_match(input_bg, image_width, image_height)
if bg_source == BGSource.UPLOAD:
pass
elif bg_source == BGSource.UPLOAD_FLIP:
input_bg = np.fliplr(input_bg)
elif bg_source == BGSource.GREY:
input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
elif bg_source == BGSource.LEFT:
gradient = np.linspace(224, 32, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif bg_source == BGSource.RIGHT:
gradient = np.linspace(32, 224, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif bg_source == BGSource.TOP:
gradient = np.linspace(224, 32, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif bg_source == BGSource.BOTTOM:
gradient = np.linspace(32, 224, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
else:
raise 'Wrong background source!'
rng = torch.Generator(device=device).manual_seed(seed)
fg = resize_and_center_crop(input_fg, image_width, image_height)
bg = resize_and_center_crop(input_bg, image_width, image_height)
concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
latents = t2i_pipe(
prompt_embeds=conds,
negative_prompt_embeds=unconds,
width=image_width,
height=image_height,
num_inference_steps=steps,
num_images_per_prompt=num_samples,
generator=rng,
output_type='latent',
guidance_scale=cfg,
cross_attention_kwargs={'concat_conds': concat_conds},
).images.to(vae.dtype) / vae.config.scaling_factor
pixels = vae.decode(latents).sample
# Use quant=False to keep high-precision float32 images
pixels = pytorch2numpy(pixels, quant=False)
latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
latents = latents.to(device=unet.device, dtype=unet.dtype)
image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
fg = resize_and_center_crop(input_fg, image_width, image_height)
bg = resize_and_center_crop(input_bg, image_width, image_height)
concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
latents = i2i_pipe(
image=latents,
strength=highres_denoise,
prompt_embeds=conds,
negative_prompt_embeds=unconds,
width=image_width,
height=image_height,
num_inference_steps=int(round(steps / highres_denoise)),
num_images_per_prompt=num_samples,
generator=rng,
output_type='latent',
guidance_scale=cfg,
cross_attention_kwargs={'concat_conds': concat_conds},
).images.to(vae.dtype) / vae.config.scaling_factor
pixels = vae.decode(latents).sample
pixels = pytorch2numpy(pixels, quant=False)
clear_memory()
return pixels, [fg, bg]
# Add save function
def save_images(images, prefix="relight"):
# Create output directory if it doesn't exist
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
# Create timestamp for unique filenames
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
saved_paths = []
for i, img in enumerate(images):
if isinstance(img, np.ndarray):
# Convert to PIL Image if numpy array
img = Image.fromarray(img.astype(np.uint8))
# Create filename with timestamp
filename = f"{prefix}_{timestamp}_{i+1}.png"
filepath = output_dir / filename
# Save image
img.save(filepath)
saved_paths.append(filepath)
return saved_paths
# Modify process_relight to save images
@torch.inference_mode()
def process_relight(image_editor_output, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
# Extract foreground and background images from the image editor
input_fg = image_editor_output["layers"][1]["image"]
input_bg = image_editor_output["layers"][0]["image"]
input_fg, matting = run_rmbg(input_fg)
results, extra_images = process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
final_results = results + extra_images
# Save the generated images
save_images(results, prefix="relight")
return results
# Modify process_normal to save images
@torch.inference_mode()
def process_normal(image_editor_output, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
# Extract foreground and background images from the image editor
input_fg = image_editor_output["layers"][1]["image"]
input_bg = image_editor_output["layers"][0]["image"]
input_fg, matting = run_rmbg(input_fg, sigma=16)
print('left ...')
left = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.LEFT.value)[0][0]
print('right ...')
right = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.RIGHT.value)[0][0]
print('bottom ...')
bottom = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.BOTTOM.value)[0][0]
print('top ...')
top = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.TOP.value)[0][0]
inner_results = [left * 2.0 - 1.0, right * 2.0 - 1.0, bottom * 2.0 - 1.0, top * 2.0 - 1.0]
ambient = (left + right + bottom + top) / 4.0
h, w, _ = ambient.shape
matting = resize_and_center_crop((matting[..., 0] * 255.0).clip(0, 255).astype(np.uint8), w, h).astype(np.float32)[..., None] / 255.0
def safa_divide(a, b):
e = 1e-5
return ((a + e) / (b + e)) - 1.0
left = safa_divide(left, ambient)
right = safa_divide(right, ambient)
bottom = safa_divide(bottom, ambient)
top = safa_divide(top, ambient)
u = (right - left) * 0.5
v = (top - bottom) * 0.5
sigma = 10.0
u = np.mean(u, axis=2)
v = np.mean(v, axis=2)
h = (1.0 - u ** 2.0 - v ** 2.0).clip(0, 1e5) ** (0.5 * sigma)
z = np.zeros_like(h)
normal = np.stack([u, v, h], axis=2)
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
normal = normal * matting + np.stack([z, z, 1 - z], axis=2) * (1 - matting)
results = [normal, left, right, bottom, top] + inner_results
results = [(x * 127.5 + 127.5).clip(0, 255).astype(np.uint8) for x in results]
# Save the generated images
save_images(results, prefix="normal")
return results
quick_prompts = [
'beautiful woman',
'handsome man',
'beautiful woman, cinematic lighting',
'handsome man, cinematic lighting',
'beautiful woman, natural lighting',
'handsome man, natural lighting',
'beautiful woman, neo punk lighting, cyberpunk',
'handsome man, neo punk lighting, cyberpunk',
]
quick_prompts = [[x] for x in quick_prompts]
class BGSource(Enum):
UPLOAD = "Use Background Image"
UPLOAD_FLIP = "Use Flipped Background Image"
LEFT = "Left Light"
RIGHT = "Right Light"
TOP = "Top Light"
BOTTOM = "Bottom Light"
GREY = "Ambient"
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
with gr.Row():
with gr.Column():
with gr.Row():
image_editor = gr.ImageEditor(label="Edit Images", type="pil")
prompt = gr.Textbox(label="Prompt")
bg_source = gr.Radio(choices=[e.value for e in BGSource],
value=BGSource.UPLOAD.value,
label="Background Source", type='value')
example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
bg_gallery = gr.Gallery(height=450, label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False)
relight_button = gr.Button(value="Relight")
with gr.Group():
with gr.Row():
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
seed = gr.Number(label="Seed", value=12345, precision=0)
with gr.Row():
image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
with gr.Accordion("Advanced options", open=False):
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
n_prompt = gr.Textbox(label="Negative Prompt",
value='lowres, bad anatomy, bad hands, cropped, worst quality')
normal_button = gr.Button(value="Compute Normal (4x Slower)")
with gr.Column():
result_gallery = gr.Image(height=832, label='Outputs')
with gr.Row():
dummy_image_for_outputs = gr.Image(visible=False, label='Result')
gr.Examples(
fn=lambda *args: [args[-1]],
examples=db_examples.background_conditioned_examples,
inputs=[
image_editor, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
],
outputs=[result_gallery],
run_on_click=True, examples_per_page=1024
)
ips = [image_editor, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
relight_button.click(fn=process_relight, inputs=ips, outputs=[result_gallery])
normal_button.click(fn=process_normal, inputs=ips, outputs=[result_gallery])
example_prompts.click(lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False)
def bg_gallery_selected(gal, evt: gr.SelectData):
return gal[evt.index]['name']
bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=image_editor)
block.launch(server_name='0.0.0.0')