CoherentControl / utils.py
foz
Fix looping
f687279
raw
history blame
3.04 kB
import os
import PIL.Image
import numpy as np
import torch
import torchvision
from torchvision.transforms import Resize, InterpolationMode
import imageio
from einops import rearrange
import cv2
from PIL import Image
import decord
from controlnet_aux import OpenposeDetector
apply_openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1):
vr = decord.VideoReader(video_path)
initial_fps = vr.get_avg_fps()
if output_fps == -1:
output_fps = int(initial_fps)
if end_t == -1:
end_t = len(vr) / initial_fps
else:
end_t = min(len(vr) / initial_fps, end_t)
assert 0 <= start_t < end_t
assert output_fps > 0
start_f_ind = int(start_t * initial_fps)
end_f_ind = int(end_t * initial_fps)
num_f = int((end_t - start_t) * output_fps)
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
video = vr.get_batch(sample_idx)
if torch.is_tensor(video):
video = video.detach().cpu().numpy()
else:
video = video.asnumpy()
_, h, w, _ = video.shape
video = rearrange(video, "f h w c -> f c h w")
video = torch.Tensor(video).to(device).to(dtype)
# Use max if you want the larger side to be equal to resolution (e.g. 512)
# k = float(resolution) / min(h, w)
k = float(resolution) / max(h, w)
h *= k
w *= k
h = int(np.round(h / 64.0)) * 64
w = int(np.round(w / 64.0)) * 64
video = Resize((h, w), interpolation=InterpolationMode.BILINEAR, antialias=True)(video)
if normalize:
video = video / 127.5 - 1.0
return video, output_fps
def pre_process_pose(input_video, apply_pose_detect: bool = True):
detected_maps = []
for frame in input_video:
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
if apply_pose_detect:
detected_map, _ = apply_openpose(img)
else:
detected_map = img
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
detected_maps.append(detected_map[None])
detected_maps = np.concatenate(detected_maps)
control = torch.from_numpy(detected_maps.copy()).float() / 255.0
return rearrange(control, 'f h w c -> f c h w')
def create_gif(frames, fps, rescale=False, path=None, watermark=None):
if path is None:
dir = "temporal"
os.makedirs(dir, exist_ok=True)
path = os.path.join(dir, 'canny_db.gif')
outputs = []
for i, x in enumerate(frames):
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
imageio.mimsave(path, outputs, loop=0, duration=1000/fps)
return path