Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import random | |
| from src.utils.config import DATA_DIR, SPLIT_DIR, SEED | |
| def split_clean_data(train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=SEED): | |
| """ | |
| Split cleaned data into train/val/test splits under data/split/. | |
| """ | |
| random.seed(seed) | |
| clean_dir = os.path.join(DATA_DIR, "clean") | |
| # Remove previous split if exists | |
| if os.path.exists(SPLIT_DIR): | |
| shutil.rmtree(SPLIT_DIR) | |
| os.makedirs(SPLIT_DIR) | |
| for crop in os.listdir(clean_dir): | |
| crop_path = os.path.join(clean_dir, crop) | |
| for disease_folder in os.listdir(crop_path): | |
| disease_path = os.path.join(crop_path, disease_folder) | |
| images = os.listdir(disease_path) | |
| if len(images) == 0: | |
| print(f"[WARNING] No images found in {disease_path}, skipping.") | |
| continue | |
| random.shuffle(images) | |
| n_total = len(images) | |
| n_train = int(n_total * train_ratio) | |
| n_val = int(n_total * val_ratio) | |
| # Safety check: at least 1 sample in each split | |
| if n_train == 0 or n_val == 0 or (n_total - n_train - n_val) == 0: | |
| print(f"[WARNING] Not enough images to split {disease_path} properly, skipping.") | |
| continue | |
| splits = { | |
| "train": images[:n_train], | |
| "val": images[n_train:n_train+n_val], | |
| "test": images[n_train+n_val:] | |
| } | |
| for split_name, split_images in splits.items(): | |
| target_dir = os.path.join(SPLIT_DIR, split_name, crop, disease_folder) | |
| os.makedirs(target_dir, exist_ok=True) | |
| for img_name in split_images: | |
| src_img = os.path.join(disease_path, img_name) | |
| dst_img = os.path.join(target_dir, img_name) | |
| shutil.copy(src_img, dst_img) | |
| print("[INFO] Finished creating train/val/test split.") |