import os import cv2 import numpy as np from PIL import Image from diffusers.models.attention_processor import XFormersAttnProcessor from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection import torch from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler from animation.modules.attention_processor import AnimationAttnProcessor from animation.modules.attention_processor_normalized import AnimationIDAttnNormalizedProcessor from animation.modules.face_model import FaceModel from animation.modules.id_encoder import FusionFaceId from animation.modules.pose_net import PoseNet from animation.modules.unet import UNetSpatioTemporalConditionModel from animation.pipelines.inference_pipeline_animation import InferenceAnimationPipeline import random import gradio as gr import gc from datetime import datetime from pathlib import Path pretrained_model_name_or_path = "checkpoints/stable-video-diffusion-img2vid-xt" revision = None posenet_model_name_or_path = "checkpoints/Animation/pose_net.pth" face_encoder_model_name_or_path = "checkpoints/Animation/face_encoder.pth" unet_model_name_or_path = "checkpoints/Animation/unet.pth" def load_images_from_folder(folder, width, height): images = [] files = os.listdir(folder) png_files = [f for f in files if f.endswith('.png')] png_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) for filename in png_files: img = Image.open(os.path.join(folder, filename)).convert('RGB') img = img.resize((width, height)) images.append(img) return images def save_frames_as_png(frames, output_path): pil_frames = [Image.fromarray(frame) if isinstance(frame, np.ndarray) else frame for frame in frames] num_frames = len(pil_frames) for i in range(num_frames): pil_frame = pil_frames[i] save_path = os.path.join(output_path, f'frame_{i}.png') pil_frame.save(save_path) def save_frames_as_mp4(frames, output_mp4_path, fps): print("Starting saving the frames as mp4") height, width, _ = frames[0].shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 'H264' for better quality out = cv2.VideoWriter(output_mp4_path, fourcc, fps, (width, height)) for frame in frames: frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() def export_to_gif(frames, output_gif_path, fps): """ Export a list of frames to a GIF. Args: - frames (list): List of frames (as numpy arrays or PIL Image objects). - output_gif_path (str): Path to save the output GIF. - duration_ms (int): Duration of each frame in milliseconds. """ # Convert numpy arrays to PIL Images if needed pil_frames = [Image.fromarray(frame) if isinstance( frame, np.ndarray) else frame for frame in frames] pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'), format='GIF', append_images=pil_frames[1:], save_all=True, duration=125, loop=0) def generate( image_input: str, pose_input: str, width: int, height: int, guidance_scale: float, num_inference_steps: int, fps: int, frames_overlap: int, tile_size: int, noise_aug_strength: float, decode_chunk_size: int, seed: int, ): gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = Path("outputs") output_dir = os.path.join(output_dir, timestamp) if seed == -1: seed = random.randint(1, 2**20 - 1) generator = torch.Generator(device=device).manual_seed(seed) pipeline = InferenceAnimationPipeline( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=noise_scheduler, feature_extractor=feature_extractor, pose_net=pose_net, face_encoder=face_encoder, ).to(device=device, dtype=dtype) validation_image_path = image_input validation_image = Image.open(image_input).convert('RGB') validation_control_images = load_images_from_folder(pose_input, width=width, height=height) num_frames = len(validation_control_images) face_model.face_helper.clean_all() validation_face = cv2.imread(validation_image_path) validation_image_bgr = cv2.cvtColor(validation_face, cv2.COLOR_RGB2BGR) validation_image_face_info = face_model.app.get(validation_image_bgr) if len(validation_image_face_info) > 0: validation_image_face_info = sorted(validation_image_face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1] validation_image_id_ante_embedding = validation_image_face_info['embedding'] else: validation_image_id_ante_embedding = None if validation_image_id_ante_embedding is None: face_model.face_helper.read_image(validation_image_bgr) face_model.face_helper.get_face_landmarks_5(only_center_face=True) face_model.face_helper.align_warp_face() if len(face_model.face_helper.cropped_faces) == 0: validation_image_id_ante_embedding = np.zeros((512,)) else: validation_image_align_face = face_model.face_helper.cropped_faces[0] print('fail to detect face using insightface, extract embedding on align face') validation_image_id_ante_embedding = face_model.handler_ante.get_feat(validation_image_align_face) # generator = torch.Generator(device=accelerator.device).manual_seed(23123134) decode_chunk_size = decode_chunk_size video_frames = pipeline( image=validation_image, image_pose=validation_control_images, height=height, width=width, num_frames=num_frames, tile_size=tile_size, tile_overlap=frames_overlap, decode_chunk_size=decode_chunk_size, motion_bucket_id=127., fps=7, min_guidance_scale=guidance_scale, max_guidance_scale=guidance_scale, noise_aug_strength=noise_aug_strength, num_inference_steps=num_inference_steps, generator=generator, output_type="pil", validation_image_id_ante_embedding=validation_image_id_ante_embedding, ).frames[0] out_file = os.path.join( output_dir, f"animation_video.mp4", ) for i in range(num_frames): img = video_frames[i] video_frames[i] = np.array(img) png_out_file = os.path.join(output_dir, "animated_images") os.makedirs(png_out_file, exist_ok=True) save_frames_as_mp4(video_frames, out_file, fps) export_to_gif(video_frames, out_file, fps) save_frames_as_png(video_frames, png_out_file) seed_update = gr.update(visible=True, value=seed) return out_file, seed_update with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False) as demo: gr.Markdown("""

StableAnimator

