Colorization / ciff_dataset.py
Noename's picture
add app
37b43d6
raw
history blame
6.61 kB
import os
import random
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
import json
from PIL import Image
import numpy as np
import argparse
from tqdm import tqdm
# 인자 νŒŒμ‹±
parser = argparse.ArgumentParser(description="Dataset creation for image colorization")
parser.add_argument("--source_dir", type=str, required=True, help="Source directory")
parser.add_argument(
"--target_dir", type=str, required=True, help="Target directory for the dataset"
)
parser.add_argument(
"--resolution", type=int, default=512, help="Resolution for the dataset"
)
args = parser.parse_args()
# 경둜 μ„€μ •
root_dir = Path("E:/datasets")
target_dir = root_dir / args.target_dir
source_dir = root_dir / args.source_dir
target_images_dir = target_dir / "images"
target_conditioning_dir = target_dir / "conditioning_images"
metadata_file = target_dir / "metadata.jsonl"
# 디렉토리 생성
target_dir.mkdir(parents=True, exist_ok=True)
target_images_dir.mkdir(exist_ok=True)
target_conditioning_dir.mkdir(exist_ok=True)
# ν”„λ‘¬ν”„νŠΈ λͺ©λ‘
prompts = [
"a color image, realistic style, photo",
"a color image, high resolution, realistic, painting",
"a color image, high resolution, realistic, photo",
"very good quality, absurd, photo, color, 4k image",
"high resolution, color, photo, realistic",
"high resolution, color, photo, realistic, 4k image",
"a color image, high resolution, realistic, 4k image",
"color, high resolution, photo, realistic",
"512x512, color, photo, realistic",
]
def process_image(image_path):
try:
# 이미지 λ‘œλ“œ 및 크둭
with Image.open(image_path) as img:
# 이미지 크기 확인
width, height = img.size
size = min(width, height)
left = (width - size) // 2
top = (height - size) // 2
right = left + size
bottom = top + size
# 크둭 및 λ¦¬μ‚¬μ΄μ¦ˆ
img_cropped = img.crop((left, top, right, bottom)).resize(
(args.resolution, args.resolution), Image.LANCZOS
)
# κ·Έλ ˆμ΄μŠ€μΌ€μΌ λ³€ν™˜
img_gray = img_cropped.convert("L")
# 파일λͺ… 생성
filename = image_path.stem + ".jpg"
# 이미지 μ €μž₯
img_cropped.save(target_images_dir / filename)
img_gray.save(target_conditioning_dir / filename)
# 메타데이터 생성
metadata = {
"image": str(filename),
"text": random.choice(prompts),
"conditioning_image": str(filename),
}
return metadata
except Exception as e:
print(f"Error processing {image_path}: {e}")
return None
def generate_dataset_loader(target_dir):
# λŒ€μƒ λ””λ ‰ν† λ¦¬μ˜ 이름을 κ°€μ Έμ˜΅λ‹ˆλ‹€
dir_name = target_dir.name
# 클래슀 이름을 μƒμ„±ν•©λ‹ˆλ‹€ (예: ciff_dataset -> CiffDataset)
class_name = ''.join(word.capitalize() for word in dir_name.split('_'))
# 파일 이름을 μƒμ„±ν•©λ‹ˆλ‹€
file_name = f"{dir_name}.py"
# 파일 경둜λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€
file_path = target_dir / file_name
# 데이터셋 λ‘œλ” μ½”λ“œλ₯Ό μƒμ„±ν•©λ‹ˆλ‹€
code = f'''
import pandas as pd
from pathlib import Path
import datasets
import os
_VERSION = datasets.Version("0.0.2")
_DESCRIPTION = "TODO"
_HOMEPAGE = "TODO"
_LICENSE = "TODO"
_CITATION = "TODO"
_FEATURES = datasets.Features(
{{
"image": datasets.Image(),
"conditioning_image": datasets.Image(),
"text": datasets.Value("string"),
}}
)
_DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
class {class_name}(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [_DEFAULT_CONFIG]
DEFAULT_CONFIG_NAME = "default"
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=_FEATURES,
supervised_keys=None,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
base_path = Path(dl_manager._base_path)
metadata_path = base_path / "metadata.jsonl"
images_dir = base_path / "images"
conditioning_images_dir = base_path / "conditioning_images"
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={{
"metadata_path": metadata_path,
"images_dir": images_dir,
"conditioning_images_dir": conditioning_images_dir,
}},
),
]
def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir):
metadata = pd.read_json(metadata_path, lines=True)
for idx, row in metadata.iterrows():
text = row["text"]
image_path = os.path.join(images_dir, row["image"])
image = open(image_path, "rb").read()
conditioning_image_path = os.path.join(conditioning_images_dir, row["conditioning_image"])
conditioning_image = open(conditioning_image_path, "rb").read()
yield idx, {{
"text": text,
"image": {{
"path": image_path,
"bytes": image,
}},
"conditioning_image": {{
"path": conditioning_image_path,
"bytes": conditioning_image,
}},
}}
'''
# νŒŒμΌμ„ μƒμ„±ν•˜κ³  μ½”λ“œλ₯Ό μž‘μ„±ν•©λ‹ˆλ‹€
with open(file_path, 'w') as f:
f.write(code)
print(f"데이터셋 λ‘œλ” 파일이 μƒμ„±λ˜μ—ˆμŠ΅λ‹ˆλ‹€: {file_path}")
def main():
# 이미지 파일 λͺ©λ‘ κ°€μ Έμ˜€κΈ°
image_files = list(source_dir.glob("*"))
# ν”„λ‘œμ„ΈμŠ€ 수 μ„€μ • (CPU μ½”μ–΄ 수 - 1)
num_workers = (3 * os.cpu_count()) // 4
# λ©€ν‹°ν”„λ‘œμ„Έμ‹± μ‹€ν–‰
with ProcessPoolExecutor(max_workers=num_workers) as executor:
results = list(tqdm(executor.map(process_image, image_files), total=len(image_files), desc="Processing images"))
# 메타데이터 μ €μž₯
with open(metadata_file, "w") as f:
for metadata in results:
if metadata:
json.dump(metadata, f)
f.write("\n")
# 데이터셋 λ‘œλ” 파일 생성
generate_dataset_loader(target_dir)
if __name__ == "__main__":
main()
print(f"Dataset creation completed. Output directory: {target_dir}")