|
import os |
|
import torch |
|
import argparse |
|
import torchvision |
|
|
|
from pipeline_videogen import VideoGenPipeline |
|
from pipelines.pipeline_inversion import VideoGenInversionPipeline |
|
|
|
from diffusers.schedulers import DDIMScheduler |
|
from diffusers.models import AutoencoderKL |
|
from diffusers.models import AutoencoderKLTemporalDecoder |
|
from transformers import CLIPTokenizer, CLIPTextModel |
|
from omegaconf import OmegaConf |
|
|
|
import os, sys |
|
sys.path.append(os.path.split(sys.path[0])[0]) |
|
from utils import find_model |
|
from models import get_models |
|
import imageio |
|
import decord |
|
import numpy as np |
|
from copy import deepcopy |
|
from PIL import Image |
|
from datasets import video_transforms |
|
from torchvision import transforms |
|
|
|
def prepare_image(path, vae, transform_video, device, dtype=torch.float16): |
|
with open(path, 'rb') as f: |
|
image = Image.open(f).convert('RGB') |
|
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2) |
|
image, ori_h, ori_w, crops_coords_top, crops_coords_left = transform_video(image) |
|
image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor) |
|
image = image.unsqueeze(2) |
|
return image |
|
|
|
def separation_content_motion(video_clip): |
|
""" |
|
Separate content and motion in a given video. |
|
Args: |
|
video_clip: A given video clip, shape [B, C, F, H, W] |
|
|
|
Return: |
|
base_frame: Base frame, shape [B, C, 1, H, W] |
|
motions: Motions based on base frame, shape [B, C, F-1, H, W] |
|
""" |
|
|
|
base_frame = video_clip[:, :, :1, :, :] |
|
|
|
|
|
motions = video_clip[:, :, 1:, :, :] - base_frame |
|
|
|
return base_frame, motions |
|
|
|
|
|
class DecordInit(object): |
|
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" |
|
|
|
def __init__(self, num_threads=1): |
|
self.num_threads = num_threads |
|
self.ctx = decord.cpu(0) |
|
|
|
def __call__(self, filename): |
|
"""Perform the Decord initialization. |
|
Args: |
|
results (dict): The resulting dict to be modified and passed |
|
to the next transform in pipeline. |
|
""" |
|
reader = decord.VideoReader(filename, |
|
ctx=self.ctx, |
|
num_threads=self.num_threads) |
|
return reader |
|
|
|
def __repr__(self): |
|
repr_str = (f'{self.__class__.__name__}(' |
|
f'sr={self.sr},' |
|
f'num_threads={self.num_threads})') |
|
return repr_str |
|
|
|
|
|
def main(args): |
|
|
|
torch.set_grad_enabled(False) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.float16 |
|
|
|
unet = get_models(args).to(device, dtype=torch.float16) |
|
state_dict = find_model(args.ckpt) |
|
unet.load_state_dict(state_dict) |
|
|
|
if args.enable_vae_temporal_decoder: |
|
if args.use_dct: |
|
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device) |
|
else: |
|
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) |
|
vae = deepcopy(vae_for_base_content).to(dtype=dtype) |
|
else: |
|
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64) |
|
vae = deepcopy(vae_for_base_content).to(dtype=dtype) |
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") |
|
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) |
|
|
|
|
|
unet.eval() |
|
vae.eval() |
|
text_encoder.eval() |
|
|
|
scheduler_inversion = DDIMScheduler.from_pretrained(args.pretrained_model_path, |
|
subfolder="scheduler", |
|
beta_start=args.beta_start, |
|
beta_end=args.beta_end, |
|
beta_schedule=args.beta_schedule,) |
|
|
|
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, |
|
subfolder="scheduler", |
|
beta_start=args.beta_start, |
|
beta_end=args.beta_end, |
|
|
|
beta_schedule=args.beta_schedule,) |
|
|
|
videogen_pipeline = VideoGenPipeline(vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
scheduler=scheduler_inversion, |
|
unet=unet).to(device) |
|
|
|
videogen_pipeline_inversion = VideoGenInversionPipeline(vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
scheduler=scheduler, |
|
unet=unet).to(device) |
|
|
|
|
|
|
|
transform_video = video_transforms.Compose([ |
|
video_transforms.ToTensorVideo(), |
|
video_transforms.SDXLCenterCrop((args.image_size[0], args.image_size[1])), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
]) |
|
|
|
|
|
|
|
video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4' |
|
|
|
|
|
video_reader = DecordInit() |
|
video = video_reader(video_path) |
|
frame_indice = np.linspace(0, 15, 16, dtype=int) |
|
video = torch.from_numpy(video.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
video = video / 255.0 |
|
video = video * 2.0 - 1.0 |
|
latents = vae.encode(video.to(dtype=torch.float16, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor).unsqueeze(0).permute(0, 2, 1, 3, 4) |
|
|
|
base_content, motion_latents = separation_content_motion(latents) |
|
|
|
|
|
image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png" |
|
edit_content = prepare_image(image_path, vae, transform_video, device, dtype=torch.float16).to(device) |
|
|
|
if not os.path.exists(args.save_img_path): |
|
os.makedirs(args.save_img_path) |
|
|
|
|
|
prompt_inversion = 'a corgi walking in the park at sunrise, oil painting style' |
|
latents = videogen_pipeline_inversion(prompt_inversion, |
|
latents=motion_latents, |
|
base_content=base_content, |
|
video_length=args.video_length, |
|
height=args.image_size[0], |
|
width=args.image_size[1], |
|
num_inference_steps=args.num_sampling_steps, |
|
guidance_scale=1.0, |
|
|
|
motion_bucket_id=args.motion_bucket_id, |
|
output_type="latent").video |
|
|
|
|
|
prompt = 'a corgi walking in the park at sunrise, oil painting style' |
|
videos = videogen_pipeline(prompt, |
|
latents=latents, |
|
base_content=edit_content, |
|
video_length=args.video_length, |
|
height=args.image_size[0], |
|
width=args.image_size[1], |
|
num_inference_steps=args.num_sampling_steps, |
|
guidance_scale=1.0, |
|
|
|
motion_bucket_id=args.motion_bucket_id, |
|
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video |
|
imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) |
|
print('save path {}'.format(args.save_img_path)) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str, default="./configs/sample.yaml") |
|
args = parser.parse_args() |
|
|
|
main(OmegaConf.load(args.config)) |
|
|
|
|
|
|