|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
import os |
|
from os.path import * |
|
from PIL import Image |
|
|
|
|
|
class ImageSet(object): |
|
""" Base class for an image dataset. |
|
""" |
|
def __init__(self, root, imgs): |
|
self.root = root |
|
self.imgs = imgs |
|
assert imgs, f'Empty image set in {root}' |
|
|
|
def init_from_folder(self, *args, **kw): |
|
imset = ImageSet.from_folder(*args, **kw) |
|
ImageSet.__init__(self, imset.root, imset.imgs) |
|
|
|
def __len__(self): |
|
return len(self.imgs) |
|
|
|
def get_image_path(self, idx): |
|
return os.path.join(self.root, self.imgs[idx]) |
|
|
|
def get_image(self, idx): |
|
fname = self.get_image_path(idx) |
|
try: |
|
return Image.open(fname).convert('RGB') |
|
except Exception as e: |
|
raise IOError("Could not load image %s (reason: %s)" % (fname, str(e))) |
|
|
|
__getitem__ = get_image |
|
|
|
@staticmethod |
|
def from_folder(root, exts=('.jpg','.jpeg','.png','.ppm'), recursive=False, listing=False, check_imgs=False): |
|
""" |
|
recursive: bool or func. If a function, it must evaluate True to the directory name. |
|
""" |
|
if listing: |
|
if listing is True: listing = f"list_imgs{'_recursive' if recursive else ''}.txt" |
|
flist = join(root, listing) |
|
try: return ImageSet.from_listing(root,flist) |
|
except IOError: print(f'>> ImageSet.from_folder(listing=True): entering {root}...') |
|
|
|
if check_imgs is True: |
|
check_imgs = verify_img |
|
|
|
for _, dirnames, dirfiles in os.walk(root): |
|
imgs = sorted([f for f in dirfiles if f.lower().endswith(exts)]) |
|
if check_imgs: imgs = [img for img in imgs if check_imgs(join(root,img))] |
|
|
|
if recursive: |
|
for dirname in sorted(dirnames): |
|
if callable(recursive) and not recursive(join(root,dirname)): continue |
|
imset = ImageSet.from_folder(join(root,dirname), exts=exts, recursive=recursive, listing=listing, check_imgs=check_imgs) |
|
imgs += [join(dirname,f) for f in imset.imgs] |
|
break |
|
|
|
if listing: |
|
try: open(flist,'w').write('\n'.join(imgs)) |
|
except IOError: pass |
|
return ImageSet(root, imgs) |
|
|
|
@staticmethod |
|
def from_listing(root, list_path): |
|
return ImageSet(root, open(list_path).read().splitlines()) |
|
|
|
def circular_pad(self, min_size): |
|
assert self.imgs, 'cannot pad an empty image set' |
|
while len(self.imgs) < min_size: |
|
self.imgs += self.imgs |
|
self.imgs = self.imgs[:min_size or None] |
|
return self |
|
|
|
def __repr__(self): |
|
prefix = os.path.commonprefix((self.get_image_path(0),self.get_image_path(len(self)-1))) |
|
return f'{self.__class__.__name__}({len(self)} images from {prefix}...)' |
|
|
|
|
|
|
|
def verify_img(path, exts=None): |
|
if exts and not path.lower().endswith(exts): return False |
|
try: |
|
Image.open(path).convert('RGB') |
|
return True |
|
except: |
|
return False |
|
|