Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,848 Bytes
e4c8837 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import torchvision.transforms as T
from hydra.utils import instantiate
from omegaconf import ListConfig
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms.functional import InterpolationMode
from src.backbone.vit_wrapper import PretrainedViTWrapper
from utils.img import PILToTensor
def seed_worker():
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def round_to_nearest_multiple(value, multiple=14):
return multiple * round(value / multiple)
def compute_feats(cfg, backbone, image_batch, min_rescale=0.60, max_rescale=0.25):
_, _, H, W = image_batch.shape # Get original height and width
with torch.no_grad():
hr_feats = backbone(image_batch)
if cfg.get("lr_img_size", None) is not None:
size = (cfg.lr_img_size, cfg.lr_img_size)
else:
# Downscale
if cfg.down_factor == "random":
downscale_factor = np.random.uniform(min_rescale, max_rescale)
elif cfg.down_factor == "fixed":
downscale_factor = 0.5
new_H = round_to_nearest_multiple(H * downscale_factor, backbone.patch_size)
new_W = round_to_nearest_multiple(W * downscale_factor, backbone.patch_size)
size = (new_H, new_W)
low_res_batch = F.interpolate(image_batch, size=size, mode="bilinear")
lr_feats = backbone(low_res_batch)
return hr_feats, lr_feats
def logger(args, base_log_dir):
os.makedirs(base_log_dir, exist_ok=True)
existing_versions = [
int(d.split("_")[-1])
for d in os.listdir(base_log_dir)
if os.path.isdir(os.path.join(base_log_dir, d)) and d.startswith("version_")
]
new_version = max(existing_versions, default=-1) + 1
new_log_dir = os.path.join(base_log_dir, f"version_{new_version}")
# Create the SummaryWriter with the new log directory
writer = SummaryWriter(log_dir=new_log_dir)
return writer, new_version, new_log_dir
def get_dataloaders(cfg, shuffle=True):
"""Get dataloaders for either training or evaluation.
Args:
cfg: Configuration object
backbone: Backbone model for normalization parameters
"""
# Default ImageNet normalization values
transforms = {
"image": T.Compose(
[
T.Resize(cfg.img_size, interpolation=InterpolationMode.BILINEAR),
T.CenterCrop((cfg.img_size, cfg.img_size)),
T.ToTensor(),
]
)
}
transforms["label"] = T.Compose(
[
# T.ToTensor(),
T.Resize(cfg.target_size, interpolation=InterpolationMode.NEAREST_EXACT),
T.CenterCrop((cfg.target_size, cfg.target_size)),
PILToTensor(),
]
)
train_dataset = cfg.dataset
val_dataset = cfg.dataset.copy()
if hasattr(val_dataset, "split"):
val_dataset.split = "val"
train_dataset = instantiate(
train_dataset,
transform=transforms["image"],
target_transform=transforms["label"],
)
val_dataset = instantiate(
val_dataset,
transform=transforms["image"],
target_transform=transforms["label"],
)
# Create generator for reproducibility
if not shuffle:
g = torch.Generator()
g.manual_seed(0)
else:
g = None
# Prepare dataloader configs - set worker_init_fn to None when shuffling for randomness
train_dataloader_cfg = cfg.train_dataloader.copy()
val_dataloader_cfg = cfg.val_dataloader.copy()
if shuffle:
# Set worker_init_fn to None to allow true randomness when shuffling
if "worker_init_fn" in train_dataloader_cfg:
train_dataloader_cfg["worker_init_fn"] = None
if "worker_init_fn" in val_dataloader_cfg:
val_dataloader_cfg["worker_init_fn"] = None
return (
instantiate(train_dataloader_cfg, dataset=train_dataset, generator=g),
instantiate(val_dataloader_cfg, dataset=val_dataset, generator=g),
)
def get_batch(batch, device):
"""Process batch and return required tensors."""
batch["image"] = batch["image"].to(device)
return batch
def setup_training_optimizations(model, cfg):
"""
Setup training optimizations based on configuration
Args:
model: The model to apply optimizations to
cfg: Configuration object with use_bf16 and use_checkpointing flags
Returns:
tuple: (scaler, use_bf16, use_checkpointing) for use in training loop
"""
# Get configuration values with defaults
use_bf16 = getattr(cfg, "use_bf16", False)
use_checkpointing = getattr(cfg, "use_checkpointing", False)
# Initialize gradient scaler for mixed precision
scaler = torch.amp.GradScaler("cuda", enabled=use_bf16)
# Enable gradient checkpointing if requested
if use_checkpointing:
if hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable()
print(" ✓ Using built-in gradient checkpointing")
else:
# For custom models, wrap forward methods
def checkpoint_wrapper(module):
if hasattr(module, "forward"):
original_forward = module.forward
def checkpointed_forward(*args, **kwargs):
return checkpoint.checkpoint(original_forward, *args, **kwargs)
module.forward = checkpointed_forward
# Apply to key modules (adjust based on your model structure)
checkpointed_modules = []
for name, module in model.named_modules():
if any(key in name for key in ["cross_decode", "encoder", "sft"]):
checkpoint_wrapper(module)
checkpointed_modules.append(name)
if checkpointed_modules:
print(f" ✓ Applied custom gradient checkpointing to: {checkpointed_modules}")
else:
print(" ⚠ No modules found for gradient checkpointing")
print(f"Training optimizations:")
print(f" Mixed precision (bfloat16): {use_bf16}")
print(f" Gradient checkpointing: {use_checkpointing}")
return scaler, use_bf16, use_checkpointing
def load_multiple_backbones(cfg, backbone_configs, device):
"""
Load multiple backbone models based on configuration.
Args:
cfg: Hydra configuration object
device: PyTorch device to load models on
Returns:
tuple: (backbones, backbone_names, primary_backbone)
- backbones: List of loaded backbone models
- backbone_names: List of backbone names
"""
backbones = []
backbone_names = []
backbone_img_sizes = []
if not isinstance(backbone_configs, list) and not isinstance(backbone_configs, ListConfig):
backbone_configs = [backbone_configs]
print(f"Loading {len(backbone_configs)} backbone(s)...")
for i, backbone_config in enumerate(backbone_configs):
name = backbone_config["name"]
if name == "rgb":
backbone = instantiate(cfg.backbone)
else:
backbone = PretrainedViTWrapper(name=name)
print(f" [{i}] Loaded {backbone_config['name']}")
# Move to device and set to eval mode
backbone = backbone.to(device)
backbone.eval() # Set to eval mode for feature extraction
# Store backbone and name
backbones.append(backbone)
backbone_names.append(backbone_config["name"])
backbone_img_sizes.append(backbone.config["input_size"][1:])
return backbones, backbone_names, backbone_img_sizes
|