Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| from pathlib import Path | |
| import tempfile | |
| import os | |
| import base64 | |
| from typing import Optional | |
| import json | |
| # SHARP モデルのインポート (遅延読み込み) | |
| SHARP_AVAILABLE = False | |
| SHARP_ERROR = None | |
| try: | |
| from sharp import Sharp | |
| SHARP_AVAILABLE = True | |
| print("✅ SHARP module loaded successfully") | |
| except ImportError as e: | |
| SHARP_ERROR = str(e) | |
| print(f"❌ SHARP import failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| except Exception as e: | |
| SHARP_ERROR = str(e) | |
| print(f"❌ Unexpected error loading SHARP: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # グローバルモデルインスタンス (メモリ効率のため) | |
| # 注意: ZeroGPUのマルチプロセッシングに対応するため、モジュールレベルで管理 | |
| _model = None | |
| def get_model(): | |
| """モデルインスタンスを取得(キャッシング) | |
| GPU workerプロセス内でモデルを初期化してキャッシュします。 | |
| これによりpickling問題を回避します。 | |
| """ | |
| global _model | |
| if _model is None and SHARP_AVAILABLE: | |
| print("🔄 Initializing SHARP model in GPU worker...") | |
| _model = Sharp() | |
| print("✅ SHARP model initialized successfully") | |
| return _model | |
| def _process_image_impl(image) -> tuple[Optional[str], str, str]: | |
| """ | |
| 画像から3D Gaussian Splatsを生成 | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| tuple: (PLYファイルパス, ステータスメッセージ, PLYデータ(base64)) | |
| """ | |
| if not SHARP_AVAILABLE: | |
| error_msg = f"❌ SHARPモデルが利用できません\n\nエラー詳細: {SHARP_ERROR}\n\n" | |
| error_msg += "考えられる原因:\n" | |
| error_msg += "1. ml-sharpパッケージのインストール失敗\n" | |
| error_msg += "2. Python バージョンの非互換性\n" | |
| error_msg += "3. 依存関係の問題\n\n" | |
| error_msg += "ログを確認してください。" | |
| return None, error_msg, "" | |
| if image is None: | |
| return None, "❌ 画像をアップロードしてください", "" | |
| try: | |
| # 一時ファイルとして保存 | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_input: | |
| input_path = Path(tmp_input.name) | |
| # PIL Imageとして保存 | |
| if hasattr(image, 'save'): | |
| image.save(input_path, format='JPEG') | |
| else: | |
| from PIL import Image | |
| Image.fromarray(image).save(input_path, format='JPEG') | |
| # モデルで推論 | |
| model = get_model() | |
| print(f"🔄 Processing image: {input_path}") | |
| gaussians = model.predict(input_path) | |
| # PLYファイルとして保存 | |
| with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as tmp_output: | |
| output_path = Path(tmp_output.name) | |
| gaussians.save(str(output_path)) | |
| # PLYファイルをBase64エンコード (Three.jsで使用) | |
| with open(output_path, 'rb') as f: | |
| ply_data = f.read() | |
| ply_base64 = base64.b64encode(ply_data).decode('utf-8') | |
| # 統計情報を取得 | |
| file_size = output_path.stat().st_size / (1024 * 1024) # MB | |
| # 入力ファイルを削除 | |
| if input_path.exists(): | |
| input_path.unlink() | |
| status_msg = f"✅ 生成完了!\n📦 ファイルサイズ: {file_size:.2f} MB" | |
| return str(output_path), status_msg, ply_base64 | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ エラーが発生しました:\n{str(e)}\n\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return None, error_msg, "" | |
| # ZeroGPUデコレータを適用 (180秒のGPUタイムアウト) | |
| # 注意: モジュールレベル関数に適用することでpickling問題を回避 | |
| process_image = spaces.GPU(duration=180)(_process_image_impl) | |
| # Three.js ビューアのHTMLテンプレート | |
| def create_viewer_html(ply_base64: str) -> str: | |
| """Three.js + GaussianSplats3D ビューアのHTMLを生成""" | |
| if not ply_base64: | |
| return """ | |
| <div style="width: 100%; height: 600px; display: flex; align-items: center; justify-content: center; background: #1a1a1a; color: white; border-radius: 8px;"> | |
| <div style="text-align: center;"> | |
| <h2>🎨 3D Gaussian Splats ビューア</h2> | |
| <p>左側で画像を処理すると、ここに3Dプレビューが表示されます</p> | |
| </div> | |
| </div> | |
| """ | |
| html = f""" | |
| <!DOCTYPE html> | |
| <html lang="ja"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>3D Gaussian Splats Viewer</title> | |
| <style> | |
| body {{ | |
| margin: 0; | |
| padding: 0; | |
| overflow: hidden; | |
| background: #000; | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| }} | |
| #container {{ | |
| width: 100%; | |
| height: 600px; | |
| position: relative; | |
| }} | |
| #loading {{ | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| transform: translate(-50%, -50%); | |
| color: white; | |
| font-size: 18px; | |
| z-index: 1000; | |
| }} | |
| #controls {{ | |
| position: absolute; | |
| top: 10px; | |
| left: 10px; | |
| background: rgba(0, 0, 0, 0.7); | |
| color: white; | |
| padding: 10px; | |
| border-radius: 5px; | |
| font-size: 12px; | |
| z-index: 1000; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div id="container"> | |
| <div id="loading">🔄 3Dモデルを読み込み中...</div> | |
| <div id="controls"> | |
| <div>🖱️ ドラッグ: 回転</div> | |
| <div>🔍 スクロール: ズーム</div> | |
| <div>⌨️ 右クリック: パン</div> | |
| </div> | |
| </div> | |
| <script type="importmap"> | |
| {{ | |
| "imports": {{ | |
| "three": "https://cdn.jsdelivr.net/npm/[email protected]/build/three.module.js", | |
| "three/addons/": "https://cdn.jsdelivr.net/npm/[email protected]/examples/jsm/" | |
| }} | |
| }} | |
| </script> | |
| <script type="module"> | |
| import * as THREE from 'three'; | |
| import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js'; | |
| // シーンの初期化 | |
| const container = document.getElementById('container'); | |
| const loading = document.getElementById('loading'); | |
| const scene = new THREE.Scene(); | |
| scene.background = new THREE.Color(0x1a1a1a); | |
| const camera = new THREE.PerspectiveCamera( | |
| 75, | |
| container.clientWidth / container.clientHeight, | |
| 0.1, | |
| 1000 | |
| ); | |
| camera.position.set(0, 0, 5); | |
| const renderer = new THREE.WebGLRenderer({{ antialias: true }}); | |
| renderer.setSize(container.clientWidth, container.clientHeight); | |
| renderer.setPixelRatio(window.devicePixelRatio); | |
| container.appendChild(renderer.domElement); | |
| // OrbitControls | |
| const controls = new OrbitControls(camera, renderer.domElement); | |
| controls.enableDamping = true; | |
| controls.dampingFactor = 0.05; | |
| // ライト | |
| const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); | |
| scene.add(ambientLight); | |
| const directionalLight = new THREE.DirectionalLight(0xffffff, 1); | |
| directionalLight.position.set(5, 10, 7.5); | |
| scene.add(directionalLight); | |
| // グリッドヘルパー | |
| const gridHelper = new THREE.GridHelper(10, 10); | |
| scene.add(gridHelper); | |
| // PLYローダー | |
| async function loadPLY() {{ | |
| try {{ | |
| // Base64からArrayBufferに変換 | |
| const plyBase64 = '{ply_base64}'; | |
| const binaryString = atob(plyBase64); | |
| const bytes = new Uint8Array(binaryString.length); | |
| for (let i = 0; i < binaryString.length; i++) {{ | |
| bytes[i] = binaryString.charCodeAt(i); | |
| }} | |
| // PLYLoaderを動的にインポート | |
| const {{ PLYLoader }} = await import('three/addons/loaders/PLYLoader.js'); | |
| const loader = new PLYLoader(); | |
| // ArrayBufferをBlob経由でロード | |
| const blob = new Blob([bytes], {{ type: 'application/octet-stream' }}); | |
| const url = URL.createObjectURL(blob); | |
| loader.load( | |
| url, | |
| function (geometry) {{ | |
| loading.style.display = 'none'; | |
| // ポイントクラウドとしてレンダリング | |
| geometry.computeVertexNormals(); | |
| // カラー情報があるか確認 | |
| const hasColors = geometry.attributes.color !== undefined; | |
| const material = new THREE.PointsMaterial({{ | |
| size: 0.01, | |
| vertexColors: hasColors, | |
| color: hasColors ? undefined : 0x00ff00, | |
| sizeAttenuation: true | |
| }}); | |
| const points = new THREE.Points(geometry, material); | |
| scene.add(points); | |
| // カメラ位置を調整 | |
| geometry.computeBoundingBox(); | |
| const bbox = geometry.boundingBox; | |
| const center = new THREE.Vector3(); | |
| bbox.getCenter(center); | |
| const size = new THREE.Vector3(); | |
| bbox.getSize(size); | |
| const maxDim = Math.max(size.x, size.y, size.z); | |
| const fov = camera.fov * (Math.PI / 180); | |
| let cameraZ = Math.abs(maxDim / Math.tan(fov / 2)); | |
| cameraZ *= 1.5; | |
| camera.position.set(center.x, center.y, center.z + cameraZ); | |
| camera.lookAt(center); | |
| controls.target.copy(center); | |
| controls.update(); | |
| URL.revokeObjectURL(url); | |
| console.log('✅ PLYファイルの読み込み完了'); | |
| }}, | |
| function (xhr) {{ | |
| const percent = (xhr.loaded / xhr.total * 100).toFixed(0); | |
| loading.textContent = `🔄 読み込み中... ${{percent}}%`; | |
| }}, | |
| function (error) {{ | |
| console.error('❌ PLY読み込みエラー:', error); | |
| loading.textContent = '❌ 読み込みエラー'; | |
| loading.style.color = 'red'; | |
| }} | |
| ); | |
| }} catch (error) {{ | |
| console.error('❌ エラー:', error); | |
| loading.textContent = '❌ エラーが発生しました'; | |
| loading.style.color = 'red'; | |
| }} | |
| }} | |
| // アニメーションループ | |
| function animate() {{ | |
| requestAnimationFrame(animate); | |
| controls.update(); | |
| renderer.render(scene, camera); | |
| }} | |
| // リサイズ対応 | |
| window.addEventListener('resize', () => {{ | |
| camera.aspect = container.clientWidth / container.clientHeight; | |
| camera.updateProjectionMatrix(); | |
| renderer.setSize(container.clientWidth, container.clientHeight); | |
| }}); | |
| // PLYを読み込んで開始 | |
| loadPLY(); | |
| animate(); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return html | |
| def update_viewer(ply_base64: str) -> str: | |
| """ビューアを更新""" | |
| return create_viewer_html(ply_base64) | |
| # Gradio UI | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), | |
| title="SHARP: 3D Gaussian Splats Generator" | |
| ) as demo: | |
| # SHARPステータスバナー | |
| if SHARP_AVAILABLE: | |
| status_banner = """ | |
| # 🎨 SHARP: 単一画像から3D Gaussian Splatsを生成 | |
| ✅ **SHARPモデル: 正常に読み込まれました** | |
| """ | |
| else: | |
| status_banner = f""" | |
| # 🎨 SHARP: 単一画像から3D Gaussian Splatsを生成 | |
| ⚠️ **警告: SHARPモデルが読み込めませんでした** | |
| エラー: `{SHARP_ERROR}` | |
| Spaceのログを確認するか、リポジトリの管理者にお問い合わせください。 | |
| """ | |
| gr.Markdown(status_banner) | |
| gr.Markdown(""" | |
| Appleの最新技術「SHARP」を使用して、1枚の画像から高品質な3D Gaussian Splatsを生成します。 | |
| 生成された3DモデルはThree.jsで右側にリアルタイムプレビューされます。 | |
| ### 使い方 | |
| 1. 左側のエリアに画像をアップロード | |
| 2. 「生成開始」ボタンをクリック | |
| 3. 右側で3Dモデルをインタラクティブに確認 | |
| 4. PLYファイルをダウンロード可能 | |
| **ZeroGPU (Nvidia H200)** で高速に処理されます 🚀 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📸 入力画像") | |
| input_image = gr.Image( | |
| label="画像をアップロード", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| height=400 | |
| ) | |
| generate_btn = gr.Button( | |
| "🚀 生成開始", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| status_box = gr.Textbox( | |
| label="ステータス", | |
| lines=3, | |
| interactive=False | |
| ) | |
| output_file = gr.File( | |
| label="📦 PLYファイルをダウンロード", | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎬 3Dプレビュー (Three.js)") | |
| viewer_html = gr.HTML( | |
| create_viewer_html(""), | |
| label="3D Viewer" | |
| ) | |
| # 非表示のステート (PLY Base64データ) | |
| ply_data_state = gr.State("") | |
| # イベントハンドラ | |
| def on_generate(image): | |
| ply_path, status, ply_base64 = process_image(image) | |
| viewer = create_viewer_html(ply_base64) | |
| return ply_path, status, ply_base64, viewer | |
| generate_btn.click( | |
| fn=on_generate, | |
| inputs=[input_image], | |
| outputs=[output_file, status_box, ply_data_state, viewer_html] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### ℹ️ 技術情報 | |
| - **モデル**: Apple SHARP (Sharp Monocular View Synthesis) | |
| - **出力形式**: PLY (Polygon File Format) | |
| - **レンダリング**: Three.js + PLYLoader | |
| - **GPU**: ZeroGPU (Nvidia H200, 動的割り当て) | |
| - **処理時間**: 通常1秒以下 | |
| ### 📚 リソース | |
| - [SHARP GitHub](https://github.com/apple/ml-sharp) | |
| - [論文 (arXiv)](https://arxiv.org/abs/2512.10685) | |
| - [Hugging Face Model](https://huggingface.co/apple/Sharp) | |
| ### ⚠️ 注意事項 | |
| - 処理にはGPUを使用するため、待機時間が発生する場合があります | |
| - ZeroGPUは60秒のタイムアウトがあります | |
| - 大きな画像は自動的にリサイズされます | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |