File size: 5,609 Bytes
d380b77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
#!/usr/bin/env python3
import glob
import os
import shutil
import traceback
import PIL.Image as Image
import numpy as np
from joblib import Parallel, delayed
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
from saicinpainting.evaluation.utils import load_yaml, SmallMode
from saicinpainting.training.data.masks import MixedMaskGenerator
class MakeManyMasksWrapper:
def __init__(self, impl, variants_n=2):
self.impl = impl
self.variants_n = variants_n
def get_masks(self, img):
img = np.transpose(np.array(img), (2, 0, 1))
return [self.impl(img)[0] for _ in range(self.variants_n)]
def process_images(src_images, indir, outdir, config):
if config.generator_kind == 'segmentation':
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
elif config.generator_kind == 'random':
variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
variants_n=variants_n)
else:
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
max_tamper_area = config.get('max_tamper_area', 1)
for infile in src_images:
try:
file_relpath = infile[len(indir):]
img_outpath = os.path.join(outdir, file_relpath)
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
image = Image.open(infile).convert('RGB')
# scale input image to output resolution and filter smaller images
if min(image.size) < config.cropping.out_min_size:
handle_small_mode = SmallMode(config.cropping.handle_small_mode)
if handle_small_mode == SmallMode.DROP:
continue
elif handle_small_mode == SmallMode.UPSCALE:
factor = config.cropping.out_min_size / min(image.size)
out_size = (np.array(image.size) * factor).round().astype('uint32')
image = image.resize(out_size, resample=Image.BICUBIC)
else:
factor = config.cropping.out_min_size / min(image.size)
out_size = (np.array(image.size) * factor).round().astype('uint32')
image = image.resize(out_size, resample=Image.BICUBIC)
# generate and select masks
src_masks = mask_generator.get_masks(image)
filtered_image_mask_pairs = []
for cur_mask in src_masks:
if config.cropping.out_square_crop:
(crop_left,
crop_top,
crop_right,
crop_bottom) = propose_random_square_crop(cur_mask,
min_overlap=config.cropping.crop_min_overlap)
cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
else:
cur_image = image
if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
continue
filtered_image_mask_pairs.append((cur_image, cur_mask))
mask_indices = np.random.choice(len(filtered_image_mask_pairs),
size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
replace=False)
# crop masks; save masks together with input image
mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
for i, idx in enumerate(mask_indices):
cur_image, cur_mask = filtered_image_mask_pairs[idx]
cur_basename = mask_basename + f'_crop{i:03d}'
Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
mode='L').save(cur_basename + f'_mask{i:03d}.png')
cur_image.save(cur_basename + '.png')
except KeyboardInterrupt:
return
except Exception as ex:
print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
def main(args):
if not args.indir.endswith('/'):
args.indir += '/'
os.makedirs(args.outdir, exist_ok=True)
config = load_yaml(args.config)
in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
if args.n_jobs == 0:
process_images(in_files, args.indir, args.outdir, config)
else:
in_files_n = len(in_files)
chunk_size = in_files_n // args.n_jobs + (1 if in_files_n % args.n_jobs > 0 else 0)
Parallel(n_jobs=args.n_jobs)(
delayed(process_images)(in_files[start:start+chunk_size], args.indir, args.outdir, config)
for start in range(0, len(in_files), chunk_size)
)
if __name__ == '__main__':
import argparse
aparser = argparse.ArgumentParser()
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
aparser.add_argument('indir', type=str, help='Path to folder with images')
aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')
main(aparser.parse_args())
|