Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """This file contains functions to post-process MaX-DeepLab results.""" | |
| import functools | |
| from typing import List, Tuple, Dict, Text | |
| import tensorflow as tf | |
| from deeplab2 import common | |
| from deeplab2 import config_pb2 | |
| from deeplab2.data import dataset | |
| from deeplab2.model import utils | |
| def _get_transformer_class_prediction( | |
| transformer_class_probs: tf.Tensor, | |
| transformer_class_confidence_threshold: float | |
| ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: | |
| """Computes the transformer class prediction and confidence score. | |
| Args: | |
| transformer_class_probs: A tf.Tensor of shape [num_mask_slots, | |
| num_thing_stuff_classes + 1]. It is a pixel level logit scores where the | |
| num_mask_slots is the number of mask slots (for both thing classes and | |
| stuff classes) in MaX-DeepLab. The last channel indicates a `void` class. | |
| transformer_class_confidence_threshold: A float for thresholding the | |
| confidence of the transformer_class_probs. The panoptic mask slots with | |
| class confidence less than the threshold are filtered and not used for | |
| panoptic prediction. Only masks whose confidence is larger than the | |
| threshold are counted in num_detections. | |
| Returns: | |
| A tuple of: | |
| - the detected mask class prediction as float32 tf.Tensor of shape | |
| [num_detections]. | |
| - the detected mask indices as tf.Tensor of shape [num_detections]. | |
| - the number of detections as tf.Tensor of shape [1]. | |
| """ | |
| transformer_class_pred = tf.cast( | |
| tf.argmax(transformer_class_probs, axis=-1), tf.float32) | |
| transformer_class_confidence = tf.reduce_max( | |
| transformer_class_probs, axis=-1, keepdims=False) | |
| # Filter mask IDs with class confidence less than the threshold. | |
| thresholded_mask = tf.cast( | |
| tf.greater_equal(transformer_class_confidence, | |
| transformer_class_confidence_threshold), tf.float32) | |
| transformer_class_confidence = (transformer_class_confidence | |
| * thresholded_mask) | |
| detected_mask_indices = tf.where(tf.greater(thresholded_mask, 0.5))[:, 0] | |
| detected_mask_class_pred = tf.gather( | |
| transformer_class_pred, detected_mask_indices) | |
| num_detections = tf.shape(detected_mask_indices)[0] | |
| return detected_mask_class_pred, detected_mask_indices, num_detections | |
| def _get_mask_id_and_semantic_maps( | |
| thing_class_ids: List[int], | |
| stuff_class_ids: List[int], | |
| pixel_space_mask_logits: tf.Tensor, | |
| transformer_class_probs: tf.Tensor, | |
| image_shape: List[int], | |
| pixel_confidence_threshold=0.4, | |
| transformer_class_confidence_threshold=0.7, | |
| pieces=1) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: | |
| """Computes the pixel-level mask ID map and semantic map per image. | |
| Args: | |
| thing_class_ids: A List of integers of shape [num_thing_classes] containing | |
| thing class indices. | |
| stuff_class_ids: A List of integers of shape [num_thing_classes] containing | |
| stuff class indices. | |
| pixel_space_mask_logits: A tf.Tensor of shape [height, width, | |
| num_mask_slots]. It is a pixel level logit scores where the | |
| num_mask_slots is the number of mask slots (for both thing classes | |
| and stuff classes) in MaX-DeepLab. | |
| transformer_class_probs: A tf.Tensor of shape [num_mask_slots, | |
| num_thing_stuff_classes + 1]. It is a pixel level logit scores where the | |
| num_mask_slots is the number of mask slots (for both thing classes and | |
| stuff classes) in MaX-DeepLab. The last channel indicates a `void` class. | |
| image_shape: A list of integers specifying the [height, width] of input | |
| image. | |
| pixel_confidence_threshold: A float indicating a threshold for the pixel | |
| level softmax probability confidence of transformer mask logits. If less | |
| than the threshold, the pixel locations have confidence `0` in | |
| `confident_regions` output, and represent `void` (ignore) regions. | |
| transformer_class_confidence_threshold: A float for thresholding the | |
| confidence of the transformer_class_probs. The panoptic mask slots with | |
| class confidence less than the threshold are filtered and not used for | |
| panoptic prediction. | |
| pieces: An integer indicating the number of pieces in the piece-wise | |
| operation. When computing panpotic prediction and confident regions, the | |
| mask logits are divided width-wise into multiple pieces and processed | |
| piece-wise due to the GPU memory limit. Then, the piece-wise outputs are | |
| concatenated along the width into the original mask shape. Defaults to 1. | |
| Returns: | |
| A tuple of: | |
| - the mask ID prediction as tf.Tensor with shape [height, width]. | |
| - the semantic prediction as tf.Tensor with shape [height, width]. | |
| - the thing region mask as tf.Tensor with shape [height, width]. | |
| - the stuff region mask as tf.Tensor with shape [height, width]. | |
| Raises: | |
| ValueError: When input image's `width - 1` is not divisible by `pieces`. | |
| """ | |
| # The last channel indicates `void` class and thus is not included. | |
| transformer_class_probs = transformer_class_probs[..., :-1] | |
| # Generate mapping from mask IDs to dataset's thing and stuff semantic IDs. | |
| thing_stuff_class_ids = thing_class_ids + stuff_class_ids | |
| detected_mask_class_pred, detected_mask_indices, num_detections = ( | |
| _get_transformer_class_prediction(transformer_class_probs, | |
| transformer_class_confidence_threshold)) | |
| # If num_detections = 0, return empty result maps. | |
| def _return_empty_mask_id_and_semantic_maps(): | |
| return ( | |
| tf.ones([image_shape[0], image_shape[1]], dtype=tf.int32), | |
| tf.zeros([image_shape[0], image_shape[1]], dtype=tf.int32), | |
| tf.zeros([image_shape[0], image_shape[1]], dtype=tf.float32), | |
| tf.zeros([image_shape[0], image_shape[1]], dtype=tf.float32)) | |
| # If num_detections > 0: | |
| def _generate_mask_id_and_semantic_maps(): | |
| output_mask_id_map = [] | |
| output_confident_region = [] | |
| logits_width = pixel_space_mask_logits.get_shape().as_list()[1] | |
| output_width = image_shape[1] | |
| if (output_width - 1) % pieces > 0: | |
| raise ValueError('`output_width - 1` must be divisible by `pieces`.') | |
| # Use of input shape of a multiple of the feature stride, plus one, so that | |
| # it preserves left- and right-alignment. | |
| piece_output_width = (output_width - 1) // pieces + 1 | |
| for piece_id in range(pieces): | |
| piece_begin = (logits_width - 1) // pieces * piece_id | |
| # Use of input shape of a multiple of the feature stride, plus one, so | |
| # that it preserves left- and right-alignment. | |
| piece_end = (logits_width - 1) // pieces * (piece_id + 1) + 1 | |
| piece_pixel_mask_logits = ( | |
| pixel_space_mask_logits[:, piece_begin:piece_end, :]) | |
| piece_pixel_mask_logits = tf.compat.v1.image.resize_bilinear( | |
| tf.expand_dims(piece_pixel_mask_logits, 0), | |
| (image_shape[0], piece_output_width), | |
| align_corners=True) | |
| piece_pixel_mask_logits = tf.squeeze(piece_pixel_mask_logits, axis=0) | |
| piece_detected_pixel_mask_logits = tf.gather( | |
| piece_pixel_mask_logits, detected_mask_indices, axis=-1) | |
| # Filter the pixels which are assigned to a mask ID that does not survive. | |
| piece_max_logits = tf.reduce_max(piece_pixel_mask_logits, axis=-1) | |
| piece_detected_max_logits = tf.reduce_max( | |
| piece_detected_pixel_mask_logits, axis=-1) | |
| piece_detected_mask = tf.cast(tf.math.equal( | |
| piece_max_logits, piece_detected_max_logits), tf.float32) | |
| # Filter with pixel mask threshold. | |
| piece_pixel_confidence_map = tf.reduce_max( | |
| tf.nn.softmax(piece_detected_pixel_mask_logits, axis=-1), axis=-1) | |
| piece_confident_region = tf.cast( | |
| piece_pixel_confidence_map > pixel_confidence_threshold, tf.float32) | |
| piece_confident_region = piece_confident_region * piece_detected_mask | |
| piece_mask_id_map = tf.cast( | |
| tf.argmax(piece_detected_pixel_mask_logits, axis=-1), tf.int32) | |
| if piece_id == pieces - 1: | |
| output_mask_id_map.append(piece_mask_id_map) | |
| output_confident_region.append(piece_confident_region) | |
| else: | |
| output_mask_id_map.append(piece_mask_id_map[:, :-1]) | |
| output_confident_region.append(piece_confident_region[:, :-1]) | |
| mask_id_map = tf.concat(output_mask_id_map, axis=1) | |
| confident_region = tf.concat(output_confident_region, axis=1) | |
| mask_id_map_flat = tf.reshape(mask_id_map, [-1]) | |
| mask_id_semantic_map_flat = tf.gather( | |
| detected_mask_class_pred, mask_id_map_flat) | |
| mask_id_semantic_map = tf.reshape( | |
| mask_id_semantic_map_flat, [image_shape[0], image_shape[1]]) | |
| # Generate thing and stuff masks (with value 1/0 indicates the | |
| # presence/absence) | |
| thing_mask = tf.cast(mask_id_semantic_map < len(thing_class_ids), | |
| tf.float32) * confident_region | |
| stuff_mask = tf.cast(mask_id_semantic_map >= len(thing_class_ids), | |
| tf.float32) * confident_region | |
| # Generate semantic_map. | |
| semantic_map = tf.gather( | |
| tf.convert_to_tensor(thing_stuff_class_ids), | |
| tf.cast(tf.round(mask_id_semantic_map_flat), tf.int32)) | |
| semantic_map = tf.reshape(semantic_map, [image_shape[0], image_shape[1]]) | |
| # Add 1 because mask ID 0 is reserved for unconfident region. | |
| mask_id_map_plus_one = mask_id_map + 1 | |
| semantic_map = tf.cast(tf.round(semantic_map), tf.int32) | |
| return (mask_id_map_plus_one, semantic_map, thing_mask, stuff_mask) | |
| mask_id_map_plus_one, semantic_map, thing_mask, stuff_mask = tf.cond( | |
| tf.cast(num_detections, tf.float32) < tf.cast(0.5, tf.float32), | |
| _return_empty_mask_id_and_semantic_maps, | |
| _generate_mask_id_and_semantic_maps) | |
| return (mask_id_map_plus_one, semantic_map, thing_mask, stuff_mask) | |
| def _filter_by_count(input_index_map: tf.Tensor, | |
| area_limit: int) -> Tuple[tf.Tensor, tf.Tensor]: | |
| """Filters input index map by area limit threshold per index. | |
| Args: | |
| input_index_map: A float32 tf.Tensor of shape [batch, height, width]. | |
| area_limit: An integer specifying the number of pixels that each index | |
| regions need to have at least. If not over the limit, the index regions | |
| are masked (zeroed) out. | |
| Returns: | |
| masked input_index_map: A tf.Tensor with shape [batch, height, width], | |
| masked by the area_limit threshold. | |
| mask: A tf.Tensor with shape [batch, height, width]. It is a pixel-level | |
| mask with 1. indicating the regions over the area limit, and 0. otherwise. | |
| """ | |
| batch_size = tf.shape(input_index_map)[0] | |
| index_map = tf.cast(tf.round(input_index_map), tf.int32) | |
| index_map_flat = tf.reshape(index_map, [batch_size, -1]) | |
| counts = tf.math.bincount(index_map_flat, axis=-1) | |
| counts_map = tf.gather(counts, index_map_flat, batch_dims=1) | |
| counts_map = tf.reshape(counts_map, tf.shape(index_map)) | |
| mask = tf.cast( | |
| tf.cast(counts_map, tf.float32) > tf.cast(area_limit - 0.5, tf.float32), | |
| input_index_map.dtype) | |
| return input_index_map * mask, mask | |
| def _merge_mask_id_and_semantic_maps( | |
| mask_id_maps_plus_one: tf.Tensor, | |
| semantic_maps: tf.Tensor, | |
| thing_masks: tf.Tensor, | |
| stuff_masks: tf.Tensor, | |
| void_label: int, | |
| label_divisor: int, | |
| thing_area_limit: int, | |
| stuff_area_limit: int,) -> tf.Tensor: | |
| """Merges mask_id maps and semantic_maps to obtain panoptic segmentation. | |
| Args: | |
| mask_id_maps_plus_one: A tf.Tensor of shape [batch, height, width]. | |
| semantic_maps: A tf.Tensor of shape [batch, height, width]. | |
| thing_masks: A float32 tf.Tensor of shape [batch, height, width] containing | |
| masks with 1. at thing regions, 0. otherwise. | |
| stuff_masks: A float32 tf.Tensor of shape [batch, height, width] containing | |
| masks with 1. at thing regions, 0. otherwise. | |
| void_label: An integer specifying the void label. | |
| label_divisor: An integer specifying the label divisor of the dataset. | |
| thing_area_limit: An integer specifying the number of pixels that thing | |
| regions need to have at least. The thing region will be included in the | |
| panoptic prediction, only if its area is larger than the limit; otherwise, | |
| it will be re-assigned as void_label. | |
| stuff_area_limit: An integer specifying the number of pixels that stuff | |
| regions need to have at least. The stuff region will be included in the | |
| panoptic prediction, only if its area is larger than the limit; otherwise, | |
| it will be re-assigned as void_label. | |
| Returns: | |
| panoptic_maps: A tf.Tensor with shape [batch, height, width]. | |
| """ | |
| thing_mask_id_maps_plus_one = (tf.cast(mask_id_maps_plus_one, tf.float32) | |
| * thing_masks) | |
| # We increase semantic_maps by 1 before masking (zeroing) by thing_masks and | |
| # stuff_masks, to ensure all valid semantic IDs are greater than 0 and thus | |
| # not masked out. | |
| semantic_maps_plus_one = semantic_maps + 1 | |
| tf.debugging.assert_less( | |
| tf.reduce_sum(thing_masks * stuff_masks), 0.5, | |
| message='thing_masks and stuff_masks must be mutually exclusive.') | |
| thing_semantic_maps = (tf.cast(semantic_maps_plus_one, tf.float32) | |
| * thing_masks) | |
| stuff_semantic_maps = (tf.cast(semantic_maps_plus_one, tf.float32) | |
| * stuff_masks) | |
| # Filter stuff_semantic_maps by stuff_area_limit. | |
| stuff_semantic_maps, _ = _filter_by_count( | |
| stuff_semantic_maps, stuff_area_limit) | |
| # Filter thing_mask_id_map and thing_semantic_map by thing_area_limit | |
| thing_mask_id_maps_plus_one, mask_id_count_filter_mask = _filter_by_count( | |
| thing_mask_id_maps_plus_one, thing_area_limit) | |
| thing_semantic_maps = thing_semantic_maps * mask_id_count_filter_mask | |
| # Filtered un-confident region will be replaced with `void_label`. The | |
| # "plus_one" will be reverted, the un-confident region (0) will be -1, and so | |
| # we add (void + 1) | |
| semantic_maps_new = thing_semantic_maps + stuff_semantic_maps - 1.0 | |
| semantic_maps_new = (tf.cast(semantic_maps_new < -0.5, tf.float32) | |
| * tf.cast(void_label + 1, tf.float32) | |
| + semantic_maps_new) | |
| panoptic_maps = (semantic_maps_new * label_divisor | |
| + thing_mask_id_maps_plus_one) | |
| panoptic_maps = tf.cast(tf.round(panoptic_maps), tf.int32) | |
| return panoptic_maps | |
| def _get_panoptic_predictions( | |
| pixel_space_mask_logits: tf.Tensor, | |
| transformer_class_logits: tf.Tensor, | |
| thing_class_ids: List[int], | |
| void_label: int, | |
| label_divisor: int, | |
| thing_area_limit: int, | |
| stuff_area_limit: int, | |
| image_shape: List[int], | |
| pixel_confidence_threshold=0.4, | |
| transformer_class_confidence_threshold=0.7, | |
| pieces=1) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: | |
| """Computes the pixel-level panoptic, mask ID, and semantic maps. | |
| Args: | |
| pixel_space_mask_logits: A tf.Tensor of shape [batch, strided_height, | |
| strided_width, num_mask_slots]. It is a pixel level logit scores where the | |
| num_mask_slots is the number of mask slots (for both thing classes | |
| and stuff classes) in MaX-DeepLab. | |
| transformer_class_logits: A tf.Tensor of shape [batch, num_mask_slots, | |
| num_thing_stuff_classes + 1]. It is a pixel level logit scores where the | |
| num_mask_slots is the number of mask slots (for both thing classes and | |
| stuff classes) in MaX-DeepLab. The last channel indicates a `void` class. | |
| thing_class_ids: A List of integers of shape [num_thing_classes] containing | |
| thing class indices. | |
| void_label: An integer specifying the void label. | |
| label_divisor: An integer specifying the label divisor of the dataset. | |
| thing_area_limit: An integer specifying the number of pixels that thing | |
| regions need to have at least. The thing region will be included in the | |
| panoptic prediction, only if its area is larger than the limit; otherwise, | |
| it will be re-assigned as void_label. | |
| stuff_area_limit: An integer specifying the number of pixels that stuff | |
| regions need to have at least. The stuff region will be included in the | |
| panoptic prediction, only if its area is larger than the limit; otherwise, | |
| it will be re-assigned as void_label. | |
| image_shape: A list of integers specifying the [height, width] of input | |
| image. | |
| pixel_confidence_threshold: A float indicating a threshold for the pixel | |
| level softmax probability confidence of transformer mask logits. If less | |
| than the threshold, the pixel locations have confidence `0` in | |
| `confident_regions` output, and represent `void` (ignore) regions. | |
| transformer_class_confidence_threshold: A float for thresholding the | |
| confidence of the transformer_class_probs. The panoptic mask slots with | |
| class confidence less than the threshold are filtered and not used for | |
| panoptic prediction. | |
| pieces: An integer indicating the number of pieces in the piece-wise | |
| operation in `_get_mask_id_and_semantic_maps`. When computing panoptic | |
| prediction and confident regions, the mask logits are divided width-wise | |
| into multiple pieces and processed piece-wise due to the GPU memory limit. | |
| Then, the piece-wise outputs are concatenated along the width into the | |
| original mask shape. Defaults to 1. | |
| Returns: | |
| A tuple of: | |
| - the panoptic prediction as tf.Tensor with shape [batch, height, width]. | |
| - the mask ID prediction as tf.Tensor with shape [batch, height, width]. | |
| - the semantic prediction as tf.Tensor with shape [batch, height, width]. | |
| """ | |
| transformer_class_probs = tf.nn.softmax(transformer_class_logits, axis=-1) | |
| batch_size = tf.shape(transformer_class_logits)[0] | |
| # num_thing_stuff_classes does not include `void` class, so we decrease by 1. | |
| num_thing_stuff_classes = ( | |
| transformer_class_logits.get_shape().as_list()[-1] - 1) | |
| # Generate thing and stuff class ids | |
| stuff_class_ids = utils.get_stuff_class_ids( | |
| num_thing_stuff_classes, thing_class_ids, void_label) | |
| mask_id_map_plus_one_lists = tf.TensorArray( | |
| tf.int32, size=batch_size, dynamic_size=False) | |
| semantic_map_lists = tf.TensorArray( | |
| tf.int32, size=batch_size, dynamic_size=False) | |
| thing_mask_lists = tf.TensorArray( | |
| tf.float32, size=batch_size, dynamic_size=False) | |
| stuff_mask_lists = tf.TensorArray( | |
| tf.float32, size=batch_size, dynamic_size=False) | |
| for i in tf.range(batch_size): | |
| mask_id_map_plus_one, semantic_map, thing_mask, stuff_mask = ( | |
| _get_mask_id_and_semantic_maps( | |
| thing_class_ids, stuff_class_ids, | |
| pixel_space_mask_logits[i, ...], transformer_class_probs[i, ...], | |
| image_shape, pixel_confidence_threshold, | |
| transformer_class_confidence_threshold, pieces) | |
| ) | |
| mask_id_map_plus_one_lists = mask_id_map_plus_one_lists.write( | |
| i, mask_id_map_plus_one) | |
| semantic_map_lists = semantic_map_lists.write(i, semantic_map) | |
| thing_mask_lists = thing_mask_lists.write(i, thing_mask) | |
| stuff_mask_lists = stuff_mask_lists.write(i, stuff_mask) | |
| # This does not work with unknown shapes. | |
| mask_id_maps_plus_one = mask_id_map_plus_one_lists.stack() | |
| semantic_maps = semantic_map_lists.stack() | |
| thing_masks = thing_mask_lists.stack() | |
| stuff_masks = stuff_mask_lists.stack() | |
| panoptic_maps = _merge_mask_id_and_semantic_maps( | |
| mask_id_maps_plus_one, semantic_maps, thing_masks, stuff_masks, | |
| void_label, label_divisor, thing_area_limit, stuff_area_limit) | |
| return panoptic_maps, mask_id_maps_plus_one, semantic_maps | |
| class PostProcessor(tf.keras.layers.Layer): | |
| """This class contains code of a MaX-DeepLab post-processor.""" | |
| def __init__( | |
| self, | |
| config: config_pb2.ExperimentOptions, | |
| dataset_descriptor: dataset.DatasetDescriptor): | |
| """Initializes a MaX-DeepLab post-processor. | |
| Args: | |
| config: A config_pb2.ExperimentOptions configuration. | |
| dataset_descriptor: A dataset.DatasetDescriptor. | |
| """ | |
| super(PostProcessor, self).__init__(name='PostProcessor') | |
| self._post_processor = functools.partial( | |
| _get_panoptic_predictions, | |
| thing_class_ids=list(dataset_descriptor.class_has_instances_list), | |
| void_label=dataset_descriptor.ignore_label, | |
| label_divisor=dataset_descriptor.panoptic_label_divisor, | |
| thing_area_limit=config.evaluator_options.thing_area_limit, | |
| stuff_area_limit=config.evaluator_options.stuff_area_limit, | |
| image_shape=list(config.eval_dataset_options.crop_size), | |
| transformer_class_confidence_threshold=config.evaluator_options | |
| .transformer_class_confidence_threshold, | |
| pixel_confidence_threshold=config.evaluator_options | |
| .pixel_confidence_threshold, | |
| pieces=1) | |
| def call(self, result_dict: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: | |
| """Performs the post-processing given model predicted results. | |
| Args: | |
| result_dict: A dictionary of tf.Tensor containing model results. The dict | |
| has to contain | |
| - common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY, | |
| - common.PRED_TRANSFORMER_CLASS_LOGITS_KEY, | |
| Returns: | |
| The post-processed dict of tf.Tensor, containing the following: | |
| - common.PRED_SEMANTIC_KEY, | |
| - common.PRED_INSTANCE_KEY, | |
| - common.PRED_PANOPTIC_KEY, | |
| """ | |
| processed_dict = {} | |
| (processed_dict[common.PRED_PANOPTIC_KEY], | |
| processed_dict[common.PRED_INSTANCE_KEY], | |
| processed_dict[common.PRED_SEMANTIC_KEY] | |
| ) = self._post_processor( | |
| result_dict[common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY], | |
| result_dict[common.PRED_TRANSFORMER_CLASS_LOGITS_KEY]) | |
| return processed_dict | |