Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import Conv2d | |
| from mmcv.ops import point_sample | |
| from mmengine.model import ModuleList, caffe2_xavier_init | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| from mmdet.registry import MODELS, TASK_UTILS | |
| from mmdet.structures import SampleList | |
| from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig, reduce_mean | |
| from ..layers import Mask2FormerTransformerDecoder, SinePositionalEncoding | |
| from ..utils import get_uncertain_point_coords_with_randomness | |
| from .anchor_free_head import AnchorFreeHead | |
| from .maskformer_head import MaskFormerHead | |
| class Mask2FormerHead(MaskFormerHead): | |
| """Implements the Mask2Former head. | |
| See `Masked-attention Mask Transformer for Universal Image | |
| Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details. | |
| Args: | |
| in_channels (list[int]): Number of channels in the input feature map. | |
| feat_channels (int): Number of channels for features. | |
| out_channels (int): Number of channels for output. | |
| num_things_classes (int): Number of things. | |
| num_stuff_classes (int): Number of stuff. | |
| num_queries (int): Number of query in Transformer decoder. | |
| pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel | |
| decoder. Defaults to None. | |
| enforce_decoder_input_project (bool, optional): Whether to add | |
| a layer to change the embed_dim of tranformer encoder in | |
| pixel decoder to the embed_dim of transformer decoder. | |
| Defaults to False. | |
| transformer_decoder (:obj:`ConfigDict` or dict): Config for | |
| transformer decoder. Defaults to None. | |
| positional_encoding (:obj:`ConfigDict` or dict): Config for | |
| transformer decoder position encoding. Defaults to | |
| dict(num_feats=128, normalize=True). | |
| loss_cls (:obj:`ConfigDict` or dict): Config of the classification | |
| loss. Defaults to None. | |
| loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. | |
| Defaults to None. | |
| loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. | |
| Defaults to None. | |
| train_cfg (:obj:`ConfigDict` or dict, optional): Training config of | |
| Mask2Former head. | |
| test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of | |
| Mask2Former head. | |
| init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ | |
| dict], optional): Initialization config dict. Defaults to None. | |
| """ | |
| def __init__(self, | |
| in_channels: List[int], | |
| feat_channels: int, | |
| out_channels: int, | |
| num_things_classes: int = 80, | |
| num_stuff_classes: int = 53, | |
| num_queries: int = 100, | |
| num_transformer_feat_level: int = 3, | |
| pixel_decoder: ConfigType = ..., | |
| enforce_decoder_input_project: bool = False, | |
| transformer_decoder: ConfigType = ..., | |
| positional_encoding: ConfigType = dict( | |
| num_feats=128, normalize=True), | |
| loss_cls: ConfigType = dict( | |
| type='CrossEntropyLoss', | |
| use_sigmoid=False, | |
| loss_weight=2.0, | |
| reduction='mean', | |
| class_weight=[1.0] * 133 + [0.1]), | |
| loss_mask: ConfigType = dict( | |
| type='CrossEntropyLoss', | |
| use_sigmoid=True, | |
| reduction='mean', | |
| loss_weight=5.0), | |
| loss_dice: ConfigType = dict( | |
| type='DiceLoss', | |
| use_sigmoid=True, | |
| activate=True, | |
| reduction='mean', | |
| naive_dice=True, | |
| eps=1.0, | |
| loss_weight=5.0), | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None, | |
| **kwargs) -> None: | |
| super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) | |
| self.num_things_classes = num_things_classes | |
| self.num_stuff_classes = num_stuff_classes | |
| self.num_classes = self.num_things_classes + self.num_stuff_classes | |
| self.num_queries = num_queries | |
| self.num_transformer_feat_level = num_transformer_feat_level | |
| self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads | |
| self.num_transformer_decoder_layers = transformer_decoder.num_layers | |
| assert pixel_decoder.encoder.layer_cfg. \ | |
| self_attn_cfg.num_levels == num_transformer_feat_level | |
| pixel_decoder_ = copy.deepcopy(pixel_decoder) | |
| pixel_decoder_.update( | |
| in_channels=in_channels, | |
| feat_channels=feat_channels, | |
| out_channels=out_channels) | |
| self.pixel_decoder = MODELS.build(pixel_decoder_) | |
| self.transformer_decoder = Mask2FormerTransformerDecoder( | |
| **transformer_decoder) | |
| self.decoder_embed_dims = self.transformer_decoder.embed_dims | |
| self.decoder_input_projs = ModuleList() | |
| # from low resolution to high resolution | |
| for _ in range(num_transformer_feat_level): | |
| if (self.decoder_embed_dims != feat_channels | |
| or enforce_decoder_input_project): | |
| self.decoder_input_projs.append( | |
| Conv2d( | |
| feat_channels, self.decoder_embed_dims, kernel_size=1)) | |
| else: | |
| self.decoder_input_projs.append(nn.Identity()) | |
| self.decoder_positional_encoding = SinePositionalEncoding( | |
| **positional_encoding) | |
| self.query_embed = nn.Embedding(self.num_queries, feat_channels) | |
| self.query_feat = nn.Embedding(self.num_queries, feat_channels) | |
| # from low resolution to high resolution | |
| self.level_embed = nn.Embedding(self.num_transformer_feat_level, | |
| feat_channels) | |
| self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) | |
| self.mask_embed = nn.Sequential( | |
| nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
| nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
| nn.Linear(feat_channels, out_channels)) | |
| self.test_cfg = test_cfg | |
| self.train_cfg = train_cfg | |
| if train_cfg: | |
| self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) | |
| self.sampler = TASK_UTILS.build( | |
| self.train_cfg['sampler'], default_args=dict(context=self)) | |
| self.num_points = self.train_cfg.get('num_points', 12544) | |
| self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) | |
| self.importance_sample_ratio = self.train_cfg.get( | |
| 'importance_sample_ratio', 0.75) | |
| self.class_weight = loss_cls.class_weight | |
| self.loss_cls = MODELS.build(loss_cls) | |
| self.loss_mask = MODELS.build(loss_mask) | |
| self.loss_dice = MODELS.build(loss_dice) | |
| def init_weights(self) -> None: | |
| for m in self.decoder_input_projs: | |
| if isinstance(m, Conv2d): | |
| caffe2_xavier_init(m, bias=0) | |
| self.pixel_decoder.init_weights() | |
| for p in self.transformer_decoder.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_normal_(p) | |
| def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, | |
| gt_instances: InstanceData, | |
| img_meta: dict) -> Tuple[Tensor]: | |
| """Compute classification and mask targets for one image. | |
| Args: | |
| cls_score (Tensor): Mask score logits from a single decoder layer | |
| for one image. Shape (num_queries, cls_out_channels). | |
| mask_pred (Tensor): Mask logits for a single decoder layer for one | |
| image. Shape (num_queries, h, w). | |
| gt_instances (:obj:`InstanceData`): It contains ``labels`` and | |
| ``masks``. | |
| img_meta (dict): Image informtation. | |
| Returns: | |
| tuple[Tensor]: A tuple containing the following for one image. | |
| - labels (Tensor): Labels of each image. \ | |
| shape (num_queries, ). | |
| - label_weights (Tensor): Label weights of each image. \ | |
| shape (num_queries, ). | |
| - mask_targets (Tensor): Mask targets of each image. \ | |
| shape (num_queries, h, w). | |
| - mask_weights (Tensor): Mask weights of each image. \ | |
| shape (num_queries, ). | |
| - pos_inds (Tensor): Sampled positive indices for each \ | |
| image. | |
| - neg_inds (Tensor): Sampled negative indices for each \ | |
| image. | |
| - sampling_result (:obj:`SamplingResult`): Sampling results. | |
| """ | |
| gt_labels = gt_instances.labels | |
| gt_masks = gt_instances.masks | |
| # sample points | |
| num_queries = cls_score.shape[0] | |
| num_gts = gt_labels.shape[0] | |
| point_coords = torch.rand((1, self.num_points, 2), | |
| device=cls_score.device) | |
| # shape (num_queries, num_points) | |
| mask_points_pred = point_sample( | |
| mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, | |
| 1)).squeeze(1) | |
| # shape (num_gts, num_points) | |
| gt_points_masks = point_sample( | |
| gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, | |
| 1)).squeeze(1) | |
| sampled_gt_instances = InstanceData( | |
| labels=gt_labels, masks=gt_points_masks) | |
| sampled_pred_instances = InstanceData( | |
| scores=cls_score, masks=mask_points_pred) | |
| # assign and sample | |
| assign_result = self.assigner.assign( | |
| pred_instances=sampled_pred_instances, | |
| gt_instances=sampled_gt_instances, | |
| img_meta=img_meta) | |
| pred_instances = InstanceData(scores=cls_score, masks=mask_pred) | |
| sampling_result = self.sampler.sample( | |
| assign_result=assign_result, | |
| pred_instances=pred_instances, | |
| gt_instances=gt_instances) | |
| pos_inds = sampling_result.pos_inds | |
| neg_inds = sampling_result.neg_inds | |
| # label target | |
| labels = gt_labels.new_full((self.num_queries, ), | |
| self.num_classes, | |
| dtype=torch.long) | |
| labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] | |
| label_weights = gt_labels.new_ones((self.num_queries, )) | |
| # mask target | |
| mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] | |
| mask_weights = mask_pred.new_zeros((self.num_queries, )) | |
| mask_weights[pos_inds] = 1.0 | |
| return (labels, label_weights, mask_targets, mask_weights, pos_inds, | |
| neg_inds, sampling_result) | |
| def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, | |
| batch_gt_instances: List[InstanceData], | |
| batch_img_metas: List[dict]) -> Tuple[Tensor]: | |
| """Loss function for outputs from a single decoder layer. | |
| Args: | |
| cls_scores (Tensor): Mask score logits from a single decoder layer | |
| for all images. Shape (batch_size, num_queries, | |
| cls_out_channels). Note `cls_out_channels` should includes | |
| background. | |
| mask_preds (Tensor): Mask logits for a pixel decoder for all | |
| images. Shape (batch_size, num_queries, h, w). | |
| batch_gt_instances (list[obj:`InstanceData`]): each contains | |
| ``labels`` and ``masks``. | |
| batch_img_metas (list[dict]): List of image meta information. | |
| Returns: | |
| tuple[Tensor]: Loss components for outputs from a single \ | |
| decoder layer. | |
| """ | |
| num_imgs = cls_scores.size(0) | |
| cls_scores_list = [cls_scores[i] for i in range(num_imgs)] | |
| mask_preds_list = [mask_preds[i] for i in range(num_imgs)] | |
| (labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |
| avg_factor) = self.get_targets(cls_scores_list, mask_preds_list, | |
| batch_gt_instances, batch_img_metas) | |
| # shape (batch_size, num_queries) | |
| labels = torch.stack(labels_list, dim=0) | |
| # shape (batch_size, num_queries) | |
| label_weights = torch.stack(label_weights_list, dim=0) | |
| # shape (num_total_gts, h, w) | |
| mask_targets = torch.cat(mask_targets_list, dim=0) | |
| # shape (batch_size, num_queries) | |
| mask_weights = torch.stack(mask_weights_list, dim=0) | |
| # classfication loss | |
| # shape (batch_size * num_queries, ) | |
| cls_scores = cls_scores.flatten(0, 1) | |
| labels = labels.flatten(0, 1) | |
| label_weights = label_weights.flatten(0, 1) | |
| class_weight = cls_scores.new_tensor(self.class_weight) | |
| loss_cls = self.loss_cls( | |
| cls_scores, | |
| labels, | |
| label_weights, | |
| avg_factor=class_weight[labels].sum()) | |
| num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) | |
| num_total_masks = max(num_total_masks, 1) | |
| # extract positive ones | |
| # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) | |
| mask_preds = mask_preds[mask_weights > 0] | |
| if mask_targets.shape[0] == 0: | |
| # zero match | |
| loss_dice = mask_preds.sum() | |
| loss_mask = mask_preds.sum() | |
| return loss_cls, loss_mask, loss_dice | |
| with torch.no_grad(): | |
| points_coords = get_uncertain_point_coords_with_randomness( | |
| mask_preds.unsqueeze(1), None, self.num_points, | |
| self.oversample_ratio, self.importance_sample_ratio) | |
| # shape (num_total_gts, h, w) -> (num_total_gts, num_points) | |
| mask_point_targets = point_sample( | |
| mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) | |
| # shape (num_queries, h, w) -> (num_queries, num_points) | |
| mask_point_preds = point_sample( | |
| mask_preds.unsqueeze(1), points_coords).squeeze(1) | |
| # dice loss | |
| loss_dice = self.loss_dice( | |
| mask_point_preds, mask_point_targets, avg_factor=num_total_masks) | |
| # mask loss | |
| # shape (num_queries, num_points) -> (num_queries * num_points, ) | |
| mask_point_preds = mask_point_preds.reshape(-1) | |
| # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) | |
| mask_point_targets = mask_point_targets.reshape(-1) | |
| loss_mask = self.loss_mask( | |
| mask_point_preds, | |
| mask_point_targets, | |
| avg_factor=num_total_masks * self.num_points) | |
| return loss_cls, loss_mask, loss_dice | |
| def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor, | |
| attn_mask_target_size: Tuple[int, int]) -> Tuple[Tensor]: | |
| """Forward for head part which is called after every decoder layer. | |
| Args: | |
| decoder_out (Tensor): in shape (batch_size, num_queries, c). | |
| mask_feature (Tensor): in shape (batch_size, c, h, w). | |
| attn_mask_target_size (tuple[int, int]): target attention | |
| mask size. | |
| Returns: | |
| tuple: A tuple contain three elements. | |
| - cls_pred (Tensor): Classification scores in shape \ | |
| (batch_size, num_queries, cls_out_channels). \ | |
| Note `cls_out_channels` should includes background. | |
| - mask_pred (Tensor): Mask scores in shape \ | |
| (batch_size, num_queries,h, w). | |
| - attn_mask (Tensor): Attention mask in shape \ | |
| (batch_size * num_heads, num_queries, h, w). | |
| """ | |
| decoder_out = self.transformer_decoder.post_norm(decoder_out) | |
| # shape (num_queries, batch_size, c) | |
| cls_pred = self.cls_embed(decoder_out) | |
| # shape (num_queries, batch_size, c) | |
| mask_embed = self.mask_embed(decoder_out) | |
| # shape (num_queries, batch_size, h, w) | |
| mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) | |
| attn_mask = F.interpolate( | |
| mask_pred, | |
| attn_mask_target_size, | |
| mode='bilinear', | |
| align_corners=False) | |
| # shape (num_queries, batch_size, h, w) -> | |
| # (batch_size * num_head, num_queries, h, w) | |
| attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( | |
| (1, self.num_heads, 1, 1)).flatten(0, 1) | |
| attn_mask = attn_mask.sigmoid() < 0.5 | |
| attn_mask = attn_mask.detach() | |
| return cls_pred, mask_pred, attn_mask | |
| def forward(self, x: List[Tensor], | |
| batch_data_samples: SampleList) -> Tuple[List[Tensor]]: | |
| """Forward function. | |
| Args: | |
| x (list[Tensor]): Multi scale Features from the | |
| upstream network, each is a 4D-tensor. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| Returns: | |
| tuple[list[Tensor]]: A tuple contains two elements. | |
| - cls_pred_list (list[Tensor)]: Classification logits \ | |
| for each decoder layer. Each is a 3D-tensor with shape \ | |
| (batch_size, num_queries, cls_out_channels). \ | |
| Note `cls_out_channels` should includes background. | |
| - mask_pred_list (list[Tensor]): Mask logits for each \ | |
| decoder layer. Each with shape (batch_size, num_queries, \ | |
| h, w). | |
| """ | |
| batch_img_metas = [ | |
| data_sample.metainfo for data_sample in batch_data_samples | |
| ] | |
| batch_size = len(batch_img_metas) | |
| mask_features, multi_scale_memorys = self.pixel_decoder(x) | |
| # multi_scale_memorys (from low resolution to high resolution) | |
| decoder_inputs = [] | |
| decoder_positional_encodings = [] | |
| for i in range(self.num_transformer_feat_level): | |
| decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) | |
| # shape (batch_size, c, h, w) -> (batch_size, h*w, c) | |
| decoder_input = decoder_input.flatten(2).permute(0, 2, 1) | |
| level_embed = self.level_embed.weight[i].view(1, 1, -1) | |
| decoder_input = decoder_input + level_embed | |
| # shape (batch_size, c, h, w) -> (batch_size, h*w, c) | |
| mask = decoder_input.new_zeros( | |
| (batch_size, ) + multi_scale_memorys[i].shape[-2:], | |
| dtype=torch.bool) | |
| decoder_positional_encoding = self.decoder_positional_encoding( | |
| mask) | |
| decoder_positional_encoding = decoder_positional_encoding.flatten( | |
| 2).permute(0, 2, 1) | |
| decoder_inputs.append(decoder_input) | |
| decoder_positional_encodings.append(decoder_positional_encoding) | |
| # shape (num_queries, c) -> (batch_size, num_queries, c) | |
| query_feat = self.query_feat.weight.unsqueeze(0).repeat( | |
| (batch_size, 1, 1)) | |
| query_embed = self.query_embed.weight.unsqueeze(0).repeat( | |
| (batch_size, 1, 1)) | |
| cls_pred_list = [] | |
| mask_pred_list = [] | |
| cls_pred, mask_pred, attn_mask = self._forward_head( | |
| query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) | |
| cls_pred_list.append(cls_pred) | |
| mask_pred_list.append(mask_pred) | |
| for i in range(self.num_transformer_decoder_layers): | |
| level_idx = i % self.num_transformer_feat_level | |
| # if a mask is all True(all background), then set it all False. | |
| attn_mask[torch.where( | |
| attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |
| # cross_attn + self_attn | |
| layer = self.transformer_decoder.layers[i] | |
| query_feat = layer( | |
| query=query_feat, | |
| key=decoder_inputs[level_idx], | |
| value=decoder_inputs[level_idx], | |
| query_pos=query_embed, | |
| key_pos=decoder_positional_encodings[level_idx], | |
| cross_attn_mask=attn_mask, | |
| query_key_padding_mask=None, | |
| # here we do not apply masking on padded region | |
| key_padding_mask=None) | |
| cls_pred, mask_pred, attn_mask = self._forward_head( | |
| query_feat, mask_features, multi_scale_memorys[ | |
| (i + 1) % self.num_transformer_feat_level].shape[-2:]) | |
| cls_pred_list.append(cls_pred) | |
| mask_pred_list.append(mask_pred) | |
| return cls_pred_list, mask_pred_list | |