Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from collections import OrderedDict | |
| from mmengine.dist import all_reduce_dict, get_dist_info | |
| from mmengine.hooks import Hook | |
| from torch import nn | |
| from mmpose.registry import HOOKS | |
| def get_norm_states(module: nn.Module) -> OrderedDict: | |
| """Get the state_dict of batch norms in the module.""" | |
| async_norm_states = OrderedDict() | |
| for name, child in module.named_modules(): | |
| if isinstance(child, nn.modules.batchnorm._NormBase): | |
| for k, v in child.state_dict().items(): | |
| async_norm_states['.'.join([name, k])] = v | |
| return async_norm_states | |
| class SyncNormHook(Hook): | |
| """Synchronize Norm states before validation.""" | |
| def before_val_epoch(self, runner): | |
| """Synchronize normalization statistics.""" | |
| module = runner.model | |
| rank, world_size = get_dist_info() | |
| if world_size == 1: | |
| return | |
| norm_states = get_norm_states(module) | |
| if len(norm_states) == 0: | |
| return | |
| try: | |
| norm_states = all_reduce_dict(norm_states, op='mean') | |
| module.load_state_dict(norm_states, strict=True) | |
| except Exception as e: | |
| runner.logger.warn(f'SyncNormHook failed: {str(e)}') | |