| | |
| | |
| | |
| | |
| | |
| |
|
| | from itertools import product |
| |
|
| | import pytest |
| | import torch |
| |
|
| | from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock |
| | from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d |
| |
|
| |
|
| | class TestSEANetModel: |
| |
|
| | def test_base(self): |
| | encoder = SEANetEncoder() |
| | decoder = SEANetDecoder() |
| |
|
| | x = torch.randn(1, 1, 24000) |
| | z = encoder(x) |
| | assert list(z.shape) == [1, 128, 75], z.shape |
| | y = decoder(z) |
| | assert y.shape == x.shape, (x.shape, y.shape) |
| |
|
| | def test_causal(self): |
| | encoder = SEANetEncoder(causal=True) |
| | decoder = SEANetDecoder(causal=True) |
| | x = torch.randn(1, 1, 24000) |
| |
|
| | z = encoder(x) |
| | assert list(z.shape) == [1, 128, 75], z.shape |
| | y = decoder(z) |
| | assert y.shape == x.shape, (x.shape, y.shape) |
| |
|
| | def test_conv_skip_connection(self): |
| | encoder = SEANetEncoder(true_skip=False) |
| | decoder = SEANetDecoder(true_skip=False) |
| |
|
| | x = torch.randn(1, 1, 24000) |
| | z = encoder(x) |
| | assert list(z.shape) == [1, 128, 75], z.shape |
| | y = decoder(z) |
| | assert y.shape == x.shape, (x.shape, y.shape) |
| |
|
| | def test_seanet_encoder_decoder_final_act(self): |
| | encoder = SEANetEncoder(true_skip=False) |
| | decoder = SEANetDecoder(true_skip=False, final_activation='Tanh') |
| |
|
| | x = torch.randn(1, 1, 24000) |
| | z = encoder(x) |
| | assert list(z.shape) == [1, 128, 75], z.shape |
| | y = decoder(z) |
| | assert y.shape == x.shape, (x.shape, y.shape) |
| |
|
| | def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str): |
| | n_blocks = 0 |
| | for layer in encoder.model: |
| | if isinstance(layer, StreamableConv1d): |
| | n_blocks += 1 |
| | assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm |
| | elif isinstance(layer, SEANetResnetBlock): |
| | for resnet_layer in layer.block: |
| | if isinstance(resnet_layer, StreamableConv1d): |
| | |
| | assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm |
| |
|
| | def test_encoder_disable_norm(self): |
| | n_residuals = [0, 1, 3] |
| | disable_blocks = [0, 1, 2, 3, 4, 5, 6] |
| | norms = ['weight_norm', 'none'] |
| | for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms): |
| | encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm, |
| | disable_norm_outer_blocks=disable_blocks) |
| | self._check_encoder_blocks_norm(encoder, disable_blocks, norm) |
| |
|
| | def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str): |
| | n_blocks = 0 |
| | for layer in decoder.model: |
| | if isinstance(layer, StreamableConv1d): |
| | n_blocks += 1 |
| | assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm |
| | elif isinstance(layer, StreamableConvTranspose1d): |
| | n_blocks += 1 |
| | assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm |
| | elif isinstance(layer, SEANetResnetBlock): |
| | for resnet_layer in layer.block: |
| | if isinstance(resnet_layer, StreamableConv1d): |
| | assert resnet_layer.conv.norm_type == 'none' \ |
| | if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm |
| |
|
| | def test_decoder_disable_norm(self): |
| | n_residuals = [0, 1, 3] |
| | disable_blocks = [0, 1, 2, 3, 4, 5, 6] |
| | norms = ['weight_norm', 'none'] |
| | for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms): |
| | decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm, |
| | disable_norm_outer_blocks=disable_blocks) |
| | self._check_decoder_blocks_norm(decoder, disable_blocks, norm) |
| |
|
| | def test_disable_norm_raises_exception(self): |
| | |
| | with pytest.raises(AssertionError): |
| | SEANetEncoder(disable_norm_outer_blocks=-1) |
| |
|
| | with pytest.raises(AssertionError): |
| | SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7) |
| |
|
| | with pytest.raises(AssertionError): |
| | SEANetDecoder(disable_norm_outer_blocks=-1) |
| |
|
| | with pytest.raises(AssertionError): |
| | SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7) |
| |
|