| |
| |
| |
| |
| |
|
|
| from typing import Tuple, Union, Iterable |
| from omegaconf import OmegaConf |
| import os |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| def dist_all_gather(x): |
| tensor_list = [torch.zeros_like(x) for _ in range(dist.get_world_size())] |
| dist.all_gather(tensor_list, x) |
| x = torch.cat(tensor_list, dim=0) |
| return x |
|
|
| def any_2tuple(data: Union[int, Tuple[int]]) -> Tuple[int]: |
| if isinstance(data, int): |
| return (data, data) |
| elif isinstance(data, Iterable): |
| assert len(data) == 2, "target size must be tuple of (w, h)" |
| return tuple(data) |
| else: |
| raise ValueError("target size must be int or tuple of (w, h)") |
|
|