Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| def knn(x, k): | |
| inner = -2*torch.matmul(x.transpose(2, 1), x) | |
| xx = torch.sum(x**2, dim=1, keepdim=True) | |
| pairwise_distance = -xx - inner - xx.transpose(2, 1) | |
| idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) | |
| return idx, pairwise_distance | |
| def local_operator(x, k): | |
| batch_size = x.size(0) | |
| num_points = x.size(2) | |
| x = x.view(batch_size, -1, num_points) | |
| idx, _ = knn(x, k=k) | |
| device = torch.device('cpu') | |
| idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points | |
| idx = idx + idx_base | |
| idx = idx.view(-1) | |
| _, num_dims, _ = x.size() | |
| x = x.transpose(2, 1).contiguous() | |
| neighbor = x.view(batch_size * num_points, -1)[idx, :] | |
| neighbor = neighbor.view(batch_size, num_points, k, num_dims) | |
| x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) | |
| feature = torch.cat((neighbor-x, neighbor), dim=3).permute(0, 3, 1, 2) # local and global all in | |
| return feature | |
| def local_operator_withnorm(x, norm_plt, k): | |
| batch_size = x.size(0) | |
| num_points = x.size(2) | |
| x = x.view(batch_size, -1, num_points) | |
| norm_plt = norm_plt.view(batch_size, -1, num_points) | |
| idx, _ = knn(x, k=k) # (batch_size, num_points, k) | |
| device = torch.device('cpu') | |
| idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points | |
| idx = idx + idx_base | |
| idx = idx.view(-1) | |
| _, num_dims, _ = x.size() | |
| x = x.transpose(2, 1).contiguous() | |
| norm_plt = norm_plt.transpose(2, 1).contiguous() | |
| neighbor = x.view(batch_size * num_points, -1)[idx, :] | |
| neighbor_norm = norm_plt.view(batch_size * num_points, -1)[idx, :] | |
| neighbor = neighbor.view(batch_size, num_points, k, num_dims) | |
| neighbor_norm = neighbor_norm.view(batch_size, num_points, k, num_dims) | |
| x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) | |
| feature = torch.cat((neighbor-x, neighbor, neighbor_norm), dim=3).permute(0, 3, 1, 2) # 3c | |
| return feature | |
| def GDM(x, M): | |
| """ | |
| Geometry-Disentangle Module | |
| M: number of disentangled points in both sharp and gentle variation components | |
| """ | |
| k = 64 # number of neighbors to decide the range of j in Eq.(5) | |
| tau = 0.2 # threshold in Eq.(2) | |
| sigma = 2 # parameters of f (Gaussian function in Eq.(2)) | |
| ############### | |
| """Graph Construction:""" | |
| device = torch.device('cpu') | |
| batch_size = x.size(0) | |
| num_points = x.size(2) | |
| x = x.view(batch_size, -1, num_points) | |
| idx, p = knn(x, k=k) # p: -[(x1-x2)^2+...] | |
| # here we add a tau | |
| p1 = torch.abs(p) | |
| p1 = torch.sqrt(p1) | |
| mask = p1 < tau | |
| # here we add a sigma | |
| p = p / (sigma * sigma) | |
| w = torch.exp(p) # b,n,n | |
| w = torch.mul(mask.float(), w) | |
| b = 1/torch.sum(w, dim=1) | |
| b = b.reshape(batch_size, num_points, 1).repeat(1, 1, num_points) | |
| c = torch.eye(num_points, num_points, device=device) | |
| c = c.expand(batch_size, num_points, num_points) | |
| D = b * c # b,n,n | |
| A = torch.matmul(D, w) # normalized adjacency matrix A_hat | |
| # Get Aij in a local area: | |
| idx2 = idx.view(batch_size * num_points, -1) | |
| idx_base2 = torch.arange(0, batch_size * num_points, device=device).view(-1, 1) * num_points | |
| idx2 = idx2 + idx_base2 | |
| idx2 = idx2.reshape(batch_size * num_points, k)[:, 1:k] | |
| idx2 = idx2.reshape(batch_size * num_points * (k - 1)) | |
| idx2 = idx2.view(-1) | |
| A = A.view(-1) | |
| A = A[idx2].reshape(batch_size, num_points, k - 1) # Aij: b,n,k | |
| ############### | |
| """Disentangling Point Clouds into Sharp(xs) and Gentle(xg) Variation Components:""" | |
| idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points | |
| idx = idx + idx_base | |
| idx = idx.reshape(batch_size * num_points, k)[:, 1:k] | |
| idx = idx.reshape(batch_size * num_points * (k - 1)) | |
| _, num_dims, _ = x.size() | |
| x = x.transpose(2, 1).contiguous() # b,n,c | |
| neighbor = x.view(batch_size * num_points, -1)[idx, :] | |
| neighbor = neighbor.view(batch_size, num_points, k - 1, num_dims) # b,n,k,c | |
| A = A.reshape(batch_size, num_points, k - 1, 1) # b,n,k,1 | |
| n = A.mul(neighbor) # b,n,k,c | |
| n = torch.sum(n, dim=2) # b,n,c | |
| pai = torch.norm(x - n, dim=-1).pow(2) # Eq.(5) | |
| pais = pai.topk(k=M, dim=-1)[1] # first M points as the sharp variation component | |
| paig = (-pai).topk(k=M, dim=-1)[1] # last M points as the gentle variation component | |
| pai_base = torch.arange(0, batch_size, device=device).view(-1, 1) * num_points | |
| indices = (pais + pai_base).view(-1) | |
| indiceg = (paig + pai_base).view(-1) | |
| xs = x.view(batch_size * num_points, -1)[indices, :] | |
| xg = x.view(batch_size * num_points, -1)[indiceg, :] | |
| xs = xs.view(batch_size, M, -1) # b,M,c | |
| xg = xg.view(batch_size, M, -1) # b,M,c | |
| return xs, xg | |
| class SGCAM(nn.Module): | |
| """Sharp-Gentle Complementary Attention Module:""" | |
| def __init__(self, in_channels, inter_channels=None, bn_layer=True): | |
| super(SGCAM, self).__init__() | |
| self.in_channels = in_channels | |
| self.inter_channels = inter_channels | |
| if self.inter_channels is None: | |
| self.inter_channels = in_channels // 2 | |
| if self.inter_channels == 0: | |
| self.inter_channels = 1 | |
| conv_nd = nn.Conv1d | |
| bn = nn.BatchNorm1d | |
| self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| if bn_layer: | |
| self.W = nn.Sequential( | |
| conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | |
| kernel_size=1, stride=1, padding=0), | |
| bn(self.in_channels) | |
| ) | |
| nn.init.constant(self.W[1].weight, 0) | |
| nn.init.constant(self.W[1].bias, 0) | |
| else: | |
| self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| nn.init.constant(self.W.weight, 0) | |
| nn.init.constant(self.W.bias, 0) | |
| self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| def forward(self, x, x_2): | |
| batch_size = x.size(0) | |
| g_x = self.g(x_2).view(batch_size, self.inter_channels, -1) | |
| g_x = g_x.permute(0, 2, 1) | |
| theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | |
| theta_x = theta_x.permute(0, 2, 1) | |
| phi_x = self.phi(x_2).view(batch_size, self.inter_channels, -1) | |
| W = torch.matmul(theta_x, phi_x) # Attention Matrix | |
| N = W.size(-1) | |
| W_div_C = W / N | |
| y = torch.matmul(W_div_C, g_x) | |
| y = y.permute(0, 2, 1).contiguous() | |
| y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | |
| W_y = self.W(y) | |
| y = W_y + x | |
| return y | |