Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from mmengine.dist import all_reduce | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int, | |
| world_size: int, epsilon: float) -> torch.Tensor: | |
| """Apply the distributed sinknorn optimization on the scores matrix to find | |
| the assignments. | |
| This function is modified from | |
| https://github.com/facebookresearch/swav/blob/main/main_swav.py | |
| Args: | |
| out (torch.Tensor): The scores matrix | |
| sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp | |
| algorithm. | |
| world_size (int): The world size of the process group. | |
| epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. | |
| Returns: | |
| torch.Tensor: Output of sinkhorn algorithm. | |
| """ | |
| eps_num_stab = 1e-12 | |
| Q = torch.exp(out / epsilon).t( | |
| ) # Q is K-by-B for consistency with notations from our paper | |
| B = Q.shape[1] * world_size # number of samples to assign | |
| K = Q.shape[0] # how many prototypes | |
| # make the matrix sums to 1 | |
| sum_Q = torch.sum(Q) | |
| all_reduce(sum_Q) | |
| Q /= sum_Q | |
| for it in range(sinkhorn_iterations): | |
| # normalize each row: total weight per prototype must be 1/K | |
| u = torch.sum(Q, dim=1, keepdim=True) | |
| if len(torch.nonzero(u == 0)) > 0: | |
| Q += eps_num_stab | |
| u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype) | |
| all_reduce(u) | |
| Q /= u | |
| Q /= K | |
| # normalize each column: total weight per sample must be 1/B | |
| Q /= torch.sum(Q, dim=0, keepdim=True) | |
| Q /= B | |
| Q *= B # the columns must sum to 1 so that Q is an assignment | |
| return Q.t() | |
| class MultiPrototypes(BaseModule): | |
| """Multi-prototypes for SwAV head. | |
| Args: | |
| output_dim (int): The output dim from SwAV neck. | |
| num_prototypes (List[int]): The number of prototypes needed. | |
| init_cfg (dict or List[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| output_dim: int, | |
| num_prototypes: List[int], | |
| init_cfg: Optional[Union[List[dict], dict]] = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| assert isinstance(num_prototypes, list) | |
| self.num_heads = len(num_prototypes) | |
| for i, k in enumerate(num_prototypes): | |
| self.add_module('prototypes' + str(i), | |
| nn.Linear(output_dim, k, bias=False)) | |
| def forward(self, x: torch.Tensor) -> List[torch.Tensor]: | |
| """Run forward for every prototype.""" | |
| out = [] | |
| for i in range(self.num_heads): | |
| out.append(getattr(self, 'prototypes' + str(i))(x)) | |
| return out | |
| class SwAVLoss(BaseModule): | |
| """The Loss for SwAV. | |
| This Loss contains clustering and sinkhorn algorithms to compute Q codes. | |
| Part of the code is borrowed from `script | |
| <https://github.com/facebookresearch/swav>`_. | |
| The queue is built in `engine/hooks/swav_hook.py`. | |
| Args: | |
| feat_dim (int): feature dimension of the prototypes. | |
| sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp | |
| algorithm. Defaults to 3. | |
| epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. | |
| Defaults to 0.05. | |
| temperature (float): temperature parameter in training loss. | |
| Defaults to 0.1. | |
| crops_for_assign (List[int]): list of crops id used for computing | |
| assignments. Defaults to [0, 1]. | |
| num_crops (List[int]): list of number of crops. Defaults to [2]. | |
| num_prototypes (int): number of prototypes. Defaults to 3000. | |
| init_cfg (dict or List[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| feat_dim: int, | |
| sinkhorn_iterations: int = 3, | |
| epsilon: float = 0.05, | |
| temperature: float = 0.1, | |
| crops_for_assign: List[int] = [0, 1], | |
| num_crops: List[int] = [2], | |
| num_prototypes: int = 3000, | |
| init_cfg: Optional[Union[List[dict], dict]] = None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.sinkhorn_iterations = sinkhorn_iterations | |
| self.epsilon = epsilon | |
| self.temperature = temperature | |
| self.crops_for_assign = crops_for_assign | |
| self.num_crops = num_crops | |
| self.use_queue = False | |
| self.queue = None | |
| self.world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| # prototype layer | |
| self.prototypes = None | |
| if isinstance(num_prototypes, list): | |
| self.prototypes = MultiPrototypes(feat_dim, num_prototypes) | |
| elif num_prototypes > 0: | |
| self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False) | |
| assert self.prototypes is not None | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward function of SwAV loss. | |
| Args: | |
| x (torch.Tensor): NxC input features. | |
| Returns: | |
| torch.Tensor: The returned loss. | |
| """ | |
| # normalize the prototypes | |
| with torch.no_grad(): | |
| w = self.prototypes.weight.data.clone() | |
| w = nn.functional.normalize(w, dim=1, p=2) | |
| self.prototypes.weight.copy_(w) | |
| embedding, output = x, self.prototypes(x) | |
| embedding = embedding.detach() | |
| bs = int(embedding.size(0) / sum(self.num_crops)) | |
| loss = 0 | |
| for i, crop_id in enumerate(self.crops_for_assign): | |
| with torch.no_grad(): | |
| out = output[bs * crop_id:bs * (crop_id + 1)].detach() | |
| # time to use the queue | |
| if self.queue is not None: | |
| if self.use_queue or not torch.all(self.queue[i, | |
| -1, :] == 0): | |
| self.use_queue = True | |
| out = torch.cat( | |
| (torch.mm(self.queue[i], | |
| self.prototypes.weight.t()), out)) | |
| # fill the queue | |
| self.queue[i, bs:] = self.queue[i, :-bs].clone() | |
| self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) * | |
| bs] | |
| # get assignments (batch_size * num_prototypes) | |
| q = distributed_sinkhorn(out, self.sinkhorn_iterations, | |
| self.world_size, self.epsilon)[-bs:] | |
| # cluster assignment prediction | |
| subloss = 0 | |
| for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id): | |
| x = output[bs * v:bs * (v + 1)] / self.temperature | |
| subloss -= torch.mean( | |
| torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1)) | |
| loss += subloss / (np.sum(self.num_crops) - 1) | |
| loss /= len(self.crops_for_assign) | |
| return loss | |