Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.model import BaseModule | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| from mmdet.registry import MODELS | |
| from mmdet.structures.bbox import bbox_cxcywh_to_xyxy | |
| from mmdet.structures.det_data_sample import SampleList | |
| from mmdet.utils import InstanceList, OptConfigType | |
| class EmbeddingRPNHead(BaseModule): | |
| """RPNHead in the `Sparse R-CNN <https://arxiv.org/abs/2011.12450>`_ . | |
| Unlike traditional RPNHead, this module does not need FPN input, but just | |
| decode `init_proposal_bboxes` and expand the first dimension of | |
| `init_proposal_bboxes` and `init_proposal_features` to the batch_size. | |
| Args: | |
| num_proposals (int): Number of init_proposals. Defaults to 100. | |
| proposal_feature_channel (int): Channel number of | |
| init_proposal_feature. Defaults to 256. | |
| init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ | |
| dict]): Initialization config dict. Defaults to None. | |
| """ | |
| def __init__(self, | |
| num_proposals: int = 100, | |
| proposal_feature_channel: int = 256, | |
| init_cfg: OptConfigType = None, | |
| **kwargs) -> None: | |
| # `**kwargs` is necessary to avoid some potential error. | |
| assert init_cfg is None, 'To prevent abnormal initialization ' \ | |
| 'behavior, init_cfg is not allowed to be set' | |
| super().__init__(init_cfg=init_cfg) | |
| self.num_proposals = num_proposals | |
| self.proposal_feature_channel = proposal_feature_channel | |
| self._init_layers() | |
| def _init_layers(self) -> None: | |
| """Initialize a sparse set of proposal boxes and proposal features.""" | |
| self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4) | |
| self.init_proposal_features = nn.Embedding( | |
| self.num_proposals, self.proposal_feature_channel) | |
| def init_weights(self) -> None: | |
| """Initialize the init_proposal_bboxes as normalized. | |
| [c_x, c_y, w, h], and we initialize it to the size of the entire | |
| image. | |
| """ | |
| super().init_weights() | |
| nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5) | |
| nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1) | |
| def _decode_init_proposals(self, x: List[Tensor], | |
| batch_data_samples: SampleList) -> InstanceList: | |
| """Decode init_proposal_bboxes according to the size of images and | |
| expand dimension of init_proposal_features to batch_size. | |
| Args: | |
| x (list[Tensor]): List of FPN features. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| Returns: | |
| List[:obj:`InstanceData`:] Detection results of each image. | |
| Each item usually contains following keys. | |
| - proposals: Decoded proposal bboxes, | |
| has shape (num_proposals, 4). | |
| - features: init_proposal_features, expanded proposal | |
| features, has shape | |
| (num_proposals, proposal_feature_channel). | |
| - imgs_whwh: Tensor with shape | |
| (num_proposals, 4), the dimension means | |
| [img_width, img_height, img_width, img_height]. | |
| """ | |
| batch_img_metas = [] | |
| for data_sample in batch_data_samples: | |
| batch_img_metas.append(data_sample.metainfo) | |
| proposals = self.init_proposal_bboxes.weight.clone() | |
| proposals = bbox_cxcywh_to_xyxy(proposals) | |
| imgs_whwh = [] | |
| for meta in batch_img_metas: | |
| h, w = meta['img_shape'][:2] | |
| imgs_whwh.append(x[0].new_tensor([[w, h, w, h]])) | |
| imgs_whwh = torch.cat(imgs_whwh, dim=0) | |
| imgs_whwh = imgs_whwh[:, None, :] | |
| proposals = proposals * imgs_whwh | |
| rpn_results_list = [] | |
| for idx in range(len(batch_img_metas)): | |
| rpn_results = InstanceData() | |
| rpn_results.bboxes = proposals[idx] | |
| rpn_results.imgs_whwh = imgs_whwh[idx].repeat( | |
| self.num_proposals, 1) | |
| rpn_results.features = self.init_proposal_features.weight.clone() | |
| rpn_results_list.append(rpn_results) | |
| return rpn_results_list | |
| def loss(self, *args, **kwargs): | |
| """Perform forward propagation and loss calculation of the detection | |
| head on the features of the upstream network.""" | |
| raise NotImplementedError( | |
| 'EmbeddingRPNHead does not have `loss`, please use ' | |
| '`predict` or `loss_and_predict` instead.') | |
| def predict(self, x: List[Tensor], batch_data_samples: SampleList, | |
| **kwargs) -> InstanceList: | |
| """Perform forward propagation of the detection head and predict | |
| detection results on the features of the upstream network.""" | |
| # `**kwargs` is necessary to avoid some potential error. | |
| return self._decode_init_proposals( | |
| x=x, batch_data_samples=batch_data_samples) | |
| def loss_and_predict(self, x: List[Tensor], batch_data_samples: SampleList, | |
| **kwargs) -> tuple: | |
| """Perform forward propagation of the head, then calculate loss and | |
| predictions from the features and data samples.""" | |
| # `**kwargs` is necessary to avoid some potential error. | |
| predictions = self._decode_init_proposals( | |
| x=x, batch_data_samples=batch_data_samples) | |
| return dict(), predictions | |