Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.modelUtils import ChebConv, Pool, residualBlock | |
| import torchvision.ops.roi_align as roi_align | |
| import numpy as np | |
| from huggingface_hub import PyTorchModelHubMixin, hf_hub_download | |
| import json | |
| import scipy.sparse as sp | |
| def scipy_to_torch_sparse(scp_matrix): | |
| values = scp_matrix.data | |
| indices = np.vstack((scp_matrix.row, scp_matrix.col)) | |
| i = torch.LongTensor(indices) | |
| v = torch.FloatTensor(values) | |
| shape = scp_matrix.shape | |
| sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape)) | |
| return sparse_tensor | |
| ## Adjacency Matrix | |
| def mOrgan(N): | |
| sub = np.zeros([N, N]) | |
| for i in range(0, N): | |
| sub[i, i-1] = 1 | |
| sub[i, (i+1)%N] = 1 | |
| return sub | |
| ## Downsampling Matrix | |
| def mOrganD(N): | |
| N2 = int(np.ceil(N/2)) | |
| sub = np.zeros([N2, N]) | |
| for i in range(0, N2): | |
| if (2*i+1) == N: | |
| sub[i, 2*i] = 1 | |
| else: | |
| sub[i, 2*i] = 1/2 | |
| sub[i, 2*i+1] = 1/2 | |
| return sub | |
| def mOrganU(N): | |
| N2 = int(np.ceil(N/2)) | |
| sub = np.zeros([N, N2]) | |
| for i in range(0, N): | |
| if i % 2 == 0: | |
| sub[i, i//2] = 1 | |
| else: | |
| sub[i, i//2] = 1/2 | |
| sub[i, (i//2 + 1) % N2] = 1/2 | |
| return sub | |
| def genMatrixesLungsHeart(): | |
| RLUNG = 44 | |
| LLUNG = 50 | |
| HEART = 26 | |
| Asub1 = mOrgan(RLUNG) | |
| Asub2 = mOrgan(LLUNG) | |
| Asub3 = mOrgan(HEART) | |
| ADsub1 = mOrgan(int(np.ceil(RLUNG / 2))) | |
| ADsub2 = mOrgan(int(np.ceil(LLUNG / 2))) | |
| ADsub3 = mOrgan(int(np.ceil(HEART / 2))) | |
| Dsub1 = mOrganD(RLUNG) | |
| Dsub2 = mOrganD(LLUNG) | |
| Dsub3 = mOrganD(HEART) | |
| Usub1 = mOrganU(RLUNG) | |
| Usub2 = mOrganU(LLUNG) | |
| Usub3 = mOrganU(HEART) | |
| p1 = RLUNG | |
| p2 = p1 + LLUNG | |
| p3 = p2 + HEART | |
| p1_ = int(np.ceil(RLUNG / 2)) | |
| p2_ = p1_ + int(np.ceil(LLUNG / 2)) | |
| p3_ = p2_ + int(np.ceil(HEART / 2)) | |
| A = np.zeros([p3, p3]) | |
| A[:p1, :p1] = Asub1 | |
| A[p1:p2, p1:p2] = Asub2 | |
| A[p2:p3, p2:p3] = Asub3 | |
| AD = np.zeros([p3_, p3_]) | |
| AD[:p1_, :p1_] = ADsub1 | |
| AD[p1_:p2_, p1_:p2_] = ADsub2 | |
| AD[p2_:p3_, p2_:p3_] = ADsub3 | |
| D = np.zeros([p3_, p3]) | |
| D[:p1_, :p1] = Dsub1 | |
| D[p1_:p2_, p1:p2] = Dsub2 | |
| D[p2_:p3_, p2:p3] = Dsub3 | |
| U = np.zeros([p3, p3_]) | |
| U[:p1, :p1_] = Usub1 | |
| U[p1:p2, p1_:p2_] = Usub2 | |
| U[p2:p3, p2_:p3_] = Usub3 | |
| return A, AD, D, U | |
| class EncoderConv(nn.Module): | |
| def __init__(self, latents = 64, hw = 32): | |
| super(EncoderConv, self).__init__() | |
| self.latents = latents | |
| self.c = 4 | |
| self.size = self.c * np.array([2,4,8,16,32], dtype = np.intc) | |
| self.maxpool = nn.MaxPool2d(2) | |
| self.dconv_down1 = residualBlock(1, self.size[0]) | |
| self.dconv_down2 = residualBlock(self.size[0], self.size[1]) | |
| self.dconv_down3 = residualBlock(self.size[1], self.size[2]) | |
| self.dconv_down4 = residualBlock(self.size[2], self.size[3]) | |
| self.dconv_down5 = residualBlock(self.size[3], self.size[4]) | |
| self.dconv_down6 = residualBlock(self.size[4], self.size[4]) | |
| self.fc_mu = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents) | |
| self.fc_logvar = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents) | |
| def forward(self, x): | |
| x = self.dconv_down1(x) | |
| x = self.maxpool(x) | |
| x = self.dconv_down2(x) | |
| x = self.maxpool(x) | |
| conv3 = self.dconv_down3(x) | |
| x = self.maxpool(conv3) | |
| conv4 = self.dconv_down4(x) | |
| x = self.maxpool(conv4) | |
| conv5 = self.dconv_down5(x) | |
| x = self.maxpool(conv5) | |
| conv6 = self.dconv_down6(x) | |
| x = conv6.view(conv6.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors | |
| x_mu = self.fc_mu(x) | |
| x_logvar = self.fc_logvar(x) | |
| return x_mu, x_logvar, conv6, conv5 | |
| class SkipBlock(nn.Module): | |
| def __init__(self, in_filters, window): | |
| super(SkipBlock, self).__init__() | |
| self.window = window | |
| self.graphConv_pre = ChebConv(in_filters, 2, 1, bias = False) | |
| def lookup(self, pos, layer, salida = (1,1)): | |
| B = pos.shape[0] | |
| N = pos.shape[1] | |
| F = layer.shape[1] | |
| h = layer.shape[-1] | |
| ## Scale from [0,1] to [0, h] | |
| pos = pos * h | |
| _x1 = (self.window[0] // 2) * 1.0 | |
| _x2 = (self.window[0] // 2 + 1) * 1.0 | |
| _y1 = (self.window[1] // 2) * 1.0 | |
| _y2 = (self.window[1] // 2 + 1) * 1.0 | |
| boxes = [] | |
| for batch in range(0, B): | |
| x1 = pos[batch,:,0].reshape(-1, 1) - _x1 | |
| x2 = pos[batch,:,0].reshape(-1, 1) + _x2 | |
| y1 = pos[batch,:,1].reshape(-1, 1) - _y1 | |
| y2 = pos[batch,:,1].reshape(-1, 1) + _y2 | |
| aux = torch.cat([x1, y1, x2, y2], axis = 1) | |
| boxes.append(aux) | |
| skip = roi_align(layer, boxes, output_size = salida, aligned=True) | |
| vista = skip.view([B, N, -1]) | |
| return vista | |
| def forward(self, x, adj, conv_layer): | |
| pos = self.graphConv_pre(x, adj) | |
| skip = self.lookup(pos, conv_layer) | |
| return torch.cat((x, skip, pos), axis = 2), pos | |
| class Hybrid(nn.Module): | |
| def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices): | |
| super(Hybrid, self).__init__() | |
| self.config = config | |
| hw = config['inputsize'] // 32 | |
| self.z = config['latents'] | |
| self.encoder = EncoderConv(latents = self.z, hw = hw) | |
| self.eval_sampling = config['eval_sampling'] | |
| self.downsample_matrices = downsample_matrices | |
| self.upsample_matrices = upsample_matrices | |
| self.adjacency_matrices = adjacency_matrices | |
| self.kld_weight = 1e-5 | |
| n_nodes = config['n_nodes'] | |
| self.filters = config['filters'] | |
| self.K = 6 | |
| self.window = (3,3) | |
| # Generate the fully connected layer for the decoder | |
| outshape = self.filters[-1] * n_nodes[-1] | |
| self.dec_lin = torch.nn.Linear(self.z, outshape) | |
| self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1]) | |
| self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2]) | |
| self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3]) | |
| self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4]) | |
| self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5]) | |
| outsize1 = self.encoder.size[4] | |
| outsize2 = self.encoder.size[4] | |
| # Store graph convolution layers | |
| self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K) | |
| self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K) | |
| self.SC_1 = SkipBlock(self.filters[4], self.window) | |
| self.graphConv_up4 = ChebConv(self.filters[4] + outsize1 + 2, self.filters[3], self.K) | |
| self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K) | |
| self.SC_2 = SkipBlock(self.filters[2], self.window) | |
| self.graphConv_up2 = ChebConv(self.filters[2] + outsize2 + 2, self.filters[1], self.K) | |
| self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False) | |
| self.pool = Pool() | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1) | |
| def sampling(self, mu, log_var): | |
| std = torch.exp(0.5*log_var) | |
| eps = torch.randn_like(std) | |
| return eps.mul(std).add_(mu) | |
| def encode(self, x): | |
| """Encode the input and return latent representations and skip connections""" | |
| mu, log_var, conv6, conv5 = self.encoder(x) | |
| return mu, log_var, conv6, conv5 | |
| def decode(self, z, conv6, conv5): | |
| """Decode from latent space using skip connections""" | |
| x = self.dec_lin(z) | |
| x = F.relu(x) | |
| x = x.reshape(x.shape[0], -1, self.filters[-1]) | |
| x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices()) | |
| x = self.normalization6u(x) | |
| x = F.relu(x) | |
| x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices()) | |
| x = self.normalization5u(x) | |
| x = F.relu(x) | |
| x, pos1 = self.SC_1(x, self.adjacency_matrices[3]._indices(), conv6) | |
| x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices()) | |
| x = self.normalization4u(x) | |
| x = F.relu(x) | |
| x = self.pool(x, self.upsample_matrices[0]) | |
| x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices()) | |
| x = self.normalization3u(x) | |
| x = F.relu(x) | |
| x, pos2 = self.SC_2(x, self.adjacency_matrices[1]._indices(), conv5) | |
| x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices()) | |
| x = self.normalization2u(x) | |
| x = F.relu(x) | |
| x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) # No relu and no bias | |
| return x, pos1, pos2 | |
| def forward(self, x): | |
| """Full forward pass (both encoding and decoding)""" | |
| self.mu, self.log_var, conv6, conv5 = self.encode(x) | |
| if self.training or self.eval_sampling: | |
| z = self.sampling(self.mu, self.log_var) | |
| else: | |
| z = self.mu | |
| return self.decode(z, conv6, conv5) | |
| class HybridNoSkip(nn.Module): | |
| def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices): | |
| super(HybridNoSkip, self).__init__() | |
| hw = config['inputsize'] // 32 | |
| self.eval_sampling = config['eval_sampling'] | |
| self.z = config['latents'] | |
| self.encoder = EncoderConv(latents = self.z, hw = hw) | |
| self.downsample_matrices = downsample_matrices | |
| self.upsample_matrices = upsample_matrices | |
| self.adjacency_matrices = adjacency_matrices | |
| self.kld_weight = 1e-5 | |
| n_nodes = config['n_nodes'] | |
| self.filters = config['filters'] | |
| self.K = 6 | |
| # Genero la capa fully connected del decoder | |
| outshape = self.filters[-1] * n_nodes[-1] | |
| self.dec_lin = torch.nn.Linear(self.z, outshape) | |
| self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1]) | |
| self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2]) | |
| self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3]) | |
| self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4]) | |
| self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5]) | |
| self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K) | |
| self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K) | |
| self.graphConv_up4 = ChebConv(self.filters[4], self.filters[3], self.K) | |
| self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K) | |
| self.graphConv_up2 = ChebConv(self.filters[2], self.filters[1], self.K) | |
| ## Out layer: Sin bias, normalization ni relu | |
| self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False) | |
| self.pool = Pool() | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1) | |
| def sampling(self, mu, log_var): | |
| std = torch.exp(0.5*log_var) | |
| eps = torch.randn_like(std) | |
| return eps.mul(std).add_(mu) | |
| def encode(self, x): | |
| mu, log_var, conv6, conv5 = self.encoder(x) | |
| return mu, log_var, conv6, conv5 | |
| def decode(self, z, conv6, conv5): | |
| # Decode from latent space z to reconstruct x | |
| x = self.dec_lin(z) | |
| x = F.relu(x) | |
| x = x.reshape(x.shape[0], -1, self.filters[-1]) | |
| x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices()) | |
| x = self.normalization6u(x) | |
| x = F.relu(x) | |
| x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices()) | |
| x = self.normalization5u(x) | |
| x = F.relu(x) | |
| x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices()) | |
| x = self.normalization4u(x) | |
| x = F.relu(x) | |
| x = self.pool(x, self.upsample_matrices[0]) | |
| x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices()) | |
| x = self.normalization3u(x) | |
| x = F.relu(x) | |
| x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices()) | |
| x = self.normalization2u(x) | |
| x = F.relu(x) | |
| x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) # No relu and no bias | |
| return x, None, None | |
| def forward(self, x): | |
| # Full forward pass: encode, sample (if training), then decode. | |
| self.mu, self.log_var, conv6, conv5 = self.encode(x) | |
| if self.training: | |
| z = self.sampling(self.mu, self.log_var) | |
| else: | |
| z = self.mu | |
| return self.decode(z, conv6, conv5) | |
| class HybridGNetHF(nn.Module, PyTorchModelHubMixin): | |
| repo_url = "https://github.com/mcosarinsky/CheXmask-U" | |
| license = "mit" | |
| pipeline_tag = "image-segmentation" | |
| def __init__(self, latents=64, inputsize=1024, K=6, filters=None, | |
| skip_features=32, eval_sampling=True, use_skip=True, | |
| n_nodes=None, device="cpu", **kwargs): | |
| super().__init__() | |
| self.device = device | |
| # Defaults | |
| if filters is None: | |
| filters = [2, 32, 32, 32, 16, 16, 16] | |
| # Save config | |
| self.config = { | |
| 'latents': latents, | |
| 'inputsize': inputsize, | |
| 'K': K, | |
| 'filters': filters, | |
| 'skip_features': skip_features, | |
| 'eval_sampling': eval_sampling, | |
| 'use_skip': use_skip | |
| } | |
| self.config.update(kwargs) | |
| self.use_skip = use_skip | |
| # ---- generate matrices ---- | |
| A, AD, D, U = genMatrixesLungsHeart() | |
| N1, N2 = A.shape[0], AD.shape[0] | |
| self.config['n_nodes'] = [N1, N1, N1, N2, N2, N2] | |
| # ---- convert to sparse tensors and move to device ---- | |
| A_ = [sp.csc_matrix(A).tocoo() for _ in range(3)] + [sp.csc_matrix(AD).tocoo() for _ in range(3)] | |
| D_ = [sp.csc_matrix(D).tocoo()] | |
| U_ = [sp.csc_matrix(U).tocoo()] | |
| self.A_t = [scipy_to_torch_sparse(x).to(self.device) for x in A_] | |
| self.D_t = [scipy_to_torch_sparse(x).to(self.device) for x in D_] | |
| self.U_t = [scipy_to_torch_sparse(x).to(self.device) for x in U_] | |
| # ---- build model ---- | |
| if self.use_skip: | |
| self.model = Hybrid(self.config, self.D_t, self.U_t, self.A_t) | |
| else: | |
| self.model = HybridNoSkip(self.config, self.D_t, self.U_t, self.A_t) | |
| # move model parameters to device | |
| self.model.to(self.device) | |
| def forward(self, x): | |
| return self.model(x) | |
| def encode(self, x): | |
| return self.model.encode(x) | |
| def decode(self, z, conv6, conv5): | |
| return self.model.decode(z, conv6, conv5) | |
| def sampling(self, mu, log_var): | |
| return self.model.sampling(mu, log_var) | |
| def from_pretrained(cls, repo_id, subfolder=None, device="cpu", **kwargs): | |
| """ | |
| Loads model directly from Hugging Face Hub. Does NOT use local paths. | |
| """ | |
| # Download config from Hub | |
| config_file = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="config.json", | |
| subfolder=subfolder | |
| ) | |
| with open(config_file, "r") as f: | |
| config = json.load(f) | |
| # Merge any additional kwargs | |
| config.update(kwargs) | |
| # Dynamically compute n_nodes | |
| A, AD, D, U = genMatrixesLungsHeart() | |
| N1, N2 = A.shape[0], AD.shape[0] | |
| config['n_nodes'] = [N1, N1, N1, N2, N2, N2] | |
| # Instantiate model on desired device | |
| model = cls(device=device, **config) | |
| # Download weights from Hub | |
| weights_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="pytorch_model.bin", | |
| subfolder=subfolder | |
| ) | |
| state_dict = torch.load(weights_path, map_location=device) | |
| if not next(iter(state_dict.keys())).startswith("model."): | |
| state_dict = {f"model.{k}": v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| return model |