|
|
import subprocess |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Literal, TypedDict |
|
|
from PIL import Image |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from jaxtyping import Float, Int, UInt8 |
|
|
from torch import Tensor |
|
|
from tqdm import tqdm |
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
|
|
|
from glob import glob |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--input_base_dir", type=str, help="base directory containing 1K, 2K, ..., 11K subdirectories") |
|
|
parser.add_argument("--output_base_dir", type=str, help="base output directory for processed datasets") |
|
|
parser.add_argument( |
|
|
"--img_subdir", |
|
|
type=str, |
|
|
default="images_4", |
|
|
help="image directory name", |
|
|
choices=[ |
|
|
"images_4", |
|
|
"images_8", |
|
|
], |
|
|
) |
|
|
parser.add_argument("--n_test", type=int, default=10, help="test skip") |
|
|
parser.add_argument("--which_stage", type=str, default=None, help="dataset directory") |
|
|
parser.add_argument("--detect_overlap", action="store_true") |
|
|
parser.add_argument("--start_k", type=int, default=1, help="starting K value (default: 1)") |
|
|
parser.add_argument("--end_k", type=int, default=11, help="ending K value (default: 11)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
TARGET_BYTES_PER_CHUNK = int(2e8) |
|
|
|
|
|
|
|
|
def get_size(path: Path) -> int: |
|
|
"""Get file or folder size in bytes.""" |
|
|
return int(subprocess.check_output(["du", "-b", path]).split()[0].decode("utf-8")) |
|
|
|
|
|
|
|
|
def load_raw(path: Path) -> UInt8[Tensor, " length"]: |
|
|
return torch.tensor(np.memmap(path, dtype="uint8", mode="r")) |
|
|
|
|
|
|
|
|
def load_images(example_path: Path) -> dict[int, UInt8[Tensor, "..."]]: |
|
|
"""Load JPG images as raw bytes (do not decode).""" |
|
|
|
|
|
return { |
|
|
int(path.stem.split("_")[-1]): load_raw(path) |
|
|
for path in example_path.iterdir() |
|
|
if path.suffix.lower() not in [".npz"] |
|
|
} |
|
|
|
|
|
|
|
|
class Metadata(TypedDict): |
|
|
url: str |
|
|
timestamps: Int[Tensor, " camera"] |
|
|
cameras: Float[Tensor, "camera entry"] |
|
|
|
|
|
|
|
|
class Example(Metadata): |
|
|
key: str |
|
|
images: list[UInt8[Tensor, "..."]] |
|
|
|
|
|
|
|
|
def load_metadata(example_path: Path) -> Metadata: |
|
|
blender2opencv = np.array( |
|
|
[[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] |
|
|
) |
|
|
url = str(example_path).split("/")[-3] |
|
|
with open(example_path, "r") as f: |
|
|
meta_data = json.load(f) |
|
|
|
|
|
store_h, store_w = meta_data["h"], meta_data["w"] |
|
|
fx, fy, cx, cy = ( |
|
|
meta_data["fl_x"], |
|
|
meta_data["fl_y"], |
|
|
meta_data["cx"], |
|
|
meta_data["cy"], |
|
|
) |
|
|
saved_fx = float(fx) / float(store_w) |
|
|
saved_fy = float(fy) / float(store_h) |
|
|
saved_cx = float(cx) / float(store_w) |
|
|
saved_cy = float(cy) / float(store_h) |
|
|
|
|
|
timestamps = [] |
|
|
cameras = [] |
|
|
opencv_c2ws = [] |
|
|
|
|
|
for frame in meta_data["frames"]: |
|
|
timestamps.append( |
|
|
int(os.path.basename(frame["file_path"]).split(".")[0].split("_")[-1]) |
|
|
) |
|
|
camera = [saved_fx, saved_fy, saved_cx, saved_cy, 0.0, 0.0] |
|
|
|
|
|
opencv_c2w = np.array(frame["transform_matrix"]) @ blender2opencv |
|
|
opencv_c2ws.append(opencv_c2w) |
|
|
camera.extend(np.linalg.inv(opencv_c2w)[:3].flatten().tolist()) |
|
|
cameras.append(np.array(camera)) |
|
|
|
|
|
|
|
|
timestamps = torch.tensor(timestamps, dtype=torch.int64) |
|
|
cameras = torch.tensor(np.stack(cameras), dtype=torch.float32) |
|
|
|
|
|
return {"url": url, "timestamps": timestamps, "cameras": cameras} |
|
|
|
|
|
|
|
|
def partition_train_test_splits(root_dir, n_test=10): |
|
|
sub_folders = sorted(glob(os.path.join(root_dir, "*/"))) |
|
|
test_list = sub_folders[::n_test] |
|
|
train_list = [x for x in sub_folders if x not in test_list] |
|
|
out_dict = {"train": train_list, "test": test_list} |
|
|
return out_dict |
|
|
|
|
|
|
|
|
def is_image_shape_matched(image_dir, target_shape): |
|
|
image_path = sorted(glob(str(image_dir / "*"))) |
|
|
if len(image_path) == 0: |
|
|
return False |
|
|
|
|
|
image_path = image_path[0] |
|
|
try: |
|
|
im = Image.open(image_path) |
|
|
except: |
|
|
return False |
|
|
w, h = im.size |
|
|
if (h, w) == target_shape: |
|
|
return True |
|
|
else: |
|
|
print("image shape: ", h, " ", w) |
|
|
return False |
|
|
|
|
|
|
|
|
def legal_check_for_all_scenes(root_dir, target_shape): |
|
|
valid_folders = [] |
|
|
sub_folders = sorted(glob(os.path.join(root_dir, "*"))) |
|
|
for sub_folder in tqdm(sub_folders, desc="checking scenes..."): |
|
|
|
|
|
img_dir = os.path.join(sub_folder, "images_4") |
|
|
if not is_image_shape_matched(Path(img_dir), target_shape): |
|
|
print(f"image shape does not match for {sub_folder}") |
|
|
continue |
|
|
pose_file = os.path.join(sub_folder, "transforms.json") |
|
|
if not os.path.isfile(pose_file): |
|
|
print(f"cannot find pose file for {sub_folder}") |
|
|
continue |
|
|
|
|
|
valid_folders.append(sub_folder) |
|
|
|
|
|
return valid_folders |
|
|
|
|
|
|
|
|
def process_single_directory(input_dir: Path, output_dir: Path): |
|
|
"""Process a single K directory""" |
|
|
print(f"\n=== Processing {input_dir.name} ===") |
|
|
|
|
|
INPUT_DIR = input_dir |
|
|
OUTPUT_DIR = output_dir |
|
|
|
|
|
if "images_8" in args.img_subdir: |
|
|
target_shape = (270, 480) |
|
|
elif "images_4" in args.img_subdir: |
|
|
target_shape = (540, 960) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
print("checking all scenes...") |
|
|
valid_scenes = legal_check_for_all_scenes(INPUT_DIR, target_shape) |
|
|
print("valid scenes:", len(valid_scenes)) |
|
|
|
|
|
|
|
|
test_scenes = "/scratch/azureml/cr/j/e8e7ca980a5641daa86426c3fa644c10/exe/wd/dl3dv_benchmark/index.json" |
|
|
with open(test_scenes, "r") as f: |
|
|
overlap_scenes = json.load(f) |
|
|
|
|
|
assert len(overlap_scenes) == 140, "test scenes should contain 140 scenes" |
|
|
|
|
|
for stage in ["train"]: |
|
|
|
|
|
error_logs = [] |
|
|
image_dirs = valid_scenes |
|
|
|
|
|
chunk_size = 0 |
|
|
chunk_index = 0 |
|
|
chunk: list[Example] = [] |
|
|
|
|
|
def save_chunk(): |
|
|
nonlocal chunk_size, chunk_index, chunk |
|
|
|
|
|
chunk_key = f"{chunk_index:0>6}" |
|
|
dir = OUTPUT_DIR / stage |
|
|
dir.mkdir(exist_ok=True, parents=True) |
|
|
torch.save(chunk, dir / f"{chunk_key}.torch") |
|
|
|
|
|
|
|
|
chunk_size = 0 |
|
|
chunk_index += 1 |
|
|
chunk = [] |
|
|
|
|
|
for image_dir in tqdm(image_dirs, desc=f"Processing {stage}"): |
|
|
key = os.path.basename(image_dir.strip("/")) |
|
|
|
|
|
if key in overlap_scenes: |
|
|
print(f"scene {key} in benchmark, skip.") |
|
|
continue |
|
|
|
|
|
image_dir = Path(image_dir) / 'images_4' |
|
|
|
|
|
|
|
|
if not image_dir.exists(): |
|
|
print(f"Image directory does not exist for {key}, skipping...") |
|
|
continue |
|
|
|
|
|
num_bytes = get_size(image_dir) |
|
|
|
|
|
|
|
|
try: |
|
|
images = load_images(image_dir) |
|
|
except: |
|
|
print("image loading error") |
|
|
continue |
|
|
meta_path = image_dir.parent / "transforms.json" |
|
|
if not meta_path.is_file(): |
|
|
error_msg = f"---------> [ERROR] no meta file in {key}, skip." |
|
|
print(error_msg) |
|
|
error_logs.append(error_msg) |
|
|
continue |
|
|
example = load_metadata(meta_path) |
|
|
|
|
|
|
|
|
try: |
|
|
example["images"] = [ |
|
|
images[timestamp.item()] for timestamp in example["timestamps"] |
|
|
] |
|
|
except: |
|
|
error_msg = f"---------> [ERROR] Some images missing in {key}, skip." |
|
|
print(error_msg) |
|
|
error_logs.append(error_msg) |
|
|
continue |
|
|
|
|
|
|
|
|
example["key"] = "dl3dv_" + key |
|
|
|
|
|
chunk.append(example) |
|
|
chunk_size += num_bytes |
|
|
|
|
|
if chunk_size >= TARGET_BYTES_PER_CHUNK: |
|
|
save_chunk() |
|
|
|
|
|
if chunk_size > 0: |
|
|
save_chunk() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
base_input_dir = Path(args.input_base_dir) |
|
|
base_output_dir = Path(args.output_base_dir) |
|
|
|
|
|
|
|
|
total_dirs = args.end_k - args.start_k + 1 |
|
|
processed_dirs = 0 |
|
|
|
|
|
for k in range(args.start_k, args.end_k + 1): |
|
|
k_dir = f"{k}K" |
|
|
input_dir = base_input_dir / k_dir |
|
|
output_dir = base_output_dir / k_dir |
|
|
|
|
|
if not input_dir.exists(): |
|
|
print(f"Warning: Input directory {input_dir} does not exist, skipping...") |
|
|
continue |
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"Processing directory {k_dir} ({processed_dirs + 1}/{total_dirs})") |
|
|
print(f"Input: {input_dir}") |
|
|
print(f"Output: {output_dir}") |
|
|
print(f"{'='*50}") |
|
|
|
|
|
|
|
|
process_single_directory(input_dir, output_dir) |
|
|
|
|
|
processed_dirs += 1 |
|
|
print(f"\nCompleted {k_dir} ({processed_dirs}/{total_dirs})") |
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"All processing complete! Processed {processed_dirs}/{total_dirs} directories.") |
|
|
print(f"{'='*50}") |