CheXmask-U / models /HybridGNet2IGSC.py
mcosarinsky's picture
update hf
00c20f5
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.modelUtils import ChebConv, Pool, residualBlock
import torchvision.ops.roi_align as roi_align
import numpy as np
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
import json
import scipy.sparse as sp
def scipy_to_torch_sparse(scp_matrix):
values = scp_matrix.data
indices = np.vstack((scp_matrix.row, scp_matrix.col))
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
shape = scp_matrix.shape
sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
return sparse_tensor
## Adjacency Matrix
def mOrgan(N):
sub = np.zeros([N, N])
for i in range(0, N):
sub[i, i-1] = 1
sub[i, (i+1)%N] = 1
return sub
## Downsampling Matrix
def mOrganD(N):
N2 = int(np.ceil(N/2))
sub = np.zeros([N2, N])
for i in range(0, N2):
if (2*i+1) == N:
sub[i, 2*i] = 1
else:
sub[i, 2*i] = 1/2
sub[i, 2*i+1] = 1/2
return sub
def mOrganU(N):
N2 = int(np.ceil(N/2))
sub = np.zeros([N, N2])
for i in range(0, N):
if i % 2 == 0:
sub[i, i//2] = 1
else:
sub[i, i//2] = 1/2
sub[i, (i//2 + 1) % N2] = 1/2
return sub
def genMatrixesLungsHeart():
RLUNG = 44
LLUNG = 50
HEART = 26
Asub1 = mOrgan(RLUNG)
Asub2 = mOrgan(LLUNG)
Asub3 = mOrgan(HEART)
ADsub1 = mOrgan(int(np.ceil(RLUNG / 2)))
ADsub2 = mOrgan(int(np.ceil(LLUNG / 2)))
ADsub3 = mOrgan(int(np.ceil(HEART / 2)))
Dsub1 = mOrganD(RLUNG)
Dsub2 = mOrganD(LLUNG)
Dsub3 = mOrganD(HEART)
Usub1 = mOrganU(RLUNG)
Usub2 = mOrganU(LLUNG)
Usub3 = mOrganU(HEART)
p1 = RLUNG
p2 = p1 + LLUNG
p3 = p2 + HEART
p1_ = int(np.ceil(RLUNG / 2))
p2_ = p1_ + int(np.ceil(LLUNG / 2))
p3_ = p2_ + int(np.ceil(HEART / 2))
A = np.zeros([p3, p3])
A[:p1, :p1] = Asub1
A[p1:p2, p1:p2] = Asub2
A[p2:p3, p2:p3] = Asub3
AD = np.zeros([p3_, p3_])
AD[:p1_, :p1_] = ADsub1
AD[p1_:p2_, p1_:p2_] = ADsub2
AD[p2_:p3_, p2_:p3_] = ADsub3
D = np.zeros([p3_, p3])
D[:p1_, :p1] = Dsub1
D[p1_:p2_, p1:p2] = Dsub2
D[p2_:p3_, p2:p3] = Dsub3
U = np.zeros([p3, p3_])
U[:p1, :p1_] = Usub1
U[p1:p2, p1_:p2_] = Usub2
U[p2:p3, p2_:p3_] = Usub3
return A, AD, D, U
class EncoderConv(nn.Module):
def __init__(self, latents = 64, hw = 32):
super(EncoderConv, self).__init__()
self.latents = latents
self.c = 4
self.size = self.c * np.array([2,4,8,16,32], dtype = np.intc)
self.maxpool = nn.MaxPool2d(2)
self.dconv_down1 = residualBlock(1, self.size[0])
self.dconv_down2 = residualBlock(self.size[0], self.size[1])
self.dconv_down3 = residualBlock(self.size[1], self.size[2])
self.dconv_down4 = residualBlock(self.size[2], self.size[3])
self.dconv_down5 = residualBlock(self.size[3], self.size[4])
self.dconv_down6 = residualBlock(self.size[4], self.size[4])
self.fc_mu = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents)
self.fc_logvar = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents)
def forward(self, x):
x = self.dconv_down1(x)
x = self.maxpool(x)
x = self.dconv_down2(x)
x = self.maxpool(x)
conv3 = self.dconv_down3(x)
x = self.maxpool(conv3)
conv4 = self.dconv_down4(x)
x = self.maxpool(conv4)
conv5 = self.dconv_down5(x)
x = self.maxpool(conv5)
conv6 = self.dconv_down6(x)
x = conv6.view(conv6.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
x_mu = self.fc_mu(x)
x_logvar = self.fc_logvar(x)
return x_mu, x_logvar, conv6, conv5
class SkipBlock(nn.Module):
def __init__(self, in_filters, window):
super(SkipBlock, self).__init__()
self.window = window
self.graphConv_pre = ChebConv(in_filters, 2, 1, bias = False)
def lookup(self, pos, layer, salida = (1,1)):
B = pos.shape[0]
N = pos.shape[1]
F = layer.shape[1]
h = layer.shape[-1]
## Scale from [0,1] to [0, h]
pos = pos * h
_x1 = (self.window[0] // 2) * 1.0
_x2 = (self.window[0] // 2 + 1) * 1.0
_y1 = (self.window[1] // 2) * 1.0
_y2 = (self.window[1] // 2 + 1) * 1.0
boxes = []
for batch in range(0, B):
x1 = pos[batch,:,0].reshape(-1, 1) - _x1
x2 = pos[batch,:,0].reshape(-1, 1) + _x2
y1 = pos[batch,:,1].reshape(-1, 1) - _y1
y2 = pos[batch,:,1].reshape(-1, 1) + _y2
aux = torch.cat([x1, y1, x2, y2], axis = 1)
boxes.append(aux)
skip = roi_align(layer, boxes, output_size = salida, aligned=True)
vista = skip.view([B, N, -1])
return vista
def forward(self, x, adj, conv_layer):
pos = self.graphConv_pre(x, adj)
skip = self.lookup(pos, conv_layer)
return torch.cat((x, skip, pos), axis = 2), pos
class Hybrid(nn.Module):
def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices):
super(Hybrid, self).__init__()
self.config = config
hw = config['inputsize'] // 32
self.z = config['latents']
self.encoder = EncoderConv(latents = self.z, hw = hw)
self.eval_sampling = config['eval_sampling']
self.downsample_matrices = downsample_matrices
self.upsample_matrices = upsample_matrices
self.adjacency_matrices = adjacency_matrices
self.kld_weight = 1e-5
n_nodes = config['n_nodes']
self.filters = config['filters']
self.K = 6
self.window = (3,3)
# Generate the fully connected layer for the decoder
outshape = self.filters[-1] * n_nodes[-1]
self.dec_lin = torch.nn.Linear(self.z, outshape)
self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1])
self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2])
self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3])
self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4])
self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5])
outsize1 = self.encoder.size[4]
outsize2 = self.encoder.size[4]
# Store graph convolution layers
self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K)
self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K)
self.SC_1 = SkipBlock(self.filters[4], self.window)
self.graphConv_up4 = ChebConv(self.filters[4] + outsize1 + 2, self.filters[3], self.K)
self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K)
self.SC_2 = SkipBlock(self.filters[2], self.window)
self.graphConv_up2 = ChebConv(self.filters[2] + outsize2 + 2, self.filters[1], self.K)
self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False)
self.pool = Pool()
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1)
def sampling(self, mu, log_var):
std = torch.exp(0.5*log_var)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
def encode(self, x):
"""Encode the input and return latent representations and skip connections"""
mu, log_var, conv6, conv5 = self.encoder(x)
return mu, log_var, conv6, conv5
def decode(self, z, conv6, conv5):
"""Decode from latent space using skip connections"""
x = self.dec_lin(z)
x = F.relu(x)
x = x.reshape(x.shape[0], -1, self.filters[-1])
x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices())
x = self.normalization6u(x)
x = F.relu(x)
x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices())
x = self.normalization5u(x)
x = F.relu(x)
x, pos1 = self.SC_1(x, self.adjacency_matrices[3]._indices(), conv6)
x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices())
x = self.normalization4u(x)
x = F.relu(x)
x = self.pool(x, self.upsample_matrices[0])
x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices())
x = self.normalization3u(x)
x = F.relu(x)
x, pos2 = self.SC_2(x, self.adjacency_matrices[1]._indices(), conv5)
x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices())
x = self.normalization2u(x)
x = F.relu(x)
x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) # No relu and no bias
return x, pos1, pos2
def forward(self, x):
"""Full forward pass (both encoding and decoding)"""
self.mu, self.log_var, conv6, conv5 = self.encode(x)
if self.training or self.eval_sampling:
z = self.sampling(self.mu, self.log_var)
else:
z = self.mu
return self.decode(z, conv6, conv5)
class HybridNoSkip(nn.Module):
def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices):
super(HybridNoSkip, self).__init__()
hw = config['inputsize'] // 32
self.eval_sampling = config['eval_sampling']
self.z = config['latents']
self.encoder = EncoderConv(latents = self.z, hw = hw)
self.downsample_matrices = downsample_matrices
self.upsample_matrices = upsample_matrices
self.adjacency_matrices = adjacency_matrices
self.kld_weight = 1e-5
n_nodes = config['n_nodes']
self.filters = config['filters']
self.K = 6
# Genero la capa fully connected del decoder
outshape = self.filters[-1] * n_nodes[-1]
self.dec_lin = torch.nn.Linear(self.z, outshape)
self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1])
self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2])
self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3])
self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4])
self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5])
self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K)
self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K)
self.graphConv_up4 = ChebConv(self.filters[4], self.filters[3], self.K)
self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K)
self.graphConv_up2 = ChebConv(self.filters[2], self.filters[1], self.K)
## Out layer: Sin bias, normalization ni relu
self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False)
self.pool = Pool()
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1)
def sampling(self, mu, log_var):
std = torch.exp(0.5*log_var)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
def encode(self, x):
mu, log_var, conv6, conv5 = self.encoder(x)
return mu, log_var, conv6, conv5
def decode(self, z, conv6, conv5):
# Decode from latent space z to reconstruct x
x = self.dec_lin(z)
x = F.relu(x)
x = x.reshape(x.shape[0], -1, self.filters[-1])
x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices())
x = self.normalization6u(x)
x = F.relu(x)
x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices())
x = self.normalization5u(x)
x = F.relu(x)
x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices())
x = self.normalization4u(x)
x = F.relu(x)
x = self.pool(x, self.upsample_matrices[0])
x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices())
x = self.normalization3u(x)
x = F.relu(x)
x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices())
x = self.normalization2u(x)
x = F.relu(x)
x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) # No relu and no bias
return x, None, None
def forward(self, x):
# Full forward pass: encode, sample (if training), then decode.
self.mu, self.log_var, conv6, conv5 = self.encode(x)
if self.training:
z = self.sampling(self.mu, self.log_var)
else:
z = self.mu
return self.decode(z, conv6, conv5)
class HybridGNetHF(nn.Module, PyTorchModelHubMixin):
repo_url = "https://github.com/mcosarinsky/CheXmask-U"
license = "mit"
pipeline_tag = "image-segmentation"
def __init__(self, latents=64, inputsize=1024, K=6, filters=None,
skip_features=32, eval_sampling=True, use_skip=True,
n_nodes=None, device="cpu", **kwargs):
super().__init__()
self.device = device
# Defaults
if filters is None:
filters = [2, 32, 32, 32, 16, 16, 16]
# Save config
self.config = {
'latents': latents,
'inputsize': inputsize,
'K': K,
'filters': filters,
'skip_features': skip_features,
'eval_sampling': eval_sampling,
'use_skip': use_skip
}
self.config.update(kwargs)
self.use_skip = use_skip
# ---- generate matrices ----
A, AD, D, U = genMatrixesLungsHeart()
N1, N2 = A.shape[0], AD.shape[0]
self.config['n_nodes'] = [N1, N1, N1, N2, N2, N2]
# ---- convert to sparse tensors and move to device ----
A_ = [sp.csc_matrix(A).tocoo() for _ in range(3)] + [sp.csc_matrix(AD).tocoo() for _ in range(3)]
D_ = [sp.csc_matrix(D).tocoo()]
U_ = [sp.csc_matrix(U).tocoo()]
self.A_t = [scipy_to_torch_sparse(x).to(self.device) for x in A_]
self.D_t = [scipy_to_torch_sparse(x).to(self.device) for x in D_]
self.U_t = [scipy_to_torch_sparse(x).to(self.device) for x in U_]
# ---- build model ----
if self.use_skip:
self.model = Hybrid(self.config, self.D_t, self.U_t, self.A_t)
else:
self.model = HybridNoSkip(self.config, self.D_t, self.U_t, self.A_t)
# move model parameters to device
self.model.to(self.device)
def forward(self, x):
return self.model(x)
def encode(self, x):
return self.model.encode(x)
def decode(self, z, conv6, conv5):
return self.model.decode(z, conv6, conv5)
def sampling(self, mu, log_var):
return self.model.sampling(mu, log_var)
@classmethod
def from_pretrained(cls, repo_id, subfolder=None, device="cpu", **kwargs):
"""
Loads model directly from Hugging Face Hub. Does NOT use local paths.
"""
# Download config from Hub
config_file = hf_hub_download(
repo_id=repo_id,
filename="config.json",
subfolder=subfolder
)
with open(config_file, "r") as f:
config = json.load(f)
# Merge any additional kwargs
config.update(kwargs)
# Dynamically compute n_nodes
A, AD, D, U = genMatrixesLungsHeart()
N1, N2 = A.shape[0], AD.shape[0]
config['n_nodes'] = [N1, N1, N1, N2, N2, N2]
# Instantiate model on desired device
model = cls(device=device, **config)
# Download weights from Hub
weights_path = hf_hub_download(
repo_id=repo_id,
filename="pytorch_model.bin",
subfolder=subfolder
)
state_dict = torch.load(weights_path, map_location=device)
if not next(iter(state_dict.keys())).startswith("model."):
state_dict = {f"model.{k}": v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
return model