Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| from typing import Dict, Sequence | |
| from mmengine.hooks import Hook | |
| from mmengine.model import is_model_wrapper | |
| from mmengine.runner import Runner | |
| from mmpose.registry import HOOKS | |
| from mmpose.utils.hooks import rgetattr, rsetattr | |
| class YOLOXPoseModeSwitchHook(Hook): | |
| """Switch the mode of YOLOX-Pose during training. | |
| This hook: | |
| 1) Turns off mosaic and mixup data augmentation. | |
| 2) Uses instance mask to assist positive anchor selection. | |
| 3) Uses auxiliary L1 loss in the head. | |
| Args: | |
| num_last_epochs (int): The number of last epochs at the end of | |
| training to close the data augmentation and switch to L1 loss. | |
| Defaults to 20. | |
| new_train_dataset (dict): New training dataset configuration that | |
| will be used in place of the original training dataset. Defaults | |
| to None. | |
| new_train_pipeline (Sequence[dict]): New data augmentation pipeline | |
| configuration that will be used in place of the original pipeline | |
| during training. Defaults to None. | |
| """ | |
| def __init__(self, | |
| num_last_epochs: int = 20, | |
| new_train_dataset: dict = None, | |
| new_train_pipeline: Sequence[dict] = None): | |
| self.num_last_epochs = num_last_epochs | |
| self.new_train_dataset = new_train_dataset | |
| self.new_train_pipeline = new_train_pipeline | |
| def _modify_dataloader(self, runner: Runner): | |
| """Modify dataloader with new dataset and pipeline configurations.""" | |
| runner.logger.info(f'New Pipeline: {self.new_train_pipeline}') | |
| train_dataloader_cfg = copy.deepcopy(runner.cfg.train_dataloader) | |
| if self.new_train_dataset: | |
| train_dataloader_cfg.dataset = self.new_train_dataset | |
| if self.new_train_pipeline: | |
| train_dataloader_cfg.dataset.pipeline = self.new_train_pipeline | |
| new_train_dataloader = Runner.build_dataloader(train_dataloader_cfg) | |
| runner.train_loop.dataloader = new_train_dataloader | |
| runner.logger.info('Recreated the dataloader!') | |
| def before_train_epoch(self, runner: Runner): | |
| """Close mosaic and mixup augmentation, switch to use L1 loss.""" | |
| epoch = runner.epoch | |
| model = runner.model | |
| if is_model_wrapper(model): | |
| model = model.module | |
| if epoch + 1 == runner.max_epochs - self.num_last_epochs: | |
| self._modify_dataloader(runner) | |
| runner.logger.info('Added additional reg loss now!') | |
| model.head.use_aux_loss = True | |
| class RTMOModeSwitchHook(Hook): | |
| """A hook to switch the mode of RTMO during training. | |
| This hook allows for dynamic adjustments of model attributes at specified | |
| training epochs. It is designed to modify configurations such as turning | |
| off specific augmentations or changing loss functions at different stages | |
| of the training process. | |
| Args: | |
| epoch_attributes (Dict[str, Dict]): A dictionary where keys are epoch | |
| numbers and values are attribute modification dictionaries. Each | |
| dictionary specifies the attribute to modify and its new value. | |
| Example: | |
| epoch_attributes = { | |
| 5: [{"attr1.subattr": new_value1}, {"attr2.subattr": new_value2}], | |
| 10: [{"attr3.subattr": new_value3}] | |
| } | |
| """ | |
| def __init__(self, epoch_attributes: Dict[int, Dict]): | |
| self.epoch_attributes = epoch_attributes | |
| def before_train_epoch(self, runner: Runner): | |
| """Method called before each training epoch. | |
| It checks if the current epoch is in the `epoch_attributes` mapping and | |
| applies the corresponding attribute changes to the model. | |
| """ | |
| epoch = runner.epoch | |
| model = runner.model | |
| if is_model_wrapper(model): | |
| model = model.module | |
| if epoch in self.epoch_attributes: | |
| for key, value in self.epoch_attributes[epoch].items(): | |
| rsetattr(model.head, key, value) | |
| runner.logger.info( | |
| f'Change model.head.{key} to {rgetattr(model.head, key)}') | |