Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Sequence, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint as checkpoint | |
| from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer | |
| from mmengine.model import BaseModule, ModuleList, Sequential | |
| from torch.nn import functional as F | |
| from mmpretrain.registry import MODELS | |
| from ..utils import LeAttention | |
| from .base_backbone import BaseBackbone | |
| class ConvBN2d(Sequential): | |
| """An implementation of Conv2d + BatchNorm2d with support of fusion. | |
| Modified from | |
| https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py | |
| Args: | |
| in_channels (int): The number of input channels. | |
| out_channels (int): The number of output channels. | |
| kernel_size (int): The size of the convolution kernel. | |
| Default: 1. | |
| stride (int): The stride of the convolution. | |
| Default: 1. | |
| padding (int): The padding of the convolution. | |
| Default: 0. | |
| dilation (int): The dilation of the convolution. | |
| Default: 1. | |
| groups (int): The number of groups in the convolution. | |
| Default: 1. | |
| bn_weight_init (float): The initial value of the weight of | |
| the nn.BatchNorm2d layer. Default: 1.0. | |
| init_cfg (dict): The initialization config of the module. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bn_weight_init=1.0, | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.add_module( | |
| 'conv2d', | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=False)) | |
| bn2d = nn.BatchNorm2d(num_features=out_channels) | |
| # bn initialization | |
| torch.nn.init.constant_(bn2d.weight, bn_weight_init) | |
| torch.nn.init.constant_(bn2d.bias, 0) | |
| self.add_module('bn2d', bn2d) | |
| def fuse(self): | |
| conv2d, bn2d = self._modules.values() | |
| w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5 | |
| w = conv2d.weight * w[:, None, None, None] | |
| b = bn2d.bias - bn2d.running_mean * bn2d.weight / \ | |
| (bn2d.running_var + bn2d.eps)**0.5 | |
| m = nn.Conv2d( | |
| in_channels=w.size(1) * self.c.groups, | |
| out_channels=w.size(0), | |
| kernel_size=w.shape[2:], | |
| stride=self.conv2d.stride, | |
| padding=self.conv2d.padding, | |
| dilation=self.conv2d.dilation, | |
| groups=self.conv2d.groups) | |
| m.weight.data.copy_(w) | |
| m.bias.data.copy_(b) | |
| return m | |
| class PatchEmbed(BaseModule): | |
| """Patch Embedding for Vision Transformer. | |
| Adapted from | |
| https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py | |
| Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use | |
| Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is | |
| (N, C, H, W). | |
| Args: | |
| in_channels (int): The number of input channels. | |
| embed_dim (int): The embedding dimension. | |
| resolution (Tuple[int, int]): The resolution of the input feature. | |
| act_cfg (dict): The activation config of the module. | |
| Default: dict(type='GELU'). | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| embed_dim, | |
| resolution, | |
| act_cfg=dict(type='GELU')): | |
| super().__init__() | |
| img_size: Tuple[int, int] = resolution | |
| self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) | |
| self.num_patches = self.patches_resolution[0] * \ | |
| self.patches_resolution[1] | |
| self.in_channels = in_channels | |
| self.embed_dim = embed_dim | |
| self.seq = nn.Sequential( | |
| ConvBN2d( | |
| in_channels, | |
| embed_dim // 2, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1), | |
| build_activation_layer(act_cfg), | |
| ConvBN2d( | |
| embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), | |
| ) | |
| def forward(self, x): | |
| return self.seq(x) | |
| class PatchMerging(nn.Module): | |
| """Patch Merging for TinyViT. | |
| Adapted from | |
| https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py | |
| Different from `mmpretrain.models.utils.PatchMerging`, this module use | |
| Conv2d and BatchNorm2d to implement PatchMerging. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| resolution (Tuple[int, int]): The resolution of the input feature. | |
| out_channels (int): The number of output channels. | |
| act_cfg (dict): The activation config of the module. | |
| Default: dict(type='GELU'). | |
| """ | |
| def __init__(self, | |
| resolution, | |
| in_channels, | |
| out_channels, | |
| act_cfg=dict(type='GELU')): | |
| super().__init__() | |
| self.img_size = resolution | |
| self.act = build_activation_layer(act_cfg) | |
| self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1) | |
| self.conv2 = ConvBN2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| groups=out_channels) | |
| self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1) | |
| self.out_resolution = (resolution[0] // 2, resolution[1] // 2) | |
| def forward(self, x): | |
| if len(x.shape) == 3: | |
| H, W = self.img_size | |
| B = x.shape[0] | |
| x = x.view(B, H, W, -1).permute(0, 3, 1, 2) | |
| x = self.conv1(x) | |
| x = self.act(x) | |
| x = self.conv2(x) | |
| x = self.act(x) | |
| x = self.conv3(x) | |
| x = x.flatten(2).transpose(1, 2) | |
| return x | |
| class MBConvBlock(nn.Module): | |
| """Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from | |
| https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| out_channels (int): The number of output channels. | |
| expand_ratio (int): The expand ratio of the hidden channels. | |
| drop_rate (float): The drop rate of the block. | |
| act_cfg (dict): The activation config of the module. | |
| Default: dict(type='GELU'). | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| expand_ratio, | |
| drop_path, | |
| act_cfg=dict(type='GELU')): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| hidden_channels = int(in_channels * expand_ratio) | |
| # linear | |
| self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1) | |
| self.act = build_activation_layer(act_cfg) | |
| # depthwise conv | |
| self.conv2 = ConvBN2d( | |
| in_channels=hidden_channels, | |
| out_channels=hidden_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| groups=hidden_channels) | |
| # linear | |
| self.conv3 = ConvBN2d( | |
| hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0) | |
| self.drop_path = DropPath( | |
| drop_path) if drop_path > 0. else nn.Identity() | |
| def forward(self, x): | |
| shortcut = x | |
| x = self.conv1(x) | |
| x = self.act(x) | |
| x = self.conv2(x) | |
| x = self.act(x) | |
| x = self.conv3(x) | |
| x = self.drop_path(x) | |
| x += shortcut | |
| x = self.act(x) | |
| return x | |
| class ConvStage(BaseModule): | |
| """Convolution Stage for TinyViT. | |
| Adapted from | |
| https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py | |
| Args: | |
| in_channels (int): The number of input channels. | |
| resolution (Tuple[int, int]): The resolution of the input feature. | |
| depth (int): The number of blocks in the stage. | |
| act_cfg (dict): The activation config of the module. | |
| drop_path (float): The drop path of the block. | |
| downsample (None | nn.Module): The downsample operation. | |
| Default: None. | |
| use_checkpoint (bool): Whether to use checkpointing to save memory. | |
| out_channels (int): The number of output channels. | |
| conv_expand_ratio (int): The expand ratio of the hidden channels. | |
| Default: 4. | |
| init_cfg (dict | list[dict], optional): Initialization config dict. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| resolution, | |
| depth, | |
| act_cfg, | |
| drop_path=0., | |
| downsample=None, | |
| use_checkpoint=False, | |
| out_channels=None, | |
| conv_expand_ratio=4., | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.use_checkpoint = use_checkpoint | |
| # build blocks | |
| self.blocks = ModuleList([ | |
| MBConvBlock( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| expand_ratio=conv_expand_ratio, | |
| drop_path=drop_path[i] | |
| if isinstance(drop_path, list) else drop_path) | |
| for i in range(depth) | |
| ]) | |
| # patch merging layer | |
| if downsample is not None: | |
| self.downsample = downsample( | |
| resolution=resolution, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| act_cfg=act_cfg) | |
| self.resolution = self.downsample.out_resolution | |
| else: | |
| self.downsample = None | |
| self.resolution = resolution | |
| def forward(self, x): | |
| for block in self.blocks: | |
| if self.use_checkpoint: | |
| x = checkpoint.checkpoint(block, x) | |
| else: | |
| x = block(x) | |
| if self.downsample is not None: | |
| x = self.downsample(x) | |
| return x | |
| class MLP(BaseModule): | |
| """MLP module for TinyViT. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| hidden_channels (int, optional): The number of hidden channels. | |
| Default: None. | |
| out_channels (int, optional): The number of output channels. | |
| Default: None. | |
| act_cfg (dict): The activation config of the module. | |
| Default: dict(type='GELU'). | |
| drop (float): Probability of an element to be zeroed. | |
| Default: 0. | |
| init_cfg (dict | list[dict], optional): Initialization config dict. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| hidden_channels=None, | |
| out_channels=None, | |
| act_cfg=dict(type='GELU'), | |
| drop=0., | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| out_channels = out_channels or in_channels | |
| hidden_channels = hidden_channels or in_channels | |
| self.norm = nn.LayerNorm(in_channels) | |
| self.fc1 = nn.Linear(in_channels, hidden_channels) | |
| self.fc2 = nn.Linear(hidden_channels, out_channels) | |
| self.act = build_activation_layer(act_cfg) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.norm(x) | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class TinyViTBlock(BaseModule): | |
| """TinViT Block. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| resolution (Tuple[int, int]): The resolution of the input feature. | |
| num_heads (int): The number of heads in the multi-head attention. | |
| window_size (int): The size of the window. | |
| Default: 7. | |
| mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. | |
| Default: 4. | |
| drop (float): Probability of an element to be zeroed. | |
| Default: 0. | |
| drop_path (float): The drop path of the block. | |
| Default: 0. | |
| local_conv_size (int): The size of the local convolution. | |
| Default: 3. | |
| act_cfg (dict): The activation config of the module. | |
| Default: dict(type='GELU'). | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| resolution, | |
| num_heads, | |
| window_size=7, | |
| mlp_ratio=4., | |
| drop=0., | |
| drop_path=0., | |
| local_conv_size=3, | |
| act_cfg=dict(type='GELU')): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.img_size = resolution | |
| self.num_heads = num_heads | |
| assert window_size > 0, 'window_size must be greater than 0' | |
| self.window_size = window_size | |
| self.mlp_ratio = mlp_ratio | |
| self.drop_path = DropPath( | |
| drop_path) if drop_path > 0. else nn.Identity() | |
| assert in_channels % num_heads == 0, \ | |
| 'dim must be divisible by num_heads' | |
| head_dim = in_channels // num_heads | |
| window_resolution = (window_size, window_size) | |
| self.attn = LeAttention( | |
| in_channels, | |
| head_dim, | |
| num_heads, | |
| attn_ratio=1, | |
| resolution=window_resolution) | |
| mlp_hidden_dim = int(in_channels * mlp_ratio) | |
| self.mlp = MLP( | |
| in_channels=in_channels, | |
| hidden_channels=mlp_hidden_dim, | |
| act_cfg=act_cfg, | |
| drop=drop) | |
| self.local_conv = ConvBN2d( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| kernel_size=local_conv_size, | |
| stride=1, | |
| padding=local_conv_size // 2, | |
| groups=in_channels) | |
| def forward(self, x): | |
| H, W = self.img_size | |
| B, L, C = x.shape | |
| assert L == H * W, 'input feature has wrong size' | |
| res_x = x | |
| if H == self.window_size and W == self.window_size: | |
| x = self.attn(x) | |
| else: | |
| x = x.view(B, H, W, C) | |
| pad_b = (self.window_size - | |
| H % self.window_size) % self.window_size | |
| pad_r = (self.window_size - | |
| W % self.window_size) % self.window_size | |
| padding = pad_b > 0 or pad_r > 0 | |
| if padding: | |
| x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) | |
| pH, pW = H + pad_b, W + pad_r | |
| nH = pH // self.window_size | |
| nW = pW // self.window_size | |
| # window partition | |
| x = x.view(B, nH, self.window_size, nW, self.window_size, | |
| C).transpose(2, 3).reshape( | |
| B * nH * nW, self.window_size * self.window_size, C) | |
| x = self.attn(x) | |
| # window reverse | |
| x = x.view(B, nH, nW, self.window_size, self.window_size, | |
| C).transpose(2, 3).reshape(B, pH, pW, C) | |
| if padding: | |
| x = x[:, :H, :W].contiguous() | |
| x = x.view(B, L, C) | |
| x = res_x + self.drop_path(x) | |
| x = x.transpose(1, 2).reshape(B, C, H, W) | |
| x = self.local_conv(x) | |
| x = x.view(B, C, L).transpose(1, 2) | |
| x = x + self.drop_path(self.mlp(x)) | |
| return x | |
| class BasicStage(BaseModule): | |
| """Basic Stage for TinyViT. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| resolution (Tuple[int, int]): The resolution of the input feature. | |
| depth (int): The number of blocks in the stage. | |
| num_heads (int): The number of heads in the multi-head attention. | |
| window_size (int): The size of the window. | |
| mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. | |
| Default: 4. | |
| drop (float): Probability of an element to be zeroed. | |
| Default: 0. | |
| drop_path (float): The drop path of the block. | |
| Default: 0. | |
| downsample (None | nn.Module): The downsample operation. | |
| Default: None. | |
| use_checkpoint (bool): Whether to use checkpointing to save memory. | |
| Default: False. | |
| act_cfg (dict): The activation config of the module. | |
| Default: dict(type='GELU'). | |
| init_cfg (dict | list[dict], optional): Initialization config dict. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| resolution, | |
| depth, | |
| num_heads, | |
| window_size, | |
| mlp_ratio=4., | |
| drop=0., | |
| drop_path=0., | |
| downsample=None, | |
| use_checkpoint=False, | |
| local_conv_size=3, | |
| out_channels=None, | |
| act_cfg=dict(type='GELU'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.use_checkpoint = use_checkpoint | |
| # build blocks | |
| self.blocks = ModuleList([ | |
| TinyViTBlock( | |
| in_channels=in_channels, | |
| resolution=resolution, | |
| num_heads=num_heads, | |
| window_size=window_size, | |
| mlp_ratio=mlp_ratio, | |
| drop=drop, | |
| local_conv_size=local_conv_size, | |
| act_cfg=act_cfg, | |
| drop_path=drop_path[i] | |
| if isinstance(drop_path, list) else drop_path) | |
| for i in range(depth) | |
| ]) | |
| # build patch merging layer | |
| if downsample is not None: | |
| self.downsample = downsample( | |
| resolution=resolution, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| act_cfg=act_cfg) | |
| self.resolution = self.downsample.out_resolution | |
| else: | |
| self.downsample = None | |
| self.resolution = resolution | |
| def forward(self, x): | |
| for block in self.blocks: | |
| if self.use_checkpoint: | |
| x = checkpoint.checkpoint(block, x) | |
| else: | |
| x = block(x) | |
| if self.downsample is not None: | |
| x = self.downsample(x) | |
| return x | |
| class TinyViT(BaseBackbone): | |
| """TinyViT. | |
| A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation | |
| for Small Vision Transformers<https://arxiv.org/abs/2201.03545v1>`_ | |
| Inspiration from | |
| https://github.com/microsoft/Cream/blob/main/TinyViT | |
| Args: | |
| arch (str | dict): The architecture of TinyViT. | |
| Default: '5m'. | |
| img_size (tuple | int): The resolution of the input image. | |
| Default: (224, 224) | |
| window_size (list): The size of the window. | |
| Default: [7, 7, 14, 7] | |
| in_channels (int): The number of input channels. | |
| Default: 3. | |
| depths (list[int]): The depth of each stage. | |
| Default: [2, 2, 6, 2]. | |
| mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim. | |
| Default: 4. | |
| drop_rate (float): Probability of an element to be zeroed. | |
| Default: 0. | |
| drop_path_rate (float): The drop path of the block. | |
| Default: 0.1. | |
| use_checkpoint (bool): Whether to use checkpointing to save memory. | |
| Default: False. | |
| mbconv_expand_ratio (int): The expand ratio of the mbconv. | |
| Default: 4.0 | |
| local_conv_size (int): The size of the local conv. | |
| Default: 3. | |
| layer_lr_decay (float): The layer lr decay. | |
| Default: 1.0 | |
| out_indices (int | list[int]): Output from which stages. | |
| Default: -1 | |
| frozen_stages (int | list[int]): Stages to be frozen (all param fixed). | |
| Default: -0 | |
| gap_before_final_nrom (bool): Whether to add a gap before the final | |
| norm. Default: True. | |
| act_cfg (dict): The activation config of the module. | |
| Default: dict(type='GELU'). | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='LN'). | |
| init_cfg (dict | list[dict], optional): Initialization config dict. | |
| Default: None. | |
| """ | |
| arch_settings = { | |
| '5m': { | |
| 'channels': [64, 128, 160, 320], | |
| 'num_heads': [2, 4, 5, 10], | |
| 'depths': [2, 2, 6, 2], | |
| }, | |
| '11m': { | |
| 'channels': [64, 128, 256, 448], | |
| 'num_heads': [2, 4, 8, 14], | |
| 'depths': [2, 2, 6, 2], | |
| }, | |
| '21m': { | |
| 'channels': [96, 192, 384, 576], | |
| 'num_heads': [3, 6, 12, 18], | |
| 'depths': [2, 2, 6, 2], | |
| }, | |
| } | |
| def __init__(self, | |
| arch='5m', | |
| img_size=(224, 224), | |
| window_size=[7, 7, 14, 7], | |
| in_channels=3, | |
| mlp_ratio=4., | |
| drop_rate=0., | |
| drop_path_rate=0.1, | |
| use_checkpoint=False, | |
| mbconv_expand_ratio=4.0, | |
| local_conv_size=3, | |
| layer_lr_decay=1.0, | |
| out_indices=-1, | |
| frozen_stages=0, | |
| gap_before_final_norm=True, | |
| act_cfg=dict(type='GELU'), | |
| norm_cfg=dict(type='LN'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| if isinstance(arch, str): | |
| assert arch in self.arch_settings, \ | |
| f'Unavaiable arch, please choose from ' \ | |
| f'({set(self.arch_settings)} or pass a dict.' | |
| arch = self.arch_settings[arch] | |
| elif isinstance(arch, dict): | |
| assert 'channels' in arch and 'num_heads' in arch and \ | |
| 'depths' in arch, 'The arch dict must have' \ | |
| f'"channels", "num_heads", "window_sizes" ' \ | |
| f'keys, but got {arch.keys()}' | |
| self.channels = arch['channels'] | |
| self.num_heads = arch['num_heads'] | |
| self.widow_sizes = window_size | |
| self.img_size = img_size | |
| self.depths = arch['depths'] | |
| self.num_stages = len(self.channels) | |
| if isinstance(out_indices, int): | |
| out_indices = [out_indices] | |
| assert isinstance(out_indices, Sequence), \ | |
| f'"out_indices" must by a sequence or int, ' \ | |
| f'get {type(out_indices)} instead.' | |
| for i, index in enumerate(out_indices): | |
| if index < 0: | |
| out_indices[i] = 4 + index | |
| assert out_indices[i] >= 0, f'Invalid out_indices {index}' | |
| self.out_indices = out_indices | |
| self.frozen_stages = frozen_stages | |
| self.gap_before_final_norm = gap_before_final_norm | |
| self.layer_lr_decay = layer_lr_decay | |
| self.patch_embed = PatchEmbed( | |
| in_channels=in_channels, | |
| embed_dim=self.channels[0], | |
| resolution=self.img_size, | |
| act_cfg=dict(type='GELU')) | |
| patches_resolution = self.patch_embed.patches_resolution | |
| # stochastic depth decay rule | |
| dpr = [ | |
| x.item() | |
| for x in torch.linspace(0, drop_path_rate, sum(self.depths)) | |
| ] | |
| # build stages | |
| self.stages = ModuleList() | |
| for i in range(self.num_stages): | |
| depth = self.depths[i] | |
| channel = self.channels[i] | |
| curr_resolution = (patches_resolution[0] // (2**i), | |
| patches_resolution[1] // (2**i)) | |
| drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])] | |
| downsample = PatchMerging if (i < self.num_stages - 1) else None | |
| out_channels = self.channels[min(i + 1, self.num_stages - 1)] | |
| if i >= 1: | |
| stage = BasicStage( | |
| in_channels=channel, | |
| resolution=curr_resolution, | |
| depth=depth, | |
| num_heads=self.num_heads[i], | |
| window_size=self.widow_sizes[i], | |
| mlp_ratio=mlp_ratio, | |
| drop=drop_rate, | |
| drop_path=drop_path, | |
| downsample=downsample, | |
| use_checkpoint=use_checkpoint, | |
| local_conv_size=local_conv_size, | |
| out_channels=out_channels, | |
| act_cfg=act_cfg) | |
| else: | |
| stage = ConvStage( | |
| in_channels=channel, | |
| resolution=curr_resolution, | |
| depth=depth, | |
| act_cfg=act_cfg, | |
| drop_path=drop_path, | |
| downsample=downsample, | |
| use_checkpoint=use_checkpoint, | |
| out_channels=out_channels, | |
| conv_expand_ratio=mbconv_expand_ratio) | |
| self.stages.append(stage) | |
| # add output norm | |
| if i in self.out_indices: | |
| norm_layer = build_norm_layer(norm_cfg, out_channels)[1] | |
| self.add_module(f'norm{i}', norm_layer) | |
| def set_layer_lr_decay(self, layer_lr_decay): | |
| # TODO: add layer_lr_decay | |
| pass | |
| def forward(self, x): | |
| outs = [] | |
| x = self.patch_embed(x) | |
| for i, stage in enumerate(self.stages): | |
| x = stage(x) | |
| if i in self.out_indices: | |
| norm_layer = getattr(self, f'norm{i}') | |
| if self.gap_before_final_norm: | |
| gap = x.mean(1) | |
| outs.append(norm_layer(gap)) | |
| else: | |
| out = norm_layer(x) | |
| # convert the (B,L,C) format into (B,C,H,W) format | |
| # which would be better for the downstream tasks. | |
| B, L, C = out.shape | |
| out = out.view(B, *stage.resolution, C) | |
| outs.append(out.permute(0, 3, 1, 2)) | |
| return tuple(outs) | |
| def _freeze_stages(self): | |
| for i in range(self.frozen_stages): | |
| stage = self.stages[i] | |
| stage.eval() | |
| for param in stage.parameters(): | |
| param.requires_grad = False | |
| def train(self, mode=True): | |
| super(TinyViT, self).train(mode) | |
| self._freeze_stages() | |