Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from collections import defaultdict | |
| from collections.abc import Sequence | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as F | |
| from mmcv.transforms import BaseTransform | |
| from mmengine.utils import is_str | |
| from PIL import Image | |
| from mmpretrain.registry import TRANSFORMS | |
| from mmpretrain.structures import DataSample, MultiTaskDataSample | |
| def to_tensor(data): | |
| """Convert objects of various python types to :obj:`torch.Tensor`. | |
| Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |
| :class:`Sequence`, :class:`int` and :class:`float`. | |
| """ | |
| if isinstance(data, torch.Tensor): | |
| return data | |
| elif isinstance(data, np.ndarray): | |
| return torch.from_numpy(data) | |
| elif isinstance(data, Sequence) and not is_str(data): | |
| return torch.tensor(data) | |
| elif isinstance(data, int): | |
| return torch.LongTensor([data]) | |
| elif isinstance(data, float): | |
| return torch.FloatTensor([data]) | |
| else: | |
| raise TypeError( | |
| f'Type {type(data)} cannot be converted to tensor.' | |
| 'Supported types are: `numpy.ndarray`, `torch.Tensor`, ' | |
| '`Sequence`, `int` and `float`') | |
| class PackInputs(BaseTransform): | |
| """Pack the inputs data. | |
| **Required Keys:** | |
| - ``input_key`` | |
| - ``*algorithm_keys`` | |
| - ``*meta_keys`` | |
| **Deleted Keys:** | |
| All other keys in the dict. | |
| **Added Keys:** | |
| - inputs (:obj:`torch.Tensor`): The forward data of models. | |
| - data_samples (:obj:`~mmpretrain.structures.DataSample`): The | |
| annotation info of the sample. | |
| Args: | |
| input_key (str): The key of element to feed into the model forwarding. | |
| Defaults to 'img'. | |
| algorithm_keys (Sequence[str]): The keys of custom elements to be used | |
| in the algorithm. Defaults to an empty tuple. | |
| meta_keys (Sequence[str]): The keys of meta information to be saved in | |
| the data sample. Defaults to :attr:`PackInputs.DEFAULT_META_KEYS`. | |
| .. admonition:: Default algorithm keys | |
| Besides the specified ``algorithm_keys``, we will set some default keys | |
| into the output data sample and do some formatting. Therefore, you | |
| don't need to set these keys in the ``algorithm_keys``. | |
| - ``gt_label``: The ground-truth label. The value will be converted | |
| into a 1-D tensor. | |
| - ``gt_score``: The ground-truth score. The value will be converted | |
| into a 1-D tensor. | |
| - ``mask``: The mask for some self-supervise tasks. The value will | |
| be converted into a tensor. | |
| .. admonition:: Default meta keys | |
| - ``sample_idx``: The id of the image sample. | |
| - ``img_path``: The path to the image file. | |
| - ``ori_shape``: The original shape of the image as a tuple (H, W). | |
| - ``img_shape``: The shape of the image after the pipeline as a | |
| tuple (H, W). | |
| - ``scale_factor``: The scale factor between the resized image and | |
| the original image. | |
| - ``flip``: A boolean indicating if image flip transform was used. | |
| - ``flip_direction``: The flipping direction. | |
| """ | |
| DEFAULT_META_KEYS = ('sample_idx', 'img_path', 'ori_shape', 'img_shape', | |
| 'scale_factor', 'flip', 'flip_direction') | |
| def __init__(self, | |
| input_key='img', | |
| algorithm_keys=(), | |
| meta_keys=DEFAULT_META_KEYS): | |
| self.input_key = input_key | |
| self.algorithm_keys = algorithm_keys | |
| self.meta_keys = meta_keys | |
| def format_input(input_): | |
| if isinstance(input_, list): | |
| return [PackInputs.format_input(item) for item in input_] | |
| elif isinstance(input_, np.ndarray): | |
| if input_.ndim == 2: # For grayscale image. | |
| input_ = np.expand_dims(input_, -1) | |
| if input_.ndim == 3 and not input_.flags.c_contiguous: | |
| input_ = np.ascontiguousarray(input_.transpose(2, 0, 1)) | |
| input_ = to_tensor(input_) | |
| elif input_.ndim == 3: | |
| # convert to tensor first to accelerate, see | |
| # https://github.com/open-mmlab/mmdetection/pull/9533 | |
| input_ = to_tensor(input_).permute(2, 0, 1).contiguous() | |
| else: | |
| # convert input with other shape to tensor without permute, | |
| # like video input (num_crops, C, T, H, W). | |
| input_ = to_tensor(input_) | |
| elif isinstance(input_, Image.Image): | |
| input_ = F.pil_to_tensor(input_) | |
| elif not isinstance(input_, torch.Tensor): | |
| raise TypeError(f'Unsupported input type {type(input_)}.') | |
| return input_ | |
| def transform(self, results: dict) -> dict: | |
| """Method to pack the input data.""" | |
| packed_results = dict() | |
| if self.input_key in results: | |
| input_ = results[self.input_key] | |
| packed_results['inputs'] = self.format_input(input_) | |
| data_sample = DataSample() | |
| # Set default keys | |
| if 'gt_label' in results: | |
| data_sample.set_gt_label(results['gt_label']) | |
| if 'gt_score' in results: | |
| data_sample.set_gt_score(results['gt_score']) | |
| if 'mask' in results: | |
| data_sample.set_mask(results['mask']) | |
| # Set custom algorithm keys | |
| for key in self.algorithm_keys: | |
| if key in results: | |
| data_sample.set_field(results[key], key) | |
| # Set meta keys | |
| for key in self.meta_keys: | |
| if key in results: | |
| data_sample.set_field(results[key], key, field_type='metainfo') | |
| packed_results['data_samples'] = data_sample | |
| return packed_results | |
| def __repr__(self) -> str: | |
| repr_str = self.__class__.__name__ | |
| repr_str += f"(input_key='{self.input_key}', " | |
| repr_str += f'algorithm_keys={self.algorithm_keys}, ' | |
| repr_str += f'meta_keys={self.meta_keys})' | |
| return repr_str | |
| class PackMultiTaskInputs(BaseTransform): | |
| """Convert all image labels of multi-task dataset to a dict of tensor. | |
| Args: | |
| multi_task_fields (Sequence[str]): | |
| input_key (str): | |
| task_handlers (dict): | |
| """ | |
| def __init__(self, | |
| multi_task_fields, | |
| input_key='img', | |
| task_handlers=dict()): | |
| self.multi_task_fields = multi_task_fields | |
| self.input_key = input_key | |
| self.task_handlers = defaultdict(PackInputs) | |
| for task_name, task_handler in task_handlers.items(): | |
| self.task_handlers[task_name] = TRANSFORMS.build(task_handler) | |
| def transform(self, results: dict) -> dict: | |
| """Method to pack the input data. | |
| result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3}, | |
| 'img': array([[[ 0, 0, 0]) | |
| """ | |
| packed_results = dict() | |
| results = results.copy() | |
| if self.input_key in results: | |
| input_ = results[self.input_key] | |
| packed_results['inputs'] = PackInputs.format_input(input_) | |
| task_results = defaultdict(dict) | |
| for field in self.multi_task_fields: | |
| if field in results: | |
| value = results.pop(field) | |
| for k, v in value.items(): | |
| task_results[k].update({field: v}) | |
| data_sample = MultiTaskDataSample() | |
| for task_name, task_result in task_results.items(): | |
| task_handler = self.task_handlers[task_name] | |
| task_pack_result = task_handler({**results, **task_result}) | |
| data_sample.set_field(task_pack_result['data_samples'], task_name) | |
| packed_results['data_samples'] = data_sample | |
| return packed_results | |
| def __repr__(self): | |
| repr = self.__class__.__name__ | |
| task_handlers = ', '.join( | |
| f"'{name}': {handler.__class__.__name__}" | |
| for name, handler in self.task_handlers.items()) | |
| repr += f'(multi_task_fields={self.multi_task_fields}, ' | |
| repr += f"input_key='{self.input_key}', " | |
| repr += f'task_handlers={{{task_handlers}}})' | |
| return repr | |
| class Transpose(BaseTransform): | |
| """Transpose numpy array. | |
| **Required Keys:** | |
| - ``*keys`` | |
| **Modified Keys:** | |
| - ``*keys`` | |
| Args: | |
| keys (List[str]): The fields to convert to tensor. | |
| order (List[int]): The output dimensions order. | |
| """ | |
| def __init__(self, keys, order): | |
| self.keys = keys | |
| self.order = order | |
| def transform(self, results): | |
| """Method to transpose array.""" | |
| for key in self.keys: | |
| results[key] = results[key].transpose(self.order) | |
| return results | |
| def __repr__(self): | |
| return self.__class__.__name__ + \ | |
| f'(keys={self.keys}, order={self.order})' | |
| class NumpyToPIL(BaseTransform): | |
| """Convert the image from OpenCV format to :obj:`PIL.Image.Image`. | |
| **Required Keys:** | |
| - ``img`` | |
| **Modified Keys:** | |
| - ``img`` | |
| Args: | |
| to_rgb (bool): Whether to convert img to rgb. Defaults to True. | |
| """ | |
| def __init__(self, to_rgb: bool = False) -> None: | |
| self.to_rgb = to_rgb | |
| def transform(self, results: dict) -> dict: | |
| """Method to convert images to :obj:`PIL.Image.Image`.""" | |
| img = results['img'] | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img | |
| results['img'] = Image.fromarray(img) | |
| return results | |
| def __repr__(self) -> str: | |
| return self.__class__.__name__ + f'(to_rgb={self.to_rgb})' | |
| class PILToNumpy(BaseTransform): | |
| """Convert img to :obj:`numpy.ndarray`. | |
| **Required Keys:** | |
| - ``img`` | |
| **Modified Keys:** | |
| - ``img`` | |
| Args: | |
| to_bgr (bool): Whether to convert img to rgb. Defaults to True. | |
| dtype (str, optional): The dtype of the converted numpy array. | |
| Defaults to None. | |
| """ | |
| def __init__(self, to_bgr: bool = False, dtype=None) -> None: | |
| self.to_bgr = to_bgr | |
| self.dtype = dtype | |
| def transform(self, results: dict) -> dict: | |
| """Method to convert img to :obj:`numpy.ndarray`.""" | |
| img = np.array(results['img'], dtype=self.dtype) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img | |
| results['img'] = img | |
| return results | |
| def __repr__(self) -> str: | |
| return self.__class__.__name__ + \ | |
| f'(to_bgr={self.to_bgr}, dtype={self.dtype})' | |
| class Collect(BaseTransform): | |
| """Collect and only reserve the specified fields. | |
| **Required Keys:** | |
| - ``*keys`` | |
| **Deleted Keys:** | |
| All keys except those in the argument ``*keys``. | |
| Args: | |
| keys (Sequence[str]): The keys of the fields to be collected. | |
| """ | |
| def __init__(self, keys): | |
| self.keys = keys | |
| def transform(self, results): | |
| data = {} | |
| for key in self.keys: | |
| data[key] = results[key] | |
| return data | |
| def __repr__(self): | |
| return self.__class__.__name__ + f'(keys={self.keys})' | |