NightRaven109's picture
Update app.py
b22f2c5 verified
raw
history blame
6.1 kB
import os
import torch
import gradio as gr
import spaces
from PIL import Image
from huggingface_hub import snapshot_download
from test_ccsr_tile import main, load_pipeline
import argparse
from accelerate import Accelerator
# Initialize global variables
pipeline = None
generator = None
accelerator = None
class Args:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@spaces.GPU
def initialize_models():
global pipeline, generator, accelerator
try:
# Download model repository
model_path = snapshot_download(
repo_id="NightRaven109/CCSRModels",
token=os.environ['Read2']
)
# Set up default arguments
args = Args(
pretrained_model_path=os.path.join(model_path, "stable-diffusion-2-1-base"),
controlnet_model_path=os.path.join(model_path, "Controlnet"),
vae_model_path=os.path.join(model_path, "vae"),
mixed_precision="fp16",
tile_vae=False,
sample_method="ddpm",
vae_encoder_tile_size=1024,
vae_decoder_tile_size=224
)
# Initialize accelerator
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
)
# Load pipeline
pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
# Initialize generator
generator = torch.Generator(device=accelerator.device)
return True
except Exception as e:
print(f"Error initializing models: {str(e)}")
return False
@spaces.GPU
def process_image(
input_image,
prompt="clean, high-resolution, 8k",
negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
guidance_scale=1.0,
conditioning_scale=1.0,
num_inference_steps=20,
seed=42,
upscale_factor=2,
color_fix_method="adain"
):
global pipeline, generator, accelerator
if pipeline is None:
if not initialize_models():
return None
try:
# Create args object with all necessary parameters
args = Args(
added_prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
conditioning_scale=conditioning_scale,
num_inference_steps=num_inference_steps,
seed=seed,
upscale=upscale_factor,
process_size=512,
align_method=color_fix_method,
t_max=0.6666,
t_min=0.0,
tile_diffusion=False,
tile_diffusion_size=None,
tile_diffusion_stride=None,
start_steps=999,
start_point='lr',
use_vae_encode_condition=False,
sample_times=1
)
# Set seed if provided
if seed is not None:
generator.manual_seed(seed)
# Process input image
validation_image = Image.fromarray(input_image)
ori_width, ori_height = validation_image.size
# Resize logic
resize_flag = False
if ori_width < args.process_size//args.upscale or ori_height < args.process_size//args.upscale:
scale = (args.process_size//args.upscale)/min(ori_width, ori_height)
validation_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))
resize_flag = True
validation_image = validation_image.resize((validation_image.size[0]*args.upscale, validation_image.size[1]*args.upscale))
validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
width, height = validation_image.size
# Generate image
inference_time, output = pipeline(
args.t_max,
args.t_min,
args.tile_diffusion,
args.tile_diffusion_size,
args.tile_diffusion_stride,
args.added_prompt,
validation_image,
num_inference_steps=args.num_inference_steps,
generator=generator,
height=height,
width=width,
guidance_scale=args.guidance_scale,
negative_prompt=args.negative_prompt,
conditioning_scale=args.conditioning_scale,
start_steps=args.start_steps,
start_point=args.start_point,
use_vae_encode_condition=args.use_vae_encode_condition,
)
image = output.images[0]
# Apply color fixing if specified
if args.align_method != "none":
from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
fix_func = wavelet_color_fix if args.align_method == "wavelet" else adain_color_fix
image = fix_func(image, validation_image)
if resize_flag:
image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
return image
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
# Create Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(label="Input Image"),
gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
gr.Number(label="Seed", value=42),
gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
],
outputs=gr.Image(label="Generated Image"),
title="Controllable Conditional Super-Resolution",
description="Upload an image to enhance its resolution using CCSR."
)
if __name__ == "__main__":
iface.launch()