"""
Safetensors model loading utilities for Helium virtual GPU infrastructure.
Provides efficient loading and saving of model weights in the safetensors format.
"""

import os
from pathlib import Path
from typing import Dict, Any, Union, Optional, BinaryIO
import numpy as np
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout
from .main import HeliumTensor, get_device, get_default_device

try:
    import safetensors.numpy
    HAS_SAFETENSORS = True
except ImportError:
    HAS_SAFETENSORS = False

class SafetensorError(Exception):
    """Base exception for safetensor operations"""
    pass

class SafetensorLoadError(SafetensorError):
    """Raised when loading safetensors fails"""
    pass

class SafetensorSaveError(SafetensorError):
    """Raised when saving safetensors fails"""
    pass

def _validate_tensor_dict(tensors: Dict[str, Any]) -> None:
    """Validate tensor dictionary format"""
    if not isinstance(tensors, dict):
        raise ValueError("Expected dictionary of tensors")
    for name, tensor in tensors.items():
        if not isinstance(name, str):
            raise ValueError(f"Tensor key must be string, got {type(name)}")
        if not isinstance(tensor, (np.ndarray, HeliumTensor)):
            raise ValueError(f"Value for key {name} must be numpy array or HeliumTensor")

def load_safetensors(
    file_path: Union[str, Path, BinaryIO],
    device_id: Optional[str] = None,
    dtype: Optional[str] = None
) -> Dict[str, 'HeliumTensor']:
    """
    Load safetensors format model weights into Helium tensors.
    
    Args:
        file_path: Path to safetensors file or file-like object
        device_id: Target virtual GPU device ID (default: current default device)
        dtype: Optional dtype to cast tensors to
    
    Returns:
        Dictionary mapping names to HeliumTensor objects
    
    Raises:
        SafetensorLoadError: If loading fails
        ImportError: If safetensors package is not available
    """
    if not HAS_SAFETENSORS:
        raise ImportError(
            "safetensors package is required. "
            "Install with 'pip install safetensors'"
        )
    
    try:
        # Get virtual GPU driver
        driver = get_device(device_id) if device_id else get_default_device()
        
        # Load raw tensors
        tensors = safetensors.numpy.load_file(str(file_path))
        
        # Convert to Helium tensors
        weights = {}
        for name, array in tensors.items():
            # Create tensor descriptor
            tensor_desc = TensorDescriptor(
                shape=array.shape,
                dtype=getattr(DType, (dtype or str(array.dtype)).upper()),
                device=Device.VGPU,
                layout=Layout.ROW_MAJOR
            )
            
            # Allocate tensor on virtual GPU
            tensor_id = driver.allocate_tensor(tensor_desc)
            
            # Copy data to virtual GPU
            driver.write_tensor(tensor_id, array)
            
            # Create HeliumTensor wrapper
            weights[name] = HeliumTensor(
                name=tensor_id,
                shape=array.shape,
                dtype=dtype or str(array.dtype),
                device_id=device_id
            )
        
        return weights
        
    except Exception as e:
        raise SafetensorLoadError(f"Failed to load safetensors: {str(e)}") from e

def save_safetensors(
    weights: Dict[str, Union[np.ndarray, 'HeliumTensor']],
    file_path: Union[str, Path, BinaryIO]
) -> None:
    """
    Save model weights in safetensors format.
    
    Args:
        weights: Dictionary of tensor names to numpy arrays or HeliumTensors
        file_path: Output path or file-like object
    
    Raises:
        SafetensorSaveError: If saving fails
        ImportError: If safetensors package is not available
    """
    if not HAS_SAFETENSORS:
        raise ImportError(
            "safetensors package is required. "
            "Install with 'pip install safetensors'"
        )
    
    try:
        _validate_tensor_dict(weights)
        
        # Convert Helium tensors to numpy
        tensors = {}
        for name, tensor in weights.items():
            if isinstance(tensor, HeliumTensor):
                # Get data from virtual GPU
                tensors[name] = tensor.cpu().numpy()
            else:
                tensors[name] = tensor
                
        # Create parent directories if needed
        if isinstance(file_path, (str, Path)):
            os.makedirs(os.path.dirname(os.path.abspath(file_path)), exist_ok=True)
        
        # Save tensors
        safetensors.numpy.save_file(tensors, str(file_path))
        
    except Exception as e:
        raise SafetensorSaveError(f"Failed to save safetensors: {str(e)}") from e
