Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Callable, Dict, Optional | |
| import torch | |
| from mmdeploy.codebase.base import CODEBASE, MMCodebase | |
| from mmdeploy.codebase.mmdet.deploy import ObjectDetection | |
| from mmdeploy.utils import Codebase, Task | |
| from mmengine import Config | |
| from mmengine.registry import Registry | |
| MMYOLO_TASK = Registry('mmyolo_tasks') | |
| class MMYOLO(MMCodebase): | |
| """MMYOLO codebase class.""" | |
| task_registry = MMYOLO_TASK | |
| def register_deploy_modules(cls): | |
| """register all rewriters for mmdet.""" | |
| import mmdeploy.codebase.mmdet.models # noqa: F401 | |
| import mmdeploy.codebase.mmdet.ops # noqa: F401 | |
| import mmdeploy.codebase.mmdet.structures # noqa: F401 | |
| def register_all_modules(cls): | |
| """register all modules.""" | |
| from mmdet.utils.setup_env import \ | |
| register_all_modules as register_all_modules_mmdet | |
| from mmyolo.utils.setup_env import \ | |
| register_all_modules as register_all_modules_mmyolo | |
| cls.register_deploy_modules() | |
| register_all_modules_mmyolo(True) | |
| register_all_modules_mmdet(False) | |
| def _get_dataset_metainfo(model_cfg: Config): | |
| """Get metainfo of dataset. | |
| Args: | |
| model_cfg Config: Input model Config object. | |
| Returns: | |
| list[str]: A list of string specifying names of different class. | |
| """ | |
| from mmyolo import datasets # noqa | |
| from mmyolo.registry import DATASETS | |
| module_dict = DATASETS.module_dict | |
| for dataloader_name in [ | |
| 'test_dataloader', 'val_dataloader', 'train_dataloader' | |
| ]: | |
| if dataloader_name not in model_cfg: | |
| continue | |
| dataloader_cfg = model_cfg[dataloader_name] | |
| dataset_cfg = dataloader_cfg.dataset | |
| dataset_cls = module_dict.get(dataset_cfg.type, None) | |
| if dataset_cls is None: | |
| continue | |
| if hasattr(dataset_cls, '_load_metainfo') and isinstance( | |
| dataset_cls._load_metainfo, Callable): | |
| meta = dataset_cls._load_metainfo( | |
| dataset_cfg.get('metainfo', None)) | |
| if meta is not None: | |
| return meta | |
| if hasattr(dataset_cls, 'METAINFO'): | |
| return dataset_cls.METAINFO | |
| return None | |
| class YOLOObjectDetection(ObjectDetection): | |
| """YOLO Object Detection task.""" | |
| def get_visualizer(self, name: str, save_dir: str): | |
| """Get visualizer. | |
| Args: | |
| name (str): Name of visualizer. | |
| save_dir (str): Directory to save visualization results. | |
| Returns: | |
| Visualizer: A visualizer instance. | |
| """ | |
| from mmdet.visualization import DetLocalVisualizer # noqa: F401,F403 | |
| metainfo = _get_dataset_metainfo(self.model_cfg) | |
| visualizer = super().get_visualizer(name, save_dir) | |
| if metainfo is not None: | |
| visualizer.dataset_meta = metainfo | |
| return visualizer | |
| def build_pytorch_model(self, | |
| model_checkpoint: Optional[str] = None, | |
| cfg_options: Optional[Dict] = None, | |
| **kwargs) -> torch.nn.Module: | |
| """Initialize torch model. | |
| Args: | |
| model_checkpoint (str): The checkpoint file of torch model, | |
| defaults to `None`. | |
| cfg_options (dict): Optional config key-pair parameters. | |
| Returns: | |
| nn.Module: An initialized torch model generated by other OpenMMLab | |
| codebases. | |
| """ | |
| from copy import deepcopy | |
| from mmengine.model import revert_sync_batchnorm | |
| from mmengine.registry import MODELS | |
| from mmyolo.utils import switch_to_deploy | |
| model = deepcopy(self.model_cfg.model) | |
| preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {})) | |
| preprocess_cfg.update( | |
| deepcopy(self.model_cfg.get('data_preprocessor', {}))) | |
| model.setdefault('data_preprocessor', preprocess_cfg) | |
| model = MODELS.build(model) | |
| if model_checkpoint is not None: | |
| from mmengine.runner.checkpoint import load_checkpoint | |
| load_checkpoint(model, model_checkpoint, map_location=self.device) | |
| model = revert_sync_batchnorm(model) | |
| switch_to_deploy(model) | |
| model = model.to(self.device) | |
| model.eval() | |
| return model | |