mcosarinsky commited on
Commit
b698ace
·
1 Parent(s): edb6fcc
models/HybridGNet2IGSC.py CHANGED
@@ -4,6 +4,110 @@ import torch.nn.functional as F
4
  from models.modelUtils import ChebConv, Pool, residualBlock
5
  import torchvision.ops.roi_align as roi_align
6
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class EncoderConv(nn.Module):
9
  def __init__(self, latents = 64, hw = 32):
@@ -205,4 +309,198 @@ class Hybrid(nn.Module):
205
  else:
206
  z = self.mu
207
 
208
- return self.decode(z, conv6, conv5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from models.modelUtils import ChebConv, Pool, residualBlock
5
  import torchvision.ops.roi_align as roi_align
6
  import numpy as np
7
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
8
+ import json
9
+ import scipy.sparse as sp
10
+
11
+
12
+ def scipy_to_torch_sparse(scp_matrix):
13
+ values = scp_matrix.data
14
+ indices = np.vstack((scp_matrix.row, scp_matrix.col))
15
+ i = torch.LongTensor(indices)
16
+ v = torch.FloatTensor(values)
17
+ shape = scp_matrix.shape
18
+
19
+ sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
20
+ return sparse_tensor
21
+
22
+ ## Adjacency Matrix
23
+ def mOrgan(N):
24
+ sub = np.zeros([N, N])
25
+ for i in range(0, N):
26
+ sub[i, i-1] = 1
27
+ sub[i, (i+1)%N] = 1
28
+ return sub
29
+
30
+ ## Downsampling Matrix
31
+ def mOrganD(N):
32
+ N2 = int(np.ceil(N/2))
33
+ sub = np.zeros([N2, N])
34
+
35
+ for i in range(0, N2):
36
+ if (2*i+1) == N:
37
+ sub[i, 2*i] = 1
38
+ else:
39
+ sub[i, 2*i] = 1/2
40
+ sub[i, 2*i+1] = 1/2
41
+
42
+ return sub
43
+
44
+ def mOrganU(N):
45
+ N2 = int(np.ceil(N/2))
46
+ sub = np.zeros([N, N2])
47
+
48
+ for i in range(0, N):
49
+ if i % 2 == 0:
50
+ sub[i, i//2] = 1
51
+ else:
52
+ sub[i, i//2] = 1/2
53
+ sub[i, (i//2 + 1) % N2] = 1/2
54
+
55
+ return sub
56
+
57
+ def genMatrixesLungsHeart():
58
+ RLUNG = 44
59
+ LLUNG = 50
60
+ HEART = 26
61
+
62
+ Asub1 = mOrgan(RLUNG)
63
+ Asub2 = mOrgan(LLUNG)
64
+ Asub3 = mOrgan(HEART)
65
+
66
+ ADsub1 = mOrgan(int(np.ceil(RLUNG / 2)))
67
+ ADsub2 = mOrgan(int(np.ceil(LLUNG / 2)))
68
+ ADsub3 = mOrgan(int(np.ceil(HEART / 2)))
69
+
70
+ Dsub1 = mOrganD(RLUNG)
71
+ Dsub2 = mOrganD(LLUNG)
72
+ Dsub3 = mOrganD(HEART)
73
+
74
+ Usub1 = mOrganU(RLUNG)
75
+ Usub2 = mOrganU(LLUNG)
76
+ Usub3 = mOrganU(HEART)
77
+
78
+ p1 = RLUNG
79
+ p2 = p1 + LLUNG
80
+ p3 = p2 + HEART
81
+
82
+ p1_ = int(np.ceil(RLUNG / 2))
83
+ p2_ = p1_ + int(np.ceil(LLUNG / 2))
84
+ p3_ = p2_ + int(np.ceil(HEART / 2))
85
+
86
+ A = np.zeros([p3, p3])
87
+
88
+ A[:p1, :p1] = Asub1
89
+ A[p1:p2, p1:p2] = Asub2
90
+ A[p2:p3, p2:p3] = Asub3
91
+
92
+ AD = np.zeros([p3_, p3_])
93
+
94
+ AD[:p1_, :p1_] = ADsub1
95
+ AD[p1_:p2_, p1_:p2_] = ADsub2
96
+ AD[p2_:p3_, p2_:p3_] = ADsub3
97
+
98
+ D = np.zeros([p3_, p3])
99
+
100
+ D[:p1_, :p1] = Dsub1
101
+ D[p1_:p2_, p1:p2] = Dsub2
102
+ D[p2_:p3_, p2:p3] = Dsub3
103
+
104
+ U = np.zeros([p3, p3_])
105
+
106
+ U[:p1, :p1_] = Usub1
107
+ U[p1:p2, p1_:p2_] = Usub2
108
+ U[p2:p3, p2_:p3_] = Usub3
109
+
110
+ return A, AD, D, U
111
 
112
  class EncoderConv(nn.Module):
113
  def __init__(self, latents = 64, hw = 32):
 
309
  else:
310
  z = self.mu
311
 
312
+ return self.decode(z, conv6, conv5)
313
+
314
+
315
+
316
+ class HybridNoSkip(nn.Module):
317
+ def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices):
318
+ super(HybridNoSkip, self).__init__()
319
+
320
+ hw = config['inputsize'] // 32
321
+ self.eval_sampling = config['eval_sampling']
322
+ self.z = config['latents']
323
+ self.encoder = EncoderConv(latents = self.z, hw = hw)
324
+
325
+ self.downsample_matrices = downsample_matrices
326
+ self.upsample_matrices = upsample_matrices
327
+ self.adjacency_matrices = adjacency_matrices
328
+ self.kld_weight = 1e-5
329
+
330
+ n_nodes = config['n_nodes']
331
+ self.filters = config['filters']
332
+ self.K = 6
333
+
334
+ # Genero la capa fully connected del decoder
335
+ outshape = self.filters[-1] * n_nodes[-1]
336
+ self.dec_lin = torch.nn.Linear(self.z, outshape)
337
+
338
+ self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1])
339
+ self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2])
340
+ self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3])
341
+ self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4])
342
+ self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5])
343
+
344
+ self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K)
345
+ self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K)
346
+ self.graphConv_up4 = ChebConv(self.filters[4], self.filters[3], self.K)
347
+ self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K)
348
+ self.graphConv_up2 = ChebConv(self.filters[2], self.filters[1], self.K)
349
+
350
+ ## Out layer: Sin bias, normalization ni relu
351
+ self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False)
352
+
353
+ self.pool = Pool()
354
+
355
+ self.reset_parameters()
356
+
357
+ def reset_parameters(self):
358
+ torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1)
359
+
360
+ def sampling(self, mu, log_var):
361
+ std = torch.exp(0.5*log_var)
362
+ eps = torch.randn_like(std)
363
+ return eps.mul(std).add_(mu)
364
+
365
+ def encode(self, x):
366
+ mu, log_var, conv6, conv5 = self.encoder(x)
367
+ return mu, log_var, conv6, conv5
368
+
369
+ def decode(self, z, conv6, conv5):
370
+ # Decode from latent space z to reconstruct x
371
+ x = self.dec_lin(z)
372
+ x = F.relu(x)
373
+ x = x.reshape(x.shape[0], -1, self.filters[-1])
374
+
375
+ x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices())
376
+ x = self.normalization6u(x)
377
+ x = F.relu(x)
378
+
379
+ x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices())
380
+ x = self.normalization5u(x)
381
+ x = F.relu(x)
382
+
383
+ x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices())
384
+ x = self.normalization4u(x)
385
+ x = F.relu(x)
386
+
387
+ x = self.pool(x, self.upsample_matrices[0])
388
+
389
+ x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices())
390
+ x = self.normalization3u(x)
391
+ x = F.relu(x)
392
+
393
+ x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices())
394
+ x = self.normalization2u(x)
395
+ x = F.relu(x)
396
+
397
+ x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) # No relu and no bias
398
+ return x, None, None
399
+
400
+ def forward(self, x):
401
+ # Full forward pass: encode, sample (if training), then decode.
402
+ self.mu, self.log_var, conv6, conv5 = self.encode(x)
403
+
404
+ if self.training:
405
+ z = self.sampling(self.mu, self.log_var)
406
+ else:
407
+ z = self.mu
408
+
409
+ return self.decode(z, conv6, conv5)
410
+
411
+
412
+ class HybridGNetHF(nn.Module, PyTorchModelHubMixin):
413
+ repo_url = "https://github.com/mcosarinsky/CheXmask-U"
414
+ license = "mit"
415
+ pipeline_tag = "image-segmentation"
416
+
417
+ def __init__(self, latents=64, inputsize=1024, K=6, filters=None,
418
+ skip_features=32, eval_sampling=True, use_skip=True,
419
+ n_nodes=None, device="cpu", **kwargs):
420
+ super().__init__()
421
+
422
+ self.device = device
423
+
424
+ # Defaults
425
+ if filters is None:
426
+ filters = [2, 32, 32, 32, 16, 16, 16]
427
+
428
+ # Save config
429
+ self.config = {
430
+ 'latents': latents,
431
+ 'inputsize': inputsize,
432
+ 'K': K,
433
+ 'filters': filters,
434
+ 'skip_features': skip_features,
435
+ 'eval_sampling': eval_sampling,
436
+ 'use_skip': use_skip
437
+ }
438
+ self.config.update(kwargs)
439
+ self.use_skip = use_skip
440
+
441
+ # ---- generate matrices ----
442
+ A, AD, D, U = genMatrixesLungsHeart()
443
+ N1, N2 = A.shape[0], AD.shape[0]
444
+ self.config['n_nodes'] = [N1, N1, N1, N2, N2, N2]
445
+
446
+ # ---- convert to sparse tensors and move to device ----
447
+ A_ = [sp.csc_matrix(A).tocoo() for _ in range(3)] + [sp.csc_matrix(AD).tocoo() for _ in range(3)]
448
+ D_ = [sp.csc_matrix(D).tocoo()]
449
+ U_ = [sp.csc_matrix(U).tocoo()]
450
+
451
+ self.A_t = [scipy_to_torch_sparse(x).to(self.device) for x in A_]
452
+ self.D_t = [scipy_to_torch_sparse(x).to(self.device) for x in D_]
453
+ self.U_t = [scipy_to_torch_sparse(x).to(self.device) for x in U_]
454
+
455
+ # ---- build model ----
456
+ if self.use_skip:
457
+ self.model = Hybrid(self.config, self.D_t, self.U_t, self.A_t)
458
+ else:
459
+ self.model = HybridNoSkip(self.config, self.D_t, self.U_t, self.A_t)
460
+
461
+ # move model parameters to device
462
+ self.model.to(self.device)
463
+
464
+ def forward(self, x):
465
+ return self.model(x)
466
+
467
+ # -----------------------------
468
+ # Dynamic from_pretrained from Hugging Face Hub ONLY
469
+ # -----------------------------
470
+ @classmethod
471
+ def from_pretrained(cls, repo_id, subfolder=None, device="cpu", **kwargs):
472
+ """
473
+ Loads model directly from Hugging Face Hub. Does NOT use local paths.
474
+ """
475
+ # Download config from Hub
476
+ config_file = hf_hub_download(
477
+ repo_id=repo_id,
478
+ filename="config.json",
479
+ subfolder=subfolder
480
+ )
481
+ with open(config_file, "r") as f:
482
+ config = json.load(f)
483
+
484
+ # Merge any additional kwargs
485
+ config.update(kwargs)
486
+
487
+ # Dynamically compute n_nodes
488
+ A, AD, D, U = genMatrixesLungsHeart()
489
+ N1, N2 = A.shape[0], AD.shape[0]
490
+ config['n_nodes'] = [N1, N1, N1, N2, N2, N2]
491
+
492
+ # Instantiate model on desired device
493
+ model = cls(device=device, **config)
494
+
495
+ # Download weights from Hub
496
+ weights_path = hf_hub_download(
497
+ repo_id=repo_id,
498
+ filename="pytorch_model.bin",
499
+ subfolder=subfolder
500
+ )
501
+ state_dict = torch.load(weights_path, map_location=device)
502
+ if not next(iter(state_dict.keys())).startswith("model."):
503
+ state_dict = {f"model.{k}": v for k, v in state_dict.items()}
504
+ model.load_state_dict(state_dict)
505
+
506
+ return model
requirements.txt CHANGED
@@ -4,3 +4,4 @@ opencv-python==4.8.0.74
4
  scipy==1.10.1
