Spaces:
Running
on
A10G
Running
on
A10G
import gc | |
import os, sys | |
sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
sys.path.append(os.path.dirname(__file__)) | |
from pathlib import Path | |
import traceback | |
from typing import List, Literal, Optional, Union, Dict | |
import numpy as np | |
import torch | |
from diffusers import AutoencoderTiny, StableDiffusionPipeline | |
from PIL import Image | |
from streamdiffusion import StreamDiffusion | |
from streamdiffusion.image_utils import postprocess_image | |
torch.set_grad_enabled(False) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
from moviepy.editor import ImageSequenceClip | |
class StreamDiffusionWrapper: | |
def __init__( | |
self, | |
model_id_or_path: str, | |
t_index_list: List[int], | |
lora_dict: Optional[Dict[str, float]] = None, | |
mode: Literal["img2img", "txt2img"] = "img2img", | |
output_type: Literal["pil", "pt", "np", "latent"] = "pil", | |
lcm_lora_id: Optional[str] = None, | |
vae_id: Optional[str] = None, | |
device: Literal["cpu", "cuda"] = "cuda", | |
dtype: torch.dtype = torch.float16, | |
frame_buffer_size: int = 1, | |
width: int = 512, | |
height: int = 512, | |
warmup: int = 10, | |
acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", | |
do_add_noise: bool = True, | |
device_ids: Optional[List[int]] = None, | |
use_lcm_lora: bool = True, | |
use_tiny_vae: bool = True, | |
enable_similar_image_filter: bool = False, | |
similar_image_filter_threshold: float = 0.98, | |
similar_image_filter_max_skip_frame: int = 10, | |
use_denoising_batch: bool = True, | |
cfg_type: Literal["none", "full", "self", "initialize"] = "self", | |
seed: int = 2, | |
use_safety_checker: bool = False, | |
engine_dir: Optional[Union[str, Path]] = "engines", | |
): | |
""" | |
Initializes the StreamDiffusionWrapper. | |
Parameters | |
---------- | |
model_id_or_path : str | |
The model id or path to load. | |
t_index_list : List[int] | |
The t_index_list to use for inference. | |
lora_dict : Optional[Dict[str, float]], optional | |
The lora_dict to load, by default None. | |
Keys are the LoRA names and values are the LoRA scales. | |
Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} | |
mode : Literal["img2img", "txt2img"], optional | |
txt2img or img2img, by default "img2img". | |
output_type : Literal["pil", "pt", "np", "latent"], optional | |
The output type of image, by default "pil". | |
lcm_lora_id : Optional[str], optional | |
The lcm_lora_id to load, by default None. | |
If None, the default LCM-LoRA | |
("latent-consistency/lcm-lora-sdv1-5") will be used. | |
vae_id : Optional[str], optional | |
The vae_id to load, by default None. | |
If None, the default TinyVAE | |
("madebyollin/taesd") will be used. | |
device : Literal["cpu", "cuda"], optional | |
The device to use for inference, by default "cuda". | |
dtype : torch.dtype, optional | |
The dtype for inference, by default torch.float16. | |
frame_buffer_size : int, optional | |
The frame buffer size for denoising batch, by default 1. | |
width : int, optional | |
The width of the image, by default 512. | |
height : int, optional | |
The height of the image, by default 512. | |
warmup : int, optional | |
The number of warmup steps to perform, by default 10. | |
acceleration : Literal["none", "xformers", "tensorrt"], optional | |
The acceleration method, by default "tensorrt". | |
do_add_noise : bool, optional | |
Whether to add noise for following denoising steps or not, | |
by default True. | |
device_ids : Optional[List[int]], optional | |
The device ids to use for DataParallel, by default None. | |
use_lcm_lora : bool, optional | |
Whether to use LCM-LoRA or not, by default True. | |
use_tiny_vae : bool, optional | |
Whether to use TinyVAE or not, by default True. | |
enable_similar_image_filter : bool, optional | |
Whether to enable similar image filter or not, | |
by default False. | |
similar_image_filter_threshold : float, optional | |
The threshold for similar image filter, by default 0.98. | |
similar_image_filter_max_skip_frame : int, optional | |
The max skip frame for similar image filter, by default 10. | |
use_denoising_batch : bool, optional | |
Whether to use denoising batch or not, by default True. | |
cfg_type : Literal["none", "full", "self", "initialize"], | |
optional | |
The cfg_type for img2img mode, by default "self". | |
You cannot use anything other than "none" for txt2img mode. | |
seed : int, optional | |
The seed, by default 2. | |
use_safety_checker : bool, optional | |
Whether to use safety checker or not, by default False. | |
""" | |
if not torch.cuda.is_available(): | |
device = 'cpu' | |
dtype = torch.float32 | |
self.sd_turbo = "turbo" in model_id_or_path | |
# print("Mode:",mode) | |
if mode == "txt2img": | |
if cfg_type != "none": | |
raise ValueError( | |
f"txt2img mode accepts only cfg_type = 'none', but got {cfg_type}" | |
) | |
if use_denoising_batch and frame_buffer_size > 1: | |
if not self.sd_turbo: | |
raise ValueError( | |
"txt2img mode cannot use denoising batch with frame_buffer_size > 1." | |
) | |
if mode == "img2img": | |
if not use_denoising_batch: | |
raise NotImplementedError( | |
"img2img mode must use denoising batch for now." | |
) | |
self.device = device | |
self.dtype = dtype | |
self.width = width | |
self.height = height | |
self.mode = mode | |
self.output_type = output_type | |
self.frame_buffer_size = frame_buffer_size | |
self.batch_size = ( | |
len(t_index_list) * frame_buffer_size | |
if use_denoising_batch | |
else frame_buffer_size | |
) | |
self.use_denoising_batch = use_denoising_batch | |
self.use_safety_checker = use_safety_checker | |
self.stream: StreamDiffusion = self._load_model( | |
model_id_or_path=model_id_or_path, | |
lora_dict=lora_dict, | |
lcm_lora_id=lcm_lora_id, | |
vae_id=vae_id, | |
t_index_list=t_index_list, | |
acceleration=acceleration, | |
warmup=warmup, | |
do_add_noise=do_add_noise, | |
use_lcm_lora=use_lcm_lora, | |
use_tiny_vae=use_tiny_vae, | |
cfg_type=cfg_type, | |
seed=seed, | |
engine_dir=engine_dir, | |
) | |
if device_ids is not None: | |
self.stream.unet = torch.nn.DataParallel( | |
self.stream.unet, device_ids=device_ids | |
) | |
if enable_similar_image_filter: | |
self.stream.enable_similar_image_filter(similar_image_filter_threshold, similar_image_filter_max_skip_frame) | |
def prepare( | |
self, | |
prompt: str, | |
negative_prompt: str = "", | |
num_inference_steps: int = 50, | |
guidance_scale: float = 1.2, | |
delta: float = 1.0, | |
) -> None: | |
""" | |
Prepares the model for inference. | |
Parameters | |
---------- | |
prompt : str | |
The prompt to generate images from. | |
num_inference_steps : int, optional | |
The number of inference steps to perform, by default 50. | |
guidance_scale : float, optional | |
The guidance scale to use, by default 1.2. | |
delta : float, optional | |
The delta multiplier of virtual residual noise, | |
by default 1.0. | |
""" | |
self.stream.prepare( | |
prompt, | |
negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
delta=delta, | |
) | |
def __call__( | |
self, | |
image: Optional[Union[str, Image.Image, torch.Tensor]] = None, | |
prompt: Optional[str] = None, | |
) -> Union[Image.Image, List[Image.Image]]: | |
""" | |
Performs img2img or txt2img based on the mode. | |
Parameters | |
---------- | |
image : Optional[Union[str, Image.Image, torch.Tensor]] | |
The image to generate from. | |
prompt : Optional[str] | |
The prompt to generate images from. | |
Returns | |
------- | |
Union[Image.Image, List[Image.Image]] | |
The generated image. | |
""" | |
if self.mode == "img2img": | |
return self.img2img(image, prompt) | |
elif self.mode == "txt2img": | |
return self.txt2img(prompt) | |
def txt2img( | |
self, prompt: Optional[str] = None | |
) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: | |
""" | |
Performs txt2img. | |
Parameters | |
---------- | |
prompt : Optional[str] | |
The prompt to generate images from. | |
Returns | |
------- | |
Union[Image.Image, List[Image.Image]] | |
The generated image. | |
""" | |
print("using txt2img") | |
if prompt is not None: | |
self.stream.update_prompt(prompt) | |
if self.sd_turbo: | |
image_tensor = self.stream.txt2img_sd_turbo(self.batch_size) | |
# print("image_tensor_1:",image_tensor.shape) | |
else: | |
image_tensor = self.stream.txt2img(self.frame_buffer_size) | |
# print("image_tensor_2:",image_tensor.shape) # torch.Size([1, 3, 512, 512]) | |
image = self.postprocess_image(image_tensor, output_type=self.output_type) | |
if self.use_safety_checker: | |
safety_checker_input = self.feature_extractor( | |
image, return_tensors="pt" | |
).to(self.device) | |
_, has_nsfw_concept = self.safety_checker( | |
images=image_tensor.to(self.dtype), | |
clip_input=safety_checker_input.pixel_values.to(self.dtype), | |
) | |
image = self.nsfw_fallback_img if has_nsfw_concept[0] else image | |
return image | |
def img2img( | |
self, image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None | |
) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: | |
""" | |
Performs img2img. | |
Parameters | |
---------- | |
image : Union[str, Image.Image, torch.Tensor] | |
The image to generate from. | |
Returns | |
------- | |
Image.Image | |
The generated image. | |
""" | |
print("using img2img") | |
if prompt is not None: | |
self.stream.update_prompt(prompt) | |
if isinstance(image, str) or isinstance(image, Image.Image): | |
image = self.preprocess_image(image) | |
image_tensor = self.stream(image) | |
image = self.postprocess_image(image_tensor, output_type=self.output_type) | |
if self.use_safety_checker: | |
safety_checker_input = self.feature_extractor( | |
image, return_tensors="pt" | |
).to(self.device) | |
_, has_nsfw_concept = self.safety_checker( | |
images=image_tensor.to(self.dtype), | |
clip_input=safety_checker_input.pixel_values.to(self.dtype), | |
) | |
image = self.nsfw_fallback_img if has_nsfw_concept[0] else image | |
return image | |
def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor: | |
""" | |
Preprocesses the image. | |
Parameters | |
---------- | |
image : Union[str, Image.Image, torch.Tensor] | |
The image to preprocess. | |
Returns | |
------- | |
torch.Tensor | |
The preprocessed image. | |
""" | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB").resize((self.width, self.height)) | |
if isinstance(image, Image.Image): | |
image = image.convert("RGB").resize((self.width, self.height)) | |
return self.stream.image_processor.preprocess( | |
image, self.height, self.width | |
).to(device=self.device, dtype=self.dtype) | |
def postprocess_image( | |
self, image_tensor: torch.Tensor, output_type: str = "pil" | |
) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: | |
""" | |
Postprocesses the image. | |
Parameters | |
---------- | |
image_tensor : torch.Tensor | |
The image tensor to postprocess. | |
Returns | |
------- | |
Union[Image.Image, List[Image.Image]] | |
The postprocessed image. | |
""" | |
if self.frame_buffer_size > 1: | |
return postprocess_image(image_tensor.cpu(), output_type=output_type) | |
else: | |
return postprocess_image(image_tensor.cpu(), output_type=output_type)[0] | |
def _load_model( | |
self, | |
model_id_or_path: str, | |
t_index_list: List[int], | |
lora_dict: Optional[Dict[str, float]] = None, | |
lcm_lora_id: Optional[str] = None, | |
vae_id: Optional[str] = None, | |
acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", | |
warmup: int = 10, | |
do_add_noise: bool = True, | |
use_lcm_lora: bool = True, | |
use_tiny_vae: bool = True, | |
cfg_type: Literal["none", "full", "self", "initialize"] = "self", | |
seed: int = 2, | |
engine_dir: Optional[Union[str, Path]] = "engines", | |
) -> StreamDiffusion: | |
""" | |
Loads the model. | |
This method does the following: | |
1. Loads the model from the model_id_or_path. | |
2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed. | |
3. Loads the VAE model from the vae_id if needed. | |
4. Enables acceleration if needed. | |
5. Prepares the model for inference. | |
6. Load the safety checker if needed. | |
Parameters | |
---------- | |
model_id_or_path : str | |
The model id or path to load. | |
t_index_list : List[int] | |
The t_index_list to use for inference. | |
lora_dict : Optional[Dict[str, float]], optional | |
The lora_dict to load, by default None. | |
Keys are the LoRA names and values are the LoRA scales. | |
Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} | |
lcm_lora_id : Optional[str], optional | |
The lcm_lora_id to load, by default None. | |
vae_id : Optional[str], optional | |
The vae_id to load, by default None. | |
acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional | |
The acceleration method, by default "tensorrt". | |
warmup : int, optional | |
The number of warmup steps to perform, by default 10. | |
do_add_noise : bool, optional | |
Whether to add noise for following denoising steps or not, | |
by default True. | |
use_lcm_lora : bool, optional | |
Whether to use LCM-LoRA or not, by default True. | |
use_tiny_vae : bool, optional | |
Whether to use TinyVAE or not, by default True. | |
cfg_type : Literal["none", "full", "self", "initialize"], | |
optional | |
The cfg_type for img2img mode, by default "self". | |
You cannot use anything other than "none" for txt2img mode. | |
seed : int, optional | |
The seed, by default 2. | |
Returns | |
------- | |
StreamDiffusion | |
The loaded model. | |
""" | |
try: # Load from local directory | |
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( | |
model_id_or_path, | |
).to(device=self.device, dtype=self.dtype) | |
except ValueError: # Load from huggingface | |
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file( | |
model_id_or_path, | |
).to(device=self.device, dtype=self.dtype) | |
except Exception: # No model found | |
traceback.print_exc() | |
print("Model load has failed. Doesn't exist.") | |
exit() | |
stream = StreamDiffusion( | |
pipe=pipe, | |
t_index_list=t_index_list, | |
torch_dtype=self.dtype, | |
width=self.width, | |
height=self.height, | |
do_add_noise=do_add_noise, | |
frame_buffer_size=self.frame_buffer_size, | |
use_denoising_batch=self.use_denoising_batch, | |
cfg_type=cfg_type, | |
) | |
print("self.sd_turbo:",self.sd_turbo) | |
print("use_lcm_lora:",use_lcm_lora) | |
print("lcm_lora_id:",lcm_lora_id) | |
print("lora_dict:",lora_dict) | |
print("use_tiny_vae:",use_tiny_vae) | |
print("vae_id:",vae_id) | |
if not self.sd_turbo: | |
if use_lcm_lora: | |
if lcm_lora_id is not None: | |
stream.load_lcm_lora( | |
pretrained_model_name_or_path_or_dict=lcm_lora_id | |
) | |
# stream.load_lcm_lora( | |
# pretrained_model_name_or_path_or_dict="/home/lab929/kyh/StreamDiffusion/lcm-lora-sdv1-5" | |
# ) | |
else: | |
stream.load_lcm_lora(pretrained_model_name_or_path_or_dict="/home/lab929/kyh/StreamDiffusion/lcm-lora-sdv1-5") | |
# stream.load_lcm_lora(pretrained_model_name_or_path_or_dict="/home/lab929/kyh/InteractiveVideo-Dev/checkpoints/lcm-lora-sdxl") | |
stream.fuse_lora() | |
if lora_dict is not None: | |
for lora_name, lora_scale in lora_dict.items(): | |
stream.load_lora(lora_name) | |
# stream.load_lora("/home/lab929/kyh/InteractiveVideo-Dev/checkpoints/genshin") | |
stream.fuse_lora(lora_scale=lora_scale) | |
print(f"Use LoRA: {lora_name} in weights {lora_scale}") | |
# stream.pipe.load_lora_weights("/home/lab929/kyh/InteractiveVideo-Dev/checkpoints/genshin") | |
# print("LORA WEIGHTS LOADED!") | |
# stream.load_lora("/home/lab929/kyh/InteractiveVideo-Dev/checkpoints/genshin") | |
# stream.fuse_lora(lora_scale=0.5) | |
# print(f"Use LoRA: genshin in weights 0.5") | |
if use_tiny_vae: | |
if vae_id is not None: | |
stream.vae = AutoencoderTiny.from_pretrained(vae_id).to( | |
device=pipe.device, dtype=pipe.dtype | |
) | |
else: | |
stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to( | |
device=pipe.device, dtype=pipe.dtype | |
) | |
try: | |
if acceleration == "xformers": | |
stream.pipe.enable_xformers_memory_efficient_attention() | |
if acceleration == "tensorrt": | |
from polygraphy import cuda | |
from streamdiffusion.acceleration.tensorrt import ( | |
TorchVAEEncoder, | |
compile_unet, | |
compile_vae_decoder, | |
compile_vae_encoder, | |
) | |
from streamdiffusion.acceleration.tensorrt.engine import ( | |
AutoencoderKLEngine, | |
UNet2DConditionModelEngine, | |
) | |
from streamdiffusion.acceleration.tensorrt.models import ( | |
VAE, | |
UNet, | |
VAEEncoder, | |
) | |
def create_prefix( | |
model_id_or_path: str, | |
max_batch_size: int, | |
min_batch_size: int, | |
): | |
maybe_path = Path(model_id_or_path) | |
if maybe_path.exists(): | |
return f"{maybe_path.stem}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" | |
else: | |
return f"{model_id_or_path}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" | |
engine_dir = Path(engine_dir) | |
unet_path = os.path.join( | |
engine_dir, | |
create_prefix( | |
model_id_or_path=model_id_or_path, | |
max_batch_size=stream.trt_unet_batch_size, | |
min_batch_size=stream.trt_unet_batch_size, | |
), | |
"unet.engine", | |
) | |
vae_encoder_path = os.path.join( | |
engine_dir, | |
create_prefix( | |
model_id_or_path=model_id_or_path, | |
max_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
min_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
), | |
"vae_encoder.engine", | |
) | |
vae_decoder_path = os.path.join( | |
engine_dir, | |
create_prefix( | |
model_id_or_path=model_id_or_path, | |
max_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
min_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
), | |
"vae_decoder.engine", | |
) | |
if not os.path.exists(unet_path): | |
os.makedirs(os.path.dirname(unet_path), exist_ok=True) | |
unet_model = UNet( | |
fp16=True if torch.cuda.is_available() else False, | |
device=stream.device, | |
max_batch_size=stream.trt_unet_batch_size, | |
min_batch_size=stream.trt_unet_batch_size, | |
embedding_dim=stream.text_encoder.config.hidden_size, | |
unet_dim=stream.unet.config.in_channels, | |
) | |
compile_unet( | |
stream.unet, | |
unet_model, | |
unet_path + ".onnx", | |
unet_path + ".opt.onnx", | |
unet_path, | |
opt_batch_size=stream.trt_unet_batch_size, | |
) | |
if not os.path.exists(vae_decoder_path): | |
os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True) | |
stream.vae.forward = stream.vae.decode | |
vae_decoder_model = VAE( | |
device=stream.device, | |
max_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
min_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
) | |
compile_vae_decoder( | |
stream.vae, | |
vae_decoder_model, | |
vae_decoder_path + ".onnx", | |
vae_decoder_path + ".opt.onnx", | |
vae_decoder_path, | |
opt_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
) | |
delattr(stream.vae, "forward") | |
if not os.path.exists(vae_encoder_path): | |
os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True) | |
vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda")) | |
vae_encoder_model = VAEEncoder( | |
device=stream.device, | |
max_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
min_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
) | |
compile_vae_encoder( | |
vae_encoder, | |
vae_encoder_model, | |
vae_encoder_path + ".onnx", | |
vae_encoder_path + ".opt.onnx", | |
vae_encoder_path, | |
opt_batch_size=self.batch_size | |
if self.mode == "txt2img" | |
else stream.frame_bff_size, | |
) | |
cuda_steram = cuda.Stream() | |
vae_config = stream.vae.config | |
vae_dtype = stream.vae.dtype | |
stream.unet = UNet2DConditionModelEngine( | |
unet_path, cuda_steram, use_cuda_graph=False | |
) | |
stream.vae = AutoencoderKLEngine( | |
vae_encoder_path, | |
vae_decoder_path, | |
cuda_steram, | |
stream.pipe.vae_scale_factor, | |
use_cuda_graph=False, | |
) | |
setattr(stream.vae, "config", vae_config) | |
setattr(stream.vae, "dtype", vae_dtype) | |
gc.collect() | |
torch.cuda.empty_cache() | |
print("TensorRT acceleration enabled.") | |
if acceleration == "sfast": | |
from streamdiffusion.acceleration.sfast import ( | |
accelerate_with_stable_fast, | |
) | |
stream = accelerate_with_stable_fast(stream) | |
print("StableFast acceleration enabled.") | |
except Exception: | |
traceback.print_exc() | |
print("Acceleration has failed. Falling back to normal mode.") | |
if seed < 0: # Random seed | |
seed = np.random.randint(0, 1000000) | |
stream.prepare( | |
"", | |
"", | |
num_inference_steps=50, | |
guidance_scale=1.1 | |
if stream.cfg_type in ["full", "self", "initialize"] | |
else 1.0, | |
generator=torch.manual_seed(seed), | |
seed=seed, | |
) | |
if self.use_safety_checker: | |
from transformers import CLIPFeatureExtractor | |
from diffusers.pipelines.stable_diffusion.safety_checker import ( | |
StableDiffusionSafetyChecker, | |
) | |
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
).to(pipe.device) | |
self.feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0)) | |
return stream | |