Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| import pytest | |
| import torch | |
| from mmocr.models.textrecog.decoders import (ABILanguageDecoder, | |
| ABIVisionDecoder, BaseDecoder, | |
| NRTRDecoder, ParallelSARDecoder, | |
| ParallelSARDecoderWithBS, | |
| SequentialSARDecoder) | |
| from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode | |
| def _create_dummy_input(): | |
| feat = torch.rand(1, 512, 4, 40) | |
| out_enc = torch.rand(1, 512) | |
| tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} | |
| img_metas = [{'valid_ratio': 1.0}] | |
| return feat, out_enc, tgt_dict, img_metas | |
| def test_base_decoder(): | |
| decoder = BaseDecoder() | |
| with pytest.raises(NotImplementedError): | |
| decoder.forward_train(None, None, None, None) | |
| with pytest.raises(NotImplementedError): | |
| decoder.forward_test(None, None, None) | |
| def test_parallel_sar_decoder(): | |
| # test parallel sar decoder | |
| decoder = ParallelSARDecoder(num_classes=37, padding_idx=36, max_seq_len=5) | |
| decoder.init_weights() | |
| decoder.train() | |
| feat, out_enc, tgt_dict, img_metas = _create_dummy_input() | |
| with pytest.raises(AssertionError): | |
| decoder(feat, out_enc, tgt_dict, [], True) | |
| with pytest.raises(AssertionError): | |
| decoder(feat, out_enc, tgt_dict, img_metas * 2, True) | |
| out_train = decoder(feat, out_enc, tgt_dict, img_metas, True) | |
| assert out_train.shape == torch.Size([1, 5, 36]) | |
| out_test = decoder(feat, out_enc, tgt_dict, img_metas, False) | |
| assert out_test.shape == torch.Size([1, 5, 36]) | |
| def test_sequential_sar_decoder(): | |
| # test parallel sar decoder | |
| decoder = SequentialSARDecoder( | |
| num_classes=37, padding_idx=36, max_seq_len=5) | |
| decoder.init_weights() | |
| decoder.train() | |
| feat, out_enc, tgt_dict, img_metas = _create_dummy_input() | |
| with pytest.raises(AssertionError): | |
| decoder(feat, out_enc, tgt_dict, []) | |
| with pytest.raises(AssertionError): | |
| decoder(feat, out_enc, tgt_dict, img_metas * 2) | |
| out_train = decoder(feat, out_enc, tgt_dict, img_metas, True) | |
| assert out_train.shape == torch.Size([1, 5, 36]) | |
| out_test = decoder(feat, out_enc, tgt_dict, img_metas, False) | |
| assert out_test.shape == torch.Size([1, 5, 36]) | |
| def test_parallel_sar_decoder_with_beam_search(): | |
| with pytest.raises(AssertionError): | |
| ParallelSARDecoderWithBS(beam_width='beam') | |
| with pytest.raises(AssertionError): | |
| ParallelSARDecoderWithBS(beam_width=0) | |
| feat, out_enc, tgt_dict, img_metas = _create_dummy_input() | |
| decoder = ParallelSARDecoderWithBS( | |
| beam_width=1, num_classes=37, padding_idx=36, max_seq_len=5) | |
| decoder.init_weights() | |
| decoder.train() | |
| with pytest.raises(AssertionError): | |
| decoder(feat, out_enc, tgt_dict, []) | |
| with pytest.raises(AssertionError): | |
| decoder(feat, out_enc, tgt_dict, img_metas * 2) | |
| out_test = decoder(feat, out_enc, tgt_dict, img_metas, train_mode=False) | |
| assert out_test.shape == torch.Size([1, 5, 36]) | |
| # test decodenode | |
| with pytest.raises(AssertionError): | |
| DecodeNode(1, 1) | |
| with pytest.raises(AssertionError): | |
| DecodeNode([1, 2], ['4', '3']) | |
| with pytest.raises(AssertionError): | |
| DecodeNode([1, 2], [0.5]) | |
| decode_node = DecodeNode([1, 2], [0.7, 0.8]) | |
| assert math.isclose(decode_node.eval(), 1.5) | |
| def test_transformer_decoder(): | |
| decoder = NRTRDecoder(num_classes=37, padding_idx=36, max_seq_len=5) | |
| decoder.init_weights() | |
| decoder.train() | |
| out_enc = torch.rand(1, 25, 512) | |
| tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} | |
| img_metas = [{'valid_ratio': 1.0}] | |
| tgt_dict['padded_targets'] = tgt_dict['padded_targets'] | |
| out_train = decoder(None, out_enc, tgt_dict, img_metas, True) | |
| assert out_train.shape == torch.Size([1, 5, 36]) | |
| out_test = decoder(None, out_enc, tgt_dict, img_metas, False) | |
| assert out_test.shape == torch.Size([1, 5, 36]) | |
| def test_abi_language_decoder(): | |
| decoder = ABILanguageDecoder(max_seq_len=25) | |
| logits = torch.randn(2, 25, 90) | |
| result = decoder( | |
| feat=None, out_enc=logits, targets_dict=None, img_metas=None) | |
| assert result['feature'].shape == torch.Size([2, 25, 512]) | |
| assert result['logits'].shape == torch.Size([2, 25, 90]) | |
| def test_abi_vision_decoder(): | |
| model = ABIVisionDecoder( | |
| in_channels=128, num_channels=16, max_seq_len=10, use_result=None) | |
| x = torch.randn(2, 128, 8, 32) | |
| result = model(x, None) | |
| assert result['feature'].shape == torch.Size([2, 10, 128]) | |
| assert result['logits'].shape == torch.Size([2, 10, 90]) | |
| assert result['attn_scores'].shape == torch.Size([2, 10, 8, 32]) | |