Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional, Tuple | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| from mmdet.models.utils.misc import floordiv | |
| from mmdet.registry import MODELS | |
| from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType | |
| from ..layers import mask_matrix_nms | |
| from ..utils import center_of_mass, generate_coordinate, multi_apply | |
| from .base_mask_head import BaseMaskHead | |
| from ...structures.mask import mask2bbox | |
| class SOLOHead(BaseMaskHead): | |
| """SOLO mask head used in `SOLO: Segmenting Objects by Locations. | |
| <https://arxiv.org/abs/1912.04488>`_ | |
| Args: | |
| num_classes (int): Number of categories excluding the background | |
| category. | |
| in_channels (int): Number of channels in the input feature map. | |
| feat_channels (int): Number of hidden channels. Used in child classes. | |
| Defaults to 256. | |
| stacked_convs (int): Number of stacking convs of the head. | |
| Defaults to 4. | |
| strides (tuple): Downsample factor of each feature map. | |
| scale_ranges (tuple[tuple[int, int]]): Area range of multiple | |
| level masks, in the format [(min1, max1), (min2, max2), ...]. | |
| A range of (16, 64) means the area range between (16, 64). | |
| pos_scale (float): Constant scale factor to control the center region. | |
| num_grids (list[int]): Divided image into a uniform grids, each | |
| feature map has a different grid value. The number of output | |
| channels is grid ** 2. Defaults to [40, 36, 24, 16, 12]. | |
| cls_down_index (int): The index of downsample operation in | |
| classification branch. Defaults to 0. | |
| loss_mask (dict): Config of mask loss. | |
| loss_cls (dict): Config of classification loss. | |
| norm_cfg (dict): Dictionary to construct and config norm layer. | |
| Defaults to norm_cfg=dict(type='GN', num_groups=32, | |
| requires_grad=True). | |
| train_cfg (dict): Training config of head. | |
| test_cfg (dict): Testing config of head. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| """ | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| in_channels: int, | |
| feat_channels: int = 256, | |
| stacked_convs: int = 4, | |
| strides: tuple = (4, 8, 16, 32, 64), | |
| scale_ranges: tuple = ((8, 32), (16, 64), (32, 128), (64, 256), (128, | |
| 512)), | |
| pos_scale: float = 0.2, | |
| num_grids: list = [40, 36, 24, 16, 12], | |
| cls_down_index: int = 0, | |
| loss_mask: ConfigType = dict( | |
| type='DiceLoss', use_sigmoid=True, loss_weight=3.0), | |
| loss_cls: ConfigType = dict( | |
| type='FocalLoss', | |
| use_sigmoid=True, | |
| gamma=2.0, | |
| alpha=0.25, | |
| loss_weight=1.0), | |
| norm_cfg: ConfigType = dict( | |
| type='GN', num_groups=32, requires_grad=True), | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| init_cfg: MultiConfig = [ | |
| dict(type='Normal', layer='Conv2d', std=0.01), | |
| dict( | |
| type='Normal', | |
| std=0.01, | |
| bias_prob=0.01, | |
| override=dict(name='conv_mask_list')), | |
| dict( | |
| type='Normal', | |
| std=0.01, | |
| bias_prob=0.01, | |
| override=dict(name='conv_cls')) | |
| ] | |
| ) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.num_classes = num_classes | |
| self.cls_out_channels = self.num_classes | |
| self.in_channels = in_channels | |
| self.feat_channels = feat_channels | |
| self.stacked_convs = stacked_convs | |
| self.strides = strides | |
| self.num_grids = num_grids | |
| # number of FPN feats | |
| self.num_levels = len(strides) | |
| assert self.num_levels == len(scale_ranges) == len(num_grids) | |
| self.scale_ranges = scale_ranges | |
| self.pos_scale = pos_scale | |
| self.cls_down_index = cls_down_index | |
| self.loss_cls = MODELS.build(loss_cls) | |
| self.loss_mask = MODELS.build(loss_mask) | |
| self.norm_cfg = norm_cfg | |
| self.init_cfg = init_cfg | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| self._init_layers() | |
| def _init_layers(self) -> None: | |
| """Initialize layers of the head.""" | |
| self.mask_convs = nn.ModuleList() | |
| self.cls_convs = nn.ModuleList() | |
| for i in range(self.stacked_convs): | |
| chn = self.in_channels + 2 if i == 0 else self.feat_channels | |
| self.mask_convs.append( | |
| ConvModule( | |
| chn, | |
| self.feat_channels, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| norm_cfg=self.norm_cfg)) | |
| chn = self.in_channels if i == 0 else self.feat_channels | |
| self.cls_convs.append( | |
| ConvModule( | |
| chn, | |
| self.feat_channels, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| norm_cfg=self.norm_cfg)) | |
| self.conv_mask_list = nn.ModuleList() | |
| for num_grid in self.num_grids: | |
| self.conv_mask_list.append( | |
| nn.Conv2d(self.feat_channels, num_grid**2, 1)) | |
| self.conv_cls = nn.Conv2d( | |
| self.feat_channels, self.cls_out_channels, 3, padding=1) | |
| def resize_feats(self, x: Tuple[Tensor]) -> List[Tensor]: | |
| """Downsample the first feat and upsample last feat in feats. | |
| Args: | |
| x (tuple[Tensor]): Features from the upstream network, each is | |
| a 4D-tensor. | |
| Returns: | |
| list[Tensor]: Features after resizing, each is a 4D-tensor. | |
| """ | |
| out = [] | |
| for i in range(len(x)): | |
| if i == 0: | |
| out.append( | |
| F.interpolate(x[0], scale_factor=0.5, mode='bilinear')) | |
| elif i == len(x) - 1: | |
| out.append( | |
| F.interpolate( | |
| x[i], size=x[i - 1].shape[-2:], mode='bilinear')) | |
| else: | |
| out.append(x[i]) | |
| return out | |
| def forward(self, x: Tuple[Tensor]) -> tuple: | |
| """Forward features from the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Features from the upstream network, each is | |
| a 4D-tensor. | |
| Returns: | |
| tuple: A tuple of classification scores and mask prediction. | |
| - mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. | |
| Each element in the list has shape | |
| (batch_size, num_grids**2 ,h ,w). | |
| - mlvl_cls_preds (list[Tensor]): Multi-level scores. | |
| Each element in the list has shape | |
| (batch_size, num_classes, num_grids ,num_grids). | |
| """ | |
| assert len(x) == self.num_levels | |
| feats = self.resize_feats(x) | |
| mlvl_mask_preds = [] | |
| mlvl_cls_preds = [] | |
| for i in range(self.num_levels): | |
| x = feats[i] | |
| mask_feat = x | |
| cls_feat = x | |
| # generate and concat the coordinate | |
| coord_feat = generate_coordinate(mask_feat.size(), | |
| mask_feat.device) | |
| mask_feat = torch.cat([mask_feat, coord_feat], 1) | |
| for mask_layer in (self.mask_convs): | |
| mask_feat = mask_layer(mask_feat) | |
| mask_feat = F.interpolate( | |
| mask_feat, scale_factor=2, mode='bilinear') | |
| mask_preds = self.conv_mask_list[i](mask_feat) | |
| # cls branch | |
| for j, cls_layer in enumerate(self.cls_convs): | |
| if j == self.cls_down_index: | |
| num_grid = self.num_grids[i] | |
| cls_feat = F.interpolate( | |
| cls_feat, size=num_grid, mode='bilinear') | |
| cls_feat = cls_layer(cls_feat) | |
| cls_pred = self.conv_cls(cls_feat) | |
| if not self.training: | |
| feat_wh = feats[0].size()[-2:] | |
| upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) | |
| mask_preds = F.interpolate( | |
| mask_preds.sigmoid(), size=upsampled_size, mode='bilinear') | |
| cls_pred = cls_pred.sigmoid() | |
| # get local maximum | |
| local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) | |
| keep_mask = local_max[:, :, :-1, :-1] == cls_pred | |
| cls_pred = cls_pred * keep_mask | |
| mlvl_mask_preds.append(mask_preds) | |
| mlvl_cls_preds.append(cls_pred) | |
| return mlvl_mask_preds, mlvl_cls_preds | |
| def loss_by_feat(self, mlvl_mask_preds: List[Tensor], | |
| mlvl_cls_preds: List[Tensor], | |
| batch_gt_instances: InstanceList, | |
| batch_img_metas: List[dict], **kwargs) -> dict: | |
| """Calculate the loss based on the features extracted by the mask head. | |
| Args: | |
| mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. | |
| Each element in the list has shape | |
| (batch_size, num_grids**2 ,h ,w). | |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``bboxes``, ``masks``, | |
| and ``labels`` attributes. | |
| batch_img_metas (list[dict]): Meta information of multiple images. | |
| Returns: | |
| dict[str, Tensor]: A dictionary of loss components. | |
| """ | |
| num_levels = self.num_levels | |
| num_imgs = len(batch_img_metas) | |
| featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds] | |
| # `BoolTensor` in `pos_masks` represent | |
| # whether the corresponding point is | |
| # positive | |
| pos_mask_targets, labels, pos_masks = multi_apply( | |
| self._get_targets_single, | |
| batch_gt_instances, | |
| featmap_sizes=featmap_sizes) | |
| # change from the outside list meaning multi images | |
| # to the outside list meaning multi levels | |
| mlvl_pos_mask_targets = [[] for _ in range(num_levels)] | |
| mlvl_pos_mask_preds = [[] for _ in range(num_levels)] | |
| mlvl_pos_masks = [[] for _ in range(num_levels)] | |
| mlvl_labels = [[] for _ in range(num_levels)] | |
| for img_id in range(num_imgs): | |
| assert num_levels == len(pos_mask_targets[img_id]) | |
| for lvl in range(num_levels): | |
| mlvl_pos_mask_targets[lvl].append( | |
| pos_mask_targets[img_id][lvl]) | |
| mlvl_pos_mask_preds[lvl].append( | |
| mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...]) | |
| mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten()) | |
| mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) | |
| # cat multiple image | |
| temp_mlvl_cls_preds = [] | |
| for lvl in range(num_levels): | |
| mlvl_pos_mask_targets[lvl] = torch.cat( | |
| mlvl_pos_mask_targets[lvl], dim=0) | |
| mlvl_pos_mask_preds[lvl] = torch.cat( | |
| mlvl_pos_mask_preds[lvl], dim=0) | |
| mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0) | |
| mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) | |
| temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( | |
| 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) | |
| num_pos = sum(item.sum() for item in mlvl_pos_masks) | |
| # dice loss | |
| loss_mask = [] | |
| for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets): | |
| if pred.size()[0] == 0: | |
| loss_mask.append(pred.sum().unsqueeze(0)) | |
| continue | |
| loss_mask.append( | |
| self.loss_mask(pred, target, reduction_override='none')) | |
| if num_pos > 0: | |
| loss_mask = torch.cat(loss_mask).sum() / num_pos | |
| else: | |
| loss_mask = torch.cat(loss_mask).mean() | |
| flatten_labels = torch.cat(mlvl_labels) | |
| flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) | |
| loss_cls = self.loss_cls( | |
| flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) | |
| return dict(loss_mask=loss_mask, loss_cls=loss_cls) | |
| def _get_targets_single(self, | |
| gt_instances: InstanceData, | |
| featmap_sizes: Optional[list] = None) -> tuple: | |
| """Compute targets for predictions of single image. | |
| Args: | |
| gt_instances (:obj:`InstanceData`): Ground truth of instance | |
| annotations. It should includes ``bboxes``, ``labels``, | |
| and ``masks`` attributes. | |
| featmap_sizes (list[:obj:`torch.size`]): Size of each | |
| feature map from feature pyramid, each element | |
| means (feat_h, feat_w). Defaults to None. | |
| Returns: | |
| Tuple: Usually returns a tuple containing targets for predictions. | |
| - mlvl_pos_mask_targets (list[Tensor]): Each element represent | |
| the binary mask targets for positive points in this | |
| level, has shape (num_pos, out_h, out_w). | |
| - mlvl_labels (list[Tensor]): Each element is | |
| classification labels for all | |
| points in this level, has shape | |
| (num_grid, num_grid). | |
| - mlvl_pos_masks (list[Tensor]): Each element is | |
| a `BoolTensor` to represent whether the | |
| corresponding point in single level | |
| is positive, has shape (num_grid **2). | |
| """ | |
| gt_labels = gt_instances.labels | |
| device = gt_labels.device | |
| gt_bboxes = gt_instances.bboxes | |
| gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * | |
| (gt_bboxes[:, 3] - gt_bboxes[:, 1])) | |
| gt_masks = gt_instances.masks.to_tensor( | |
| dtype=torch.bool, device=device) | |
| mlvl_pos_mask_targets = [] | |
| mlvl_labels = [] | |
| mlvl_pos_masks = [] | |
| for (lower_bound, upper_bound), stride, featmap_size, num_grid \ | |
| in zip(self.scale_ranges, self.strides, | |
| featmap_sizes, self.num_grids): | |
| mask_target = torch.zeros( | |
| [num_grid**2, featmap_size[0], featmap_size[1]], | |
| dtype=torch.uint8, | |
| device=device) | |
| # FG cat_id: [0, num_classes -1], BG cat_id: num_classes | |
| labels = torch.zeros([num_grid, num_grid], | |
| dtype=torch.int64, | |
| device=device) + self.num_classes | |
| pos_mask = torch.zeros([num_grid**2], | |
| dtype=torch.bool, | |
| device=device) | |
| gt_inds = ((gt_areas >= lower_bound) & | |
| (gt_areas <= upper_bound)).nonzero().flatten() | |
| if len(gt_inds) == 0: | |
| mlvl_pos_mask_targets.append( | |
| mask_target.new_zeros(0, featmap_size[0], featmap_size[1])) | |
| mlvl_labels.append(labels) | |
| mlvl_pos_masks.append(pos_mask) | |
| continue | |
| hit_gt_bboxes = gt_bboxes[gt_inds] | |
| hit_gt_labels = gt_labels[gt_inds] | |
| hit_gt_masks = gt_masks[gt_inds, ...] | |
| pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] - | |
| hit_gt_bboxes[:, 0]) * self.pos_scale | |
| pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] - | |
| hit_gt_bboxes[:, 1]) * self.pos_scale | |
| # Make sure hit_gt_masks has a value | |
| valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0 | |
| output_stride = stride / 2 | |
| for gt_mask, gt_label, pos_h_range, pos_w_range, \ | |
| valid_mask_flag in \ | |
| zip(hit_gt_masks, hit_gt_labels, pos_h_ranges, | |
| pos_w_ranges, valid_mask_flags): | |
| if not valid_mask_flag: | |
| continue | |
| upsampled_size = (featmap_sizes[0][0] * 4, | |
| featmap_sizes[0][1] * 4) | |
| center_h, center_w = center_of_mass(gt_mask) | |
| coord_w = int( | |
| floordiv((center_w / upsampled_size[1]), (1. / num_grid), | |
| rounding_mode='trunc')) | |
| coord_h = int( | |
| floordiv((center_h / upsampled_size[0]), (1. / num_grid), | |
| rounding_mode='trunc')) | |
| # left, top, right, down | |
| top_box = max( | |
| 0, | |
| int( | |
| floordiv( | |
| (center_h - pos_h_range) / upsampled_size[0], | |
| (1. / num_grid), | |
| rounding_mode='trunc'))) | |
| down_box = min( | |
| num_grid - 1, | |
| int( | |
| floordiv( | |
| (center_h + pos_h_range) / upsampled_size[0], | |
| (1. / num_grid), | |
| rounding_mode='trunc'))) | |
| left_box = max( | |
| 0, | |
| int( | |
| floordiv( | |
| (center_w - pos_w_range) / upsampled_size[1], | |
| (1. / num_grid), | |
| rounding_mode='trunc'))) | |
| right_box = min( | |
| num_grid - 1, | |
| int( | |
| floordiv( | |
| (center_w + pos_w_range) / upsampled_size[1], | |
| (1. / num_grid), | |
| rounding_mode='trunc'))) | |
| top = max(top_box, coord_h - 1) | |
| down = min(down_box, coord_h + 1) | |
| left = max(coord_w - 1, left_box) | |
| right = min(right_box, coord_w + 1) | |
| labels[top:(down + 1), left:(right + 1)] = gt_label | |
| # ins | |
| gt_mask = np.uint8(gt_mask.cpu().numpy()) | |
| # Follow the original implementation, F.interpolate is | |
| # different from cv2 and opencv | |
| gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride) | |
| gt_mask = torch.from_numpy(gt_mask).to(device=device) | |
| for i in range(top, down + 1): | |
| for j in range(left, right + 1): | |
| index = int(i * num_grid + j) | |
| mask_target[index, :gt_mask.shape[0], :gt_mask. | |
| shape[1]] = gt_mask | |
| pos_mask[index] = True | |
| mlvl_pos_mask_targets.append(mask_target[pos_mask]) | |
| mlvl_labels.append(labels) | |
| mlvl_pos_masks.append(pos_mask) | |
| return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks | |
| def predict_by_feat(self, mlvl_mask_preds: List[Tensor], | |
| mlvl_cls_scores: List[Tensor], | |
| batch_img_metas: List[dict], **kwargs) -> InstanceList: | |
| """Transform a batch of output features extracted from the head into | |
| mask results. | |
| Args: | |
| mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. | |
| Each element in the list has shape | |
| (batch_size, num_grids**2 ,h ,w). | |
| mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element | |
| in the list has shape | |
| (batch_size, num_classes, num_grids ,num_grids). | |
| batch_img_metas (list[dict]): Meta information of all images. | |
| Returns: | |
| list[:obj:`InstanceData`]: Processed results of multiple | |
| images.Each :obj:`InstanceData` usually contains | |
| following keys. | |
| - scores (Tensor): Classification scores, has shape | |
| (num_instance,). | |
| - labels (Tensor): Has shape (num_instances,). | |
| - masks (Tensor): Processed mask results, has | |
| shape (num_instances, h, w). | |
| """ | |
| mlvl_cls_scores = [ | |
| item.permute(0, 2, 3, 1) for item in mlvl_cls_scores | |
| ] | |
| assert len(mlvl_mask_preds) == len(mlvl_cls_scores) | |
| num_levels = len(mlvl_cls_scores) | |
| results_list = [] | |
| for img_id in range(len(batch_img_metas)): | |
| cls_pred_list = [ | |
| mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels) | |
| for lvl in range(num_levels) | |
| ] | |
| mask_pred_list = [ | |
| mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels) | |
| ] | |
| cls_pred_list = torch.cat(cls_pred_list, dim=0) | |
| mask_pred_list = torch.cat(mask_pred_list, dim=0) | |
| img_meta = batch_img_metas[img_id] | |
| results = self._predict_by_feat_single( | |
| cls_pred_list, mask_pred_list, img_meta=img_meta) | |
| results_list.append(results) | |
| return results_list | |
| def _predict_by_feat_single(self, | |
| cls_scores: Tensor, | |
| mask_preds: Tensor, | |
| img_meta: dict, | |
| cfg: OptConfigType = None) -> InstanceData: | |
| """Transform a single image's features extracted from the head into | |
| mask results. | |
| Args: | |
| cls_scores (Tensor): Classification score of all points | |
| in single image, has shape (num_points, num_classes). | |
| mask_preds (Tensor): Mask prediction of all points in | |
| single image, has shape (num_points, feat_h, feat_w). | |
| img_meta (dict): Meta information of corresponding image. | |
| cfg (dict, optional): Config used in test phase. | |
| Defaults to None. | |
| Returns: | |
| :obj:`InstanceData`: Processed results of single image. | |
| it usually contains following keys. | |
| - scores (Tensor): Classification scores, has shape | |
| (num_instance,). | |
| - labels (Tensor): Has shape (num_instances,). | |
| - masks (Tensor): Processed mask results, has | |
| shape (num_instances, h, w). | |
| """ | |
| def empty_results(cls_scores, ori_shape): | |
| """Generate a empty results.""" | |
| results = InstanceData() | |
| results.scores = cls_scores.new_ones(0) | |
| results.masks = cls_scores.new_zeros(0, *ori_shape) | |
| results.labels = cls_scores.new_ones(0) | |
| results.bboxes = cls_scores.new_zeros(0, 4) | |
| return results | |
| cfg = self.test_cfg if cfg is None else cfg | |
| assert len(cls_scores) == len(mask_preds) | |
| featmap_size = mask_preds.size()[-2:] | |
| h, w = img_meta['img_shape'][:2] | |
| upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) | |
| score_mask = (cls_scores > cfg.score_thr) | |
| cls_scores = cls_scores[score_mask] | |
| if len(cls_scores) == 0: | |
| return empty_results(cls_scores, img_meta['ori_shape'][:2]) | |
| inds = score_mask.nonzero() | |
| cls_labels = inds[:, 1] | |
| # Filter the mask mask with an area is smaller than | |
| # stride of corresponding feature level | |
| lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0) | |
| strides = cls_scores.new_ones(lvl_interval[-1]) | |
| strides[:lvl_interval[0]] *= self.strides[0] | |
| for lvl in range(1, self.num_levels): | |
| strides[lvl_interval[lvl - | |
| 1]:lvl_interval[lvl]] *= self.strides[lvl] | |
| strides = strides[inds[:, 0]] | |
| mask_preds = mask_preds[inds[:, 0]] | |
| masks = mask_preds > cfg.mask_thr | |
| sum_masks = masks.sum((1, 2)).float() | |
| keep = sum_masks > strides | |
| if keep.sum() == 0: | |
| return empty_results(cls_scores, img_meta['ori_shape'][:2]) | |
| masks = masks[keep] | |
| mask_preds = mask_preds[keep] | |
| sum_masks = sum_masks[keep] | |
| cls_scores = cls_scores[keep] | |
| cls_labels = cls_labels[keep] | |
| # maskness. | |
| mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks | |
| cls_scores *= mask_scores | |
| scores, labels, _, keep_inds = mask_matrix_nms( | |
| masks, | |
| cls_labels, | |
| cls_scores, | |
| mask_area=sum_masks, | |
| nms_pre=cfg.nms_pre, | |
| max_num=cfg.max_per_img, | |
| kernel=cfg.kernel, | |
| sigma=cfg.sigma, | |
| filter_thr=cfg.filter_thr) | |
| # mask_matrix_nms may return an empty Tensor | |
| if len(keep_inds) == 0: | |
| return empty_results(cls_scores, img_meta['ori_shape'][:2]) | |
| mask_preds = mask_preds[keep_inds] | |
| mask_preds = F.interpolate( | |
| mask_preds.unsqueeze(0), size=upsampled_size, | |
| mode='bilinear')[:, :, :h, :w] | |
| mask_preds = F.interpolate( | |
| mask_preds, size=img_meta['ori_shape'][:2], | |
| mode='bilinear').squeeze(0) | |
| masks = mask_preds > cfg.mask_thr | |
| results = InstanceData() | |
| results.masks = masks | |
| results.labels = labels | |
| results.scores = scores | |
| # create an empty bbox in InstanceData to avoid bugs when | |
| # calculating metrics. | |
| bboxes = mask2bbox(masks) | |
| # results.bboxes = results.scores.new_zeros(len(scores), 4) | |
| results.bboxes = bboxes | |
| return results | |
| class DecoupledSOLOHead(SOLOHead): | |
| """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations. | |
| <https://arxiv.org/abs/1912.04488>`_ | |
| Args: | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| """ | |
| def __init__(self, | |
| *args, | |
| init_cfg: MultiConfig = [ | |
| dict(type='Normal', layer='Conv2d', std=0.01), | |
| dict( | |
| type='Normal', | |
| std=0.01, | |
| bias_prob=0.01, | |
| override=dict(name='conv_mask_list_x')), | |
| dict( | |
| type='Normal', | |
| std=0.01, | |
| bias_prob=0.01, | |
| override=dict(name='conv_mask_list_y')), | |
| dict( | |
| type='Normal', | |
| std=0.01, | |
| bias_prob=0.01, | |
| override=dict(name='conv_cls')) | |
| ], | |
| **kwargs) -> None: | |
| super().__init__(*args, init_cfg=init_cfg, **kwargs) | |
| def _init_layers(self) -> None: | |
| self.mask_convs_x = nn.ModuleList() | |
| self.mask_convs_y = nn.ModuleList() | |
| self.cls_convs = nn.ModuleList() | |
| for i in range(self.stacked_convs): | |
| chn = self.in_channels + 1 if i == 0 else self.feat_channels | |
| self.mask_convs_x.append( | |
| ConvModule( | |
| chn, | |
| self.feat_channels, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| norm_cfg=self.norm_cfg)) | |
| self.mask_convs_y.append( | |
| ConvModule( | |
| chn, | |
| self.feat_channels, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| norm_cfg=self.norm_cfg)) | |
| chn = self.in_channels if i == 0 else self.feat_channels | |
| self.cls_convs.append( | |
| ConvModule( | |
| chn, | |
| self.feat_channels, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| norm_cfg=self.norm_cfg)) | |
| self.conv_mask_list_x = nn.ModuleList() | |
| self.conv_mask_list_y = nn.ModuleList() | |
| for num_grid in self.num_grids: | |
| self.conv_mask_list_x.append( | |
| nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) | |
| self.conv_mask_list_y.append( | |
| nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) | |
| self.conv_cls = nn.Conv2d( | |
| self.feat_channels, self.cls_out_channels, 3, padding=1) | |
| def forward(self, x: Tuple[Tensor]) -> Tuple: | |
| """Forward features from the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Features from the upstream network, each is | |
| a 4D-tensor. | |
| Returns: | |
| tuple: A tuple of classification scores and mask prediction. | |
| - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction | |
| from x branch. Each element in the list has shape | |
| (batch_size, num_grids ,h ,w). | |
| - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction | |
| from y branch. Each element in the list has shape | |
| (batch_size, num_grids ,h ,w). | |
| - mlvl_cls_preds (list[Tensor]): Multi-level scores. | |
| Each element in the list has shape | |
| (batch_size, num_classes, num_grids ,num_grids). | |
| """ | |
| assert len(x) == self.num_levels | |
| feats = self.resize_feats(x) | |
| mask_preds_x = [] | |
| mask_preds_y = [] | |
| cls_preds = [] | |
| for i in range(self.num_levels): | |
| x = feats[i] | |
| mask_feat = x | |
| cls_feat = x | |
| # generate and concat the coordinate | |
| coord_feat = generate_coordinate(mask_feat.size(), | |
| mask_feat.device) | |
| mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1) | |
| mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1) | |
| for mask_layer_x, mask_layer_y in \ | |
| zip(self.mask_convs_x, self.mask_convs_y): | |
| mask_feat_x = mask_layer_x(mask_feat_x) | |
| mask_feat_y = mask_layer_y(mask_feat_y) | |
| mask_feat_x = F.interpolate( | |
| mask_feat_x, scale_factor=2, mode='bilinear') | |
| mask_feat_y = F.interpolate( | |
| mask_feat_y, scale_factor=2, mode='bilinear') | |
| mask_pred_x = self.conv_mask_list_x[i](mask_feat_x) | |
| mask_pred_y = self.conv_mask_list_y[i](mask_feat_y) | |
| # cls branch | |
| for j, cls_layer in enumerate(self.cls_convs): | |
| if j == self.cls_down_index: | |
| num_grid = self.num_grids[i] | |
| cls_feat = F.interpolate( | |
| cls_feat, size=num_grid, mode='bilinear') | |
| cls_feat = cls_layer(cls_feat) | |
| cls_pred = self.conv_cls(cls_feat) | |
| if not self.training: | |
| feat_wh = feats[0].size()[-2:] | |
| upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) | |
| mask_pred_x = F.interpolate( | |
| mask_pred_x.sigmoid(), | |
| size=upsampled_size, | |
| mode='bilinear') | |
| mask_pred_y = F.interpolate( | |
| mask_pred_y.sigmoid(), | |
| size=upsampled_size, | |
| mode='bilinear') | |
| cls_pred = cls_pred.sigmoid() | |
| # get local maximum | |
| local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) | |
| keep_mask = local_max[:, :, :-1, :-1] == cls_pred | |
| cls_pred = cls_pred * keep_mask | |
| mask_preds_x.append(mask_pred_x) | |
| mask_preds_y.append(mask_pred_y) | |
| cls_preds.append(cls_pred) | |
| return mask_preds_x, mask_preds_y, cls_preds | |
| def loss_by_feat(self, mlvl_mask_preds_x: List[Tensor], | |
| mlvl_mask_preds_y: List[Tensor], | |
| mlvl_cls_preds: List[Tensor], | |
| batch_gt_instances: InstanceList, | |
| batch_img_metas: List[dict], **kwargs) -> dict: | |
| """Calculate the loss based on the features extracted by the mask head. | |
| Args: | |
| mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction | |
| from x branch. Each element in the list has shape | |
| (batch_size, num_grids ,h ,w). | |
| mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction | |
| from y branch. Each element in the list has shape | |
| (batch_size, num_grids ,h ,w). | |
| mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element | |
| in the list has shape | |
| (batch_size, num_classes, num_grids ,num_grids). | |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``bboxes``, ``masks``, | |
| and ``labels`` attributes. | |
| batch_img_metas (list[dict]): Meta information of multiple images. | |
| Returns: | |
| dict[str, Tensor]: A dictionary of loss components. | |
| """ | |
| num_levels = self.num_levels | |
| num_imgs = len(batch_img_metas) | |
| featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x] | |
| pos_mask_targets, labels, xy_pos_indexes = multi_apply( | |
| self._get_targets_single, | |
| batch_gt_instances, | |
| featmap_sizes=featmap_sizes) | |
| # change from the outside list meaning multi images | |
| # to the outside list meaning multi levels | |
| mlvl_pos_mask_targets = [[] for _ in range(num_levels)] | |
| mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)] | |
| mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)] | |
| mlvl_labels = [[] for _ in range(num_levels)] | |
| for img_id in range(num_imgs): | |
| for lvl in range(num_levels): | |
| mlvl_pos_mask_targets[lvl].append( | |
| pos_mask_targets[img_id][lvl]) | |
| mlvl_pos_mask_preds_x[lvl].append( | |
| mlvl_mask_preds_x[lvl][img_id, | |
| xy_pos_indexes[img_id][lvl][:, 1]]) | |
| mlvl_pos_mask_preds_y[lvl].append( | |
| mlvl_mask_preds_y[lvl][img_id, | |
| xy_pos_indexes[img_id][lvl][:, 0]]) | |
| mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) | |
| # cat multiple image | |
| temp_mlvl_cls_preds = [] | |
| for lvl in range(num_levels): | |
| mlvl_pos_mask_targets[lvl] = torch.cat( | |
| mlvl_pos_mask_targets[lvl], dim=0) | |
| mlvl_pos_mask_preds_x[lvl] = torch.cat( | |
| mlvl_pos_mask_preds_x[lvl], dim=0) | |
| mlvl_pos_mask_preds_y[lvl] = torch.cat( | |
| mlvl_pos_mask_preds_y[lvl], dim=0) | |
| mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) | |
| temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( | |
| 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) | |
| num_pos = 0. | |
| # dice loss | |
| loss_mask = [] | |
| for pred_x, pred_y, target in \ | |
| zip(mlvl_pos_mask_preds_x, | |
| mlvl_pos_mask_preds_y, mlvl_pos_mask_targets): | |
| num_masks = pred_x.size(0) | |
| if num_masks == 0: | |
| # make sure can get grad | |
| loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0)) | |
| continue | |
| num_pos += num_masks | |
| pred_mask = pred_y.sigmoid() * pred_x.sigmoid() | |
| loss_mask.append( | |
| self.loss_mask(pred_mask, target, reduction_override='none')) | |
| if num_pos > 0: | |
| loss_mask = torch.cat(loss_mask).sum() / num_pos | |
| else: | |
| loss_mask = torch.cat(loss_mask).mean() | |
| # cate | |
| flatten_labels = torch.cat(mlvl_labels) | |
| flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) | |
| loss_cls = self.loss_cls( | |
| flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) | |
| return dict(loss_mask=loss_mask, loss_cls=loss_cls) | |
| def _get_targets_single(self, | |
| gt_instances: InstanceData, | |
| featmap_sizes: Optional[list] = None) -> tuple: | |
| """Compute targets for predictions of single image. | |
| Args: | |
| gt_instances (:obj:`InstanceData`): Ground truth of instance | |
| annotations. It should includes ``bboxes``, ``labels``, | |
| and ``masks`` attributes. | |
| featmap_sizes (list[:obj:`torch.size`]): Size of each | |
| feature map from feature pyramid, each element | |
| means (feat_h, feat_w). Defaults to None. | |
| Returns: | |
| Tuple: Usually returns a tuple containing targets for predictions. | |
| - mlvl_pos_mask_targets (list[Tensor]): Each element represent | |
| the binary mask targets for positive points in this | |
| level, has shape (num_pos, out_h, out_w). | |
| - mlvl_labels (list[Tensor]): Each element is | |
| classification labels for all | |
| points in this level, has shape | |
| (num_grid, num_grid). | |
| - mlvl_xy_pos_indexes (list[Tensor]): Each element | |
| in the list contains the index of positive samples in | |
| corresponding level, has shape (num_pos, 2), last | |
| dimension 2 present (index_x, index_y). | |
| """ | |
| mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks = \ | |
| super()._get_targets_single(gt_instances, | |
| featmap_sizes=featmap_sizes) | |
| mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero() | |
| for item in mlvl_labels] | |
| return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes | |
| def predict_by_feat(self, mlvl_mask_preds_x: List[Tensor], | |
| mlvl_mask_preds_y: List[Tensor], | |
| mlvl_cls_scores: List[Tensor], | |
| batch_img_metas: List[dict], **kwargs) -> InstanceList: | |
| """Transform a batch of output features extracted from the head into | |
| mask results. | |
| Args: | |
| mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction | |
| from x branch. Each element in the list has shape | |
| (batch_size, num_grids ,h ,w). | |
| mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction | |
| from y branch. Each element in the list has shape | |
| (batch_size, num_grids ,h ,w). | |
| mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element | |
| in the list has shape | |
| (batch_size, num_classes ,num_grids ,num_grids). | |
| batch_img_metas (list[dict]): Meta information of all images. | |
| Returns: | |
| list[:obj:`InstanceData`]: Processed results of multiple | |
| images.Each :obj:`InstanceData` usually contains | |
| following keys. | |
| - scores (Tensor): Classification scores, has shape | |
| (num_instance,). | |
| - labels (Tensor): Has shape (num_instances,). | |
| - masks (Tensor): Processed mask results, has | |
| shape (num_instances, h, w). | |
| """ | |
| mlvl_cls_scores = [ | |
| item.permute(0, 2, 3, 1) for item in mlvl_cls_scores | |
| ] | |
| assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores) | |
| num_levels = len(mlvl_cls_scores) | |
| results_list = [] | |
| for img_id in range(len(batch_img_metas)): | |
| cls_pred_list = [ | |
| mlvl_cls_scores[i][img_id].view( | |
| -1, self.cls_out_channels).detach() | |
| for i in range(num_levels) | |
| ] | |
| mask_pred_list_x = [ | |
| mlvl_mask_preds_x[i][img_id] for i in range(num_levels) | |
| ] | |
| mask_pred_list_y = [ | |
| mlvl_mask_preds_y[i][img_id] for i in range(num_levels) | |
| ] | |
| cls_pred_list = torch.cat(cls_pred_list, dim=0) | |
| mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0) | |
| mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0) | |
| img_meta = batch_img_metas[img_id] | |
| results = self._predict_by_feat_single( | |
| cls_pred_list, | |
| mask_pred_list_x, | |
| mask_pred_list_y, | |
| img_meta=img_meta) | |
| results_list.append(results) | |
| return results_list | |
| def _predict_by_feat_single(self, | |
| cls_scores: Tensor, | |
| mask_preds_x: Tensor, | |
| mask_preds_y: Tensor, | |
| img_meta: dict, | |
| cfg: OptConfigType = None) -> InstanceData: | |
| """Transform a single image's features extracted from the head into | |
| mask results. | |
| Args: | |
| cls_scores (Tensor): Classification score of all points | |
| in single image, has shape (num_points, num_classes). | |
| mask_preds_x (Tensor): Mask prediction of x branch of | |
| all points in single image, has shape | |
| (sum_num_grids, feat_h, feat_w). | |
| mask_preds_y (Tensor): Mask prediction of y branch of | |
| all points in single image, has shape | |
| (sum_num_grids, feat_h, feat_w). | |
| img_meta (dict): Meta information of corresponding image. | |
| cfg (dict): Config used in test phase. | |
| Returns: | |
| :obj:`InstanceData`: Processed results of single image. | |
| it usually contains following keys. | |
| - scores (Tensor): Classification scores, has shape | |
| (num_instance,). | |
| - labels (Tensor): Has shape (num_instances,). | |
| - masks (Tensor): Processed mask results, has | |
| shape (num_instances, h, w). | |
| """ | |
| def empty_results(cls_scores, ori_shape): | |
| """Generate a empty results.""" | |
| results = InstanceData() | |
| results.scores = cls_scores.new_ones(0) | |
| results.masks = cls_scores.new_zeros(0, *ori_shape) | |
| results.labels = cls_scores.new_ones(0) | |
| results.bboxes = cls_scores.new_zeros(0, 4) | |
| return results | |
| cfg = self.test_cfg if cfg is None else cfg | |
| featmap_size = mask_preds_x.size()[-2:] | |
| h, w = img_meta['img_shape'][:2] | |
| upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) | |
| score_mask = (cls_scores > cfg.score_thr) | |
| cls_scores = cls_scores[score_mask] | |
| inds = score_mask.nonzero() | |
| lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0) | |
| num_all_points = lvl_interval[-1] | |
| lvl_start_index = inds.new_ones(num_all_points) | |
| num_grids = inds.new_ones(num_all_points) | |
| seg_size = inds.new_tensor(self.num_grids).cumsum(0) | |
| mask_lvl_start_index = inds.new_ones(num_all_points) | |
| strides = inds.new_ones(num_all_points) | |
| lvl_start_index[:lvl_interval[0]] *= 0 | |
| mask_lvl_start_index[:lvl_interval[0]] *= 0 | |
| num_grids[:lvl_interval[0]] *= self.num_grids[0] | |
| strides[:lvl_interval[0]] *= self.strides[0] | |
| for lvl in range(1, self.num_levels): | |
| lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ | |
| lvl_interval[lvl - 1] | |
| mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ | |
| seg_size[lvl - 1] | |
| num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ | |
| self.num_grids[lvl] | |
| strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ | |
| self.strides[lvl] | |
| lvl_start_index = lvl_start_index[inds[:, 0]] | |
| mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]] | |
| num_grids = num_grids[inds[:, 0]] | |
| strides = strides[inds[:, 0]] | |
| y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids | |
| x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids | |
| y_inds = mask_lvl_start_index + y_lvl_offset | |
| x_inds = mask_lvl_start_index + x_lvl_offset | |
| cls_labels = inds[:, 1] | |
| mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...] | |
| masks = mask_preds > cfg.mask_thr | |
| sum_masks = masks.sum((1, 2)).float() | |
| keep = sum_masks > strides | |
| if keep.sum() == 0: | |
| return empty_results(cls_scores, img_meta['ori_shape'][:2]) | |
| masks = masks[keep] | |
| mask_preds = mask_preds[keep] | |
| sum_masks = sum_masks[keep] | |
| cls_scores = cls_scores[keep] | |
| cls_labels = cls_labels[keep] | |
| # maskness. | |
| mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks | |
| cls_scores *= mask_scores | |
| scores, labels, _, keep_inds = mask_matrix_nms( | |
| masks, | |
| cls_labels, | |
| cls_scores, | |
| mask_area=sum_masks, | |
| nms_pre=cfg.nms_pre, | |
| max_num=cfg.max_per_img, | |
| kernel=cfg.kernel, | |
| sigma=cfg.sigma, | |
| filter_thr=cfg.filter_thr) | |
| # mask_matrix_nms may return an empty Tensor | |
| if len(keep_inds) == 0: | |
| return empty_results(cls_scores, img_meta['ori_shape'][:2]) | |
| mask_preds = mask_preds[keep_inds] | |
| mask_preds = F.interpolate( | |
| mask_preds.unsqueeze(0), size=upsampled_size, | |
| mode='bilinear')[:, :, :h, :w] | |
| mask_preds = F.interpolate( | |
| mask_preds, size=img_meta['ori_shape'][:2], | |
| mode='bilinear').squeeze(0) | |
| masks = mask_preds > cfg.mask_thr | |
| results = InstanceData() | |
| results.masks = masks | |
| results.labels = labels | |
| results.scores = scores | |
| # create an empty bbox in InstanceData to avoid bugs when | |
| # calculating metrics. | |
| bboxes = mask2bbox(masks) | |
| # results.bboxes = results.scores.new_zeros(len(scores), 4) | |
| results.bboxes = bboxes | |
| return results | |
| class DecoupledSOLOLightHead(DecoupledSOLOHead): | |
| """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by | |
| Locations <https://arxiv.org/abs/1912.04488>`_ | |
| Args: | |
| with_dcn (bool): Whether use dcn in mask_convs and cls_convs, | |
| Defaults to False. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| """ | |
| def __init__(self, | |
| *args, | |
| dcn_cfg: OptConfigType = None, | |
| init_cfg: MultiConfig = [ | |
| dict(type='Normal', layer='Conv2d', std=0.01), | |
| dict( | |
| type='Normal', | |
| std=0.01, | |
| bias_prob=0.01, | |
| override=dict(name='conv_mask_list_x')), | |
| dict( | |
| type='Normal', | |
| std=0.01, | |
| bias_prob=0.01, | |
| override=dict(name='conv_mask_list_y')), | |
| dict( | |
| type='Normal', | |
| std=0.01, | |
| bias_prob=0.01, | |
| override=dict(name='conv_cls')) | |
| ], | |
| **kwargs) -> None: | |
| assert dcn_cfg is None or isinstance(dcn_cfg, dict) | |
| self.dcn_cfg = dcn_cfg | |
| super().__init__(*args, init_cfg=init_cfg, **kwargs) | |
| def _init_layers(self) -> None: | |
| self.mask_convs = nn.ModuleList() | |
| self.cls_convs = nn.ModuleList() | |
| for i in range(self.stacked_convs): | |
| if self.dcn_cfg is not None \ | |
| and i == self.stacked_convs - 1: | |
| conv_cfg = self.dcn_cfg | |
| else: | |
| conv_cfg = None | |
| chn = self.in_channels + 2 if i == 0 else self.feat_channels | |
| self.mask_convs.append( | |
| ConvModule( | |
| chn, | |
| self.feat_channels, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=self.norm_cfg)) | |
| chn = self.in_channels if i == 0 else self.feat_channels | |
| self.cls_convs.append( | |
| ConvModule( | |
| chn, | |
| self.feat_channels, | |
| 3, | |
| stride=1, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=self.norm_cfg)) | |
| self.conv_mask_list_x = nn.ModuleList() | |
| self.conv_mask_list_y = nn.ModuleList() | |
| for num_grid in self.num_grids: | |
| self.conv_mask_list_x.append( | |
| nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) | |
| self.conv_mask_list_y.append( | |
| nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) | |
| self.conv_cls = nn.Conv2d( | |
| self.feat_channels, self.cls_out_channels, 3, padding=1) | |
| def forward(self, x: Tuple[Tensor]) -> Tuple: | |
| """Forward features from the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Features from the upstream network, each is | |
| a 4D-tensor. | |
| Returns: | |
| tuple: A tuple of classification scores and mask prediction. | |
| - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction | |
| from x branch. Each element in the list has shape | |
| (batch_size, num_grids ,h ,w). | |
| - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction | |
| from y branch. Each element in the list has shape | |
| (batch_size, num_grids ,h ,w). | |
| - mlvl_cls_preds (list[Tensor]): Multi-level scores. | |
| Each element in the list has shape | |
| (batch_size, num_classes, num_grids ,num_grids). | |
| """ | |
| assert len(x) == self.num_levels | |
| feats = self.resize_feats(x) | |
| mask_preds_x = [] | |
| mask_preds_y = [] | |
| cls_preds = [] | |
| for i in range(self.num_levels): | |
| x = feats[i] | |
| mask_feat = x | |
| cls_feat = x | |
| # generate and concat the coordinate | |
| coord_feat = generate_coordinate(mask_feat.size(), | |
| mask_feat.device) | |
| mask_feat = torch.cat([mask_feat, coord_feat], 1) | |
| for mask_layer in self.mask_convs: | |
| mask_feat = mask_layer(mask_feat) | |
| mask_feat = F.interpolate( | |
| mask_feat, scale_factor=2, mode='bilinear') | |
| mask_pred_x = self.conv_mask_list_x[i](mask_feat) | |
| mask_pred_y = self.conv_mask_list_y[i](mask_feat) | |
| # cls branch | |
| for j, cls_layer in enumerate(self.cls_convs): | |
| if j == self.cls_down_index: | |
| num_grid = self.num_grids[i] | |
| cls_feat = F.interpolate( | |
| cls_feat, size=num_grid, mode='bilinear') | |
| cls_feat = cls_layer(cls_feat) | |
| cls_pred = self.conv_cls(cls_feat) | |
| if not self.training: | |
| feat_wh = feats[0].size()[-2:] | |
| upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) | |
| mask_pred_x = F.interpolate( | |
| mask_pred_x.sigmoid(), | |
| size=upsampled_size, | |
| mode='bilinear') | |
| mask_pred_y = F.interpolate( | |
| mask_pred_y.sigmoid(), | |
| size=upsampled_size, | |
| mode='bilinear') | |
| cls_pred = cls_pred.sigmoid() | |
| # get local maximum | |
| local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) | |
| keep_mask = local_max[:, :, :-1, :-1] == cls_pred | |
| cls_pred = cls_pred * keep_mask | |
| mask_preds_x.append(mask_pred_x) | |
| mask_preds_y.append(mask_pred_y) | |
| cls_preds.append(cls_pred) | |
| return mask_preds_x, mask_preds_y, cls_preds | |