🌐 Github | 📜 arXiv
⚠️ This demo is for academic research and experiential use only.
""") with gr.Row(): with gr.Column(): with gr.Group(): image_input = gr.Image(label="Reference Image", type="filepath") pose_input = gr.Textbox(label="Driven Poses", placeholder="Please enter your driven pose directory here.") with gr.Group(): with gr.Row(): width = gr.Number(label="Width (supports only 512×512 and 576×1024)", value=512) height = gr.Number(label="Height (supports only 512×512 and 576×1024)", value=512) with gr.Row(): guidance_scale = gr.Number(label="Guidance scale (recommended 3.0)", value=3.0, step=0.1, precision=1) num_inference_steps = gr.Number(label="Inference steps (recommended 25)", value=20) with gr.Row(): fps = gr.Number(label="FPS", value=8) frames_overlap = gr.Number(label="Overlap Frames (recommended 4)", value=4) with gr.Row(): tile_size = gr.Number(label="Tile Size (recommended 16)", value=16) noise_aug_strength = gr.Number(label="Noise Augmentation Strength (recommended 0.02)", value=0.02, step=0.01, precision=2) with gr.Row(): decode_chunk_size = gr.Number(label="Decode Chunk Size (recommended 4 or 16)", value=4) seed = gr.Number(label="Random Seed (Enter a positive number, -1 for random)", value=-1) generate_button = gr.Button("🎬 Generate The Video") with gr.Column(): video_output = gr.Video(label="Generate The Video") with gr.Row(): seed_text = gr.Number(label="Video Generation Seed", visible=False, interactive=False) gr.Examples([ ["inference/case-1/reference.png","inference/case-1/poses",512,512], ["inference/case-2/reference.png","inference/case-2/poses",512,512], ["inference/case-3/reference.png","inference/case-3/poses",512,512], ["inference/case-4/reference.png","inference/case-4/poses",512,512], ["inference/case-5/reference.png","inference/case-5/poses",576,1024], ], inputs=[image_input, pose_input, width, height]) generate_button.click( generate, inputs=[image_input, pose_input, width, height, guidance_scale, num_inference_steps, fps, frames_overlap, tile_size, noise_aug_strength, decode_chunk_size, seed], outputs=[video_output, seed_text], ) if __name__ == "__main__": feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_name_or_path, subfolder="feature_extractor", revision=revision) noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=revision) vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) unet = UNetSpatioTemporalConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", low_cpu_mem_usage=True, ) pose_net = PoseNet(noise_latent_channels=unet.config.block_out_channels[0]) face_encoder = FusionFaceId( cross_attention_dim=1024, id_embeddings_dim=512, # clip_embeddings_dim=image_encoder.config.hidden_size, clip_embeddings_dim=1024, num_tokens=4, ) face_model = FaceModel() lora_rank = 128 attn_procs = {} unet_svd = unet.state_dict() for name in unet.attn_processors.keys(): if "transformer_blocks" in name and "temporal_transformer_blocks" not in name: cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: # print(f"This is AnimationAttnProcessor: {name}") attn_procs[name] = AnimationAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank) else: # print(f"This is AnimationIDAttnProcessor: {name}") layer_name = name.split(".processor")[0] weights = { "to_k_ip.weight": unet_svd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_svd[layer_name + ".to_v.weight"], } attn_procs[name] = AnimationIDAttnNormalizedProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank) attn_procs[name].load_state_dict(weights, strict=False) elif "temporal_transformer_blocks" in name: cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = XFormersAttnProcessor() else: attn_procs[name] = XFormersAttnProcessor() unet.set_attn_processor(attn_procs) # resume the previous checkpoint if posenet_model_name_or_path is not None and face_encoder_model_name_or_path is not None and unet_model_name_or_path is not None: print("Loading existing posenet weights, face_encoder weights and unet weights.") if posenet_model_name_or_path.endswith(".pth"): pose_net_state_dict = torch.load(posenet_model_name_or_path, map_location="cpu") pose_net.load_state_dict(pose_net_state_dict, strict=True) else: print("posenet weights loading fail") print(1/0) if face_encoder_model_name_or_path.endswith(".pth"): face_encoder_state_dict = torch.load(face_encoder_model_name_or_path, map_location="cpu") face_encoder.load_state_dict(face_encoder_state_dict, strict=True) else: print("face_encoder weights loading fail") print(1/0) if unet_model_name_or_path.endswith(".pth"): unet_state_dict = torch.load(unet_model_name_or_path, map_location="cpu") unet.load_state_dict(unet_state_dict, strict=True) else: print("unet weights loading fail") print(1/0) vae.requires_grad_(False) image_encoder.requires_grad_(False) unet.requires_grad_(False) pose_net.requires_grad_(False) face_encoder.requires_grad_(False) total_vram_in_gb = torch.cuda.get_device_properties(0).total_memory / 1073741824 print(f'\033[32mCUDA version:{torch.version.cuda}\033[0m') print(f'\033[32mPytorch version:{torch.__version__}\033[0m') print(f'\033[32mGPU Type:{torch.cuda.get_device_name()}\033[0m') print(f'\033[32mGPU Memory:{total_vram_in_gb:.2f}GB\033[0m') if torch.cuda.get_device_capability()[0] >= 8: print(f'\033[32mSupports BF16, use BF16\033[0m') dtype = torch.bfloat16 else: print(f'\033[32mBF16 is not supported, use FP16. The 5B model is not recommended\033[0m') dtype = torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" demo.queue().launch(inline=False, share=False, debug=True, server_name='0.0.0.0', server_port=7860, allowed_paths=["/content"])