|
import sys |
|
import argparse |
|
import os |
|
import cv2 |
|
import glob |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
from .raft import RAFT |
|
from .utils import flow_viz |
|
from .utils.utils import InputPadder |
|
|
|
|
|
|
|
DEVICE = 'cuda' |
|
|
|
def load_image(imfile): |
|
img = np.array(Image.open(imfile)).astype(np.uint8) |
|
img = torch.from_numpy(img).permute(2, 0, 1).float() |
|
return img |
|
|
|
|
|
def load_image_list(image_files): |
|
images = [] |
|
for imfile in sorted(image_files): |
|
images.append(load_image(imfile)) |
|
|
|
images = torch.stack(images, dim=0) |
|
images = images.to(DEVICE) |
|
|
|
padder = InputPadder(images.shape) |
|
return padder.pad(images)[0] |
|
|
|
|
|
def viz(img, flo): |
|
img = img[0].permute(1,2,0).cpu().numpy() |
|
flo = flo[0].permute(1,2,0).cpu().numpy() |
|
|
|
|
|
flo = flow_viz.flow_to_image(flo) |
|
|
|
img_flo = flo |
|
|
|
cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) |
|
|
|
|
|
|
|
|
|
def demo(args): |
|
model = torch.nn.DataParallel(RAFT(args)) |
|
model.load_state_dict(torch.load(args.model)) |
|
|
|
model = model.module |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
images = glob.glob(os.path.join(args.path, '*.png')) + \ |
|
glob.glob(os.path.join(args.path, '*.jpg')) |
|
|
|
images = load_image_list(images) |
|
for i in range(images.shape[0]-1): |
|
image1 = images[i,None] |
|
image2 = images[i+1,None] |
|
|
|
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) |
|
viz(image1, flow_up) |
|
|
|
|
|
def RAFT_infer(args): |
|
model = torch.nn.DataParallel(RAFT(args)) |
|
model.load_state_dict(torch.load(args.model)) |
|
|
|
model = model.module |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
return model |
|
|