Spaces:
Runtime error
Runtime error
import os | |
import PIL.Image | |
import numpy as np | |
import torch | |
import torchvision | |
from torchvision.transforms import Resize, InterpolationMode | |
import imageio | |
from einops import rearrange | |
import cv2 | |
from PIL import Image | |
import decord | |
from controlnet_aux import OpenposeDetector | |
apply_openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") | |
def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1): | |
vr = decord.VideoReader(video_path) | |
initial_fps = vr.get_avg_fps() | |
if output_fps == -1: | |
output_fps = int(initial_fps) | |
if end_t == -1: | |
end_t = len(vr) / initial_fps | |
else: | |
end_t = min(len(vr) / initial_fps, end_t) | |
assert 0 <= start_t < end_t | |
assert output_fps > 0 | |
start_f_ind = int(start_t * initial_fps) | |
end_f_ind = int(end_t * initial_fps) | |
num_f = int((end_t - start_t) * output_fps) | |
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int) | |
video = vr.get_batch(sample_idx) | |
if torch.is_tensor(video): | |
video = video.detach().cpu().numpy() | |
else: | |
video = video.asnumpy() | |
_, h, w, _ = video.shape | |
video = rearrange(video, "f h w c -> f c h w") | |
video = torch.Tensor(video).to(device).to(dtype) | |
# Use max if you want the larger side to be equal to resolution (e.g. 512) | |
# k = float(resolution) / min(h, w) | |
k = float(resolution) / max(h, w) | |
h *= k | |
w *= k | |
h = int(np.round(h / 64.0)) * 64 | |
w = int(np.round(w / 64.0)) * 64 | |
video = Resize((h, w), interpolation=InterpolationMode.BILINEAR, antialias=True)(video) | |
if normalize: | |
video = video / 127.5 - 1.0 | |
return video, output_fps | |
def pre_process_pose(input_video, apply_pose_detect: bool = True): | |
detected_maps = [] | |
for frame in input_video: | |
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) | |
if apply_pose_detect: | |
detected_map, _ = apply_openpose(img) | |
else: | |
detected_map = img | |
H, W, C = img.shape | |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) | |
detected_maps.append(detected_map[None]) | |
detected_maps = np.concatenate(detected_maps) | |
control = torch.from_numpy(detected_maps.copy()).float() / 255.0 | |
return rearrange(control, 'f h w c -> f c h w') | |
def create_gif(frames, fps, rescale=False, path=None, watermark=None): | |
if path is None: | |
dir = "temporal" | |
os.makedirs(dir, exist_ok=True) | |
path = os.path.join(dir, 'canny_db.gif') | |
outputs = [] | |
for i, x in enumerate(frames): | |
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) | |
imageio.mimsave(path, outputs, duration=1/fps) | |
return path | |
def post_process_gif(list_of_results, image_resolution): | |
output_file = "/tmp/ddxk.gif" | |
imageio.mimsave(output_file, list_of_results, duration=1/4) | |
return output_file | |