Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| # Copyright 2024 Yiwei Guo | |
| # Derived mostly from fairseq (https://github.com/facebookresearch/fairseq) | |
| """Prompt Pre-net Modules.""" | |
| import math | |
| import torch.nn as nn | |
| from vec2wav2.models.fairseq_modules.fp32_group_norm import Fp32GroupNorm | |
| from vec2wav2.models.fairseq_modules.layer_norm import Fp32LayerNorm | |
| from vec2wav2.models.fairseq_modules.transpose_last import TransposeLast | |
| import torch | |
| def norm_block(is_layer_norm, dim, affine=True): | |
| if is_layer_norm: | |
| mod = nn.Sequential( | |
| TransposeLast(), | |
| Fp32LayerNorm(dim, elementwise_affine=affine), | |
| TransposeLast(), | |
| ) | |
| else: | |
| mod = Fp32GroupNorm(1, dim, affine=affine) | |
| return mod | |
| class ZeroPad1d(nn.Module): | |
| def __init__(self, pad_left, pad_right): | |
| super().__init__() | |
| self.pad_left = pad_left | |
| self.pad_right = pad_right | |
| def forward(self, x): | |
| return nn.functional.pad(x, (self.pad_left, self.pad_right)) | |
| class ConvPromptPrenet(nn.Module): | |
| def __init__( | |
| self, | |
| conv_layers, | |
| embed, | |
| dropout, | |
| skip_connections, | |
| residual_scale, | |
| non_affine_group_norm, | |
| conv_bias, | |
| activation, | |
| ): | |
| super().__init__() | |
| def block(n_in, n_out, k, stride, pad): | |
| return nn.Sequential( | |
| nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, padding=pad), | |
| nn.Dropout(p=dropout), | |
| norm_block(False, n_out, affine=not non_affine_group_norm), | |
| activation, | |
| ) | |
| in_d = embed | |
| self.conv_layers = nn.ModuleList() | |
| self.residual_proj = nn.ModuleList() | |
| for dim, k, stride, pad in conv_layers: | |
| if in_d != dim and skip_connections: | |
| self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) | |
| else: | |
| self.residual_proj.append(None) | |
| self.conv_layers.append(block(in_d, dim, k, stride, pad)) | |
| in_d = dim | |
| self.conv_layers = nn.Sequential(*self.conv_layers) | |
| self.skip_connections = skip_connections | |
| self.residual_scale = math.sqrt(residual_scale) | |
| def forward(self, x): | |
| for rproj, conv in zip(self.residual_proj, self.conv_layers): | |
| residual = x | |
| x = conv(x) | |
| if self.skip_connections: | |
| if rproj is not None: | |
| residual = rproj(residual) | |
| x = (x + residual) * self.residual_scale | |
| return x | |