|
import os |
|
import random |
|
import shutil |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
# Define paths |
|
dataset_folder = 'path/to/dataset' |
|
train_folder = os.path.join(dataset_folder, 'train') |
|
val_folder = os.path.join(dataset_folder, 'validation') |
|
|
|
# Create validation folder if it doesn't exist |
|
os.makedirs(val_folder, exist_ok=True) |
|
|
|
# Get all label folders inside train folder |
|
label_folders = [f for f in os.listdir(train_folder) if os.path.isdir(os.path.join(train_folder, f))] |
|
|
|
# Function to move images from a specific label folder |
|
def process_label_folder(label_folder, num_threads): |
|
train_label_folder = os.path.join(train_folder, label_folder) |
|
val_label_folder = os.path.join(val_folder, label_folder) |
|
|
|
# Create corresponding validation label folder |
|
os.makedirs(val_label_folder, exist_ok=True) |
|
|
|
# Get all images in the train/label_folder |
|
all_images = os.listdir(train_label_folder) |
|
total_images = len(all_images) |
|
|
|
# Calculate 20% of images for validation |
|
val_size = int(total_images * 0.2) |
|
|
|
# Randomly select 20% of the images for validation |
|
val_images = random.sample(all_images, val_size) |
|
|
|
# Function to move a single image |
|
def move_image(image): |
|
src = os.path.join(train_label_folder, image) |
|
dest = os.path.join(val_label_folder, image) |
|
shutil.move(src, dest) |
|
|
|
# Use ThreadPoolExecutor to move images in parallel |
|
with ThreadPoolExecutor(max_workers=num_threads) as executor: |
|
executor.map(move_image, val_images) |
|
|
|
print(f"Moved {val_size} images from {label_folder} to validation folder.") |
|
|
|
# Main function to get user input for number of threads and process folders |
|
def main(): |
|
# Ask user for the number of threads |
|
num_threads = int(input("Enter the number of threads to use: ")) |
|
|
|
# Process each label folder using the input number of threads |
|
for label_folder in label_folders: |
|
process_label_folder(label_folder, num_threads) |
|
|
|
print("Validation dataset created.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|