Deadmon's picture
Update app.py
756128e verified
raw
history blame
10.6 kB
import os
import random
import spaces
import gradio as gr
import numpy as np
import PIL.Image
import torch
import torchvision.transforms.functional as TF
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
from controlnet_aux import PidiNetDetector, HEDdetector
from diffusers.utils import load_image
from huggingface_hub import HfApi, snapshot_download
from pathlib import Path
from PIL import Image, ImageOps
import cv2
from gradio_imageslider import ImageSlider
js_func = """
function refresh() {
const url = new URL(window.location);
}
"""
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
DESCRIPTION = ''''''
if not torch.cuda.is_available():
DESCRIPTION += ""
style_list = [
# ... (style list remains the same)
]
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + negative
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
# Download the model files
ckpt_dir_pony = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
ckpt_dir_cyber = snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl")
ckpt_dir_stallion = snapshot_download(repo_id="John6666/stallion-dreams-pony-realistic-v1-sdxl")
# Load the models
vae_pony = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_pony, "vae"), torch_dtype=torch.float16)
vae_cyber = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_cyber, "vae"), torch_dtype=torch.float16)
vae_stallion = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_stallion, "vae"), torch_dtype=torch.float16)
controlnet_pony = ControlNetModel.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16)
controlnet_cyber = ControlNetModel.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16)
controlnet_stallion = ControlNetModel.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16)
pipe_pony = StableDiffusionXLControlNetPipeline.from_pretrained(
ckpt_dir_pony, controlnet=controlnet_pony, vae=vae_pony, torch_dtype=torch.float16, scheduler=eulera_scheduler
)
pipe_cyber = StableDiffusionXLControlNetPipeline.from_pretrained(
ckpt_dir_cyber, controlnet=controlnet_cyber, vae=vae_cyber, torch_dtype=torch.float16, scheduler=eulera_scheduler
)
pipe_stallion = StableDiffusionXLControlNetPipeline.from_pretrained(
ckpt_dir_stallion, controlnet=controlnet_stallion, vae=vae_stallion, torch_dtype=torch.float16, scheduler=eulera_scheduler
)
MAX_SEED = np.iinfo(np.int32).max
processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=120)
def run(
image: dict,
prompt: str,
negative_prompt: str,
model_choice: str, # Add this new input
style_name: str = DEFAULT_STYLE_NAME,
num_steps: int = 25,
guidance_scale: float = 5,
controlnet_conditioning_scale: float = 1.0,
seed: int = 0,
use_hed: bool = False,
use_canny: bool = False,
progress=gr.Progress(track_tqdm=True),
) -> PIL.Image.Image:
# Get the composite image from the EditorValue dict
composite_image = image['composite']
width, height = composite_image.size
# Calculate new dimensions to fit within 1024x1024 while maintaining aspect ratio
max_size = 1024
ratio = min(max_size / width, max_size / height)
new_width = int(width * ratio)
new_height = int(height * ratio)
# Resize the image
resized_image = composite_image.resize((new_width, new_height), Image.LANCZOS)
if use_canny:
controlnet_img = np.array(resized_image)
controlnet_img = cv2.Canny(controlnet_img, 100, 200)
controlnet_img = HWC3(controlnet_img)
image = Image.fromarray(controlnet_img)
elif not use_hed:
controlnet_img = resized_image
image = resized_image
else:
controlnet_img = processor(resized_image, scribble=False)
controlnet_img = np.array(controlnet_img)
controlnet_img = nms(controlnet_img, 127, 3)
controlnet_img = cv2.GaussianBlur(controlnet_img, (0, 0), 3)
random_val = int(round(random.uniform(0.01, 0.10), 2) * 255)
controlnet_img[controlnet_img > random_val] = 255
controlnet_img[controlnet_img < 255] = 0
image = Image.fromarray(controlnet_img)
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
generator = torch.Generator(device=device).manual_seed(seed)
# Select the appropriate pipe based on the model choice
if model_choice == "Pony Realism v21":
pipe = pipe_pony
elif model_choice == "Cyber Realistic Pony v61":
pipe = pipe_cyber
else: # "Stallion Dreams Pony Realistic v1"
pipe = pipe_stallion
pipe.to(device)
if use_canny:
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=num_steps,
generator=generator,
controlnet_conditioning_scale=controlnet_conditioning_scale,
guidance_scale=guidance_scale,
width=new_width,
height=new_height,
).images[0]
else:
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=num_steps,
generator=generator,
controlnet_conditioning_scale=controlnet_conditioning_scale,
guidance_scale=guidance_scale,
width=new_width,
height=new_height,
).images[0]
pipe.to("cpu")
torch.cuda.empty_cache()
return (controlnet_img, out)
with gr.Blocks(css="style.css", js=js_func) as demo:
gr.Markdown(DESCRIPTION, elem_id="description")
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Row():
with gr.Column():
with gr.Group():
image = gr.ImageEditor(type="pil", label="Sketch your image or upload one", width=512, height=512)
prompt = gr.Textbox(label="Prompt")
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
model_choice = gr.Dropdown(
["Pony Realism v21", "Cyber Realistic Pony v61", "Stallion Dreams Pony Realistic v1"],
label="Model Choice",
value="Pony Realism v21"
)
use_hed = gr.Checkbox(label="use HED detector", value=False, info="check this box if you upload an image and want to turn it to a sketch")
use_canny = gr.Checkbox(label="use Canny", value=False, info="check this to use ControlNet canny instead of scribble")
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
)
num_steps = gr.Slider(
label="Number of steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.1,
maximum=10.0,
step=0.1,
value=5,
)
controlnet_conditioning_scale = gr.Slider(
label="controlnet conditioning scale",
minimum=0.5,
maximum=5.0,
step=0.1,
value=0.9,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Column():
with gr.Group():
image_slider = ImageSlider(position=0.5)
inputs = [
image,
prompt,
negative_prompt,
model_choice, # Add this new input
style,
num_steps,
guidance_scale,
controlnet_conditioning_scale,
seed,
use_hed,
use_canny
]
outputs = [image_slider]
run_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(lambda x: None, inputs=None, outputs=image_slider).then(
fn=run, inputs=inputs, outputs=outputs
)
demo.queue().launch(show_error=True, ssl_verify=False)