from typing import Optional, Union, Dict, Any, TYPE_CHECKING
import numpy as np
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, Device, DType, Layout
from virtual_gpu_driver.src.stream import Stream
from .module import HeliumModule
from .core.db_manager import HeliumDBManager

if TYPE_CHECKING:
    from .tensor import HeliumTensor

class HeliumLayerNorm(HeliumModule):
    """
    Hardware-accelerated Layer Normalization implementation
    
    Applies Layer Normalization over a mini-batch of inputs as described in
    the paper "Layer Normalization" [Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton]
    
    y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
    
    where gamma (weight) and beta (bias) are learnable parameters.
    """
    
    def __init__(
        self,
        normalized_shape: int,
        eps: float = 1e-5,
        device_id: int = 0,
        dtype: str = "float32"
    ):
        """
        Initialize layer normalization module
        
        Args:
            normalized_shape: Size of the last dimension
            eps: Small value added to variance for numerical stability
            device_id: Virtual GPU device ID
            dtype: Data type for computations
        """
        super().__init__(device_id=device_id, dtype=dtype)
        
        self.normalized_shape = normalized_shape
        self.eps = eps
        
        # Create parameter tensors
        self.weight = self._create_param(normalized_shape)
        self.bias = self._create_param(normalized_shape)
        
        # Create stream for async execution
        self.stream = Stream(self.driver)
        
        # Get database manager instance
        self.db = HeliumDBManager.get_instance()
        
        # Initialize temp tensors dict
        self._temp_tensors = {}
        self._counter = 0
        
    def _create_param(self, size: int) -> str:
        """Create a parameter tensor"""
        desc = TensorDescriptor(
            shape=(size,),
            dtype=getattr(DType, self.dtype.upper()),
            device=Device.VGPU,
            layout=Layout.ROW_MAJOR
        )
        return self.driver.allocate_tensor(desc)
        
    def _get_temp_tensor(self, shape: tuple) -> str:
        """Get a temporary tensor for intermediate computations"""
        tensor_id = f"ln_temp_{self._counter}"
        self._counter += 1
        
        desc = TensorDescriptor(
            shape=shape,
            dtype=getattr(DType, self.dtype.upper()),
            device=Device.VGPU,
            layout=Layout.ROW_MAJOR
        )
        
        self._temp_tensors[tensor_id] = self.driver.allocate_tensor(desc)
        return tensor_id
    """
    All computations done in driver memory
    Returns: name of normalized tensor in driver
    """
    def _free_temp_tensor(self, tensor_id: str):
        """Free a temporary tensor"""
        if tensor_id in self._temp_tensors:
            self.driver.free_tensor(self._temp_tensors[tensor_id])
            del self._temp_tensors[tensor_id]
            
    def __del__(self):
        """Clean up allocated tensors"""
        if hasattr(self, '_temp_tensors'):
            for tensor_id in list(self._temp_tensors.keys()):
                self._free_temp_tensor(tensor_id)
                
        # Free parameter tensors
        if hasattr(self, 'weight'):
            self.driver.free_tensor(self.weight)
        if hasattr(self, 'bias'):
            self.driver.free_tensor(self.bias)
            
    def _check_input_shape(self, input_shape: tuple):
        """Validate input shape"""
        if input_shape[-1] != self.normalized_shape:
            raise ValueError(
                f"Expected last dimension to be {self.normalized_shape}, "
                f"got {input_shape[-1]}"
            )
            
    def forward(
        self, 
        input_tensor: Union[str, "HeliumTensor"],
        scale: Optional[Union[str, "HeliumTensor"]] = None,
        offset: Optional[Union[str, "HeliumTensor"]] = None
    ) -> Union[str, "HeliumTensor"]:
        """
        Apply layer normalization
        
        Args:
            input_tensor: Input of shape (*, normalized_shape)
            scale: Optional override for weight parameter
            offset: Optional override for bias parameter
            
        Returns:
            Normalized tensor of same shape as input
        """
        input_shape = self.driver.get_tensor_shape(input_tensor)
        self._check_input_shape(input_shape)
        
        with self.stream:
            # Calculate mean
            mean = self._get_temp_tensor(input_shape[:-1])
            self.driver.reduce_mean(
                input_tensor,
                mean,
                axis=-1,
                keepdims=True
            )
            
            # Calculate variance
            variance = self._get_temp_tensor(input_shape[:-1])
            self.driver.reduce_variance(
                input_tensor,
                variance,
                mean,
                axis=-1,
                keepdims=True
            )
            
            # Normalize
            normalized = self._get_temp_tensor(input_shape)
            self.driver.normalize(
                input_tensor,
                mean,
                variance,
                normalized,
                eps=self.eps
            )
            
            # Scale and offset
            scale_tensor = scale if scale is not None else self.weight
            offset_tensor = offset if offset is not None else self.bias
            
            output = self._get_temp_tensor(input_shape)
            self.driver.scale_and_shift(
                normalized,
                scale_tensor,
                offset_tensor,
                output
            )
            
            # Clean up intermediate tensors
            self._free_temp_tensor(mean)
            self._free_temp_tensor(variance)
            self._free_temp_tensor(normalized)
            
        return output
        
    def compute_variance(self, state, driver, x_name, mean_name, gamma_name, beta_name, eps=1e-5):
        """Compute variance in driver memory"""
        chip_id = driver.default_chip_id
        sm_id = driver.default_sm_id
        
        diff_name = state.get_temp_tensor(
            driver.sub(x_name, mean_name),
            "diff"
        )
        squared_name = state.get_temp_tensor(
            driver.mul(diff_name, diff_name),
            "squared"
        )
        var_name = state.get_temp_tensor(
            driver.mean(squared_name, axis=-1, keepdims=True),
            "var"
        )
        
        # Free intermediates
        state.free_temp_tensor(squared_name)
        
        # Normalize in driver memory
        std_name = state.get_temp_tensor(
            driver.sqrt(driver.add_scalar(var_name, eps)),
            "std"
        )
        normalized_name = state.get_temp_tensor(
            driver.div(diff_name, std_name),
            "normalized"
        )
        
        # Free more intermediates
        state.free_temp_tensor(diff_name)
        state.free_temp_tensor(mean_name)
        state.free_temp_tensor(var_name)
        state.free_temp_tensor(std_name)
        
        # Scale and shift in driver memory
        scaled_name = state.get_temp_tensor(
            driver.mul(normalized_name, gamma_name),
            "scaled"
        )
        output_name = state.get_temp_tensor(
            driver.add(scaled_name, beta_name),
            "output"
        )
        
        # Free final intermediates
        state.free_temp_tensor(normalized_name)
        state.free_temp_tensor(scaled_name)
        
        return output_name

def layer_norm(
    input_tensor: Union[str, "HeliumTensor"],
    normalized_shape: int,
    weight: Optional[Union[str, "HeliumTensor"]] = None,
    bias: Optional[Union[str, "HeliumTensor"]] = None,
    eps: float = 1e-5,
    device_id: int = 0,
    dtype: str = "float32"
) -> Union[str, "HeliumTensor"]:
    """
    Apply Layer Normalization over a mini-batch of inputs
    
    Args:
        input_tensor: Input of shape (*, normalized_shape)
        normalized_shape: Size of last dimension to normalize over
        weight: Optional scale parameter
        bias: Optional offset parameter
        eps: Small value for numerical stability
        device_id: Virtual GPU device ID
        dtype: Data type for computations
        
    Returns:
        Normalized tensor of same shape as input
    """
    module = HeliumLayerNorm(
        normalized_shape=normalized_shape,
        eps=eps,
        device_id=device_id,
        dtype=dtype
    )
    return module.forward(input_tensor, weight, bias)