ProductPlacement / gradio_demo_bg.py
Ashoka74's picture
Upload 7 files
7417c4f verified
import os
import math
import gradio as gr
import numpy as np
import torch
import safetensors.torch as sf
import db_examples
import datetime
from pathlib import Path
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPTextModel, CLIPTokenizer
from briarmbg import BriaRMBG
from enum import Enum
from torch.hub import download_url_to_file
import cv2
from typing import Optional
from Depth.depth_anything_v2.dpt import DepthAnythingV2
# from FLORENCE
import spaces
import supervision as sv
import torch
from PIL import Image
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
import torch
DEVICE = torch.device("cuda")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
@spaces.GPU(duration=20)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(image_input, text_input) -> Optional[Image.Image]:
# if not image_input:
# gr.Info("Please upload an image.")
# return None
# if not text_input:
# gr.Info("Please enter a text prompt.")
# return None
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text_input
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
if len(detections) == 0:
gr.Info("No objects detected.")
return None
return Image.fromarray(detections.mask[0].astype("uint8") * 255)
try:
import xformers
import xformers.ops
XFORMERS_AVAILABLE = True
print("xformers is available - Using memory efficient attention")
except ImportError:
XFORMERS_AVAILABLE = False
print("xformers not available - Using default attention")
# 'stablediffusionapi/realistic-vision-v51'
# 'runwayml/stable-diffusion-v1-5'
sd15_name = 'stablediffusionapi/realistic-vision-v51'
tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
# Change UNet
with torch.no_grad():
new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
new_conv_in.bias = unet.conv_in.bias
unet.conv_in = new_conv_in
unet_original_forward = unet.forward
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
new_sample = torch.cat([sample, c_concat], dim=1)
kwargs['cross_attention_kwargs'] = {}
return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
unet.forward = hooked_unet_forward
# Load
model_path = './models/iclight_sd15_fbc.safetensors'
if not os.path.exists(model_path):
download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors', dst=model_path)
# Device and dtype setup
device = torch.device('cuda')
dtype = torch.float16 # RTX 2070 works well with float16
# Memory optimizations for RTX 2070
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set a smaller attention slice size for RTX 2070
torch.backends.cuda.max_split_size_mb = 512
# Move models to device with consistent dtype
text_encoder = text_encoder.to(device=device, dtype=dtype)
vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
unet = unet.to(device=device, dtype=dtype)
rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
model.eval()
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
#FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device)
SAM_IMAGE_MODEL = load_sam_image_model(device=device)
# Update the state dict merging to use correct dtype
sd_offset = sf.load_file(model_path)
sd_origin = unet.state_dict()
sd_merged = {k: sd_origin[k] + sd_offset[k].to(device=device, dtype=dtype) for k in sd_origin.keys()}
unet.load_state_dict(sd_merged, strict=True)
del sd_offset, sd_origin, sd_merged
def enable_efficient_attention():
if XFORMERS_AVAILABLE:
try:
# RTX 2070 specific settings
unet.set_use_memory_efficient_attention_xformers(True)
vae.set_use_memory_efficient_attention_xformers(True)
print("Enabled xformers memory efficient attention")
except Exception as e:
print(f"Xformers error: {e}")
print("Falling back to sliced attention")
# Use sliced attention for RTX 2070
unet.set_attention_slice_size(4)
vae.set_attention_slice_size(4)
unet.set_attn_processor(AttnProcessor2_0())
vae.set_attn_processor(AttnProcessor2_0())
else:
# Fallback for when xformers is not available
print("Using sliced attention")
unet.set_attention_slice_size(4)
vae.set_attention_slice_size(4)
unet.set_attn_processor(AttnProcessor2_0())
vae.set_attn_processor(AttnProcessor2_0())
# Add memory clearing function
def clear_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Enable efficient attention
enable_efficient_attention()
# Samplers
ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
euler_a_scheduler = EulerAncestralDiscreteScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
steps_offset=1
)
dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
steps_offset=1
)
# Pipelines
t2i_pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=dpmpp_2m_sde_karras_scheduler,
safety_checker=None,
requires_safety_checker=False,
feature_extractor=None,
image_encoder=None
)
i2i_pipe = StableDiffusionImg2ImgPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=dpmpp_2m_sde_karras_scheduler,
safety_checker=None,
requires_safety_checker=False,
feature_extractor=None,
image_encoder=None
)
@torch.inference_mode()
def encode_prompt_inner(txt: str):
max_length = tokenizer.model_max_length
chunk_length = tokenizer.model_max_length - 2
id_start = tokenizer.bos_token_id
id_end = tokenizer.eos_token_id
id_pad = id_end
def pad(x, p, i):
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
chunks = [pad(ck, id_pad, max_length) for ck in chunks]
token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
conds = text_encoder(token_ids).last_hidden_state
return conds
@torch.inference_mode()
def encode_prompt_pair(positive_prompt, negative_prompt):
c = encode_prompt_inner(positive_prompt)
uc = encode_prompt_inner(negative_prompt)
c_len = float(len(c))
uc_len = float(len(uc))
max_count = max(c_len, uc_len)
c_repeat = int(math.ceil(max_count / c_len))
uc_repeat = int(math.ceil(max_count / uc_len))
max_chunk = max(len(c), len(uc))
c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
c = torch.cat([p[None, ...] for p in c], dim=1)
uc = torch.cat([p[None, ...] for p in uc], dim=1)
return c, uc
@torch.inference_mode()
def pytorch2numpy(imgs, quant=True):
results = []
for x in imgs:
y = x.movedim(0, -1)
if quant:
y = y * 127.5 + 127.5
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
else:
y = y * 0.5 + 0.5
y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
results.append(y)
return results
@torch.inference_mode()
def numpy2pytorch(imgs):
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
h = h.movedim(-1, 1)
return h
def resize_and_center_crop(image, target_width, target_height):
pil_image = Image.fromarray(image)
original_width, original_height = pil_image.size
scale_factor = max(target_width / original_width, target_height / original_height)
resized_width = int(round(original_width * scale_factor))
resized_height = int(round(original_height * scale_factor))
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
left = (resized_width - target_width) / 2
top = (resized_height - target_height) / 2
right = (resized_width + target_width) / 2
bottom = (resized_height + target_height) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return np.array(cropped_image)
def resize_without_crop(image, target_width, target_height):
pil_image = Image.fromarray(image)
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
return np.array(resized_image)
@torch.inference_mode()
def run_rmbg(img, sigma=0.0):
# Convert RGBA to RGB if needed
if img.shape[-1] == 4:
# Use white background for alpha composition
alpha = img[..., 3:] / 255.0
rgb = img[..., :3]
white_bg = np.ones_like(rgb) * 255
img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
H, W, C = img.shape
assert C == 3
k = (256.0 / float(H * W)) ** 0.5
feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
alpha = rmbg(feed)[0][0]
alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
alpha = alpha.movedim(1, -1)[0]
alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
# Create RGBA image
rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
return result.clip(0, 255).astype(np.uint8), rgba
@torch.inference_mode()
def process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
clear_memory()
bg_source = BGSource(bg_source)
if bg_source == BGSource.UPLOAD:
pass
elif bg_source == BGSource.UPLOAD_FLIP:
input_bg = np.fliplr(input_bg)
elif bg_source == BGSource.GREY:
input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
elif bg_source == BGSource.LEFT:
gradient = np.linspace(224, 32, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif bg_source == BGSource.RIGHT:
gradient = np.linspace(32, 224, image_width)
image = np.tile(gradient, (image_height, 1))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif bg_source == BGSource.TOP:
gradient = np.linspace(224, 32, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif bg_source == BGSource.BOTTOM:
gradient = np.linspace(32, 224, image_height)[:, None]
image = np.tile(gradient, (1, image_width))
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
else:
raise 'Wrong background source!'
rng = torch.Generator(device=device).manual_seed(seed)
fg = resize_and_center_crop(input_fg, image_width, image_height)
bg = resize_and_center_crop(input_bg, image_width, image_height)
concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
latents = t2i_pipe(
prompt_embeds=conds,
negative_prompt_embeds=unconds,
width=image_width,
height=image_height,
num_inference_steps=steps,
num_images_per_prompt=num_samples,
generator=rng,
output_type='latent',
guidance_scale=cfg,
cross_attention_kwargs={'concat_conds': concat_conds},
).images.to(vae.dtype) / vae.config.scaling_factor
pixels = vae.decode(latents).sample
pixels = pytorch2numpy(pixels)
pixels = [resize_without_crop(
image=p,
target_width=int(round(image_width * highres_scale / 64.0) * 64),
target_height=int(round(image_height * highres_scale / 64.0) * 64))
for p in pixels]
pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
latents = latents.to(device=unet.device, dtype=unet.dtype)
image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
fg = resize_and_center_crop(input_fg, image_width, image_height)
bg = resize_and_center_crop(input_bg, image_width, image_height)
concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
latents = i2i_pipe(
image=latents,
strength=highres_denoise,
prompt_embeds=conds,
negative_prompt_embeds=unconds,
width=image_width,
height=image_height,
num_inference_steps=int(round(steps / highres_denoise)),
num_images_per_prompt=num_samples,
generator=rng,
output_type='latent',
guidance_scale=cfg,
cross_attention_kwargs={'concat_conds': concat_conds},
).images.to(vae.dtype) / vae.config.scaling_factor
pixels = vae.decode(latents).sample
pixels = pytorch2numpy(pixels, quant=False)
clear_memory()
return pixels, [fg, bg]
# Add save function
def save_images(images, prefix="relight"):
# Create output directory if it doesn't exist
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
# Create timestamp for unique filenames
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
saved_paths = []
for i, img in enumerate(images):
if isinstance(img, np.ndarray):
# Convert to PIL Image if numpy array
img = Image.fromarray(img)
# Create filename with timestamp
filename = f"{prefix}_{timestamp}_{i+1}.png"
filepath = output_dir / filename
# Save image
img.save(filepath)
# print(f"Saved {len(saved_paths)} images to {output_dir}")
return saved_paths
# Modify process_relight to save images
@torch.inference_mode()
def process_relight(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
input_fg, matting = run_rmbg(input_fg)
# show input_fg in a new image
input_fg_img = Image.fromarray(input_fg)
results, extra_images = process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
final_results = results + extra_images
# Save the generated images
save_images(results, prefix="relight")
return results
# Modify process_normal to save images
@torch.inference_mode()
def process_normal(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
input_fg, matting = run_rmbg(input_fg, sigma=16)
print('left ...')
left = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.LEFT.value)[0][0]
print('right ...')
right = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.RIGHT.value)[0][0]
print('bottom ...')
bottom = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.BOTTOM.value)[0][0]
print('top ...')
top = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.TOP.value)[0][0]
inner_results = [left * 2.0 - 1.0, right * 2.0 - 1.0, bottom * 2.0 - 1.0, top * 2.0 - 1.0]
ambient = (left + right + bottom + top) / 4.0
h, w, _ = ambient.shape
matting = resize_and_center_crop((matting[..., 0] * 255.0).clip(0, 255).astype(np.uint8), w, h).astype(np.float32)[..., None] / 255.0
def safa_divide(a, b):
e = 1e-5
return ((a + e) / (b + e)) - 1.0
left = safa_divide(left, ambient)
right = safa_divide(right, ambient)
bottom = safa_divide(bottom, ambient)
top = safa_divide(top, ambient)
u = (right - left) * 0.5
v = (top - bottom) * 0.5
sigma = 10.0
u = np.mean(u, axis=2)
v = np.mean(v, axis=2)
h = (1.0 - u ** 2.0 - v ** 2.0).clip(0, 1e5) ** (0.5 * sigma)
z = np.zeros_like(h)
normal = np.stack([u, v, h], axis=2)
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
normal = normal * matting + np.stack([z, z, 1 - z], axis=2) * (1 - matting)
results = [normal, left, right, bottom, top] + inner_results
results = [(x * 127.5 + 127.5).clip(0, 255).astype(np.uint8) for x in results]
# Save the generated images
save_images(results, prefix="normal")
return results
quick_prompts = [
'modern sofa in living room',
'elegant dining table with chairs',
'luxurious bed in bedroom, cinematic lighting',
'minimalist office desk, natural lighting',
'vintage wooden cabinet, warm lighting',
'contemporary bookshelf, ambient lighting',
'designer armchair, dramatic lighting',
'modern kitchen island, bright lighting',
]
quick_prompts = [[x] for x in quick_prompts]
class BGSource(Enum):
UPLOAD = "Use Background Image"
UPLOAD_FLIP = "Use Flipped Background Image"
LEFT = "Left Light"
RIGHT = "Right Light"
TOP = "Top Light"
BOTTOM = "Bottom Light"
GREY = "Ambient"
class MaskMover:
def __init__(self):
self.extracted_fg = None
self.original_fg = None # Store original foreground
def set_extracted_fg(self, fg_image):
"""Store the extracted foreground with alpha channel"""
self.extracted_fg = fg_image.copy()
self.original_fg = fg_image.copy() # Keep original
return fg_image
def create_composite(self, background, x_pos, y_pos, scale=1.0):
"""Create composite with foreground at specified position"""
if self.original_fg is None or background is None:
return background
# Convert inputs to PIL Images
if isinstance(background, np.ndarray):
bg = Image.fromarray(background)
else:
bg = background
if isinstance(self.original_fg, np.ndarray):
fg = Image.fromarray(self.original_fg)
else:
fg = self.original_fg
# Scale the foreground size
new_width = int(fg.width * scale)
new_height = int(fg.height * scale)
fg = fg.resize((new_width, new_height), Image.LANCZOS)
# Center the scaled foreground at the position
x = int(x_pos - new_width / 2)
y = int(y_pos - new_height / 2)
# Create composite
result = bg.copy()
if fg.mode == 'RGBA': # If foreground has alpha channel
result.paste(fg, (x, y), fg.split()[3]) # Use alpha channel as mask
else:
result.paste(fg, (x, y))
return np.array(result)
def get_depth(image):
if image is None:
return None
# Convert from PIL/gradio format to cv2
raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Get depth map
depth = model.infer_image(raw_img) # HxW raw depth map
# Normalize depth for visualization
depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
# Convert to RGB for display
depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
return Image.fromarray(depth_colored)
# def find_objects(image_input):
# detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
# if len(detections) == 0:
# gr.Info("No objects detected.")
# return None
# return Image.fromarray(detections.mask[0].astype("uint8") * 255)
block = gr.Blocks().queue()
with block:
mask_mover = MaskMover()
with gr.Row():
gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
with gr.Row():
with gr.Column():
# Step 1: Input and Extract
with gr.Group():
gr.Markdown("### Step 1: Extract Foreground")
input_image = gr.Image(type="numpy", label="Input Image", height=480)
input_text = gr.Textbox(label="Describe target object")
find_objects_button = gr.Button(value="Find Objects")
extract_button = gr.Button(value="Remove Background")
extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
# Step 2: Background and Position
with gr.Group():
gr.Markdown("### Step 2: Position on Background")
input_bg = gr.Image(type="numpy", label="Background Image", height=480)
with gr.Row():
x_slider = gr.Slider(
minimum=0,
maximum=1000,
label="X Position",
value=500,
visible=False
)
y_slider = gr.Slider(
minimum=0,
maximum=1000,
label="Y Position",
value=500,
visible=False
)
fg_scale_slider = gr.Slider(
label="Foreground Scale",
minimum=0.01,
maximum=3.0,
value=1.0,
step=0.01
)
get_depth_button = gr.Button(value="Get Depth")
depth_image = gr.Image(type="numpy", label="Depth Image", height=480)
editor = gr.ImageEditor(
type="numpy",
label="Position Foreground",
height=480,
visible=False
)
# Step 3: Relighting Options
with gr.Group():
gr.Markdown("### Step 3: Relighting Settings")
prompt = gr.Textbox(label="Prompt")
bg_source = gr.Radio(
choices=[e.value for e in BGSource],
value=BGSource.UPLOAD.value,
label="Background Source",
type='value'
)
example_prompts = gr.Dataset(
samples=quick_prompts,
label='Prompt Quick List',
components=[prompt]
)
# bg_gallery = gr.Gallery(
# height=450,
# label='Background Quick List',
# value=db_examples.bg_samples,
# columns=5,
# allow_preview=False
# )
relight_button = gr.Button(value="Relight")
# Additional settings
with gr.Group():
with gr.Row():
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
seed = gr.Number(label="Seed", value=12345, precision=0)
with gr.Row():
image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
with gr.Accordion("Advanced options", open=False):
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
n_prompt = gr.Textbox(
label="Negative Prompt",
value='lowres, bad anatomy, bad hands, cropped, worst quality'
)
normal_button = gr.Button(value="Compute Normal (4x Slower)")
with gr.Column():
result_gallery = gr.Image(height=832, label='Outputs')
# Event handlers
def extract_foreground(image):
if image is None:
return None, gr.update(visible=True), gr.update(visible=True)
result, rgba = run_rmbg(image)
mask_mover.set_extracted_fg(rgba)
return result, gr.update(visible=True), gr.update(visible=True)
original_bg = None
extract_button.click(
fn=extract_foreground,
inputs=[input_image],
outputs=[extracted_fg, x_slider, y_slider]
)
find_objects_button.click(
fn=process_image,
inputs=[input_image, input_text],
outputs=[extracted_fg]
)
get_depth_button.click(
fn=get_depth,
inputs=[input_bg],
outputs=[depth_image]
)
def update_position(background, x_pos, y_pos, scale):
"""Update composite when position changes"""
global original_bg
if background is None:
return None
if original_bg is None:
original_bg = background.copy()
# Convert string values to float
x_pos = float(x_pos)
y_pos = float(y_pos)
scale = float(scale)
return mask_mover.create_composite(original_bg, x_pos, y_pos, scale)
x_slider.change(
fn=update_position,
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
outputs=[input_bg]
)
y_slider.change(
fn=update_position,
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
outputs=[input_bg]
)
fg_scale_slider.change(
fn=update_position,
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
outputs=[input_bg]
)
# Update inputs list to include fg_scale_slider
ips = [input_bg, input_bg, prompt, image_width, image_height, num_samples,
seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise,
bg_source, x_slider, y_slider, fg_scale_slider] # Added fg_scale_slider
def process_relight_with_position(*args):
if mask_mover.extracted_fg is None:
gr.Warning("Please extract foreground first")
return None
background = args[1] # Get background image
x_pos = float(args[-3]) # x_slider value
y_pos = float(args[-2]) # y_slider value
scale = float(args[-1]) # fg_scale_slider value
# Get original foreground size after scaling
fg = Image.fromarray(mask_mover.original_fg)
new_width = int(fg.width * scale)
new_height = int(fg.height * scale)
# Calculate crop region around foreground position
crop_x = int(x_pos - new_width/2)
crop_y = int(y_pos - new_height/2)
crop_width = new_width
crop_height = new_height
# Add padding for context (20% extra on each side)
padding = 0.2
crop_x = int(crop_x - crop_width * padding)
crop_y = int(crop_y - crop_height * padding)
crop_width = int(crop_width * (1 + 2 * padding))
crop_height = int(crop_height * (1 + 2 * padding))
# Ensure crop dimensions are multiples of 8
crop_width = ((crop_width + 7) // 8) * 8
crop_height = ((crop_height + 7) // 8) * 8
# Ensure crop region is within image bounds
bg_height, bg_width = background.shape[:2]
crop_x = max(0, min(crop_x, bg_width - crop_width))
crop_y = max(0, min(crop_y, bg_height - crop_height))
# Get actual crop dimensions after boundary check
crop_width = min(crop_width, bg_width - crop_x)
crop_height = min(crop_height, bg_height - crop_y)
# Ensure dimensions are multiples of 8 again
crop_width = (crop_width // 8) * 8
crop_height = (crop_height // 8) * 8
# Crop region from background
crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
# Create composite in cropped region
fg_local_x = int(new_width/2 + crop_width*padding)
fg_local_y = int(new_height/2 + crop_height*padding)
cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
# Process the cropped region
crop_args = list(args)
crop_args[0] = cropped_composite
crop_args[1] = crop_region
crop_args[3] = crop_width
crop_args[4] = crop_height
crop_args = crop_args[:-3] # Remove position and scale arguments
# Get relit result
relit_crop = process_relight(*crop_args)[0]
# Resize relit result to match crop dimensions if needed
if relit_crop.shape[:2] != (crop_height, crop_width):
relit_crop = resize_without_crop(relit_crop, crop_width, crop_height)
# Place relit crop back into original background
result = background.copy()
result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop
return result
# Update button click events with new inputs list
relight_button.click(
fn=process_relight_with_position,
inputs=ips,
outputs=[result_gallery]
)
# Update normal_button to use same argument handling
def process_normal_with_position(*args):
if mask_mover.extracted_fg is None:
gr.Warning("Please extract foreground first")
return None
background = args[1]
x_pos = float(args[-3]) # x_slider value
y_pos = float(args[-2]) # y_slider value
scale = float(args[-1]) # fg_scale_slider value
# Get original foreground size after scaling
fg = Image.fromarray(mask_mover.original_fg)
new_width = int(fg.width * scale)
new_height = int(fg.height * scale)
# Calculate crop region around foreground position
crop_x = int(x_pos - new_width/2)
crop_y = int(y_pos - new_height/2)
crop_width = new_width
crop_height = new_height
# Add padding for context (20% extra on each side)
padding = 0.2
crop_x = int(crop_x - crop_width * padding)
crop_y = int(crop_y - crop_height * padding)
crop_width = int(crop_width * (1 + 2 * padding))
crop_height = int(crop_height * (1 + 2 * padding))
# Ensure crop dimensions are multiples of 8
crop_width = ((crop_width + 7) // 8) * 8
crop_height = ((crop_height + 7) // 8) * 8
# Ensure crop region is within image bounds
bg_height, bg_width = background.shape[:2]
crop_x = max(0, min(crop_x, bg_width - crop_width))
crop_y = max(0, min(crop_y, bg_height - crop_height))
# Crop region from background
crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
# Create composite in cropped region
fg_local_x = int(new_width/2 + crop_width*padding)
fg_local_y = int(new_height/2 + crop_height*padding)
cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
# Process the cropped region
crop_args = list(args)
crop_args[0] = cropped_composite
crop_args[1] = crop_region
crop_args[3] = crop_width
crop_args[4] = crop_height
crop_args = crop_args[:-3]
# Get processed result
processed_crop = process_normal(*crop_args)
# Place processed crop back into original background
result = background.copy()
result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = processed_crop
return result[0] if result else None
normal_button.click(
fn=process_normal_with_position,
inputs=ips,
outputs=[result_gallery]
)
example_prompts.click(
fn=lambda x: x[0],
inputs=example_prompts,
outputs=prompt,
show_progress=False,
queue=False
)
# def bg_gallery_selected(gal, evt: gr.SelectData):
# return gal[evt.index]['name']
# bg_gallery.select(
# fn=bg_gallery_selected,
# inputs=bg_gallery,
# outputs=input_bg
# )
block.launch(server_name='0.0.0.0')