NaRCan_demo / app.py
Koi953215's picture
fix bug
79910d2
raw
history blame
12.9 kB
import gradio as gr
import numpy as np
import torch
import cv2
import os
import imageio
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from controlnet_aux import LineartDetector
from functools import partial
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from NaRCan_model import Homography, Siren
from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_example():
case = [
[
'examples/bear.mp4',
],
[
'examples/boat.mp4',
],
[
'examples/woman-drink.mp4',
],
[
'examples/corgi.mp4',
],
[
'examples/yacht.mp4',
],
[
'examples/koolshooters.mp4',
],
[
'examples/overlook-the-ocean.mp4',
],
[
'examples/rotate.mp4',
],
[
'examples/shark-ocean.mp4',
],
[
'examples/surf.mp4',
],
[
'examples/cactus.mp4',
],
[
'examples/gold-fish.mp4',
]
]
return case
def set_default_prompt(video_name):
video_to_prompt = {
'bear.mp4': 'bear, Van Gogh Style',
'boat.mp4': 'a burning boat sails on lava',
'cactus.mp4': 'cactus, made of paper',
'corgi.mp4': 'a hellhound',
'gold-fish.mp4': 'Goldfish in the Milky Way',
'koolshooters.mp4': 'Avatar',
'overlook-the-ocean.mp4': 'ocean, pixel style',
'rotate.mp4': 'turbine engine',
'shark-ocean.mp4': 'A sleek shark, cartoon style',
'surf.mp4': 'Sailing, The background is a large white cloud, sketch style',
'woman-drink.mp4': 'a drinking zombie',
'yacht.mp4': 'yacht, cyberpunk style',
}
return video_to_prompt.get(video_name, '')
def update_prompt(input_video):
video_name = input_video.split('/')[-1]
return set_default_prompt(video_name)
# Map videos to corresponding images
video_to_image = {
'bear.mp4': ['canonical/bear.png', 'pth_file/bear', 'examples_frames/bear'],
'boat.mp4': ['canonical/boat.png', 'pth_file/boat', 'examples_frames/boat'],
'cactus.mp4': ['canonical/cactus.png', 'pth_file/cactus', 'examples_frames/cactus'],
'corgi.mp4': ['canonical/corgi.png', 'pth_file/corgi', 'examples_frames/corgi'],
'gold-fish.mp4': ['canonical/gold-fish.png', 'pth_file/gold-fish', 'examples_frames/gold-fish'],
'koolshooters.mp4': ['canonical/koolshooters.png', 'pth_file/koolshooters', 'examples_frames/koolshooters'],
'overlook-the-ocean.mp4': ['canonical/overlook-the-ocean.png', 'pth_file/overlook-the-ocean', 'examples_frames/overlook-the-ocean'],
'rotate.mp4': ['canonical/rotate.png', 'pth_file/rotate', 'examples_frames/rotate'],
'shark-ocean.mp4': ['canonical/shark-ocean.png', 'pth_file/shark-ocean', 'examples_frames/shark-ocean'],
'surf.mp4': ['canonical/surf.png', 'pth_file/surf', 'examples_frames/surf'],
'woman-drink.mp4': ['canonical/woman-drink.png', 'pth_file/woman-drink', 'examples_frames/woman-drink'],
'yacht.mp4': ['canonical/yacht.png', 'pth_file/yacht', 'examples_frames/yacht'],
}
def images_to_video(image_list, output_path, fps=10):
# Convert PIL Images to numpy arrays
frames = [np.array(img).astype(np.uint8) for img in image_list]
frames = frames[:20]
# Create video writer
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
for frame in frames:
writer.append_data(frame)
writer.close()
def NaRCan_make_video(edit_canonical, pth_path, frames_path):
# load NaRCan model
checkpoint_g_old = torch.load(os.path.join(pth_path, "homography_g.pth"))
checkpoint_g = torch.load(os.path.join(pth_path, "mlp_g.pth"))
g_old = Homography(hidden_features=256, hidden_layers=2).to(device)
g = Siren(in_features=3, out_features=2, hidden_features=256,
hidden_layers=5, outermost_linear=True).to(device)
g_old.load_state_dict(checkpoint_g_old)
g.load_state_dict(checkpoint_g)
g_old.eval()
g.eval()
transform = Compose([
Resize(512),
ToTensor(),
Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
])
v = TestVideoFitting(frames_path, transform)
videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)
model_input, ground_truth = next(iter(videoloader))
model_input, ground_truth = model_input[0].to(device), ground_truth[0].to(device)
myoutput = None
data_len = len(os.listdir(frames_path))
with torch.no_grad():
batch_size = (v.H * v.W)
for step in range(data_len):
start = (step * batch_size) % len(model_input)
end = min(start + batch_size, len(model_input))
# get the deformation
xy, t = model_input[start:end, :-1], model_input[start:end, [-1]]
xyt = model_input[start:end]
h_old = apply_homography(xy, g_old(t))
h = g(xyt)
xy_ = h_old + h
# use canonical to reconstruct
w, h = v.W, v.H
canonical_img = np.array(edit_canonical.convert('RGB'))
canonical_img = torch.from_numpy(canonical_img).float().to(device)
h_c, w_c = canonical_img.shape[:2]
grid_new = xy_.clone()
grid_new[..., 1] = xy_[..., 0] / 1.5
grid_new[..., 0] = xy_[..., 1] / 2.0
if len(canonical_img.shape) == 3:
canonical_img = canonical_img.unsqueeze(0)
results = torch.nn.functional.grid_sample(
canonical_img.permute(0, 3, 1, 2),
grid_new.unsqueeze(1).unsqueeze(0),
mode='bilinear',
padding_mode='border')
o = results.squeeze().permute(1,0)
if step == 0:
myoutput = o
else:
myoutput = torch.cat([myoutput, o])
myoutput = myoutput.reshape(512, 512, data_len, 3).permute(2, 0, 1, 3).clone().detach().cpu().numpy().astype(np.float32)
# myoutput = np.clip(myoutput, -1, 1) * 0.5 + 0.5
for i in range(len(myoutput)):
myoutput[i] = Image.fromarray(np.uint8(myoutput[i])).resize((512, 512)) #854, 480
edit_video_path = f'NaRCan_fps_10.mp4'
images_to_video(myoutput, edit_video_path)
return edit_video_path
def edit_with_pnp(input_video, prompt, num_steps, guidance_scale, seed, n_prompt, control_type="Lineart"):
video_name = input_video.split('/')[-1]
if video_name in video_to_image:
image_path = video_to_image[video_name][0]
pth_path = video_to_image[video_name][1]
frames_path = video_to_image[video_name][2]
else:
return None
if control_type == "Lineart":
# Load the control net model for lineart
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.to(device)
# lineart
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processor_partial = partial(processor, coarse=False)
size_ = 768
canonical_image = Image.open(image_path)
ori_size = canonical_image.size
image = processor_partial(canonical_image.resize((size_, size_)), detect_resolution=size_, image_resolution=size_)
image = image.resize(ori_size, resample=Image.BILINEAR)
generator = torch.manual_seed(seed) if seed != -1 else None
output_images = pipe(
prompt=prompt,
image=image,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
negative_prompt=n_prompt,
generator=generator
).images
# output_images[0] = output_images[0].resize(ori_size, resample=Image.BILINEAR)
else:
# Load the control net model for canny
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.to(device)
# canny
canonical_image = cv2.imread(image_path)
canonical_image = cv2.cvtColor(canonical_image, cv2.COLOR_BGR2RGB)
image = cv2.cvtColor(canonical_image, cv2.COLOR_RGB2GRAY)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)
generator = torch.manual_seed(seed) if seed != -1 else None
output_images = pipe(
prompt=prompt,
image=image,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
negative_prompt=n_prompt,
generator=generator
).images
edit_video_path = NaRCan_make_video(output_images[0], pth_path, frames_path)
# Here we return the first output image as the result
return edit_video_path
########
# demo #
########
intro = """
<div style="text-align:center">
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
NaRCan - <small>Natural Refined Canonical Image</small>
</h1>
<span>[<a target="_blank" href="https://koi953215.github.io/NaRCan_page/">Project page</a>], [<a target="_blank" href="https://huggingface.co/papers/2406.06523">Paper</a>]</span>
<div style="display:flex; justify-content: center;margin-top: 0.5em">Each edit takes ~10 sec </div>
</div>
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML(intro)
frames = gr.State()
inverted_latents = gr.State()
latents = gr.State()
zs = gr.State()
do_inversion = gr.State(value=True)
with gr.Row():
input_video = gr.Video(label="Input Video", interactive=False, elem_id="input_video", value='examples/bear.mp4')
output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video")
input_video.style(height=365, width=365)
output_video.style(height=365, width=365)
with gr.Row():
prompt = gr.Textbox(
label="Describe your edited video",
max_lines=1,
value="bear, Van Gogh Style"
# placeholder="bear, Van Gogh Style"
)
with gr.Row():
run_button = gr.Button("Edit your video!", visible=True)
max_images = 12
default_num_images = 3
with gr.Accordion('Advanced options', open=False):
control_type = gr.Dropdown(
["Canny", "Lineart"],
label="Control Type",
info="Canny or Lineart",
value="Lineart"
)
num_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=20,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=9.0,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=2147483647,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value=""
)
input_video.change(
fn = update_prompt,
inputs = [input_video],
outputs = [prompt],
queue = False)
run_button.click(fn = edit_with_pnp,
inputs = [input_video,
prompt,
num_steps,
guidance_scale,
seed,
n_prompt,
control_type,
],
outputs = [output_video]
)
gr.Examples(
examples=get_example(),
label='Examples',
inputs=[input_video],
outputs=[output_video],
examples_per_page=8
)
demo.queue()
demo.launch()