Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from diffusers.models.attention_processor import XFormersAttnProcessor | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
import torch | |
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler | |
from animation.modules.attention_processor import AnimationAttnProcessor | |
from animation.modules.attention_processor_normalized import AnimationIDAttnNormalizedProcessor | |
from animation.modules.face_model import FaceModel | |
from animation.modules.id_encoder import FusionFaceId | |
from animation.modules.pose_net import PoseNet | |
from animation.modules.unet import UNetSpatioTemporalConditionModel | |
from animation.pipelines.inference_pipeline_animation import InferenceAnimationPipeline | |
import random | |
import gradio as gr | |
import gc | |
from datetime import datetime | |
from pathlib import Path | |
pretrained_model_name_or_path = "checkpoints/stable-video-diffusion-img2vid-xt" | |
revision = None | |
posenet_model_name_or_path = "checkpoints/Animation/pose_net.pth" | |
face_encoder_model_name_or_path = "checkpoints/Animation/face_encoder.pth" | |
unet_model_name_or_path = "checkpoints/Animation/unet.pth" | |
def load_images_from_folder(folder, width, height): | |
images = [] | |
files = os.listdir(folder) | |
png_files = [f for f in files if f.endswith('.png')] | |
png_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) | |
for filename in png_files: | |
img = Image.open(os.path.join(folder, filename)).convert('RGB') | |
img = img.resize((width, height)) | |
images.append(img) | |
return images | |
def save_frames_as_png(frames, output_path): | |
pil_frames = [Image.fromarray(frame) if isinstance(frame, np.ndarray) else frame for frame in frames] | |
num_frames = len(pil_frames) | |
for i in range(num_frames): | |
pil_frame = pil_frames[i] | |
save_path = os.path.join(output_path, f'frame_{i}.png') | |
pil_frame.save(save_path) | |
def save_frames_as_mp4(frames, output_mp4_path, fps): | |
print("Starting saving the frames as mp4") | |
height, width, _ = frames[0].shape | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 'H264' for better quality | |
out = cv2.VideoWriter(output_mp4_path, fourcc, fps, (width, height)) | |
for frame in frames: | |
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
out.write(frame_bgr) | |
out.release() | |
def export_to_gif(frames, output_gif_path, fps): | |
""" | |
Export a list of frames to a GIF. | |
Args: | |
- frames (list): List of frames (as numpy arrays or PIL Image objects). | |
- output_gif_path (str): Path to save the output GIF. | |
- duration_ms (int): Duration of each frame in milliseconds. | |
""" | |
# Convert numpy arrays to PIL Images if needed | |
pil_frames = [Image.fromarray(frame) if isinstance( | |
frame, np.ndarray) else frame for frame in frames] | |
pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'), | |
format='GIF', | |
append_images=pil_frames[1:], | |
save_all=True, | |
duration=125, | |
loop=0) | |
def generate( | |
image_input: str, | |
pose_input: str, | |
width: int, | |
height: int, | |
guidance_scale: float, | |
num_inference_steps: int, | |
fps: int, | |
frames_overlap: int, | |
tile_size: int, | |
noise_aug_strength: float, | |
decode_chunk_size: int, | |
seed: int, | |
): | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_dir = Path("outputs") | |
output_dir = os.path.join(output_dir, timestamp) | |
if seed == -1: | |
seed = random.randint(1, 2**20 - 1) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
pipeline = InferenceAnimationPipeline( | |
vae=vae, | |
image_encoder=image_encoder, | |
unet=unet, | |
scheduler=noise_scheduler, | |
feature_extractor=feature_extractor, | |
pose_net=pose_net, | |
face_encoder=face_encoder, | |
).to(device=device, dtype=dtype) | |
validation_image_path = image_input | |
validation_image = Image.open(image_input).convert('RGB') | |
validation_control_images = load_images_from_folder(pose_input, width=width, height=height) | |
num_frames = len(validation_control_images) | |
face_model.face_helper.clean_all() | |
validation_face = cv2.imread(validation_image_path) | |
validation_image_bgr = cv2.cvtColor(validation_face, cv2.COLOR_RGB2BGR) | |
validation_image_face_info = face_model.app.get(validation_image_bgr) | |
if len(validation_image_face_info) > 0: | |
validation_image_face_info = sorted(validation_image_face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1] | |
validation_image_id_ante_embedding = validation_image_face_info['embedding'] | |
else: | |
validation_image_id_ante_embedding = None | |
if validation_image_id_ante_embedding is None: | |
face_model.face_helper.read_image(validation_image_bgr) | |
face_model.face_helper.get_face_landmarks_5(only_center_face=True) | |
face_model.face_helper.align_warp_face() | |
if len(face_model.face_helper.cropped_faces) == 0: | |
validation_image_id_ante_embedding = np.zeros((512,)) | |
else: | |
validation_image_align_face = face_model.face_helper.cropped_faces[0] | |
print('fail to detect face using insightface, extract embedding on align face') | |
validation_image_id_ante_embedding = face_model.handler_ante.get_feat(validation_image_align_face) | |
# generator = torch.Generator(device=accelerator.device).manual_seed(23123134) | |
decode_chunk_size = decode_chunk_size | |
video_frames = pipeline( | |
image=validation_image, | |
image_pose=validation_control_images, | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
tile_size=tile_size, | |
tile_overlap=frames_overlap, | |
decode_chunk_size=decode_chunk_size, | |
motion_bucket_id=127., | |
fps=7, | |
min_guidance_scale=guidance_scale, | |
max_guidance_scale=guidance_scale, | |
noise_aug_strength=noise_aug_strength, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
output_type="pil", | |
validation_image_id_ante_embedding=validation_image_id_ante_embedding, | |
).frames[0] | |
out_file = os.path.join( | |
output_dir, | |
f"animation_video.mp4", | |
) | |
for i in range(num_frames): | |
img = video_frames[i] | |
video_frames[i] = np.array(img) | |
png_out_file = os.path.join(output_dir, "animated_images") | |
os.makedirs(png_out_file, exist_ok=True) | |
save_frames_as_mp4(video_frames, out_file, fps) | |
export_to_gif(video_frames, out_file, fps) | |
save_frames_as_png(video_frames, png_out_file) | |
seed_update = gr.update(visible=True, value=seed) | |
return out_file, seed_update | |
with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False) as demo: | |
gr.Markdown(""" | |
<div> | |
<h2 style="font-size: 30px;text-align: center;">StableAnimator</h2> | |
</div> | |
<div style="text-align: center;"> | |
<a href="https://github.com/Francis-Rings/StableAnimator">🌐 Github</a> | | |
<a href="https://arxiv.org/abs/2411.17697">📜 arXiv </a> | |
</div> | |
<div style="text-align: center; font-weight: bold; color: red;"> | |
⚠️ This demo is for academic research and experiential use only. | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
image_input = gr.Image(label="Reference Image", type="filepath") | |
pose_input = gr.Textbox(label="Driven Poses", placeholder="Please enter your driven pose directory here.") | |
with gr.Group(): | |
with gr.Row(): | |
width = gr.Number(label="Width (supports only 512×512 and 576×1024)", value=512) | |
height = gr.Number(label="Height (supports only 512×512 and 576×1024)", value=512) | |
with gr.Row(): | |
guidance_scale = gr.Number(label="Guidance scale (recommended 3.0)", value=3.0, step=0.1, precision=1) | |
num_inference_steps = gr.Number(label="Inference steps (recommended 25)", value=20) | |
with gr.Row(): | |
fps = gr.Number(label="FPS", value=8) | |
frames_overlap = gr.Number(label="Overlap Frames (recommended 4)", value=4) | |
with gr.Row(): | |
tile_size = gr.Number(label="Tile Size (recommended 16)", value=16) | |
noise_aug_strength = gr.Number(label="Noise Augmentation Strength (recommended 0.02)", value=0.02, step=0.01, precision=2) | |
with gr.Row(): | |
decode_chunk_size = gr.Number(label="Decode Chunk Size (recommended 4 or 16)", value=4) | |
seed = gr.Number(label="Random Seed (Enter a positive number, -1 for random)", value=-1) | |
generate_button = gr.Button("🎬 Generate The Video") | |
with gr.Column(): | |
video_output = gr.Video(label="Generate The Video") | |
with gr.Row(): | |
seed_text = gr.Number(label="Video Generation Seed", visible=False, interactive=False) | |
gr.Examples([ | |
["inference/case-1/reference.png","inference/case-1/poses",512,512], | |
["inference/case-2/reference.png","inference/case-2/poses",512,512], | |
["inference/case-3/reference.png","inference/case-3/poses",512,512], | |
["inference/case-4/reference.png","inference/case-4/poses",512,512], | |
["inference/case-5/reference.png","inference/case-5/poses",576,1024], | |
], inputs=[image_input, pose_input, width, height]) | |
generate_button.click( | |
generate, | |
inputs=[image_input, pose_input, width, height, guidance_scale, num_inference_steps, fps, frames_overlap, tile_size, noise_aug_strength, decode_chunk_size, seed], | |
outputs=[video_output, seed_text], | |
) | |
if __name__ == "__main__": | |
feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_name_or_path, subfolder="feature_extractor", revision=revision) | |
noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=revision) | |
vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) | |
unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="unet", | |
low_cpu_mem_usage=True, | |
) | |
pose_net = PoseNet(noise_latent_channels=unet.config.block_out_channels[0]) | |
face_encoder = FusionFaceId( | |
cross_attention_dim=1024, | |
id_embeddings_dim=512, | |
# clip_embeddings_dim=image_encoder.config.hidden_size, | |
clip_embeddings_dim=1024, | |
num_tokens=4, ) | |
face_model = FaceModel() | |
lora_rank = 128 | |
attn_procs = {} | |
unet_svd = unet.state_dict() | |
for name in unet.attn_processors.keys(): | |
if "transformer_blocks" in name and "temporal_transformer_blocks" not in name: | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
hidden_size = unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = unet.config.block_out_channels[block_id] | |
if cross_attention_dim is None: | |
# print(f"This is AnimationAttnProcessor: {name}") | |
attn_procs[name] = AnimationAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank) | |
else: | |
# print(f"This is AnimationIDAttnProcessor: {name}") | |
layer_name = name.split(".processor")[0] | |
weights = { | |
"to_k_ip.weight": unet_svd[layer_name + ".to_k.weight"], | |
"to_v_ip.weight": unet_svd[layer_name + ".to_v.weight"], | |
} | |
attn_procs[name] = AnimationIDAttnNormalizedProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank) | |
attn_procs[name].load_state_dict(weights, strict=False) | |
elif "temporal_transformer_blocks" in name: | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
hidden_size = unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = unet.config.block_out_channels[block_id] | |
if cross_attention_dim is None: | |
attn_procs[name] = XFormersAttnProcessor() | |
else: | |
attn_procs[name] = XFormersAttnProcessor() | |
unet.set_attn_processor(attn_procs) | |
# resume the previous checkpoint | |
if posenet_model_name_or_path is not None and face_encoder_model_name_or_path is not None and unet_model_name_or_path is not None: | |
print("Loading existing posenet weights, face_encoder weights and unet weights.") | |
if posenet_model_name_or_path.endswith(".pth"): | |
pose_net_state_dict = torch.load(posenet_model_name_or_path, map_location="cpu") | |
pose_net.load_state_dict(pose_net_state_dict, strict=True) | |
else: | |
print("posenet weights loading fail") | |
print(1/0) | |
if face_encoder_model_name_or_path.endswith(".pth"): | |
face_encoder_state_dict = torch.load(face_encoder_model_name_or_path, map_location="cpu") | |
face_encoder.load_state_dict(face_encoder_state_dict, strict=True) | |
else: | |
print("face_encoder weights loading fail") | |
print(1/0) | |
if unet_model_name_or_path.endswith(".pth"): | |
unet_state_dict = torch.load(unet_model_name_or_path, map_location="cpu") | |
unet.load_state_dict(unet_state_dict, strict=True) | |
else: | |
print("unet weights loading fail") | |
print(1/0) | |
vae.requires_grad_(False) | |
image_encoder.requires_grad_(False) | |
unet.requires_grad_(False) | |
pose_net.requires_grad_(False) | |
face_encoder.requires_grad_(False) | |
total_vram_in_gb = torch.cuda.get_device_properties(0).total_memory / 1073741824 | |
print(f'\033[32mCUDA version:{torch.version.cuda}\033[0m') | |
print(f'\033[32mPytorch version:{torch.__version__}\033[0m') | |
print(f'\033[32mGPU Type:{torch.cuda.get_device_name()}\033[0m') | |
print(f'\033[32mGPU Memory:{total_vram_in_gb:.2f}GB\033[0m') | |
if torch.cuda.get_device_capability()[0] >= 8: | |
print(f'\033[32mSupports BF16, use BF16\033[0m') | |
dtype = torch.bfloat16 | |
else: | |
print(f'\033[32mBF16 is not supported, use FP16. The 5B model is not recommended\033[0m') | |
dtype = torch.float16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
demo.queue().launch(inline=False, share=False, debug=True, server_name='0.0.0.0', server_port=7860, allowed_paths=["/content"]) | |