|
|
import torch |
|
|
import torchvision |
|
|
from torchvision.models.detection import FasterRCNN |
|
|
from torchvision.models.detection.rpn import AnchorGenerator |
|
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor |
|
|
from torchvision.datasets import CocoDetection, VOCDetection |
|
|
import torchvision.transforms.v2 as T |
|
|
from torch.utils.data import DataLoader, Subset, Dataset |
|
|
import os |
|
|
import json |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as patches |
|
|
import utils |
|
|
|
|
|
import argparse |
|
|
import glob |
|
|
import xml.etree.ElementTree as ET |
|
|
import logging |
|
|
import numpy as np |
|
|
from pycocotools.coco import COCO |
|
|
from pycocotools.cocoeval import COCOeval |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='目标检测训练脚本') |
|
|
parser.add_argument('--dataset_type', type=str, default='coco', choices=['coco', 'pascal', 'yolo'], |
|
|
help='数据集类型:coco, pascal, 或 yolo') |
|
|
args = parser.parse_args() |
|
|
DATASET_TYPE = args.dataset_type |
|
|
print(f"选择的数据集类型: {DATASET_TYPE}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
COCO_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/coco' |
|
|
COCO_TRAIN_IMG_DIR = os.path.join(COCO_BASE_PATH, 'train') |
|
|
COCO_TRAIN_ANN_FILE = os.path.join(COCO_BASE_PATH, 'train/train_annotations.json') |
|
|
COCO_VAL_IMG_DIR = os.path.join(COCO_BASE_PATH, 'valid') |
|
|
COCO_VAL_ANN_FILE = os.path.join(COCO_BASE_PATH, 'valid/valid_annotations.json') |
|
|
COCO_TEST_IMG_DIR = os.path.join(COCO_BASE_PATH, 'test') |
|
|
COCO_TEST_ANN_FILE = os.path.join(COCO_BASE_PATH, 'test/test_annotations.json') |
|
|
|
|
|
PASCAL_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/pascal' |
|
|
PASCAL_YEAR = '2007' |
|
|
|
|
|
YOLO_BASE_PATH = '/hdd_16T/Zirui/work/CNNa2/dataset/dataset_42028assg2_24902417/Object_Detection/yolo' |
|
|
|
|
|
|
|
|
|
|
|
def load_coco_categories(ann_file): |
|
|
with open(ann_file, 'r') as f: |
|
|
data = json.load(f) |
|
|
categories = {cat['id']: cat['name'] for cat in data['categories']} |
|
|
return categories, len(categories) |
|
|
|
|
|
|
|
|
|
|
|
PASCAL_CLASSES = [ |
|
|
"__background__", |
|
|
"young", "empty_pod", |
|
|
] |
|
|
|
|
|
PASCAL_NAME_TO_ID = {name: i for i, name in enumerate(PASCAL_CLASSES) if name != "__background__"} |
|
|
PASCAL_ID_TO_NAME = {i: name for name, i in PASCAL_NAME_TO_ID.items()} |
|
|
PASCAL_NUM_CLASSES = len(PASCAL_CLASSES) - 1 |
|
|
|
|
|
|
|
|
def load_yolo_classes(classes_file=None): |
|
|
|
|
|
if classes_file and os.path.exists(classes_file): |
|
|
with open(classes_file, 'r') as f: |
|
|
classes = [line.strip() for line in f.readlines()] |
|
|
return {i+1: name for i, name in enumerate(classes)}, len(classes) |
|
|
|
|
|
|
|
|
|
|
|
yaml_file = os.path.join(YOLO_BASE_PATH, 'data.yaml') |
|
|
if os.path.exists(yaml_file): |
|
|
try: |
|
|
import yaml |
|
|
with open(yaml_file, 'r') as f: |
|
|
data = yaml.safe_load(f) |
|
|
if 'names' in data: |
|
|
classes = data['names'] |
|
|
return {i+1: name for i, name in enumerate(classes)}, len(classes) |
|
|
except Exception as e: |
|
|
print(f"从YAML加载类别失败: {e}") |
|
|
|
|
|
|
|
|
classes = ['Ready', 'empty_pod', 'germination', 'pod', 'young'] |
|
|
return {i+1: name for i, name in enumerate(classes)}, len(classes) |
|
|
|
|
|
|
|
|
if DATASET_TYPE == 'coco': |
|
|
category_id_to_name, NUM_CLASSES = load_coco_categories(COCO_TRAIN_ANN_FILE) |
|
|
elif DATASET_TYPE == 'pascal': |
|
|
category_id_to_name = PASCAL_ID_TO_NAME |
|
|
NUM_CLASSES = PASCAL_NUM_CLASSES |
|
|
elif DATASET_TYPE == 'yolo': |
|
|
category_id_to_name, NUM_CLASSES = load_yolo_classes() |
|
|
else: |
|
|
raise ValueError(f"不支持的数据集类型: {DATASET_TYPE}") |
|
|
|
|
|
print(f"类别数量: {NUM_CLASSES}") |
|
|
|
|
|
|
|
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
print(f"使用的设备: {DEVICE}") |
|
|
BATCH_SIZE = 32 |
|
|
NUM_EPOCHS = 10 |
|
|
LEARNING_RATE = 0.005 |
|
|
MOMENTUM = 0.9 |
|
|
WEIGHT_DECAY = 0.0005 |
|
|
SAVE_MODEL_PATH = f'faster_rcnn_{DATASET_TYPE}_model.pth' |
|
|
|
|
|
MODEL_NUM_CLASSES = NUM_CLASSES + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transform(train): |
|
|
transforms = [] |
|
|
if train: |
|
|
|
|
|
transforms.append(T.RandomHorizontalFlip(0.5)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transforms.append(T.PILToTensor()) |
|
|
transforms.append(T.ConvertImageDtype(torch.float32)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return T.Compose(transforms) |
|
|
|
|
|
|
|
|
class CustomCocoDetection(CocoDetection): |
|
|
def __getitem__(self, index): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img, coco_targets = super(CustomCocoDetection, self).__getitem__(index) |
|
|
|
|
|
|
|
|
image_id = self.ids[index] |
|
|
target = {} |
|
|
target["boxes"] = [] |
|
|
target["labels"] = [] |
|
|
target["image_id"] = torch.tensor([image_id]) |
|
|
target["area"] = [] |
|
|
target["iscrowd"] = [] |
|
|
|
|
|
for anno in coco_targets: |
|
|
x, y, w, h = anno["bbox"] |
|
|
xmin = x |
|
|
ymin = y |
|
|
xmax = x + w |
|
|
ymax = y + h |
|
|
if w > 0 and h > 0: |
|
|
target["boxes"].append([xmin, ymin, xmax, ymax]) |
|
|
target["labels"].append(anno["category_id"]) |
|
|
target["area"].append(anno["area"]) |
|
|
target["iscrowd"].append(anno["iscrowd"]) |
|
|
|
|
|
if len(target["boxes"]) > 0: |
|
|
target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32) |
|
|
target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64) |
|
|
target["area"] = torch.as_tensor(target["area"], dtype=torch.float32) |
|
|
target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8) |
|
|
else: |
|
|
|
|
|
target["boxes"] = torch.zeros((0, 4), dtype=torch.float32) |
|
|
target["labels"] = torch.zeros((0,), dtype=torch.int64) |
|
|
target["area"] = torch.zeros((0,), dtype=torch.float32) |
|
|
target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8) |
|
|
|
|
|
|
|
|
if self.transforms is not None: |
|
|
img = self.transforms(img) |
|
|
|
|
|
return img, target |
|
|
|
|
|
|
|
|
class CustomVOCDetection(Dataset): |
|
|
def __init__(self, root, year='2007', image_set='train', transforms=None): |
|
|
self.root = root |
|
|
self.year = year |
|
|
self.image_set = image_set |
|
|
self._transforms = transforms |
|
|
|
|
|
|
|
|
self.data_dir = os.path.join(root, image_set) |
|
|
|
|
|
if not os.path.exists(self.data_dir): |
|
|
raise FileNotFoundError(f"can't find Pascal data set directory: {self.data_dir}") |
|
|
|
|
|
|
|
|
self.images = [] |
|
|
for f in os.listdir(self.data_dir): |
|
|
if f.endswith('.jpg') or f.endswith('.png'): |
|
|
img_id = os.path.splitext(f)[0] |
|
|
xml_file = os.path.join(self.data_dir, f"{img_id}.xml") |
|
|
|
|
|
if os.path.exists(xml_file): |
|
|
self.images.append(img_id) |
|
|
|
|
|
|
|
|
self.name_to_id = PASCAL_NAME_TO_ID |
|
|
|
|
|
print(f"found {len(self.images)} Pascal VOC {image_set} images") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
img_id = self.images[index] |
|
|
|
|
|
|
|
|
img_path = os.path.join(self.data_dir, f"{img_id}.jpg") |
|
|
if not os.path.exists(img_path): |
|
|
img_path = os.path.join(self.data_dir, f"{img_id}.png") |
|
|
|
|
|
img = Image.open(img_path).convert("RGB") |
|
|
img_width, img_height = img.size |
|
|
|
|
|
|
|
|
xml_path = os.path.join(self.data_dir, f"{img_id}.xml") |
|
|
tree = ET.parse(xml_path) |
|
|
root = tree.getroot() |
|
|
|
|
|
target = {} |
|
|
target["boxes"] = [] |
|
|
target["labels"] = [] |
|
|
target["image_id"] = torch.tensor([index]) |
|
|
target["area"] = [] |
|
|
target["iscrowd"] = [] |
|
|
|
|
|
for obj in root.findall('object'): |
|
|
label_name = obj.find('name').text |
|
|
if label_name not in self.name_to_id: |
|
|
continue |
|
|
|
|
|
label_id = self.name_to_id[label_name] |
|
|
|
|
|
bbox = obj.find('bndbox') |
|
|
xmin = float(bbox.find('xmin').text) |
|
|
ymin = float(bbox.find('ymin').text) |
|
|
xmax = float(bbox.find('xmax').text) |
|
|
ymax = float(bbox.find('ymax').text) |
|
|
|
|
|
|
|
|
if xmax > xmin and ymax > ymin: |
|
|
target["boxes"].append([xmin, ymin, xmax, ymax]) |
|
|
target["labels"].append(label_id) |
|
|
area = (xmax - xmin) * (ymax - ymin) |
|
|
target["area"].append(area) |
|
|
target["iscrowd"].append(0) |
|
|
|
|
|
if len(target["boxes"]) > 0: |
|
|
target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32) |
|
|
target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64) |
|
|
target["area"] = torch.as_tensor(target["area"], dtype=torch.float32) |
|
|
target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8) |
|
|
else: |
|
|
target["boxes"] = torch.zeros((0, 4), dtype=torch.float32) |
|
|
target["labels"] = torch.zeros((0,), dtype=torch.int64) |
|
|
target["area"] = torch.zeros((0,), dtype=torch.float32) |
|
|
target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8) |
|
|
|
|
|
|
|
|
if self._transforms is not None: |
|
|
img = self._transforms(img) |
|
|
|
|
|
return img, target |
|
|
|
|
|
|
|
|
class CustomYOLODataset(Dataset): |
|
|
def __init__(self, img_dir, label_dir, classes_file=None, transforms=None): |
|
|
self.img_dir = img_dir |
|
|
self.label_dir = label_dir |
|
|
self.transforms = transforms |
|
|
|
|
|
|
|
|
self.img_files = sorted(glob.glob(os.path.join(img_dir, '*.jpg')) + |
|
|
glob.glob(os.path.join(img_dir, '*.png'))) |
|
|
|
|
|
|
|
|
self.valid_files = [] |
|
|
for img_path in self.img_files: |
|
|
base_name = os.path.basename(img_path).split('.')[0] |
|
|
label_path = os.path.join(label_dir, f"{base_name}.txt") |
|
|
if os.path.exists(label_path): |
|
|
self.valid_files.append((img_path, label_path)) |
|
|
|
|
|
print(f"found {len(self.valid_files)} YOLO format image-label pairs") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.valid_files) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
img_path, label_path = self.valid_files[idx] |
|
|
|
|
|
|
|
|
img = Image.open(img_path).convert("RGB") |
|
|
img_width, img_height = img.size |
|
|
|
|
|
target = {} |
|
|
target["boxes"] = [] |
|
|
target["labels"] = [] |
|
|
target["image_id"] = torch.tensor([idx]) |
|
|
target["area"] = [] |
|
|
target["iscrowd"] = [] |
|
|
|
|
|
|
|
|
|
|
|
with open(label_path, 'r') as f: |
|
|
for line in f.readlines(): |
|
|
if line.strip(): |
|
|
parts = line.strip().split() |
|
|
if len(parts) == 5: |
|
|
cls_id = int(parts[0]) + 1 |
|
|
x_center = float(parts[1]) * img_width |
|
|
y_center = float(parts[2]) * img_height |
|
|
width = float(parts[3]) * img_width |
|
|
height = float(parts[4]) * img_height |
|
|
|
|
|
|
|
|
xmin = max(0, x_center - width/2) |
|
|
ymin = max(0, y_center - height/2) |
|
|
xmax = min(img_width, x_center + width/2) |
|
|
ymax = min(img_height, y_center + height/2) |
|
|
|
|
|
|
|
|
if xmax > xmin and ymax > ymin: |
|
|
target["boxes"].append([xmin, ymin, xmax, ymax]) |
|
|
target["labels"].append(cls_id) |
|
|
area = (xmax - xmin) * (ymax - ymin) |
|
|
target["area"].append(area) |
|
|
target["iscrowd"].append(0) |
|
|
|
|
|
if len(target["boxes"]) > 0: |
|
|
target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32) |
|
|
target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64) |
|
|
target["area"] = torch.as_tensor(target["area"], dtype=torch.float32) |
|
|
target["iscrowd"] = torch.as_tensor(target["iscrowd"], dtype=torch.uint8) |
|
|
else: |
|
|
target["boxes"] = torch.zeros((0, 4), dtype=torch.float32) |
|
|
target["labels"] = torch.zeros((0,), dtype=torch.int64) |
|
|
target["area"] = torch.zeros((0,), dtype=torch.float32) |
|
|
target["iscrowd"] = torch.zeros((0,), dtype=torch.uint8) |
|
|
|
|
|
|
|
|
if self.transforms: |
|
|
img = self.transforms(img) |
|
|
|
|
|
return img, target |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if DATASET_TYPE == 'coco': |
|
|
dataset_train = CustomCocoDetection(root=COCO_TRAIN_IMG_DIR, |
|
|
annFile=COCO_TRAIN_ANN_FILE, |
|
|
transforms=get_transform(train=True)) |
|
|
|
|
|
dataset_val = CustomCocoDetection(root=COCO_VAL_IMG_DIR, |
|
|
annFile=COCO_VAL_ANN_FILE, |
|
|
transforms=get_transform(train=False)) |
|
|
|
|
|
dataset_test = CustomCocoDetection(root=COCO_TEST_IMG_DIR, |
|
|
annFile=COCO_TEST_ANN_FILE, |
|
|
transforms=get_transform(train=False)) |
|
|
|
|
|
elif DATASET_TYPE == 'pascal': |
|
|
dataset_train = CustomVOCDetection(root=PASCAL_BASE_PATH, |
|
|
year=PASCAL_YEAR, |
|
|
image_set='train', |
|
|
transforms=get_transform(train=True)) |
|
|
|
|
|
dataset_val = CustomVOCDetection(root=PASCAL_BASE_PATH, |
|
|
year=PASCAL_YEAR, |
|
|
image_set='valid', |
|
|
transforms=get_transform(train=False)) |
|
|
|
|
|
dataset_test = CustomVOCDetection(root=PASCAL_BASE_PATH, |
|
|
year=PASCAL_YEAR, |
|
|
image_set='test', |
|
|
transforms=get_transform(train=False)) |
|
|
|
|
|
|
|
|
print("\nDebug information - Pascal VOC dataset:") |
|
|
print(f"Category mapping: {PASCAL_NAME_TO_ID}") |
|
|
for i in range(min(3, len(dataset_train))): |
|
|
_, target = dataset_train[i] |
|
|
print(f"Sample {i}:") |
|
|
print(f" Number of bounding boxes: {len(target['boxes'])}") |
|
|
print(f" Labels: {target['labels'].tolist()}") |
|
|
if len(target['labels']) > 0: |
|
|
print(f" Category names: {[PASCAL_ID_TO_NAME[label.item()] for label in target['labels']]}") |
|
|
print(f" Bounding boxes: {target['boxes'].shape}") |
|
|
|
|
|
|
|
|
elif DATASET_TYPE == 'yolo': |
|
|
|
|
|
YOLO_TRAIN_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'train', 'images') |
|
|
YOLO_TRAIN_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'train', 'labels') |
|
|
YOLO_VAL_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'valid', 'images') |
|
|
YOLO_VAL_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'valid', 'labels') |
|
|
YOLO_TEST_IMG_DIR = os.path.join(YOLO_BASE_PATH, 'test', 'images') |
|
|
YOLO_TEST_LABEL_DIR = os.path.join(YOLO_BASE_PATH, 'test', 'labels') |
|
|
|
|
|
dataset_train = CustomYOLODataset( |
|
|
img_dir=YOLO_TRAIN_IMG_DIR, |
|
|
label_dir=YOLO_TRAIN_LABEL_DIR, |
|
|
classes_file=None, |
|
|
transforms=get_transform(train=True) |
|
|
) |
|
|
|
|
|
dataset_val = CustomYOLODataset( |
|
|
img_dir=YOLO_VAL_IMG_DIR, |
|
|
label_dir=YOLO_VAL_LABEL_DIR, |
|
|
classes_file=None, |
|
|
transforms=get_transform(train=False) |
|
|
) |
|
|
|
|
|
dataset_test = CustomYOLODataset( |
|
|
img_dir=YOLO_TEST_IMG_DIR, |
|
|
label_dir=YOLO_TEST_LABEL_DIR, |
|
|
classes_file=None, |
|
|
transforms=get_transform(train=False) |
|
|
) |
|
|
|
|
|
|
|
|
print("\nDebug information - YOLO dataset:") |
|
|
print(f"Category mapping: {category_id_to_name}") |
|
|
for i in range(min(3, len(dataset_train))): |
|
|
_, target = dataset_train[i] |
|
|
print(f"Sample {i}:") |
|
|
print(f" Number of bounding boxes: {len(target['boxes'])}") |
|
|
print(f" Labels: {target['labels'].tolist()}") |
|
|
if len(target['labels']) > 0: |
|
|
print(f" Category names: {[category_id_to_name.get(label.item(), f'unknown-{label.item()}') for label in target['labels']]}") |
|
|
print(f" Bounding boxes: {target['boxes'].shape}") |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported dataset type: {DATASET_TYPE}") |
|
|
|
|
|
print(f"Training set size: {len(dataset_train)}") |
|
|
print(f"Validation set size: {len(dataset_val)}") |
|
|
print(f"Test set size: {len(dataset_test)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
return tuple(zip(*batch)) |
|
|
|
|
|
data_loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, |
|
|
num_workers=4, collate_fn=collate_fn) |
|
|
data_loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, |
|
|
num_workers=4, collate_fn=collate_fn) |
|
|
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, |
|
|
num_workers=4, collate_fn=collate_fn) |
|
|
|
|
|
except FileNotFoundError as e: |
|
|
print(f"Error: can't find training or validation files: {e}") |
|
|
print(f"Please ensure the file paths for the selected dataset {DATASET_TYPE} are correct and files exist.") |
|
|
exit() |
|
|
except NotImplementedError as e: |
|
|
print(f"Error: {e}") |
|
|
print("Please implement the corresponding dataset loading class and try again.") |
|
|
exit() |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
def get_faster_rcnn_model(num_classes): |
|
|
|
|
|
weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT |
|
|
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights) |
|
|
|
|
|
|
|
|
in_features = model.roi_heads.box_predictor.cls_score.in_features |
|
|
|
|
|
|
|
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
model = get_faster_rcnn_model(MODEL_NUM_CLASSES) |
|
|
model.to(DEVICE) |
|
|
|
|
|
|
|
|
print(f"\nModel information:") |
|
|
print(f"Number of input classes: {MODEL_NUM_CLASSES} (including background class)") |
|
|
print(f"Used device: {DEVICE}") |
|
|
print(f"Predictor input features: {model.roi_heads.box_predictor.cls_score.in_features}") |
|
|
print(f"Predictor output classes: {model.roi_heads.box_predictor.cls_score.out_features}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params = [p for p in model.parameters() if p.requires_grad] |
|
|
optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) |
|
|
|
|
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) |
|
|
|
|
|
|
|
|
|
|
|
print("Starting training...") |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler(f'training_{DATASET_TYPE}_Faster_R-CNN.log'), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
|
|
|
logger = logging.getLogger('faster_rcnn') |
|
|
|
|
|
|
|
|
logger.info(f"Selected dataset type: {DATASET_TYPE}") |
|
|
logger.info(f"Training set size: {len(dataset_train)}") |
|
|
|
|
|
|
|
|
def evaluate_map(model, data_loader, device, dataset_type, category_mapping=None): |
|
|
model.eval() |
|
|
|
|
|
|
|
|
all_predictions = [] |
|
|
all_targets = [] |
|
|
|
|
|
print(f"Caluating mAP...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for images, targets in data_loader: |
|
|
images = list(image.to(device) for image in images) |
|
|
|
|
|
|
|
|
predictions = model(images) |
|
|
|
|
|
|
|
|
all_predictions.extend(predictions) |
|
|
all_targets.extend(targets) |
|
|
|
|
|
|
|
|
pred_instances = [] |
|
|
gt_instances = [] |
|
|
|
|
|
image_id = 0 |
|
|
instance_id = 0 |
|
|
|
|
|
for pred, target in zip(all_predictions, all_targets): |
|
|
|
|
|
boxes = pred['boxes'].cpu().numpy() |
|
|
scores = pred['scores'].cpu().numpy() |
|
|
labels = pred['labels'].cpu().numpy() |
|
|
|
|
|
|
|
|
keep = scores > 0.05 |
|
|
boxes = boxes[keep] |
|
|
scores = scores[keep] |
|
|
labels = labels[keep] |
|
|
|
|
|
for box, score, label in zip(boxes, scores, labels): |
|
|
pred_instances.append({ |
|
|
'image_id': image_id, |
|
|
'category_id': int(label), |
|
|
'bbox': [float(box[0]), float(box[1]), float(box[2] - box[0]), float(box[3] - box[1])], |
|
|
'score': float(score), |
|
|
'id': instance_id |
|
|
}) |
|
|
instance_id += 1 |
|
|
|
|
|
|
|
|
gt_boxes = target['boxes'].cpu().numpy() |
|
|
gt_labels = target['labels'].cpu().numpy() |
|
|
|
|
|
for gt_box, gt_label in zip(gt_boxes, gt_labels): |
|
|
gt_instances.append({ |
|
|
'image_id': image_id, |
|
|
'category_id': int(gt_label), |
|
|
'bbox': [float(gt_box[0]), float(gt_box[1]), float(gt_box[2] - gt_box[0]), float(gt_box[3] - gt_box[1])], |
|
|
'area': float((gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])), |
|
|
'iscrowd': 0, |
|
|
'id': instance_id |
|
|
}) |
|
|
instance_id += 1 |
|
|
|
|
|
image_id += 1 |
|
|
|
|
|
|
|
|
coco_gt = COCO() |
|
|
coco_dt = COCO() |
|
|
|
|
|
|
|
|
categories = [] |
|
|
if dataset_type == 'coco': |
|
|
for id, name in category_mapping.items(): |
|
|
categories.append({'id': id, 'name': name}) |
|
|
elif dataset_type == 'pascal': |
|
|
for id, name in category_mapping.items(): |
|
|
categories.append({'id': id, 'name': name}) |
|
|
elif dataset_type == 'yolo': |
|
|
for id, name in category_mapping.items(): |
|
|
categories.append({'id': id, 'name': name}) |
|
|
|
|
|
|
|
|
coco_gt.dataset = { |
|
|
'images': [{'id': i} for i in range(image_id)], |
|
|
'annotations': gt_instances, |
|
|
'categories': categories |
|
|
} |
|
|
coco_gt.createIndex() |
|
|
|
|
|
coco_dt.dataset = { |
|
|
'images': [{'id': i} for i in range(image_id)], |
|
|
'annotations': pred_instances, |
|
|
'categories': categories |
|
|
} |
|
|
coco_dt.createIndex() |
|
|
|
|
|
|
|
|
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') |
|
|
coco_eval.evaluate() |
|
|
coco_eval.accumulate() |
|
|
coco_eval.summarize() |
|
|
|
|
|
|
|
|
return coco_eval.stats[0] |
|
|
|
|
|
for epoch in range(NUM_EPOCHS): |
|
|
model.train() |
|
|
epoch_loss = 0 |
|
|
for i, (images, targets) in enumerate(data_loader_train): |
|
|
images = list(image.to(DEVICE) for image in images) |
|
|
|
|
|
targets = [{k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
|
|
|
loss_dict = model(images, targets) |
|
|
losses = sum(loss for loss in loss_dict.values()) |
|
|
|
|
|
|
|
|
if i < 2 or i % 50 == 0: |
|
|
print(f" 损失详情: {', '.join([f'{k}: {v.item():.6f}' for k, v in loss_dict.items()])}") |
|
|
|
|
|
|
|
|
losses.backward() |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
batch_loss = losses.item() |
|
|
epoch_loss += batch_loss |
|
|
|
|
|
if (i + 1) % 50 == 0: |
|
|
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(data_loader_train)}], Batch Loss: {batch_loss:.6f}') |
|
|
|
|
|
|
|
|
lr_scheduler.step() |
|
|
|
|
|
avg_epoch_loss = epoch_loss / len(data_loader_train) |
|
|
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] 结束, 平均训练损失: {avg_epoch_loss:.4f}') |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_loss = 0 |
|
|
with torch.no_grad(): |
|
|
for images, targets in data_loader_val: |
|
|
images = list(image.to(DEVICE) for image in images) |
|
|
|
|
|
|
|
|
targets = [{k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
print("Calculating mAP...") |
|
|
val_map = evaluate_map(model, data_loader_val, DEVICE, DATASET_TYPE, |
|
|
category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME) |
|
|
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Validation mAP: {val_map:.4f}') |
|
|
logger.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Validation mAP: {val_map:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), f'faster_rcnn_{DATASET_TYPE}_epoch_{epoch+1}.pth') |
|
|
print(f"Model saved to faster_rcnn_{DATASET_TYPE}_epoch_{epoch+1}.pth") |
|
|
|
|
|
|
|
|
logger.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Average training loss: {avg_epoch_loss:.4f}, Validation mAP: {val_map:.4f}") |
|
|
|
|
|
print("Training completed!") |
|
|
|
|
|
torch.save(model.state_dict(), SAVE_MODEL_PATH) |
|
|
print(f"Final model saved to {SAVE_MODEL_PATH}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Starting test/inference...") |
|
|
|
|
|
model.load_state_dict(torch.load(SAVE_MODEL_PATH, map_location=DEVICE)) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
print("计算训练集mAP...") |
|
|
train_map = evaluate_map(model, data_loader_train, DEVICE, DATASET_TYPE, |
|
|
category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME) |
|
|
print(f'训练集mAP: {train_map:.4f}') |
|
|
|
|
|
|
|
|
print("计算验证集mAP...") |
|
|
val_map = evaluate_map(model, data_loader_val, DEVICE, DATASET_TYPE, |
|
|
category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME) |
|
|
print(f'验证集mAP: {val_map:.4f}') |
|
|
|
|
|
|
|
|
print("计算测试集mAP...") |
|
|
test_map = evaluate_map(model, data_loader_test, DEVICE, DATASET_TYPE, |
|
|
category_id_to_name if DATASET_TYPE == 'coco' or DATASET_TYPE == 'yolo' else PASCAL_ID_TO_NAME) |
|
|
print(f'测试集mAP: {test_map:.4f}') |
|
|
|
|
|
|
|
|
logger.info(f"Final results - Training mAP: {train_map:.4f}, Validation mAP: {val_map:.4f}, Test mAP: {test_map:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Code execution completed.") |