Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Implements a resized feature fuser for stacked decoders in MaX-DeepLab. | |
| Reference: | |
| MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, | |
| CVPR 2021. https://arxiv.org/abs/2012.00759 | |
| Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. | |
| """ | |
| import tensorflow as tf | |
| from deeplab2.model import utils | |
| from deeplab2.model.layers import activations | |
| from deeplab2.model.layers import convolutions | |
| class ResizedFuse(tf.keras.layers.Layer): | |
| """Fuses features by resizing and 1x1 convolutions. | |
| This function fuses all input features to a desired shape, by projecting the | |
| features to the desired number of channels, bilinear resizing the outputs | |
| (either upsampling or downsampling), and finally adding the outputs. If the | |
| input channel equals the desired output channels, the 1x1 convolutional | |
| projection is skipped. If the projection and bilinear resizing can be fused | |
| into a stride 2 convolution, we use this faster implementation. Other strides | |
| are also supported with the bilinear resizing, but are probably slower than | |
| strided convolutions. | |
| Reference: | |
| MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, | |
| CVPR 2021. https://arxiv.org/abs/2012.00759 | |
| Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. | |
| """ | |
| def __init__(self, | |
| name, | |
| height, | |
| width, | |
| num_channels, | |
| activation='relu', | |
| bn_layer=tf.keras.layers.BatchNormalization, | |
| conv_kernel_weight_decay=0.0): | |
| """Initializes a ResizedFuse layer. | |
| Args: | |
| name: A string, the name of this layer. | |
| height: An integer, the desired height of the output. | |
| width: An integer, the desired width of the output. | |
| num_channels: An integer, the num of output channels. | |
| activation: A string, type of activation function to apply. | |
| bn_layer: A tf.keras.layers.Layer that computes the normalization | |
| (default: tf.keras.layers.BatchNormalization). | |
| conv_kernel_weight_decay: A float, the weight decay for convolution | |
| kernels. | |
| """ | |
| super(ResizedFuse, self).__init__(name=name) | |
| self._height = height | |
| self._width = width | |
| self._num_channels = num_channels | |
| self._activation_fn = activations.get_activation(activation) | |
| self._bn_layer = bn_layer | |
| self._conv_kernel_weight_decay = conv_kernel_weight_decay | |
| def build(self, input_shapes): | |
| for index, feature_shape in enumerate(input_shapes): | |
| _, feature_height, feature_width, feature_channels = feature_shape | |
| if feature_channels == self._num_channels: | |
| continue | |
| elif ((feature_height + 1) // 2 == self._height and | |
| (feature_width + 1) // 2 == self._width): | |
| # Use stride 2 convolution to accelerate the operation if it generates | |
| # the desired spatial shape. Otherwise, the more general 1x1 convolution | |
| # and bilinear resizing are applied. | |
| # In a stacked decoder, we follow relu-conv-bn because we do the feature | |
| # summation before relu and after bn (following ResNet bottleneck | |
| # design). This ordering makes it easier to implement. Besides, it | |
| # avoids using many 1x1 convolutions when the input has a correct shape. | |
| current_name = '_strided_conv_bn{}'.format(index + 1) | |
| utils.safe_setattr( | |
| self, current_name, convolutions.Conv2DSame( | |
| self._num_channels, 1, current_name[1:], | |
| strides=2, | |
| use_bias=False, | |
| use_bn=True, | |
| bn_layer=self._bn_layer, | |
| activation='none', | |
| conv_kernel_weight_decay=self._conv_kernel_weight_decay)) | |
| else: | |
| # If the input channel does not match that of the output, and the | |
| # operation cannot be accelerated by stride 2 convolution, then we | |
| # perform a flexible operation as follows. We first project the feature | |
| # to the desired number of channels, and then bilinearly resize the | |
| # output to the desired spatial resolution. | |
| current_name = '_resized_conv_bn{}'.format(index + 1) | |
| utils.safe_setattr( | |
| self, current_name, convolutions.Conv2DSame( | |
| self._num_channels, 1, current_name[1:], | |
| use_bias=False, | |
| use_bn=True, | |
| bn_layer=self._bn_layer, | |
| activation='none', | |
| conv_kernel_weight_decay=self._conv_kernel_weight_decay)) | |
| def call(self, inputs, training=False): | |
| """Performs a forward pass. | |
| Args: | |
| inputs: A list of input [batch, input_height, input_width, input_channels] | |
| tensors to fuse, where each input tensor may have different spatial | |
| resolutions and number of channels. | |
| training: A boolean, whether the model is in training mode. | |
| Returns: | |
| output: A fused feature [batch, height, width, num_channels] tensor. | |
| """ | |
| output_features = [] | |
| for index, feature in enumerate(inputs): | |
| _, feature_height, feature_width, feature_channels = ( | |
| feature.get_shape().as_list()) | |
| if feature_channels == self._num_channels: | |
| # Resize the input feature if the number of channels equals the output. | |
| # We do not use a 1x1 convolution for this case because the previous | |
| # operation and the next operation are usually also 1x1 convolutions. | |
| # Besides, in stacked decoder, a feature can be reused many time, so it | |
| # saves parameter to avoid those many 1x1 convolutions. | |
| output_features.append(utils.resize_bilinear( | |
| feature, [self._height, self._width], | |
| align_corners=True)) | |
| elif ((feature_height + 1) // 2 == self._height and | |
| (feature_width + 1) // 2 == self._width): | |
| current_name = '_strided_conv_bn{}'.format(index + 1) | |
| feature = self._activation_fn(feature) | |
| feature = getattr(self, current_name)(feature, training=training) | |
| output_features.append(feature) | |
| else: | |
| current_name = '_resized_conv_bn{}'.format(index + 1) | |
| feature = self._activation_fn(feature) | |
| feature = getattr(self, current_name)(feature, training=training) | |
| output_features.append(utils.resize_bilinear( | |
| feature, [self._height, self._width], | |
| align_corners=True)) | |
| # Set the spatial shape of each output feature if possible. | |
| output_features[-1].set_shape( | |
| [None, | |
| self._height, | |
| self._width, | |
| self._num_channels]) | |
| output = tf.add_n(output_features) | |
| return self._activation_fn(output) | |