Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tests for mobilenet.""" | |
| from absl.testing import parameterized | |
| import tensorflow as tf | |
| from deeplab2.model import test_utils | |
| from deeplab2.model.encoder import mobilenet | |
| class MobilenetTest(tf.test.TestCase, parameterized.TestCase): | |
| def test_mobilenetv3_construct_graph(self, model_name): | |
| tf.keras.backend.set_image_data_format('channels_last') | |
| input_size = 128 | |
| mobilenet_models = { | |
| 'MobileNetV3Small': mobilenet.MobileNetV3Small, | |
| 'MobileNetV3Large': mobilenet.MobileNetV3Large, | |
| } | |
| mobilenet_channels = { | |
| # The number of filters of layers having outputs been collected | |
| # for filter_size_scale = 1.0 | |
| 'MobileNetV3Small': [16, 24, 48, 576], | |
| 'MobileNetV3Large': [24, 40, 112, 960], | |
| } | |
| network = mobilenet_models[str(model_name)](width_multiplier=1.0) | |
| inputs = tf.ones([1, input_size, input_size, 3]) | |
| endpoints = network(inputs) | |
| for idx, num_filter in enumerate(mobilenet_channels[model_name]): | |
| self.assertAllEqual( | |
| [1, input_size / 2 ** (idx+2), input_size / 2 ** (idx+2), num_filter], | |
| endpoints['res'+str(idx+2)].shape.as_list()) | |
| def test_mobilenetv3_atrous_endpoint_shape(self, model_name, output_stride): | |
| tf.keras.backend.set_image_data_format('channels_last') | |
| input_size = 321 | |
| batch_size = 2 | |
| mobilenet_models = { | |
| 'MobileNetV3Small': mobilenet.MobileNetV3Small, | |
| 'MobileNetV3Large': mobilenet.MobileNetV3Large, | |
| } | |
| stride_spatial_shapes_map = { | |
| 4: [81, 81, 81, 81], | |
| 8: [81, 41, 41, 41], | |
| 16: [81, 41, 21, 21], | |
| 32: [81, 41, 21, 11], | |
| } | |
| mobilenet_channels = { | |
| # The number of filters of layers having outputs been collected | |
| # for filter_size_scale = 1.0 | |
| 'MobileNetV3Small': [16, 24, 48, 576], | |
| 'MobileNetV3Large': [24, 40, 112, 960], | |
| } | |
| network = mobilenet_models[str(model_name)]( | |
| width_multiplier=1.0, | |
| output_stride=output_stride) | |
| spatial_shapes = stride_spatial_shapes_map[output_stride] | |
| inputs = tf.ones([batch_size, input_size, input_size, 3]) | |
| endpoints = network(inputs) | |
| for idx, num_filters in enumerate(mobilenet_channels[model_name]): | |
| expected_shape = [ | |
| batch_size, spatial_shapes[idx], spatial_shapes[idx], num_filters | |
| ] | |
| self.assertAllEqual( | |
| expected_shape, | |
| endpoints['res'+str(idx+2)].shape.as_list()) | |
| def test_mobilenet_reload_weights(self, model_name): | |
| tf.keras.backend.set_image_data_format('channels_last') | |
| mobilenet_models = { | |
| 'MobileNetV3Small': mobilenet.MobileNetV3Small, | |
| 'MobileNetV3Large': mobilenet.MobileNetV3Large, | |
| } | |
| tf.random.set_seed(0) | |
| pixel_inputs = test_utils.create_test_input(1, 320, 320, 3) | |
| network1 = mobilenet_models[model_name]( | |
| width_multiplier=1.0, | |
| output_stride=32, | |
| name='m1') | |
| network1(pixel_inputs, False) | |
| outputs1 = network1(pixel_inputs, False) | |
| pixel_outputs = outputs1['res5'] | |
| # Feature extraction at the normal network rate. | |
| network2 = mobilenet_models[model_name]( | |
| width_multiplier=1.0, | |
| output_stride=32, | |
| name='m2') | |
| network2(pixel_inputs, False) | |
| # Make the two networks use the same weights. | |
| network2.set_weights(network1.get_weights()) | |
| outputs2 = network2(pixel_inputs, False) | |
| expected = outputs2['res5'] | |
| self.assertAllClose(network1.get_weights(), network2.get_weights(), | |
| atol=1e-4, rtol=1e-4) | |
| self.assertAllClose(pixel_outputs, expected, atol=1e-4, rtol=1e-4) | |
| if __name__ == '__main__': | |
| tf.test.main() | |