Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as checkpoint | |
| import torchvision.transforms as T | |
| from hydra.utils import instantiate | |
| from omegaconf import ListConfig | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torchvision.transforms.functional import InterpolationMode | |
| from src.backbone.vit_wrapper import PretrainedViTWrapper | |
| from utils.img import PILToTensor | |
| def seed_worker(): | |
| worker_seed = torch.initial_seed() % 2**32 | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) | |
| def round_to_nearest_multiple(value, multiple=14): | |
| return multiple * round(value / multiple) | |
| def compute_feats(cfg, backbone, image_batch, min_rescale=0.60, max_rescale=0.25): | |
| _, _, H, W = image_batch.shape # Get original height and width | |
| with torch.no_grad(): | |
| hr_feats = backbone(image_batch) | |
| if cfg.get("lr_img_size", None) is not None: | |
| size = (cfg.lr_img_size, cfg.lr_img_size) | |
| else: | |
| # Downscale | |
| if cfg.down_factor == "random": | |
| downscale_factor = np.random.uniform(min_rescale, max_rescale) | |
| elif cfg.down_factor == "fixed": | |
| downscale_factor = 0.5 | |
| new_H = round_to_nearest_multiple(H * downscale_factor, backbone.patch_size) | |
| new_W = round_to_nearest_multiple(W * downscale_factor, backbone.patch_size) | |
| size = (new_H, new_W) | |
| low_res_batch = F.interpolate(image_batch, size=size, mode="bilinear") | |
| lr_feats = backbone(low_res_batch) | |
| return hr_feats, lr_feats | |
| def logger(args, base_log_dir): | |
| os.makedirs(base_log_dir, exist_ok=True) | |
| existing_versions = [ | |
| int(d.split("_")[-1]) | |
| for d in os.listdir(base_log_dir) | |
| if os.path.isdir(os.path.join(base_log_dir, d)) and d.startswith("version_") | |
| ] | |
| new_version = max(existing_versions, default=-1) + 1 | |
| new_log_dir = os.path.join(base_log_dir, f"version_{new_version}") | |
| # Create the SummaryWriter with the new log directory | |
| writer = SummaryWriter(log_dir=new_log_dir) | |
| return writer, new_version, new_log_dir | |
| def get_dataloaders(cfg, shuffle=True): | |
| """Get dataloaders for either training or evaluation. | |
| Args: | |
| cfg: Configuration object | |
| backbone: Backbone model for normalization parameters | |
| """ | |
| # Default ImageNet normalization values | |
| transforms = { | |
| "image": T.Compose( | |
| [ | |
| T.Resize(cfg.img_size, interpolation=InterpolationMode.BILINEAR), | |
| T.CenterCrop((cfg.img_size, cfg.img_size)), | |
| T.ToTensor(), | |
| ] | |
| ) | |
| } | |
| transforms["label"] = T.Compose( | |
| [ | |
| # T.ToTensor(), | |
| T.Resize(cfg.target_size, interpolation=InterpolationMode.NEAREST_EXACT), | |
| T.CenterCrop((cfg.target_size, cfg.target_size)), | |
| PILToTensor(), | |
| ] | |
| ) | |
| train_dataset = cfg.dataset | |
| val_dataset = cfg.dataset.copy() | |
| if hasattr(val_dataset, "split"): | |
| val_dataset.split = "val" | |
| train_dataset = instantiate( | |
| train_dataset, | |
| transform=transforms["image"], | |
| target_transform=transforms["label"], | |
| ) | |
| val_dataset = instantiate( | |
| val_dataset, | |
| transform=transforms["image"], | |
| target_transform=transforms["label"], | |
| ) | |
| # Create generator for reproducibility | |
| if not shuffle: | |
| g = torch.Generator() | |
| g.manual_seed(0) | |
| else: | |
| g = None | |
| # Prepare dataloader configs - set worker_init_fn to None when shuffling for randomness | |
| train_dataloader_cfg = cfg.train_dataloader.copy() | |
| val_dataloader_cfg = cfg.val_dataloader.copy() | |
| if shuffle: | |
| # Set worker_init_fn to None to allow true randomness when shuffling | |
| if "worker_init_fn" in train_dataloader_cfg: | |
| train_dataloader_cfg["worker_init_fn"] = None | |
| if "worker_init_fn" in val_dataloader_cfg: | |
| val_dataloader_cfg["worker_init_fn"] = None | |
| return ( | |
| instantiate(train_dataloader_cfg, dataset=train_dataset, generator=g), | |
| instantiate(val_dataloader_cfg, dataset=val_dataset, generator=g), | |
| ) | |
| def get_batch(batch, device): | |
| """Process batch and return required tensors.""" | |
| batch["image"] = batch["image"].to(device) | |
| return batch | |
| def setup_training_optimizations(model, cfg): | |
| """ | |
| Setup training optimizations based on configuration | |
| Args: | |
| model: The model to apply optimizations to | |
| cfg: Configuration object with use_bf16 and use_checkpointing flags | |
| Returns: | |
| tuple: (scaler, use_bf16, use_checkpointing) for use in training loop | |
| """ | |
| # Get configuration values with defaults | |
| use_bf16 = getattr(cfg, "use_bf16", False) | |
| use_checkpointing = getattr(cfg, "use_checkpointing", False) | |
| # Initialize gradient scaler for mixed precision | |
| scaler = torch.amp.GradScaler("cuda", enabled=use_bf16) | |
| # Enable gradient checkpointing if requested | |
| if use_checkpointing: | |
| if hasattr(model, "gradient_checkpointing_enable"): | |
| model.gradient_checkpointing_enable() | |
| print(" ✓ Using built-in gradient checkpointing") | |
| else: | |
| # For custom models, wrap forward methods | |
| def checkpoint_wrapper(module): | |
| if hasattr(module, "forward"): | |
| original_forward = module.forward | |
| def checkpointed_forward(*args, **kwargs): | |
| return checkpoint.checkpoint(original_forward, *args, **kwargs) | |
| module.forward = checkpointed_forward | |
| # Apply to key modules (adjust based on your model structure) | |
| checkpointed_modules = [] | |
| for name, module in model.named_modules(): | |
| if any(key in name for key in ["cross_decode", "encoder", "sft"]): | |
| checkpoint_wrapper(module) | |
| checkpointed_modules.append(name) | |
| if checkpointed_modules: | |
| print(f" ✓ Applied custom gradient checkpointing to: {checkpointed_modules}") | |
| else: | |
| print(" ⚠ No modules found for gradient checkpointing") | |
| print(f"Training optimizations:") | |
| print(f" Mixed precision (bfloat16): {use_bf16}") | |
| print(f" Gradient checkpointing: {use_checkpointing}") | |
| return scaler, use_bf16, use_checkpointing | |
| def load_multiple_backbones(cfg, backbone_configs, device): | |
| """ | |
| Load multiple backbone models based on configuration. | |
| Args: | |
| cfg: Hydra configuration object | |
| device: PyTorch device to load models on | |
| Returns: | |
| tuple: (backbones, backbone_names, primary_backbone) | |
| - backbones: List of loaded backbone models | |
| - backbone_names: List of backbone names | |
| """ | |
| backbones = [] | |
| backbone_names = [] | |
| backbone_img_sizes = [] | |
| if not isinstance(backbone_configs, list) and not isinstance(backbone_configs, ListConfig): | |
| backbone_configs = [backbone_configs] | |
| print(f"Loading {len(backbone_configs)} backbone(s)...") | |
| for i, backbone_config in enumerate(backbone_configs): | |
| name = backbone_config["name"] | |
| if name == "rgb": | |
| backbone = instantiate(cfg.backbone) | |
| else: | |
| backbone = PretrainedViTWrapper(name=name) | |
| print(f" [{i}] Loaded {backbone_config['name']}") | |
| # Move to device and set to eval mode | |
| backbone = backbone.to(device) | |
| backbone.eval() # Set to eval mode for feature extraction | |
| # Store backbone and name | |
| backbones.append(backbone) | |
| backbone_names.append(backbone_config["name"]) | |
| backbone_img_sizes.append(backbone.config["input_size"][1:]) | |
| return backbones, backbone_names, backbone_img_sizes | |