"""
High-level device and tensor management for helium virtual GPU infrastructure.
Provides tensor operations and memory management for virtual devices.
"""
from typing import Optional, Union, List, Tuple, Dict, Any
import numpy as np
from enum import Enum
from virtual_gpu_driver.src.driver_api import get_storage_manager
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, Layout

class DeviceType(Enum):
    VGPU = "vgpu"
    VRAM = "vram"

class Device:
    """Represents a virtual compute device (VGPU or VRAM)"""
    def __init__(self, device_type: Union[str, DeviceType], index: int = 0):
        if isinstance(device_type, str):
            device_type = DeviceType(device_type.lower())
        self.type = device_type
        self.index = index
        
    def __str__(self) -> str:
        return f"{self.type.value}:{self.index}"
        
    @staticmethod
    def parse(device_str: Optional[Union[str, 'Device']] = None) -> 'Device':
        """Parse device string like 'vgpu:0'"""
        if device_str is None:
            return Device(DeviceType.VGPU)
        if isinstance(device_str, Device):
            return device_str
            
        parts = device_str.split(':')
        device_type = parts[0].lower()
        index = int(parts[1]) if len(parts) > 1 else 0
        return Device(device_type, index)

class dtensor:
    """
    Device tensor with automatic memory management and device placement
    Similar to PyTorch's tensor but using our virtual GPU infrastructure
    """
    def __init__(
        self, 
        data: Union[np.ndarray, List, Tuple],
        device: Optional[Union[str, Device]] = None,
        dtype: Optional[np.dtype] = None,
        requires_grad: bool = False,
        pin_memory: bool = False
    ):
        self.device = Device.parse(device)
        self.requires_grad = requires_grad
        self.pin_memory = pin_memory
        self.grad = None
        
        # Initialize tensor on device
        storage = get_storage_manager()
        self._handle = storage.allocate_tensor(
            np.asarray(data, dtype=dtype),
            device_type=self.device.type.value,
            device_index=self.device.index,
            pin_memory=self.pin_memory
        )
            
    @property
    def shape(self) -> Tuple[int, ...]:
        return get_storage_manager().get_tensor_shape(self._handle)
        
    @property
    def dtype(self) -> np.dtype:
        return get_storage_manager().get_tensor_dtype(self._handle)
        
    def to(self, device: Union[str, Device]) -> 'dtensor':
        """Move tensor to specified device"""
        device = Device.parse(device)
        if device.type == self.device.type and device.index == self.device.index:
            return self
            
        # Move to new device
        storage = get_storage_manager()
        new_handle = storage.move_tensor(
            self._handle,
            device_type=device.type.value,
            device_index=device.index
        )
        
        self._handle = new_handle
        self.device = device
        return self
        
    def vgpu(self, index: int = 0) -> 'dtensor':
        """Move tensor to virtual GPU"""
        return self.to(Device(DeviceType.VGPU, index))
        
    def vram(self, index: int = 0) -> 'dtensor':
        """Move tensor to VRAM"""
        return self.to(Device(DeviceType.VRAM, index))

    @staticmethod
    def from_numpy(array: np.ndarray) -> 'dtensor':
        """Create tensor from numpy array on default device"""
        return dtensor(array)
        
    def numpy(self) -> np.ndarray:
        """Get tensor data as numpy array"""
        return get_storage_manager().get_tensor_data(self._handle)
        
    def __repr__(self) -> str:
        return f"dtensor({self.data}, device='{self.device}')"

class TopK:
    """Efficient top-k computation on device"""
    def __init__(self, k: int, dim: int = -1):
        self.k = k
        self.dim = dim
        
    def __call__(self, tensor: dtensor) -> Tuple[dtensor, dtensor]:
        """Return top k values and indices"""
        storage = get_storage_manager()
        values_handle, indices_handle = storage.topk(
            tensor._handle,
            k=self.k,
            dim=self.dim,
            device_type=tensor.device.type.value,
            device_index=tensor.device.index
        )
        
        # Create new tensors with same device configuration
        values = dtensor(
            np.zeros((self.k,)),
            device=tensor.device,
            dtype=tensor.dtype
        )
        indices = dtensor(
            np.zeros((self.k,), dtype=np.int64),
            device=tensor.device
        )
        
        # Assign handles
        values._handle = values_handle
        indices._handle = indices_handle
        return values, indices

class Module:
    """Base class for all neural network modules"""
    def __init__(self):
        self.training = True
        self._parameters = {}
        self._buffers = {}
        self._modules = {}
        self._device = Device(DeviceType.CPU)
        
    def register_parameter(self, name: str, param: Optional[dtensor]) -> None:
        """Register a parameter with the module"""
        if param is None:
            self._parameters[name] = None
        else:
            self._parameters[name] = param.to(self._device)
            
    def register_buffer(self, name: str, tensor: Optional[dtensor]) -> None:
        """Register a persistent buffer"""
        if tensor is None:
            self._buffers[name] = None
        else:
            self._buffers[name] = tensor.to(self._device)
            
    def add_module(self, name: str, module: Optional['Module']) -> None:
        """Register a child module"""
        self._modules[name] = module
        
    def to(self, device: Union[str, Device]) -> 'Module':
        """Move module to specified device"""
        device = Device.parse(device)
        self._device = device
        
        # Move parameters
        for param in self._parameters.values():
            if param is not None:
                param.to(device)
                
        # Move buffers
        for buffer in self._buffers.values():
            if buffer is not None:
                buffer.to(device)
                
        # Move child modules
        for module in self._modules.values():
            if module is not None:
                module.to(device)
                
        return self
        
    def train(self, mode: bool = True) -> 'Module':
        """Set training mode"""
        self.training = mode
        for module in self._modules.values():
            if module is not None:
                module.train(mode)
        return self
        
    def eval(self) -> 'Module':
        """Set evaluation mode"""
        return self.train(False)
        
    def parameters(self) -> List[dtensor]:
        """Get all parameters"""
        params = []
        for param in self._parameters.values():
            if param is not None:
                params.append(param)
        for module in self._modules.values():
            if module is not None:
                params.extend(module.parameters())
        return params
        
    def state_dict(self) -> Dict[str, Any]:
        """Get module state"""
        state = {}
        for name, param in self._parameters.items():
            if param is not None:
                state[name] = param.cpu().numpy()
        for name, buffer in self._buffers.items():
            if buffer is not None:
                state[name] = buffer.cpu().numpy()
        for name, module in self._modules.items():
            if module is not None:
                state[name] = module.state_dict()
        return state
        
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Load module state"""
        for name, param_array in state_dict.items():
            if name in self._parameters:
                if self._parameters[name] is not None:
                    self._parameters[name].copy_(dtensor(param_array))
            elif name in self._buffers:
                if self._buffers[name] is not None:
                    self._buffers[name].copy_(dtensor(param_array))
            elif name in self._modules:
                if self._modules[name] is not None:
                    self._modules[name].load_state_dict(param_array)
            else:
                raise KeyError(f"Unexpected key in state_dict: {name}")

def get_storage_manager():
    """Get global storage manager instance"""
    from virtual_gpu_driver.src.memory.memory_manager import MemoryManager
    return MemoryManager()