5
  torch_geometric==2.3.0
6
  torchvision==0.15.2
 
 
4
  scipy==1.10.1
5
  torch_geometric==2.3.0
6
  torchvision==0.15.2
7
+ huggingface_hub==1.2.3
utils/segmentation.py CHANGED
@@ -9,7 +9,7 @@ from zipfile import ZipFile
9
  from .plotting import plot_side_by_side_comparison
10
 
11
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
- from models.HybridGNet2IGSC import Hybrid
13
 
14
  hybrid = None
15
 
@@ -20,106 +20,6 @@ def seed_everything(seed=42):
20
  if torch.cuda.is_available():
21
  torch.cuda.manual_seed_all(seed)
22
 
23
- def scipy_to_torch_sparse(scp_matrix):
24
- values = scp_matrix.data
25
- indices = np.vstack((scp_matrix.row, scp_matrix.col))
26
- i = torch.LongTensor(indices)
27
- v = torch.FloatTensor(values)
28
- shape = scp_matrix.shape
29
-
30
- sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
31
- return sparse_tensor
32
-
33
- ## Adjacency Matrix
34
- def mOrgan(N):
35
- sub = np.zeros([N, N])
36
- for i in range(0, N):
37
- sub[i, i-1] = 1
38
- sub[i, (i+1)%N] = 1
39
- return sub
40
-
41
- ## Downsampling Matrix
42
- def mOrganD(N):
43
- N2 = int(np.ceil(N/2))
44
- sub = np.zeros([N2, N])
45
-
46
- for i in range(0, N2):
47
- if (2*i+1) == N:
48
- sub[i, 2*i] = 1
49
- else:
50
- sub[i, 2*i] = 1/2
51
- sub[i, 2*i+1] = 1/2
52
-
53
- return sub
54
-
55
- def mOrganU(N):
56
- N2 = int(np.ceil(N/2))
57
- sub = np.zeros([N, N2])
58
-
59
- for i in range(0, N):
60
- if i % 2 == 0:
61
- sub[i, i//2] = 1
62
- else:
63
- sub[i, i//2] = 1/2
64
- sub[i, (i//2 + 1) % N2] = 1/2
65
-
66
- return sub
67
-
68
- def genMatrixesLungsHeart():
69
- RLUNG = 44
70
- LLUNG = 50
71
- HEART = 26
72
-
73
- Asub1 = mOrgan(RLUNG)
74
- Asub2 = mOrgan(LLUNG)
75
- Asub3 = mOrgan(HEART)
76
-
77
- ADsub1 = mOrgan(int(np.ceil(RLUNG / 2)))
78
- ADsub2 = mOrgan(int(np.ceil(LLUNG / 2)))
79
- ADsub3 = mOrgan(int(np.ceil(HEART / 2)))
80
-
81
- Dsub1 = mOrganD(RLUNG)
82
- Dsub2 = mOrganD(LLUNG)
83
- Dsub3 = mOrganD(HEART)
84
-
85
- Usub1 = mOrganU(RLUNG)
86
- Usub2 = mOrganU(LLUNG)
87
- Usub3 = mOrganU(HEART)
88
-
89
- p1 = RLUNG
90
- p2 = p1 + LLUNG
91
- p3 = p2 + HEART
92
-
93
- p1_ = int(np.ceil(RLUNG / 2))
94
- p2_ = p1_ + int(np.ceil(LLUNG / 2))
95
- p3_ = p2_ + int(np.ceil(HEART / 2))
96
-
97
- A = np.zeros([p3, p3])
98
-
99
- A[:p1, :p1] = Asub1
100
- A[p1:p2, p1:p2] = Asub2
101
- A[p2:p3, p2:p3] = Asub3
102
-
103
- AD = np.zeros([p3_, p3_])
104
-
105
- AD[:p1_, :p1_] = ADsub1
106
- AD[p1_:p2_, p1_:p2_] = ADsub2
107
- AD[p2_:p3_, p2_:p3_] = ADsub3
108
-
109
- D = np.zeros([p3_, p3])
110
-
111
- D[:p1_, :p1] = Dsub1
112
- D[p1_:p2_, p1:p2] = Dsub2
113
- D[p2_:p3_, p2:p3] = Dsub3
114
-
115
- U = np.zeros([p3, p3_])
116
-
117
- U[:p1, :p1_] = Usub1
118
- U[p1:p2, p1_:p2_] = Usub2
119
- U[p2:p3, p2_:p3_] = Usub3
120
-
121
- return A, AD, D, U
122
-
123
  def zip_files(files, output_name="complete_results.zip"):
124
  with ZipFile(output_name, "w") as zipObj:
125
  for file in files:
@@ -167,16 +67,11 @@ def removePreprocess(output, info):
167
 
168
  def loadModel(device):
169
  global hybrid
170
- A, AD, D, U = genMatrixesLungsHeart()
171
- N1, N2 = A.shape[0], AD.shape[0]
172
- A, AD, D, U = [sp.csc_matrix(x).tocoo() for x in [A, AD, D, U]]
173
- D_, U_ = [D.copy()], [U.copy()]
174
- A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()]
175
- config = {'n_nodes':[N1,N1,N1,N2,N2,N2], 'latents':64, 'inputsize':1024,
176
- 'filters':[2,32,32,32,16,16,16], 'skip_features':32, 'eval_sampling':True}
177
- A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_,D_,U_))
178
- hybrid = Hybrid(config.copy(), D_t, U_t, A_t).to(device)
179
- hybrid.load_state_dict(torch.load("weights/weights.pt", map_location=device))
180
  hybrid.eval()
181
  return hybrid
182
 
 
9
  from .plotting import plot_side_by_side_comparison
10
 
11
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+ from models.HybridGNet2IGSC import HybridGNetHF
13
 
14
  hybrid = None
15
 
 
20
  if torch.cuda.is_available():
21
  torch.cuda.manual_seed_all(seed)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def zip_files(files, output_name="complete_results.zip"):
24
  with ZipFile(output_name, "w") as zipObj:
25
  for file in files:
 
67
 
68
  def loadModel(device):
69
  global hybrid
70
+ hybrid = HybridGNetHF.from_pretrained(
71
+ repo_id="mcosarinsky/CheXmask-U",
72
+ subfolder="v1_skip",
73
+ device=device
74
+ )
 
 
 
 
 
75
  hybrid.eval()
76
  return hybrid
77