Spaces:
Runtime error
Runtime error
# Copyright 2023 ByteDance and/or its affiliates. | |
# | |
# Copyright (2023) MagicAnimate Authors | |
# | |
# ByteDance, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from ByteDance or | |
# its affiliates is strictly prohibited. | |
import argparse | |
import datetime | |
import inspect | |
import os | |
import random | |
import numpy as np | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from collections import OrderedDict | |
import torch | |
import torch.distributed as dist | |
from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler | |
from tqdm import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from magicanimate.models.unet_controlnet import UNet3DConditionModel | |
from magicanimate.models.controlnet import ControlNetModel | |
from magicanimate.models.appearance_encoder import AppearanceEncoderModel | |
from magicanimate.models.mutual_self_attention import ReferenceAttentionControl | |
from magicanimate.pipelines.pipeline_animation import AnimationPipeline | |
from magicanimate.utils.util import save_videos_grid | |
from magicanimate.utils.dist_tools import distributed_init | |
from accelerate.utils import set_seed | |
from magicanimate.utils.videoreader import VideoReader | |
from einops import rearrange | |
from pathlib import Path | |
def main(args): | |
*_, func_args = inspect.getargvalues(inspect.currentframe()) | |
func_args = dict(func_args) | |
config = OmegaConf.load(args.config) | |
# Initialize distributed training | |
device = torch.device(f"cuda:{args.rank}") | |
dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} | |
if config.savename is None: | |
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
savedir = f"samples/{Path(args.config).stem}-{time_str}" | |
else: | |
savedir = f"samples/{config.savename}" | |
if args.dist: | |
dist.broadcast_object_list([savedir], 0) | |
dist.barrier() | |
if args.rank == 0: | |
os.makedirs(savedir, exist_ok=True) | |
inference_config = OmegaConf.load(config.inference_config) | |
motion_module = config.motion_module | |
### >>> create animation pipeline >>> ### | |
tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") | |
if config.pretrained_unet_path: | |
unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) | |
else: | |
unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) | |
appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device) | |
reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) | |
reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) | |
if config.pretrained_vae_path is not None: | |
vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) | |
else: | |
vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") | |
### Load controlnet | |
controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) | |
unet.enable_xformers_memory_efficient_attention() | |
appearance_encoder.enable_xformers_memory_efficient_attention() | |
controlnet.enable_xformers_memory_efficient_attention() | |
vae.to(torch.float16) | |
unet.to(torch.float16) | |
text_encoder.to(torch.float16) | |
appearance_encoder.to(torch.float16) | |
controlnet.to(torch.float16) | |
pipeline = AnimationPipeline( | |
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, | |
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), | |
# NOTE: UniPCMultistepScheduler | |
) | |
# 1. unet ckpt | |
# 1.1 motion module | |
motion_module_state_dict = torch.load(motion_module, map_location="cpu") | |
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) | |
motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict | |
try: | |
# extra steps for self-trained models | |
state_dict = OrderedDict() | |
for key in motion_module_state_dict.keys(): | |
if key.startswith("module."): | |
_key = key.split("module.")[-1] | |
state_dict[_key] = motion_module_state_dict[key] | |
else: | |
state_dict[key] = motion_module_state_dict[key] | |
motion_module_state_dict = state_dict | |
del state_dict | |
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) | |
assert len(unexpected) == 0 | |
except: | |
_tmp_ = OrderedDict() | |
for key in motion_module_state_dict.keys(): | |
if "motion_modules" in key: | |
if key.startswith("unet."): | |
_key = key.split('unet.')[-1] | |
_tmp_[_key] = motion_module_state_dict[key] | |
else: | |
_tmp_[key] = motion_module_state_dict[key] | |
missing, unexpected = unet.load_state_dict(_tmp_, strict=False) | |
assert len(unexpected) == 0 | |
del _tmp_ | |
del motion_module_state_dict | |
pipeline.to(device) | |
### <<< create validation pipeline <<< ### | |
random_seeds = config.get("seed", [-1]) | |
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) | |
random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds | |
# input test videos (either source video/ conditions) | |
test_videos = config.video_path | |
source_images = config.source_image | |
num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps) | |
# read size, step from yaml file | |
sizes = [config.size] * len(test_videos) | |
steps = [config.S] * len(test_videos) | |
config.random_seed = [] | |
prompt = n_prompt = "" | |
for idx, (source_image, test_video, random_seed, size, step) in tqdm( | |
enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), | |
total=len(test_videos), | |
disable=(args.rank!=0) | |
): | |
samples_per_video = [] | |
samples_per_clip = [] | |
# manually set random seed for reproduction | |
if random_seed != -1: | |
torch.manual_seed(random_seed) | |
set_seed(random_seed) | |
else: | |
torch.seed() | |
config.random_seed.append(torch.initial_seed()) | |
if test_video.endswith('.mp4'): | |
control = VideoReader(test_video).read() | |
if control[0].shape[0] != size: | |
control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] | |
if config.max_length is not None: | |
control = control[config.offset: (config.offset+config.max_length)] | |
control = np.array(control) | |
if source_image.endswith(".mp4"): | |
source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size))) | |
else: | |
source_image = np.array(Image.open(source_image).resize((size, size))) | |
H, W, C = source_image.shape | |
print(f"current seed: {torch.initial_seed()}") | |
init_latents = None | |
# print(f"sampling {prompt} ...") | |
original_length = control.shape[0] | |
if control.shape[0] % config.L > 0: | |
control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge') | |
generator = torch.Generator(device=torch.device("cuda:0")) | |
generator.manual_seed(torch.initial_seed()) | |
sample = pipeline( | |
prompt, | |
negative_prompt = n_prompt, | |
num_inference_steps = config.steps, | |
guidance_scale = config.guidance_scale, | |
width = W, | |
height = H, | |
video_length = len(control), | |
controlnet_condition = control, | |
init_latents = init_latents, | |
generator = generator, | |
num_actual_inference_steps = num_actual_inference_steps, | |
appearance_encoder = appearance_encoder, | |
reference_control_writer = reference_control_writer, | |
reference_control_reader = reference_control_reader, | |
source_image = source_image, | |
**dist_kwargs, | |
).videos | |
if args.rank == 0: | |
source_images = np.array([source_image] * original_length) | |
source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 | |
samples_per_video.append(source_images) | |
control = control / 255.0 | |
control = rearrange(control, "t h w c -> 1 c t h w") | |
control = torch.from_numpy(control) | |
samples_per_video.append(control[:, :, :original_length]) | |
samples_per_video.append(sample[:, :, :original_length]) | |
samples_per_video = torch.cat(samples_per_video) | |
video_name = os.path.basename(test_video)[:-4] | |
source_name = os.path.basename(config.source_image[idx]).split(".")[0] | |
save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4") | |
save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4") | |
if config.save_individual_videos: | |
save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4") | |
save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4") | |
if args.dist: | |
dist.barrier() | |
if args.rank == 0: | |
OmegaConf.save(config, f"{savedir}/config.yaml") | |
def distributed_main(device_id, args): | |
args.rank = device_id | |
args.device_id = device_id | |
if torch.cuda.is_available(): | |
torch.cuda.set_device(args.device_id) | |
torch.cuda.init() | |
distributed_init(args) | |
main(args) | |
def run(args): | |
if args.dist: | |
args.world_size = max(1, torch.cuda.device_count()) | |
assert args.world_size <= torch.cuda.device_count() | |
if args.world_size > 0 and torch.cuda.device_count() > 1: | |
port = random.randint(10000, 20000) | |
args.init_method = f"tcp://localhost:{port}" | |
torch.multiprocessing.spawn( | |
fn=distributed_main, | |
args=(args,), | |
nprocs=args.world_size, | |
) | |
else: | |
main(args) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, required=True) | |
parser.add_argument("--dist", action="store_true", required=False) | |
parser.add_argument("--rank", type=int, default=0, required=False) | |
parser.add_argument("--world_size", type=int, default=1, required=False) | |
args = parser.parse_args() | |
run(args) | |