Spaces:
Runtime error
Runtime error
| ### demo.py | |
| # Define model classes for inference. | |
| ### | |
| import json | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.backends.cudnn as cudnn | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms._transforms_video as transforms_video | |
| from sklearn.metrics import confusion_matrix | |
| from einops import rearrange | |
| from transformers import BertTokenizer | |
| from svitt.model import SViTT | |
| from svitt.datasets import VideoClassyDataset | |
| from svitt.video_transforms import Permute | |
| from svitt.config import load_cfg, setup_config | |
| from svitt.evaluation_charades import charades_map | |
| from svitt.evaluation import get_mean_accuracy | |
| class VideoModel(nn.Module): | |
| """ Base model for video understanding based on SViTT architecture. """ | |
| def __init__(self, config): | |
| """ Initializes the model. | |
| Parameters: | |
| config: config file | |
| """ | |
| super(VideoModel, self).__init__() | |
| self.cfg = load_cfg(config) | |
| self.model = self.build_model() | |
| use_gpu = torch.cuda.is_available() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if use_gpu: | |
| self.model = self.model.to(self.device) | |
| self.templates = ['{}'] | |
| self.dataset = self.cfg['data']['dataset'] | |
| self.eval() | |
| def build_model(self): | |
| cfg = self.cfg | |
| if cfg['model'].get('pretrain', False): | |
| ckpt_path = cfg['model']['pretrain'] | |
| else: | |
| raise Exception('no checkpoint found') | |
| if cfg['model'].get('config', False): | |
| config_path = cfg['model']['config'] | |
| else: | |
| raise Exception('no model config found') | |
| self.model_cfg = setup_config(config_path) | |
| self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder) | |
| model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer) | |
| print(f"Loading checkpoint from {ckpt_path}") | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| state_dict = checkpoint["model"] | |
| # fix for zero-shot evaluation | |
| for key in list(state_dict.keys()): | |
| if "bert" in key: | |
| encoder_key = key.replace("bert.", "") | |
| state_dict[encoder_key] = state_dict[key] | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| model.load_state_dict(state_dict, strict=False) | |
| return model | |
| def eval(self): | |
| cudnn.benchmark = True | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.model.eval() | |
| class VideoCLSModel(VideoModel): | |
| """ Video model for video classification tasks (Charades-Ego, EGTEA). """ | |
| def __init__(self, config): | |
| super(VideoCLSModel, self).__init__(config) | |
| self.labels, self.mapping_vn2act = self.gen_label_map() | |
| self.text_features = self.get_text_features() | |
| def gen_label_map(self): | |
| labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json') | |
| if os.path.isfile(labelmap): | |
| print(f"=> Loading label maps from {labelmap}") | |
| meta = json.load(open(labelmap, 'r')) | |
| labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act'] | |
| else: | |
| from svitt.preprocess import generate_label_map | |
| labels, mapping_vn2act = generate_label_map(self.dataset) | |
| meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act} | |
| meta_dir = f'meta/{self.dataset}' | |
| if not os.path.exists(meta_dir): | |
| os.makedirs(meta_dir) | |
| json.dump(meta, open(f'{meta_dir}/label_map.json', 'w')) | |
| print(f"=> Label map is generated and saved to {meta_dir}/label_map.json") | |
| return labels, mapping_vn2act | |
| def load_data(self, idx=None): | |
| print(f"=> Creating dataset") | |
| cfg, dataset = self.cfg, self.dataset | |
| data_cfg = cfg['data'] | |
| crop_size = 224 | |
| val_transform = transforms.Compose([ | |
| Permute([3, 0, 1, 2]), # T H W C -> C T H W | |
| transforms.Resize(crop_size), | |
| transforms.CenterCrop(crop_size), | |
| transforms_video.NormalizeVideo( | |
| mean=[108.3272985, 116.7460125, 104.09373615000001], | |
| std=[68.5005327, 66.6321579, 70.32316305], | |
| ), | |
| ]) | |
| if idx is None: | |
| metadata_val = data_cfg['metadata_val'] | |
| else: | |
| metadata_val = data_cfg['metadata_val'].format(idx) | |
| if dataset in ['charades_ego', 'egtea']: | |
| val_dataset = VideoClassyDataset( | |
| dataset, | |
| data_cfg['root'], | |
| metadata_val, | |
| transform=val_transform, | |
| is_training=False, | |
| label_mapping=self.mapping_vn2act, | |
| is_trimmed=False, | |
| num_clips=1, | |
| clip_length=data_cfg['clip_length'], | |
| clip_stride=data_cfg['clip_stride'], | |
| sparse_sample=data_cfg['sparse_sample'], | |
| ) | |
| else: | |
| raise NotImplementedError | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, batch_size=8, shuffle=False, | |
| num_workers=4, pin_memory=True, sampler=None, drop_last=False | |
| ) | |
| return val_loader | |
| def get_text_features(self): | |
| print('=> Extracting text features') | |
| embeddings = self.tokenizer( | |
| self.labels, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.model_cfg.max_txt_l.video, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| _, class_embeddings = self.model.encode_text(embeddings) | |
| return class_embeddings | |
| def forward(self, idx=None): | |
| print('=> Start forwarding') | |
| val_loader = self.load_data(idx) | |
| all_outputs = [] | |
| all_targets = [] | |
| for i, values in enumerate(val_loader): | |
| images = values[0] | |
| target = values[1] | |
| images = images.to(self.device) | |
| # encode images | |
| images = rearrange(images, 'b c k h w -> b k c h w') | |
| dims = images.shape | |
| images = images.reshape(-1, 4, dims[-3], dims[-2], dims[-1]) | |
| image_features, _ = self.model.encode_image(images) | |
| if image_features.ndim == 3: | |
| image_features = rearrange(image_features, '(b k) n d -> b (k n) d', b=1) | |
| else: | |
| image_features = rearrange(image_features, '(b k) d -> b k d', b=1) | |
| # cosine similarity as logits | |
| similarity = self.model.get_sim(image_features, self.text_features)[0] | |
| all_outputs.append(similarity.cpu()) | |
| all_targets.append(target) | |
| all_outputs = torch.cat(all_outputs) | |
| all_targets = torch.cat(all_targets) | |
| return all_outputs, all_targets | |
| def predict(self, idx=0): | |
| all_outputs, all_targets = self.forward(idx) | |
| preds, targets = all_outputs.numpy(), all_targets.numpy() | |
| #sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.06)[0][0] | |
| sel = 5 | |
| df = pd.DataFrame(self.labels) | |
| pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist() | |
| gt_action = df.iloc[np.where(targets[0])[0]].values.tolist() | |
| pred_action = sorted([x[0] for x in pred_action]) | |
| gt_action = sorted([x[0] for x in gt_action]) | |
| return pred_action, gt_action | |
| def evaluate(self): | |
| all_outputs, all_targets = self.forward() | |
| preds, targets = all_outputs.numpy(), all_targets.numpy() | |
| if self.dataset == 'charades_ego': | |
| m_ap, _, m_aps = charades_map(preds, targets) | |
| print('mAP = {:.3f}'.format(m_ap)) | |
| elif self.dataset == 'egtea': | |
| cm = confusion_matrix(targets, preds.argmax(axis=1)) | |
| mean_class_acc, acc = get_mean_accuracy(cm) | |
| print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) | |
| else: | |
| raise NotImplementedError | |