42028-CNN-A2 / detection.py
Ziruibest's picture
Upload folder CNNa2
938fb27 verified
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.datasets import CocoDetection, VOCDetection
import torchvision.transforms.v2 as T # 使用新的 v2 transforms
from torch.utils.data import DataLoader, Subset, Dataset
import os
import json
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import utils # PyTorch 官方 Faster R-CNN 示例中的辅助文件 (需要下载或实现)
# https://github.com/pytorch/vision/tree/main/references/detection
import argparse # 用于命令行参数解析
import glob # 用于查找文件
import xml.etree.ElementTree as ET # 用于解析 PASCAL VOC XML
import logging
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
# --- 命令行参数解析 ---
parser = argparse.ArgumentParser(description='目标检测训练脚本')
parser.add_argument('--dataset_type', type=str, default='coco', choices=['coco', 'pascal', 'yolo'],
help='数据集类型:coco, pascal, 或 yolo')
args = parser.parse_args()
DATASET_TYPE = args.dataset_type
print(f"选择的数据集类型: {DATASET_TYPE}")
# --- 1. 配置和超参数 ---
# --- 数据集路径定义 ---
COCO_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/coco'
COCO_TRAIN_IMG_DIR = os.path.join(COCO_BASE_PATH, 'train')
COCO_TRAIN_ANN_FILE = os.path.join(COCO_BASE_PATH, 'train/train_annotations.json')
COCO_VAL_IMG_DIR = os.path.join(COCO_BASE_PATH, 'valid')
COCO_VAL_ANN_FILE = os.path.join(COCO_BASE_PATH, 'valid/valid_annotations.json')
COCO_TEST_IMG_DIR = os.path.join(COCO_BASE_PATH, 'test')
COCO_TEST_ANN_FILE = os.path.join(COCO_BASE_PATH, 'test/test_annotations.json')
PASCAL_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/pascal'
PASCAL_YEAR = '2007' # 假设使用 VOC 2007
YOLO_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/yolo'
# --- COCO 类别定义 ---
def load_coco_categories(ann_file):
with open(ann_file, 'r') as f:
data = json.load(f)
categories = {cat['id']: cat['name'] for cat in data['categories']}
return categories, len(categories)
# --- PASCAL VOC 类别定义 ---
# 更新为实际数据集中的类别
PASCAL_CLASSES = [
"__background__", # 背景类
"young", "empty_pod", # 实际数据集中的类别
]
# 创建类别名称到 ID 的映射 (从 1 开始,符合通常习惯)
PASCAL_NAME_TO_ID = {name: i for i, name in enumerate(PASCAL_CLASSES) if name != "__background__"}
PASCAL_ID_TO_NAME = {i: name for name, i in PASCAL_NAME_TO_ID.items()}
PASCAL_NUM_CLASSES = len(PASCAL_CLASSES) - 1 # 实际物体类别数量
# --- YOLO 类别定义 ---
def load_yolo_classes(classes_file=None):
# 如果有提供classes.txt文件,从文件加载
if classes_file and os.path.exists(classes_file):
with open(classes_file, 'r') as f:
classes = [line.strip() for line in f.readlines()]
return {i+1: name for i, name in enumerate(classes)}, len(classes)
# 否则使用yaml文件中定义的类别,或者使用默认值
# 从data.yaml加载
yaml_file = os.path.join(YOLO_BASE_PATH, 'data.yaml')
if os.path.exists(yaml_file):
try:
import yaml
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
if 'names' in data:
classes = data['names']
return {i+1: name for i, name in enumerate(classes)}, len(classes)
except Exception as e:
print(f"从YAML加载类别失败: {e}")
# 如果都失败了,使用默认值
classes = ['Ready', 'empty_pod', 'germination', 'pod', 'young']
return {i+1: name for i, name in enumerate(classes)}, len(classes)
# 根据数据集类型加载相应的类别信息
if DATASET_TYPE == 'coco':
category_id_to_name, NUM_CLASSES = load_coco_categories(COCO_TRAIN_ANN_FILE)
elif DATASET_TYPE == 'pascal':
category_id_to_name = PASCAL_ID_TO_NAME
NUM_CLASSES = PASCAL_NUM_CLASSES
elif DATASET_TYPE == 'yolo':
category_id_to_name, NUM_CLASSES = load_yolo_classes() # 不传入文件,使用默认类别或从yaml加载
else:
raise ValueError(f"不支持的数据集类型: {DATASET_TYPE}")
print(f"类别数量: {NUM_CLASSES}")
# --- 训练参数 ---
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"使用的设备: {DEVICE}")
BATCH_SIZE = 32 # 根据你的 GPU 显存调整
NUM_EPOCHS = 10 # 训练轮数
LEARNING_RATE = 0.005
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
SAVE_MODEL_PATH = f'faster_rcnn_{DATASET_TYPE}_model.pth'
# 背景类也要算进去,所以模型输出是 NUM_CLASSES + 1
MODEL_NUM_CLASSES = NUM_CLASSES + 1
# --- 2. 数据集和数据加载器 ---
# 更新 get_transform 使其更通用
def get_transform(train):
transforms = []
if train:
# 训练时的数据增强
transforms.append(T.RandomHorizontalFlip(0.5))
# 可以添加更多增强...
# 对于 COCO 和 VOC,输入已经是 PIL Image
# 转换为 tensor 并调整数据类型
transforms.append(T.PILToTensor())
transforms.append(T.ConvertImageDtype(torch.float32)) # 归一化到 [0, 1]
# 确保图像格式正确,例如确保3通道
# transforms.append(lambda x: x if x.shape[0] == 3 else torch.cat([x, x, x], 0) if x.shape[0] == 1 else x[:3])
return T.Compose(transforms)
# 自定义COCO数据集类 (保持不变)
class CustomCocoDetection(CocoDetection):
def __getitem__(self, index):
# 注意:确保 CustomCocoDetection 中的 transforms 调用方式正确
# 如果 get_transform 只接受 img,则调用应为 img = self.transforms(img)
# 如果 get_transform 设计为 T.Compose([T.PILToTensor(), ...]),则原始调用可能正确
# 为保持一致性,让 get_transform 只处理图像,并在 Dataset 类中应用
img, coco_targets = super(CustomCocoDetection, self).__getitem__(index) # 获取原始 PIL 图像和 COCO 标注列表
# 将COCO标注格式转换为Faster R-CNN需要的格式
image_id = self.ids[index]
target = {}
target["boxes"] = []
target["labels"] = []
target["image_id"] = torch.tensor([image_id])
target["area"] = [] # Faster R-CNN 可能需要 area
target["iscrowd"] = [] # Faster R-CNN 可能需要 iscrowd
for anno in coco_targets:
x, y, w, h = anno["bbox"]
xmin = x
ymin = y
xmax = x + w
ymax = y + h
if w > 0 and h > 0:
target["boxes"].append([xmin, ymin, xmax, ymax])
target["labels"].append(anno["category_id"])
target["area"].append(anno["area"])
target["iscrowd"].append(anno["iscrowd"])
if len(target["boxes"]) > 0:
target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32)
target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64)
target["area"] = torch.as_tensor(target["area"], dtype=torch.float32)
target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8)
else:
# 处理没有标注的情况
target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
target["labels"] = torch.zeros((0,), dtype=torch.int64)
target["area"] = torch.zeros((0,), dtype=torch.float32)
target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8)
# 应用图像变换
if self.transforms is not None:
img = self.transforms(img) # 假设 transform 只作用于图像
return img, target
# 自定义 PASCAL VOC 数据集类
class CustomVOCDetection(Dataset):
def __init__(self, root, year='2007', image_set='train', transforms=None):
self.root = root
self.year = year
self.image_set = image_set
self._transforms = transforms
# 针对用户提供的目录结构:图像和XML注释在同一目录下
self.data_dir = os.path.join(root, image_set)
if not os.path.exists(self.data_dir):
raise FileNotFoundError(f"can't find Pascal data set directory: {self.data_dir}")
# 加载所有图像文件名(不含扩展名)
self.images = []
for f in os.listdir(self.data_dir):
if f.endswith('.jpg') or f.endswith('.png'):
img_id = os.path.splitext(f)[0]
xml_file = os.path.join(self.data_dir, f"{img_id}.xml")
# 确保每个图像都有对应的XML标注
if os.path.exists(xml_file):
self.images.append(img_id)
# 类别映射
self.name_to_id = PASCAL_NAME_TO_ID
print(f"found {len(self.images)} Pascal VOC {image_set} images")
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img_id = self.images[index]
# 加载图像
img_path = os.path.join(self.data_dir, f"{img_id}.jpg")
if not os.path.exists(img_path):
img_path = os.path.join(self.data_dir, f"{img_id}.png") # 尝试PNG格式
img = Image.open(img_path).convert("RGB")
img_width, img_height = img.size
# 加载标注XML
xml_path = os.path.join(self.data_dir, f"{img_id}.xml")
tree = ET.parse(xml_path)
root = tree.getroot()
target = {}
target["boxes"] = []
target["labels"] = []
target["image_id"] = torch.tensor([index])
target["area"] = []
target["iscrowd"] = []
for obj in root.findall('object'):
label_name = obj.find('name').text
if label_name not in self.name_to_id:
continue
label_id = self.name_to_id[label_name]
bbox = obj.find('bndbox')
xmin = float(bbox.find('xmin').text)
ymin = float(bbox.find('ymin').text)
xmax = float(bbox.find('xmax').text)
ymax = float(bbox.find('ymax').text)
# 检查边界框有效性
if xmax > xmin and ymax > ymin:
target["boxes"].append([xmin, ymin, xmax, ymax])
target["labels"].append(label_id)
area = (xmax - xmin) * (ymax - ymin)
target["area"].append(area)
target["iscrowd"].append(0) # VOC 没有 crowd 标注
if len(target["boxes"]) > 0:
target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32)
target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64)
target["area"] = torch.as_tensor(target["area"], dtype=torch.float32)
target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8)
else:
target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
target["labels"] = torch.zeros((0,), dtype=torch.int64)
target["area"] = torch.zeros((0,), dtype=torch.float32)
target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8)
# 应用图像变换
if self._transforms is not None:
img = self._transforms(img)
return img, target
# 自定义YOLO数据集类
class CustomYOLODataset(Dataset):
def __init__(self, img_dir, label_dir, classes_file=None, transforms=None):
self.img_dir = img_dir
self.label_dir = label_dir
self.transforms = transforms
# 获取所有图像文件
self.img_files = sorted(glob.glob(os.path.join(img_dir, '*.jpg')) +
glob.glob(os.path.join(img_dir, '*.png')))
# 确保每个图像都有对应的标签文件
self.valid_files = []
for img_path in self.img_files:
base_name = os.path.basename(img_path).split('.')[0]
label_path = os.path.join(label_dir, f"{base_name}.txt")
if os.path.exists(label_path):
self.valid_files.append((img_path, label_path))
print(f"found {len(self.valid_files)} YOLO format image-label pairs")
def __len__(self):
return len(self.valid_files)
def __getitem__(self, idx):
img_path, label_path = self.valid_files[idx]
# 加载图像
img = Image.open(img_path).convert("RGB")
img_width, img_height = img.size
target = {}
target["boxes"] = []
target["labels"] = []
target["image_id"] = torch.tensor([idx])
target["area"] = []
target["iscrowd"] = []
# 解析YOLO格式标签 (类别 中心x 中心y 宽 高)
# YOLO坐标已经被归一化为[0,1]
with open(label_path, 'r') as f:
for line in f.readlines():
if line.strip():
parts = line.strip().split()
if len(parts) == 5:
cls_id = int(parts[0]) + 1 # YOLO类别从0开始,我们从1开始
x_center = float(parts[1]) * img_width
y_center = float(parts[2]) * img_height
width = float(parts[3]) * img_width
height = float(parts[4]) * img_height
# 将中心坐标和宽高转换为左上角和右下角坐标
xmin = max(0, x_center - width/2)
ymin = max(0, y_center - height/2)
xmax = min(img_width, x_center + width/2)
ymax = min(img_height, y_center + height/2)
# 确保边界框有效
if xmax > xmin and ymax > ymin:
target["boxes"].append([xmin, ymin, xmax, ymax])
target["labels"].append(cls_id)
area = (xmax - xmin) * (ymax - ymin)
target["area"].append(area)
target["iscrowd"].append(0)
if len(target["boxes"]) > 0:
target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32)
target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64)
target["area"] = torch.as_tensor(target["area"], dtype=torch.float32)
target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8)
else:
target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
target["labels"] = torch.zeros((0,), dtype=torch.int64)
target["area"] = torch.zeros((0,), dtype=torch.float32)
target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8)
# 应用图像变换
if self.transforms:
img = self.transforms(img)
return img, target
# 使用自定义的CocoDetection数据集类
try:
# 根据数据集类型加载不同的数据集
if DATASET_TYPE == 'coco':
dataset_train = CustomCocoDetection(root=COCO_TRAIN_IMG_DIR,
annFile=COCO_TRAIN_ANN_FILE,
transforms=get_transform(train=True))
dataset_val = CustomCocoDetection(root=COCO_VAL_IMG_DIR,
annFile=COCO_VAL_ANN_FILE,
transforms=get_transform(train=False))
dataset_test = CustomCocoDetection(root=COCO_TEST_IMG_DIR,
annFile=COCO_TEST_ANN_FILE,
transforms=get_transform(train=False))
elif DATASET_TYPE == 'pascal':
dataset_train = CustomVOCDetection(root=PASCAL_BASE_PATH,
year=PASCAL_YEAR,
image_set='train',
transforms=get_transform(train=True))
dataset_val = CustomVOCDetection(root=PASCAL_BASE_PATH,
year=PASCAL_YEAR,
image_set='valid',
transforms=get_transform(train=False))
dataset_test = CustomVOCDetection(root=PASCAL_BASE_PATH,
year=PASCAL_YEAR,
image_set='test',
transforms=get_transform(train=False))
# 调试信息: 打印一些样本的标签信息
print("\nDebug information - Pascal VOC dataset:")
print(f"Category mapping: {PASCAL_NAME_TO_ID}")
for i in range(min(3, len(dataset_train))):
_, target = dataset_train[i]
print(f"Sample {i}:")
print(f" Number of bounding boxes: {len(target['boxes'])}")
print(f" Labels: {target['labels'].tolist()}")
if len(target['labels']) > 0:
print(f" Category names: {[PASCAL_ID_TO_NAME[label.item()] for label in target['labels']]}")
print(f" Bounding boxes: {target['boxes'].shape}")
elif DATASET_TYPE == 'yolo':
# YOLO数据集目录结构
YOLO_TRAIN_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'train', 'images')
YOLO_TRAIN_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'train', 'labels')
YOLO_VAL_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'valid', 'images')
YOLO_VAL_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'valid', 'labels')
YOLO_TEST_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'test', 'images')
YOLO_TEST_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'test', 'labels')
dataset_train = CustomYOLODataset(
img_dir=YOLO_TRAIN_IMG_DIR,
label_dir=YOLO_TRAIN_LABEL_DIR,
classes_file=None, # 不再使用classes.txt,而是在load_yolo_classes中处理
transforms=get_transform(train=True)
)
dataset_val = CustomYOLODataset(
img_dir=YOLO_VAL_IMG_DIR,
label_dir=YOLO_VAL_LABEL_DIR,
classes_file=None,
transforms=get_transform(train=False)
)
dataset_test = CustomYOLODataset(
img_dir=YOLO_TEST_IMG_DIR,
label_dir=YOLO_TEST_LABEL_DIR,
classes_file=None,
transforms=get_transform(train=False)
)
# 调试信息: 打印一些样本的标签信息
print("\nDebug information - YOLO dataset:")
print(f"Category mapping: {category_id_to_name}")
for i in range(min(3, len(dataset_train))):
_, target = dataset_train[i]
print(f"Sample {i}:")
print(f" Number of bounding boxes: {len(target['boxes'])}")
print(f" Labels: {target['labels'].tolist()}")
if len(target['labels']) > 0:
print(f" Category names: {[category_id_to_name.get(label.item(), f'unknown-{label.item()}') for label in target['labels']]}")
print(f" Bounding boxes: {target['boxes'].shape}")
else:
raise ValueError(f"Unsupported dataset type: {DATASET_TYPE}")
print(f"Training set size: {len(dataset_train)}")
print(f"Validation set size: {len(dataset_val)}")
print(f"Test set size: {len(dataset_test)}")
# # 可选:如果数据集很大,可以只取一部分进行调试
# dataset_train = Subset(dataset_train, range(100))
# dataset_val = Subset(dataset_val, range(50))
# dataset_test = Subset(dataset_test, range(50))
# 定义数据加载器
# 需要一个 collate_fn 来处理批次中不同数量的目标
def collate_fn(batch):
return tuple(zip(*batch))
data_loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True,
num_workers=4, collate_fn=collate_fn)
data_loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, # 验证通常 batch_size=1
num_workers=4, collate_fn=collate_fn)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False,
num_workers=4, collate_fn=collate_fn)
except FileNotFoundError as e:
print(f"Error: can't find training or validation files: {e}")
print(f"Please ensure the file paths for the selected dataset {DATASET_TYPE} are correct and files exist.")
exit()
except NotImplementedError as e:
print(f"Error: {e}")
print("Please implement the corresponding dataset loading class and try again.")
exit()
except Exception as e:
print(f"Error: {e}")
exit()
# --- 3. 模型定义 (Faster R-CNN) ---
def get_faster_rcnn_model(num_classes):
# 加载预训练的 Faster R-CNN 模型 (ResNet-50 backbone with FPN)
weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights)
# 获取分类器的输入特征数
in_features = model.roi_heads.box_predictor.cls_score.in_features
# 替换预训练的头部为一个新的头部 (适应我们的类别数量)
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# 可选: 使用 MobileNetV3-Large FPN 作为 backbone (更快但精度可能稍低)
# weights_mobilenet = torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT
# model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=weights_mobilenet)
# in_features = model.roi_heads.box_predictor.cls_score.in_features
# model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
# 实例化模型
model = get_faster_rcnn_model(MODEL_NUM_CLASSES)
model.to(DEVICE)
# 打印模型信息
print(f"\nModel information:")
print(f"Number of input classes: {MODEL_NUM_CLASSES} (including background class)")
print(f"Used device: {DEVICE}")
print(f"Predictor input features: {model.roi_heads.box_predictor.cls_score.in_features}")
print(f"Predictor output classes: {model.roi_heads.box_predictor.cls_score.out_features}")
# --- 4. 训练设置 ---
# 定义优化器
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
# 定义学习率调度器 (可选, 但推荐)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
# --- 5. 训练循环 ---
print("Starting training...")
# 设置日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'training_{DATASET_TYPE}_Faster_R-CNN.log'), # 保存到文件
logging.StreamHandler() # 同时输出到控制台
]
)
logger = logging.getLogger('faster_rcnn')
# 使用示例
logger.info(f"Selected dataset type: {DATASET_TYPE}")
logger.info(f"Training set size: {len(dataset_train)}")
# --- 定义mAP计算函数 ---
def evaluate_map(model, data_loader, device, dataset_type, category_mapping=None):
model.eval() # 设置模型为评估模式
# 用于存储所有预测和目标信息
all_predictions = []
all_targets = []
print(f"Caluating mAP...")
with torch.no_grad():
for images, targets in data_loader:
images = list(image.to(device) for image in images)
# 执行预测
predictions = model(images)
# 将预测和目标添加到列表中
all_predictions.extend(predictions)
all_targets.extend(targets)
# 将预测和目标转换为COCO格式,以便使用pycocotools计算mAP
pred_instances = []
gt_instances = []
image_id = 0
instance_id = 0
for pred, target in zip(all_predictions, all_targets):
# 处理预测结果
boxes = pred['boxes'].cpu().numpy()
scores = pred['scores'].cpu().numpy()
labels = pred['labels'].cpu().numpy()
# 只保留置信度较高的预测
keep = scores > 0.05
boxes = boxes[keep]
scores = scores[keep]
labels = labels[keep]
for box, score, label in zip(boxes, scores, labels):
pred_instances.append({
'image_id': image_id,
'category_id': int(label),
'bbox': [float(box[0]), float(box[1]), float(box[2] - box[0]), float(box[3] - box[1])], # COCO格式 [x, y, width, height]
'score': float(score),
'id': instance_id
})
instance_id += 1
# 处理真实标签
gt_boxes = target['boxes'].cpu().numpy()
gt_labels = target['labels'].cpu().numpy()
for gt_box, gt_label in zip(gt_boxes, gt_labels):
gt_instances.append({
'image_id': image_id,
'category_id': int(gt_label),
'bbox': [float(gt_box[0]), float(gt_box[1]), float(gt_box[2] - gt_box[0]), float(gt_box[3] - gt_box[1])], # COCO格式
'area': float((gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])),
'iscrowd': 0,
'id': instance_id
})
instance_id += 1
image_id += 1
# 创建COCO对象
coco_gt = COCO()
coco_dt = COCO()
# 设置类别信息
categories = []
if dataset_type == 'coco':
for id, name in category_mapping.items():
categories.append({'id': id, 'name': name})
elif dataset_type == 'pascal':
for id, name in category_mapping.items():
categories.append({'id': id, 'name': name})
elif dataset_type == 'yolo':
for id, name in category_mapping.items():
categories.append({'id': id, 'name': name})
# 创建COCO数据结构
coco_gt.dataset = {
'images': [{'id': i} for i in range(image_id)],
'annotations': gt_instances,
'categories': categories
}
coco_gt.createIndex()
coco_dt.dataset = {
'images': [{'id': i} for i in range(image_id)],
'annotations': pred_instances,
'categories': categories
}
coco_dt.createIndex()
# 计算mAP
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
# 返回[email protected]:0.95 (mAP)
return coco_eval.stats[0]
for epoch in range(NUM_EPOCHS):
model.train() # 设置模型为训练模式
epoch_loss = 0
for i, (images, targets) in enumerate(data_loader_train):
images = list(image.to(DEVICE) for image in images)
# 确保 targets 中的 'boxes' 和 'labels' 是 Tensors
targets = [{k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
# 清空之前的梯度
optimizer.zero_grad()
# 前向传播并计算损失
# Faster R-CNN 在训练模式下返回一个损失字典
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values()) # 将所有损失加起来
# 调试信息:打印单个损失组件
if i < 2 or i % 50 == 0: # 定期打印损失明细
print(f" 损失详情: {', '.join([f'{k}: {v.item():.6f}' for k, v in loss_dict.items()])}")
# 反向传播
losses.backward()
# 更新权重
optimizer.step()
batch_loss = losses.item()
epoch_loss += batch_loss
if (i + 1) % 50 == 0: # 每 50 个 batch 打印一次日志
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(data_loader_train)}], Batch Loss: {batch_loss:.6f}')
# 更新学习率
lr_scheduler.step()
avg_epoch_loss = epoch_loss / len(data_loader_train)
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] 结束, 平均训练损失: {avg_epoch_loss:.4f}')
# --- 6. 验证循环 ---
model.eval() # 设置模型为评估模式
val_loss = 0
with torch.no_grad(): # 验证时不需要计算梯度
for images, targets in data_loader_val:
images = list(image.to(DEVICE) for image in images)
# 验证时模型也需要 targets 来计算可能的损失或其他指标 (如果需要)
# 如果只想做推理,targets 可以是 None,但官方实现通常需要它们
targets = [{k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
# 注意: 在 eval 模式下,模型默认返回预测结果,而不是损失。
# 为了获取验证损失(如果需要评估指标),可能需要修改模型或评估逻辑
# 这里我们暂时跳过精确的验证损失计算,因为配置比较复杂
# 如果你想评估 mAP,需要使用 cocoapi 或类似工具
pass # 暂时跳过验证损失计算和 mAP 评估
# 计算验证集上的mAP
print("Calculating mAP...")
val_map = evaluate_map(model, data_loader_val, DEVICE, DATASET_TYPE,
category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME)
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Validation mAP: {val_map:.4f}')
logger.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Validation mAP: {val_map:.4f}")
# print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], 验证完成.') # 添加验证指标输出
# 保存模型 (例如,每个 epoch 或基于验证性能)
torch.save(model.state_dict(), f'faster_rcnn_{DATASET_TYPE}_epoch_{epoch+1}.pth')
print(f"Model saved to faster_rcnn_{DATASET_TYPE}_epoch_{epoch+1}.pth")
# 记录指标
logger.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Average training loss: {avg_epoch_loss:.4f}, Validation mAP: {val_map:.4f}")
print("Training completed!")
# 保存最终模型
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(f"Final model saved to {SAVE_MODEL_PATH}")
# --- 7. 测试/推理 ---
print("Starting test/inference...")
# 加载训练好的模型权重
model.load_state_dict(torch.load(SAVE_MODEL_PATH, map_location=DEVICE))
model.eval() # 确保模型在评估模式
# 计算训练集mAP
print("计算训练集mAP...")
train_map = evaluate_map(model, data_loader_train, DEVICE, DATASET_TYPE,
category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME)
print(f'训练集mAP: {train_map:.4f}')
# 计算验证集mAP
print("计算验证集mAP...")
val_map = evaluate_map(model, data_loader_val, DEVICE, DATASET_TYPE,
category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME)
print(f'验证集mAP: {val_map:.4f}')
# 计算测试集mAP
print("计算测试集mAP...")
test_map = evaluate_map(model, data_loader_test, DEVICE, DATASET_TYPE,
category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME)
print(f'测试集mAP: {test_map:.4f}')
# 记录最终结果
logger.info(f"Final results - Training mAP: {train_map:.4f}, Validation mAP: {val_map:.4f}, Test mAP: {test_map:.4f}")
# --- About YOLOv3 ---
# --- End ---
print("Code execution completed.")