|
|
|
from enum import Enum, auto |
|
|
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image, ImageEnhance, ImageFilter |
|
import cv2 |
|
import numpy as np |
|
from refiners.fluxion.utils import load_from_safetensors, tensor_to_image |
|
from refiners.foundationals.clip import CLIPTextEncoderL |
|
from refiners.foundationals.latent_diffusion import SD1UNet |
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder |
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight |
|
|
|
|
|
def load_ic_light(device: torch.device, dtype: torch.dtype) -> ICLight: |
|
return ICLight( |
|
patch_weights=load_from_safetensors( |
|
path=hf_hub_download( |
|
repo_id="refiners/sd15.ic_light.fc", |
|
filename="model.safetensors", |
|
revision="ea10b4403e97c786a98afdcbdf0e0fec794ea542", |
|
), |
|
), |
|
unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sd15.realistic_vision.v5_1.unet", |
|
filename="model.safetensors", |
|
revision="94f74be7adfd27bee330ea1071481c0254c29989", |
|
) |
|
), |
|
clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sd15.realistic_vision.v5_1.text_encoder", |
|
filename="model.safetensors", |
|
revision="7f6fa1e870c8f197d34488e14b89e63fb8d7fd6e", |
|
) |
|
), |
|
lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sd15.realistic_vision.v5_1.autoencoder", |
|
filename="model.safetensors", |
|
revision="99f089787a6e1a852a0992da1e286a19fcbbaa50", |
|
) |
|
), |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
|
|
def resize_modulo_8( |
|
image: Image.Image, |
|
size: int = 768, |
|
resample: Image.Resampling | None = None, |
|
on_short: bool = True, |
|
) -> Image.Image: |
|
"""이미지 크기를 8의 배수로 조정""" |
|
assert size % 8 == 0, "Size must be a multiple of 8 because this is the latent compression size." |
|
side_size = min(image.size) if on_short else max(image.size) |
|
scale = size / (side_size * 8) |
|
new_size = (int(image.width * scale) * 8, int(image.height * scale) * 8) |
|
return image.resize(new_size, resample=resample or Image.Resampling.LANCZOS) |
|
|
|
|
|
def adjust_image( |
|
image: Image.Image, |
|
brightness=0.0, |
|
contrast=0.0, |
|
temperature=0.0, |
|
saturation=0.0, |
|
tint=0.0, |
|
blur_intensity=0, |
|
exposure=0.0, |
|
vibrance=0.0, |
|
color_mixer_blues=0.0, |
|
) -> Image.Image: |
|
"""이미지 조정 함수""" |
|
image = image.convert('RGB') |
|
|
|
|
|
if exposure != 0.0: |
|
|
|
exposure_factor = 1 + (exposure / 5.0) |
|
exposure_factor = max(exposure_factor, 0.01) |
|
enhancer = ImageEnhance.Brightness(image) |
|
image = enhancer.enhance(exposure_factor) |
|
|
|
|
|
if brightness != 0.0: |
|
|
|
brightness_factor = 1 + (brightness / 5.0) |
|
brightness_factor = max(brightness_factor, 0.01) |
|
enhancer = ImageEnhance.Brightness(image) |
|
image = enhancer.enhance(brightness_factor) |
|
|
|
|
|
if contrast != 0.0: |
|
|
|
contrast_factor = 1 + (contrast / 100.0) |
|
contrast_factor = max(contrast_factor, 0.01) |
|
enhancer = ImageEnhance.Contrast(image) |
|
image = enhancer.enhance(contrast_factor) |
|
|
|
|
|
if vibrance != 0.0: |
|
|
|
vibrance_factor = 1 + (vibrance / 100.0) |
|
vibrance_factor = max(vibrance_factor, 0.0) |
|
enhancer = ImageEnhance.Color(image) |
|
image = enhancer.enhance(vibrance_factor) |
|
|
|
|
|
if saturation != 0.0: |
|
|
|
saturation_factor = 1 + (saturation / 100.0) |
|
saturation_factor = max(saturation_factor, 0.0) |
|
enhancer = ImageEnhance.Color(image) |
|
image = enhancer.enhance(saturation_factor) |
|
|
|
|
|
if temperature != 0.0: |
|
|
|
temp_factor = 1 + (temperature / 100.0) |
|
temp_factor = max(temp_factor, 0.01) |
|
|
|
r, g, b = image.split() |
|
r = r.point(lambda i: i * temp_factor) |
|
b = b.point(lambda i: i / temp_factor) |
|
image = Image.merge('RGB', (r, g, b)) |
|
|
|
|
|
if tint != 0.0: |
|
image_np = np.array(image) |
|
image_hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV).astype(np.float32) |
|
image_hsv[:, :, 0] = (image_hsv[:, :, 0] + tint) % 180 |
|
image_hsv[:, :, 0] = np.clip(image_hsv[:, :, 0], 0, 179) |
|
image_rgb = cv2.cvtColor(image_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) |
|
image = Image.fromarray(image_rgb) |
|
|
|
|
|
if blur_intensity > 0: |
|
image = image.filter(ImageFilter.GaussianBlur(radius=blur_intensity)) |
|
|
|
|
|
if color_mixer_blues != 0.0: |
|
image_np = np.array(image).astype(np.float32) |
|
|
|
image_np[:, :, 2] = np.clip(image_np[:, :, 2] + (color_mixer_blues / 100.0) * 255, 0, 255) |
|
image = Image.fromarray(image_np.astype(np.uint8)) |
|
|
|
return image |
|
|
|
|
|
class LightingPreference(str, Enum): |
|
LEFT = auto() |
|
RIGHT = auto() |
|
TOP = auto() |
|
BOTTOM = auto() |
|
NONE = auto() |
|
|
|
def get_init_image(self, width: int, height: int, interval: tuple[float, float] = (0.0, 1.0)) -> Image.Image | None: |
|
"""조명 선호도에 따른 그라데이션 이미지 생성""" |
|
start, end = interval |
|
match self: |
|
case LightingPreference.LEFT: |
|
tensor = torch.linspace(end, start, width).repeat(1, 1, height, 1) |
|
case LightingPreference.RIGHT: |
|
tensor = torch.linspace(start, end, width).repeat(1, 1, height, 1) |
|
case LightingPreference.TOP: |
|
tensor = torch.linspace(end, start, height).repeat(1, 1, width, 1).transpose(2, 3) |
|
case LightingPreference.BOTTOM: |
|
tensor = torch.linspace(start, end, height).repeat(1, 1, width, 1).transpose(2, 3) |
|
case LightingPreference.NONE: |
|
return None |
|
|
|
return tensor_to_image(tensor).convert("RGB") |
|
|
|
@classmethod |
|
def from_str(cls, value: str): |
|
match value.lower(): |
|
case "left": |
|
return LightingPreference.LEFT |
|
case "right": |
|
return LightingPreference.RIGHT |
|
case "top": |
|
return LightingPreference.TOP |
|
case "bottom": |
|
return LightingPreference.BOTTOM |
|
case "none": |
|
return LightingPreference.NONE |
|
case _: |
|
raise ValueError(f"Invalid lighting preference: {value}") |