Spaces:
Sleeping
Sleeping
File size: 10,816 Bytes
be791d6 9e8c2c6 be791d6 9e8c2c6 be791d6 9e8c2c6 be791d6 9e8c2c6 be791d6 9e8c2c6 be791d6 9e8c2c6 be791d6 9e8c2c6 be791d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
import os
import torch
import argparse
import torchvision
from pipeline_videogen import VideoGenPipeline
from pipelines.pipeline_inversion import VideoGenInversionPipeline
from diffusers.schedulers import DDIMScheduler
from diffusers.models import AutoencoderKL
from diffusers.models import AutoencoderKLTemporalDecoder
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from utils import find_model
from models import get_models
import imageio
import decord
import numpy as np
from copy import deepcopy
from PIL import Image
from datasets import video_transforms
from torchvision import transforms
from models.unet import UNet3DConditionModel
from einops import repeat
from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
def prepare_image(path, vae, transform_video, device, dtype=torch.float16):
with open(path, 'rb') as f:
image = Image.open(f).convert('RGB')
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
image, ori_h, ori_w, crops_coords_top, crops_coords_left = transform_video(image)
image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
image = image.unsqueeze(2)
return image
def separation_content_motion(video_clip):
"""
Separate content and motion in a given video.
Args:
video_clip: A given video clip, shape [B, C, F, H, W]
Return:
base_frame: Base frame, shape [B, C, 1, H, W]
motions: Motions based on base frame, shape [B, C, F-1, H, W]
"""
# Selecting the first frame from each video in the batch as the base frame
base_frame = video_clip[:, :, :1, :, :]
# Calculating the motion (difference between each frame and the base frame)
motions = video_clip[:, :, 1:, :, :] - base_frame
return base_frame, motions
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
def main(args):
# torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 # torch.float16
# unet = get_models(args).to(device, dtype=torch.float16)
# state_dict = find_model(args.ckpt)
# unet.load_state_dict(state_dict)
unet = UNet3DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet").to(device, dtype=torch.float16)
if args.enable_vae_temporal_decoder:
if args.use_dct:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
else:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
else:
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
# set eval mode
unet.eval()
vae.eval()
text_encoder.eval()
scheduler_inversion = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,)
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
# beta_end=0.017,
beta_schedule=args.beta_schedule,)
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler_inversion,
unet=unet).to(device)
videogen_pipeline_inversion = VideoGenInversionPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
unet=unet).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
# videogen_pipeline.enable_vae_slicing()
transform_video = video_transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.SDXLCenterCrop((args.image_size[0], args.image_size[1])), # center crop using shor edge, then resize
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
# video_path = './video_editing/A_man_walking_on_the_beach.mp4'
# video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4'
video_path = './video_editing/test_03.mp4'
video_reader = DecordInit()
video = video_reader(video_path)
frame_indice = np.linspace(0, 15, 16, dtype=int)
video = torch.from_numpy(video.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
video = video / 255.0
video = video * 2.0 - 1.0
latents = vae.encode(video.to(dtype=torch.float16, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor).unsqueeze(0).permute(0, 2, 1, 3, 4)
base_content, motion_latents = separation_content_motion(latents)
# image_path = "./video_editing/a_man_walking_in_the_park.png"
# image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png"
image_path = "./video_editing/test_03.png"
if args.use_dct:
edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
else:
edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
# prompt_inversion = 'a man walking on the beach'
# prompt_inversion = 'a corgi walking in the park at sunrise, oil painting style'
# prompt_inversion = 'A girl is playing the guitar in her room'
prompt_inversion = 'A man is walking inside the church'
latents = videogen_pipeline_inversion(prompt_inversion,
latents=motion_latents,
base_content=base_content,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=1.0,
# guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
output_type="latent").video
# prompt = 'a man walking in the park'
# prompt = 'a corgi walking in the park at sunrise, oil painting style'
# prompt = 'A girl is playing the guitar in her room'
prompt = 'A man is walking inside the church'
if args.use_dct:
# filter params
print("Using DCT!")
edit_content_repeat = repeat(edit_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
# define filter
freq_filter = dct_low_pass_filter(dct_coefficients=edit_content,
percentage=0.23)
noise = latents.to(dtype=torch.float64)
# add noise to base_content
diffuse_timesteps = torch.full((1,),int(985))
diffuse_timesteps = diffuse_timesteps.long()
# 3d content
edit_content_noise = scheduler.add_noise(
original_samples=edit_content_repeat.to(device),
noise=noise,
timesteps=diffuse_timesteps.to(device))
# 3d content
latents = exchanged_mixed_dct_freq(noise=noise,
base_content=edit_content_noise,
LPF_3d=freq_filter).to(dtype=torch.float16)
latents = latents.to(dtype=torch.float16)
edit_content = edit_content.to(dtype=torch.float16)
videos = videogen_pipeline(prompt,
latents=latents,
base_content=edit_content,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
# guidance_scale=1.0,
guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0
print('save path {}'.format(args.save_img_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
args = parser.parse_args()
main(OmegaConf.load(args.config))
|