Claude
Adopt working configuration from ronedgecomb/ml-sharp Space
f93ea45 unverified
import gradio as gr
import spaces
import torch
from pathlib import Path
import tempfile
import os
import base64
from typing import Optional
import json
# SHARP モデルのインポート (遅延読み込み)
SHARP_AVAILABLE = False
SHARP_ERROR = None
try:
from sharp import Sharp
SHARP_AVAILABLE = True
print("✅ SHARP module loaded successfully")
except ImportError as e:
SHARP_ERROR = str(e)
print(f"❌ SHARP import failed: {e}")
import traceback
traceback.print_exc()
except Exception as e:
SHARP_ERROR = str(e)
print(f"❌ Unexpected error loading SHARP: {e}")
import traceback
traceback.print_exc()
# グローバルモデルインスタンス (メモリ効率のため)
# 注意: ZeroGPUのマルチプロセッシングに対応するため、モジュールレベルで管理
_model = None
def get_model():
"""モデルインスタンスを取得(キャッシング)
GPU workerプロセス内でモデルを初期化してキャッシュします。
これによりpickling問題を回避します。
"""
global _model
if _model is None and SHARP_AVAILABLE:
print("🔄 Initializing SHARP model in GPU worker...")
_model = Sharp()
print("✅ SHARP model initialized successfully")
return _model
def _process_image_impl(image) -> tuple[Optional[str], str, str]:
"""
画像から3D Gaussian Splatsを生成
Args:
image: PIL Image or numpy array
Returns:
tuple: (PLYファイルパス, ステータスメッセージ, PLYデータ(base64))
"""
if not SHARP_AVAILABLE:
error_msg = f"❌ SHARPモデルが利用できません\n\nエラー詳細: {SHARP_ERROR}\n\n"
error_msg += "考えられる原因:\n"
error_msg += "1. ml-sharpパッケージのインストール失敗\n"
error_msg += "2. Python バージョンの非互換性\n"
error_msg += "3. 依存関係の問題\n\n"
error_msg += "ログを確認してください。"
return None, error_msg, ""
if image is None:
return None, "❌ 画像をアップロードしてください", ""
try:
# 一時ファイルとして保存
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_input:
input_path = Path(tmp_input.name)
# PIL Imageとして保存
if hasattr(image, 'save'):
image.save(input_path, format='JPEG')
else:
from PIL import Image
Image.fromarray(image).save(input_path, format='JPEG')
# モデルで推論
model = get_model()
print(f"🔄 Processing image: {input_path}")
gaussians = model.predict(input_path)
# PLYファイルとして保存
with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as tmp_output:
output_path = Path(tmp_output.name)
gaussians.save(str(output_path))
# PLYファイルをBase64エンコード (Three.jsで使用)
with open(output_path, 'rb') as f:
ply_data = f.read()
ply_base64 = base64.b64encode(ply_data).decode('utf-8')
# 統計情報を取得
file_size = output_path.stat().st_size / (1024 * 1024) # MB
# 入力ファイルを削除
if input_path.exists():
input_path.unlink()
status_msg = f"✅ 生成完了!\n📦 ファイルサイズ: {file_size:.2f} MB"
return str(output_path), status_msg, ply_base64
except Exception as e:
import traceback
error_msg = f"❌ エラーが発生しました:\n{str(e)}\n\n{traceback.format_exc()}"
print(error_msg)
return None, error_msg, ""
# ZeroGPUデコレータを適用 (180秒のGPUタイムアウト)
# 注意: モジュールレベル関数に適用することでpickling問題を回避
process_image = spaces.GPU(duration=180)(_process_image_impl)
# Three.js ビューアのHTMLテンプレート
def create_viewer_html(ply_base64: str) -> str:
"""Three.js + GaussianSplats3D ビューアのHTMLを生成"""
if not ply_base64:
return """
<div style="width: 100%; height: 600px; display: flex; align-items: center; justify-content: center; background: #1a1a1a; color: white; border-radius: 8px;">
<div style="text-align: center;">
<h2>🎨 3D Gaussian Splats ビューア</h2>
<p>左側で画像を処理すると、ここに3Dプレビューが表示されます</p>
</div>
</div>
"""
html = f"""
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>3D Gaussian Splats Viewer</title>
<style>
body {{
margin: 0;
padding: 0;
overflow: hidden;
background: #000;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}}
#container {{
width: 100%;
height: 600px;
position: relative;
}}
#loading {{
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
color: white;
font-size: 18px;
z-index: 1000;
}}
#controls {{
position: absolute;
top: 10px;
left: 10px;
background: rgba(0, 0, 0, 0.7);
color: white;
padding: 10px;
border-radius: 5px;
font-size: 12px;
z-index: 1000;
}}
</style>
</head>
<body>
<div id="container">
<div id="loading">🔄 3Dモデルを読み込み中...</div>
<div id="controls">
<div>🖱️ ドラッグ: 回転</div>
<div>🔍 スクロール: ズーム</div>
<div>⌨️ 右クリック: パン</div>
</div>
</div>
<script type="importmap">
{{
"imports": {{
"three": "https://cdn.jsdelivr.net/npm/[email protected]/build/three.module.js",
"three/addons/": "https://cdn.jsdelivr.net/npm/[email protected]/examples/jsm/"
}}
}}
</script>
<script type="module">
import * as THREE from 'three';
import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js';
// シーンの初期化
const container = document.getElementById('container');
const loading = document.getElementById('loading');
const scene = new THREE.Scene();
scene.background = new THREE.Color(0x1a1a1a);
const camera = new THREE.PerspectiveCamera(
75,
container.clientWidth / container.clientHeight,
0.1,
1000
);
camera.position.set(0, 0, 5);
const renderer = new THREE.WebGLRenderer({{ antialias: true }});
renderer.setSize(container.clientWidth, container.clientHeight);
renderer.setPixelRatio(window.devicePixelRatio);
container.appendChild(renderer.domElement);
// OrbitControls
const controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.05;
// ライト
const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
scene.add(ambientLight);
const directionalLight = new THREE.DirectionalLight(0xffffff, 1);
directionalLight.position.set(5, 10, 7.5);
scene.add(directionalLight);
// グリッドヘルパー
const gridHelper = new THREE.GridHelper(10, 10);
scene.add(gridHelper);
// PLYローダー
async function loadPLY() {{
try {{
// Base64からArrayBufferに変換
const plyBase64 = '{ply_base64}';
const binaryString = atob(plyBase64);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {{
bytes[i] = binaryString.charCodeAt(i);
}}
// PLYLoaderを動的にインポート
const {{ PLYLoader }} = await import('three/addons/loaders/PLYLoader.js');
const loader = new PLYLoader();
// ArrayBufferをBlob経由でロード
const blob = new Blob([bytes], {{ type: 'application/octet-stream' }});
const url = URL.createObjectURL(blob);
loader.load(
url,
function (geometry) {{
loading.style.display = 'none';
// ポイントクラウドとしてレンダリング
geometry.computeVertexNormals();
// カラー情報があるか確認
const hasColors = geometry.attributes.color !== undefined;
const material = new THREE.PointsMaterial({{
size: 0.01,
vertexColors: hasColors,
color: hasColors ? undefined : 0x00ff00,
sizeAttenuation: true
}});
const points = new THREE.Points(geometry, material);
scene.add(points);
// カメラ位置を調整
geometry.computeBoundingBox();
const bbox = geometry.boundingBox;
const center = new THREE.Vector3();
bbox.getCenter(center);
const size = new THREE.Vector3();
bbox.getSize(size);
const maxDim = Math.max(size.x, size.y, size.z);
const fov = camera.fov * (Math.PI / 180);
let cameraZ = Math.abs(maxDim / Math.tan(fov / 2));
cameraZ *= 1.5;
camera.position.set(center.x, center.y, center.z + cameraZ);
camera.lookAt(center);
controls.target.copy(center);
controls.update();
URL.revokeObjectURL(url);
console.log('✅ PLYファイルの読み込み完了');
}},
function (xhr) {{
const percent = (xhr.loaded / xhr.total * 100).toFixed(0);
loading.textContent = `🔄 読み込み中... ${{percent}}%`;
}},
function (error) {{
console.error('❌ PLY読み込みエラー:', error);
loading.textContent = '❌ 読み込みエラー';
loading.style.color = 'red';
}}
);
}} catch (error) {{
console.error('❌ エラー:', error);
loading.textContent = '❌ エラーが発生しました';
loading.style.color = 'red';
}}
}}
// アニメーションループ
function animate() {{
requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
}}
// リサイズ対応
window.addEventListener('resize', () => {{
camera.aspect = container.clientWidth / container.clientHeight;
camera.updateProjectionMatrix();
renderer.setSize(container.clientWidth, container.clientHeight);
}});
// PLYを読み込んで開始
loadPLY();
animate();
</script>
</body>
</html>
"""
return html
def update_viewer(ply_base64: str) -> str:
"""ビューアを更新"""
return create_viewer_html(ply_base64)
# Gradio UI
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"),
title="SHARP: 3D Gaussian Splats Generator"
) as demo:
# SHARPステータスバナー
if SHARP_AVAILABLE:
status_banner = """
# 🎨 SHARP: 単一画像から3D Gaussian Splatsを生成
✅ **SHARPモデル: 正常に読み込まれました**
"""
else:
status_banner = f"""
# 🎨 SHARP: 単一画像から3D Gaussian Splatsを生成
⚠️ **警告: SHARPモデルが読み込めませんでした**
エラー: `{SHARP_ERROR}`
Spaceのログを確認するか、リポジトリの管理者にお問い合わせください。
"""
gr.Markdown(status_banner)
gr.Markdown("""
Appleの最新技術「SHARP」を使用して、1枚の画像から高品質な3D Gaussian Splatsを生成します。
生成された3DモデルはThree.jsで右側にリアルタイムプレビューされます。
### 使い方
1. 左側のエリアに画像をアップロード
2. 「生成開始」ボタンをクリック
3. 右側で3Dモデルをインタラクティブに確認
4. PLYファイルをダウンロード可能
**ZeroGPU (Nvidia H200)** で高速に処理されます 🚀
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📸 入力画像")
input_image = gr.Image(
label="画像をアップロード",
type="pil",
sources=["upload", "clipboard"],
height=400
)
generate_btn = gr.Button(
"🚀 生成開始",
variant="primary",
size="lg"
)
status_box = gr.Textbox(
label="ステータス",
lines=3,
interactive=False
)
output_file = gr.File(
label="📦 PLYファイルをダウンロード",
interactive=False
)
with gr.Column(scale=1):
gr.Markdown("### 🎬 3Dプレビュー (Three.js)")
viewer_html = gr.HTML(
create_viewer_html(""),
label="3D Viewer"
)
# 非表示のステート (PLY Base64データ)
ply_data_state = gr.State("")
# イベントハンドラ
def on_generate(image):
ply_path, status, ply_base64 = process_image(image)
viewer = create_viewer_html(ply_base64)
return ply_path, status, ply_base64, viewer
generate_btn.click(
fn=on_generate,
inputs=[input_image],
outputs=[output_file, status_box, ply_data_state, viewer_html]
)
gr.Markdown("""
---
### ℹ️ 技術情報
- **モデル**: Apple SHARP (Sharp Monocular View Synthesis)
- **出力形式**: PLY (Polygon File Format)
- **レンダリング**: Three.js + PLYLoader
- **GPU**: ZeroGPU (Nvidia H200, 動的割り当て)
- **処理時間**: 通常1秒以下
### 📚 リソース
- [SHARP GitHub](https://github.com/apple/ml-sharp)
- [論文 (arXiv)](https://arxiv.org/abs/2512.10685)
- [Hugging Face Model](https://huggingface.co/apple/Sharp)
### ⚠️ 注意事項
- 処理にはGPUを使用するため、待機時間が発生する場合があります
- ZeroGPUは60秒のタイムアウトがあります
- 大きな画像は自動的にリサイズされます
""")
if __name__ == "__main__":
demo.launch()