File size: 7,848 Bytes
e4c8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import torchvision.transforms as T
from hydra.utils import instantiate
from omegaconf import ListConfig
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms.functional import InterpolationMode

from src.backbone.vit_wrapper import PretrainedViTWrapper
from utils.img import PILToTensor


def seed_worker():
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def round_to_nearest_multiple(value, multiple=14):
    return multiple * round(value / multiple)


def compute_feats(cfg, backbone, image_batch, min_rescale=0.60, max_rescale=0.25):
    _, _, H, W = image_batch.shape  # Get original height and width

    with torch.no_grad():
        hr_feats = backbone(image_batch)

        if cfg.get("lr_img_size", None) is not None:
            size = (cfg.lr_img_size, cfg.lr_img_size)
        else:
            # Downscale
            if cfg.down_factor == "random":
                downscale_factor = np.random.uniform(min_rescale, max_rescale)

            elif cfg.down_factor == "fixed":
                downscale_factor = 0.5

            new_H = round_to_nearest_multiple(H * downscale_factor, backbone.patch_size)
            new_W = round_to_nearest_multiple(W * downscale_factor, backbone.patch_size)
            size = (new_H, new_W)
        low_res_batch = F.interpolate(image_batch, size=size, mode="bilinear")
        lr_feats = backbone(low_res_batch)

        return hr_feats, lr_feats


def logger(args, base_log_dir):
    os.makedirs(base_log_dir, exist_ok=True)
    existing_versions = [
        int(d.split("_")[-1])
        for d in os.listdir(base_log_dir)
        if os.path.isdir(os.path.join(base_log_dir, d)) and d.startswith("version_")
    ]
    new_version = max(existing_versions, default=-1) + 1
    new_log_dir = os.path.join(base_log_dir, f"version_{new_version}")

    # Create the SummaryWriter with the new log directory
    writer = SummaryWriter(log_dir=new_log_dir)
    return writer, new_version, new_log_dir


def get_dataloaders(cfg, shuffle=True):
    """Get dataloaders for either training or evaluation.

    Args:
        cfg: Configuration object
        backbone: Backbone model for normalization parameters
    """
    # Default ImageNet normalization values
    transforms = {
        "image": T.Compose(
            [
                T.Resize(cfg.img_size, interpolation=InterpolationMode.BILINEAR),
                T.CenterCrop((cfg.img_size, cfg.img_size)),
                T.ToTensor(),
            ]
        )
    }

    transforms["label"] = T.Compose(
        [
            # T.ToTensor(),
            T.Resize(cfg.target_size, interpolation=InterpolationMode.NEAREST_EXACT),
            T.CenterCrop((cfg.target_size, cfg.target_size)),
            PILToTensor(),
        ]
    )
    train_dataset = cfg.dataset
    val_dataset = cfg.dataset.copy()
    if hasattr(val_dataset, "split"):
        val_dataset.split = "val"

    train_dataset = instantiate(
        train_dataset,
        transform=transforms["image"],
        target_transform=transforms["label"],
    )
    val_dataset = instantiate(
        val_dataset,
        transform=transforms["image"],
        target_transform=transforms["label"],
    )

    # Create generator for reproducibility
    if not shuffle:
        g = torch.Generator()
        g.manual_seed(0)
    else:
        g = None

    # Prepare dataloader configs - set worker_init_fn to None when shuffling for randomness
    train_dataloader_cfg = cfg.train_dataloader.copy()
    val_dataloader_cfg = cfg.val_dataloader.copy()

    if shuffle:
        # Set worker_init_fn to None to allow true randomness when shuffling
        if "worker_init_fn" in train_dataloader_cfg:
            train_dataloader_cfg["worker_init_fn"] = None
        if "worker_init_fn" in val_dataloader_cfg:
            val_dataloader_cfg["worker_init_fn"] = None

    return (
        instantiate(train_dataloader_cfg, dataset=train_dataset, generator=g),
        instantiate(val_dataloader_cfg, dataset=val_dataset, generator=g),
    )


def get_batch(batch, device):
    """Process batch and return required tensors."""
    batch["image"] = batch["image"].to(device)
    return batch


def setup_training_optimizations(model, cfg):
    """
    Setup training optimizations based on configuration

    Args:
        model: The model to apply optimizations to
        cfg: Configuration object with use_bf16 and use_checkpointing flags

    Returns:
        tuple: (scaler, use_bf16, use_checkpointing) for use in training loop
    """
    # Get configuration values with defaults
    use_bf16 = getattr(cfg, "use_bf16", False)
    use_checkpointing = getattr(cfg, "use_checkpointing", False)

    # Initialize gradient scaler for mixed precision
    scaler = torch.amp.GradScaler("cuda", enabled=use_bf16)

    # Enable gradient checkpointing if requested
    if use_checkpointing:
        if hasattr(model, "gradient_checkpointing_enable"):
            model.gradient_checkpointing_enable()
            print("   ✓ Using built-in gradient checkpointing")
        else:
            # For custom models, wrap forward methods
            def checkpoint_wrapper(module):
                if hasattr(module, "forward"):
                    original_forward = module.forward

                    def checkpointed_forward(*args, **kwargs):
                        return checkpoint.checkpoint(original_forward, *args, **kwargs)

                    module.forward = checkpointed_forward

            # Apply to key modules (adjust based on your model structure)
            checkpointed_modules = []
            for name, module in model.named_modules():
                if any(key in name for key in ["cross_decode", "encoder", "sft"]):
                    checkpoint_wrapper(module)
                    checkpointed_modules.append(name)

            if checkpointed_modules:
                print(f"   ✓ Applied custom gradient checkpointing to: {checkpointed_modules}")
            else:
                print("   ⚠ No modules found for gradient checkpointing")

    print(f"Training optimizations:")
    print(f"  Mixed precision (bfloat16): {use_bf16}")
    print(f"  Gradient checkpointing: {use_checkpointing}")

    return scaler, use_bf16, use_checkpointing


def load_multiple_backbones(cfg, backbone_configs, device):
    """
    Load multiple backbone models based on configuration.

    Args:
        cfg: Hydra configuration object
        device: PyTorch device to load models on

    Returns:
        tuple: (backbones, backbone_names, primary_backbone)
            - backbones: List of loaded backbone models
            - backbone_names: List of backbone names
    """
    backbones = []
    backbone_names = []
    backbone_img_sizes = []

    if not isinstance(backbone_configs, list) and not isinstance(backbone_configs, ListConfig):
        backbone_configs = [backbone_configs]
    print(f"Loading {len(backbone_configs)} backbone(s)...")

    for i, backbone_config in enumerate(backbone_configs):
        name = backbone_config["name"]
        if name == "rgb":
            backbone = instantiate(cfg.backbone)
        else:
            backbone = PretrainedViTWrapper(name=name)
        print(f"  [{i}] Loaded {backbone_config['name']}")

        # Move to device and set to eval mode
        backbone = backbone.to(device)
        backbone.eval()  # Set to eval mode for feature extraction

        # Store backbone and name
        backbones.append(backbone)
        backbone_names.append(backbone_config["name"])
        backbone_img_sizes.append(backbone.config["input_size"][1:])

    return backbones, backbone_names, backbone_img_sizes