import torch import torchvision from torchvision.models.detection import FasterRCNN from torchvision.models.detection.rpn import AnchorGenerator from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.datasets import CocoDetection, VOCDetection import torchvision.transforms.v2 as T # 使用新的 v2 transforms from torch.utils.data import DataLoader, Subset, Dataset import os import json from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as patches import utils # PyTorch 官方 Faster R-CNN 示例中的辅助文件 (需要下载或实现) # https://github.com/pytorch/vision/tree/main/references/detection import argparse # 用于命令行参数解析 import glob # 用于查找文件 import xml.etree.ElementTree as ET # 用于解析 PASCAL VOC XML import logging import numpy as np from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval # --- 命令行参数解析 --- parser = argparse.ArgumentParser(description='目标检测训练脚本') parser.add_argument('--dataset_type', type=str, default='coco', choices=['coco', 'pascal', 'yolo'], help='数据集类型:coco, pascal, 或 yolo') args = parser.parse_args() DATASET_TYPE = args.dataset_type print(f"选择的数据集类型: {DATASET_TYPE}") # --- 1. 配置和超参数 --- # --- 数据集路径定义 --- COCO_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/coco' COCO_TRAIN_IMG_DIR = os.path.join(COCO_BASE_PATH, 'train') COCO_TRAIN_ANN_FILE = os.path.join(COCO_BASE_PATH, 'train/train_annotations.json') COCO_VAL_IMG_DIR = os.path.join(COCO_BASE_PATH, 'valid') COCO_VAL_ANN_FILE = os.path.join(COCO_BASE_PATH, 'valid/valid_annotations.json') COCO_TEST_IMG_DIR = os.path.join(COCO_BASE_PATH, 'test') COCO_TEST_ANN_FILE = os.path.join(COCO_BASE_PATH, 'test/test_annotations.json') PASCAL_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/pascal' PASCAL_YEAR = '2007' # 假设使用 VOC 2007 YOLO_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/yolo' # --- COCO 类别定义 --- def load_coco_categories(ann_file): with open(ann_file, 'r') as f: data = json.load(f) categories = {cat['id']: cat['name'] for cat in data['categories']} return categories, len(categories) # --- PASCAL VOC 类别定义 --- # 更新为实际数据集中的类别 PASCAL_CLASSES = [ "__background__", # 背景类 "young", "empty_pod", # 实际数据集中的类别 ] # 创建类别名称到 ID 的映射 (从 1 开始,符合通常习惯) PASCAL_NAME_TO_ID = {name: i for i, name in enumerate(PASCAL_CLASSES) if name != "__background__"} PASCAL_ID_TO_NAME = {i: name for name, i in PASCAL_NAME_TO_ID.items()} PASCAL_NUM_CLASSES = len(PASCAL_CLASSES) - 1 # 实际物体类别数量 # --- YOLO 类别定义 --- def load_yolo_classes(classes_file=None): # 如果有提供classes.txt文件,从文件加载 if classes_file and os.path.exists(classes_file): with open(classes_file, 'r') as f: classes = [line.strip() for line in f.readlines()] return {i+1: name for i, name in enumerate(classes)}, len(classes) # 否则使用yaml文件中定义的类别,或者使用默认值 # 从data.yaml加载 yaml_file = os.path.join(YOLO_BASE_PATH, 'data.yaml') if os.path.exists(yaml_file): try: import yaml with open(yaml_file, 'r') as f: data = yaml.safe_load(f) if 'names' in data: classes = data['names'] return {i+1: name for i, name in enumerate(classes)}, len(classes) except Exception as e: print(f"从YAML加载类别失败: {e}") # 如果都失败了,使用默认值 classes = ['Ready', 'empty_pod', 'germination', 'pod', 'young'] return {i+1: name for i, name in enumerate(classes)}, len(classes) # 根据数据集类型加载相应的类别信息 if DATASET_TYPE == 'coco': category_id_to_name, NUM_CLASSES = load_coco_categories(COCO_TRAIN_ANN_FILE) elif DATASET_TYPE == 'pascal': category_id_to_name = PASCAL_ID_TO_NAME NUM_CLASSES = PASCAL_NUM_CLASSES elif DATASET_TYPE == 'yolo': category_id_to_name, NUM_CLASSES = load_yolo_classes() # 不传入文件,使用默认类别或从yaml加载 else: raise ValueError(f"不支持的数据集类型: {DATASET_TYPE}") print(f"类别数量: {NUM_CLASSES}") # --- 训练参数 --- DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') print(f"使用的设备: {DEVICE}") BATCH_SIZE = 32 # 根据你的 GPU 显存调整 NUM_EPOCHS = 10 # 训练轮数 LEARNING_RATE = 0.005 MOMENTUM = 0.9 WEIGHT_DECAY = 0.0005 SAVE_MODEL_PATH = f'faster_rcnn_{DATASET_TYPE}_model.pth' # 背景类也要算进去,所以模型输出是 NUM_CLASSES + 1 MODEL_NUM_CLASSES = NUM_CLASSES + 1 # --- 2. 数据集和数据加载器 --- # 更新 get_transform 使其更通用 def get_transform(train): transforms = [] if train: # 训练时的数据增强 transforms.append(T.RandomHorizontalFlip(0.5)) # 可以添加更多增强... # 对于 COCO 和 VOC,输入已经是 PIL Image # 转换为 tensor 并调整数据类型 transforms.append(T.PILToTensor()) transforms.append(T.ConvertImageDtype(torch.float32)) # 归一化到 [0, 1] # 确保图像格式正确,例如确保3通道 # transforms.append(lambda x: x if x.shape[0] == 3 else torch.cat([x, x, x], 0) if x.shape[0] == 1 else x[:3]) return T.Compose(transforms) # 自定义COCO数据集类 (保持不变) class CustomCocoDetection(CocoDetection): def __getitem__(self, index): # 注意:确保 CustomCocoDetection 中的 transforms 调用方式正确 # 如果 get_transform 只接受 img,则调用应为 img = self.transforms(img) # 如果 get_transform 设计为 T.Compose([T.PILToTensor(), ...]),则原始调用可能正确 # 为保持一致性,让 get_transform 只处理图像,并在 Dataset 类中应用 img, coco_targets = super(CustomCocoDetection, self).__getitem__(index) # 获取原始 PIL 图像和 COCO 标注列表 # 将COCO标注格式转换为Faster R-CNN需要的格式 image_id = self.ids[index] target = {} target["boxes"] = [] target["labels"] = [] target["image_id"] = torch.tensor([image_id]) target["area"] = [] # Faster R-CNN 可能需要 area target["iscrowd"] = [] # Faster R-CNN 可能需要 iscrowd for anno in coco_targets: x, y, w, h = anno["bbox"] xmin = x ymin = y xmax = x + w ymax = y + h if w > 0 and h > 0: target["boxes"].append([xmin, ymin, xmax, ymax]) target["labels"].append(anno["category_id"]) target["area"].append(anno["area"]) target["iscrowd"].append(anno["iscrowd"]) if len(target["boxes"]) > 0: target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32) target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64) target["area"] = torch.as_tensor(target["area"], dtype=torch.float32) target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8) else: # 处理没有标注的情况 target["boxes"] = torch.zeros((0, 4), dtype=torch.float32) target["labels"] = torch.zeros((0,), dtype=torch.int64) target["area"] = torch.zeros((0,), dtype=torch.float32) target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8) # 应用图像变换 if self.transforms is not None: img = self.transforms(img) # 假设 transform 只作用于图像 return img, target # 自定义 PASCAL VOC 数据集类 class CustomVOCDetection(Dataset): def __init__(self, root, year='2007', image_set='train', transforms=None): self.root = root self.year = year self.image_set = image_set self._transforms = transforms # 针对用户提供的目录结构:图像和XML注释在同一目录下 self.data_dir = os.path.join(root, image_set) if not os.path.exists(self.data_dir): raise FileNotFoundError(f"can't find Pascal data set directory: {self.data_dir}") # 加载所有图像文件名(不含扩展名) self.images = [] for f in os.listdir(self.data_dir): if f.endswith('.jpg') or f.endswith('.png'): img_id = os.path.splitext(f)[0] xml_file = os.path.join(self.data_dir, f"{img_id}.xml") # 确保每个图像都有对应的XML标注 if os.path.exists(xml_file): self.images.append(img_id) # 类别映射 self.name_to_id = PASCAL_NAME_TO_ID print(f"found {len(self.images)} Pascal VOC {image_set} images") def __len__(self): return len(self.images) def __getitem__(self, index): img_id = self.images[index] # 加载图像 img_path = os.path.join(self.data_dir, f"{img_id}.jpg") if not os.path.exists(img_path): img_path = os.path.join(self.data_dir, f"{img_id}.png") # 尝试PNG格式 img = Image.open(img_path).convert("RGB") img_width, img_height = img.size # 加载标注XML xml_path = os.path.join(self.data_dir, f"{img_id}.xml") tree = ET.parse(xml_path) root = tree.getroot() target = {} target["boxes"] = [] target["labels"] = [] target["image_id"] = torch.tensor([index]) target["area"] = [] target["iscrowd"] = [] for obj in root.findall('object'): label_name = obj.find('name').text if label_name not in self.name_to_id: continue label_id = self.name_to_id[label_name] bbox = obj.find('bndbox') xmin = float(bbox.find('xmin').text) ymin = float(bbox.find('ymin').text) xmax = float(bbox.find('xmax').text) ymax = float(bbox.find('ymax').text) # 检查边界框有效性 if xmax > xmin and ymax > ymin: target["boxes"].append([xmin, ymin, xmax, ymax]) target["labels"].append(label_id) area = (xmax - xmin) * (ymax - ymin) target["area"].append(area) target["iscrowd"].append(0) # VOC 没有 crowd 标注 if len(target["boxes"]) > 0: target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32) target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64) target["area"] = torch.as_tensor(target["area"], dtype=torch.float32) target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8) else: target["boxes"] = torch.zeros((0, 4), dtype=torch.float32) target["labels"] = torch.zeros((0,), dtype=torch.int64) target["area"] = torch.zeros((0,), dtype=torch.float32) target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8) # 应用图像变换 if self._transforms is not None: img = self._transforms(img) return img, target # 自定义YOLO数据集类 class CustomYOLODataset(Dataset): def __init__(self, img_dir, label_dir, classes_file=None, transforms=None): self.img_dir = img_dir self.label_dir = label_dir self.transforms = transforms # 获取所有图像文件 self.img_files = sorted(glob.glob(os.path.join(img_dir, '*.jpg')) + glob.glob(os.path.join(img_dir, '*.png'))) # 确保每个图像都有对应的标签文件 self.valid_files = [] for img_path in self.img_files: base_name = os.path.basename(img_path).split('.')[0] label_path = os.path.join(label_dir, f"{base_name}.txt") if os.path.exists(label_path): self.valid_files.append((img_path, label_path)) print(f"found {len(self.valid_files)} YOLO format image-label pairs") def __len__(self): return len(self.valid_files) def __getitem__(self, idx): img_path, label_path = self.valid_files[idx] # 加载图像 img = Image.open(img_path).convert("RGB") img_width, img_height = img.size target = {} target["boxes"] = [] target["labels"] = [] target["image_id"] = torch.tensor([idx]) target["area"] = [] target["iscrowd"] = [] # 解析YOLO格式标签 (类别 中心x 中心y 宽 高) # YOLO坐标已经被归一化为[0,1] with open(label_path, 'r') as f: for line in f.readlines(): if line.strip(): parts = line.strip().split() if len(parts) == 5: cls_id = int(parts[0]) + 1 # YOLO类别从0开始,我们从1开始 x_center = float(parts[1]) * img_width y_center = float(parts[2]) * img_height width = float(parts[3]) * img_width height = float(parts[4]) * img_height # 将中心坐标和宽高转换为左上角和右下角坐标 xmin = max(0, x_center - width/2) ymin = max(0, y_center - height/2) xmax = min(img_width, x_center + width/2) ymax = min(img_height, y_center + height/2) # 确保边界框有效 if xmax > xmin and ymax > ymin: target["boxes"].append([xmin, ymin, xmax, ymax]) target["labels"].append(cls_id) area = (xmax - xmin) * (ymax - ymin) target["area"].append(area) target["iscrowd"].append(0) if len(target["boxes"]) > 0: target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32) target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64) target["area"] = torch.as_tensor(target["area"], dtype=torch.float32) target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8) else: target["boxes"] = torch.zeros((0, 4), dtype=torch.float32) target["labels"] = torch.zeros((0,), dtype=torch.int64) target["area"] = torch.zeros((0,), dtype=torch.float32) target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8) # 应用图像变换 if self.transforms: img = self.transforms(img) return img, target # 使用自定义的CocoDetection数据集类 try: # 根据数据集类型加载不同的数据集 if DATASET_TYPE == 'coco': dataset_train = CustomCocoDetection(root=COCO_TRAIN_IMG_DIR, annFile=COCO_TRAIN_ANN_FILE, transforms=get_transform(train=True)) dataset_val = CustomCocoDetection(root=COCO_VAL_IMG_DIR, annFile=COCO_VAL_ANN_FILE, transforms=get_transform(train=False)) dataset_test = CustomCocoDetection(root=COCO_TEST_IMG_DIR, annFile=COCO_TEST_ANN_FILE, transforms=get_transform(train=False)) elif DATASET_TYPE == 'pascal': dataset_train = CustomVOCDetection(root=PASCAL_BASE_PATH, year=PASCAL_YEAR, image_set='train', transforms=get_transform(train=True)) dataset_val = CustomVOCDetection(root=PASCAL_BASE_PATH, year=PASCAL_YEAR, image_set='valid', transforms=get_transform(train=False)) dataset_test = CustomVOCDetection(root=PASCAL_BASE_PATH, year=PASCAL_YEAR, image_set='test', transforms=get_transform(train=False)) # 调试信息: 打印一些样本的标签信息 print("\nDebug information - Pascal VOC dataset:") print(f"Category mapping: {PASCAL_NAME_TO_ID}") for i in range(min(3, len(dataset_train))): _, target = dataset_train[i] print(f"Sample {i}:") print(f" Number of bounding boxes: {len(target['boxes'])}") print(f" Labels: {target['labels'].tolist()}") if len(target['labels']) > 0: print(f" Category names: {[PASCAL_ID_TO_NAME[label.item()] for label in target['labels']]}") print(f" Bounding boxes: {target['boxes'].shape}") elif DATASET_TYPE == 'yolo': # YOLO数据集目录结构 YOLO_TRAIN_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'train', 'images') YOLO_TRAIN_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'train', 'labels') YOLO_VAL_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'valid', 'images') YOLO_VAL_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'valid', 'labels') YOLO_TEST_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'test', 'images') YOLO_TEST_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'test', 'labels') dataset_train = CustomYOLODataset( img_dir=YOLO_TRAIN_IMG_DIR, label_dir=YOLO_TRAIN_LABEL_DIR, classes_file=None, # 不再使用classes.txt,而是在load_yolo_classes中处理 transforms=get_transform(train=True) ) dataset_val = CustomYOLODataset( img_dir=YOLO_VAL_IMG_DIR, label_dir=YOLO_VAL_LABEL_DIR, classes_file=None, transforms=get_transform(train=False) ) dataset_test = CustomYOLODataset( img_dir=YOLO_TEST_IMG_DIR, label_dir=YOLO_TEST_LABEL_DIR, classes_file=None, transforms=get_transform(train=False) ) # 调试信息: 打印一些样本的标签信息 print("\nDebug information - YOLO dataset:") print(f"Category mapping: {category_id_to_name}") for i in range(min(3, len(dataset_train))): _, target = dataset_train[i] print(f"Sample {i}:") print(f" Number of bounding boxes: {len(target['boxes'])}") print(f" Labels: {target['labels'].tolist()}") if len(target['labels']) > 0: print(f" Category names: {[category_id_to_name.get(label.item(), f'unknown-{label.item()}') for label in target['labels']]}") print(f" Bounding boxes: {target['boxes'].shape}") else: raise ValueError(f"Unsupported dataset type: {DATASET_TYPE}") print(f"Training set size: {len(dataset_train)}") print(f"Validation set size: {len(dataset_val)}") print(f"Test set size: {len(dataset_test)}") # # 可选:如果数据集很大,可以只取一部分进行调试 # dataset_train = Subset(dataset_train, range(100)) # dataset_val = Subset(dataset_val, range(50)) # dataset_test = Subset(dataset_test, range(50)) # 定义数据加载器 # 需要一个 collate_fn 来处理批次中不同数量的目标 def collate_fn(batch): return tuple(zip(*batch)) data_loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=collate_fn) data_loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, # 验证通常 batch_size=1 num_workers=4, collate_fn=collate_fn) data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4, collate_fn=collate_fn) except FileNotFoundError as e: print(f"Error: can't find training or validation files: {e}") print(f"Please ensure the file paths for the selected dataset {DATASET_TYPE} are correct and files exist.") exit() except NotImplementedError as e: print(f"Error: {e}") print("Please implement the corresponding dataset loading class and try again.") exit() except Exception as e: print(f"Error: {e}") exit() # --- 3. 模型定义 (Faster R-CNN) --- def get_faster_rcnn_model(num_classes): # 加载预训练的 Faster R-CNN 模型 (ResNet-50 backbone with FPN) weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights) # 获取分类器的输入特征数 in_features = model.roi_heads.box_predictor.cls_score.in_features # 替换预训练的头部为一个新的头部 (适应我们的类别数量) model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # 可选: 使用 MobileNetV3-Large FPN 作为 backbone (更快但精度可能稍低) # weights_mobilenet = torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT # model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=weights_mobilenet) # in_features = model.roi_heads.box_predictor.cls_score.in_features # model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model # 实例化模型 model = get_faster_rcnn_model(MODEL_NUM_CLASSES) model.to(DEVICE) # 打印模型信息 print(f"\nModel information:") print(f"Number of input classes: {MODEL_NUM_CLASSES} (including background class)") print(f"Used device: {DEVICE}") print(f"Predictor input features: {model.roi_heads.box_predictor.cls_score.in_features}") print(f"Predictor output classes: {model.roi_heads.box_predictor.cls_score.out_features}") # --- 4. 训练设置 --- # 定义优化器 params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) # 定义学习率调度器 (可选, 但推荐) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) # --- 5. 训练循环 --- print("Starting training...") # 设置日志配置 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(f'training_{DATASET_TYPE}_Faster_R-CNN.log'), # 保存到文件 logging.StreamHandler() # 同时输出到控制台 ] ) logger = logging.getLogger('faster_rcnn') # 使用示例 logger.info(f"Selected dataset type: {DATASET_TYPE}") logger.info(f"Training set size: {len(dataset_train)}") # --- 定义mAP计算函数 --- def evaluate_map(model, data_loader, device, dataset_type, category_mapping=None): model.eval() # 设置模型为评估模式 # 用于存储所有预测和目标信息 all_predictions = [] all_targets = [] print(f"Caluating mAP...") with torch.no_grad(): for images, targets in data_loader: images = list(image.to(device) for image in images) # 执行预测 predictions = model(images) # 将预测和目标添加到列表中 all_predictions.extend(predictions) all_targets.extend(targets) # 将预测和目标转换为COCO格式,以便使用pycocotools计算mAP pred_instances = [] gt_instances = [] image_id = 0 instance_id = 0 for pred, target in zip(all_predictions, all_targets): # 处理预测结果 boxes = pred['boxes'].cpu().numpy() scores = pred['scores'].cpu().numpy() labels = pred['labels'].cpu().numpy() # 只保留置信度较高的预测 keep = scores > 0.05 boxes = boxes[keep] scores = scores[keep] labels = labels[keep] for box, score, label in zip(boxes, scores, labels): pred_instances.append({ 'image_id': image_id, 'category_id': int(label), 'bbox': [float(box[0]), float(box[1]), float(box[2] - box[0]), float(box[3] - box[1])], # COCO格式 [x, y, width, height] 'score': float(score), 'id': instance_id }) instance_id += 1 # 处理真实标签 gt_boxes = target['boxes'].cpu().numpy() gt_labels = target['labels'].cpu().numpy() for gt_box, gt_label in zip(gt_boxes, gt_labels): gt_instances.append({ 'image_id': image_id, 'category_id': int(gt_label), 'bbox': [float(gt_box[0]), float(gt_box[1]), float(gt_box[2] - gt_box[0]), float(gt_box[3] - gt_box[1])], # COCO格式 'area': float((gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])), 'iscrowd': 0, 'id': instance_id }) instance_id += 1 image_id += 1 # 创建COCO对象 coco_gt = COCO() coco_dt = COCO() # 设置类别信息 categories = [] if dataset_type == 'coco': for id, name in category_mapping.items(): categories.append({'id': id, 'name': name}) elif dataset_type == 'pascal': for id, name in category_mapping.items(): categories.append({'id': id, 'name': name}) elif dataset_type == 'yolo': for id, name in category_mapping.items(): categories.append({'id': id, 'name': name}) # 创建COCO数据结构 coco_gt.dataset = { 'images': [{'id': i} for i in range(image_id)], 'annotations': gt_instances, 'categories': categories } coco_gt.createIndex() coco_dt.dataset = { 'images': [{'id': i} for i in range(image_id)], 'annotations': pred_instances, 'categories': categories } coco_dt.createIndex() # 计算mAP coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() # 返回AP@0.5:0.95 (mAP) return coco_eval.stats[0] for epoch in range(NUM_EPOCHS): model.train() # 设置模型为训练模式 epoch_loss = 0 for i, (images, targets) in enumerate(data_loader_train): images = list(image.to(DEVICE) for image in images) # 确保 targets 中的 'boxes' 和 'labels' 是 Tensors targets = [{k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] # 清空之前的梯度 optimizer.zero_grad() # 前向传播并计算损失 # Faster R-CNN 在训练模式下返回一个损失字典 loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # 将所有损失加起来 # 调试信息:打印单个损失组件 if i < 2 or i % 50 == 0: # 定期打印损失明细 print(f" 损失详情: {', '.join([f'{k}: {v.item():.6f}' for k, v in loss_dict.items()])}") # 反向传播 losses.backward() # 更新权重 optimizer.step() batch_loss = losses.item() epoch_loss += batch_loss if (i + 1) % 50 == 0: # 每 50 个 batch 打印一次日志 print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(data_loader_train)}], Batch Loss: {batch_loss:.6f}') # 更新学习率 lr_scheduler.step() avg_epoch_loss = epoch_loss / len(data_loader_train) print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] 结束, 平均训练损失: {avg_epoch_loss:.4f}') # --- 6. 验证循环 --- model.eval() # 设置模型为评估模式 val_loss = 0 with torch.no_grad(): # 验证时不需要计算梯度 for images, targets in data_loader_val: images = list(image.to(DEVICE) for image in images) # 验证时模型也需要 targets 来计算可能的损失或其他指标 (如果需要) # 如果只想做推理,targets 可以是 None,但官方实现通常需要它们 targets = [{k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] # 注意: 在 eval 模式下,模型默认返回预测结果,而不是损失。 # 为了获取验证损失(如果需要评估指标),可能需要修改模型或评估逻辑 # 这里我们暂时跳过精确的验证损失计算,因为配置比较复杂 # 如果你想评估 mAP,需要使用 cocoapi 或类似工具 pass # 暂时跳过验证损失计算和 mAP 评估 # 计算验证集上的mAP print("Calculating mAP...") val_map = evaluate_map(model, data_loader_val, DEVICE, DATASET_TYPE, category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME) print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Validation mAP: {val_map:.4f}') logger.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Validation mAP: {val_map:.4f}") # print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], 验证完成.') # 添加验证指标输出 # 保存模型 (例如,每个 epoch 或基于验证性能) torch.save(model.state_dict(), f'faster_rcnn_{DATASET_TYPE}_epoch_{epoch+1}.pth') print(f"Model saved to faster_rcnn_{DATASET_TYPE}_epoch_{epoch+1}.pth") # 记录指标 logger.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Average training loss: {avg_epoch_loss:.4f}, Validation mAP: {val_map:.4f}") print("Training completed!") # 保存最终模型 torch.save(model.state_dict(), SAVE_MODEL_PATH) print(f"Final model saved to {SAVE_MODEL_PATH}") # --- 7. 测试/推理 --- print("Starting test/inference...") # 加载训练好的模型权重 model.load_state_dict(torch.load(SAVE_MODEL_PATH, map_location=DEVICE)) model.eval() # 确保模型在评估模式 # 计算训练集mAP print("计算训练集mAP...") train_map = evaluate_map(model, data_loader_train, DEVICE, DATASET_TYPE, category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME) print(f'训练集mAP: {train_map:.4f}') # 计算验证集mAP print("计算验证集mAP...") val_map = evaluate_map(model, data_loader_val, DEVICE, DATASET_TYPE, category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME) print(f'验证集mAP: {val_map:.4f}') # 计算测试集mAP print("计算测试集mAP...") test_map = evaluate_map(model, data_loader_test, DEVICE, DATASET_TYPE, category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME) print(f'测试集mAP: {test_map:.4f}') # 记录最终结果 logger.info(f"Final results - Training mAP: {train_map:.4f}, Validation mAP: {val_map:.4f}, Test mAP: {test_map:.4f}") # --- About YOLOv3 --- # --- End --- print("Code execution completed.")