Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule, Scale | |
| from mmengine.config import ConfigDict | |
| from mmengine.model import BaseModule, kaiming_init | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| from mmdet.registry import MODELS | |
| from mmdet.structures.bbox import cat_boxes | |
| from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, | |
| OptInstanceList, reduce_mean) | |
| from ..task_modules.prior_generators import MlvlPointGenerator | |
| from ..utils import (aligned_bilinear, filter_scores_and_topk, multi_apply, | |
| relative_coordinate_maps, select_single_mlvl) | |
| from ..utils.misc import empty_instances | |
| from .base_mask_head import BaseMaskHead | |
| from .fcos_head import FCOSHead | |
| INF = 1e8 | |
| class CondInstBboxHead(FCOSHead): | |
| """CondInst box head used in https://arxiv.org/abs/1904.02689. | |
| Note that CondInst Bbox Head is a extension of FCOS head. | |
| Two differences are described as follows: | |
| 1. CondInst box head predicts a set of params for each instance. | |
| 2. CondInst box head return the pos_gt_inds and pos_inds. | |
| Args: | |
| num_params (int): Number of params for instance segmentation. | |
| """ | |
| def __init__(self, *args, num_params: int = 169, **kwargs) -> None: | |
| self.num_params = num_params | |
| super().__init__(*args, **kwargs) | |
| def _init_layers(self) -> None: | |
| """Initialize layers of the head.""" | |
| super()._init_layers() | |
| self.controller = nn.Conv2d( | |
| self.feat_channels, self.num_params, 3, padding=1) | |
| def forward_single(self, x: Tensor, scale: Scale, | |
| stride: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
| """Forward features of a single scale level. | |
| Args: | |
| x (Tensor): FPN feature maps of the specified stride. | |
| scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize | |
| the bbox prediction. | |
| stride (int): The corresponding stride for feature maps, only | |
| used to normalize the bbox prediction when self.norm_on_bbox | |
| is True. | |
| Returns: | |
| tuple: scores for each class, bbox predictions, centerness | |
| predictions and param predictions of input feature maps. | |
| """ | |
| cls_score, bbox_pred, cls_feat, reg_feat = \ | |
| super(FCOSHead, self).forward_single(x) | |
| if self.centerness_on_reg: | |
| centerness = self.conv_centerness(reg_feat) | |
| else: | |
| centerness = self.conv_centerness(cls_feat) | |
| # scale the bbox_pred of different level | |
| # float to avoid overflow when enabling FP16 | |
| bbox_pred = scale(bbox_pred).float() | |
| if self.norm_on_bbox: | |
| # bbox_pred needed for gradient computation has been modified | |
| # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace | |
| # F.relu(bbox_pred) with bbox_pred.clamp(min=0) | |
| bbox_pred = bbox_pred.clamp(min=0) | |
| if not self.training: | |
| bbox_pred *= stride | |
| else: | |
| bbox_pred = bbox_pred.exp() | |
| param_pred = self.controller(reg_feat) | |
| return cls_score, bbox_pred, centerness, param_pred | |
| def loss_by_feat( | |
| self, | |
| cls_scores: List[Tensor], | |
| bbox_preds: List[Tensor], | |
| centernesses: List[Tensor], | |
| param_preds: List[Tensor], | |
| batch_gt_instances: InstanceList, | |
| batch_img_metas: List[dict], | |
| batch_gt_instances_ignore: OptInstanceList = None | |
| ) -> Dict[str, Tensor]: | |
| """Calculate the loss based on the features extracted by the detection | |
| head. | |
| Args: | |
| cls_scores (list[Tensor]): Box scores for each scale level, | |
| each is a 4D-tensor, the channel number is | |
| num_points * num_classes. | |
| bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
| level, each is a 4D-tensor, the channel number is | |
| num_points * 4. | |
| centernesses (list[Tensor]): centerness for each scale level, each | |
| is a 4D-tensor, the channel number is num_points * 1. | |
| param_preds (List[Tensor]): param_pred for each scale level, each | |
| is a 4D-tensor, the channel number is num_params. | |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``bboxes`` and ``labels`` | |
| attributes. | |
| batch_img_metas (list[dict]): Meta information of each image, e.g., | |
| image size, scaling factor, etc. | |
| batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): | |
| Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
| data that is ignored during training and testing. | |
| Defaults to None. | |
| Returns: | |
| dict[str, Tensor]: A dictionary of loss components. | |
| """ | |
| assert len(cls_scores) == len(bbox_preds) == len(centernesses) | |
| featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] | |
| # Need stride for rel coord compute | |
| all_level_points_strides = self.prior_generator.grid_priors( | |
| featmap_sizes, | |
| dtype=bbox_preds[0].dtype, | |
| device=bbox_preds[0].device, | |
| with_stride=True) | |
| all_level_points = [i[:, :2] for i in all_level_points_strides] | |
| all_level_strides = [i[:, 2] for i in all_level_points_strides] | |
| labels, bbox_targets, pos_inds_list, pos_gt_inds_list = \ | |
| self.get_targets(all_level_points, batch_gt_instances) | |
| num_imgs = cls_scores[0].size(0) | |
| # flatten cls_scores, bbox_preds and centerness | |
| flatten_cls_scores = [ | |
| cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) | |
| for cls_score in cls_scores | |
| ] | |
| flatten_bbox_preds = [ | |
| bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) | |
| for bbox_pred in bbox_preds | |
| ] | |
| flatten_centerness = [ | |
| centerness.permute(0, 2, 3, 1).reshape(-1) | |
| for centerness in centernesses | |
| ] | |
| flatten_cls_scores = torch.cat(flatten_cls_scores) | |
| flatten_bbox_preds = torch.cat(flatten_bbox_preds) | |
| flatten_centerness = torch.cat(flatten_centerness) | |
| flatten_labels = torch.cat(labels) | |
| flatten_bbox_targets = torch.cat(bbox_targets) | |
| # repeat points to align with bbox_preds | |
| flatten_points = torch.cat( | |
| [points.repeat(num_imgs, 1) for points in all_level_points]) | |
| # FG cat_id: [0, num_classes -1], BG cat_id: num_classes | |
| bg_class_ind = self.num_classes | |
| pos_inds = ((flatten_labels >= 0) | |
| & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) | |
| num_pos = torch.tensor( | |
| len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) | |
| num_pos = max(reduce_mean(num_pos), 1.0) | |
| loss_cls = self.loss_cls( | |
| flatten_cls_scores, flatten_labels, avg_factor=num_pos) | |
| pos_bbox_preds = flatten_bbox_preds[pos_inds] | |
| pos_centerness = flatten_centerness[pos_inds] | |
| pos_bbox_targets = flatten_bbox_targets[pos_inds] | |
| pos_centerness_targets = self.centerness_target(pos_bbox_targets) | |
| # centerness weighted iou loss | |
| centerness_denorm = max( | |
| reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) | |
| if len(pos_inds) > 0: | |
| pos_points = flatten_points[pos_inds] | |
| pos_decoded_bbox_preds = self.bbox_coder.decode( | |
| pos_points, pos_bbox_preds) | |
| pos_decoded_target_preds = self.bbox_coder.decode( | |
| pos_points, pos_bbox_targets) | |
| loss_bbox = self.loss_bbox( | |
| pos_decoded_bbox_preds, | |
| pos_decoded_target_preds, | |
| weight=pos_centerness_targets, | |
| avg_factor=centerness_denorm) | |
| loss_centerness = self.loss_centerness( | |
| pos_centerness, pos_centerness_targets, avg_factor=num_pos) | |
| else: | |
| loss_bbox = pos_bbox_preds.sum() | |
| loss_centerness = pos_centerness.sum() | |
| self._raw_positive_infos.update(cls_scores=cls_scores) | |
| self._raw_positive_infos.update(centernesses=centernesses) | |
| self._raw_positive_infos.update(param_preds=param_preds) | |
| self._raw_positive_infos.update(all_level_points=all_level_points) | |
| self._raw_positive_infos.update(all_level_strides=all_level_strides) | |
| self._raw_positive_infos.update(pos_gt_inds_list=pos_gt_inds_list) | |
| self._raw_positive_infos.update(pos_inds_list=pos_inds_list) | |
| return dict( | |
| loss_cls=loss_cls, | |
| loss_bbox=loss_bbox, | |
| loss_centerness=loss_centerness) | |
| def get_targets( | |
| self, points: List[Tensor], batch_gt_instances: InstanceList | |
| ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: | |
| """Compute regression, classification and centerness targets for points | |
| in multiple images. | |
| Args: | |
| points (list[Tensor]): Points of each fpn level, each has shape | |
| (num_points, 2). | |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``bboxes`` and ``labels`` | |
| attributes. | |
| Returns: | |
| tuple: Targets of each level. | |
| - concat_lvl_labels (list[Tensor]): Labels of each level. | |
| - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ | |
| level. | |
| - pos_inds_list (list[Tensor]): pos_inds of each image. | |
| - pos_gt_inds_list (List[Tensor]): pos_gt_inds of each image. | |
| """ | |
| assert len(points) == len(self.regress_ranges) | |
| num_levels = len(points) | |
| # expand regress ranges to align with points | |
| expanded_regress_ranges = [ | |
| points[i].new_tensor(self.regress_ranges[i])[None].expand_as( | |
| points[i]) for i in range(num_levels) | |
| ] | |
| # concat all levels points and regress ranges | |
| concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) | |
| concat_points = torch.cat(points, dim=0) | |
| # the number of points per img, per lvl | |
| num_points = [center.size(0) for center in points] | |
| # get labels and bbox_targets of each image | |
| labels_list, bbox_targets_list, pos_inds_list, pos_gt_inds_list = \ | |
| multi_apply( | |
| self._get_targets_single, | |
| batch_gt_instances, | |
| points=concat_points, | |
| regress_ranges=concat_regress_ranges, | |
| num_points_per_lvl=num_points) | |
| # split to per img, per level | |
| labels_list = [labels.split(num_points, 0) for labels in labels_list] | |
| bbox_targets_list = [ | |
| bbox_targets.split(num_points, 0) | |
| for bbox_targets in bbox_targets_list | |
| ] | |
| # concat per level image | |
| concat_lvl_labels = [] | |
| concat_lvl_bbox_targets = [] | |
| for i in range(num_levels): | |
| concat_lvl_labels.append( | |
| torch.cat([labels[i] for labels in labels_list])) | |
| bbox_targets = torch.cat( | |
| [bbox_targets[i] for bbox_targets in bbox_targets_list]) | |
| if self.norm_on_bbox: | |
| bbox_targets = bbox_targets / self.strides[i] | |
| concat_lvl_bbox_targets.append(bbox_targets) | |
| return (concat_lvl_labels, concat_lvl_bbox_targets, pos_inds_list, | |
| pos_gt_inds_list) | |
| def _get_targets_single( | |
| self, gt_instances: InstanceData, points: Tensor, | |
| regress_ranges: Tensor, num_points_per_lvl: List[int] | |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
| """Compute regression and classification targets for a single image.""" | |
| num_points = points.size(0) | |
| num_gts = len(gt_instances) | |
| gt_bboxes = gt_instances.bboxes | |
| gt_labels = gt_instances.labels | |
| gt_masks = gt_instances.get('masks', None) | |
| if num_gts == 0: | |
| return gt_labels.new_full((num_points,), self.num_classes), \ | |
| gt_bboxes.new_zeros((num_points, 4)), \ | |
| gt_bboxes.new_zeros((0,), dtype=torch.int64), \ | |
| gt_bboxes.new_zeros((0,), dtype=torch.int64) | |
| areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( | |
| gt_bboxes[:, 3] - gt_bboxes[:, 1]) | |
| # TODO: figure out why these two are different | |
| # areas = areas[None].expand(num_points, num_gts) | |
| areas = areas[None].repeat(num_points, 1) | |
| regress_ranges = regress_ranges[:, None, :].expand( | |
| num_points, num_gts, 2) | |
| gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) | |
| xs, ys = points[:, 0], points[:, 1] | |
| xs = xs[:, None].expand(num_points, num_gts) | |
| ys = ys[:, None].expand(num_points, num_gts) | |
| left = xs - gt_bboxes[..., 0] | |
| right = gt_bboxes[..., 2] - xs | |
| top = ys - gt_bboxes[..., 1] | |
| bottom = gt_bboxes[..., 3] - ys | |
| bbox_targets = torch.stack((left, top, right, bottom), -1) | |
| if self.center_sampling: | |
| # condition1: inside a `center bbox` | |
| radius = self.center_sample_radius | |
| # if gt_mask not None, use gt mask's centroid to determine | |
| # the center region rather than gt_bbox center | |
| if gt_masks is None: | |
| center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2 | |
| center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2 | |
| else: | |
| h, w = gt_masks.height, gt_masks.width | |
| masks = gt_masks.to_tensor( | |
| dtype=torch.bool, device=gt_bboxes.device) | |
| yys = torch.arange( | |
| 0, h, dtype=torch.float32, device=masks.device) | |
| xxs = torch.arange( | |
| 0, w, dtype=torch.float32, device=masks.device) | |
| # m00/m10/m01 represent the moments of a contour | |
| # centroid is computed by m00/m10 and m00/m01 | |
| m00 = masks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6) | |
| m10 = (masks * xxs).sum(dim=-1).sum(dim=-1) | |
| m01 = (masks * yys[:, None]).sum(dim=-1).sum(dim=-1) | |
| center_xs = m10 / m00 | |
| center_ys = m01 / m00 | |
| center_xs = center_xs[None].expand(num_points, num_gts) | |
| center_ys = center_ys[None].expand(num_points, num_gts) | |
| center_gts = torch.zeros_like(gt_bboxes) | |
| stride = center_xs.new_zeros(center_xs.shape) | |
| # project the points on current lvl back to the `original` sizes | |
| lvl_begin = 0 | |
| for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): | |
| lvl_end = lvl_begin + num_points_lvl | |
| stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius | |
| lvl_begin = lvl_end | |
| x_mins = center_xs - stride | |
| y_mins = center_ys - stride | |
| x_maxs = center_xs + stride | |
| y_maxs = center_ys + stride | |
| center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0], | |
| x_mins, gt_bboxes[..., 0]) | |
| center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1], | |
| y_mins, gt_bboxes[..., 1]) | |
| center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2], | |
| gt_bboxes[..., 2], x_maxs) | |
| center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3], | |
| gt_bboxes[..., 3], y_maxs) | |
| cb_dist_left = xs - center_gts[..., 0] | |
| cb_dist_right = center_gts[..., 2] - xs | |
| cb_dist_top = ys - center_gts[..., 1] | |
| cb_dist_bottom = center_gts[..., 3] - ys | |
| center_bbox = torch.stack( | |
| (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1) | |
| inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 | |
| else: | |
| # condition1: inside a gt bbox | |
| inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 | |
| # condition2: limit the regression range for each location | |
| max_regress_distance = bbox_targets.max(-1)[0] | |
| inside_regress_range = ( | |
| (max_regress_distance >= regress_ranges[..., 0]) | |
| & (max_regress_distance <= regress_ranges[..., 1])) | |
| # if there are still more than one objects for a location, | |
| # we choose the one with minimal area | |
| areas[inside_gt_bbox_mask == 0] = INF | |
| areas[inside_regress_range == 0] = INF | |
| min_area, min_area_inds = areas.min(dim=1) | |
| labels = gt_labels[min_area_inds] | |
| labels[min_area == INF] = self.num_classes # set as BG | |
| bbox_targets = bbox_targets[range(num_points), min_area_inds] | |
| # return pos_inds & pos_gt_inds | |
| bg_class_ind = self.num_classes | |
| pos_inds = ((labels >= 0) | |
| & (labels < bg_class_ind)).nonzero().reshape(-1) | |
| pos_gt_inds = min_area_inds[labels < self.num_classes] | |
| return labels, bbox_targets, pos_inds, pos_gt_inds | |
| def get_positive_infos(self) -> InstanceList: | |
| """Get positive information from sampling results. | |
| Returns: | |
| list[:obj:`InstanceData`]: Positive information of each image, | |
| usually including positive bboxes, positive labels, positive | |
| priors, etc. | |
| """ | |
| assert len(self._raw_positive_infos) > 0 | |
| pos_gt_inds_list = self._raw_positive_infos['pos_gt_inds_list'] | |
| pos_inds_list = self._raw_positive_infos['pos_inds_list'] | |
| num_imgs = len(pos_gt_inds_list) | |
| cls_score_list = [] | |
| centerness_list = [] | |
| param_pred_list = [] | |
| point_list = [] | |
| stride_list = [] | |
| for cls_score_per_lvl, centerness_per_lvl, param_pred_per_lvl,\ | |
| point_per_lvl, stride_per_lvl in \ | |
| zip(self._raw_positive_infos['cls_scores'], | |
| self._raw_positive_infos['centernesses'], | |
| self._raw_positive_infos['param_preds'], | |
| self._raw_positive_infos['all_level_points'], | |
| self._raw_positive_infos['all_level_strides']): | |
| cls_score_per_lvl = \ | |
| cls_score_per_lvl.permute( | |
| 0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) | |
| centerness_per_lvl = \ | |
| centerness_per_lvl.permute( | |
| 0, 2, 3, 1).reshape(num_imgs, -1, 1) | |
| param_pred_per_lvl = \ | |
| param_pred_per_lvl.permute( | |
| 0, 2, 3, 1).reshape(num_imgs, -1, self.num_params) | |
| point_per_lvl = point_per_lvl.unsqueeze(0).repeat(num_imgs, 1, 1) | |
| stride_per_lvl = stride_per_lvl.unsqueeze(0).repeat(num_imgs, 1) | |
| cls_score_list.append(cls_score_per_lvl) | |
| centerness_list.append(centerness_per_lvl) | |
| param_pred_list.append(param_pred_per_lvl) | |
| point_list.append(point_per_lvl) | |
| stride_list.append(stride_per_lvl) | |
| cls_scores = torch.cat(cls_score_list, dim=1) | |
| centernesses = torch.cat(centerness_list, dim=1) | |
| param_preds = torch.cat(param_pred_list, dim=1) | |
| all_points = torch.cat(point_list, dim=1) | |
| all_strides = torch.cat(stride_list, dim=1) | |
| positive_infos = [] | |
| for i, (pos_gt_inds, | |
| pos_inds) in enumerate(zip(pos_gt_inds_list, pos_inds_list)): | |
| pos_info = InstanceData() | |
| pos_info.points = all_points[i][pos_inds] | |
| pos_info.strides = all_strides[i][pos_inds] | |
| pos_info.scores = cls_scores[i][pos_inds] | |
| pos_info.centernesses = centernesses[i][pos_inds] | |
| pos_info.param_preds = param_preds[i][pos_inds] | |
| pos_info.pos_assigned_gt_inds = pos_gt_inds | |
| pos_info.pos_inds = pos_inds | |
| positive_infos.append(pos_info) | |
| return positive_infos | |
| def predict_by_feat(self, | |
| cls_scores: List[Tensor], | |
| bbox_preds: List[Tensor], | |
| score_factors: Optional[List[Tensor]] = None, | |
| param_preds: Optional[List[Tensor]] = None, | |
| batch_img_metas: Optional[List[dict]] = None, | |
| cfg: Optional[ConfigDict] = None, | |
| rescale: bool = False, | |
| with_nms: bool = True) -> InstanceList: | |
| """Transform a batch of output features extracted from the head into | |
| bbox results. | |
| Note: When score_factors is not None, the cls_scores are | |
| usually multiplied by it then obtain the real score used in NMS, | |
| such as CenterNess in FCOS, IoU branch in ATSS. | |
| Args: | |
| cls_scores (list[Tensor]): Classification scores for all | |
| scale levels, each is a 4D-tensor, has shape | |
| (batch_size, num_priors * num_classes, H, W). | |
| bbox_preds (list[Tensor]): Box energies / deltas for all | |
| scale levels, each is a 4D-tensor, has shape | |
| (batch_size, num_priors * 4, H, W). | |
| score_factors (list[Tensor], optional): Score factor for | |
| all scale level, each is a 4D-tensor, has shape | |
| (batch_size, num_priors * 1, H, W). Defaults to None. | |
| param_preds (list[Tensor], optional): Params for all scale | |
| level, each is a 4D-tensor, has shape | |
| (batch_size, num_priors * num_params, H, W) | |
| batch_img_metas (list[dict], Optional): Batch image meta info. | |
| Defaults to None. | |
| cfg (ConfigDict, optional): Test / postprocessing | |
| configuration, if None, test_cfg would be used. | |
| Defaults to None. | |
| rescale (bool): If True, return boxes in original image space. | |
| Defaults to False. | |
| with_nms (bool): If True, do nms before return boxes. | |
| Defaults to True. | |
| Returns: | |
| list[:obj:`InstanceData`]: Object detection results of each image | |
| after the post process. Each item usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| """ | |
| assert len(cls_scores) == len(bbox_preds) | |
| if score_factors is None: | |
| # e.g. Retina, FreeAnchor, Foveabox, etc. | |
| with_score_factors = False | |
| else: | |
| # e.g. FCOS, PAA, ATSS, AutoAssign, etc. | |
| with_score_factors = True | |
| assert len(cls_scores) == len(score_factors) | |
| num_levels = len(cls_scores) | |
| featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] | |
| all_level_points_strides = self.prior_generator.grid_priors( | |
| featmap_sizes, | |
| dtype=bbox_preds[0].dtype, | |
| device=bbox_preds[0].device, | |
| with_stride=True) | |
| all_level_points = [i[:, :2] for i in all_level_points_strides] | |
| all_level_strides = [i[:, 2] for i in all_level_points_strides] | |
| result_list = [] | |
| for img_id in range(len(batch_img_metas)): | |
| img_meta = batch_img_metas[img_id] | |
| cls_score_list = select_single_mlvl( | |
| cls_scores, img_id, detach=True) | |
| bbox_pred_list = select_single_mlvl( | |
| bbox_preds, img_id, detach=True) | |
| if with_score_factors: | |
| score_factor_list = select_single_mlvl( | |
| score_factors, img_id, detach=True) | |
| else: | |
| score_factor_list = [None for _ in range(num_levels)] | |
| param_pred_list = select_single_mlvl( | |
| param_preds, img_id, detach=True) | |
| results = self._predict_by_feat_single( | |
| cls_score_list=cls_score_list, | |
| bbox_pred_list=bbox_pred_list, | |
| score_factor_list=score_factor_list, | |
| param_pred_list=param_pred_list, | |
| mlvl_points=all_level_points, | |
| mlvl_strides=all_level_strides, | |
| img_meta=img_meta, | |
| cfg=cfg, | |
| rescale=rescale, | |
| with_nms=with_nms) | |
| result_list.append(results) | |
| return result_list | |
| def _predict_by_feat_single(self, | |
| cls_score_list: List[Tensor], | |
| bbox_pred_list: List[Tensor], | |
| score_factor_list: List[Tensor], | |
| param_pred_list: List[Tensor], | |
| mlvl_points: List[Tensor], | |
| mlvl_strides: List[Tensor], | |
| img_meta: dict, | |
| cfg: ConfigDict, | |
| rescale: bool = False, | |
| with_nms: bool = True) -> InstanceData: | |
| """Transform a single image's features extracted from the head into | |
| bbox results. | |
| Args: | |
| cls_score_list (list[Tensor]): Box scores from all scale | |
| levels of a single image, each item has shape | |
| (num_priors * num_classes, H, W). | |
| bbox_pred_list (list[Tensor]): Box energies / deltas from | |
| all scale levels of a single image, each item has shape | |
| (num_priors * 4, H, W). | |
| score_factor_list (list[Tensor]): Score factor from all scale | |
| levels of a single image, each item has shape | |
| (num_priors * 1, H, W). | |
| param_pred_list (List[Tensor]): Param predition from all scale | |
| levels of a single image, each item has shape | |
| (num_priors * num_params, H, W). | |
| mlvl_points (list[Tensor]): Each element in the list is | |
| the priors of a single level in feature pyramid. | |
| It has shape (num_priors, 2) | |
| mlvl_strides (List[Tensor]): Each element in the list is | |
| the stride of a single level in feature pyramid. | |
| It has shape (num_priors, 1) | |
| img_meta (dict): Image meta info. | |
| cfg (mmengine.Config): Test / postprocessing configuration, | |
| if None, test_cfg would be used. | |
| rescale (bool): If True, return boxes in original image space. | |
| Defaults to False. | |
| with_nms (bool): If True, do nms before return boxes. | |
| Defaults to True. | |
| Returns: | |
| :obj:`InstanceData`: Detection results of each image | |
| after the post process. | |
| Each item usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| """ | |
| if score_factor_list[0] is None: | |
| # e.g. Retina, FreeAnchor, etc. | |
| with_score_factors = False | |
| else: | |
| # e.g. FCOS, PAA, ATSS, etc. | |
| with_score_factors = True | |
| cfg = self.test_cfg if cfg is None else cfg | |
| cfg = copy.deepcopy(cfg) | |
| img_shape = img_meta['img_shape'] | |
| nms_pre = cfg.get('nms_pre', -1) | |
| mlvl_bbox_preds = [] | |
| mlvl_param_preds = [] | |
| mlvl_valid_points = [] | |
| mlvl_valid_strides = [] | |
| mlvl_scores = [] | |
| mlvl_labels = [] | |
| if with_score_factors: | |
| mlvl_score_factors = [] | |
| else: | |
| mlvl_score_factors = None | |
| for level_idx, (cls_score, bbox_pred, score_factor, | |
| param_pred, points, strides) in \ | |
| enumerate(zip(cls_score_list, bbox_pred_list, | |
| score_factor_list, param_pred_list, | |
| mlvl_points, mlvl_strides)): | |
| assert cls_score.size()[-2:] == bbox_pred.size()[-2:] | |
| dim = self.bbox_coder.encode_size | |
| bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) | |
| if with_score_factors: | |
| score_factor = score_factor.permute(1, 2, | |
| 0).reshape(-1).sigmoid() | |
| cls_score = cls_score.permute(1, 2, | |
| 0).reshape(-1, self.cls_out_channels) | |
| if self.use_sigmoid_cls: | |
| scores = cls_score.sigmoid() | |
| else: | |
| # remind that we set FG labels to [0, num_class-1] | |
| # since mmdet v2.0 | |
| # BG cat_id: num_class | |
| scores = cls_score.softmax(-1)[:, :-1] | |
| param_pred = param_pred.permute(1, 2, | |
| 0).reshape(-1, self.num_params) | |
| # After https://github.com/open-mmlab/mmdetection/pull/6268/, | |
| # this operation keeps fewer bboxes under the same `nms_pre`. | |
| # There is no difference in performance for most models. If you | |
| # find a slight drop in performance, you can set a larger | |
| # `nms_pre` than before. | |
| score_thr = cfg.get('score_thr', 0) | |
| results = filter_scores_and_topk( | |
| scores, score_thr, nms_pre, | |
| dict( | |
| bbox_pred=bbox_pred, | |
| param_pred=param_pred, | |
| points=points, | |
| strides=strides)) | |
| scores, labels, keep_idxs, filtered_results = results | |
| bbox_pred = filtered_results['bbox_pred'] | |
| param_pred = filtered_results['param_pred'] | |
| points = filtered_results['points'] | |
| strides = filtered_results['strides'] | |
| if with_score_factors: | |
| score_factor = score_factor[keep_idxs] | |
| mlvl_bbox_preds.append(bbox_pred) | |
| mlvl_param_preds.append(param_pred) | |
| mlvl_valid_points.append(points) | |
| mlvl_valid_strides.append(strides) | |
| mlvl_scores.append(scores) | |
| mlvl_labels.append(labels) | |
| if with_score_factors: | |
| mlvl_score_factors.append(score_factor) | |
| bbox_pred = torch.cat(mlvl_bbox_preds) | |
| priors = cat_boxes(mlvl_valid_points) | |
| bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) | |
| results = InstanceData() | |
| results.bboxes = bboxes | |
| results.scores = torch.cat(mlvl_scores) | |
| results.labels = torch.cat(mlvl_labels) | |
| results.param_preds = torch.cat(mlvl_param_preds) | |
| results.points = torch.cat(mlvl_valid_points) | |
| results.strides = torch.cat(mlvl_valid_strides) | |
| if with_score_factors: | |
| results.score_factors = torch.cat(mlvl_score_factors) | |
| return self._bbox_post_process( | |
| results=results, | |
| cfg=cfg, | |
| rescale=rescale, | |
| with_nms=with_nms, | |
| img_meta=img_meta) | |
| class MaskFeatModule(BaseModule): | |
| """CondInst mask feature map branch used in \ | |
| https://arxiv.org/abs/1904.02689. | |
| Args: | |
| in_channels (int): Number of channels in the input feature map. | |
| feat_channels (int): Number of hidden channels of the mask feature | |
| map branch. | |
| start_level (int): The starting feature map level from RPN that | |
| will be used to predict the mask feature map. | |
| end_level (int): The ending feature map level from rpn that | |
| will be used to predict the mask feature map. | |
| out_channels (int): Number of output channels of the mask feature | |
| map branch. This is the channel count of the mask | |
| feature map that to be dynamically convolved with the predicted | |
| kernel. | |
| mask_stride (int): Downsample factor of the mask feature map output. | |
| Defaults to 4. | |
| num_stacked_convs (int): Number of convs in mask feature branch. | |
| conv_cfg (dict): Config dict for convolution layer. Default: None. | |
| norm_cfg (dict): Config dict for normalization layer. Default: None. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| feat_channels: int, | |
| start_level: int, | |
| end_level: int, | |
| out_channels: int, | |
| mask_stride: int = 4, | |
| num_stacked_convs: int = 4, | |
| conv_cfg: OptConfigType = None, | |
| norm_cfg: OptConfigType = None, | |
| init_cfg: MultiConfig = [ | |
| dict(type='Normal', layer='Conv2d', std=0.01) | |
| ], | |
| **kwargs) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.in_channels = in_channels | |
| self.feat_channels = feat_channels | |
| self.start_level = start_level | |
| self.end_level = end_level | |
| self.mask_stride = mask_stride | |
| self.num_stacked_convs = num_stacked_convs | |
| assert start_level >= 0 and end_level >= start_level | |
| self.out_channels = out_channels | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self._init_layers() | |
| def _init_layers(self) -> None: | |
| """Initialize layers of the head.""" | |
| self.convs_all_levels = nn.ModuleList() | |
| for i in range(self.start_level, self.end_level + 1): | |
| convs_per_level = nn.Sequential() | |
| convs_per_level.add_module( | |
| f'conv{i}', | |
| ConvModule( | |
| self.in_channels, | |
| self.feat_channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| inplace=False, | |
| bias=False)) | |
| self.convs_all_levels.append(convs_per_level) | |
| conv_branch = [] | |
| for _ in range(self.num_stacked_convs): | |
| conv_branch.append( | |
| ConvModule( | |
| self.feat_channels, | |
| self.feat_channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| bias=False)) | |
| self.conv_branch = nn.Sequential(*conv_branch) | |
| self.conv_pred = nn.Conv2d( | |
| self.feat_channels, self.out_channels, 1, stride=1) | |
| def init_weights(self) -> None: | |
| """Initialize weights of the head.""" | |
| super().init_weights() | |
| kaiming_init(self.convs_all_levels, a=1, distribution='uniform') | |
| kaiming_init(self.conv_branch, a=1, distribution='uniform') | |
| kaiming_init(self.conv_pred, a=1, distribution='uniform') | |
| def forward(self, x: Tuple[Tensor]) -> Tensor: | |
| """Forward features from the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Features from the upstream network, each is | |
| a 4D-tensor. | |
| Returns: | |
| Tensor: The predicted mask feature map. | |
| """ | |
| inputs = x[self.start_level:self.end_level + 1] | |
| assert len(inputs) == (self.end_level - self.start_level + 1) | |
| feature_add_all_level = self.convs_all_levels[0](inputs[0]) | |
| target_h, target_w = feature_add_all_level.size()[2:] | |
| for i in range(1, len(inputs)): | |
| input_p = inputs[i] | |
| x_p = self.convs_all_levels[i](input_p) | |
| h, w = x_p.size()[2:] | |
| factor_h = target_h // h | |
| factor_w = target_w // w | |
| assert factor_h == factor_w | |
| feature_per_level = aligned_bilinear(x_p, factor_h) | |
| feature_add_all_level = feature_add_all_level + \ | |
| feature_per_level | |
| feature_add_all_level = self.conv_branch(feature_add_all_level) | |
| feature_pred = self.conv_pred(feature_add_all_level) | |
| return feature_pred | |
| class CondInstMaskHead(BaseMaskHead): | |
| """CondInst mask head used in https://arxiv.org/abs/1904.02689. | |
| This head outputs the mask for CondInst. | |
| Args: | |
| mask_feature_head (dict): Config of CondInstMaskFeatHead. | |
| num_layers (int): Number of dynamic conv layers. | |
| feat_channels (int): Number of channels in the dynamic conv. | |
| mask_out_stride (int): The stride of the mask feat. | |
| size_of_interest (int): The size of the region used in rel coord. | |
| max_masks_to_train (int): Maximum number of masks to train for | |
| each image. | |
| loss_segm (:obj:`ConfigDict` or dict, optional): Config of | |
| segmentation loss. | |
| train_cfg (:obj:`ConfigDict` or dict, optional): Training config | |
| of head. | |
| test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of | |
| head. | |
| """ | |
| def __init__(self, | |
| mask_feature_head: ConfigType, | |
| num_layers: int = 3, | |
| feat_channels: int = 8, | |
| mask_out_stride: int = 4, | |
| size_of_interest: int = 8, | |
| max_masks_to_train: int = -1, | |
| topk_masks_per_img: int = -1, | |
| loss_mask: ConfigType = None, | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None) -> None: | |
| super().__init__() | |
| self.mask_feature_head = MaskFeatModule(**mask_feature_head) | |
| self.mask_feat_stride = self.mask_feature_head.mask_stride | |
| self.in_channels = self.mask_feature_head.out_channels | |
| self.num_layers = num_layers | |
| self.feat_channels = feat_channels | |
| self.size_of_interest = size_of_interest | |
| self.mask_out_stride = mask_out_stride | |
| self.max_masks_to_train = max_masks_to_train | |
| self.topk_masks_per_img = topk_masks_per_img | |
| self.prior_generator = MlvlPointGenerator([self.mask_feat_stride]) | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| self.loss_mask = MODELS.build(loss_mask) | |
| self._init_layers() | |
| def _init_layers(self) -> None: | |
| """Initialize layers of the head.""" | |
| weight_nums, bias_nums = [], [] | |
| for i in range(self.num_layers): | |
| if i == 0: | |
| weight_nums.append((self.in_channels + 2) * self.feat_channels) | |
| bias_nums.append(self.feat_channels) | |
| elif i == self.num_layers - 1: | |
| weight_nums.append(self.feat_channels * 1) | |
| bias_nums.append(1) | |
| else: | |
| weight_nums.append(self.feat_channels * self.feat_channels) | |
| bias_nums.append(self.feat_channels) | |
| self.weight_nums = weight_nums | |
| self.bias_nums = bias_nums | |
| self.num_params = sum(weight_nums) + sum(bias_nums) | |
| def parse_dynamic_params( | |
| self, params: Tensor) -> Tuple[List[Tensor], List[Tensor]]: | |
| """parse the dynamic params for dynamic conv.""" | |
| num_insts = params.size(0) | |
| params_splits = list( | |
| torch.split_with_sizes( | |
| params, self.weight_nums + self.bias_nums, dim=1)) | |
| weight_splits = params_splits[:self.num_layers] | |
| bias_splits = params_splits[self.num_layers:] | |
| for i in range(self.num_layers): | |
| if i < self.num_layers - 1: | |
| weight_splits[i] = weight_splits[i].reshape( | |
| num_insts * self.in_channels, -1, 1, 1) | |
| bias_splits[i] = bias_splits[i].reshape(num_insts * | |
| self.in_channels) | |
| else: | |
| # out_channels x in_channels x 1 x 1 | |
| weight_splits[i] = weight_splits[i].reshape( | |
| num_insts * 1, -1, 1, 1) | |
| bias_splits[i] = bias_splits[i].reshape(num_insts) | |
| return weight_splits, bias_splits | |
| def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor], | |
| biases: List[Tensor], num_insts: int) -> Tensor: | |
| """dynamic forward, each layer follow a relu.""" | |
| n_layers = len(weights) | |
| x = features | |
| for i, (w, b) in enumerate(zip(weights, biases)): | |
| x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts) | |
| if i < n_layers - 1: | |
| x = F.relu(x) | |
| return x | |
| def forward(self, x: tuple, positive_infos: InstanceList) -> tuple: | |
| """Forward feature from the upstream network to get prototypes and | |
| linearly combine the prototypes, using masks coefficients, into | |
| instance masks. Finally, crop the instance masks with given bboxes. | |
| Args: | |
| x (Tuple[Tensor]): Feature from the upstream network, which is | |
| a 4D-tensor. | |
| positive_infos (List[:obj:``InstanceData``]): Positive information | |
| that calculate from detect head. | |
| Returns: | |
| tuple: Predicted instance segmentation masks | |
| """ | |
| mask_feats = self.mask_feature_head(x) | |
| return multi_apply(self.forward_single, mask_feats, positive_infos) | |
| def forward_single(self, mask_feat: Tensor, | |
| positive_info: InstanceData) -> Tensor: | |
| """Forward features of a each image.""" | |
| pos_param_preds = positive_info.get('param_preds') | |
| pos_points = positive_info.get('points') | |
| pos_strides = positive_info.get('strides') | |
| num_inst = pos_param_preds.shape[0] | |
| mask_feat = mask_feat[None].repeat(num_inst, 1, 1, 1) | |
| _, _, H, W = mask_feat.size() | |
| if num_inst == 0: | |
| return (pos_param_preds.new_zeros((0, 1, H, W)), ) | |
| locations = self.prior_generator.single_level_grid_priors( | |
| mask_feat.size()[2:], 0, device=mask_feat.device) | |
| rel_coords = relative_coordinate_maps(locations, pos_points, | |
| pos_strides, | |
| self.size_of_interest, | |
| mask_feat.size()[2:]) | |
| mask_head_inputs = torch.cat([rel_coords, mask_feat], dim=1) | |
| mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) | |
| weights, biases = self.parse_dynamic_params(pos_param_preds) | |
| mask_preds = self.dynamic_conv_forward(mask_head_inputs, weights, | |
| biases, num_inst) | |
| mask_preds = mask_preds.reshape(-1, H, W) | |
| mask_preds = aligned_bilinear( | |
| mask_preds.unsqueeze(0), | |
| int(self.mask_feat_stride / self.mask_out_stride)).squeeze(0) | |
| return (mask_preds, ) | |
| def loss_by_feat(self, mask_preds: List[Tensor], | |
| batch_gt_instances: InstanceList, | |
| batch_img_metas: List[dict], positive_infos: InstanceList, | |
| **kwargs) -> dict: | |
| """Calculate the loss based on the features extracted by the mask head. | |
| Args: | |
| mask_preds (list[Tensor]): List of predicted masks, each has | |
| shape (num_classes, H, W). | |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``bboxes``, ``masks``, | |
| and ``labels`` attributes. | |
| batch_img_metas (list[dict]): Meta information of multiple images. | |
| positive_infos (List[:obj:``InstanceData``]): Information of | |
| positive samples of each image that are assigned in detection | |
| head. | |
| Returns: | |
| dict[str, Tensor]: A dictionary of loss components. | |
| """ | |
| assert positive_infos is not None, \ | |
| 'positive_infos should not be None in `CondInstMaskHead`' | |
| losses = dict() | |
| loss_mask = 0. | |
| num_imgs = len(mask_preds) | |
| total_pos = 0 | |
| for idx in range(num_imgs): | |
| (mask_pred, pos_mask_targets, num_pos) = \ | |
| self._get_targets_single( | |
| mask_preds[idx], batch_gt_instances[idx], | |
| positive_infos[idx]) | |
| # mask loss | |
| total_pos += num_pos | |
| if num_pos == 0 or pos_mask_targets is None: | |
| loss = mask_pred.new_zeros(1).mean() | |
| else: | |
| loss = self.loss_mask( | |
| mask_pred, pos_mask_targets, | |
| reduction_override='none').sum() | |
| loss_mask += loss | |
| if total_pos == 0: | |
| total_pos += 1 # avoid nan | |
| loss_mask = loss_mask / total_pos | |
| losses.update(loss_mask=loss_mask) | |
| return losses | |
| def _get_targets_single(self, mask_preds: Tensor, | |
| gt_instances: InstanceData, | |
| positive_info: InstanceData): | |
| """Compute targets for predictions of single image. | |
| Args: | |
| mask_preds (Tensor): Predicted prototypes with shape | |
| (num_classes, H, W). | |
| gt_instances (:obj:`InstanceData`): Ground truth of instance | |
| annotations. It should includes ``bboxes``, ``labels``, | |
| and ``masks`` attributes. | |
| positive_info (:obj:`InstanceData`): Information of positive | |
| samples that are assigned in detection head. It usually | |
| contains following keys. | |
| - pos_assigned_gt_inds (Tensor): Assigner GT indexes of | |
| positive proposals, has shape (num_pos, ) | |
| - pos_inds (Tensor): Positive index of image, has | |
| shape (num_pos, ). | |
| - param_pred (Tensor): Positive param preditions | |
| with shape (num_pos, num_params). | |
| Returns: | |
| tuple: Usually returns a tuple containing learning targets. | |
| - mask_preds (Tensor): Positive predicted mask with shape | |
| (num_pos, mask_h, mask_w). | |
| - pos_mask_targets (Tensor): Positive mask targets with shape | |
| (num_pos, mask_h, mask_w). | |
| - num_pos (int): Positive numbers. | |
| """ | |
| gt_bboxes = gt_instances.bboxes | |
| device = gt_bboxes.device | |
| gt_masks = gt_instances.masks.to_tensor( | |
| dtype=torch.bool, device=device).float() | |
| # process with mask targets | |
| pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds') | |
| scores = positive_info.get('scores') | |
| centernesses = positive_info.get('centernesses') | |
| num_pos = pos_assigned_gt_inds.size(0) | |
| if gt_masks.size(0) == 0 or num_pos == 0: | |
| return mask_preds, None, 0 | |
| # Since we're producing (near) full image masks, | |
| # it'd take too much vram to backprop on every single mask. | |
| # Thus we select only a subset. | |
| if (self.max_masks_to_train != -1) and \ | |
| (num_pos > self.max_masks_to_train): | |
| perm = torch.randperm(num_pos) | |
| select = perm[:self.max_masks_to_train] | |
| mask_preds = mask_preds[select] | |
| pos_assigned_gt_inds = pos_assigned_gt_inds[select] | |
| num_pos = self.max_masks_to_train | |
| elif self.topk_masks_per_img != -1: | |
| unique_gt_inds = pos_assigned_gt_inds.unique() | |
| num_inst_per_gt = max( | |
| int(self.topk_masks_per_img / len(unique_gt_inds)), 1) | |
| keep_mask_preds = [] | |
| keep_pos_assigned_gt_inds = [] | |
| for gt_ind in unique_gt_inds: | |
| per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind) | |
| mask_preds_per_inst = mask_preds[per_inst_pos_inds] | |
| gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds] | |
| if sum(per_inst_pos_inds) > num_inst_per_gt: | |
| per_inst_scores = scores[per_inst_pos_inds].sigmoid().max( | |
| dim=1)[0] | |
| per_inst_centerness = centernesses[ | |
| per_inst_pos_inds].sigmoid().reshape(-1, ) | |
| select = (per_inst_scores * per_inst_centerness).topk( | |
| k=num_inst_per_gt, dim=0)[1] | |
| mask_preds_per_inst = mask_preds_per_inst[select] | |
| gt_inds_per_inst = gt_inds_per_inst[select] | |
| keep_mask_preds.append(mask_preds_per_inst) | |
| keep_pos_assigned_gt_inds.append(gt_inds_per_inst) | |
| mask_preds = torch.cat(keep_mask_preds) | |
| pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds) | |
| num_pos = pos_assigned_gt_inds.size(0) | |
| # Follow the origin implement | |
| start = int(self.mask_out_stride // 2) | |
| gt_masks = gt_masks[:, start::self.mask_out_stride, | |
| start::self.mask_out_stride] | |
| gt_masks = gt_masks.gt(0.5).float() | |
| pos_mask_targets = gt_masks[pos_assigned_gt_inds] | |
| return (mask_preds, pos_mask_targets, num_pos) | |
| def predict_by_feat(self, | |
| mask_preds: List[Tensor], | |
| results_list: InstanceList, | |
| batch_img_metas: List[dict], | |
| rescale: bool = True, | |
| **kwargs) -> InstanceList: | |
| """Transform a batch of output features extracted from the head into | |
| mask results. | |
| Args: | |
| mask_preds (list[Tensor]): Predicted prototypes with shape | |
| (num_classes, H, W). | |
| results_list (List[:obj:``InstanceData``]): BBoxHead results. | |
| batch_img_metas (list[dict]): Meta information of all images. | |
| rescale (bool, optional): Whether to rescale the results. | |
| Defaults to False. | |
| Returns: | |
| list[:obj:`InstanceData`]: Processed results of multiple | |
| images.Each :obj:`InstanceData` usually contains | |
| following keys. | |
| - scores (Tensor): Classification scores, has shape | |
| (num_instance,). | |
| - labels (Tensor): Has shape (num_instances,). | |
| - masks (Tensor): Processed mask results, has | |
| shape (num_instances, h, w). | |
| """ | |
| assert len(mask_preds) == len(results_list) == len(batch_img_metas) | |
| for img_id in range(len(batch_img_metas)): | |
| img_meta = batch_img_metas[img_id] | |
| results = results_list[img_id] | |
| bboxes = results.bboxes | |
| mask_pred = mask_preds[img_id] | |
| if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0: | |
| results_list[img_id] = empty_instances( | |
| [img_meta], | |
| bboxes.device, | |
| task_type='mask', | |
| instance_results=[results])[0] | |
| else: | |
| im_mask = self._predict_by_feat_single( | |
| mask_preds=mask_pred, | |
| bboxes=bboxes, | |
| img_meta=img_meta, | |
| rescale=rescale) | |
| results.masks = im_mask | |
| return results_list | |
| def _predict_by_feat_single(self, | |
| mask_preds: Tensor, | |
| bboxes: Tensor, | |
| img_meta: dict, | |
| rescale: bool, | |
| cfg: OptConfigType = None): | |
| """Transform a single image's features extracted from the head into | |
| mask results. | |
| Args: | |
| mask_preds (Tensor): Predicted prototypes, has shape [H, W, N]. | |
| img_meta (dict): Meta information of each image, e.g., | |
| image size, scaling factor, etc. | |
| rescale (bool): If rescale is False, then returned masks will | |
| fit the scale of imgs[0]. | |
| cfg (dict, optional): Config used in test phase. | |
| Defaults to None. | |
| Returns: | |
| :obj:`InstanceData`: Processed results of single image. | |
| it usually contains following keys. | |
| - scores (Tensor): Classification scores, has shape | |
| (num_instance,). | |
| - labels (Tensor): Has shape (num_instances,). | |
| - masks (Tensor): Processed mask results, has | |
| shape (num_instances, h, w). | |
| """ | |
| cfg = self.test_cfg if cfg is None else cfg | |
| scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( | |
| (1, 2)) | |
| img_h, img_w = img_meta['img_shape'][:2] | |
| ori_h, ori_w = img_meta['ori_shape'][:2] | |
| mask_preds = mask_preds.sigmoid().unsqueeze(0) | |
| mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride) | |
| mask_preds = mask_preds[:, :, :img_h, :img_w] | |
| if rescale: # in-placed rescale the bboxes | |
| scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( | |
| (1, 2)) | |
| bboxes /= scale_factor | |
| masks = F.interpolate( | |
| mask_preds, (ori_h, ori_w), | |
| mode='bilinear', | |
| align_corners=False).squeeze(0) > cfg.mask_thr | |
| else: | |
| masks = mask_preds.squeeze(0) > cfg.mask_thr | |
| return masks | |