init repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- FGT_codes/FGT/checkpoint/config.yaml +34 -0
- FGT_codes/FGT/checkpoint/fgt.pth.tar +3 -0
- FGT_codes/FGT/config/data_info.yaml +11 -0
- FGT_codes/FGT/config/davis_name2len.pkl +3 -0
- FGT_codes/FGT/config/davis_name2len_train.pkl +3 -0
- FGT_codes/FGT/config/davis_name2len_val.pkl +3 -0
- FGT_codes/FGT/config/train.yaml +93 -0
- FGT_codes/FGT/config/valid_config.yaml +8 -0
- FGT_codes/FGT/config/youtubevos_name2len.pkl +3 -0
- FGT_codes/FGT/data/__init__.py +49 -0
- FGT_codes/FGT/data/train_dataset.py +165 -0
- FGT_codes/FGT/data/util/MaskModel.py +123 -0
- FGT_codes/FGT/data/util/STTN_mask.py +244 -0
- FGT_codes/FGT/data/util/__init__.py +28 -0
- FGT_codes/FGT/data/util/flow_utils/__init__.py +0 -0
- FGT_codes/FGT/data/util/flow_utils/flow_reversal.py +77 -0
- FGT_codes/FGT/data/util/flow_utils/region_fill.py +142 -0
- FGT_codes/FGT/data/util/freeform_masks.py +266 -0
- FGT_codes/FGT/data/util/mask_generators.py +217 -0
- FGT_codes/FGT/data/util/readers.py +527 -0
- FGT_codes/FGT/data/util/util.py +259 -0
- FGT_codes/FGT/data/util/utils.py +158 -0
- FGT_codes/FGT/flowCheckPoint/config.yaml +11 -0
- FGT_codes/FGT/flowCheckPoint/lafc_single.pth.tar +3 -0
- FGT_codes/FGT/inputs.py +83 -0
- FGT_codes/FGT/metrics/__init__.py +31 -0
- FGT_codes/FGT/metrics/psnr.py +10 -0
- FGT_codes/FGT/metrics/ssim.py +46 -0
- FGT_codes/FGT/models/BaseNetwork.py +46 -0
- FGT_codes/FGT/models/__init__.py +0 -0
- FGT_codes/FGT/models/__pycache__/BaseNetwork.cpython-39.pyc +0 -0
- FGT_codes/FGT/models/__pycache__/__init__.cpython-39.pyc +0 -0
- FGT_codes/FGT/models/__pycache__/model.cpython-39.pyc +0 -0
- FGT_codes/FGT/models/lafc_single.py +114 -0
- FGT_codes/FGT/models/model.py +284 -0
- FGT_codes/FGT/models/temporal_patch_gan.py +76 -0
- FGT_codes/FGT/models/transformer_base/__init__.py +0 -0
- FGT_codes/FGT/models/transformer_base/__pycache__/__init__.cpython-39.pyc +0 -0
- FGT_codes/FGT/models/transformer_base/__pycache__/attention_base.cpython-39.pyc +0 -0
- FGT_codes/FGT/models/transformer_base/__pycache__/attention_flow.cpython-39.pyc +0 -0
- FGT_codes/FGT/models/transformer_base/__pycache__/ffn_base.cpython-39.pyc +0 -0
- FGT_codes/FGT/models/transformer_base/attention_base.py +106 -0
- FGT_codes/FGT/models/transformer_base/attention_flow.py +171 -0
- FGT_codes/FGT/models/transformer_base/ffn_base.py +114 -0
- FGT_codes/FGT/models/utils/RAFT/utils/__init__.py +0 -0
- FGT_codes/FGT/models/utils/RAFT/utils/utils.py +82 -0
- FGT_codes/FGT/models/utils/__init__.py +0 -0
- FGT_codes/FGT/models/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- FGT_codes/FGT/models/utils/__pycache__/network_blocks_2d.cpython-39.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
*.pth.tar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
*.pth.tar filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.o filter=lfs diff=lfs merge=lfs -text
|
FGT_codes/FGT/checkpoint/config.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PASSMASK: 1
|
2 |
+
alpha: 0.3
|
3 |
+
ape: 1
|
4 |
+
cnum: 64
|
5 |
+
conv_type: vanilla
|
6 |
+
dist_cnum: 32
|
7 |
+
drop: 0
|
8 |
+
frame_hidden: 512
|
9 |
+
gd: 4
|
10 |
+
in_channel: 4
|
11 |
+
init_weights: 1
|
12 |
+
input_resolution: !!python/tuple
|
13 |
+
- 240
|
14 |
+
- 432
|
15 |
+
flow_inChannel: 2
|
16 |
+
flow_cnum: 64
|
17 |
+
flow_hidden: 256
|
18 |
+
kernel_size: !!python/tuple
|
19 |
+
- 7
|
20 |
+
- 7
|
21 |
+
mlp_ratio: 40
|
22 |
+
numBlocks: 8
|
23 |
+
num_head: 4
|
24 |
+
padding: !!python/tuple
|
25 |
+
- 3
|
26 |
+
- 3
|
27 |
+
stride: !!python/tuple
|
28 |
+
- 3
|
29 |
+
- 3
|
30 |
+
sw: 8
|
31 |
+
tw: 2
|
32 |
+
use_bias: 1
|
33 |
+
norm: None
|
34 |
+
model: model
|
FGT_codes/FGT/checkpoint/fgt.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:41352263b2d14aec73f0dcf75c4bf5155ddb23404aba6f023a0300aadfd7672f
|
3 |
+
size 157341393
|
FGT_codes/FGT/config/data_info.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dataset general info
|
2 |
+
frame_path: youtubevos_frames
|
3 |
+
flow_path: youtubevos_flows
|
4 |
+
name2len: config/youtubevos_name2len.pkl
|
5 |
+
|
6 |
+
flow:
|
7 |
+
flow_height: 240
|
8 |
+
flow_width: 432
|
9 |
+
augments: False
|
10 |
+
colors: RGB
|
11 |
+
ext: .jpg
|
FGT_codes/FGT/config/davis_name2len.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6607939cc02910f5badaebff46242f299597e93c07d77b6d740a3004f179f50c
|
3 |
+
size 1621
|
FGT_codes/FGT/config/davis_name2len_train.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ad5e89d5486b38f74ac62d08924a4ff7caa445d34df827385457e8516d4763f
|
3 |
+
size 1073
|
FGT_codes/FGT/config/davis_name2len_val.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:30b2a23f943f40f2a09e98b474b88b07271e46a1224cb415650432d491cc1896
|
3 |
+
size 188
|
FGT_codes/FGT/config/train.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### General settings
|
2 |
+
name: FGT_train
|
3 |
+
use_tb_logger: true
|
4 |
+
outputdir: /myData/ret/experiments
|
5 |
+
datadir: /myData
|
6 |
+
record_iter: 16
|
7 |
+
|
8 |
+
### Calling definition
|
9 |
+
model: model
|
10 |
+
datasetName_train: train_dataset
|
11 |
+
network: network
|
12 |
+
|
13 |
+
### datasets
|
14 |
+
datasets:
|
15 |
+
train:
|
16 |
+
name: youtubevos
|
17 |
+
type: video
|
18 |
+
mode: train
|
19 |
+
dataInfo_config: ./config/data_info.yaml
|
20 |
+
use_shuffle: True
|
21 |
+
n_workers: 0
|
22 |
+
batch_size: 2
|
23 |
+
|
24 |
+
val:
|
25 |
+
name: youtubevos
|
26 |
+
type: video
|
27 |
+
mode: val
|
28 |
+
use_shuffle: False
|
29 |
+
n_workers: 1
|
30 |
+
batch_size: 1
|
31 |
+
val_config: ./config/valid_config.yaml
|
32 |
+
|
33 |
+
### train settings
|
34 |
+
train:
|
35 |
+
lr: 0.0001
|
36 |
+
lr_decay: 0.1
|
37 |
+
manual_seed: 10
|
38 |
+
BETA1: 0.9
|
39 |
+
BETA2: 0.999
|
40 |
+
MAX_ITERS: 500000
|
41 |
+
UPDATE_INTERVAL: 300000 # 400000 is also OK
|
42 |
+
WARMUP: ~
|
43 |
+
val_freq: 1 # Set to 1 is for debug, you can enlarge it to 50 in regular training
|
44 |
+
TEMPORAL_GAN: ~ # without temporal GAN
|
45 |
+
|
46 |
+
### logger
|
47 |
+
logger:
|
48 |
+
PRINT_FREQ: 16
|
49 |
+
SAVE_CHECKPOINT_FREQ: 4000 # 100 is for debug consideration
|
50 |
+
|
51 |
+
### Data related parameters
|
52 |
+
flow2rgb: 1
|
53 |
+
flow_direction: for
|
54 |
+
num_frames: 5
|
55 |
+
sample: random
|
56 |
+
max_val: 0.01
|
57 |
+
|
58 |
+
### Model related parameters
|
59 |
+
res_h: 240
|
60 |
+
res_w: 432
|
61 |
+
in_channel: 4
|
62 |
+
cnum: 64
|
63 |
+
flow_inChannel: 2
|
64 |
+
flow_cnum: 64
|
65 |
+
dist_cnum: 32
|
66 |
+
frame_hidden: 512
|
67 |
+
flow_hidden: 256
|
68 |
+
PASSMASK: 1
|
69 |
+
num_blocks: 8
|
70 |
+
kernel_size_w: 7
|
71 |
+
kernel_size_h: 7
|
72 |
+
stride_h: 3
|
73 |
+
stride_w: 3
|
74 |
+
num_head: 4
|
75 |
+
conv_type: vanilla
|
76 |
+
norm: None
|
77 |
+
use_bias: 1
|
78 |
+
ape: 1
|
79 |
+
pos_mode: single
|
80 |
+
mlp_ratio: 40
|
81 |
+
drop: 0
|
82 |
+
init_weights: 1
|
83 |
+
tw: 2
|
84 |
+
sw: 8
|
85 |
+
gd: 4
|
86 |
+
|
87 |
+
### Loss weights
|
88 |
+
L1M: 1
|
89 |
+
L1V: 1
|
90 |
+
adv: 0.01
|
91 |
+
|
92 |
+
### inference parameters
|
93 |
+
ref_length: 10
|
FGT_codes/FGT/config/valid_config.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flow_height: 240
|
2 |
+
flow_width: 432
|
3 |
+
data_root: davis_valid_flows
|
4 |
+
mask_root: rectMask_96
|
5 |
+
frame_root: JPEGImages/480p
|
6 |
+
flow_root: davis_test_flows
|
7 |
+
batch_size: 1
|
8 |
+
name2len: config/davis_name2len_val.pkl
|
FGT_codes/FGT/config/youtubevos_name2len.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:60410308d4a0e780a531290d8bddc7f204bc0e8a500eab7c01c563b8efce9753
|
3 |
+
size 75501
|
FGT_codes/FGT/data/__init__.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from importlib import import_module
|
5 |
+
|
6 |
+
|
7 |
+
def create_dataloader(phase, dataset, dataset_opt, opt=None, sampler=None):
|
8 |
+
logger = logging.getLogger('base')
|
9 |
+
if phase == 'train':
|
10 |
+
num_workers = dataset_opt['n_workers'] * opt['world_size']
|
11 |
+
batch_size = dataset_opt['batch_size']
|
12 |
+
if sampler is not None:
|
13 |
+
logger.info('N_workers: {}, batch_size: {} DDP train dataloader has been established'.format(num_workers,
|
14 |
+
batch_size))
|
15 |
+
return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
|
16 |
+
num_workers=num_workers, sampler=sampler,
|
17 |
+
pin_memory=True)
|
18 |
+
else:
|
19 |
+
logger.info('N_workers: {}, batch_size: {} train dataloader has been established'.format(num_workers,
|
20 |
+
batch_size))
|
21 |
+
return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
|
22 |
+
num_workers=num_workers, shuffle=True,
|
23 |
+
pin_memory=True)
|
24 |
+
|
25 |
+
else:
|
26 |
+
logger.info(
|
27 |
+
'N_workers: {}, batch_size: {} validate/test dataloader has been established'.format(
|
28 |
+
dataset_opt['n_workers'],
|
29 |
+
dataset_opt['batch_size']))
|
30 |
+
return torch.utils.data.DataLoader(dataset, batch_size=dataset_opt['batch_size'], shuffle=False,
|
31 |
+
num_workers=dataset_opt['n_workers'],
|
32 |
+
pin_memory=False)
|
33 |
+
|
34 |
+
|
35 |
+
def create_dataset(dataset_opt, dataInfo, phase, dataset_name):
|
36 |
+
if phase == 'train':
|
37 |
+
dataset_package = import_module('data.{}'.format(dataset_name))
|
38 |
+
dataset = dataset_package.VideoBasedDataset(dataset_opt, dataInfo)
|
39 |
+
|
40 |
+
mode = dataset_opt['mode']
|
41 |
+
logger = logging.getLogger('base')
|
42 |
+
logger.info(
|
43 |
+
'{} train dataset [{:s} - {:s} - {:s}] is created.'.format(dataset_opt['type'].upper(),
|
44 |
+
dataset.__class__.__name__,
|
45 |
+
dataset_opt['name'], mode))
|
46 |
+
else: # validate and test dataset
|
47 |
+
return ValueError('No dataset initialized for valdataset')
|
48 |
+
|
49 |
+
return dataset
|
FGT_codes/FGT/data/train_dataset.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
import os
|
8 |
+
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
import numpy as np
|
11 |
+
import cvbase
|
12 |
+
from .util.STTN_mask import create_random_shape_with_random_motion
|
13 |
+
import imageio
|
14 |
+
from .util.flow_utils import region_fill as rf
|
15 |
+
|
16 |
+
logger = logging.getLogger('base')
|
17 |
+
|
18 |
+
|
19 |
+
class VideoBasedDataset(Dataset):
|
20 |
+
def __init__(self, opt, dataInfo):
|
21 |
+
self.opt = opt
|
22 |
+
self.sampleMethod = opt['sample']
|
23 |
+
self.dataInfo = dataInfo
|
24 |
+
self.height, self.width = self.opt['input_resolution']
|
25 |
+
self.frame_path = dataInfo['frame_path']
|
26 |
+
self.flow_path = dataInfo['flow_path'] # The path of the optical flows
|
27 |
+
self.train_list = os.listdir(self.frame_path)
|
28 |
+
self.name2length = self.dataInfo['name2len']
|
29 |
+
with open(self.name2length, 'rb') as f:
|
30 |
+
self.name2length = pickle.load(f)
|
31 |
+
self.sequenceLen = self.opt['num_frames']
|
32 |
+
self.flow2rgb = opt['flow2rgb'] # whether to change flow to rgb domain
|
33 |
+
self.flow_direction = opt[
|
34 |
+
'flow_direction'] # The direction must be in ['for', 'back', 'bi'], indicating forward, backward and bidirectional flows
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.train_list)
|
38 |
+
|
39 |
+
def __getitem__(self, idx):
|
40 |
+
try:
|
41 |
+
item = self.load_item(idx)
|
42 |
+
except:
|
43 |
+
print('Loading error: ' + self.train_list[idx])
|
44 |
+
item = self.load_item(0)
|
45 |
+
return item
|
46 |
+
|
47 |
+
def frameSample(self, frameLen, sequenceLen):
|
48 |
+
if self.sampleMethod == 'random':
|
49 |
+
indices = [i for i in range(frameLen)]
|
50 |
+
sampleIndices = random.sample(indices, sequenceLen)
|
51 |
+
elif self.sampleMethod == 'seq':
|
52 |
+
pivot = random.randint(0, sequenceLen - 1 - frameLen)
|
53 |
+
sampleIndices = [i for i in range(pivot, pivot + frameLen)]
|
54 |
+
else:
|
55 |
+
raise ValueError('Cannot determine the sample method {}'.format(self.sampleMethod))
|
56 |
+
return sampleIndices
|
57 |
+
|
58 |
+
def load_item(self, idx):
|
59 |
+
video = self.train_list[idx]
|
60 |
+
frame_dir = os.path.join(self.frame_path, video)
|
61 |
+
forward_flow_dir = os.path.join(self.flow_path, video, 'forward_flo')
|
62 |
+
backward_flow_dir = os.path.join(self.flow_path, video, 'backward_flo')
|
63 |
+
frameLen = self.name2length[video]
|
64 |
+
flowLen = frameLen - 1
|
65 |
+
assert frameLen > self.sequenceLen, 'Frame length {} is less than sequence length'.format(frameLen)
|
66 |
+
sampledIndices = self.frameSample(frameLen, self.sequenceLen)
|
67 |
+
|
68 |
+
# generate random masks for these sampled frames
|
69 |
+
candidateMasks = create_random_shape_with_random_motion(frameLen, 0.9, 1.1, 1, 10)
|
70 |
+
|
71 |
+
# read the frames and masks
|
72 |
+
frames, masks, forward_flows, backward_flows = [], [], [], []
|
73 |
+
for i in range(len(sampledIndices)):
|
74 |
+
frame = self.read_frame(os.path.join(frame_dir, '{:05d}.jpg'.format(sampledIndices[i])), self.height,
|
75 |
+
self.width)
|
76 |
+
mask = self.read_mask(candidateMasks[sampledIndices[i]], self.height, self.width)
|
77 |
+
frames.append(frame)
|
78 |
+
masks.append(mask)
|
79 |
+
if self.flow_direction == 'for':
|
80 |
+
forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen)
|
81 |
+
forward_flow = self.diffusion_flow(forward_flow, mask)
|
82 |
+
forward_flows.append(forward_flow)
|
83 |
+
elif self.flow_direction == 'back':
|
84 |
+
backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i])
|
85 |
+
backward_flow = self.diffusion_flow(backward_flow, mask)
|
86 |
+
backward_flows.append(backward_flow)
|
87 |
+
elif self.flow_direction == 'bi':
|
88 |
+
forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen)
|
89 |
+
forward_flow = self.diffusion_flow(forward_flow, mask)
|
90 |
+
forward_flows.append(forward_flow)
|
91 |
+
backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i])
|
92 |
+
backward_flow = self.diffusion_flow(backward_flow, mask)
|
93 |
+
backward_flows.append(backward_flow)
|
94 |
+
else:
|
95 |
+
raise ValueError('Unknown flow direction mode: {}'.format(self.flow_direction))
|
96 |
+
inputs = {'frames': frames, 'masks': masks, 'forward_flo': forward_flows, 'backward_flo': backward_flows}
|
97 |
+
inputs = self.to_tensor(inputs)
|
98 |
+
inputs['frames'] = (inputs['frames'] / 255.) * 2 - 1
|
99 |
+
return inputs
|
100 |
+
|
101 |
+
def diffusion_flow(self, flow, mask):
|
102 |
+
flow_filled = np.zeros(flow.shape)
|
103 |
+
flow_filled[:, :, 0] = rf.regionfill(flow[:, :, 0] * (1 - mask), mask)
|
104 |
+
flow_filled[:, :, 1] = rf.regionfill(flow[:, :, 1] * (1 - mask), mask)
|
105 |
+
return flow_filled
|
106 |
+
|
107 |
+
def read_frame(self, path, height, width):
|
108 |
+
frame = imageio.imread(path)
|
109 |
+
frame = cv2.resize(frame, (width, height), cv2.INTER_LINEAR)
|
110 |
+
return frame
|
111 |
+
|
112 |
+
def read_mask(self, mask, height, width):
|
113 |
+
mask = np.array(mask)
|
114 |
+
mask = mask / 255.
|
115 |
+
raw_mask = (mask > 0.5).astype(np.uint8)
|
116 |
+
raw_mask = cv2.resize(raw_mask, dsize=(width, height), interpolation=cv2.INTER_NEAREST)
|
117 |
+
return raw_mask
|
118 |
+
|
119 |
+
def read_forward_flow(self, forward_flow_dir, sampledIndex, flowLen):
|
120 |
+
if sampledIndex >= flowLen:
|
121 |
+
sampledIndex = flowLen - 1
|
122 |
+
flow = cvbase.read_flow(os.path.join(forward_flow_dir, '{:05d}.flo'.format(sampledIndex)))
|
123 |
+
height, width = flow.shape[:2]
|
124 |
+
flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR)
|
125 |
+
flow[:, :, 0] = flow[:, :, 0] / width * self.width
|
126 |
+
flow[:, :, 1] = flow[:, :, 1] / height * self.height
|
127 |
+
return flow
|
128 |
+
|
129 |
+
def read_backward_flow(self, backward_flow_dir, sampledIndex):
|
130 |
+
if sampledIndex == 0:
|
131 |
+
sampledIndex = 0
|
132 |
+
else:
|
133 |
+
sampledIndex -= 1
|
134 |
+
flow = cvbase.read_flow(os.path.join(backward_flow_dir, '{:05d}.flo'.format(sampledIndex)))
|
135 |
+
height, width = flow.shape[:2]
|
136 |
+
flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR)
|
137 |
+
flow[:, :, 0] = flow[:, :, 0] / width * self.width
|
138 |
+
flow[:, :, 1] = flow[:, :, 1] / height * self.height
|
139 |
+
return flow
|
140 |
+
|
141 |
+
def to_tensor(self, data_list):
|
142 |
+
"""
|
143 |
+
|
144 |
+
Args:
|
145 |
+
data_list: A list contains multiple numpy arrays
|
146 |
+
|
147 |
+
Returns: The stacked tensor list
|
148 |
+
|
149 |
+
"""
|
150 |
+
keys = list(data_list.keys())
|
151 |
+
for key in keys:
|
152 |
+
if data_list[key] is None or data_list[key] == []:
|
153 |
+
data_list.pop(key)
|
154 |
+
else:
|
155 |
+
item = data_list[key]
|
156 |
+
if not isinstance(item, list):
|
157 |
+
item = torch.from_numpy(np.transpose(item, (2, 0, 1))).float() # [c, h, w]
|
158 |
+
else:
|
159 |
+
item = np.stack(item, axis=0)
|
160 |
+
if len(item.shape) == 3: # [t, h, w]
|
161 |
+
item = item[:, :, :, np.newaxis]
|
162 |
+
item = torch.from_numpy(np.transpose(item, (0, 3, 1, 2))).float() # [t, c, h, w]
|
163 |
+
data_list[key] = item
|
164 |
+
return data_list
|
165 |
+
|
FGT_codes/FGT/data/util/MaskModel.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class RandomMask():
|
5 |
+
def __init__(self, videoLength, dataInfo):
|
6 |
+
self.videoLength = videoLength
|
7 |
+
self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
|
8 |
+
dataInfo['image']['image_width']
|
9 |
+
self.maskHeight, self.maskWidth = dataInfo['mask']['mask_height'], \
|
10 |
+
dataInfo['mask']['mask_width']
|
11 |
+
try:
|
12 |
+
self.maxDeltaHeight, self.maxDeltaWidth = dataInfo['mask']['max_delta_height'], \
|
13 |
+
dataInfo['mask']['max_delta_width']
|
14 |
+
except KeyError:
|
15 |
+
self.maxDeltaHeight, self.maxDeltaWidth = 0, 0
|
16 |
+
|
17 |
+
try:
|
18 |
+
self.verticalMargin, self.horizontalMargin = dataInfo['mask']['vertical_margin'], \
|
19 |
+
dataInfo['mask']['horizontal_margin']
|
20 |
+
except KeyError:
|
21 |
+
self.verticalMargin, self.horizontalMargin = 0, 0
|
22 |
+
|
23 |
+
def __call__(self):
|
24 |
+
from .utils import random_bbox
|
25 |
+
from .utils import bbox2mask
|
26 |
+
masks = []
|
27 |
+
bbox = random_bbox(self.imageHeight, self.imageWidth, self.verticalMargin, self.horizontalMargin,
|
28 |
+
self.maskHeight, self.maskWidth)
|
29 |
+
if random.uniform(0, 1) > 0.5:
|
30 |
+
mask = bbox2mask(self.imageHeight, self.imageWidth, 0, 0, bbox)
|
31 |
+
for frame in range(self.videoLength):
|
32 |
+
masks.append(mask)
|
33 |
+
else:
|
34 |
+
for frame in range(self.videoLength):
|
35 |
+
delta_h, delta_w = random.randint(-3, 3), random.randint(-3, 3) # 每次向四个方向移动三个像素以内
|
36 |
+
bbox = list(bbox)
|
37 |
+
bbox[0] = min(max(self.verticalMargin, bbox[0] + delta_h), self.imageHeight - self.verticalMargin - bbox[2])
|
38 |
+
bbox[1] = min(max(self.horizontalMargin, bbox[1] + delta_w), self.imageWidth - self.horizontalMargin - bbox[3])
|
39 |
+
mask = bbox2mask(self.imageHeight, self.imageWidth, 0, 0, bbox)
|
40 |
+
masks.append(mask)
|
41 |
+
masks = np.stack(masks, axis=0)
|
42 |
+
if len(masks.shape) == 3:
|
43 |
+
masks = masks[:, :, :, np.newaxis]
|
44 |
+
assert len(masks.shape) == 4, 'Wrong mask dimension {}'.format(len(masks.shape))
|
45 |
+
return masks
|
46 |
+
|
47 |
+
|
48 |
+
class MidRandomMask():
|
49 |
+
### This mask is considered without random motion
|
50 |
+
def __init__(self, videoLength, dataInfo):
|
51 |
+
self.videoLength = videoLength
|
52 |
+
self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
|
53 |
+
dataInfo['image']['image_width']
|
54 |
+
self.maskHeight, self.maskWidth = dataInfo['mask']['mask_height'], \
|
55 |
+
dataInfo['mask']['mask_width']
|
56 |
+
|
57 |
+
def __call__(self):
|
58 |
+
from .utils import mid_bbox_mask
|
59 |
+
mask = mid_bbox_mask(self.imageHeight, self.imageWidth, self.maskHeight, self.maskWidth)
|
60 |
+
masks = []
|
61 |
+
for _ in range(self.videoLength):
|
62 |
+
masks.append(mask)
|
63 |
+
return mask
|
64 |
+
|
65 |
+
|
66 |
+
class MatrixMask():
|
67 |
+
### This mask is considered without random motion
|
68 |
+
def __init__(self, videoLength, dataInfo):
|
69 |
+
self.videoLength = videoLength
|
70 |
+
self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
|
71 |
+
dataInfo['image']['image_width']
|
72 |
+
self.maskHeight, self.maskWidth = dataInfo['mask']['mask_height'], \
|
73 |
+
dataInfo['mask']['mask_width']
|
74 |
+
try:
|
75 |
+
self.row, self.column = dataInfo['mask']['row'], \
|
76 |
+
dataInfo['mask']['column']
|
77 |
+
except KeyError:
|
78 |
+
self.row, self.column = 5, 4
|
79 |
+
|
80 |
+
def __call__(self):
|
81 |
+
from .utils import matrix2bbox
|
82 |
+
mask = matrix2bbox(self.imageHeight, self.imageWidth, self.maskHeight,
|
83 |
+
self.maskWidth, self.row, self.column)
|
84 |
+
masks = []
|
85 |
+
for video in range(self.videoLength):
|
86 |
+
masks.append(mask)
|
87 |
+
return mask
|
88 |
+
|
89 |
+
|
90 |
+
class FreeFormMask():
|
91 |
+
def __init__(self, videoLength, dataInfo):
|
92 |
+
self.videoLength = videoLength
|
93 |
+
self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
|
94 |
+
dataInfo['image']['image_width']
|
95 |
+
self.maxVertex = dataInfo['mask']['max_vertex']
|
96 |
+
self.maxLength = dataInfo['mask']['max_length']
|
97 |
+
self.maxBrushWidth = dataInfo['mask']['max_brush_width']
|
98 |
+
self.maxAngle = dataInfo['mask']['max_angle']
|
99 |
+
|
100 |
+
def __call__(self):
|
101 |
+
from .utils import freeFormMask
|
102 |
+
mask = freeFormMask(self.imageHeight, self.imageWidth,
|
103 |
+
self.maxVertex, self.maxLength,
|
104 |
+
self.maxBrushWidth, self.maxAngle)
|
105 |
+
return mask
|
106 |
+
|
107 |
+
|
108 |
+
class StationaryMask():
|
109 |
+
def __init__(self, videoLength, dataInfo):
|
110 |
+
self.videoLength = videoLength
|
111 |
+
self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
|
112 |
+
dataInfo['image']['image_width']
|
113 |
+
# self.maxPointNum = dataInfo['mask']['max_point_num']
|
114 |
+
# self.maxLength = dataInfo['mask']['max_length']
|
115 |
+
|
116 |
+
def __call__(self):
|
117 |
+
from .STTN_mask import create_random_shape_with_random_motion
|
118 |
+
masks = create_random_shape_with_random_motion(self.videoLength, 0.9, 1.1, 1, 10, self.imageHeight, self.imageWidth)
|
119 |
+
masks = np.stack(masks, axis=0)
|
120 |
+
if len(masks.shape) == 3:
|
121 |
+
masks = masks[:, :, :, np.newaxis]
|
122 |
+
assert len(masks.shape) == 4, 'Your masks with a wrong shape {}'.format(len(masks.shape))
|
123 |
+
return masks
|
FGT_codes/FGT/data/util/STTN_mask.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.patches as patches
|
2 |
+
from matplotlib.path import Path
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import io
|
6 |
+
import cv2
|
7 |
+
import time
|
8 |
+
import math
|
9 |
+
import argparse
|
10 |
+
import shutil
|
11 |
+
import random
|
12 |
+
import zipfile
|
13 |
+
from glob import glob
|
14 |
+
import math
|
15 |
+
import numpy as np
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
from PIL import Image, ImageOps, ImageDraw, ImageFilter
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torchvision
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.distributed as dist
|
24 |
+
|
25 |
+
import matplotlib
|
26 |
+
from matplotlib import pyplot as plt
|
27 |
+
matplotlib.use('agg')
|
28 |
+
|
29 |
+
|
30 |
+
class GroupRandomHorizontalFlip(object):
|
31 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, is_flow=False):
|
35 |
+
self.is_flow = is_flow
|
36 |
+
|
37 |
+
def __call__(self, img_group, is_flow=False):
|
38 |
+
v = random.random()
|
39 |
+
if v < 0.5:
|
40 |
+
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
41 |
+
if self.is_flow:
|
42 |
+
for i in range(0, len(ret), 2):
|
43 |
+
# invert flow pixel values when flipping
|
44 |
+
ret[i] = ImageOps.invert(ret[i])
|
45 |
+
return ret
|
46 |
+
else:
|
47 |
+
return img_group
|
48 |
+
|
49 |
+
|
50 |
+
class Stack(object):
|
51 |
+
def __init__(self, roll=False):
|
52 |
+
self.roll = roll
|
53 |
+
|
54 |
+
def __call__(self, img_group):
|
55 |
+
mode = img_group[0].mode
|
56 |
+
if mode == '1':
|
57 |
+
img_group = [img.convert('L') for img in img_group]
|
58 |
+
mode = 'L'
|
59 |
+
if mode == 'L':
|
60 |
+
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
|
61 |
+
elif mode == 'RGB':
|
62 |
+
if self.roll:
|
63 |
+
return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
|
64 |
+
else:
|
65 |
+
return np.stack(img_group, axis=2)
|
66 |
+
else:
|
67 |
+
raise NotImplementedError("Image mode {}".format(mode))
|
68 |
+
|
69 |
+
|
70 |
+
class ToTorchFormatTensor(object):
|
71 |
+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
72 |
+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
73 |
+
|
74 |
+
def __init__(self, div=True):
|
75 |
+
self.div = div
|
76 |
+
|
77 |
+
def __call__(self, pic):
|
78 |
+
if isinstance(pic, np.ndarray):
|
79 |
+
# numpy img: [L, C, H, W]
|
80 |
+
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
|
81 |
+
else:
|
82 |
+
# handle PIL Image
|
83 |
+
img = torch.ByteTensor(
|
84 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
85 |
+
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
86 |
+
# put it from HWC to CHW format
|
87 |
+
# yikes, this transpose takes 80% of the loading time/CPU
|
88 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
89 |
+
img = img.float().div(255) if self.div else img.float()
|
90 |
+
return img
|
91 |
+
|
92 |
+
|
93 |
+
# ##########################################
|
94 |
+
# ##########################################
|
95 |
+
|
96 |
+
def create_random_shape_with_random_motion(video_length, zoomin, zoomout, rotmin, rotmax, imageHeight=240, imageWidth=432):
|
97 |
+
# get a random shape
|
98 |
+
assert zoomin < 1, "Zoom-in parameter must be smaller than 1"
|
99 |
+
assert zoomout > 1, "Zoom-out parameter must be larger than 1"
|
100 |
+
assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !"
|
101 |
+
height = random.randint(imageHeight//3, imageHeight-1)
|
102 |
+
width = random.randint(imageWidth//3, imageWidth-1)
|
103 |
+
edge_num = random.randint(6, 8)
|
104 |
+
ratio = random.randint(6, 8)/10
|
105 |
+
region = get_random_shape(
|
106 |
+
edge_num=edge_num, ratio=ratio, height=height, width=width)
|
107 |
+
region_width, region_height = region.size
|
108 |
+
# get random position
|
109 |
+
x, y = random.randint(
|
110 |
+
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
|
111 |
+
velocity = get_random_velocity(max_speed=3)
|
112 |
+
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
113 |
+
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
114 |
+
masks = [m.convert('L')]
|
115 |
+
# return fixed masks
|
116 |
+
if random.uniform(0, 1) > 0.5:
|
117 |
+
return masks*video_length # -> directly copy all the base masks
|
118 |
+
# return moving masks
|
119 |
+
for _ in range(video_length-1):
|
120 |
+
x, y, velocity = random_move_control_points(
|
121 |
+
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
|
122 |
+
m = Image.fromarray(
|
123 |
+
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
124 |
+
### add by kaidong, to simulate zoon-in, zoom-out and rotation
|
125 |
+
extra_transform = random.uniform(0, 1)
|
126 |
+
# zoom in and zoom out
|
127 |
+
if extra_transform > 0.75:
|
128 |
+
resize_coefficient = random.uniform(zoomin, zoomout)
|
129 |
+
region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST)
|
130 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
131 |
+
region_width, region_height = region.size
|
132 |
+
# rotation
|
133 |
+
elif extra_transform > 0.5:
|
134 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
135 |
+
m = m.rotate(random.randint(rotmin, rotmax))
|
136 |
+
# region_width, region_height = region.size
|
137 |
+
### end
|
138 |
+
else:
|
139 |
+
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
140 |
+
masks.append(m.convert('L'))
|
141 |
+
return masks
|
142 |
+
|
143 |
+
|
144 |
+
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
|
145 |
+
'''
|
146 |
+
There is the initial point and 3 points per cubic bezier curve.
|
147 |
+
Thus, the curve will only pass though n points, which will be the sharp edges.
|
148 |
+
The other 2 modify the shape of the bezier curve.
|
149 |
+
edge_num, Number of possibly sharp edges
|
150 |
+
points_num, number of points in the Path
|
151 |
+
ratio, (0, 1) magnitude of the perturbation from the unit circle,
|
152 |
+
'''
|
153 |
+
points_num = edge_num*3 + 1
|
154 |
+
angles = np.linspace(0, 2*np.pi, points_num)
|
155 |
+
codes = np.full(points_num, Path.CURVE4)
|
156 |
+
codes[0] = Path.MOVETO
|
157 |
+
# Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
|
158 |
+
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
|
159 |
+
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
|
160 |
+
verts[-1, :] = verts[0, :]
|
161 |
+
path = Path(verts, codes)
|
162 |
+
# draw paths into images
|
163 |
+
fig = plt.figure()
|
164 |
+
ax = fig.add_subplot(111)
|
165 |
+
patch = patches.PathPatch(path, facecolor='black', lw=2)
|
166 |
+
ax.add_patch(patch)
|
167 |
+
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
|
168 |
+
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
|
169 |
+
ax.axis('off') # removes the axis to leave only the shape
|
170 |
+
fig.canvas.draw()
|
171 |
+
# convert plt images into numpy images
|
172 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
173 |
+
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
|
174 |
+
plt.close(fig)
|
175 |
+
# postprocess
|
176 |
+
data = cv2.resize(data, (width, height))[:, :, 0]
|
177 |
+
data = (1 - np.array(data > 0).astype(np.uint8))*255
|
178 |
+
corrdinates = np.where(data > 0)
|
179 |
+
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
|
180 |
+
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
|
181 |
+
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
|
182 |
+
return region
|
183 |
+
|
184 |
+
|
185 |
+
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
186 |
+
speed, angle = velocity
|
187 |
+
d_speed, d_angle = maxAcceleration
|
188 |
+
if dist == 'uniform':
|
189 |
+
speed += np.random.uniform(-d_speed, d_speed)
|
190 |
+
angle += np.random.uniform(-d_angle, d_angle)
|
191 |
+
elif dist == 'guassian':
|
192 |
+
speed += np.random.normal(0, d_speed / 2)
|
193 |
+
angle += np.random.normal(0, d_angle / 2)
|
194 |
+
else:
|
195 |
+
raise NotImplementedError(
|
196 |
+
f'Distribution type {dist} is not supported.')
|
197 |
+
return (speed, angle)
|
198 |
+
|
199 |
+
|
200 |
+
def get_random_velocity(max_speed=3, dist='uniform'):
|
201 |
+
if dist == 'uniform':
|
202 |
+
speed = np.random.uniform(max_speed)
|
203 |
+
elif dist == 'guassian':
|
204 |
+
speed = np.abs(np.random.normal(0, max_speed / 2))
|
205 |
+
else:
|
206 |
+
raise NotImplementedError(
|
207 |
+
'Distribution type {} is not supported.'.format(dist))
|
208 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
209 |
+
return (speed, angle)
|
210 |
+
|
211 |
+
|
212 |
+
def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
|
213 |
+
region_width, region_height = region_size
|
214 |
+
speed, angle = lineVelocity
|
215 |
+
X += int(speed * np.cos(angle))
|
216 |
+
Y += int(speed * np.sin(angle))
|
217 |
+
lineVelocity = random_accelerate(
|
218 |
+
lineVelocity, maxLineAcceleration, dist='guassian')
|
219 |
+
if ((X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0)):
|
220 |
+
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
221 |
+
new_X = np.clip(X, 0, imageHeight - region_height)
|
222 |
+
new_Y = np.clip(Y, 0, imageWidth - region_width)
|
223 |
+
return new_X, new_Y, lineVelocity
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
# ##############################################
|
228 |
+
# ##############################################
|
229 |
+
|
230 |
+
if __name__ == '__main__':
|
231 |
+
import os
|
232 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
233 |
+
trials = 10
|
234 |
+
for _ in range(trials):
|
235 |
+
video_length = 10
|
236 |
+
# The returned masks are either stationary (50%) or moving (50%)
|
237 |
+
masks = create_random_shape_with_random_motion(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432)
|
238 |
+
i = 0
|
239 |
+
|
240 |
+
for m in masks:
|
241 |
+
cv2.imshow('mask', np.array(m))
|
242 |
+
cv2.waitKey(500)
|
243 |
+
# m.save('mask_{}.png'.format(i))
|
244 |
+
i += 1
|
FGT_codes/FGT/data/util/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .STTN_mask import create_random_shape_with_random_motion
|
2 |
+
|
3 |
+
import logging
|
4 |
+
logger = logging.getLogger('base')
|
5 |
+
|
6 |
+
|
7 |
+
def initialize_mask(videoLength, dataInfo):
|
8 |
+
from .MaskModel import RandomMask
|
9 |
+
from .MaskModel import MidRandomMask
|
10 |
+
from .MaskModel import MatrixMask
|
11 |
+
from .MaskModel import FreeFormMask
|
12 |
+
from .MaskModel import StationaryMask
|
13 |
+
|
14 |
+
return {'random': RandomMask(videoLength, dataInfo),
|
15 |
+
'mid': MidRandomMask(videoLength, dataInfo),
|
16 |
+
'matrix': MatrixMask(videoLength, dataInfo),
|
17 |
+
'free': FreeFormMask(videoLength, dataInfo),
|
18 |
+
'stationary': StationaryMask(videoLength, dataInfo)
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def create_mask(maskClass, form):
|
23 |
+
if form == 'mix':
|
24 |
+
from random import randint
|
25 |
+
candidates = list(maskClass.keys())
|
26 |
+
candidate_index = randint(0, len(candidates) - 1)
|
27 |
+
return maskClass[candidates[candidate_index]]()
|
28 |
+
return maskClass[form]()
|
FGT_codes/FGT/data/util/flow_utils/__init__.py
ADDED
File without changes
|
FGT_codes/FGT/data/util/flow_utils/flow_reversal.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def flow_reversal(flow):
|
5 |
+
"""
|
6 |
+
flow: shape [b, c, h, w]
|
7 |
+
return: backward flow in corresponding to the forward flow
|
8 |
+
The formula is borrowed from Quadratic Video Interpolation (4)
|
9 |
+
"""
|
10 |
+
b, c, h, w = flow.shape
|
11 |
+
y = flow[:, 0:1, :, :]
|
12 |
+
x = flow[:, 1:2, :, :] # [b, 1, h, w]
|
13 |
+
|
14 |
+
x = x.repeat(1, c, 1, 1)
|
15 |
+
y = y.repeat(1, c, 1, 1)
|
16 |
+
|
17 |
+
# get the four points of the square (x1, y1), (x1, y2), (x2, y1), (x2, y2)
|
18 |
+
x1 = torch.floor(x)
|
19 |
+
x2 = x1 + 1
|
20 |
+
y1 = torch.floor(y)
|
21 |
+
y2 = y1 + 1
|
22 |
+
|
23 |
+
# get gaussian weights
|
24 |
+
w11, w12, w21, w22 = get_gaussian_weights(x, y, x1, x2, y1, y2)
|
25 |
+
|
26 |
+
# calculate the weight maps for each optical flows
|
27 |
+
flow11, o11 = sample_one(flow, x1, y1, w11)
|
28 |
+
flow12, o12 = sample_one(flow, x1, y2, w12)
|
29 |
+
flow21, o21 = sample_one(flow, x2, y1, w21)
|
30 |
+
flow22, o22 = sample_one(flow, x2, y2, w22)
|
31 |
+
|
32 |
+
# fuse all the reversed flows based on equation (4)
|
33 |
+
flow_o = flow11 + flow12 + flow21 + flow22
|
34 |
+
o = o11 + o12 + o21 + o22
|
35 |
+
|
36 |
+
flow_o = -flow_o
|
37 |
+
flow_o[o > 0] = flow_o[o > 0] / o[o > 0]
|
38 |
+
|
39 |
+
return flow_o
|
40 |
+
|
41 |
+
|
42 |
+
def get_gaussian_weights(x, y, x1, x2, y1, y2):
|
43 |
+
sigma = 1
|
44 |
+
w11 = torch.exp(-((x - x1) ** 2 + (y - y1) ** 2) / (sigma ** 2))
|
45 |
+
w12 = torch.exp(-((x - x1) ** 2 + (y - y2) ** 2) / (sigma ** 2))
|
46 |
+
w21 = torch.exp(-((x - x2) ** 2 + (y - y1) ** 2) / (sigma ** 2))
|
47 |
+
w22 = torch.exp(-((x - x2) ** 2 + (y - y2) ** 2) / (sigma ** 2))
|
48 |
+
return w11, w12, w21, w22
|
49 |
+
|
50 |
+
|
51 |
+
def sample_one(flow, shiftx, shifty, weight):
|
52 |
+
b, c, h, w = flow.shape
|
53 |
+
flat_shiftx = shiftx.view(-1) # [h * w]
|
54 |
+
flat_shifty = shifty.view(-1) # [h * w]
|
55 |
+
flat_basex = torch.arange(0, h, requires_grad=False).view(-1, 1).long().repeat(b, c, 1, w).view(-1) # [h * w]
|
56 |
+
flat_basey = torch.arange(0, w, requires_grad=False).view(-1, 1).long().repeat(b, c, h, 1).view(-1) # [h * w]
|
57 |
+
flat_weight = weight.reshape(-1) # [h * w]
|
58 |
+
flat_flow = flow.reshape(-1)
|
59 |
+
|
60 |
+
idxn = torch.arange(0, b, requires_grad=False).view(b, 1, 1, 1).long().repeat(1, c, h, w).view(-1)
|
61 |
+
idxc = torch.arange(0, c, requires_grad=False).view(1, c, 1, 1).long().repeat(b, 1, h, w).view(-1)
|
62 |
+
idxx = flat_shiftx.long() + flat_basex # size [-1]
|
63 |
+
idxy = flat_shifty.long() + flat_basey # size [-1]
|
64 |
+
|
65 |
+
# record the shifted pixels inside the image boundaries
|
66 |
+
mask = idxx.ge(0) & idxx.lt(h) & idxy.ge(0) & idxy.lt(w)
|
67 |
+
|
68 |
+
# mask off points out of boundaries
|
69 |
+
ids = idxn * c * h * w + idxc * h * w + idxx * w + idxy
|
70 |
+
ids_mask = torch.masked_select(ids, mask).clone()
|
71 |
+
|
72 |
+
# put the value into corresponding regions
|
73 |
+
flow_warp = torch.zeros([b * c * h * w])
|
74 |
+
flow_warp.put_(ids_mask, torch.masked_select(flat_flow * flat_weight, mask), accumulate=True)
|
75 |
+
one_warp = torch.zeros([b * c * h * w])
|
76 |
+
one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)
|
77 |
+
return flow_warp.view(b, c, h, w), one_warp.view(b, c, h, w)
|
FGT_codes/FGT/data/util/flow_utils/region_fill.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from scipy import sparse
|
4 |
+
from scipy.sparse.linalg import spsolve
|
5 |
+
|
6 |
+
|
7 |
+
# Laplacian filling
|
8 |
+
def regionfill(I, mask, factor=1.0):
|
9 |
+
if np.count_nonzero(mask) == 0:
|
10 |
+
return I.copy()
|
11 |
+
resize_mask = cv2.resize(
|
12 |
+
mask.astype(float), (0, 0), fx=factor, fy=factor) > 0
|
13 |
+
resize_I = cv2.resize(I.astype(float), (0, 0), fx=factor, fy=factor)
|
14 |
+
maskPerimeter = findBoundaryPixels(resize_mask)
|
15 |
+
regionfillLaplace(resize_I, resize_mask, maskPerimeter)
|
16 |
+
resize_I = cv2.resize(resize_I, (I.shape[1], I.shape[0]))
|
17 |
+
resize_I[mask == 0] = I[mask == 0]
|
18 |
+
return resize_I
|
19 |
+
|
20 |
+
|
21 |
+
def findBoundaryPixels(mask):
|
22 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
|
23 |
+
maskDilated = cv2.dilate(mask.astype(float), kernel)
|
24 |
+
return (maskDilated > 0) & (mask == 0)
|
25 |
+
|
26 |
+
|
27 |
+
def regionfillLaplace(I, mask, maskPerimeter):
|
28 |
+
height, width = I.shape
|
29 |
+
rightSide = formRightSide(I, maskPerimeter)
|
30 |
+
|
31 |
+
# Location of mask pixels
|
32 |
+
maskIdx = np.where(mask)
|
33 |
+
|
34 |
+
# Only keep values for pixels that are in the mask
|
35 |
+
rightSide = rightSide[maskIdx]
|
36 |
+
|
37 |
+
# Number the mask pixels in a grid matrix
|
38 |
+
grid = -np.ones((height, width))
|
39 |
+
grid[maskIdx] = range(0, maskIdx[0].size)
|
40 |
+
# Pad with zeros to avoid "index out of bounds" errors in the for loop
|
41 |
+
grid = padMatrix(grid)
|
42 |
+
gridIdx = np.where(grid >= 0)
|
43 |
+
|
44 |
+
# Form the connectivity matrix D=sparse(i,j,s)
|
45 |
+
# Connect each mask pixel to itself
|
46 |
+
i = np.arange(0, maskIdx[0].size)
|
47 |
+
j = np.arange(0, maskIdx[0].size)
|
48 |
+
# The coefficient is the number of neighbors over which we average
|
49 |
+
numNeighbors = computeNumberOfNeighbors(height, width)
|
50 |
+
s = numNeighbors[maskIdx]
|
51 |
+
# Now connect the N,E,S,W neighbors if they exist
|
52 |
+
for direction in ((-1, 0), (0, 1), (1, 0), (0, -1)):
|
53 |
+
# Possible neighbors in the current direction
|
54 |
+
neighbors = grid[gridIdx[0] + direction[0], gridIdx[1] + direction[1]]
|
55 |
+
# ConDnect mask points to neighbors with -1's
|
56 |
+
index = (neighbors >= 0)
|
57 |
+
i = np.concatenate((i, grid[gridIdx[0][index], gridIdx[1][index]]))
|
58 |
+
j = np.concatenate((j, neighbors[index]))
|
59 |
+
s = np.concatenate((s, -np.ones(np.count_nonzero(index))))
|
60 |
+
|
61 |
+
D = sparse.coo_matrix((s, (i.astype(int), j.astype(int)))).tocsr()
|
62 |
+
sol = spsolve(D, rightSide)
|
63 |
+
I[maskIdx] = sol
|
64 |
+
return I
|
65 |
+
|
66 |
+
|
67 |
+
def formRightSide(I, maskPerimeter):
|
68 |
+
height, width = I.shape
|
69 |
+
perimeterValues = np.zeros((height, width))
|
70 |
+
perimeterValues[maskPerimeter] = I[maskPerimeter]
|
71 |
+
rightSide = np.zeros((height, width))
|
72 |
+
|
73 |
+
rightSide[1:height - 1, 1:width - 1] = (
|
74 |
+
perimeterValues[0:height - 2, 1:width - 1] +
|
75 |
+
perimeterValues[2:height, 1:width - 1] +
|
76 |
+
perimeterValues[1:height - 1, 0:width - 2] +
|
77 |
+
perimeterValues[1:height - 1, 2:width])
|
78 |
+
|
79 |
+
rightSide[1:height - 1, 0] = (
|
80 |
+
perimeterValues[0:height - 2, 0] + perimeterValues[2:height, 0] +
|
81 |
+
perimeterValues[1:height - 1, 1])
|
82 |
+
|
83 |
+
rightSide[1:height - 1, width - 1] = (
|
84 |
+
perimeterValues[0:height - 2, width - 1] +
|
85 |
+
perimeterValues[2:height, width - 1] +
|
86 |
+
perimeterValues[1:height - 1, width - 2])
|
87 |
+
|
88 |
+
rightSide[0, 1:width - 1] = (
|
89 |
+
perimeterValues[1, 1:width - 1] + perimeterValues[0, 0:width - 2] +
|
90 |
+
perimeterValues[0, 2:width])
|
91 |
+
|
92 |
+
rightSide[height - 1, 1:width - 1] = (
|
93 |
+
perimeterValues[height - 2, 1:width - 1] +
|
94 |
+
perimeterValues[height - 1, 0:width - 2] +
|
95 |
+
perimeterValues[height - 1, 2:width])
|
96 |
+
|
97 |
+
rightSide[0, 0] = perimeterValues[0, 1] + perimeterValues[1, 0]
|
98 |
+
rightSide[0, width - 1] = (
|
99 |
+
perimeterValues[0, width - 2] + perimeterValues[1, width - 1])
|
100 |
+
rightSide[height - 1, 0] = (
|
101 |
+
perimeterValues[height - 2, 0] + perimeterValues[height - 1, 1])
|
102 |
+
rightSide[height - 1, width - 1] = (perimeterValues[height - 2, width - 1] +
|
103 |
+
perimeterValues[height - 1, width - 2])
|
104 |
+
return rightSide
|
105 |
+
|
106 |
+
|
107 |
+
def computeNumberOfNeighbors(height, width):
|
108 |
+
# Initialize
|
109 |
+
numNeighbors = np.zeros((height, width))
|
110 |
+
# Interior pixels have 4 neighbors
|
111 |
+
numNeighbors[1:height - 1, 1:width - 1] = 4
|
112 |
+
# Border pixels have 3 neighbors
|
113 |
+
numNeighbors[1:height - 1, (0, width - 1)] = 3
|
114 |
+
numNeighbors[(0, height - 1), 1:width - 1] = 3
|
115 |
+
# Corner pixels have 2 neighbors
|
116 |
+
numNeighbors[(0, 0, height - 1, height - 1), (0, width - 1, 0,
|
117 |
+
width - 1)] = 2
|
118 |
+
return numNeighbors
|
119 |
+
|
120 |
+
|
121 |
+
def padMatrix(grid):
|
122 |
+
height, width = grid.shape
|
123 |
+
gridPadded = -np.ones((height + 2, width + 2))
|
124 |
+
gridPadded[1:height + 1, 1:width + 1] = grid
|
125 |
+
gridPadded = gridPadded.astype(grid.dtype)
|
126 |
+
return gridPadded
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == '__main__':
|
130 |
+
import time
|
131 |
+
x = np.linspace(0, 255, 500)
|
132 |
+
xv, _ = np.meshgrid(x, x)
|
133 |
+
image = ((xv + np.transpose(xv)) / 2.0).astype(int)
|
134 |
+
mask = np.zeros((500, 500))
|
135 |
+
mask[100:259, 100:259] = 1
|
136 |
+
mask = (mask > 0)
|
137 |
+
image[mask] = 0
|
138 |
+
st = time.time()
|
139 |
+
inpaint = regionfill(image, mask, 0.5).astype(np.uint8)
|
140 |
+
print(time.time() - st)
|
141 |
+
cv2.imshow('img', np.concatenate((image.astype(np.uint8), inpaint)))
|
142 |
+
cv2.waitKey()
|
FGT_codes/FGT/data/util/freeform_masks.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import shutil
|
4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) # NOQA
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import argparse
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from .mask_generators import get_video_masks_by_moving_random_stroke, get_masked_ratio
|
11 |
+
from .util import make_dirs, make_dir_under_root, get_everything_under
|
12 |
+
from .readers import MaskReader
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
'-od', '--output_dir',
|
18 |
+
type=str,
|
19 |
+
help="Output directory name"
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
'-im',
|
23 |
+
'--image_masks', action='store_true',
|
24 |
+
help="Set this if you want to generate independent masks in one directory."
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
'-vl', '--video_len',
|
28 |
+
type=int,
|
29 |
+
help="Maximum video length (i.e. #mask)"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
'-ns', '--num_stroke',
|
33 |
+
type=int,
|
34 |
+
help="Number of stroke in one mask"
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
'-nsb', '--num_stroke_bound',
|
38 |
+
type=int,
|
39 |
+
nargs=2,
|
40 |
+
help="Upper/lower bound of number of stroke in one mask"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
'-n',
|
44 |
+
type=int,
|
45 |
+
help="Number of mask to generate"
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
'-sp',
|
49 |
+
'--stroke_preset',
|
50 |
+
type=str,
|
51 |
+
default='rand_curve',
|
52 |
+
help="Preset of the stroke parameters"
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
'-iw',
|
56 |
+
'--image_width',
|
57 |
+
type=int,
|
58 |
+
default=320
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
'-ih',
|
62 |
+
'--image_height',
|
63 |
+
type=int,
|
64 |
+
default=180
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
'--cluster_by_area',
|
68 |
+
action='store_true'
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
'--leave_boarder_unmasked',
|
72 |
+
type=int,
|
73 |
+
help='Set this to a number, then a copy of the mask where the mask of boarder is erased.'
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
'--redo_without_generation',
|
77 |
+
action='store_true',
|
78 |
+
help='Set this, and the script will skip the generation and redo the left tasks'
|
79 |
+
'(uncluster -> erase boarder -> re-cluster)'
|
80 |
+
)
|
81 |
+
args = parser.parse_args()
|
82 |
+
return args
|
83 |
+
|
84 |
+
|
85 |
+
def get_stroke_preset(stroke_preset):
|
86 |
+
if stroke_preset == 'object_like':
|
87 |
+
return {
|
88 |
+
"nVertexBound": [5, 30],
|
89 |
+
"maxHeadSpeed": 15,
|
90 |
+
"maxHeadAcceleration": (10, 1.5),
|
91 |
+
"brushWidthBound": (20, 50),
|
92 |
+
"nMovePointRatio": 0.5,
|
93 |
+
"maxPiontMove": 10,
|
94 |
+
"maxLineAcceleration": (5, 0.5),
|
95 |
+
"boarderGap": None,
|
96 |
+
"maxInitSpeed": 10,
|
97 |
+
}
|
98 |
+
elif stroke_preset == 'object_like_middle':
|
99 |
+
return {
|
100 |
+
"nVertexBound": [5, 15],
|
101 |
+
"maxHeadSpeed": 8,
|
102 |
+
"maxHeadAcceleration": (4, 1.5),
|
103 |
+
"brushWidthBound": (20, 50),
|
104 |
+
"nMovePointRatio": 0.5,
|
105 |
+
"maxPiontMove": 5,
|
106 |
+
"maxLineAcceleration": (5, 0.5),
|
107 |
+
"boarderGap": None,
|
108 |
+
"maxInitSpeed": 10,
|
109 |
+
}
|
110 |
+
elif stroke_preset == 'object_like_small':
|
111 |
+
return {
|
112 |
+
"nVertexBound": [5, 20],
|
113 |
+
"maxHeadSpeed": 7,
|
114 |
+
"maxHeadAcceleration": (3.5, 1.5),
|
115 |
+
"brushWidthBound": (10, 30),
|
116 |
+
"nMovePointRatio": 0.5,
|
117 |
+
"maxPiontMove": 5,
|
118 |
+
"maxLineAcceleration": (3, 0.5),
|
119 |
+
"boarderGap": None,
|
120 |
+
"maxInitSpeed": 4,
|
121 |
+
}
|
122 |
+
elif stroke_preset == 'rand_curve':
|
123 |
+
return {
|
124 |
+
"nVertexBound": [10, 30],
|
125 |
+
"maxHeadSpeed": 20,
|
126 |
+
"maxHeadAcceleration": (15, 0.5),
|
127 |
+
"brushWidthBound": (3, 10),
|
128 |
+
"nMovePointRatio": 0.5,
|
129 |
+
"maxPiontMove": 3,
|
130 |
+
"maxLineAcceleration": (5, 0.5),
|
131 |
+
"boarderGap": None,
|
132 |
+
"maxInitSpeed": 6
|
133 |
+
}
|
134 |
+
elif stroke_preset == 'rand_curve_small':
|
135 |
+
return {
|
136 |
+
"nVertexBound": [6, 22],
|
137 |
+
"maxHeadSpeed": 12,
|
138 |
+
"maxHeadAcceleration": (8, 0.5),
|
139 |
+
"brushWidthBound": (2.5, 5),
|
140 |
+
"nMovePointRatio": 0.5,
|
141 |
+
"maxPiontMove": 1.5,
|
142 |
+
"maxLineAcceleration": (3, 0.5),
|
143 |
+
"boarderGap": None,
|
144 |
+
"maxInitSpeed": 3
|
145 |
+
}
|
146 |
+
else:
|
147 |
+
raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
|
148 |
+
|
149 |
+
|
150 |
+
def copy_masks_without_boarder(root_dir, args):
|
151 |
+
def erase_mask_boarder(mask, gap):
|
152 |
+
pix = np.asarray(mask).astype('uint8') * 255
|
153 |
+
pix[:gap, :] = 255
|
154 |
+
pix[-gap:, :] = 255
|
155 |
+
pix[:, :gap] = 255
|
156 |
+
pix[:, -gap:] = 255
|
157 |
+
return Image.fromarray(pix).convert('1')
|
158 |
+
|
159 |
+
wo_boarder_dir = root_dir + '_noBoarder'
|
160 |
+
shutil.copytree(root_dir, wo_boarder_dir)
|
161 |
+
|
162 |
+
for i, filename in enumerate(get_everything_under(wo_boarder_dir)):
|
163 |
+
if args.image_masks:
|
164 |
+
mask = Image.open(filename)
|
165 |
+
mask_wo_boarder = erase_mask_boarder(mask, args.leave_boarder_unmasked)
|
166 |
+
mask_wo_boarder.save(filename)
|
167 |
+
else:
|
168 |
+
# filename is a diretory containing multiple mask files
|
169 |
+
for f in get_everything_under(filename, pattern='*.png'):
|
170 |
+
mask = Image.open(f)
|
171 |
+
mask_wo_boarder = erase_mask_boarder(mask, args.leave_boarder_unmasked)
|
172 |
+
mask_wo_boarder.save(f)
|
173 |
+
|
174 |
+
return wo_boarder_dir
|
175 |
+
|
176 |
+
|
177 |
+
def cluster_by_masked_area(root_dir, args):
|
178 |
+
clustered_dir = root_dir + '_clustered'
|
179 |
+
make_dirs(clustered_dir)
|
180 |
+
radius = 5
|
181 |
+
|
182 |
+
# all masks with ratio in x +- radius will be stored in sub-directory x
|
183 |
+
clustered_centors = np.arange(radius, 100, radius * 2)
|
184 |
+
clustered_subdirs = []
|
185 |
+
for c in clustered_centors:
|
186 |
+
# make sub-directories for each ratio range
|
187 |
+
clustered_subdirs.append(make_dir_under_root(clustered_dir, str(c)))
|
188 |
+
|
189 |
+
for i, filename in enumerate(get_everything_under(root_dir)):
|
190 |
+
if args.image_masks:
|
191 |
+
ratio = get_masked_ratio(Image.open(filename))
|
192 |
+
else:
|
193 |
+
# filename is a diretory containing multiple mask files
|
194 |
+
ratio = np.mean([
|
195 |
+
get_masked_ratio(Image.open(f))
|
196 |
+
for f in get_everything_under(filename, pattern='*.png')
|
197 |
+
])
|
198 |
+
|
199 |
+
# find the nearest centor
|
200 |
+
for i, c in enumerate(clustered_centors):
|
201 |
+
if c - radius <= ratio * 100 <= c + radius:
|
202 |
+
shutil.move(filename, clustered_subdirs[i])
|
203 |
+
break
|
204 |
+
|
205 |
+
shutil.rmtree(root_dir)
|
206 |
+
os.rename(clustered_dir, root_dir)
|
207 |
+
|
208 |
+
|
209 |
+
def decide_nStroke(args):
|
210 |
+
if args.num_stroke is not None:
|
211 |
+
return args.num_stroke
|
212 |
+
elif args.num_stroke_bound is not None:
|
213 |
+
return np.random.randint(args.num_stroke_bound[0], args.num_stroke_bound[1])
|
214 |
+
else:
|
215 |
+
raise ValueError('One of "-ns" or "-nsb" is needed')
|
216 |
+
|
217 |
+
|
218 |
+
def main(args):
|
219 |
+
preset = get_stroke_preset(args.stroke_preset)
|
220 |
+
make_dirs(args.output_dir)
|
221 |
+
|
222 |
+
if args.redo_without_generation:
|
223 |
+
assert(len(get_everything_under(args.output_dir)) > 0)
|
224 |
+
# put back clustered masks
|
225 |
+
for clustered_subdir in get_everything_under(args.output_dir):
|
226 |
+
if not os.path.isdir(clustered_subdir):
|
227 |
+
continue
|
228 |
+
for f in get_everything_under(clustered_subdir):
|
229 |
+
shutil.move(f, args.output_dir)
|
230 |
+
os.rmdir(clustered_subdir)
|
231 |
+
|
232 |
+
else:
|
233 |
+
if args.image_masks:
|
234 |
+
for i in range(args.n):
|
235 |
+
nStroke = decide_nStroke(args)
|
236 |
+
mask = get_video_masks_by_moving_random_stroke(
|
237 |
+
video_len=1, imageWidth=args.image_width, imageHeight=args.image_height,
|
238 |
+
nStroke=nStroke, **preset
|
239 |
+
)[0]
|
240 |
+
mask.save(os.path.join(args.output_dir, f'{i:07d}.png'))
|
241 |
+
|
242 |
+
else:
|
243 |
+
for i in range(args.n):
|
244 |
+
mask_dir = make_dir_under_root(args.output_dir, f'{i:05d}')
|
245 |
+
mask_reader = MaskReader(mask_dir, read=False)
|
246 |
+
|
247 |
+
nStroke = decide_nStroke(args)
|
248 |
+
masks = get_video_masks_by_moving_random_stroke(
|
249 |
+
imageWidth=args.image_width, imageHeight=args.image_height,
|
250 |
+
video_len=args.video_len, nStroke=nStroke, **preset)
|
251 |
+
|
252 |
+
mask_reader.set_files(masks)
|
253 |
+
mask_reader.save_files(output_dir=mask_reader.dir_name)
|
254 |
+
|
255 |
+
if args.leave_boarder_unmasked is not None:
|
256 |
+
dir_leave_boarder = copy_masks_without_boarder(args.output_dir, args)
|
257 |
+
if args.cluster_by_area:
|
258 |
+
cluster_by_masked_area(dir_leave_boarder, args)
|
259 |
+
|
260 |
+
if args.cluster_by_area:
|
261 |
+
cluster_by_masked_area(args.output_dir, args)
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == "__main__":
|
265 |
+
args = parse_args()
|
266 |
+
main(args)
|
FGT_codes/FGT/data/util/mask_generators.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
|
5 |
+
|
6 |
+
def get_video_masks_by_moving_random_stroke(
|
7 |
+
video_len, imageWidth=320, imageHeight=180, nStroke=5,
|
8 |
+
nVertexBound=[10, 30], maxHeadSpeed=15, maxHeadAcceleration=(15, 0.5),
|
9 |
+
brushWidthBound=(5, 20), boarderGap=None, nMovePointRatio=0.5, maxPiontMove=10,
|
10 |
+
maxLineAcceleration=5, maxInitSpeed=5
|
11 |
+
):
|
12 |
+
'''
|
13 |
+
Get video masks by random strokes which move randomly between each
|
14 |
+
frame, including the whole stroke and its control points
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
imageWidth: Image width
|
19 |
+
imageHeight: Image height
|
20 |
+
nStroke: Number of drawed lines
|
21 |
+
nVertexBound: Lower/upper bound of number of control points for each line
|
22 |
+
maxHeadSpeed: Max head speed when creating control points
|
23 |
+
maxHeadAcceleration: Max acceleration applying on the current head point (
|
24 |
+
a head point and its velosity decides the next point)
|
25 |
+
brushWidthBound (min, max): Bound of width for each stroke
|
26 |
+
boarderGap: The minimum gap between image boarder and drawed lines
|
27 |
+
nMovePointRatio: The ratio of control points to move for next frames
|
28 |
+
maxPiontMove: The magnitude of movement for control points for next frames
|
29 |
+
maxLineAcceleration: The magnitude of acceleration for the whole line
|
30 |
+
|
31 |
+
Examples
|
32 |
+
----------
|
33 |
+
object_like_setting = {
|
34 |
+
"nVertexBound": [5, 20],
|
35 |
+
"maxHeadSpeed": 15,
|
36 |
+
"maxHeadAcceleration": (15, 3.14),
|
37 |
+
"brushWidthBound": (30, 50),
|
38 |
+
"nMovePointRatio": 0.5,
|
39 |
+
"maxPiontMove": 10,
|
40 |
+
"maxLineAcceleration": (5, 0.5),
|
41 |
+
"boarderGap": 20,
|
42 |
+
"maxInitSpeed": 10,
|
43 |
+
}
|
44 |
+
rand_curve_setting = {
|
45 |
+
"nVertexBound": [10, 30],
|
46 |
+
"maxHeadSpeed": 20,
|
47 |
+
"maxHeadAcceleration": (15, 0.5),
|
48 |
+
"brushWidthBound": (3, 10),
|
49 |
+
"nMovePointRatio": 0.5,
|
50 |
+
"maxPiontMove": 3,
|
51 |
+
"maxLineAcceleration": (5, 0.5),
|
52 |
+
"boarderGap": 20,
|
53 |
+
"maxInitSpeed": 6
|
54 |
+
}
|
55 |
+
get_video_masks_by_moving_random_stroke(video_len=5, nStroke=3, **object_like_setting)
|
56 |
+
'''
|
57 |
+
assert(video_len >= 1)
|
58 |
+
|
59 |
+
# Initilize a set of control points to draw the first mask
|
60 |
+
mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
|
61 |
+
control_points_set = []
|
62 |
+
for i in range(nStroke):
|
63 |
+
brushWidth = np.random.randint(brushWidthBound[0], brushWidthBound[1])
|
64 |
+
Xs, Ys, velocity = get_random_stroke_control_points(
|
65 |
+
imageWidth=imageWidth, imageHeight=imageHeight,
|
66 |
+
nVertexBound=nVertexBound, maxHeadSpeed=maxHeadSpeed,
|
67 |
+
maxHeadAcceleration=maxHeadAcceleration, boarderGap=boarderGap,
|
68 |
+
maxInitSpeed=maxInitSpeed
|
69 |
+
)
|
70 |
+
control_points_set.append((Xs, Ys, velocity, brushWidth))
|
71 |
+
draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
|
72 |
+
|
73 |
+
# Generate the following masks by randomly move strokes and their control points
|
74 |
+
masks = [mask]
|
75 |
+
for i in range(video_len - 1):
|
76 |
+
mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
|
77 |
+
for j in range(len(control_points_set)):
|
78 |
+
Xs, Ys, velocity, brushWidth = control_points_set[j]
|
79 |
+
new_Xs, new_Ys = random_move_control_points(
|
80 |
+
Xs, Ys, velocity, nMovePointRatio, maxPiontMove,
|
81 |
+
maxLineAcceleration, boarderGap
|
82 |
+
)
|
83 |
+
control_points_set[j] = (new_Xs, new_Ys, velocity, brushWidth)
|
84 |
+
for Xs, Ys, velocity, brushWidth in control_points_set:
|
85 |
+
draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
|
86 |
+
masks.append(mask)
|
87 |
+
|
88 |
+
return masks
|
89 |
+
|
90 |
+
|
91 |
+
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
92 |
+
speed, angle = velocity
|
93 |
+
d_speed, d_angle = maxAcceleration
|
94 |
+
|
95 |
+
if dist == 'uniform':
|
96 |
+
speed += np.random.uniform(-d_speed, d_speed)
|
97 |
+
angle += np.random.uniform(-d_angle, d_angle)
|
98 |
+
elif dist == 'guassian':
|
99 |
+
speed += np.random.normal(0, d_speed / 2)
|
100 |
+
angle += np.random.normal(0, d_angle / 2)
|
101 |
+
else:
|
102 |
+
raise NotImplementedError(f'Distribution type {dist} is not supported.')
|
103 |
+
|
104 |
+
return (speed, angle)
|
105 |
+
|
106 |
+
|
107 |
+
def random_move_control_points(Xs, Ys, lineVelocity, nMovePointRatio, maxPiontMove, maxLineAcceleration, boarderGap=15):
|
108 |
+
new_Xs = Xs.copy()
|
109 |
+
new_Ys = Ys.copy()
|
110 |
+
|
111 |
+
# move the whole line and accelerate
|
112 |
+
speed, angle = lineVelocity
|
113 |
+
new_Xs += int(speed * np.cos(angle))
|
114 |
+
new_Ys += int(speed * np.sin(angle))
|
115 |
+
lineVelocity = random_accelerate(lineVelocity, maxLineAcceleration, dist='guassian')
|
116 |
+
|
117 |
+
# choose points to move
|
118 |
+
chosen = np.arange(len(Xs))
|
119 |
+
np.random.shuffle(chosen)
|
120 |
+
chosen = chosen[:int(len(Xs) * nMovePointRatio)]
|
121 |
+
for i in chosen:
|
122 |
+
new_Xs[i] += np.random.randint(-maxPiontMove, maxPiontMove)
|
123 |
+
new_Ys[i] += np.random.randint(-maxPiontMove, maxPiontMove)
|
124 |
+
return new_Xs, new_Ys
|
125 |
+
|
126 |
+
|
127 |
+
def get_random_stroke_control_points(
|
128 |
+
imageWidth, imageHeight,
|
129 |
+
nVertexBound=(10, 30), maxHeadSpeed=10, maxHeadAcceleration=(5, 0.5), boarderGap=20,
|
130 |
+
maxInitSpeed=10
|
131 |
+
):
|
132 |
+
'''
|
133 |
+
Implementation the free-form training masks generating algorithm
|
134 |
+
proposed by JIAHUI YU et al. in "Free-Form Image Inpainting with Gated Convolution"
|
135 |
+
'''
|
136 |
+
startX = np.random.randint(imageWidth)
|
137 |
+
startY = np.random.randint(imageHeight)
|
138 |
+
Xs = [startX]
|
139 |
+
Ys = [startY]
|
140 |
+
|
141 |
+
numVertex = np.random.randint(nVertexBound[0], nVertexBound[1])
|
142 |
+
|
143 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
144 |
+
speed = np.random.uniform(0, maxHeadSpeed)
|
145 |
+
|
146 |
+
for i in range(numVertex):
|
147 |
+
speed, angle = random_accelerate((speed, angle), maxHeadAcceleration)
|
148 |
+
speed = np.clip(speed, 0, maxHeadSpeed)
|
149 |
+
|
150 |
+
nextX = startX + speed * np.sin(angle)
|
151 |
+
nextY = startY + speed * np.cos(angle)
|
152 |
+
|
153 |
+
if boarderGap is not None:
|
154 |
+
nextX = np.clip(nextX, boarderGap, imageWidth - boarderGap)
|
155 |
+
nextY = np.clip(nextY, boarderGap, imageHeight - boarderGap)
|
156 |
+
|
157 |
+
startX, startY = nextX, nextY
|
158 |
+
Xs.append(nextX)
|
159 |
+
Ys.append(nextY)
|
160 |
+
|
161 |
+
velocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
162 |
+
|
163 |
+
return np.array(Xs), np.array(Ys), velocity
|
164 |
+
|
165 |
+
|
166 |
+
def get_random_velocity(max_speed, dist='uniform'):
|
167 |
+
if dist == 'uniform':
|
168 |
+
speed = np.random.uniform(max_speed)
|
169 |
+
elif dist == 'guassian':
|
170 |
+
speed = np.abs(np.random.normal(0, max_speed / 2))
|
171 |
+
else:
|
172 |
+
raise NotImplementedError(f'Distribution type {dist} is not supported.')
|
173 |
+
|
174 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
175 |
+
return (speed, angle)
|
176 |
+
|
177 |
+
|
178 |
+
def draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=255):
|
179 |
+
radius = brushWidth // 2 - 1
|
180 |
+
for i in range(1, len(Xs)):
|
181 |
+
draw = ImageDraw.Draw(mask)
|
182 |
+
startX, startY = Xs[i - 1], Ys[i - 1]
|
183 |
+
nextX, nextY = Xs[i], Ys[i]
|
184 |
+
draw.line((startX, startY) + (nextX, nextY), fill=fill, width=brushWidth)
|
185 |
+
for x, y in zip(Xs, Ys):
|
186 |
+
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=fill)
|
187 |
+
return mask
|
188 |
+
|
189 |
+
|
190 |
+
# modified from https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/generate_data.py
|
191 |
+
def get_random_walk_mask(imageWidth=320, imageHeight=180, length=None):
|
192 |
+
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
|
193 |
+
canvas = np.zeros((imageHeight, imageWidth)).astype("i")
|
194 |
+
if length is None:
|
195 |
+
length = imageWidth * imageHeight
|
196 |
+
x = random.randint(0, imageHeight - 1)
|
197 |
+
y = random.randint(0, imageWidth - 1)
|
198 |
+
x_list = []
|
199 |
+
y_list = []
|
200 |
+
for i in range(length):
|
201 |
+
r = random.randint(0, len(action_list) - 1)
|
202 |
+
x = np.clip(x + action_list[r][0], a_min=0, a_max=imageHeight - 1)
|
203 |
+
y = np.clip(y + action_list[r][1], a_min=0, a_max=imageWidth - 1)
|
204 |
+
x_list.append(x)
|
205 |
+
y_list.append(y)
|
206 |
+
canvas[np.array(x_list), np.array(y_list)] = 1
|
207 |
+
return Image.fromarray(canvas * 255).convert('1')
|
208 |
+
|
209 |
+
|
210 |
+
def get_masked_ratio(mask):
|
211 |
+
"""
|
212 |
+
Calculate the masked ratio.
|
213 |
+
mask: Expected a binary PIL image, where 0 and 1 represent
|
214 |
+
masked(invalid) and valid pixel values.
|
215 |
+
"""
|
216 |
+
hist = mask.histogram()
|
217 |
+
return hist[0] / np.prod(mask.size)
|
FGT_codes/FGT/data/util/readers.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # NOQA
|
4 |
+
import argparse
|
5 |
+
from math import ceil
|
6 |
+
from glob import glob
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
from PIL import Image, ImageDraw, ImageOps, ImageFont
|
11 |
+
|
12 |
+
from utils.logging_config import logger
|
13 |
+
from utils.util import make_dirs, bbox_offset
|
14 |
+
|
15 |
+
|
16 |
+
DEFAULT_FPS = 6
|
17 |
+
MAX_LENGTH = 60
|
18 |
+
|
19 |
+
|
20 |
+
def parse_args():
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument(
|
23 |
+
'-fps', '--fps',
|
24 |
+
type=int, default=DEFAULT_FPS,
|
25 |
+
help="Output video FPS"
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
'-v', '--video_dir',
|
29 |
+
type=str,
|
30 |
+
help="Video directory name"
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
'-vs', '--video_dirs',
|
34 |
+
nargs='+',
|
35 |
+
type=str,
|
36 |
+
help="Video directory names"
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
'-v2', '--video_dir2',
|
40 |
+
type=str,
|
41 |
+
help="Video directory name"
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
'-sd', '--segms_dir',
|
45 |
+
type=str,
|
46 |
+
help="Segmentation directory name"
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
'-fgd', '--fg_dir',
|
50 |
+
type=str,
|
51 |
+
help="Foreground directory name"
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
'-fgfd', '--fg_frames_dir',
|
55 |
+
type=str,
|
56 |
+
help="Foreground frames directory name"
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
'-fgsd', '--fg_segms_dir',
|
60 |
+
type=str,
|
61 |
+
help="Foreground segmentations directory name"
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
'-syfd', '--syn_frames_dir',
|
65 |
+
type=str,
|
66 |
+
help="Synthesized frames directory name"
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
'-bgfd', '--bg_frames_dir',
|
70 |
+
type=str,
|
71 |
+
help="Background frames directory name"
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
'-rt', '--reader_type',
|
75 |
+
type=str,
|
76 |
+
help="Type of reader"
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
'-od', '--output_dir',
|
80 |
+
type=str,
|
81 |
+
help="Output directory name"
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
'-o', '--output_filename',
|
85 |
+
type=str, required=True,
|
86 |
+
help="Output output filename"
|
87 |
+
)
|
88 |
+
args = parser.parse_args()
|
89 |
+
return args
|
90 |
+
|
91 |
+
|
92 |
+
class Reader:
|
93 |
+
def __init__(self, dir_name, read=True, max_length=None, sample_period=1):
|
94 |
+
self.dir_name = dir_name
|
95 |
+
self.count = 0
|
96 |
+
self.max_length = max_length
|
97 |
+
self.filenames = []
|
98 |
+
self.sample_period = sample_period
|
99 |
+
if read:
|
100 |
+
if os.path.exists(dir_name):
|
101 |
+
# self.filenames = read_filenames_from_dir(dir_name, self.__class__.__name__)
|
102 |
+
# ^^^^^ yield None when reading some videos of face forensics data
|
103 |
+
# (related to 'Too many levels of symbolic links'?)
|
104 |
+
|
105 |
+
self.filenames = sorted(glob(os.path.join(dir_name, '*')))
|
106 |
+
self.filenames = [f for f in self.filenames if os.path.isfile(f)]
|
107 |
+
self.filenames = self.filenames[::sample_period][:max_length]
|
108 |
+
self.files = self.read_files(self.filenames)
|
109 |
+
else:
|
110 |
+
self.files = []
|
111 |
+
logger.warning(f"Directory {dir_name} not exists!")
|
112 |
+
else:
|
113 |
+
self.files = []
|
114 |
+
self.current_index = 0
|
115 |
+
|
116 |
+
def append(self, file_):
|
117 |
+
self.files.append(file_)
|
118 |
+
|
119 |
+
def set_files(self, files):
|
120 |
+
self.files = files
|
121 |
+
|
122 |
+
def read_files(self, filenames):
|
123 |
+
assert type(filenames) == list, f'filenames is not a list; dirname: {self.dir_name}'
|
124 |
+
filenames.sort()
|
125 |
+
frames = []
|
126 |
+
for filename in filenames:
|
127 |
+
file_ = self.read_file(filename)
|
128 |
+
frames.append(file_)
|
129 |
+
return frames
|
130 |
+
|
131 |
+
def save_files(self, output_dir=None):
|
132 |
+
make_dirs(output_dir)
|
133 |
+
logger.info(f"Saving {self.__class__.__name__} files to {output_dir}")
|
134 |
+
for i, file_ in enumerate(self.files):
|
135 |
+
self._save_file(output_dir, i, file_)
|
136 |
+
|
137 |
+
def _save_file(self, output_dir, i, file_):
|
138 |
+
raise NotImplementedError("This is an abstract function")
|
139 |
+
|
140 |
+
def read_file(self, filename):
|
141 |
+
raise NotImplementedError("This is an abstract function")
|
142 |
+
|
143 |
+
def __iter__(self):
|
144 |
+
return self
|
145 |
+
|
146 |
+
def __next__(self):
|
147 |
+
if self.current_index < len(self.files):
|
148 |
+
file_ = self.files[self.current_index]
|
149 |
+
self.current_index += 1
|
150 |
+
return file_
|
151 |
+
else:
|
152 |
+
self.current_index = 0
|
153 |
+
raise StopIteration
|
154 |
+
|
155 |
+
def __getitem__(self, key):
|
156 |
+
return self.files[key]
|
157 |
+
|
158 |
+
def __len__(self):
|
159 |
+
return len(self.files)
|
160 |
+
|
161 |
+
|
162 |
+
class FrameReader(Reader):
|
163 |
+
def __init__(
|
164 |
+
self, dir_name, resize=None, read=True, max_length=MAX_LENGTH,
|
165 |
+
scale=1, sample_period=1
|
166 |
+
):
|
167 |
+
self.resize = resize
|
168 |
+
self.scale = scale
|
169 |
+
self.sample_period = sample_period
|
170 |
+
super().__init__(dir_name, read, max_length, sample_period)
|
171 |
+
|
172 |
+
def read_file(self, filename):
|
173 |
+
origin_frame = Image.open(filename)
|
174 |
+
size = self.resize if self.resize is not None else origin_frame.size
|
175 |
+
origin_frame_resized = origin_frame.resize(
|
176 |
+
(int(size[0] * self.scale), int(size[1] * self.scale))
|
177 |
+
)
|
178 |
+
return origin_frame_resized
|
179 |
+
|
180 |
+
def _save_file(self, output_dir, i, file_):
|
181 |
+
if len(self.filenames) == len(self.files):
|
182 |
+
name = sorted(self.filenames)[i].split('/')[-1]
|
183 |
+
else:
|
184 |
+
name = f"frame_{i:04}.png"
|
185 |
+
filename = os.path.join(
|
186 |
+
output_dir, name
|
187 |
+
)
|
188 |
+
file_.save(filename, "PNG")
|
189 |
+
|
190 |
+
def write_files_to_video(self, output_filename, fps=DEFAULT_FPS, frame_num_when_repeat_list=[1]):
|
191 |
+
logger.info(
|
192 |
+
f"Writeing frames to video {output_filename} with FPS={fps}")
|
193 |
+
video_writer = cv2.VideoWriter(
|
194 |
+
output_filename,
|
195 |
+
cv2.VideoWriter_fourcc(*"MJPG"),
|
196 |
+
fps,
|
197 |
+
self.files[0].size
|
198 |
+
)
|
199 |
+
for frame_num_when_repeat in frame_num_when_repeat_list:
|
200 |
+
for frame in self.files:
|
201 |
+
frame = frame.convert("RGB")
|
202 |
+
frame_cv = np.array(frame)
|
203 |
+
frame_cv = cv2.cvtColor(frame_cv, cv2.COLOR_RGB2BGR)
|
204 |
+
for i in range(frame_num_when_repeat):
|
205 |
+
video_writer.write(frame_cv)
|
206 |
+
video_writer.release()
|
207 |
+
|
208 |
+
|
209 |
+
class SynthesizedFrameReader(FrameReader):
|
210 |
+
def __init__(
|
211 |
+
self, bg_frames_dir, fg_frames_dir,
|
212 |
+
fg_segms_dir, segm_bbox_mask_dir, fg_dir, dir_name,
|
213 |
+
bboxes_list_dir,
|
214 |
+
fg_scale=0.7, fg_location=(48, 27), mask_only=False
|
215 |
+
):
|
216 |
+
self.bg_reader = FrameReader(bg_frames_dir)
|
217 |
+
self.size = self.bg_reader[0].size
|
218 |
+
# TODO: add different location and change scale to var
|
219 |
+
self.fg_reader = ForegroundReader(
|
220 |
+
fg_frames_dir, fg_segms_dir, fg_dir,
|
221 |
+
resize=self.size,
|
222 |
+
scale=fg_scale
|
223 |
+
)
|
224 |
+
self.fg_location = fg_location
|
225 |
+
# self.masks = self.fg_reader.masks
|
226 |
+
# self.bbox_masks = self.fg_reader.bbox_masks
|
227 |
+
super().__init__(dir_name, read=False)
|
228 |
+
self.files = self.synthesize_frames(
|
229 |
+
self.bg_reader, self.fg_reader, mask_only)
|
230 |
+
self.bbox_masks = MaskGenerator(
|
231 |
+
segm_bbox_mask_dir, self.size, self.get_bboxeses()
|
232 |
+
)
|
233 |
+
self.bboxes_list_dir = bboxes_list_dir
|
234 |
+
self.bboxes_list = self.get_bboxeses()
|
235 |
+
self.save_bboxes()
|
236 |
+
|
237 |
+
def save_bboxes(self):
|
238 |
+
make_dirs(self.bboxes_list_dir)
|
239 |
+
logger.info(f"Saving bboxes to {self.bboxes_list_dir}")
|
240 |
+
for i, bboxes in enumerate(self.bboxes_list):
|
241 |
+
save_path = os.path.join(self.bboxes_list_dir, f"bboxes_{i:04}.txt")
|
242 |
+
if len(bboxes) > 0:
|
243 |
+
np.savetxt(save_path, bboxes[0], fmt='%4u')
|
244 |
+
|
245 |
+
def get_bboxeses(self):
|
246 |
+
bboxeses = self.fg_reader.segms.bboxeses
|
247 |
+
new_bboxeses = []
|
248 |
+
for bboxes in bboxeses:
|
249 |
+
new_bboxes = []
|
250 |
+
for bbox in bboxes:
|
251 |
+
offset_bbox = bbox_offset(bbox, self.fg_location)
|
252 |
+
new_bboxes.append(offset_bbox)
|
253 |
+
new_bboxeses.append(new_bboxes)
|
254 |
+
return new_bboxeses
|
255 |
+
|
256 |
+
def synthesize_frames(self, bg_reader, fg_reader, mask_only=False):
|
257 |
+
logger.info(
|
258 |
+
f"Synthesizing {bg_reader.dir_name} and {fg_reader.dir_name}"
|
259 |
+
)
|
260 |
+
synthesized_frames = []
|
261 |
+
for i, bg in enumerate(bg_reader):
|
262 |
+
if i == len(fg_reader):
|
263 |
+
break
|
264 |
+
fg = fg_reader[i]
|
265 |
+
mask = fg_reader.get_mask(i)
|
266 |
+
synthesized_frame = bg.copy()
|
267 |
+
if mask_only:
|
268 |
+
synthesized_frame.paste(mask, self.fg_location, mask)
|
269 |
+
else:
|
270 |
+
synthesized_frame.paste(fg, self.fg_location, mask)
|
271 |
+
synthesized_frames.append(synthesized_frame)
|
272 |
+
return synthesized_frames
|
273 |
+
|
274 |
+
|
275 |
+
class WarpedFrameReader(FrameReader):
|
276 |
+
def __init__(self, dir_name, i, ks):
|
277 |
+
self.i = i
|
278 |
+
self.ks = ks
|
279 |
+
super().__init__(dir_name)
|
280 |
+
|
281 |
+
def _save_file(self, output_dir, i, file_):
|
282 |
+
filename = os.path.join(
|
283 |
+
output_dir,
|
284 |
+
f"warped_frame_{self.i:04}_k{self.ks[i]:02}.png"
|
285 |
+
)
|
286 |
+
file_.save(filename)
|
287 |
+
|
288 |
+
|
289 |
+
class SegmentationReader(FrameReader):
|
290 |
+
def __init__(
|
291 |
+
self, dir_name,
|
292 |
+
resize=None, scale=1
|
293 |
+
):
|
294 |
+
super().__init__(
|
295 |
+
dir_name, resize=resize, scale=scale
|
296 |
+
)
|
297 |
+
|
298 |
+
def read_file(self, filename):
|
299 |
+
origin_frame = Image.open(filename)
|
300 |
+
mask = ImageOps.invert(origin_frame.convert("L"))
|
301 |
+
mask = mask.point(lambda x: 0 if x < 255 else 255, '1')
|
302 |
+
size = self.resize if self.resize is not None else origin_frame.size
|
303 |
+
mask_resized = mask.resize(
|
304 |
+
(int(size[0] * self.scale), int(size[1] * self.scale))
|
305 |
+
)
|
306 |
+
return mask_resized
|
307 |
+
|
308 |
+
|
309 |
+
class MaskReader(Reader):
|
310 |
+
def __init__(self, dir_name, read=True):
|
311 |
+
super().__init__(dir_name, read=read)
|
312 |
+
|
313 |
+
def read_file(self, filename):
|
314 |
+
mask = Image.open(filename)
|
315 |
+
return mask
|
316 |
+
|
317 |
+
def _save_file(self, output_dir, i, file_):
|
318 |
+
filename = os.path.join(
|
319 |
+
output_dir,
|
320 |
+
f"mask_{i:04}.png"
|
321 |
+
)
|
322 |
+
file_.save(filename)
|
323 |
+
|
324 |
+
def get_bboxes(self, i):
|
325 |
+
# TODO: save bbox instead of looking for one
|
326 |
+
mask = self.files[i]
|
327 |
+
mask = ImageOps.invert(mask.convert("L")).convert("1")
|
328 |
+
mask = np.array(mask)
|
329 |
+
image, contours, hier = cv2.findContours(
|
330 |
+
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
331 |
+
bboxes = []
|
332 |
+
for c in contours:
|
333 |
+
# get the bounding rect
|
334 |
+
x, y, w, h = cv2.boundingRect(c)
|
335 |
+
bbox = ((x, y), (x + w - 1, y + h - 1))
|
336 |
+
bboxes.append(bbox)
|
337 |
+
return bboxes
|
338 |
+
|
339 |
+
def get_bbox(self, i):
|
340 |
+
# TODO: save bbox instead of looking for one
|
341 |
+
mask = self.files[i]
|
342 |
+
mask = ImageOps.invert(mask.convert("L"))
|
343 |
+
mask = np.array(mask)
|
344 |
+
image, contours, hier = cv2.findContours(
|
345 |
+
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
346 |
+
for c in contours:
|
347 |
+
# get the bounding rect
|
348 |
+
x, y, w, h = cv2.boundingRect(c)
|
349 |
+
bbox = ((x, y), (x + w - 1, y + h - 1))
|
350 |
+
return bbox
|
351 |
+
|
352 |
+
|
353 |
+
class MaskGenerator(Reader):
|
354 |
+
def __init__(
|
355 |
+
self, mask_output_dir, size, bboxeses, save_masks=True
|
356 |
+
):
|
357 |
+
self.bboxeses = bboxeses
|
358 |
+
self.size = size
|
359 |
+
super().__init__(mask_output_dir, read=False)
|
360 |
+
self.files = self.generate_masks()
|
361 |
+
if save_masks:
|
362 |
+
make_dirs(mask_output_dir)
|
363 |
+
self.save_files(mask_output_dir)
|
364 |
+
|
365 |
+
def _save_file(self, output_dir, i, file_):
|
366 |
+
filename = os.path.join(
|
367 |
+
output_dir,
|
368 |
+
f"mask_{i:04}.png"
|
369 |
+
)
|
370 |
+
file_.save(filename)
|
371 |
+
|
372 |
+
def get_bboxes(self, i):
|
373 |
+
return self.bboxeses[i]
|
374 |
+
|
375 |
+
def generate_masks(self):
|
376 |
+
masks = []
|
377 |
+
for i in range(len(self.bboxeses)):
|
378 |
+
mask = self.generate_mask(i)
|
379 |
+
masks.append(mask)
|
380 |
+
return masks
|
381 |
+
|
382 |
+
def generate_mask(self, i):
|
383 |
+
bboxes = self.bboxeses[i]
|
384 |
+
mask = Image.new("1", self.size, 1)
|
385 |
+
draw = ImageDraw.Draw(mask)
|
386 |
+
for bbox in bboxes:
|
387 |
+
draw.rectangle(
|
388 |
+
bbox, fill=0
|
389 |
+
)
|
390 |
+
return mask
|
391 |
+
|
392 |
+
|
393 |
+
class ForegroundReader(FrameReader):
|
394 |
+
def __init__(
|
395 |
+
self, frames_dir, segms_dir, dir_name,
|
396 |
+
resize=None, scale=1
|
397 |
+
):
|
398 |
+
self.frames_dir = frames_dir
|
399 |
+
self.segms_dir = segms_dir
|
400 |
+
self.frames = FrameReader(
|
401 |
+
frames_dir,
|
402 |
+
resize=resize, scale=scale
|
403 |
+
)
|
404 |
+
self.segms = SegmentationReader(
|
405 |
+
segms_dir, resize=resize, scale=scale
|
406 |
+
)
|
407 |
+
super().__init__(dir_name, read=False)
|
408 |
+
self.masks = self.segms.masks
|
409 |
+
# self.bbox_masks = self.segms.bbox_masks
|
410 |
+
self.files = self.generate_fg_frames(self.frames, self.segms)
|
411 |
+
|
412 |
+
def get_mask(self, i):
|
413 |
+
return self.masks[i]
|
414 |
+
|
415 |
+
def generate_fg_frames(self, frames, segms):
|
416 |
+
logger.info(
|
417 |
+
f"Generating fg frames from {self.frames_dir} and {self.segms_dir}"
|
418 |
+
)
|
419 |
+
fg_frames = []
|
420 |
+
for i, frame in enumerate(frames):
|
421 |
+
mask = self.masks[i]
|
422 |
+
fg_frame = Image.new("RGB", frame.size, (0, 0, 0))
|
423 |
+
fg_frame.paste(
|
424 |
+
frame, (0, 0),
|
425 |
+
mask
|
426 |
+
)
|
427 |
+
fg_frames.append(fg_frame)
|
428 |
+
return fg_frames
|
429 |
+
|
430 |
+
|
431 |
+
class CompareFramesReader(FrameReader):
|
432 |
+
def __init__(self, dir_names, col=2, names=[], mask_dir=None):
|
433 |
+
self.videos = []
|
434 |
+
for dir_name in dir_names:
|
435 |
+
# If a method fails on this video, use None to indicate the situation
|
436 |
+
try:
|
437 |
+
self.videos.append(FrameReader(dir_name))
|
438 |
+
except AssertionError:
|
439 |
+
self.videos.append(None)
|
440 |
+
if mask_dir is not None:
|
441 |
+
self.masks = MaskReader(mask_dir)
|
442 |
+
self.names = names
|
443 |
+
self.files = self.combine_videos(self.videos, col)
|
444 |
+
|
445 |
+
def combine_videos(self, videos, col=2, edge_offset=35, h_start_offset=35):
|
446 |
+
combined_frames = []
|
447 |
+
w, h = videos[0][0].size
|
448 |
+
# Prevent the first method fails and have a "None" as its video
|
449 |
+
i = 0
|
450 |
+
while videos[i] is None:
|
451 |
+
i += 1
|
452 |
+
length = len(videos[i])
|
453 |
+
video_num = len(videos)
|
454 |
+
row = ceil(video_num / col)
|
455 |
+
for frame_idx in range(length):
|
456 |
+
width = col * w + (col - 1) * edge_offset
|
457 |
+
height = row * h + (row - 1) * edge_offset + h_start_offset
|
458 |
+
combined_frame = Image.new("RGBA", (width, height))
|
459 |
+
draw = ImageDraw.Draw(combined_frame)
|
460 |
+
for i, video in enumerate(videos):
|
461 |
+
# Give the failed method a black output
|
462 |
+
if video is None or frame_idx >= len(video):
|
463 |
+
failed = True
|
464 |
+
frame = Image.new("RGBA", (w, h))
|
465 |
+
else:
|
466 |
+
frame = video[frame_idx].convert("RGBA")
|
467 |
+
failed = False
|
468 |
+
|
469 |
+
f_x = (i % col) * (w + edge_offset)
|
470 |
+
f_y = (i // col) * (h + edge_offset) + h_start_offset
|
471 |
+
combined_frame.paste(frame, (f_x, f_y))
|
472 |
+
|
473 |
+
# Draw name
|
474 |
+
font = ImageFont.truetype("DejaVuSans.ttf", 12)
|
475 |
+
# font = ImageFont.truetype("DejaVuSans-Bold.ttf", 13)
|
476 |
+
# font = ImageFont.truetype("timesbd.ttf", 14)
|
477 |
+
name = self.names[i] if not failed else f'{self.names[i]} (failed)'
|
478 |
+
draw.text(
|
479 |
+
(f_x + 10, f_y - 20),
|
480 |
+
name, (255, 255, 255), font=font
|
481 |
+
)
|
482 |
+
|
483 |
+
combined_frames.append(combined_frame)
|
484 |
+
return combined_frames
|
485 |
+
|
486 |
+
|
487 |
+
class BoundingBoxesListReader(Reader):
|
488 |
+
def __init__(
|
489 |
+
self, dir_name, resize=None, read=True, max_length=MAX_LENGTH,
|
490 |
+
scale=1
|
491 |
+
):
|
492 |
+
self.resize = resize
|
493 |
+
self.scale = scale
|
494 |
+
super().__init__(dir_name, read, max_length)
|
495 |
+
|
496 |
+
def read_file(self, filename):
|
497 |
+
bboxes = np.loadtxt(filename, dtype=int)
|
498 |
+
bboxes = [bboxes.tolist()]
|
499 |
+
return bboxes
|
500 |
+
|
501 |
+
|
502 |
+
def save_frames_to_dir(frames, dirname):
|
503 |
+
reader = FrameReader(dirname, read=False)
|
504 |
+
reader.set_files(frames)
|
505 |
+
reader.save_files(dirname)
|
506 |
+
|
507 |
+
|
508 |
+
if __name__ == "__main__":
|
509 |
+
args = parse_args()
|
510 |
+
if args.reader_type is None:
|
511 |
+
reader = FrameReader(args.video_dir)
|
512 |
+
elif args.reader_type == 'fg':
|
513 |
+
reader = ForegroundReader(
|
514 |
+
args.video_dir, args.segms_dir, args.fg_dir)
|
515 |
+
elif args.reader_type == 'sy':
|
516 |
+
reader = SynthesizedFrameReader(
|
517 |
+
args.bg_frames_dir, args.fg_frames_dir,
|
518 |
+
args.fg_segms_dir, args.fg_dir, args.syn_frames_dir
|
519 |
+
)
|
520 |
+
elif args.reader_type == 'com':
|
521 |
+
reader = CompareFramesReader(
|
522 |
+
args.video_dirs
|
523 |
+
)
|
524 |
+
reader.write_files_to_video(
|
525 |
+
os.path.join(args.output_dir, args.output_filename),
|
526 |
+
fps=args.fps
|
527 |
+
)
|
FGT_codes/FGT/data/util/util.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import shutil
|
4 |
+
from glob import glob
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from utils.logging_config import logger
|
10 |
+
|
11 |
+
|
12 |
+
def parse_args():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument(
|
15 |
+
'-v', '--video_dir',
|
16 |
+
type=str,
|
17 |
+
help="Video directory name"
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
'-fl', '--flow_dir',
|
21 |
+
type=str,
|
22 |
+
help="Optical flow ground truth directory name"
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
'-od', '--output_dir',
|
26 |
+
type=str,
|
27 |
+
help="Output directory name"
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
'-o', '--output_filename',
|
31 |
+
type=str,
|
32 |
+
help="Output output filename"
|
33 |
+
)
|
34 |
+
args = parser.parse_args()
|
35 |
+
return args
|
36 |
+
|
37 |
+
|
38 |
+
def make_dirs(dir_name):
|
39 |
+
if not os.path.exists(dir_name):
|
40 |
+
os.makedirs(dir_name)
|
41 |
+
logger.info(f"Directory {dir_name} made")
|
42 |
+
|
43 |
+
|
44 |
+
ensure_dir = make_dirs
|
45 |
+
|
46 |
+
|
47 |
+
def make_dir_under_root(root_dir, name):
|
48 |
+
full_dir_name = os.path.join(root_dir, name)
|
49 |
+
make_dirs(full_dir_name)
|
50 |
+
return full_dir_name
|
51 |
+
|
52 |
+
|
53 |
+
def rm_dirs(dir_name, ignore_errors=False):
|
54 |
+
if os.path.exists(dir_name):
|
55 |
+
shutil.rmtree(dir_name, ignore_errors)
|
56 |
+
logger.info(f"Directory {dir_name} removed")
|
57 |
+
|
58 |
+
|
59 |
+
def read_dirnames_under_root(root_dir, skip_list=[]):
|
60 |
+
dirnames = [
|
61 |
+
name for i, name in enumerate(sorted(os.listdir(root_dir)))
|
62 |
+
if (os.path.isdir(os.path.join(root_dir, name))
|
63 |
+
and name not in skip_list
|
64 |
+
and i not in skip_list)
|
65 |
+
]
|
66 |
+
logger.info(f"Reading directories under {root_dir}, exclude {skip_list}, num: {len(dirnames)}")
|
67 |
+
return dirnames
|
68 |
+
|
69 |
+
|
70 |
+
def bbox_offset(bbox, location):
|
71 |
+
x0, y0 = location
|
72 |
+
(x1, y1), (x2, y2) = bbox
|
73 |
+
return ((x1 + x0, y1 + y0), (x2 + x0, y2 + y0))
|
74 |
+
|
75 |
+
|
76 |
+
def cover2_bbox(bbox1, bbox2):
|
77 |
+
x1 = min(bbox1[0][0], bbox2[0][0])
|
78 |
+
y1 = min(bbox1[0][1], bbox2[0][1])
|
79 |
+
x2 = max(bbox1[1][0], bbox2[1][0])
|
80 |
+
y2 = max(bbox1[1][1], bbox2[1][1])
|
81 |
+
return ((x1, y1), (x2, y2))
|
82 |
+
|
83 |
+
|
84 |
+
def extend_r_bbox(bbox, w, h, r):
|
85 |
+
(x1, y1), (x2, y2) = bbox
|
86 |
+
x1 = max(x1 - r, 0)
|
87 |
+
x2 = min(x2 + r, w)
|
88 |
+
y1 = max(y1 - r, 0)
|
89 |
+
y2 = min(y2 + r, h)
|
90 |
+
return ((x1, y1), (x2, y2))
|
91 |
+
|
92 |
+
|
93 |
+
def mean_squared_error(A, B):
|
94 |
+
return np.square(np.subtract(A, B)).mean()
|
95 |
+
|
96 |
+
|
97 |
+
def bboxes_to_mask(size, bboxes):
|
98 |
+
mask = Image.new("L", size, 255)
|
99 |
+
mask = np.array(mask)
|
100 |
+
for bbox in bboxes:
|
101 |
+
try:
|
102 |
+
(x1, y1), (x2, y2) = bbox
|
103 |
+
except Exception:
|
104 |
+
(x1, y1, x2, y2) = bbox
|
105 |
+
|
106 |
+
mask[y1:y2, x1:x2] = 0
|
107 |
+
mask = Image.fromarray(mask.astype("uint8"))
|
108 |
+
return mask
|
109 |
+
|
110 |
+
|
111 |
+
def get_extended_from_box(img_size, box, patch_size):
|
112 |
+
def _decide_patch_num(box_width, patch_size):
|
113 |
+
num = np.ceil(box_width / patch_size).astype(np.int)
|
114 |
+
if (num * patch_size - box_width) < (patch_size // 2):
|
115 |
+
num += 1
|
116 |
+
return num
|
117 |
+
|
118 |
+
x1, y1 = box[0]
|
119 |
+
x2, y2 = box[1]
|
120 |
+
new_box = (x1, y1, x2 - x1, y2 - y1)
|
121 |
+
box_x_start, box_y_start, box_x_size, box_y_size = new_box
|
122 |
+
|
123 |
+
patchN_x = _decide_patch_num(box_x_size, patch_size)
|
124 |
+
patchN_y = _decide_patch_num(box_y_size, patch_size)
|
125 |
+
|
126 |
+
extend_x = (patch_size * patchN_x - box_x_size) // 2
|
127 |
+
extend_y = (patch_size * patchN_y - box_y_size) // 2
|
128 |
+
img_x_size = img_size[0]
|
129 |
+
img_y_size = img_size[1]
|
130 |
+
|
131 |
+
x_start = max(0, box_x_start - extend_x)
|
132 |
+
x_end = min(box_x_start - extend_x + patchN_x * patch_size, img_x_size)
|
133 |
+
|
134 |
+
y_start = max(0, box_y_start - extend_y)
|
135 |
+
y_end = min(box_y_start - extend_y + patchN_y * patch_size, img_y_size)
|
136 |
+
x_start, y_start, x_end, y_end = int(x_start), int(y_start), int(x_end), int(y_end)
|
137 |
+
extented_box = ((x_start, y_start), (x_end, y_end))
|
138 |
+
return extented_box
|
139 |
+
|
140 |
+
|
141 |
+
# code modified from https://github.com/WonwoongCho/Generative-Inpainting-pytorch/blob/master/util.py
|
142 |
+
def spatial_discounting_mask(mask_width, mask_height, discounting_gamma):
|
143 |
+
"""Generate spatial discounting mask constant.
|
144 |
+
Spatial discounting mask is first introduced in publication:
|
145 |
+
Generative Image Inpainting with Contextual Attention, Yu et al.
|
146 |
+
Returns:
|
147 |
+
np.array: spatial discounting mask
|
148 |
+
"""
|
149 |
+
gamma = discounting_gamma
|
150 |
+
mask_values = np.ones((mask_width, mask_height), dtype=np.float32)
|
151 |
+
for i in range(mask_width):
|
152 |
+
for j in range(mask_height):
|
153 |
+
mask_values[i, j] = max(
|
154 |
+
gamma**min(i, mask_width - i),
|
155 |
+
gamma**min(j, mask_height - j))
|
156 |
+
|
157 |
+
return mask_values
|
158 |
+
|
159 |
+
|
160 |
+
def bboxes_to_discounting_loss_mask(img_size, bboxes, discounting_gamma=0.99):
|
161 |
+
mask = np.zeros(img_size, dtype=np.float32) + 0.5
|
162 |
+
for bbox in bboxes:
|
163 |
+
try:
|
164 |
+
(x1, y1), (x2, y2) = bbox
|
165 |
+
except Exception:
|
166 |
+
(x1, y1, x2, y2) = bbox
|
167 |
+
mask_width, mask_height = y2 - y1, x2 - x1
|
168 |
+
mask[y1:y2, x1:x2] = spatial_discounting_mask(mask_width, mask_height, discounting_gamma)
|
169 |
+
return mask
|
170 |
+
|
171 |
+
|
172 |
+
def find_proper_window(image_size, bbox_point):
|
173 |
+
'''
|
174 |
+
parameters:
|
175 |
+
image_size(2-tuple): (height, width)
|
176 |
+
bbox_point(2-2-tuple): (first_point, last_point)
|
177 |
+
return values:
|
178 |
+
window left-up point, (2-tuple)
|
179 |
+
window right-bottom point, (2-tuple)
|
180 |
+
'''
|
181 |
+
bbox_height = bbox_point[1][0] - bbox_point[0][0]
|
182 |
+
bbox_width = bbox_point[1][1] - bbox_point[0][1]
|
183 |
+
|
184 |
+
window_size = min(
|
185 |
+
max(bbox_height, bbox_width) * 2,
|
186 |
+
image_size[0], image_size[1]
|
187 |
+
)
|
188 |
+
# Limit min window size due to the requirement of VGG16
|
189 |
+
window_size = max(window_size, 32)
|
190 |
+
|
191 |
+
horizontal_span = window_size - (bbox_point[1][1] - bbox_point[0][1])
|
192 |
+
vertical_span = window_size - (bbox_point[1][0] - bbox_point[0][0])
|
193 |
+
|
194 |
+
top_bound, bottom_bound = bbox_point[0][0] - \
|
195 |
+
vertical_span // 2, bbox_point[1][0] + vertical_span // 2
|
196 |
+
left_bound, right_bound = bbox_point[0][1] - \
|
197 |
+
horizontal_span // 2, bbox_point[1][1] + horizontal_span // 2
|
198 |
+
|
199 |
+
if left_bound < 0:
|
200 |
+
right_bound += 0 - left_bound
|
201 |
+
left_bound += 0 - left_bound
|
202 |
+
elif right_bound > image_size[1]:
|
203 |
+
left_bound -= right_bound - image_size[1]
|
204 |
+
right_bound -= right_bound - image_size[1]
|
205 |
+
|
206 |
+
if top_bound < 0:
|
207 |
+
bottom_bound += 0 - top_bound
|
208 |
+
top_bound += 0 - top_bound
|
209 |
+
elif bottom_bound > image_size[0]:
|
210 |
+
top_bound -= bottom_bound - image_size[0]
|
211 |
+
bottom_bound -= bottom_bound - image_size[0]
|
212 |
+
|
213 |
+
return (top_bound, left_bound), (bottom_bound, right_bound)
|
214 |
+
|
215 |
+
|
216 |
+
def drawrect(drawcontext, xy, outline=None, width=0, partial=None):
|
217 |
+
(x1, y1), (x2, y2) = xy
|
218 |
+
if partial is None:
|
219 |
+
points = (x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)
|
220 |
+
drawcontext.line(points, fill=outline, width=width)
|
221 |
+
else:
|
222 |
+
drawcontext.line([(x1, y1), (x1, y1 + partial)], fill=outline, width=width)
|
223 |
+
drawcontext.line([(x1 + partial, y1), (x1, y1)], fill=outline, width=width)
|
224 |
+
|
225 |
+
drawcontext.line([(x2, y1), (x2, y1 + partial)], fill=outline, width=width)
|
226 |
+
drawcontext.line([(x2, y1), (x2 - partial, y1)], fill=outline, width=width)
|
227 |
+
|
228 |
+
drawcontext.line([(x1, y2), (x1 + partial, y2)], fill=outline, width=width)
|
229 |
+
drawcontext.line([(x1, y2), (x1, y2 - partial)], fill=outline, width=width)
|
230 |
+
|
231 |
+
drawcontext.line([(x2 - partial, y2), (x2, y2)], fill=outline, width=width)
|
232 |
+
drawcontext.line([(x2, y2), (x2, y2 - partial)], fill=outline, width=width)
|
233 |
+
|
234 |
+
|
235 |
+
def get_everything_under(root_dir, pattern='*', only_dirs=False, only_files=False):
|
236 |
+
assert not(only_dirs and only_files), 'You will get nothnig '\
|
237 |
+
'when "only_dirs" and "only_files" are both set to True'
|
238 |
+
everything = sorted(glob(os.path.join(root_dir, pattern)))
|
239 |
+
if only_dirs:
|
240 |
+
everything = [f for f in everything if os.path.isdir(f)]
|
241 |
+
if only_files:
|
242 |
+
everything = [f for f in everything if os.path.isfile(f)]
|
243 |
+
|
244 |
+
return everything
|
245 |
+
|
246 |
+
|
247 |
+
def read_filenames_from_dir(dir_name, reader, max_length=None):
|
248 |
+
logger.debug(
|
249 |
+
f"{reader} reading files from {dir_name}")
|
250 |
+
filenames = []
|
251 |
+
for root, dirs, files in os.walk(dir_name):
|
252 |
+
assert len(dirs) == 0, f"There are direcories: {dirs} in {root}"
|
253 |
+
assert len(files) != 0, f"There are no files in {root}"
|
254 |
+
filenames = [os.path.join(root, name) for name in sorted(files)]
|
255 |
+
for name in filenames:
|
256 |
+
logger.debug(name)
|
257 |
+
if max_length is not None:
|
258 |
+
return filenames[:max_length]
|
259 |
+
return filenames
|
FGT_codes/FGT/data/util/utils.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
def random_bbox(img_height, img_width, vertical_margin, horizontal_margin, mask_height, mask_width):
|
6 |
+
maxt = img_height - vertical_margin - mask_height
|
7 |
+
maxl = img_width - horizontal_margin - mask_width
|
8 |
+
|
9 |
+
t = random.randint(vertical_margin, maxt)
|
10 |
+
l = random.randint(horizontal_margin, maxl)
|
11 |
+
h = random.randint(mask_height // 2, mask_height)
|
12 |
+
w = random.randint(mask_width // 2, mask_width)
|
13 |
+
return (t, l, h, w) # 产生随机块状box,这个box后面会发展成为mask
|
14 |
+
|
15 |
+
|
16 |
+
def mid_bbox_mask(img_height, img_width, mask_height, mask_width):
|
17 |
+
def npmask(bbox, height, width):
|
18 |
+
mask = np.zeros((height, width, 1), np.float32)
|
19 |
+
mask[bbox[0]: bbox[0] + bbox[2], bbox[1]: bbox[1] + bbox[3], :] = 255.
|
20 |
+
return mask
|
21 |
+
|
22 |
+
bbox = (img_height * 3 // 8, img_width * 3 // 8, mask_height, mask_width)
|
23 |
+
mask = npmask(bbox, img_height, img_width)
|
24 |
+
|
25 |
+
return mask
|
26 |
+
|
27 |
+
|
28 |
+
def bbox2mask(img_height, img_width, max_delta_height, max_delta_width, bbox):
|
29 |
+
"""Generate mask tensor from bbox.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
bbox: configuration tuple, (top, left, height, width)
|
33 |
+
config: Config should have configuration including IMG_SHAPES,
|
34 |
+
MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
tf.Tensor: output with shape [B, 1, H, W]
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
def npmask(bbox, height, width, delta_h, delta_w):
|
42 |
+
mask = np.zeros((height, width, 1), np.float32)
|
43 |
+
h = np.random.randint(delta_h // 2 + 1) # 防止有0产生
|
44 |
+
w = np.random.randint(delta_w // 2 + 1)
|
45 |
+
mask[bbox[0] + h: bbox[0] + bbox[2] - h, bbox[1] + w: bbox[1] + bbox[3] - w, :] = 255. # height_true = height - 2 * h, width_true = width - 2 * w
|
46 |
+
return mask
|
47 |
+
|
48 |
+
mask = npmask(bbox, img_height, img_width,
|
49 |
+
max_delta_height,
|
50 |
+
max_delta_width)
|
51 |
+
|
52 |
+
return mask
|
53 |
+
|
54 |
+
|
55 |
+
def matrix2bbox(img_height, img_width, mask_height, mask_width, row, column):
|
56 |
+
"""Generate masks with a matrix form
|
57 |
+
@param img_height
|
58 |
+
@param img_width
|
59 |
+
@param mask_height
|
60 |
+
@param mask_width
|
61 |
+
@param row: number of blocks in row
|
62 |
+
@param column: number of blocks in column
|
63 |
+
@return mbbox: multiple bboxes in (y, h, h, w) manner
|
64 |
+
"""
|
65 |
+
assert img_height - column * mask_height > img_height // 2, "Too many masks across a column"
|
66 |
+
assert img_width - row * mask_width > img_width // 2, "Too many masks across a row"
|
67 |
+
|
68 |
+
interval_height = (img_height - column * mask_height) // (column + 1)
|
69 |
+
interval_width = (img_width - row * mask_width) // (row + 1)
|
70 |
+
|
71 |
+
mbbox = []
|
72 |
+
for i in range(row):
|
73 |
+
for j in range(column):
|
74 |
+
y = interval_height * (j+1) + j * mask_height
|
75 |
+
x = interval_width * (i+1) + i * mask_width
|
76 |
+
mbbox.append((y, x, mask_height, mask_width))
|
77 |
+
return mbbox
|
78 |
+
|
79 |
+
|
80 |
+
def mbbox2masks(img_height, img_width, mbbox):
|
81 |
+
|
82 |
+
def npmask(mbbox, height, width):
|
83 |
+
mask = np.zeros((height, width, 1), np.float32)
|
84 |
+
for bbox in mbbox:
|
85 |
+
mask[bbox[0]: bbox[0] + bbox[2], bbox[1]: bbox[1] + bbox[3], :] = 255. # height_true = height - 2 * h, width_true = width - 2 * w
|
86 |
+
return mask
|
87 |
+
|
88 |
+
mask = npmask(mbbox, img_height, img_width)
|
89 |
+
|
90 |
+
return mask
|
91 |
+
|
92 |
+
|
93 |
+
def draw_line(mask, startX, startY, angle, length, brushWidth):
|
94 |
+
"""assume the size of mask is (H,W,1)
|
95 |
+
"""
|
96 |
+
assert len(mask.shape) == 2 or mask.shape[2] == 1, "The channel of mask doesn't fit the opencv format"
|
97 |
+
offsetX = int(np.round(length * np.cos(angle)))
|
98 |
+
offsetY = int(np.round(length * np.sin(angle)))
|
99 |
+
endX = startX + offsetX
|
100 |
+
endY = startY + offsetY
|
101 |
+
if endX > mask.shape[1]:
|
102 |
+
endX = mask.shape[1]
|
103 |
+
if endY > mask.shape[0]:
|
104 |
+
endY = mask.shape[0]
|
105 |
+
mask_processed = cv2.line(mask, (startX, startY), (endX, endY), 255, brushWidth)
|
106 |
+
return mask_processed, endX, endY
|
107 |
+
|
108 |
+
|
109 |
+
def draw_circle(mask, circle_x, circle_y, brushWidth):
|
110 |
+
radius = brushWidth // 2
|
111 |
+
assert len(mask.shape) == 2 or mask.shape[2] == 1, "The channel of mask doesn't fit the opencv format"
|
112 |
+
mask_processed = cv2.circle(mask, (circle_x, circle_y), radius, 255)
|
113 |
+
return mask_processed
|
114 |
+
|
115 |
+
|
116 |
+
def freeFormMask(img_height, img_width, maxVertex, maxLength, maxBrushWidth, maxAngle):
|
117 |
+
mask = np.zeros((img_height, img_width))
|
118 |
+
numVertex = random.randint(1, maxVertex)
|
119 |
+
startX = random.randint(10, img_width)
|
120 |
+
startY = random.randint(10, img_height)
|
121 |
+
brushWidth = random.randint(10, maxBrushWidth)
|
122 |
+
for i in range(numVertex):
|
123 |
+
angle = random.uniform(0, maxAngle)
|
124 |
+
if i % 2 == 0:
|
125 |
+
angle = 2 * np.pi - angle
|
126 |
+
length = random.randint(10, maxLength)
|
127 |
+
mask, endX, endY = draw_line(mask, startX, startY, angle, length, brushWidth)
|
128 |
+
startX = startX + int(length * np.sin(angle))
|
129 |
+
startY = startY + int(length * np.cos(angle))
|
130 |
+
mask = draw_circle(mask, endX, endY, brushWidth)
|
131 |
+
|
132 |
+
if random.random() < 0.5:
|
133 |
+
mask = np.fliplr(mask)
|
134 |
+
if random.random() < 0.5:
|
135 |
+
mask = np.flipud(mask)
|
136 |
+
|
137 |
+
if len(mask.shape) == 2:
|
138 |
+
mask = mask[:, :, np.newaxis]
|
139 |
+
|
140 |
+
return mask
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
# for stationary mask generation
|
145 |
+
# stationary_mask_generator(240, 480, 50, 120)
|
146 |
+
|
147 |
+
# for free-form mask generation
|
148 |
+
# mask = freeFormMask(240, 480, 30, 50, 20, np.pi)
|
149 |
+
# cv2.imwrite('mask.png', mask)
|
150 |
+
|
151 |
+
# for matrix mask generation
|
152 |
+
# img_height, img_width = 240, 480
|
153 |
+
# masks = matrix2bbox(240, 480, 20, 20, 5, 4)
|
154 |
+
# matrixMask = mbbox2masks(img_height, img_width, masks)
|
155 |
+
# cv2.imwrite('matrixMask.png', matrixMask)
|
156 |
+
pass
|
157 |
+
|
158 |
+
|
FGT_codes/FGT/flowCheckPoint/config.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PASSMASK: 1
|
2 |
+
cnum: 48
|
3 |
+
conv_type: vanilla
|
4 |
+
flow_interval: 1
|
5 |
+
in_channel: 3
|
6 |
+
init_weights: 1
|
7 |
+
num_flows: 1
|
8 |
+
resBlocks: 1
|
9 |
+
use_bias: 1
|
10 |
+
use_residual: 1
|
11 |
+
model: lafc_single
|
FGT_codes/FGT/flowCheckPoint/lafc_single.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fa168e8b711852c458594cddf4262afdb81e096253197a802a29b4dec9d6d12
|
3 |
+
size 11547053
|
FGT_codes/FGT/inputs.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def args_parser():
|
5 |
+
parser = argparse.ArgumentParser(description="General top layer trainer")
|
6 |
+
parser.add_argument("--opt", type=str, default="config/train.yaml", help="Path to optional configuration file")
|
7 |
+
parser.add_argument('--model', type=str, default='model',
|
8 |
+
help='Model block name, in the `model` directory')
|
9 |
+
parser.add_argument('--name', type=str, default='FGT_train', help='Experiment name')
|
10 |
+
parser.add_argument('--outputdir', type=str, default='/myData/ret/experiments', help='Output dir to save results')
|
11 |
+
parser.add_argument('--datadir', type=str, default='/myData/', metavar='PATH')
|
12 |
+
parser.add_argument('--datasetName_train', type=str, default='train_dataset_frames_diffusedFlows',
|
13 |
+
help='The file name of the train dataset, in `data` directory')
|
14 |
+
parser.add_argument('--network', type=str, default='network',
|
15 |
+
help='The network file which defines the training process, in the `network` directory')
|
16 |
+
parser.add_argument('--finetune', type=int, default=0, help='Whether to fine tune trained models')
|
17 |
+
# parser.add_argument('--checkPoint', type=str, default='', help='checkpoint path for continue training')
|
18 |
+
parser.add_argument('--gen_state', type=str, default='', help='Checkpoint of the generator')
|
19 |
+
parser.add_argument('--dis_state', type=str, default='', help='Checkpoint of the discriminator')
|
20 |
+
parser.add_argument('--opt_state', type=str, default='', help='Checkpoint of the options')
|
21 |
+
parser.add_argument('--record_iter', type=int, default=16, help='How many iters to print an item of log')
|
22 |
+
parser.add_argument('--flow_checkPoint', type=str, default='flowCheckPoint/',
|
23 |
+
help='The path for flow model filling')
|
24 |
+
parser.add_argument('--dataMode', type=str, default='resize', choices=['resize', 'crop'])
|
25 |
+
|
26 |
+
# data related parameters
|
27 |
+
parser.add_argument('--flow2rgb', type=int, default=1, help='Whether to transform flows from raw data to rgb')
|
28 |
+
parser.add_argument('--flow_direction', type=str, default='for', choices=['for', 'back', 'bi'],
|
29 |
+
help='Which GT flow should be chosen for guidance')
|
30 |
+
parser.add_argument('--num_frames', type=int, default=5, help='How many frames are chosen for frame completion')
|
31 |
+
parser.add_argument('--sample', type=str, default='random', choices=['random', 'seq'],
|
32 |
+
help='Choose the sample method for training in each iterations')
|
33 |
+
parser.add_argument('--max_val', type=float, default=0.01, help='The maximal value to quantize the optical flows')
|
34 |
+
|
35 |
+
# model related parameters
|
36 |
+
parser.add_argument('--res_h', type=int, default=240, help='The height of the frame resolution')
|
37 |
+
parser.add_argument('--res_w', type=int, default=432, help='The width of the frame resolution')
|
38 |
+
parser.add_argument('--in_channel', type=int, default=4, help='The input channel of the frame branch')
|
39 |
+
parser.add_argument('--cnum', type=int, default=64, help='The initial channel number of the frame branch')
|
40 |
+
parser.add_argument('--flow_inChannel', type=int, default=2, help='The input channel of the flow branch')
|
41 |
+
parser.add_argument('--flow_cnum', type=int, default=64, help='The initial channel dimension of the flow branch')
|
42 |
+
parser.add_argument('--dist_cnum', type=int, default=32, help='The initial channel num in the discriminator')
|
43 |
+
parser.add_argument('--frame_hidden', type=int, default=512,
|
44 |
+
help='The channel / patch dimension in the frame branch')
|
45 |
+
parser.add_argument('--flow_hidden', type=int, default=256, help='The channel / patch dimension in the flow branch')
|
46 |
+
parser.add_argument('--PASSMASK', type=int, default=1,
|
47 |
+
help='1 -> concat the mask with the corrupted optical flows to fill the flow')
|
48 |
+
parser.add_argument('--numBlocks', type=int, default=8, help='How many transformer blocks do we need to stack')
|
49 |
+
parser.add_argument('--kernel_size_w', type=int, default=7, help='The width of the kernel for extracting patches')
|
50 |
+
parser.add_argument('--kernel_size_h', type=int, default=7, help='The height of the kernel for extracting patches')
|
51 |
+
parser.add_argument('--stride_h', type=int, default=3, help='The height of the stride')
|
52 |
+
parser.add_argument('--stride_w', type=int, default=3, help='The width of the stride')
|
53 |
+
parser.add_argument('--pad_h', type=int, default=3, help='The height of the padding')
|
54 |
+
parser.add_argument('--pad_w', type=int, default=3, help='The width of the padding')
|
55 |
+
parser.add_argument('--num_head', type=int, default=4, help='The head number for the multihead attention')
|
56 |
+
parser.add_argument('--conv_type', type=str, choices=['vanilla', 'gated', 'partial'], default='vanilla',
|
57 |
+
help='Which kind of conv to use')
|
58 |
+
parser.add_argument('--norm', type=str, default='None', choices=['None', 'BN', 'SN', 'IN'],
|
59 |
+
help='The normalization method for the conv blocks')
|
60 |
+
parser.add_argument('--use_bias', type=int, default=1, help='If 1, use bias in the convolution blocks')
|
61 |
+
parser.add_argument('--ape', type=int, default=1, help='If ape = 1, use absolute positional embedding')
|
62 |
+
parser.add_argument('--pos_mode', type=str, default='single', choices=['single', 'dual'],
|
63 |
+
help='If pos_mode = dual, add positional embedding to flow patches')
|
64 |
+
parser.add_argument('--mlp_ratio', type=int, default=40, help='The mlp dilation rate for the feed forward layers')
|
65 |
+
parser.add_argument('--drop', type=int, default=0, help='The dropout rate, 0 by default')
|
66 |
+
parser.add_argument('--init_weights', type=int, default=1, help='If 1, initialize the network, 1 by default')
|
67 |
+
|
68 |
+
# loss related parameters
|
69 |
+
parser.add_argument('--L1M', type=float, default=1, help='The weight of L1 loss in the masked area')
|
70 |
+
parser.add_argument('--L1V', type=float, default=1, help='The weight of L1 loss in the valid area')
|
71 |
+
parser.add_argument('--adv', type=float, default=0.01, help='The weight of adversarial loss')
|
72 |
+
|
73 |
+
# spatial and temporal related parameters
|
74 |
+
parser.add_argument('--tw', type=int, default=2, help='The number of temporal group in the temporal transformer')
|
75 |
+
parser.add_argument('--sw', type=int, default=8,
|
76 |
+
help='The number of spatial window size in the spatial transformer')
|
77 |
+
parser.add_argument('--gd', type=int, default=4, help='Global downsample rate for spatial transformer')
|
78 |
+
|
79 |
+
parser.add_argument('--ref_length', type=int, default=10, help='The sample interval during inference')
|
80 |
+
parser.add_argument('--use_valid', action='store_true')
|
81 |
+
|
82 |
+
args = parser.parse_args()
|
83 |
+
return args
|
FGT_codes/FGT/metrics/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr
|
3 |
+
from skimage.metrics import structural_similarity as ssim
|
4 |
+
import os
|
5 |
+
|
6 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
7 |
+
|
8 |
+
|
9 |
+
def calculate_metrics(results, gts):
|
10 |
+
B, H, W, C = results.shape
|
11 |
+
psnr_values, ssim_values, L1errors, L2errors = [], [], [], []
|
12 |
+
for i in range(B):
|
13 |
+
result = results[i]
|
14 |
+
gt = gts[i]
|
15 |
+
result_img = result
|
16 |
+
gt_img = gt
|
17 |
+
residual = result - gt
|
18 |
+
L1error = np.mean(np.abs(residual))
|
19 |
+
L2error = np.sum(residual ** 2) ** 0.5 / (H * W * C)
|
20 |
+
psnr_value = psnr(result_img, gt_img)
|
21 |
+
ssim_value = ssim(result_img, gt_img, multichannel=True)
|
22 |
+
L1errors.append(L1error)
|
23 |
+
L2errors.append(L2error)
|
24 |
+
psnr_values.append(psnr_value)
|
25 |
+
ssim_values.append(ssim_value)
|
26 |
+
L1_value = np.mean(L1errors)
|
27 |
+
L2_value = np.mean(L2errors)
|
28 |
+
psnr_value = np.mean(psnr_values)
|
29 |
+
ssim_value = np.mean(ssim_values)
|
30 |
+
|
31 |
+
return {'l1': L1_value, 'l2': L2_value, 'psnr': psnr_value, 'ssim': ssim_value}
|
FGT_codes/FGT/metrics/psnr.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import math
|
3 |
+
|
4 |
+
|
5 |
+
def psnr(img1, img2):
|
6 |
+
mse = numpy.mean( (img1 - img2) ** 2 )
|
7 |
+
if mse == 0:
|
8 |
+
return 100
|
9 |
+
PIXEL_MAX = 255.0
|
10 |
+
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
|
FGT_codes/FGT/metrics/ssim.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def calculate_ssim(img1, img2):
|
6 |
+
C1 = (0.01 * 255)**2
|
7 |
+
C2 = (0.03 * 255)**2
|
8 |
+
|
9 |
+
img1 = img1.astype(np.float64)
|
10 |
+
img2 = img2.astype(np.float64)
|
11 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
12 |
+
window = np.outer(kernel, kernel.transpose())
|
13 |
+
|
14 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
15 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
16 |
+
mu1_sq = mu1**2
|
17 |
+
mu2_sq = mu2**2
|
18 |
+
mu1_mu2 = mu1 * mu2
|
19 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
20 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
21 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
22 |
+
|
23 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
24 |
+
(sigma1_sq + sigma2_sq + C2))
|
25 |
+
return ssim_map.mean()
|
26 |
+
|
27 |
+
|
28 |
+
def ssim(img1, img2):
|
29 |
+
'''calculate SSIM
|
30 |
+
the same outputs as MATLAB's
|
31 |
+
img1, img2: [0, 255]
|
32 |
+
'''
|
33 |
+
if not img1.shape == img2.shape:
|
34 |
+
raise ValueError('Input images must have the same dimensions.')
|
35 |
+
if img1.ndim == 2:
|
36 |
+
return calculate_ssim(img1, img2)
|
37 |
+
elif img1.ndim == 3:
|
38 |
+
if img1.shape[2] == 3:
|
39 |
+
ssims = []
|
40 |
+
for i in range(3):
|
41 |
+
ssims.append(calculate_ssim(img1[:, :, i], img2[:, :, i]))
|
42 |
+
return np.array(ssims).mean()
|
43 |
+
elif img1.shape[2] == 1:
|
44 |
+
return calculate_ssim(np.squeeze(img1), np.squeeze(img2))
|
45 |
+
else:
|
46 |
+
raise ValueError('Wrong input image dimensions.')
|
FGT_codes/FGT/models/BaseNetwork.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils.network_blocks_2d import *
|
2 |
+
|
3 |
+
|
4 |
+
class BaseNetwork(nn.Module):
|
5 |
+
def __init__(self, conv_type):
|
6 |
+
super(BaseNetwork, self).__init__()
|
7 |
+
self.conv_type = conv_type
|
8 |
+
if conv_type == 'gated':
|
9 |
+
self.ConvBlock = GatedConv
|
10 |
+
self.DeconvBlock = GatedDeconv
|
11 |
+
if conv_type == 'partial':
|
12 |
+
self.ConvBlock = PartialConv
|
13 |
+
self.DeconvBlock = PartialDeconv
|
14 |
+
if conv_type == 'vanilla':
|
15 |
+
self.ConvBlock = VanillaConv
|
16 |
+
self.DeconvBlock = VanillaDeconv
|
17 |
+
self.ConvBlock2d = self.ConvBlock
|
18 |
+
self.DeconvBlock2d = self.DeconvBlock
|
19 |
+
|
20 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
21 |
+
'''
|
22 |
+
initialize network's weights
|
23 |
+
init_type: normal | xavier | kaiming | orthogonal
|
24 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
25 |
+
'''
|
26 |
+
|
27 |
+
def init_func(m):
|
28 |
+
classname = m.__class__.__name__
|
29 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
30 |
+
if init_type == 'normal':
|
31 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
32 |
+
elif init_type == 'xavier':
|
33 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
34 |
+
elif init_type == 'kaiming':
|
35 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
36 |
+
elif init_type == 'orthogonal':
|
37 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
38 |
+
|
39 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
40 |
+
nn.init.constant_(m.bias.data, 0.0)
|
41 |
+
|
42 |
+
elif classname.find('BatchNorm2d') != -1:
|
43 |
+
nn.init.normal_(m.weight.data, 1.0, gain)
|
44 |
+
nn.init.constant_(m.bias.data, 0.0)
|
45 |
+
|
46 |
+
self.apply(init_func)
|
FGT_codes/FGT/models/__init__.py
ADDED
File without changes
|
FGT_codes/FGT/models/__pycache__/BaseNetwork.cpython-39.pyc
ADDED
Binary file (1.97 kB). View file
|
|
FGT_codes/FGT/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (163 Bytes). View file
|
|
FGT_codes/FGT/models/__pycache__/model.cpython-39.pyc
ADDED
Binary file (10.3 kB). View file
|
|
FGT_codes/FGT/models/lafc_single.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
import functools
|
5 |
+
from .BaseNetwork import BaseNetwork
|
6 |
+
from models.utils.reconstructionLayers import make_layer, ResidualBlock_noBN
|
7 |
+
|
8 |
+
|
9 |
+
class Model(nn.Module):
|
10 |
+
def __init__(self, config):
|
11 |
+
super(Model, self).__init__()
|
12 |
+
self.net = P3DNet(config['num_flows'], config['cnum'], config['in_channel'], config['PASSMASK'],
|
13 |
+
config['use_residual'],
|
14 |
+
config['resBlocks'], config['use_bias'], config['conv_type'], config['init_weights'])
|
15 |
+
|
16 |
+
def forward(self, flows, masks, edges=None):
|
17 |
+
ret = self.net(flows, masks, edges)
|
18 |
+
return ret
|
19 |
+
|
20 |
+
|
21 |
+
class P3DNet(BaseNetwork):
|
22 |
+
def __init__(self, num_flows, num_feats, in_channels, passmask, use_residual, res_blocks,
|
23 |
+
use_bias, conv_type, init_weights):
|
24 |
+
super().__init__(conv_type)
|
25 |
+
self.passmask = passmask
|
26 |
+
self.encoder2 = nn.Sequential(
|
27 |
+
nn.ReplicationPad2d(2),
|
28 |
+
self.ConvBlock2d(in_channels, num_feats, kernel_size=5, stride=1, padding=0, bias=use_bias, norm=None),
|
29 |
+
self.ConvBlock2d(num_feats, num_feats * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=None)
|
30 |
+
)
|
31 |
+
self.encoder4 = nn.Sequential(
|
32 |
+
self.ConvBlock2d(num_feats * 2, num_feats * 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
|
33 |
+
norm=None),
|
34 |
+
self.ConvBlock2d(num_feats * 2, num_feats * 4, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=None)
|
35 |
+
)
|
36 |
+
residualBlock = functools.partial(ResidualBlock_noBN, nf=num_feats * 4)
|
37 |
+
self.res_blocks = make_layer(residualBlock, res_blocks)
|
38 |
+
self.resNums = res_blocks
|
39 |
+
# dilation convolution to enlarge the receptive field
|
40 |
+
self.middle = nn.Sequential(
|
41 |
+
self.ConvBlock2d(num_feats * 4, num_feats * 4, kernel_size=3, stride=1, padding=8, bias=use_bias,
|
42 |
+
dilation=8, norm=None),
|
43 |
+
self.ConvBlock2d(num_feats * 4, num_feats * 4, kernel_size=3, stride=1, padding=4, bias=use_bias,
|
44 |
+
dilation=4, norm=None),
|
45 |
+
self.ConvBlock2d(num_feats * 4, num_feats * 4, kernel_size=3, stride=1, padding=2, bias=use_bias,
|
46 |
+
dilation=2, norm=None),
|
47 |
+
self.ConvBlock2d(num_feats * 4, num_feats * 4, kernel_size=3, stride=1, padding=1, bias=use_bias,
|
48 |
+
dilation=1, norm=None),
|
49 |
+
)
|
50 |
+
self.decoder2 = nn.Sequential(
|
51 |
+
self.DeconvBlock2d(num_feats * 8, num_feats * 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
|
52 |
+
norm=None),
|
53 |
+
self.ConvBlock2d(num_feats * 2, num_feats * 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
|
54 |
+
norm=None),
|
55 |
+
self.ConvBlock2d(num_feats * 2, num_feats * 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
|
56 |
+
norm=None)
|
57 |
+
)
|
58 |
+
self.decoder = nn.Sequential(
|
59 |
+
self.DeconvBlock2d(num_feats * 4, num_feats, kernel_size=3, stride=1, padding=1, bias=use_bias,
|
60 |
+
norm=None),
|
61 |
+
self.ConvBlock2d(num_feats, num_feats // 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
|
62 |
+
norm=None),
|
63 |
+
self.ConvBlock2d(num_feats // 2, 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
|
64 |
+
norm=None)
|
65 |
+
)
|
66 |
+
self.edgeDetector = EdgeDetection(conv_type)
|
67 |
+
if init_weights:
|
68 |
+
self.init_weights()
|
69 |
+
|
70 |
+
def forward(self, flows, masks, edges=None):
|
71 |
+
if self.passmask:
|
72 |
+
inputs = torch.cat((flows, masks), dim=1)
|
73 |
+
else:
|
74 |
+
inputs = flows
|
75 |
+
if edges is not None:
|
76 |
+
inputs = torch.cat((inputs, edges), dim=1)
|
77 |
+
e2 = self.encoder2(inputs)
|
78 |
+
e4 = self.encoder4(e2)
|
79 |
+
if self.resNums > 0:
|
80 |
+
e4_res = self.res_blocks(e4)
|
81 |
+
else:
|
82 |
+
e4_res = e4
|
83 |
+
c_e4_filled = self.middle(e4_res)
|
84 |
+
c_e4 = torch.cat((c_e4_filled, e4), dim=1)
|
85 |
+
c_e2Post = self.decoder2(c_e4)
|
86 |
+
c_e2 = torch.cat((c_e2Post, e2), dim=1)
|
87 |
+
output = self.decoder(c_e2)
|
88 |
+
edge = self.edgeDetector(output)
|
89 |
+
return output, edge
|
90 |
+
|
91 |
+
|
92 |
+
class EdgeDetection(BaseNetwork):
|
93 |
+
def __init__(self, conv_type, in_channels=2, out_channels=1, mid_channels=16):
|
94 |
+
super(EdgeDetection, self).__init__(conv_type)
|
95 |
+
self.projection = self.ConvBlock2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=1,
|
96 |
+
padding=1, norm=None)
|
97 |
+
self.mid_layer_1 = self.ConvBlock2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=3,
|
98 |
+
stride=1, padding=1, norm=None)
|
99 |
+
self.mid_layer_2 = self.ConvBlock2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=3,
|
100 |
+
stride=1, padding=1, activation=None, norm=None)
|
101 |
+
self.l_relu = nn.LeakyReLU()
|
102 |
+
self.out_layer = self.ConvBlock2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1,
|
103 |
+
activation=None, norm=None)
|
104 |
+
|
105 |
+
def forward(self, flow):
|
106 |
+
flow = self.projection(flow)
|
107 |
+
edge = self.mid_layer_1(flow)
|
108 |
+
edge = self.mid_layer_2(edge)
|
109 |
+
edge = self.l_relu(flow + edge)
|
110 |
+
edge = self.out_layer(edge)
|
111 |
+
edge = torch.sigmoid(edge)
|
112 |
+
return edge
|
113 |
+
|
114 |
+
|
FGT_codes/FGT/models/model.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.BaseNetwork import BaseNetwork
|
2 |
+
from models.transformer_base.ffn_base import FusionFeedForward
|
3 |
+
from models.transformer_base.attention_flow import SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow
|
4 |
+
from models.transformer_base.attention_base import TMHSA
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from functools import reduce
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class Model(nn.Module):
|
13 |
+
def __init__(self, config):
|
14 |
+
super(Model, self).__init__()
|
15 |
+
self.net = FGT(config['tw'], config['sw'], config['gd'], config['input_resolution'], config['in_channel'],
|
16 |
+
config['cnum'], config['flow_inChannel'], config['flow_cnum'], config['frame_hidden'],
|
17 |
+
config['flow_hidden'], config['PASSMASK'],
|
18 |
+
config['numBlocks'], config['kernel_size'], config['stride'], config['padding'],
|
19 |
+
config['num_head'], config['conv_type'], config['norm'],
|
20 |
+
config['use_bias'], config['ape'],
|
21 |
+
config['mlp_ratio'], config['drop'], config['init_weights'])
|
22 |
+
|
23 |
+
def forward(self, frames, flows, masks):
|
24 |
+
ret = self.net(frames, flows, masks)
|
25 |
+
return ret
|
26 |
+
|
27 |
+
|
28 |
+
class Encoder(nn.Module):
|
29 |
+
def __init__(self, in_channels):
|
30 |
+
super(Encoder, self).__init__()
|
31 |
+
self.group = [1, 2, 4, 8, 1]
|
32 |
+
self.layers = nn.ModuleList([
|
33 |
+
nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
|
34 |
+
nn.LeakyReLU(0.2, inplace=True),
|
35 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
36 |
+
nn.LeakyReLU(0.2, inplace=True),
|
37 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
38 |
+
nn.LeakyReLU(0.2, inplace=True),
|
39 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
40 |
+
nn.LeakyReLU(0.2, inplace=True),
|
41 |
+
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
|
42 |
+
nn.LeakyReLU(0.2, inplace=True),
|
43 |
+
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
|
44 |
+
nn.LeakyReLU(0.2, inplace=True),
|
45 |
+
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
|
46 |
+
nn.LeakyReLU(0.2, inplace=True),
|
47 |
+
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
|
48 |
+
nn.LeakyReLU(0.2, inplace=True),
|
49 |
+
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
|
50 |
+
nn.LeakyReLU(0.2, inplace=True)
|
51 |
+
])
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
bt, c, h, w = x.size()
|
55 |
+
h, w = h // 4, w // 4
|
56 |
+
out = x
|
57 |
+
for i, layer in enumerate(self.layers):
|
58 |
+
if i == 8:
|
59 |
+
x0 = out
|
60 |
+
if i > 8 and i % 2 == 0:
|
61 |
+
g = self.group[(i - 8) // 2]
|
62 |
+
x = x0.view(bt, g, -1, h, w)
|
63 |
+
o = out.view(bt, g, -1, h, w)
|
64 |
+
out = torch.cat([x, o], 2).view(bt, -1, h, w)
|
65 |
+
out = layer(out)
|
66 |
+
return out
|
67 |
+
|
68 |
+
|
69 |
+
class AddPosEmb(nn.Module):
|
70 |
+
def __init__(self, h, w, in_channels, out_channels):
|
71 |
+
super(AddPosEmb, self).__init__()
|
72 |
+
self.proj = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels)
|
73 |
+
self.h, self.w = h, w
|
74 |
+
|
75 |
+
def forward(self, x, h=0, w=0):
|
76 |
+
B, N, C = x.shape
|
77 |
+
if h == 0 and w == 0:
|
78 |
+
assert N == self.h * self.w, 'Wrong input size'
|
79 |
+
else:
|
80 |
+
assert N == h * w, 'Wrong input size during inference'
|
81 |
+
feat_token = x
|
82 |
+
if h == 0 and w == 0:
|
83 |
+
cnn_feat = feat_token.transpose(1, 2).view(B, C, self.h, self.w)
|
84 |
+
else:
|
85 |
+
cnn_feat = feat_token.transpose(1, 2).view(B, C, h, w)
|
86 |
+
x = self.proj(cnn_feat) + cnn_feat
|
87 |
+
x = x.flatten(2).transpose(1, 2)
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class Vec2Patch(nn.Module):
|
92 |
+
def __init__(self, channel, hidden, output_size, kernel_size, stride, padding):
|
93 |
+
super(Vec2Patch, self).__init__()
|
94 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
95 |
+
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
96 |
+
self.embedding = nn.Linear(hidden, c_out)
|
97 |
+
self.restore = nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding)
|
98 |
+
self.kernel_size = kernel_size
|
99 |
+
self.stride = stride
|
100 |
+
self.padding = padding
|
101 |
+
|
102 |
+
def forward(self, x, output_h=0, output_w=0):
|
103 |
+
feat = self.embedding(x)
|
104 |
+
feat = feat.permute(0, 2, 1)
|
105 |
+
if output_h != 0 or output_w != 0:
|
106 |
+
feat = F.fold(feat, output_size=(output_h, output_w), kernel_size=self.kernel_size, stride=self.stride,
|
107 |
+
padding=self.padding)
|
108 |
+
else:
|
109 |
+
feat = self.restore(feat)
|
110 |
+
return feat
|
111 |
+
|
112 |
+
|
113 |
+
class TemporalTransformer(nn.Module):
|
114 |
+
def __init__(self, token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, dropout, n_vecs,
|
115 |
+
t2t_params):
|
116 |
+
super(TemporalTransformer, self).__init__()
|
117 |
+
self.attention = TMHSA(token_size=token_size, group_size=t_groupSize, d_model=frame_hidden, head=num_heads,
|
118 |
+
p=dropout)
|
119 |
+
self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout)
|
120 |
+
self.norm1 = nn.LayerNorm(frame_hidden)
|
121 |
+
self.norm2 = nn.LayerNorm(frame_hidden)
|
122 |
+
self.dropout = nn.Dropout(p=dropout)
|
123 |
+
|
124 |
+
def forward(self, x, t, h, w, output_size):
|
125 |
+
token_size = h * w
|
126 |
+
s = self.norm1(x)
|
127 |
+
x = x + self.dropout(self.attention(s, t, h, w))
|
128 |
+
y = self.norm2(x)
|
129 |
+
x = x + self.ffn(y, token_size, output_size[0], output_size[1])
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class SpatialTransformer(nn.Module):
|
134 |
+
def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, g_downSize, mlp_ratio,
|
135 |
+
dropout, n_vecs, t2t_params):
|
136 |
+
super(SpatialTransformer, self).__init__()
|
137 |
+
self.attention = SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(token_size=token_size, window_size=s_windowSize,
|
138 |
+
kernel_size=g_downSize, d_model=frame_hidden,
|
139 |
+
flow_dModel=flow_hidden, head=num_heads, p=dropout)
|
140 |
+
self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout)
|
141 |
+
self.norm = nn.LayerNorm(frame_hidden)
|
142 |
+
self.dropout = nn.Dropout(p=dropout)
|
143 |
+
|
144 |
+
def forward(self, x, f, t, h, w, output_size):
|
145 |
+
token_size = h * w
|
146 |
+
x = x + self.dropout(self.attention(x, f, t, h, w))
|
147 |
+
y = self.norm(x)
|
148 |
+
x = x + self.ffn(y, token_size, output_size[0], output_size[1])
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class TransformerBlock(nn.Module):
|
153 |
+
def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, g_downSize,
|
154 |
+
mlp_ratio,
|
155 |
+
dropout, n_vecs,
|
156 |
+
t2t_params):
|
157 |
+
super(TransformerBlock, self).__init__()
|
158 |
+
self.t_transformer = TemporalTransformer(token_size=token_size, frame_hidden=frame_hidden, num_heads=num_heads,
|
159 |
+
t_groupSize=t_groupSize, mlp_ratio=mlp_ratio,
|
160 |
+
dropout=dropout, n_vecs=n_vecs,
|
161 |
+
t2t_params=t2t_params) # temporal multi-head self attention
|
162 |
+
self.s_transformer = SpatialTransformer(token_size=token_size, frame_hidden=frame_hidden,
|
163 |
+
flow_hidden=flow_hidden, num_heads=num_heads, s_windowSize=s_windowSize,
|
164 |
+
g_downSize=g_downSize, mlp_ratio=mlp_ratio,
|
165 |
+
dropout=dropout, n_vecs=n_vecs, t2t_params=t2t_params)
|
166 |
+
|
167 |
+
def forward(self, inputs):
|
168 |
+
x, f, t = inputs['x'], inputs['f'], inputs['t']
|
169 |
+
h, w = inputs['h'], inputs['w']
|
170 |
+
output_size = inputs['output_size']
|
171 |
+
x = self.t_transformer(x, t, h, w, output_size)
|
172 |
+
x = self.s_transformer(x, f, t, h, w, output_size)
|
173 |
+
return {'x': x, 'f': f, 't': t, 'h': h, 'w': w, 'output_size': output_size}
|
174 |
+
|
175 |
+
|
176 |
+
class Decoder(BaseNetwork):
|
177 |
+
def __init__(self, conv_type, in_channels, out_channels, use_bias, norm=None):
|
178 |
+
super(Decoder, self).__init__(conv_type)
|
179 |
+
self.layer1 = self.DeconvBlock(in_channels, in_channels, kernel_size=3, padding=1, norm=norm,
|
180 |
+
bias=use_bias)
|
181 |
+
self.layer2 = self.ConvBlock(in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1, norm=norm,
|
182 |
+
bias=use_bias)
|
183 |
+
self.layer3 = self.DeconvBlock(in_channels // 2, in_channels // 2, kernel_size=3, padding=1, norm=norm,
|
184 |
+
bias=use_bias)
|
185 |
+
self.final = self.ConvBlock(in_channels // 2, out_channels, kernel_size=3, stride=1, padding=1, norm=norm,
|
186 |
+
bias=use_bias, activation=None)
|
187 |
+
|
188 |
+
def forward(self, features):
|
189 |
+
feat1 = self.layer1(features)
|
190 |
+
feat2 = self.layer2(feat1)
|
191 |
+
feat3 = self.layer3(feat2)
|
192 |
+
output = self.final(feat3)
|
193 |
+
return output
|
194 |
+
|
195 |
+
|
196 |
+
class FGT(BaseNetwork):
|
197 |
+
def __init__(self, t_groupSize, s_windowSize, g_downSize, input_resolution, in_channels, cnum, flow_inChannel,
|
198 |
+
flow_cnum,
|
199 |
+
frame_hidden, flow_hidden, passmask, numBlocks, kernel_size, stride, padding, num_heads, conv_type,
|
200 |
+
norm, use_bias, ape, mlp_ratio=4, drop=0, init_weights=True):
|
201 |
+
super(FGT, self).__init__(conv_type)
|
202 |
+
self.in_channels = in_channels
|
203 |
+
self.passmask = passmask
|
204 |
+
self.ape = ape
|
205 |
+
self.frame_endoder = Encoder(in_channels)
|
206 |
+
self.flow_encoder = nn.Sequential(
|
207 |
+
nn.ReplicationPad2d(2),
|
208 |
+
self.ConvBlock(flow_inChannel, flow_cnum, kernel_size=5, stride=1, padding=0, bias=use_bias, norm=norm),
|
209 |
+
self.ConvBlock(flow_cnum, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm),
|
210 |
+
self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=norm),
|
211 |
+
self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm)
|
212 |
+
)
|
213 |
+
# patch to vector operation
|
214 |
+
self.patch2vec = nn.Conv2d(cnum * 2, frame_hidden, kernel_size=kernel_size, stride=stride, padding=padding)
|
215 |
+
self.f_patch2vec = nn.Conv2d(flow_cnum * 2, flow_hidden, kernel_size=kernel_size, stride=stride,
|
216 |
+
padding=padding)
|
217 |
+
# initialize transformer blocks for frame completion
|
218 |
+
n_vecs = 1
|
219 |
+
token_size = []
|
220 |
+
output_shape = (input_resolution[0] // 4, input_resolution[1] // 4)
|
221 |
+
for i, d in enumerate(kernel_size):
|
222 |
+
token_nums = int((output_shape[i] + 2 * padding[i] - kernel_size[i]) / stride[i] + 1)
|
223 |
+
n_vecs *= token_nums
|
224 |
+
token_size.append(token_nums)
|
225 |
+
# Add positional embedding to the encode features
|
226 |
+
if self.ape:
|
227 |
+
self.add_pos_emb = AddPosEmb(token_size[0], token_size[1], frame_hidden, frame_hidden)
|
228 |
+
self.token_size = token_size
|
229 |
+
# initialize transformer blocks
|
230 |
+
blocks = []
|
231 |
+
t2t_params = {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_shape}
|
232 |
+
for i in range(numBlocks // 2 - 1):
|
233 |
+
layer = TransformerBlock(token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize,
|
234 |
+
g_downSize, mlp_ratio, drop, n_vecs, t2t_params)
|
235 |
+
blocks.append(layer)
|
236 |
+
self.first_t_transformer = TemporalTransformer(token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio,
|
237 |
+
drop, n_vecs, t2t_params)
|
238 |
+
self.first_s_transformer = SpatialTransformer(token_size, frame_hidden, flow_hidden, num_heads, s_windowSize,
|
239 |
+
g_downSize, mlp_ratio, drop, n_vecs, t2t_params)
|
240 |
+
self.transformer = nn.Sequential(*blocks)
|
241 |
+
# vector to patch operation
|
242 |
+
self.vec2patch = Vec2Patch(cnum * 2, frame_hidden, output_shape, kernel_size, stride, padding)
|
243 |
+
# decoder
|
244 |
+
self.decoder = Decoder(conv_type, cnum * 2, 3, use_bias, norm)
|
245 |
+
|
246 |
+
if init_weights:
|
247 |
+
self.init_weights()
|
248 |
+
|
249 |
+
def forward(self, masked_frames, flows, masks):
|
250 |
+
b, t, c, h, w = masked_frames.shape
|
251 |
+
cf = flows.shape[2]
|
252 |
+
output_shape = (h // 4, w // 4)
|
253 |
+
if self.passmask:
|
254 |
+
inputs = torch.cat((masked_frames, masks), dim=2)
|
255 |
+
else:
|
256 |
+
inputs = masked_frames
|
257 |
+
inputs = inputs.view(b * t, self.in_channels, h, w)
|
258 |
+
flows = flows.view(b * t, cf, h, w)
|
259 |
+
enc_feats = self.frame_endoder(inputs)
|
260 |
+
flow_feats = self.flow_encoder(flows)
|
261 |
+
trans_feat = self.patch2vec(enc_feats)
|
262 |
+
flow_patches = self.f_patch2vec(flow_feats)
|
263 |
+
_, c, h, w = trans_feat.shape
|
264 |
+
cf = flow_patches.shape[1]
|
265 |
+
if h != self.token_size[0] or w != self.token_size[1]:
|
266 |
+
new_h, new_w = h, w
|
267 |
+
else:
|
268 |
+
new_h, new_w = 0, 0
|
269 |
+
output_shape = (0, 0)
|
270 |
+
trans_feat = trans_feat.view(b * t, c, -1).permute(0, 2, 1)
|
271 |
+
flow_patches = flow_patches.view(b * t, cf, -1).permute(0, 2, 1)
|
272 |
+
trans_feat = self.first_t_transformer(trans_feat, t, new_h, new_w, output_shape)
|
273 |
+
trans_feat = self.add_pos_emb(trans_feat, new_h, new_w)
|
274 |
+
trans_feat = self.first_s_transformer(trans_feat, flow_patches, t, new_h, new_w, output_shape)
|
275 |
+
inputs_trans_feat = {'x': trans_feat, 'f': flow_patches, 't': t, 'h': new_h, 'w': new_w,
|
276 |
+
'output_size': output_shape}
|
277 |
+
trans_feat = self.transformer(inputs_trans_feat)['x']
|
278 |
+
trans_feat = self.vec2patch(trans_feat, output_shape[0], output_shape[1])
|
279 |
+
enc_feats = enc_feats + trans_feat
|
280 |
+
|
281 |
+
output = self.decoder(enc_feats)
|
282 |
+
output = torch.tanh(output)
|
283 |
+
return output
|
284 |
+
|
FGT_codes/FGT/models/temporal_patch_gan.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# temporal patch GAN to maintain the temporal consecutive of the flows
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from .BaseNetwork import BaseNetwork
|
5 |
+
|
6 |
+
|
7 |
+
class Discriminator(BaseNetwork):
|
8 |
+
def __init__(self, in_channels, conv_type, dist_cnum, use_sigmoid=False, use_spectral_norm=True, init_weights=True):
|
9 |
+
"""
|
10 |
+
|
11 |
+
Args:
|
12 |
+
in_channels: The input channels of the discriminator
|
13 |
+
use_sigmoid: Whether to use sigmoid for the base network (true for the nsgan)
|
14 |
+
use_spectral_norm: The usage of the spectral norm: always be true for the stability of GAN
|
15 |
+
init_weights: always be True
|
16 |
+
"""
|
17 |
+
super(Discriminator, self).__init__(conv_type)
|
18 |
+
self.use_sigmoid = use_sigmoid
|
19 |
+
nf = dist_cnum
|
20 |
+
|
21 |
+
self.conv = nn.Sequential(
|
22 |
+
spectral_norm(
|
23 |
+
nn.Conv3d(in_channels=in_channels, out_channels=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
24 |
+
padding=(1, 2, 2),
|
25 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
26 |
+
nn.LeakyReLU(0.2, inplace=True),
|
27 |
+
spectral_norm(
|
28 |
+
nn.Conv3d(in_channels=nf * 1, out_channels=nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
29 |
+
padding=(1, 2, 2),
|
30 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
31 |
+
nn.LeakyReLU(0.2, inplace=True),
|
32 |
+
spectral_norm(
|
33 |
+
nn.Conv3d(in_channels=nf * 2, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
34 |
+
padding=(1, 2, 2),
|
35 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
36 |
+
nn.LeakyReLU(0.2, inplace=True),
|
37 |
+
spectral_norm(
|
38 |
+
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
39 |
+
padding=(1, 2, 2),
|
40 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
41 |
+
nn.LeakyReLU(0.2, inplace=True),
|
42 |
+
spectral_norm(
|
43 |
+
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
44 |
+
padding=(1, 2, 2),
|
45 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
46 |
+
nn.LeakyReLU(0.2, inplace=True),
|
47 |
+
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
48 |
+
padding=(1, 2, 2))
|
49 |
+
)
|
50 |
+
|
51 |
+
if init_weights:
|
52 |
+
self.init_weights()
|
53 |
+
|
54 |
+
def forward(self, xs, t):
|
55 |
+
"""
|
56 |
+
|
57 |
+
Args:
|
58 |
+
xs: Input feature, with shape of [bt, c, h, w]
|
59 |
+
|
60 |
+
Returns: The discriminative map from the GAN
|
61 |
+
|
62 |
+
"""
|
63 |
+
bt, c, h, w = xs.shape
|
64 |
+
b = bt // t
|
65 |
+
xs = xs.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous()
|
66 |
+
feat = self.conv(xs)
|
67 |
+
if self.use_sigmoid:
|
68 |
+
feat = torch.sigmoid(feat)
|
69 |
+
out = torch.transpose(feat, 1, 2) # [b, t, c, h, w]
|
70 |
+
return out
|
71 |
+
|
72 |
+
|
73 |
+
def spectral_norm(module, mode=True):
|
74 |
+
if mode:
|
75 |
+
return nn.utils.spectral_norm(module)
|
76 |
+
return module
|
FGT_codes/FGT/models/transformer_base/__init__.py
ADDED
File without changes
|
FGT_codes/FGT/models/transformer_base/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (180 Bytes). View file
|
|
FGT_codes/FGT/models/transformer_base/__pycache__/attention_base.cpython-39.pyc
ADDED
Binary file (4.1 kB). View file
|
|
FGT_codes/FGT/models/transformer_base/__pycache__/attention_flow.cpython-39.pyc
ADDED
Binary file (5.51 kB). View file
|
|
FGT_codes/FGT/models/transformer_base/__pycache__/ffn_base.cpython-39.pyc
ADDED
Binary file (4.11 kB). View file
|
|
FGT_codes/FGT/models/transformer_base/attention_base.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class Attention(nn.Module):
|
8 |
+
"""
|
9 |
+
Compute 'Scaled Dot Product Attention
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, p=0.1):
|
13 |
+
super(Attention, self).__init__()
|
14 |
+
self.dropout = nn.Dropout(p=p)
|
15 |
+
|
16 |
+
def forward(self, query, key, value):
|
17 |
+
scores = torch.matmul(query, key.transpose(-2, -1)
|
18 |
+
) / math.sqrt(query.size(-1))
|
19 |
+
p_attn = F.softmax(scores, dim=-1)
|
20 |
+
p_attn = self.dropout(p_attn)
|
21 |
+
p_val = torch.matmul(p_attn, value)
|
22 |
+
return p_val, p_attn
|
23 |
+
|
24 |
+
|
25 |
+
class TMHSA(nn.Module):
|
26 |
+
def __init__(self, token_size, group_size, d_model, head, p=0.1):
|
27 |
+
super(TMHSA, self).__init__()
|
28 |
+
self.h, self.w = token_size
|
29 |
+
self.group_size = group_size # 这里的group size表示可分的组
|
30 |
+
self.wh, self.ww = math.ceil(self.h / self.group_size), math.ceil(self.w / self.group_size)
|
31 |
+
self.pad_r = (self.ww - self.w % self.ww) % self.ww
|
32 |
+
self.pad_b = (self.wh - self.h % self.wh) % self.wh
|
33 |
+
self.new_h, self.new_w = self.h + self.pad_b, self.w + self.pad_r # 只在右侧和下侧进行padding,另一侧不padding,实现起来更加容易
|
34 |
+
self.window_h, self.window_w = self.new_h // self.group_size, self.new_w // self.group_size # 这里面的group表示的是窗口大小,而window_size表示的是group大小(与spatial的定义不同)
|
35 |
+
self.d_model = d_model
|
36 |
+
self.p = p
|
37 |
+
self.query_embedding = nn.Linear(d_model, d_model)
|
38 |
+
self.key_embedding = nn.Linear(d_model, d_model)
|
39 |
+
self.value_embedding = nn.Linear(d_model, d_model)
|
40 |
+
self.output_linear = nn.Linear(d_model, d_model)
|
41 |
+
self.attention = Attention(p=p)
|
42 |
+
self.head = head
|
43 |
+
|
44 |
+
def inference(self, x, t, h, w):
|
45 |
+
# calculate the attention related parameters
|
46 |
+
wh, ww = math.ceil(h / self.group_size), math.ceil(w / self.group_size)
|
47 |
+
pad_r = (ww - w % ww) % ww
|
48 |
+
pad_b = (wh - h % wh) % wh
|
49 |
+
new_h, new_w = h + pad_b, w + pad_r
|
50 |
+
window_h, window_w = new_h // self.group_size, new_w // self.group_size
|
51 |
+
bt, n, c = x.shape
|
52 |
+
b = bt // t
|
53 |
+
c_h = c // self.head
|
54 |
+
x = x.view(bt, h, w, c)
|
55 |
+
if pad_r > 0 or pad_b > 0:
|
56 |
+
x = F.pad(x,
|
57 |
+
(0, 0, 0, pad_r, 0, pad_b)) # channel, channel, left, right, top, bottom -> [bt, new_h, new_w, c]
|
58 |
+
query = self.query_embedding(x)
|
59 |
+
key = self.key_embedding(x)
|
60 |
+
value = self.value_embedding(x)
|
61 |
+
query = query.view(b, t, self.group_size, window_h, self.group_size, window_w, self.head, c_h)
|
62 |
+
query = query.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
|
63 |
+
key = key.view(b, t, self.group_size, window_h, self.group_size, window_w, self.head, c_h)
|
64 |
+
key = key.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
|
65 |
+
value = value.view(b, t, self.group_size, window_h, self.group_size, window_w, self.head, c_h)
|
66 |
+
value = value.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
|
67 |
+
att, _ = self.attention(query, key, value)
|
68 |
+
att = att.view(b, self.group_size, self.group_size, self.head, t, window_h, window_w, c_h)
|
69 |
+
att = att.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(bt, new_h, new_w, c)
|
70 |
+
if pad_b > 0 or pad_r > 0:
|
71 |
+
att = att[:, :h, :w, :]
|
72 |
+
att = att.reshape(bt, n, c)
|
73 |
+
output = self.output_linear(att)
|
74 |
+
return output
|
75 |
+
|
76 |
+
def forward(self, x, t, h=0, w=0):
|
77 |
+
bt, n, c = x.shape
|
78 |
+
if h == 0 and w == 0:
|
79 |
+
assert n == self.h * self.w, 'Wrong input shape: {} with token: h->{}, w->{}'.format(x.shape, self.h,
|
80 |
+
self.w)
|
81 |
+
else:
|
82 |
+
assert n == h * w, 'Wrong input shape: {} with token: h->{}, w->{}'.format(x.shape, h, w)
|
83 |
+
return self.inference(x, t, h, w)
|
84 |
+
b = bt // t
|
85 |
+
c_h = c // self.head
|
86 |
+
x = x.view(bt, self.h, self.w, c)
|
87 |
+
if self.pad_r > 0 or self.pad_b > 0:
|
88 |
+
x = F.pad(x, (
|
89 |
+
0, 0, 0, self.pad_r, 0, self.pad_b)) # channel, channel, left, right, top, bottom -> [bt, new_h, new_w, c]
|
90 |
+
query = self.query_embedding(x)
|
91 |
+
key = self.key_embedding(x)
|
92 |
+
value = self.value_embedding(x)
|
93 |
+
query = query.view(b, t, self.group_size, self.window_h, self.group_size, self.window_w, self.head, c_h)
|
94 |
+
query = query.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
|
95 |
+
key = key.view(b, t, self.group_size, self.window_h, self.group_size, self.window_w, self.head, c_h)
|
96 |
+
key = key.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
|
97 |
+
value = value.view(b, t, self.group_size, self.window_h, self.group_size, self.window_w, self.head, c_h)
|
98 |
+
value = value.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
|
99 |
+
att, _ = self.attention(query, key, value)
|
100 |
+
att = att.view(b, self.group_size, self.group_size, self.head, t, self.window_h, self.window_w, c_h)
|
101 |
+
att = att.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(bt, self.new_h, self.new_w, c)
|
102 |
+
if self.pad_b > 0 or self.pad_r > 0:
|
103 |
+
att = att[:, :self.h, :self.w, :]
|
104 |
+
att = att.reshape(bt, n, c)
|
105 |
+
output = self.output_linear(att)
|
106 |
+
return output
|
FGT_codes/FGT/models/transformer_base/attention_flow.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class Attention(nn.Module):
|
8 |
+
"""
|
9 |
+
Compute 'Scaled Dot Product Attention
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, p=0.1):
|
13 |
+
super(Attention, self).__init__()
|
14 |
+
self.dropout = nn.Dropout(p=p)
|
15 |
+
|
16 |
+
def forward(self, query, key, value):
|
17 |
+
scores = torch.matmul(query, key.transpose(-2, -1)
|
18 |
+
) / math.sqrt(query.size(-1))
|
19 |
+
p_attn = F.softmax(scores, dim=-1)
|
20 |
+
p_attn = self.dropout(p_attn)
|
21 |
+
p_val = torch.matmul(p_attn, value)
|
22 |
+
return p_val, p_attn
|
23 |
+
|
24 |
+
|
25 |
+
class SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(nn.Module):
|
26 |
+
def __init__(self, token_size, window_size, kernel_size, d_model, flow_dModel, head, p=0.1):
|
27 |
+
super(SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow, self).__init__()
|
28 |
+
self.h, self.w = token_size
|
29 |
+
self.head = head
|
30 |
+
self.window_size = window_size
|
31 |
+
self.d_model = d_model
|
32 |
+
self.flow_dModel = flow_dModel
|
33 |
+
in_channels = d_model + flow_dModel
|
34 |
+
self.query_embedding = nn.Linear(in_channels, d_model)
|
35 |
+
self.key_embedding = nn.Linear(in_channels, d_model)
|
36 |
+
self.value_embedding = nn.Linear(d_model, d_model)
|
37 |
+
self.output_linear = nn.Linear(d_model, d_model)
|
38 |
+
self.attention = Attention(p)
|
39 |
+
self.pad_l = self.pad_t = 0
|
40 |
+
self.pad_r = (self.window_size - self.w % self.window_size) % self.window_size
|
41 |
+
self.pad_b = (self.window_size - self.h % self.window_size) % self.window_size
|
42 |
+
self.new_h, self.new_w = self.h + self.pad_b, self.w + self.pad_r
|
43 |
+
self.group_h, self.group_w = self.new_h // self.window_size, self.new_w // self.window_size
|
44 |
+
self.global_extract_v = nn.Conv2d(d_model, d_model, kernel_size=kernel_size, stride=kernel_size, padding=0,
|
45 |
+
groups=d_model)
|
46 |
+
self.global_extract_k = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=kernel_size,
|
47 |
+
padding=0,
|
48 |
+
groups=in_channels)
|
49 |
+
self.q_norm = nn.LayerNorm(d_model + flow_dModel)
|
50 |
+
self.k_norm = nn.LayerNorm(d_model + flow_dModel)
|
51 |
+
self.v_norm = nn.LayerNorm(d_model)
|
52 |
+
self.reweightFlow = nn.Sequential(
|
53 |
+
nn.Linear(in_channels, flow_dModel),
|
54 |
+
nn.Sigmoid()
|
55 |
+
)
|
56 |
+
|
57 |
+
def inference(self, x, f, h, w):
|
58 |
+
pad_r = (self.window_size - w % self.window_size) % self.window_size
|
59 |
+
pad_b = (self.window_size - h % self.window_size) % self.window_size
|
60 |
+
new_h, new_w = h + pad_b, w + pad_r
|
61 |
+
group_h, group_w = new_h // self.window_size, new_w // self.window_size
|
62 |
+
bt, n, c = x.shape
|
63 |
+
cf = f.shape[2]
|
64 |
+
x = x.view(bt, h, w, c)
|
65 |
+
f = f.view(bt, h, w, cf)
|
66 |
+
if pad_r > 0 or pad_b > 0:
|
67 |
+
x = F.pad(x, (0, 0, self.pad_l, pad_r, self.pad_t, pad_b))
|
68 |
+
f = F.pad(f, (0, 0, self.pad_l, pad_r, self.pad_t, pad_b))
|
69 |
+
y = x.permute(0, 3, 1, 2)
|
70 |
+
xf = torch.cat((x, f), dim=-1)
|
71 |
+
flow_weights = self.reweightFlow(xf)
|
72 |
+
f = f * flow_weights
|
73 |
+
qk = torch.cat((x, f), dim=-1) # [b, h, w, c]
|
74 |
+
qk_c = qk.shape[-1]
|
75 |
+
# generate q
|
76 |
+
q = qk.reshape(bt, group_h, self.window_size, group_w, self.window_size, qk_c).transpose(2, 3)
|
77 |
+
q = q.reshape(bt, group_h * group_w, self.window_size * self.window_size, qk_c)
|
78 |
+
# generate k
|
79 |
+
ky = qk.permute(0, 3, 1, 2) # [b, c, h, w]
|
80 |
+
k_global = self.global_extract_k(ky)
|
81 |
+
k_global = k_global.permute(0, 2, 3, 1).reshape(bt, -1, qk_c).unsqueeze(1).repeat(1, group_h * group_w, 1, 1)
|
82 |
+
k = torch.cat((q, k_global), dim=2)
|
83 |
+
# norm q and k
|
84 |
+
q = self.q_norm(q)
|
85 |
+
k = self.k_norm(k)
|
86 |
+
# generate v
|
87 |
+
global_tokens = self.global_extract_v(y) # [bt, c, h', w']
|
88 |
+
global_tokens = global_tokens.permute(0, 2, 3, 1).reshape(bt, -1, c).unsqueeze(1).repeat(1,
|
89 |
+
group_h * group_w,
|
90 |
+
1,
|
91 |
+
1) # [bt, gh * gw, h'*w', c]
|
92 |
+
x = x.reshape(bt, group_h, self.window_size, group_w, self.window_size, c).transpose(2,
|
93 |
+
3) # [bt, gh, gw, ws, ws, c]
|
94 |
+
x = x.reshape(bt, group_h * group_w, self.window_size * self.window_size, c) # [bt, gh * gw, ws^2, c]
|
95 |
+
v = torch.cat((x, global_tokens), dim=2)
|
96 |
+
v = self.v_norm(v)
|
97 |
+
query = self.query_embedding(q) # [bt, self.group_h, self.group_w, self.window_size, self.window_size, c]
|
98 |
+
key = self.key_embedding(k)
|
99 |
+
value = self.value_embedding(v)
|
100 |
+
query = query.reshape(bt, group_h * group_w, self.window_size * self.window_size, self.head,
|
101 |
+
c // self.head).permute(0, 1, 3, 2, 4)
|
102 |
+
key = key.reshape(bt, group_h * group_w, -1, self.head,
|
103 |
+
c // self.head).permute(0, 1, 3, 2, 4)
|
104 |
+
value = value.reshape(bt, group_h * group_w, -1, self.head,
|
105 |
+
c // self.head).permute(0, 1, 3, 2, 4)
|
106 |
+
attn, _ = self.attention(query, key, value)
|
107 |
+
x = attn.transpose(2, 3).reshape(bt, group_h, group_w, self.window_size, self.window_size, c)
|
108 |
+
x = x.transpose(2, 3).reshape(bt, group_h * self.window_size, group_w * self.window_size, c)
|
109 |
+
if pad_r > 0 or pad_b > 0:
|
110 |
+
x = x[:, :h, :w, :].contiguous()
|
111 |
+
x = x.reshape(bt, n, c)
|
112 |
+
output = self.output_linear(x)
|
113 |
+
return output
|
114 |
+
|
115 |
+
def forward(self, x, f, t, h=0, w=0):
|
116 |
+
if h != 0 or w != 0:
|
117 |
+
return self.inference(x, f, h, w)
|
118 |
+
bt, n, c = x.shape
|
119 |
+
cf = f.shape[2]
|
120 |
+
x = x.view(bt, self.h, self.w, c)
|
121 |
+
f = f.view(bt, self.h, self.w, cf)
|
122 |
+
if self.pad_r > 0 or self.pad_b > 0:
|
123 |
+
x = F.pad(x, (0, 0, self.pad_l, self.pad_r, self.pad_t, self.pad_b))
|
124 |
+
f = F.pad(f, (0, 0, self.pad_l, self.pad_r, self.pad_t, self.pad_b)) # [bt, cf, h, w]
|
125 |
+
y = x.permute(0, 3, 1, 2)
|
126 |
+
xf = torch.cat((x, f), dim=-1)
|
127 |
+
weights = self.reweightFlow(xf)
|
128 |
+
f = f * weights
|
129 |
+
qk = torch.cat((x, f), dim=-1) # [b, h, w, c]
|
130 |
+
qk_c = qk.shape[-1]
|
131 |
+
# generate q
|
132 |
+
q = qk.reshape(bt, self.group_h, self.window_size, self.group_w, self.window_size, qk_c).transpose(2, 3)
|
133 |
+
q = q.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, qk_c)
|
134 |
+
# generate k
|
135 |
+
ky = qk.permute(0, 3, 1, 2) # [b, c, h, w]
|
136 |
+
k_global = self.global_extract_k(ky) # [b, qk_c, h, w]
|
137 |
+
k_global = k_global.permute(0, 2, 3, 1).reshape(bt, -1, qk_c).unsqueeze(1).repeat(1,
|
138 |
+
self.group_h * self.group_w,
|
139 |
+
1, 1)
|
140 |
+
k = torch.cat((q, k_global), dim=2)
|
141 |
+
# norm q and k
|
142 |
+
q = self.q_norm(q)
|
143 |
+
k = self.k_norm(k)
|
144 |
+
# generate v
|
145 |
+
global_tokens = self.global_extract_v(y) # [bt, c, h', w']
|
146 |
+
global_tokens = global_tokens.permute(0, 2, 3, 1).reshape(bt, -1, c).unsqueeze(1).repeat(1,
|
147 |
+
self.group_h * self.group_w,
|
148 |
+
1,
|
149 |
+
1) # [bt, gh * gw, h'*w', c]
|
150 |
+
x = x.reshape(bt, self.group_h, self.window_size, self.group_w, self.window_size, c).transpose(2,
|
151 |
+
3) # [bt, gh, gw, ws, ws, c]
|
152 |
+
x = x.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, c) # [bt, gh * gw, ws^2, c]
|
153 |
+
v = torch.cat((x, global_tokens), dim=2)
|
154 |
+
v = self.v_norm(v)
|
155 |
+
query = self.query_embedding(q) # [bt, self.group_h, self.group_w, self.window_size, self.window_size, c]
|
156 |
+
key = self.key_embedding(k)
|
157 |
+
value = self.value_embedding(v)
|
158 |
+
query = query.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, self.head,
|
159 |
+
c // self.head).permute(0, 1, 3, 2, 4)
|
160 |
+
key = key.reshape(bt, self.group_h * self.group_w, -1, self.head,
|
161 |
+
c // self.head).permute(0, 1, 3, 2, 4)
|
162 |
+
value = value.reshape(bt, self.group_h * self.group_w, -1, self.head,
|
163 |
+
c // self.head).permute(0, 1, 3, 2, 4)
|
164 |
+
attn, _ = self.attention(query, key, value)
|
165 |
+
x = attn.transpose(2, 3).reshape(bt, self.group_h, self.group_w, self.window_size, self.window_size, c)
|
166 |
+
x = x.transpose(2, 3).reshape(bt, self.group_h * self.window_size, self.group_w * self.window_size, c)
|
167 |
+
if self.pad_r > 0 or self.pad_b > 0:
|
168 |
+
x = x[:, :self.h, :self.w, :].contiguous()
|
169 |
+
x = x.reshape(bt, n, c)
|
170 |
+
output = self.output_linear(x)
|
171 |
+
return output
|
FGT_codes/FGT/models/transformer_base/ffn_base.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from functools import reduce
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
|
8 |
+
class FeedForward(nn.Module):
|
9 |
+
def __init__(self, frame_hidden, mlp_ratio, n_vecs, t2t_params, p):
|
10 |
+
"""
|
11 |
+
|
12 |
+
Args:
|
13 |
+
frame_hidden: hidden size of frame features
|
14 |
+
mlp_ratio: mlp ratio in the middle layer of the transformers
|
15 |
+
n_vecs: number of vectors in the transformer
|
16 |
+
t2t_params: dictionary -> {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_shape}
|
17 |
+
p: dropout rate, 0 by default
|
18 |
+
"""
|
19 |
+
super(FeedForward, self).__init__()
|
20 |
+
self.conv = nn.Sequential(
|
21 |
+
nn.Linear(frame_hidden, frame_hidden * mlp_ratio),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Dropout(p),
|
24 |
+
nn.Linear(frame_hidden * mlp_ratio, frame_hidden),
|
25 |
+
nn.Dropout(p)
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, x, n_vecs=0, output_h=0, output_w=0):
|
29 |
+
x = self.conv(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class FusionFeedForward(nn.Module):
|
34 |
+
def __init__(self, frame_hidden, mlp_ratio, n_vecs, t2t_params, p):
|
35 |
+
super(FusionFeedForward, self).__init__()
|
36 |
+
self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size'])
|
37 |
+
self.t2t_params = t2t_params
|
38 |
+
hidden_size = self.kernel_shape * mlp_ratio
|
39 |
+
self.conv1 = nn.Linear(frame_hidden, hidden_size)
|
40 |
+
self.conv2 = nn.Sequential(
|
41 |
+
nn.ReLU(inplace=True),
|
42 |
+
nn.Dropout(p),
|
43 |
+
nn.Linear(hidden_size, frame_hidden),
|
44 |
+
nn.Dropout(p)
|
45 |
+
)
|
46 |
+
assert t2t_params is not None and n_vecs is not None
|
47 |
+
tp = t2t_params.copy()
|
48 |
+
self.fold = nn.Fold(**tp)
|
49 |
+
del tp['output_size']
|
50 |
+
self.unfold = nn.Unfold(**tp)
|
51 |
+
self.n_vecs = n_vecs
|
52 |
+
|
53 |
+
def forward(self, x, n_vecs=0, output_h=0, output_w=0):
|
54 |
+
x = self.conv1(x)
|
55 |
+
b, n, c = x.size()
|
56 |
+
if n_vecs != 0:
|
57 |
+
normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
|
58 |
+
x = self.unfold(F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), output_size=(output_h, output_w),
|
59 |
+
kernel_size=self.t2t_params['kernel_size'], stride=self.t2t_params['stride'],
|
60 |
+
padding=self.t2t_params['padding']) / F.fold(normalizer,
|
61 |
+
output_size=(output_h, output_w),
|
62 |
+
kernel_size=self.t2t_params[
|
63 |
+
'kernel_size'],
|
64 |
+
stride=self.t2t_params['stride'],
|
65 |
+
padding=self.t2t_params[
|
66 |
+
'padding'])).permute(0,
|
67 |
+
2,
|
68 |
+
1).contiguous().view(
|
69 |
+
b, n, c)
|
70 |
+
else:
|
71 |
+
normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, self.n_vecs, self.kernel_shape).permute(0, 2, 1)
|
72 |
+
x = self.unfold(self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) / self.fold(normalizer)).permute(0,
|
73 |
+
2,
|
74 |
+
1).contiguous().view(
|
75 |
+
b, n, c)
|
76 |
+
x = self.conv2(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class ResidualBlock_noBN(nn.Module):
|
81 |
+
"""Residual block w/o BN
|
82 |
+
---Conv-ReLU-Conv-+-
|
83 |
+
|________________|
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, nf=64):
|
87 |
+
super(ResidualBlock_noBN, self).__init__()
|
88 |
+
self.conv1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=True)
|
89 |
+
self.conv2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=True)
|
90 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
"""
|
94 |
+
|
95 |
+
Args:
|
96 |
+
x: with shape of [b, c, t, h, w]
|
97 |
+
|
98 |
+
Returns: processed features with shape [b, c, t, h, w]
|
99 |
+
|
100 |
+
"""
|
101 |
+
identity = x
|
102 |
+
out = self.lrelu(self.conv1(x))
|
103 |
+
out = self.conv2(out)
|
104 |
+
out = identity + out
|
105 |
+
# Remove ReLU at the end of the residual block
|
106 |
+
# http://torch.ch/blog/2016/02/04/resnets.html
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
def make_layer(block, n_layers):
|
111 |
+
layers = []
|
112 |
+
for _ in range(n_layers):
|
113 |
+
layers.append(block())
|
114 |
+
return nn.Sequential(*layers)
|
FGT_codes/FGT/models/utils/RAFT/utils/__init__.py
ADDED
File without changes
|
FGT_codes/FGT/models/utils/RAFT/utils/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy import interpolate
|
5 |
+
|
6 |
+
|
7 |
+
class InputPadder:
|
8 |
+
""" Pads images such that dimensions are divisible by 8 """
|
9 |
+
def __init__(self, dims, mode='sintel'):
|
10 |
+
self.ht, self.wd = dims[-2:]
|
11 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
12 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
13 |
+
if mode == 'sintel':
|
14 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
15 |
+
else:
|
16 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
17 |
+
|
18 |
+
def pad(self, *inputs):
|
19 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
20 |
+
|
21 |
+
def unpad(self,x):
|
22 |
+
ht, wd = x.shape[-2:]
|
23 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
24 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
25 |
+
|
26 |
+
def forward_interpolate(flow):
|
27 |
+
flow = flow.detach().cpu().numpy()
|
28 |
+
dx, dy = flow[0], flow[1]
|
29 |
+
|
30 |
+
ht, wd = dx.shape
|
31 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
32 |
+
|
33 |
+
x1 = x0 + dx
|
34 |
+
y1 = y0 + dy
|
35 |
+
|
36 |
+
x1 = x1.reshape(-1)
|
37 |
+
y1 = y1.reshape(-1)
|
38 |
+
dx = dx.reshape(-1)
|
39 |
+
dy = dy.reshape(-1)
|
40 |
+
|
41 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
42 |
+
x1 = x1[valid]
|
43 |
+
y1 = y1[valid]
|
44 |
+
dx = dx[valid]
|
45 |
+
dy = dy[valid]
|
46 |
+
|
47 |
+
flow_x = interpolate.griddata(
|
48 |
+
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
49 |
+
|
50 |
+
flow_y = interpolate.griddata(
|
51 |
+
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
52 |
+
|
53 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
54 |
+
return torch.from_numpy(flow).float()
|
55 |
+
|
56 |
+
|
57 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
58 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
59 |
+
H, W = img.shape[-2:]
|
60 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
61 |
+
xgrid = 2*xgrid/(W-1) - 1
|
62 |
+
ygrid = 2*ygrid/(H-1) - 1
|
63 |
+
|
64 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
65 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
66 |
+
|
67 |
+
if mask:
|
68 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
69 |
+
return img, mask.float()
|
70 |
+
|
71 |
+
return img
|
72 |
+
|
73 |
+
|
74 |
+
def coords_grid(batch, ht, wd):
|
75 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
76 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
77 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
78 |
+
|
79 |
+
|
80 |
+
def upflow8(flow, mode='bilinear'):
|
81 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
82 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
FGT_codes/FGT/models/utils/__init__.py
ADDED
File without changes
|
FGT_codes/FGT/models/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (169 Bytes). View file
|
|
FGT_codes/FGT/models/utils/__pycache__/network_blocks_2d.cpython-39.pyc
ADDED
Binary file (5.41 kB). View file
|
|