import re import types from typing import List, Tuple, Union import timm import timm.data import torch import torch.nn.functional as F from einops import rearrange from timm.models.vision_transformer import VisionTransformer from torch import nn from torchvision import transforms # We provide a list of timm model names, more are available on their official repo. MODEL_LIST = [ # DINO "vit_base_patch16_224.dino", # DINOv2 "vit_base_patch14_dinov2.lvd142m", # DINOv2-R "vit_base_patch14_reg4_dinov2", # Franca "franca_vitb14", # DINOv3-ViT "vit_base_patch16_dinov3.lvd1689m", "vit_large_patch16_dinov3.lvd1689m", "vit_7b_patch16_dinov3.lvd1689m", # SigLIP2 "vit_base_patch16_siglip_512.v2_webli", # PE Core "vit_pe_core_small_patch16_384.fb", # PE Spatial "vit_pe_spatial_tiny_patch16_512.fb", # RADIO "radio_v2.5-b", # CAPI "capi_vitl14_lvd", # MAE "vit_large_patch16_224.mae", ] IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) class PretrainedViTWrapper(nn.Module): def __init__( self, name, norm: bool = True, dynamic_img_size: bool = True, dynamic_img_pad: bool = False, **kwargs, ): super().__init__() # comment out the following line to test the models not in the list self.name = name load_weights = False if "dvt_" == name[:4]: load_weights = True load_tag = "dvt" name = name.replace("dvt_", "") if "fit3d_" == name[:6]: load_weights = True load_tag = "fit3d" name = name.replace("fit3d_", "") # Set patch size try: self.patch_size = int(re.search(r"patch(\d+)", name).group(1)) except: self.patch_size = 16 if "franca" in name or "capi" in name: self.patch_size = 14 if "convnext" in name: self.patch_size = 32 name, self.patch_size self.dynamic_img_size = dynamic_img_size self.dynamic_img_pad = dynamic_img_pad self.model, self.config = self.create_model(name, **kwargs) self.config["ps"] = self.patch_size self.embed_dim = self.model.embed_dim self.norm = norm if load_weights: ckpt = torch.load(f"/home/lchambon/workspace/JAFAR/ckpts/{load_tag}_{name}.pth", map_location="cpu") if load_tag == "dvt": self.load_state_dict(ckpt["model"], strict=True) elif load_tag == "fit3d": self.model.load_state_dict(ckpt, strict=True) def create_model(self, name: str, **kwargs) -> Tuple[VisionTransformer, transforms.Compose]: if "radio" in self.name: model = torch.hub.load( "NVlabs/RADIO", "radio_model", version=name, progress=True, skip_validation=True, ) data_config = { "mean": torch.tensor([0.0, 0.0, 0.0]), "std": torch.tensor([1.0, 1.0, 1.0]), "input_size": (3, 512, 512), } elif "franca" in self.name: model = torch.hub.load("valeoai/Franca", name, use_rasa_head=True) data_config = {"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "input_size": (3, 448, 448)} elif "capi" in self.name: model = torch.hub.load("facebookresearch/capi:main", name, force_reload=False) data_config = {"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "input_size": (3, 448, 448)} else: timm_kwargs = dict( pretrained=True, num_classes=0, patch_size=self.patch_size, ) if "sam" not in self.name and "convnext" not in self.name: timm_kwargs["dynamic_img_size"] = self.dynamic_img_size timm_kwargs["dynamic_img_pad"] = self.dynamic_img_pad timm_kwargs.update(kwargs) model = timm.create_model(name, **timm_kwargs) data_config = timm.data.resolve_model_data_config(model=model) model = model.eval() return model, data_config def forward( self, x: torch.Tensor, n: Union[int, List[int], Tuple[int]] = 1, return_prefix_tokens: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """Intermediate layer accessor inspired by DINO / DINOv2 interface. Args: x: Input tensor. n: Take last n blocks if int, all if None, select matching indices if sequence reshape: Whether to reshape the output. """ common_kwargs = dict( norm=self.norm, output_fmt="NCHW", intermediates_only=True, ) if "sam" not in self.name and return_prefix_tokens: common_kwargs["return_prefix_tokens"] = return_prefix_tokens elif "franca" in self.name: B, C, H, W = x.shape feats = self.model.forward_features(x, use_rasa_head=True) out = feats["patch_token_rasa"] out = rearrange(out, "b (h w) c -> b c h w", h=H // self.patch_size, w=W // self.patch_size) elif "capi" in self.name: *_, out = self.model(x) out = out.permute(0, 3, 1, 2) else: out = self.model.forward_intermediates(x, n, **common_kwargs) # "sam" models return feats only, others may return (feats, prefix) if not isinstance(out, list) and not isinstance(out, tuple): out = [out] return out[0] else: assert len(out) == 1, f"Out contains {len(out)} elements, expected 1." return out[0]