from typing import Optional, Union, Tuple, List, Dict
import numpy as np
from dataclasses import dataclass
from enum import Enum
import warnings
from .core.db_manager import HeliumDBManager
from virtual_gpu_driver.src.ai.tensor_types import Tensor, Device, DType
import hashlib
import json
from functools import lru_cache

class NormType(Enum):
    """Supported normalization types"""
    BATCH = "batch"
    LAYER = "layer"
    GROUP = "group"
    INSTANCE = "instance"
    RMS = "rms"

def normalize(input: Tensor,
             mean: Optional[Tensor] = None,
             variance: Optional[Tensor] = None,
             weight: Optional[Tensor] = None,
             bias: Optional[Tensor] = None,
             eps: float = 1e-5) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Normalizes the input using mean and variance. If mean/variance not provided,
    they are computed from the input.
    
    Args:
        input: Input tensor
        mean: Optional pre-computed mean
        variance: Optional pre-computed variance  
        weight: Optional scale parameter
        bias: Optional bias parameter
        eps: Small constant for numerical stability
        
    Returns:
        Tuple of (normalized tensor, mean, variance)
    """
    # Calculate mean and variance if not provided
    if mean is None or variance is None:
        # Compute stats along last dimension
        axes = tuple(range(input.ndim - 1))
        mean = input.mean(axis=axes, keepdims=True)
        variance = input.var(axis=axes, keepdims=True)
    
    # Normalize
    denom = (variance + eps).sqrt()
    normalized = (input - mean) / denom
    
    # Apply scale and bias if provided
    if weight is not None:
        normalized = normalized * weight
    if bias is not None:
        normalized = normalized + bias
        
    return normalized, mean, variance

@dataclass
class NormConfig:
    """Configuration for normalization layers"""
    norm_type: NormType
    num_features: int
    eps: float = 1e-5
    momentum: float = 0.1
    affine: bool = True
    num_groups: int = 32  # For group norm
    track_running_stats: bool = True
    dtype: np.dtype = np.float32
    use_cache: bool = True

class NormalizationCache:
    """Cache manager for normalization computations"""
    def __init__(self):
        self.db = HeliumDBManager.get_instance()
        self.running_means: Dict[str, np.ndarray] = {}
        self.running_vars: Dict[str, np.ndarray] = {}
        
    def _compute_key(self, x: np.ndarray, norm_type: NormType) -> str:
        """Compute cache key for input"""
        hasher = hashlib.sha256()
        hasher.update(x.tobytes())
        hasher.update(norm_type.value.encode())
        return hasher.hexdigest()
        
    def get(self, key: str) -> Optional[Dict[str, np.ndarray]]:
        """Get cached computation"""
        return self.db.get_activation(key)
        
    def set(self, key: str, value: Dict[str, np.ndarray], metadata: Dict):
        """Cache computation"""
        self.db.set_activation(key, value, metadata)
        
    def update_running_stats(
        self,
        key: str,
        mean: np.ndarray,
        var: np.ndarray,
        momentum: float
    ):
        """Update running statistics"""
        if key in self.running_means:
            self.running_means[key] = (
                (1 - momentum) * self.running_means[key] +
                momentum * mean
            )
            self.running_vars[key] = (
                (1 - momentum) * self.running_vars[key] +
                momentum * var
            )
        else:
            self.running_means[key] = mean
            self.running_vars[key] = var

class Normalization:
    """
    Unified normalization implementation with support for:
    - Multiple normalization types
    - Hardware acceleration
    - Mixed precision
    - Computation caching
    - Running statistics tracking
    """
    
    def __init__(
        self,
        config: NormConfig,
        driver = None
    ):
        """Initialize normalization layer"""
        self.config = config
        self.driver = driver
        self.cache = NormalizationCache()
        
        # Initialize learnable parameters if needed
        if config.affine:
            self.gamma = np.ones(config.num_features, dtype=config.dtype)
            self.beta = np.zeros(config.num_features, dtype=config.dtype)
        else:
            self.gamma = None
            self.beta = None
            
    @staticmethod
    @lru_cache(maxsize=128)
    def _get_reshape_dims(
        input_shape: Tuple[int, ...],
        num_features: int,
        norm_type: NormType
    ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
        """Get reshaping dimensions for parameters"""
        ndim = len(input_shape)
        if norm_type == NormType.BATCH:
            param_shape = (1, num_features) + (1,) * (ndim - 2)
            reduction_axes = (0,) + tuple(range(2, ndim))
        elif norm_type == NormType.LAYER:
            param_shape = (1,) * (ndim - 1) + (num_features,)
            reduction_axes = tuple(range(ndim - 1))
        else:  # GROUP, INSTANCE
            param_shape = (1, num_features) + (1,) * (ndim - 2)
            reduction_axes = (2,) + tuple(range(3, ndim))
        return param_shape, reduction_axes

    def _check_input(self, x: np.ndarray):
        """Validate input tensor"""
        if x.ndim < 2:
            raise ValueError(f"Expected at least 2D input, got shape {x.shape}")
            
        if self.config.norm_type in [NormType.BATCH, NormType.GROUP]:
            if x.shape[1] != self.config.num_features:
                raise ValueError(
                    f"Expected {self.config.num_features} features, got {x.shape[1]}"
                )

    def _compute_stats(
        self,
        x: np.ndarray,
        reduction_axes: Tuple[int, ...]
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Compute mean and variance"""
        if self.driver and hasattr(self.driver, 'reduce_mean'):
            mean = self.driver.reduce_mean(x, axis=reduction_axes, keepdims=True)
            var = self.driver.reduce_var(x, axis=reduction_axes, keepdims=True)
        else:
            mean = np.mean(x, axis=reduction_axes, keepdims=True)
            var = np.var(x, axis=reduction_axes, keepdims=True)
        return mean, var


    def normalize(
        self,
        x: np.ndarray,
        training: bool = True
    ) -> np.ndarray:
        """
        Apply normalization to input tensor
        """
        self._check_input(x)
        
        # Get cache key and check cache
        if self.config.use_cache and not training:
            cache_key = self.cache._compute_key(x, self.config.norm_type)
            cached = self.cache.get(cache_key)
            if cached is not None:
                return cached['output']
                
        # Get reshaping dimensions
        param_shape, reduction_axes = self._get_reshape_dims(
            x.shape,
            self.config.num_features,
            self.config.norm_type
        )
        
        # Special handling for group norm
        if self.config.norm_type == NormType.GROUP:
            groups = self.config.num_groups
            N, C = x.shape[:2]
            x = x.reshape(N, groups, C // groups, *x.shape[2:])
            reduction_axes = (2,) + tuple(range(3, x.ndim))
            
        # Compute statistics
        mean, var = self._compute_stats(x, reduction_axes)
        
        # Update running statistics during training
        if training and self.config.track_running_stats:
            self.cache.update_running_stats(
                str(id(self)),
                mean,
                var,
                self.config.momentum
            )
            
        # Normalize
        x_norm = (x - mean) / np.sqrt(var + self.config.eps)
        
        # Reshape back if group norm
        if self.config.norm_type == NormType.GROUP:
            x_norm = x_norm.reshape(N, C, *x.shape[3:])
            
        # Apply affine transform if needed
        if self.config.affine:
            gamma = self.gamma.reshape(param_shape)
            beta = self.beta.reshape(param_shape)
            out = gamma * x_norm + beta
        else:
            out = x_norm
            
        # Cache result if needed
        if self.config.use_cache and not training:
            self.cache.set(
                cache_key,
                {
                    'output': out,
                    'mean': mean,
                    'var': var
                },
                {
                    'shape': x.shape,
                    'dtype': str(x.dtype),
                    'norm_type': self.config.norm_type.value
                }
            )
            
        return out

    @classmethod
    def batch_norm(
        cls,
        x: np.ndarray,
        num_features: Optional[int] = None,
        **kwargs
    ) -> np.ndarray:
        """Convenience method for batch normalization"""
        config = NormConfig(
            norm_type=NormType.BATCH,
            num_features=num_features or x.shape[1],
            **kwargs
        )
        return cls(config).normalize(x)
        
    @classmethod
    def layer_norm(
        cls,
        x: np.ndarray,
        num_features: Optional[int] = None,
        **kwargs
    ) -> np.ndarray:
        """Convenience method for layer normalization"""
        config = NormConfig(
            norm_type=NormType.LAYER,
            num_features=num_features or x.shape[-1],
            **kwargs
        )
        return cls(config).normalize(x)
        
    @classmethod
    def group_norm(
        cls,
        x: np.ndarray,
        num_features: Optional[int] = None,
        num_groups: int = 32,
        **kwargs
    ) -> np.ndarray:
        """Convenience method for group normalization"""
        config = NormConfig(
            norm_type=NormType.GROUP,
            num_features=num_features or x.shape[1],
            num_groups=num_groups,
            **kwargs
        )
        return cls(config).normalize(x)
        
    @classmethod
    def instance_norm(
        cls,
        x: np.ndarray,
        num_features: Optional[int] = None,
        **kwargs
    ) -> np.ndarray:
        """Convenience method for instance normalization"""
        config = NormConfig(
            norm_type=NormType.INSTANCE,
            num_features=num_features or x.shape[1],
            **kwargs
        )
        return cls(config).normalize(x)
        
    @classmethod
    def rms_norm(
        cls,
        x: np.ndarray,
        num_features: Optional[int] = None,
        **kwargs
    ) -> np.ndarray:
        """Convenience method for RMS normalization"""
        config = NormConfig(
            norm_type=NormType.RMS,
            num_features=num_features or x.shape[-1],
            track_running_stats=False,  # RMS norm doesn't use running stats
            **kwargs
        )
        return cls(config).normalize(x)
