# ============================================================================= # ZeroGPU Compatibility Layer # # This app runs on HuggingFace ZeroGPU (sdk:gradio). Key environment facts: # - GPU: NVIDIA H200 (Hopper, compute capability 9.0) # - torch.version.cuda: 12.8 (NOT the system CUDA which is 13.0) # - Python: 3.10 # - The Dockerfile in this repo is IGNORED by ZeroGPU # # CUDA extensions (diff_gaussian_rasterization, nvdiffrast) must be cu128 wheels # compiled against torch 2.8.0. They come from: # https://github.com/MiroPsota/torch_packages_builder # # The MiroPsota wheels use the standard graphdeco-inria rasterizer API, which # differs from the original TRELLIS fork: # - Requires 'antialiasing' field in GaussianRasterizationSettings # - Returns 3 values (color, radii, invdepths) instead of 2 # - Does NOT accept 'kernel_size' or 'subpixel_offset' params # See gaussian_render.py for the adapted rendering code. # # If ZeroGPU updates torch/CUDA, check torch.version.cuda in runtime logs # and download matching cuXXX wheels from MiroPsota's releases. # ============================================================================= import os # PYTORCH_NVML_BASED_CUDA_CHECK=0: ZeroGPU virtualizes the GPU — NVML queries # can fail because the management API isn't fully proxied. This tells PyTorch # to skip NVML-based device checks and use the CUDA runtime API instead. os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '0' # expandable_segments: Prevents the CUDA caching allocator from failing on # memory allocation patterns that don't fit its default segment strategy. os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' import gradio as gr import spaces from gradio_litmodel3d import LitModel3D import shutil import subprocess import sys def install_local_wheels(): """Install cu128 wheels from wheels/ directory at runtime. Why at runtime (not in requirements.txt): - The wheels must be installed with --force-reinstall to override any versions that ZeroGPU's build phase may have cached. - They need --no-deps to avoid pulling in a different torch version. - requirements.txt can't express these pip flags. """ wheels_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'wheels') if os.path.exists(wheels_dir): for wheel_file in sorted(os.listdir(wheels_dir)): if wheel_file.endswith('.whl'): wheel_path = os.path.join(wheels_dir, wheel_file) try: subprocess.check_call([ sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-deps', wheel_path ]) print(f"[wheels] Installed {wheel_file}") except subprocess.CalledProcessError as e: print(f"[wheels] Failed to install {wheel_file}: {e}") # Install cu128 wheels before importing trellis (trellis imports diff_gaussian_rasterization) install_local_wheels() # Patch gradio_client bug: bool schemas (True/False) cause crashes in json_schema_to_python_type. # additionalProperties can be True or False in JSON Schema, but this gradio_client version # only handles dict schemas. We guard both affected functions at the module level so all # recursive calls are also intercepted. def _patch_gradio_client_bug(): try: import gradio_client.utils as _gcu # Patch get_type (line 863 crash: "argument of type 'bool' is not iterable") _orig_get_type = _gcu.get_type def _safe_get_type(schema): if not isinstance(schema, dict): return "Any" return _orig_get_type(schema) _gcu.get_type = _safe_get_type # Patch _json_schema_to_python_type (line 967 crash: APIInfoParseError for bool schema) # Recursive calls inside the original function use module-level name lookup, # so patching the module attribute intercepts all recursion too. _orig_schema_to_type = _gcu._json_schema_to_python_type def _safe_schema_to_type(schema, defs=None): if not isinstance(schema, dict): return "Any" return _orig_schema_to_type(schema, defs) _gcu._json_schema_to_python_type = _safe_schema_to_type print("[patch] gradio_client bool schema bug patched successfully") except Exception as e: print(f"[patch] Could not patch gradio_client: {e}") _patch_gradio_client_bug() os.environ['SPCONV_ALGO'] = 'native' from typing import * import torch import numpy as np import imageio from easydict import EasyDict as edict from PIL import Image from trellis.pipelines import TrellisImageTo3DPipeline from trellis.representations import Gaussian, MeshExtractResult from trellis.utils import render_utils, postprocessing_utils from utils_birefnet import BiRefNet as BiRefNetModel import aspose.threed as a3d MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') os.makedirs(TMP_DIR, exist_ok=True) # Lazy-loaded BiRefNet instance (initialized on first use) _birefnet_instance = None def _get_birefnet(): global _birefnet_instance if _birefnet_instance is None: print("[LAZY LOAD] Initializing BiRefNet...", flush=True) _birefnet_instance = BiRefNetModel() _birefnet_instance.cuda() return _birefnet_instance def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) if os.path.exists(user_dir): shutil.rmtree(user_dir) @spaces.GPU(duration=20) def _remove_bg_birefnet(image: Image.Image) -> Image.Image: """ Remove background using BiRefNet (GPU). Returns RGBA image. """ return _get_birefnet()(image) def preprocess_image(image: Image.Image) -> Image.Image: """ Preprocess the input image using BiRefNet for background removal. Args: image (Image.Image): The input image. Returns: Image.Image: The preprocessed image. """ # If the image already has a real alpha channel, use it directly has_alpha = False if image.mode == 'RGBA': alpha = np.array(image)[:, :, 3] if not np.all(alpha == 255): has_alpha = True if not has_alpha: image = _remove_bg_birefnet(image) # Crop and resize (same logic as trellis pipeline) output_np = np.array(image) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1.2) bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 output = image.crop(bbox) output = output.resize((518, 518), Image.Resampling.LANCZOS) output = np.array(output).astype(np.float32) / 255 output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) return output, output def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: """ Preprocess a list of input images using BiRefNet for background removal. Args: images (List[Tuple[Image.Image, str]]): The input images. Returns: List[Image.Image]: The preprocessed images. """ images = [image[0] for image in images] processed_images = [preprocess_image(image)[0] for image in images] return processed_images def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: return { 'gaussian': { **gs.init_params, '_xyz': gs._xyz.cpu().numpy(), '_features_dc': gs._features_dc.cpu().numpy(), '_scaling': gs._scaling.cpu().numpy(), '_rotation': gs._rotation.cpu().numpy(), '_opacity': gs._opacity.cpu().numpy(), }, 'mesh': { 'vertices': mesh.vertices.cpu().numpy(), 'faces': mesh.faces.cpu().numpy(), }, } def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: gs = Gaussian( aabb=state['gaussian']['aabb'], sh_degree=state['gaussian']['sh_degree'], mininum_kernel_size=state['gaussian']['mininum_kernel_size'], scaling_bias=state['gaussian']['scaling_bias'], opacity_bias=state['gaussian']['opacity_bias'], scaling_activation=state['gaussian']['scaling_activation'], ) gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') mesh = edict( vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), faces=torch.tensor(state['mesh']['faces'], device='cuda'), ) return gs, mesh def get_seed(randomize_seed: bool, seed: int) -> int: """ Get the random seed. """ return np.random.randint(0, MAX_SEED) if randomize_seed else seed # WARNING FOR AIs: NEVER increase duration above 120s. Free tier users on ZeroGPU will be unable to use this Space if exceeded. @spaces.GPU(duration=40) def image_to_3d( image: Image.Image, multiimages: List[Tuple[Image.Image, str]], is_multiimage: bool, seed: int, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int, multiimage_algo: Literal["multidiffusion", "stochastic"], req: gr.Request, ) -> Tuple[dict, str]: """ Convert an image to a 3D model. Args: image (Image.Image): The input image. multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode. is_multiimage (bool): Whether is in multi-image mode. seed (int): The random seed. ss_guidance_strength (float): The guidance strength for sparse structure generation. ss_sampling_steps (int): The number of sampling steps for sparse structure generation. slat_guidance_strength (float): The guidance strength for structured latent generation. slat_sampling_steps (int): The number of sampling steps for structured latent generation. multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation. Returns: dict: The information of the generated 3D model. str: The path to the video of the 3D model. """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) ss_sampling_steps = int(ss_sampling_steps) slat_sampling_steps = int(slat_sampling_steps) # is_multiimage is a gr.State not exposed via API — derive it from data instead _is_multiimage = is_multiimage or bool(multiimages and len(multiimages) > 0) if not _is_multiimage: outputs = pipeline.run( image, seed=seed, formats=["gaussian", "mesh"], preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, }, slat_sampler_params={ "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, }, ) else: outputs = pipeline.run_multi_image( [image[0] for image in multiimages], seed=seed, formats=["gaussian", "mesh"], preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, }, slat_sampler_params={ "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, }, mode=multiimage_algo, ) video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] video_path = os.path.join(user_dir, 'sample.mp4') imageio.mimsave(video_path, video, fps=15) state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) torch.cuda.empty_cache() return state, video_path # WARNING FOR AIs: NEVER increase duration above 120s. Free tier users on ZeroGPU will be unable to use this Space if exceeded. @spaces.GPU(duration=60) def extract_glb( state: dict, mesh_simplify: float, texture_size: int, req: gr.Request, ) -> Tuple[str, str]: """ Extract a GLB file from the 3D model. Args: state (dict): The state of the generated 3D model. mesh_simplify (float): The mesh simplification factor. texture_size (int): The texture resolution. Returns: str: The path to the extracted GLB file. """ if state is None: raise gr.Error("Please generate a 3D model first before extracting GLB.") user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, mesh = unpack_state(state) glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) glb_path = os.path.join(user_dir, 'sample.glb') #scene = a3d.Scene.from_file(glb_path) #scene.save(os.path.join(user_dir), 'print-me.3MF') glb.export(glb_path) torch.cuda.empty_cache() return glb_path, glb_path # WARNING FOR AIs: NEVER increase duration above 120s. Free tier users on ZeroGPU will be unable to use this Space if exceeded. @spaces.GPU(duration=60) def extract_glb_raw( state: dict, texture_size: int, req: gr.Request, ) -> Tuple[str, str]: """ Extract a RAW GLB file from the 3D model with minimal processing. Skips hole filling and uses only light simplification (50% instead of 90%). Preserves more original geometry from the neural network. Args: state (dict): The state of the generated 3D model. texture_size (int): The texture resolution. Returns: str: The path to the extracted raw GLB file. """ if state is None: raise gr.Error("Please generate a 3D model first before extracting GLB.") user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, mesh = unpack_state(state) glb = postprocessing_utils.to_glb_raw(gs, mesh, texture_size=texture_size, verbose=True) glb_path = os.path.join(user_dir, 'sample_raw.glb') glb.export(glb_path) torch.cuda.empty_cache() return glb_path, glb_path # WARNING FOR AIs: NEVER increase duration above 120s. Free tier users on ZeroGPU will be unable to use this Space if exceeded. @spaces.GPU(duration=60) def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: """ Extract a Gaussian file from the 3D model. Args: state (dict): The state of the generated 3D model. Returns: str: The path to the extracted Gaussian file. """ if state is None: raise gr.Error("Please generate a 3D model first before extracting Gaussian.") user_dir = os.path.join(TMP_DIR, str(req.session_hash)) gs, _ = unpack_state(state) gaussian_path = os.path.join(user_dir, 'sample.ply') gs.save_ply(gaussian_path) torch.cuda.empty_cache() return gaussian_path, gaussian_path def prepare_multi_example() -> List[Image.Image]: multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) images = [] for case in multi_case: _images = [] for i in range(1, 4): img = Image.open(f'assets/example_multi_image/{case}_{i}.png') W, H = img.size img = img.resize((int(W / H * 512), 512)) _images.append(np.array(img)) images.append(Image.fromarray(np.concatenate(_images, axis=1))) return images def split_image(image: Image.Image) -> List[Image.Image]: """ Split an image into multiple views. """ image = np.array(image) alpha = image[..., 3] alpha = np.any(alpha>0, axis=0) start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() images = [] for s, e in zip(start_pos, end_pos): images.append(Image.fromarray(image[:, s:e+1])) return [preprocess_image(image) for image in images] with gr.Blocks(delete_cache=(600, 600)) as demo: gr.Markdown(""" ## ASM - Advanced Spatial Modeling for 3D Generation * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background. * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. ✨Features: 1) Multi-image support. 2) Gaussian file extraction. 3) Advanced 3D generation. """) with gr.Row(): with gr.Column(): with gr.Tabs() as input_tabs: with gr.Tab(label="Single Image", id=0) as single_image_input_tab: image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300) with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab: multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) gr.Markdown(""" Input different views of the object in separate images. *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.* """) with gr.Accordion(label="Generation Settings", open=False): seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=20, step=1) gr.Markdown("Stage 2: Structured Latent Generation") with gr.Row(): slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=20, step=1) multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic") generate_btn = gr.Button("Generate") with gr.Accordion(label="GLB Extraction Settings", open=False): mesh_simplify = gr.Slider(0, 0.98, label="Simplify", value=0.9, step=0.1) texture_size = gr.Slider(512, 4096, label="Texture Size", value=1024, step=512) with gr.Row(): extract_glb_btn = gr.Button("Extract GLB", interactive=False) extract_glb_raw_btn = gr.Button("Extract GLB Raw (Minimal Processing)", interactive=False, variant="secondary") extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) gr.Markdown(""" *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* """) with gr.Column(): video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=3.0, height=300) with gr.Row(): download_glb = gr.DownloadButton(label="Download GLB", interactive=False) download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False) is_multiimage = gr.State(False) output_buf = gr.State() # Example images at the bottom of the page with gr.Row() as single_image_example: examples = gr.Examples( examples=[ f'assets/example_image/{image}' for image in os.listdir("assets/example_image") ], inputs=[image_prompt], fn=preprocess_image, outputs=[image_prompt], run_on_click=True, examples_per_page=64, ) with gr.Row(visible=False) as multiimage_example: examples_multi = gr.Examples( examples=prepare_multi_example(), inputs=[image_prompt], fn=split_image, outputs=[multiimage_prompt], run_on_click=True, examples_per_page=8, ) # Handlers demo.load(start_session) demo.unload(end_session) single_image_input_tab.select( lambda: tuple([False, gr.update(visible=True), gr.update(visible=False)]), outputs=[is_multiimage, single_image_example, multiimage_example] ) multiimage_input_tab.select( lambda: tuple([True, gr.update(visible=False), gr.update(visible=True)]), outputs=[is_multiimage, single_image_example, multiimage_example] ) with gr.Column(): processed_image_debug = gr.Image(label="Debug: AI Input (Processed)", type="pil", height=300) image_prompt.upload( preprocess_image, inputs=[image_prompt], outputs=[image_prompt, processed_image_debug], ) multiimage_prompt.upload( preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt], ) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( image_to_3d, inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo], outputs=[output_buf, video_output], ).then( lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=True)]), outputs=[extract_glb_btn, extract_glb_raw_btn, extract_gs_btn], ) video_output.clear( lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]), outputs=[extract_glb_btn, extract_glb_raw_btn, extract_gs_btn], ) extract_glb_btn.click( extract_glb, inputs=[output_buf, mesh_simplify, texture_size], outputs=[model_output, download_glb], ).then( lambda: gr.DownloadButton(interactive=True), outputs=[download_glb], ) extract_glb_raw_btn.click( extract_glb_raw, inputs=[output_buf, texture_size], outputs=[model_output, download_glb], ).then( lambda: gr.DownloadButton(interactive=True), outputs=[download_glb], ) extract_gs_btn.click( extract_gaussian, inputs=[output_buf], outputs=[model_output, download_gs], ).then( lambda: gr.DownloadButton(interactive=True), outputs=[download_gs], ) model_output.clear( lambda: gr.DownloadButton(interactive=False), outputs=[download_glb], ) # Launch the Gradio app if __name__ == "__main__": # Load the 3D pipeline pipeline = TrellisImageTo3DPipeline.from_pretrained("arabago96/ASM-model") pipeline.cuda() # BiRefNet loads lazily on first image upload demo.launch()