42028-CNN-A2 / detect_yolo.py
Ziruibest's picture
Upload folder CNNa2
938fb27 verified
import torch
# import torchvision # Keep if needed for image loading/manipulation during conversion
# import torchvision.transforms.v2 as T # Keep if needed during conversion
from torch.utils.data import DataLoader, Dataset # Keep if needed during conversion
import os
import json
from PIL import Image
import matplotlib.pyplot as plt # Keep for visualization if desired
# import matplotlib.patches as patches
import glob
import xml.etree.ElementTree as ET
import logging
import numpy as np
# from pycocotools.coco import COCO # Keep if reading COCO format
# from pycocotools.cocoeval import COCOeval # YOLOv5 handles evaluation internally
import argparse
import yaml # Keep for reading/writing YAML
import sys
from pathlib import Path
import random
import shutil # For file operations during conversion
import subprocess # To call YOLOv5 training script
# --- 命令行参数解析 ---
parser = argparse.ArgumentParser(description='目标检测训练脚本 (YOLOv5)')
parser.add_argument('--dataset_type', type=str, default='coco', choices=['coco', 'pascal', 'yolo'],
help='原始数据集类型:coco, pascal, 或 yolo')
parser.add_argument('--yolov5_model', type=str, default='yolov5s',
help='YOLOv5 模型类型 (e.g., yolov5s, yolov5m, yolov5l, yolov5x, or path to .pt)')
parser.add_argument('--img_size', type=int, default=640, help='YOLOv5 训练图像尺寸')
parser.add_argument('--batch_size', type=int, default=16, help='训练批次大小 (adjust based on GPU memory)')
# parser.add_argument('--eval_batch_size', type=int, default=16, help='评估批次大小 (YOLOv5 handles this)')
parser.add_argument('--epochs', type=int, default=50, help='训练轮数')
parser.add_argument('--lr', type=float, default=0.01, help='初始学习率 (YOLOv5 train.py default)') # Note: YOLOv5 has its own defaults
parser.add_argument('--output_dir', type=str, default='./yolov5_training_results', help='YOLOv5 训练输出项目目录')
parser.add_argument('--run_name', type=str, default='exp', help='YOLOv5 训练运行名称')
# parser.add_argument('--logging_steps', type=int, default=50, help='日志记录步数 (YOLOv5 handles this)')
# parser.add_argument('--eval_steps', type=int, default=200, help='评估步数 (YOLOv5 handles this)')
# parser.add_argument('--save_steps', type=int, default=200, help='模型保存步数 (YOLOv5 handles this)')
parser.add_argument('--subset_ratio', type=float, default=1.0, help='使用原始数据集的比例 (用于快速转换测试, 1.0 使用全部)')
parser.add_argument('--yolov5_repo_path', type=str, default='./yolov5', help='本地 YOLOv5 代码库路径') # Important: Path to cloned yolov5 repo
parser.add_argument('--converted_data_dir', type=str, default='./yolov5_data', help='转换后 YOLOv5 格式数据的存放目录')
parser.add_argument('--workers', type=int, default=8, help='YOLOv5 dataloader workers')
args = parser.parse_args()
DATASET_TYPE = args.dataset_type
# MODEL_CHECKPOINT = args.model_checkpoint # Removed
YOLOV5_MODEL = args.yolov5_model
IMG_SIZE = args.img_size
BATCH_SIZE = args.batch_size
# EVAL_BATCH_SIZE = args.eval_batch_size # Removed
NUM_EPOCHS = args.epochs
LEARNING_RATE = args.lr # Can be passed to train.py if needed, otherwise default is used
OUTPUT_DIR = Path(args.output_dir) # Use Path object
RUN_NAME = args.run_name
# LOGGING_STEPS = args.logging_steps # Removed
# EVAL_STEPS = args.eval_steps # Removed
# SAVE_STEPS = args.save_steps # Removed
SUBSET_RATIO = args.subset_ratio
YOLOV5_REPO_PATH = Path(args.yolov5_repo_path)
CONVERTED_DATA_DIR = Path(args.converted_data_dir)
WORKERS = args.workers
print(f"选择的原始数据集类型: {DATASET_TYPE}")
logging.info(f"选择的原始数据集类型: {DATASET_TYPE}")
print(f"使用的 YOLOv5 模型: {YOLOV5_MODEL}")
logging.info(f"使用的 YOLOv5 模型: {YOLOV5_MODEL}")
print(f"转换后的数据目录: {CONVERTED_DATA_DIR}")
logging.info(f"转换后的数据目录: {CONVERTED_DATA_DIR}")
print(f"YOLOv5 代码库路径: {YOLOV5_REPO_PATH}")
logging.info(f"YOLOv5 代码库路径: {YOLOV5_REPO_PATH}")
# --- 1. 配置和超参数 (类别加载部分保留) ---
# --- 数据集路径定义 (保持不变) ---
COCO_BASE_PATH = Path('/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/coco')
COCO_TRAIN_IMG_DIR = COCO_BASE_PATH / 'train'
COCO_TRAIN_ANN_FILE = COCO_BASE_PATH / 'train/train_annotations.json'
COCO_VAL_IMG_DIR = COCO_BASE_PATH / 'valid'
COCO_VAL_ANN_FILE = COCO_BASE_PATH / 'valid/valid_annotations.json'
COCO_TEST_IMG_DIR = COCO_BASE_PATH / 'test'
COCO_TEST_ANN_FILE = COCO_BASE_PATH / 'test/test_annotations.json' # Optional
PASCAL_BASE_PATH = Path('/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/pascal')
PASCAL_YEAR = '2007'
YOLO_BASE_PATH = Path('/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/yolo')
YOLO_CLASSES_FILE = YOLO_BASE_PATH / 'data.yaml'
# --- 类别定义 (保持不变, 用于生成 data.yaml) ---
label2id = {}
id2label = {}
NUM_CLASSES = 0
class_names = [] # Need a list of names for data.yaml
# (Keep the functions load_coco_categories, setup_pascal_categories, load_yolo_classes)
# Minor modification: Store class names in a list
def load_coco_categories(ann_file):
global label2id, id2label, NUM_CLASSES, class_names
with open(ann_file, 'r') as f:
data = json.load(f)
categories = sorted(data['categories'], key=lambda x: x['id'])
label2id = {cat['name']: idx for idx, cat in enumerate(categories)}
id2label = {idx: cat['name'] for idx, cat in enumerate(categories)}
class_names = [cat['name'] for cat in categories]
NUM_CLASSES = len(categories)
print("COCO Categories Loaded:")
print(f" Num Classes: {NUM_CLASSES}")
print(f" Class Names: {class_names}")
return id2label, label2id, NUM_CLASSES, class_names
def setup_pascal_categories():
global label2id, id2label, NUM_CLASSES, class_names
PASCAL_CLASS_NAMES = ["young", "empty_pod"]
label2id = {name: idx for idx, name in enumerate(PASCAL_CLASS_NAMES)}
id2label = {idx: name for idx, name in enumerate(PASCAL_CLASS_NAMES)}
class_names = PASCAL_CLASS_NAMES
NUM_CLASSES = len(PASCAL_CLASS_NAMES)
print("Pascal VOC Categories Setup:")
print(f" Num Classes: {NUM_CLASSES}")
print(f" Class Names: {class_names}")
return id2label, label2id, NUM_CLASSES, class_names
def load_yolo_classes(classes_file=None):
global label2id, id2label, NUM_CLASSES, class_names
classes = []
if classes_file and classes_file.exists():
with open(classes_file, 'r') as f:
classes = [line.strip() for line in f.readlines() if line.strip()]
else:
yaml_file = YOLO_BASE_PATH / 'data.yaml'
if yaml_file.exists():
try:
import yaml
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
if 'names' in data and isinstance(data['names'], list):
classes = data['names']
except Exception as e:
print(f"从YAML加载类别失败: {e}, 使用默认值")
if not classes:
classes = ['Ready', 'empty_pod', 'germination', 'pod', 'young']
print("Using default YOLO classes.")
label2id = {name: idx for idx, name in enumerate(classes)}
id2label = {idx: name for idx, name in enumerate(classes)}
class_names = classes
NUM_CLASSES = len(classes)
print("YOLO Categories Loaded:")
print(f" Num Classes: {NUM_CLASSES}")
print(f" Class Names: {class_names}")
return id2label, label2id, NUM_CLASSES, class_names
# 根据数据集类型加载相应的类别信息
if DATASET_TYPE == 'coco':
ann_file_to_load = COCO_TRAIN_ANN_FILE if COCO_TRAIN_ANN_FILE.exists() else COCO_VAL_ANN_FILE
if ann_file_to_load.exists():
id2label, label2id, NUM_CLASSES, class_names = load_coco_categories(ann_file_to_load)
else:
raise FileNotFoundError("无法找到COCO注释文件来加载类别")
elif DATASET_TYPE == 'pascal':
id2label, label2id, NUM_CLASSES, class_names = setup_pascal_categories()
elif DATASET_TYPE == 'yolo':
id2label, label2id, NUM_CLASSES, class_names = load_yolo_classes(YOLO_CLASSES_FILE)
else:
raise ValueError(f"不支持的数据集类型: {DATASET_TYPE}")
if NUM_CLASSES == 0:
raise ValueError("未能加载类别信息,类别数量为0")
print(f"最终类别数量 (Num Classes): {NUM_CLASSES}")
logging.info(f"最终类别数量 (Num Classes): {NUM_CLASSES}")
print(f"ID to Label mapping: {id2label}")
logging.info(f"ID to Label mapping: {id2label}")
print(f"Class Names List: {class_names}")
logging.info(f"Class Names List: {class_names}")
# --- 训练参数 ---
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"使用的设备: {DEVICE}")
logging.info(f"使用的设备: {DEVICE}")
CONVERTED_DATA_DIR.mkdir(parents=True, exist_ok=True)
(CONVERTED_DATA_DIR / 'images' / 'train').mkdir(parents=True, exist_ok=True)
(CONVERTED_DATA_DIR / 'images' / 'val').mkdir(parents=True, exist_ok=True)
(CONVERTED_DATA_DIR / 'images' / 'test').mkdir(parents=True, exist_ok=True)
(CONVERTED_DATA_DIR / 'labels' / 'train').mkdir(parents=True, exist_ok=True)
(CONVERTED_DATA_DIR / 'labels' / 'val').mkdir(parents=True, exist_ok=True)
(CONVERTED_DATA_DIR / 'labels' / 'test').mkdir(parents=True, exist_ok=True)
def convert_to_yolo_format(image_path, annotations, img_width, img_height, label_map, output_label_path):
""" Converts annotations for a single image to a YOLO format txt file. """
yolo_lines = []
for ann in annotations:
# Assuming ann is a dict with 'bbox'=[xmin, ymin, xmax, ymax] and 'category_id' (0-based)
label_id = ann['category_id']
box = ann['bbox']
# Convert [xmin, ymin, xmax, ymax] to YOLO format [center_x, center_y, width, height] (normalized)
x_min, y_min, x_max, y_max = box
center_x = ((x_min + x_max) / 2) / img_width
center_y = ((y_min + y_max) / 2) / img_height
width = (x_max - x_min) / img_width
height = (y_max - y_min) / img_height
# Clamp values to [0.0, 1.0]
center_x = max(0.0, min(1.0, center_x))
center_y = max(0.0, min(1.0, center_y))
width = max(0.0, min(1.0, width))
height = max(0.0, min(1.0, height))
# Ensure valid box dimensions after conversion
if width > 0 and height > 0:
yolo_lines.append(f"{label_id} {center_x:.6f} {center_y:.6f} {width:.6f} {height:.6f}")
# Write to label file
if yolo_lines:
with open(output_label_path, 'w') as f:
f.write("\n".join(yolo_lines))
# else: # Optional: create empty file if no objects
# output_label_path.touch()
def process_dataset(split, use_subset=False, subset_ratio=1.0):
"""Processes a dataset split (train, val, test) and converts it."""
print(f"\nProcessing {split} split...")
output_img_dir = CONVERTED_DATA_DIR / 'images' / split
output_label_dir = CONVERTED_DATA_DIR / 'labels' / split
items_to_process = [] # List of tuples: (img_path, annotations, width, height, original_id)
if DATASET_TYPE == 'coco':
img_dir = COCO_TRAIN_IMG_DIR if split == 'train' else (COCO_VAL_IMG_DIR if split == 'val' else COCO_TEST_IMG_DIR)
ann_file = COCO_TRAIN_ANN_FILE if split == 'train' else (COCO_VAL_ANN_FILE if split == 'val' else COCO_TEST_ANN_FILE)
if not ann_file.exists():
print(f"Warning: Annotation file {ann_file} not found for split {split}. Skipping.")
return
from pycocotools.coco import COCO # Import locally if needed
coco = COCO(str(ann_file))
image_ids = list(sorted(coco.imgs.keys()))
if use_subset:
num_subset = int(len(image_ids) * subset_ratio)
image_ids = random.sample(image_ids, num_subset)
print(f"Using a subset of {num_subset} COCO {split} images.")
# Create category mapping (COCO ID -> 0-based ID)
coco_cat_id_to_label_id = {
coco_cat_id: label2id[cat['name']]
for coco_cat_id, cat in coco.cats.items() if cat['name'] in label2id
}
for img_id in image_ids:
img_info = coco.loadImgs(img_id)[0]
original_img_path = img_dir / img_info['file_name']
width, height = img_info['width'], img_info['height']
ann_ids = coco.getAnnIds(imgIds=img_id)
coco_anns = coco.loadAnns(ann_ids)
target_annotations = []
for ann in coco_anns:
if ann.get('iscrowd', 0) or ann.get('ignore', 0):
continue
coco_cat_id = ann['category_id']
if coco_cat_id in coco_cat_id_to_label_id:
label_id = coco_cat_id_to_label_id[coco_cat_id]
x, y, w, h = ann['bbox']
if w > 0 and h > 0:
target_annotations.append({
"bbox": [x, y, x + w, y + h],
"category_id": label_id,
})
items_to_process.append((original_img_path, target_annotations, width, height, img_id))
elif DATASET_TYPE == 'pascal':
# Pascal VOC has predefined splits (trainval, test). Let's map train -> train, valid -> trainval? Or just train/val.
# Assuming 'train', 'valid', 'test' maps to corresponding directories under PASCAL_BASE_PATH
if split == 'val':
data_dir = PASCAL_BASE_PATH / 'valid'
else:
data_dir = PASCAL_BASE_PATH / split
if not data_dir.exists():
print(f"Warning: Pascal directory {data_dir} not found for split {split}. Skipping.")
return
all_files = list(data_dir.glob('*.jpg')) + list(data_dir.glob('*.png'))
image_files = []
for img_file in all_files:
xml_file = data_dir / f"{img_file.stem}.xml"
if xml_file.exists():
image_files.append(img_file)
if use_subset:
num_subset = int(len(image_files) * subset_ratio)
image_files = random.sample(image_files, num_subset)
print(f"Using a subset of {num_subset} Pascal {split} images.")
name_to_id = label2id
for index, img_path in enumerate(image_files):
xml_path = data_dir / f"{img_path.stem}.xml"
try:
img = Image.open(img_path)
img_width, img_height = img.size
tree = ET.parse(xml_path)
root = tree.getroot()
target_annotations = []
for obj in root.findall('object'):
label_name = obj.find('name').text
if label_name in name_to_id:
label_id = name_to_id[label_name]
bbox = obj.find('bndbox')
xmin = float(bbox.find('xmin').text)
ymin = float(bbox.find('ymin').text)
xmax = float(bbox.find('xmax').text)
ymax = float(bbox.find('ymax').text)
if xmax > xmin and ymax > ymin:
target_annotations.append({
"bbox": [xmin, ymin, xmax, ymax],
"category_id": label_id,
})
items_to_process.append((img_path, target_annotations, img_width, img_height, index))
except Exception as e:
print(f"Error processing Pascal file {img_path} or {xml_path}: {e}")
elif DATASET_TYPE == 'yolo':
# 修正验证集目录名称
if split == 'val':
img_dir = YOLO_BASE_PATH / 'valid' / 'images' # 使用'valid'而不是'val'
label_dir = YOLO_BASE_PATH / 'valid' / 'labels'
else:
img_dir = YOLO_BASE_PATH / split / 'images'
label_dir = YOLO_BASE_PATH / split / 'labels'
if not img_dir.exists() or not label_dir.exists():
print(f"Warning: YOLO directory {img_dir} or {label_dir} not found for split {split}. Skipping.")
return
img_files = sorted(list(img_dir.glob('*.jpg')) + list(img_dir.glob('*.png')))
valid_files = []
for img_path in img_files:
label_path = label_dir / f"{img_path.stem}.txt"
if label_path.exists():
valid_files.append((img_path, label_path))
if use_subset:
num_subset = int(len(valid_files) * subset_ratio)
valid_files = random.sample(valid_files, num_subset)
print(f"Using a subset of {num_subset} YOLO {split} images.")
print(f"Copying {len(valid_files)} YOLO {split} files...")
for img_path, label_path in valid_files:
# Just copy files directly as they are already in YOLO format
shutil.copy(img_path, output_img_dir / img_path.name)
shutil.copy(label_path, output_label_dir / label_path.name)
# No further processing needed for YOLO -> YOLO conversion
return # Skip the conversion loop below
else:
raise ValueError(f"Unsupported dataset type for conversion: {DATASET_TYPE}")
# --- Conversion Loop ---
print(f"Converting {len(items_to_process)} items for {split} split...")
processed_count = 0
for img_path, annotations, width, height, original_id in items_to_process:
if not img_path.exists():
print(f"Warning: Image file not found: {img_path}. Skipping.")
continue
# Define output paths
output_img_path = output_img_dir / img_path.name
output_label_path = output_label_dir / f"{img_path.stem}.txt"
# 1. Copy image file
shutil.copy(img_path, output_img_path)
# 2. Convert annotations and write label file
convert_to_yolo_format(img_path, annotations, width, height, label2id, output_label_path)
processed_count += 1
print(f"Finished converting {processed_count} items for {split} split.")
# --- Run Conversion for Train/Val/Test splits ---
use_subset_flag = SUBSET_RATIO < 1.0
process_dataset('train', use_subset=use_subset_flag, subset_ratio=SUBSET_RATIO)
process_dataset('val', use_subset=use_subset_flag, subset_ratio=SUBSET_RATIO) # Use subset for val too if desired
process_dataset('test', use_subset=False) # Usually convert full test set
# --- 3. Create data.yaml file ---
data_yaml_path = CONVERTED_DATA_DIR / 'data.yaml'
data_yaml_content = {
'train': str((CONVERTED_DATA_DIR / 'images' / 'train').resolve()), # Use absolute paths
'val': str((CONVERTED_DATA_DIR / 'images' / 'val').resolve()),
'test': str((CONVERTED_DATA_DIR / 'images' / 'test').resolve()) if (CONVERTED_DATA_DIR / 'images' / 'test').iterdir() else '', # Optional test path
'nc': NUM_CLASSES,
'names': class_names
}
with open(data_yaml_path, 'w') as f:
yaml.dump(data_yaml_content, f, default_flow_style=None, sort_keys=False)
print(f"\nCreated data configuration file: {data_yaml_path}")
print("Dataset conversion to YOLOv5 format complete.")
# --- 4. Execute YOLOv5 Training ---
print("\nStarting YOLOv5 training...")
# Ensure YOLOv5 repository exists
if not (YOLOV5_REPO_PATH / 'train.py').exists():
print(f"Error: YOLOv5 train.py not found in {YOLOV5_REPO_PATH}")
print("Please clone the YOLOv5 repository: git clone https://github.com/ultralytics/yolov5.git")
print("And ensure --yolov5_repo_path points to it.")
exit(1)
# Construct the training command
# Note: We use --project for the main output folder and --name for the specific run subfolder.
# YOLOv5 typically handles learning rates and schedulers internally based on defaults/hyp.yaml,
# but you can override with arguments like --lr0 if needed.
train_cmd = [
sys.executable, # Use the current Python interpreter
str(YOLOV5_REPO_PATH / 'train.py'),
'--img', str(IMG_SIZE),
'--batch', str(BATCH_SIZE),
'--epochs', str(NUM_EPOCHS),
'--data', str(data_yaml_path.resolve()), # Absolute path to data.yaml
'--weights', f"{YOLOV5_MODEL}.pt" if not Path(YOLOV5_MODEL).suffix == '.pt' else YOLOV5_MODEL, # Handle model name vs path
'--project', str(OUTPUT_DIR.resolve()), # Main output directory
'--name', RUN_NAME, # Specific experiment name under project dir
'--workers', str(WORKERS),
# '--device', '0' # Specify GPU ID if needed, otherwise YOLOv5 auto-detects
# '--lr0', str(LEARNING_RATE) # Optionally override initial LR
'--exist-ok' # Allow reusing existing project/name directory
]
# Add device argument if using GPU
if DEVICE.type == 'cuda':
# YOLOv5 usually expects device index, e.g., '0' or '0,1'
gpu_index = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
train_cmd.extend(['--device', str(gpu_index)])
print("\nExecuting command:")
print(" ".join(train_cmd))
# Execute the training command
try:
subprocess.run(train_cmd, check=True)
print("\nYOLOv5 training finished successfully.")
print(f"Results saved in: {OUTPUT_DIR / RUN_NAME}")
# --- Optional: Run validation explicitly on the best weights ---
best_weights_path = OUTPUT_DIR / RUN_NAME / 'weights' / 'best.pt'
if best_weights_path.exists():
print("\nRunning validation on best weights...")
val_cmd = [
sys.executable,
str(YOLOV5_REPO_PATH / 'val.py'),
'--img', str(IMG_SIZE),
'--batch', str(BATCH_SIZE * 2), # Can often use larger batch for validation
'--data', str(data_yaml_path.resolve()),
'--weights', str(best_weights_path.resolve()),
'--project', str(OUTPUT_DIR.resolve()),
'--name', f'{RUN_NAME}_val_best',
'--task', 'val',
'--workers', str(WORKERS),
'--device', str(gpu_index) if DEVICE.type == 'cuda' else 'cpu',
'--exist-ok'
]
print("\nExecuting command:")
print(" ".join(val_cmd))
subprocess.run(val_cmd, check=True)
print(f"Validation results saved in: {OUTPUT_DIR / f'{RUN_NAME}_val_best'}")
else:
print("Could not find best.pt for validation.")
except subprocess.CalledProcessError as e:
print(f"\nYOLOv5 training/validation failed with error code {e.returncode}")
print(e)
except FileNotFoundError:
print(f"Error: Could not find Python interpreter or train.py/val.py script.")
print("Ensure Python is in your PATH and --yolov5_repo_path is correct.")
# #python your_modified_script_name.py \
# --dataset_type coco \ # 或 pascal, yolo
# --yolov5_model yolov5s \
# --img_size 640 \
# --batch_size 16 \
# --epochs 50 \
# --output_dir ./yolov5_finetuned_coco \
# --run_name coco_run1 \
# --yolov5_repo_path ./yolov5 \ # 指向克隆的 yolov5 文件夹
# --converted_data_dir ./coco_yolov5_format \
# --subset_ratio 0.1 # 可选,使用10%数据快速测试