import os import random from concurrent.futures import ProcessPoolExecutor from pathlib import Path import json from PIL import Image import numpy as np import argparse from tqdm import tqdm # 인자 파싱 parser = argparse.ArgumentParser(description="Dataset creation for image colorization") parser.add_argument("--source_dir", type=str, required=True, help="Source directory") parser.add_argument( "--target_dir", type=str, required=True, help="Target directory for the dataset" ) parser.add_argument( "--resolution", type=int, default=512, help="Resolution for the dataset" ) args = parser.parse_args() # 경로 설정 root_dir = Path("E:/datasets") target_dir = root_dir / args.target_dir source_dir = root_dir / args.source_dir target_images_dir = target_dir / "images" target_conditioning_dir = target_dir / "conditioning_images" metadata_file = target_dir / "metadata.jsonl" # 디렉토리 생성 target_dir.mkdir(parents=True, exist_ok=True) target_images_dir.mkdir(exist_ok=True) target_conditioning_dir.mkdir(exist_ok=True) # 프롬프트 목록 prompts = [ "a color image, realistic style, photo", "a color image, high resolution, realistic, painting", "a color image, high resolution, realistic, photo", "very good quality, absurd, photo, color, 4k image", "high resolution, color, photo, realistic", "high resolution, color, photo, realistic, 4k image", "a color image, high resolution, realistic, 4k image", "color, high resolution, photo, realistic", "512x512, color, photo, realistic", ] def process_image(image_path): try: # 이미지 로드 및 크롭 with Image.open(image_path) as img: # 이미지 크기 확인 width, height = img.size size = min(width, height) left = (width - size) // 2 top = (height - size) // 2 right = left + size bottom = top + size # 크롭 및 리사이즈 img_cropped = img.crop((left, top, right, bottom)).resize( (args.resolution, args.resolution), Image.LANCZOS ) # 그레이스케일 변환 img_gray = img_cropped.convert("L") # 파일명 생성 filename = image_path.stem + ".jpg" # 이미지 저장 img_cropped.save(target_images_dir / filename) img_gray.save(target_conditioning_dir / filename) # 메타데이터 생성 metadata = { "image": str(filename), "text": random.choice(prompts), "conditioning_image": str(filename), } return metadata except Exception as e: print(f"Error processing {image_path}: {e}") return None def generate_dataset_loader(target_dir): # 대상 디렉토리의 이름을 가져옵니다 dir_name = target_dir.name # 클래스 이름을 생성합니다 (예: ciff_dataset -> CiffDataset) class_name = ''.join(word.capitalize() for word in dir_name.split('_')) # 파일 이름을 생성합니다 file_name = f"{dir_name}.py" # 파일 경로를 생성합니다 file_path = target_dir / file_name # 데이터셋 로더 코드를 생성합니다 code = f''' import pandas as pd from pathlib import Path import datasets import os _VERSION = datasets.Version("0.0.2") _DESCRIPTION = "TODO" _HOMEPAGE = "TODO" _LICENSE = "TODO" _CITATION = "TODO" _FEATURES = datasets.Features( {{ "image": datasets.Image(), "conditioning_image": datasets.Image(), "text": datasets.Value("string"), }} ) _DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION) class {class_name}(datasets.GeneratorBasedBuilder): BUILDER_CONFIGS = [_DEFAULT_CONFIG] DEFAULT_CONFIG_NAME = "default" def _info(self): return datasets.DatasetInfo( description=_DESCRIPTION, features=_FEATURES, supervised_keys=None, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION, ) def _split_generators(self, dl_manager): base_path = Path(dl_manager._base_path) metadata_path = base_path / "metadata.jsonl" images_dir = base_path / "images" conditioning_images_dir = base_path / "conditioning_images" return [ datasets.SplitGenerator( name=datasets.Split.TRAIN, gen_kwargs={{ "metadata_path": metadata_path, "images_dir": images_dir, "conditioning_images_dir": conditioning_images_dir, }}, ), ] def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir): metadata = pd.read_json(metadata_path, lines=True) for idx, row in metadata.iterrows(): text = row["text"] image_path = os.path.join(images_dir, row["image"]) image = open(image_path, "rb").read() conditioning_image_path = os.path.join(conditioning_images_dir, row["conditioning_image"]) conditioning_image = open(conditioning_image_path, "rb").read() yield idx, {{ "text": text, "image": {{ "path": image_path, "bytes": image, }}, "conditioning_image": {{ "path": conditioning_image_path, "bytes": conditioning_image, }}, }} ''' # 파일을 생성하고 코드를 작성합니다 with open(file_path, 'w') as f: f.write(code) print(f"데이터셋 로더 파일이 생성되었습니다: {file_path}") def main(): # 이미지 파일 목록 가져오기 image_files = list(source_dir.glob("*")) # 프로세스 수 설정 (CPU 코어 수 - 1) num_workers = (3 * os.cpu_count()) // 4 # 멀티프로세싱 실행 with ProcessPoolExecutor(max_workers=num_workers) as executor: results = list(tqdm(executor.map(process_image, image_files), total=len(image_files), desc="Processing images")) # 메타데이터 저장 with open(metadata_file, "w") as f: for metadata in results: if metadata: json.dump(metadata, f) f.write("\n") # 데이터셋 로더 파일 생성 generate_dataset_loader(target_dir) if __name__ == "__main__": main() print(f"Dataset creation completed. Output directory: {target_dir}")