|
|
| import os |
| import gradio as gr |
|
|
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from loguru import logger |
| from tqdm import tqdm |
| from tools.common_utils import save_video |
| from dkt.pipelines.pipeline import DKTPipeline, ModelConfig |
|
|
|
|
| import cv2 |
| import copy |
| import trimesh |
|
|
| from os.path import join |
| from tools.depth2pcd import depth2pcd |
| |
|
|
|
|
| from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map |
| import datetime |
| import tempfile |
| import time |
|
|
|
|
| |
|
|
| NEGATIVE_PROMPT = '' |
| height = 480 |
| width = 832 |
| window_size = 21 |
| |
| DKT_PIPELINE_14B = DKTPipeline(is14B=True) |
| |
|
|
| example_inputs = [ |
| "examples/1.mp4", |
| "examples/7.mp4", |
| "examples/8.mp4", |
| "examples/39.mp4", |
| "examples/10.mp4", |
| "examples/30.mp4", |
| |
| "examples/35.mp4", |
| "examples/40.mp4", |
| "examples/2.mp4", |
|
|
|
|
| "examples/4.mp4", |
| "examples/episode_48-camera_head.mp4", |
| "examples/input_20251128_121408.mp4", |
| "examples/input_20251128_122722.mp4", |
| "examples/5eaeaff52b23787a3dc3c610655a49d2.mp4", |
| "examples/9f2909760aff526070f169620ff38290.mp4", |
| "examples/18.mp4", |
| |
| "examples/28.mp4", |
| "examples/73fc0b2a3af3474de27c7da0bfbf5faa.mp4", |
| "examples/episode_48-camera_third_view.mp4", |
| "examples/extra_5.mp4", |
| "examples/extra_9.mp4", |
| "examples/IMG_5703.MOV", |
| "examples/input_20251202_031811.mp4", |
| "examples/input_20251202_032007.mp4", |
| "examples/teaser_1.mp4", |
| "examples/3.mp4", |
| "examples/teaser_3.mp4", |
| "examples/teaser_7.mp4", |
| "examples/teaser_25.mp4", |
|
|
|
|
|
|
|
|
|
|
|
|
| ] |
|
|
|
|
|
|
|
|
|
|
| def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene: |
| pts_3d = point_map[valid_mask] * np.array([-1, -1, 1]) |
| pts_rgb = frame[valid_mask] |
|
|
| |
| scene_3d = trimesh.Scene() |
|
|
| |
| point_cloud_data = trimesh.PointCloud( |
| vertices=pts_3d, colors=pts_rgb |
| ) |
| |
| scene_3d.add_geometry(point_cloud_data) |
| return scene_3d |
|
|
|
|
|
|
| def create_simple_glb_from_pointcloud(points, colors, glb_filename): |
| try: |
| if len(points) == 0: |
| logger.warning(f"No valid points to create GLB for {glb_filename}") |
| return False |
| |
| if colors is not None: |
| |
| pts_rgb = colors |
| else: |
| logger.info("No colors provided, adding default white colors") |
| pts_rgb = np.ones((len(points), 3)) |
| |
| valid_mask = np.ones(len(points), dtype=bool) |
| |
| scene_3d = pmap_to_glb(points, valid_mask, pts_rgb) |
| |
| scene_3d.export(glb_filename) |
| |
| |
| return True |
| |
| except Exception as e: |
| logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}") |
| return False |
|
|
|
|
|
|
|
|
|
|
|
|
| def process_video( |
| video_file, |
| model_size, |
| num_inference_steps, |
| overlap |
| ): |
| global height |
| global width |
| global window_size |
|
|
| global DKT_PIPELINE_14B |
| global DKT_PIPELINE |
|
|
| if model_size == "14B": |
| logger.info(f'14B model is chosen') |
| pipeline = DKT_PIPELINE_14B |
| elif model_size == "1.3B": |
| logger.info(f'1.3B model is chosen') |
| pipeline = DKT_PIPELINE |
| else: |
| raise ValueError(f"Invalid model size: {model_size}") |
|
|
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_') |
| |
|
|
|
|
| |
| start_time = time.time() |
|
|
| print(f"[1] Starting pipeline...") |
| try: |
| prediction_result = pipeline( |
| video_file, |
| negative_prompt=NEGATIVE_PROMPT, |
| height=height, |
| width=width, |
| num_inference_steps=num_inference_steps, |
| overlap=overlap, |
| return_rgb=True, |
| get_moge_intrinsics=False |
| ) |
| print(f"[2] Pipeline done, keys: {prediction_result.keys()}") |
| except Exception as e: |
| print(f"[ERROR] Pipeline failed: {type(e).__name__}: {e}") |
| import traceback |
| traceback.print_exc() |
| raise |
|
|
| end_time = time.time() |
| spend_time = end_time - start_time |
| print(f"[3] Pipeline time: {spend_time:.2f}s") |
| logger.info(f"pipeline spend time: {spend_time:.2f} seconds for depth prediction") |
|
|
| |
| |
| output_filename = f"output_{timestamp}.mp4" |
| output_path = os.path.join(cur_save_dir, output_filename) |
|
|
| cap = cv2.VideoCapture(video_file) |
| input_fps = cap.get(cv2.CAP_PROP_FPS) |
| cap.release() |
|
|
| print(f"[4] Saving video, fps={input_fps}") |
| save_video(prediction_result['colored_depth_map'], output_path, fps=input_fps, quality=8) |
| print(f"[5] Video saved: {output_path}") |
| return output_path |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
|
|
|
|
|
|
| |
|
|
|
|
| css = """ |
| #download { |
| height: 118px; |
| } |
| .slider .inner { |
| width: 5px; |
| background: #FFF; |
| } |
| .viewport { |
| aspect-ratio: 4/3; |
| } |
| .tabs button.selected { |
| font-size: 20px !important; |
| color: crimson !important; |
| } |
| h1 { |
| text-align: center; |
| display: block; |
| } |
| h2 { |
| text-align: center; |
| display: block; |
| } |
| h3 { |
| text-align: center; |
| display: block; |
| } |
| .md_feedback li { |
| margin-bottom: 0px !important; |
| } |
| """ |
|
|
|
|
|
|
| head_html = """ |
| <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> |
| <script> |
| window.dataLayer = window.dataLayer || []; |
| function gtag() {dataLayer.push(arguments);} |
| gtag('js', new Date()); |
| gtag('config', 'G-1FWSVCGZTG'); |
| </script> |
| """ |
|
|
|
|
|
|
|
|
| with gr.Blocks(css=css, title="DKT", head=head_html) as demo: |
| |
| gr.Markdown( |
| """ |
| # Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation |
| <p align="center"> |
| |
| <a title="Website" href="https://daniellli.github.io/projects/DKT/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
| <img src="https://www.obukhov.ai/img/badges/badge-website.svg"> |
| </a> |
| <a title="Github" href="https://github.com/Daniellli/DKT" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
| <img src="https://img.shields.io/github/stars/Daniellli/DKT?style=social" alt="badge-github-stars"> |
| </a> |
| <a title="Social" href="https://x.com/xshocng1" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> |
| </a> |
| """ |
| ) |
| |
| |
|
|
| with gr.Row(): |
| with gr.Column(): |
| input_video = gr.Video(label="Input Video", elem_id='video-display-input') |
| |
| model_size = gr.Radio( |
| |
| choices=["14B"], |
| value="14B", |
| label="Model Size" |
| ) |
|
|
|
|
| with gr.Accordion("Advanced Parameters", open=False): |
| num_inference_steps = gr.Slider( |
| minimum=1, maximum=50, value=5, step=1, |
| label="Number of Inference Steps" |
| ) |
| overlap = gr.Slider( |
| minimum=1, maximum=20, value=3, step=1, |
| label="Overlap" |
| ) |
| |
| submit = gr.Button(value="Compute Depth", variant="primary") |
| |
| with gr.Column(): |
| output_video = gr.Video( |
| label="Depth Outputs", |
| elem_id='video-display-output', |
| autoplay=True |
| ) |
| vis_video = gr.Video( |
| label="Visualization Video", |
| visible=False, |
| autoplay=True |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def on_submit(video_file, model_size, num_inference_steps, overlap): |
| logger.info('on_submit is calling') |
| if video_file is None: |
| return None, None |
| |
| try: |
| |
| start_time = time.time() |
| output_path = process_video( |
| video_file, model_size, num_inference_steps, overlap |
| ) |
| spend_time = time.time() - start_time |
| logger.info(f"Total spend time in on_submit: {spend_time:.2f} seconds") |
| print(f"Total spend time in on_submit: {spend_time:.2f} seconds") |
|
|
| |
| if output_path is None: |
| return None, None |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| return output_path, None |
| |
| except Exception as e: |
| logger.error(e) |
| return None, None |
|
|
| |
| submit.click( |
| on_submit, |
| inputs=[ |
| input_video, model_size, num_inference_steps, overlap |
| ], |
| outputs=[ |
| output_video, vis_video |
| |
| ] |
| ) |
| |
|
|
| |
| def on_example_submit(video_file): |
| """Wrapper function for examples with default parameters""" |
| return on_submit(video_file, "14B", 5, 3) |
|
|
| examples = gr.Examples( |
| examples=example_inputs, |
| inputs=[input_video], |
| outputs=[ |
| output_video, vis_video |
| |
| ], |
| fn=on_example_submit, |
| examples_per_page=36, |
| cache_examples=False |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| |
| |
| |
| demo.queue().launch() |
| |