import numpy as np
from typing import Optional, Union, Tuple, Dict, Any
import math
from contextlib import contextmanager
from virtual_gpu_driver.src.ai.tensor_types import Tensor, Device, DType
from .module import HeliumModule

class HeliumGELU(HeliumModule):
    """
    Implements the Gaussian Error Linear Unit (GELU) activation function.
    GELU(x) = x * Φ(x) where Φ(x) is the Gaussian CDF.
    """
    def __init__(self, approximate: str = 'tanh'):
        super().__init__()
        self.approximate = approximate
        
    def forward(self, x: Tensor) -> Tensor:
        """Forward pass of GELU activation"""
        return gelu(x)

def gelu(x: Tensor) -> Tensor:
    """
    Gaussian Error Linear Unit (GELU) activation function
    
    Args:
        x: Input tensor
        
    Returns:
        GELU activation applied to input
    """
    # GELU(x) = x * Φ(x)
    # where Φ(x) is the cumulative distribution function of the standard normal distribution
    # We use the approximation: GELU(x) ≈ 0.5x(1 + tanh(√(2/π)(x + 0.044715x³)))
    
    sqrt_2_over_pi = np.sqrt(2 / np.pi)
    inner = sqrt_2_over_pi * (x + 0.044715 * np.power(x, 3))
    cdf = 0.5 * (1.0 + np.tanh(inner))
    return x * cdf

class TensorResourceManager:
    """Efficient tensor resource management with context support"""
    def __init__(self, driver, prefix: str):
        self.driver = driver
        self.prefix = prefix
        self.counter = 0
        self._temp_tensors = set()

    @contextmanager
    def temp_tensor(self, data, name_suffix: str = ""):
        """Context manager for automatic tensor cleanup"""
        name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
        self.counter += 1
        self.driver.create_tensor(name, data)
        self._temp_tensors.add(name)
        try:
            yield name
        finally:
            self.free_tensor(name)

    def free_tensor(self, name: str):
        """Safe tensor cleanup"""
        if name in self._temp_tensors:
            if self.driver.tensor_exists(name):
                self.driver.delete_tensor(name)
            self._temp_tensors.remove(name)

    def cleanup(self):
        """Cleanup all remaining temporary tensors"""
        for name in list(self._temp_tensors):
            self.free_tensor(name)

def validate_input(x: Union[str, np.ndarray], driver) -> Tuple[bool, str]:
    """Validate input tensor or array"""
    if driver is not None:
        if not driver.tensor_exists(x):
            return False, f"Tensor {x} does not exist"
        if not driver.is_valid_tensor(x):
            return False, f"Invalid tensor {x}"
    elif not isinstance(x, (np.ndarray, list, tuple)):
        return False, "Input must be a numpy array when driver is None"
    return True, ""

def gelu_numpy(x: np.ndarray) -> np.ndarray:
    """Optimized NumPy implementation of GELU activation"""
    # Fast approximation using tanh
    # y = 0.5x(1 + tanh(√(2/π)(x + 0.044715x³)))
    x = np.asarray(x, dtype=np.float32)
    
    # Fused multiply-add operations for better performance
    cdf = 0.044715 * x * x * x
    cdf = (np.sqrt(2 / np.pi) * (x + cdf)).astype(np.float32)
    
    # Use fast approximation for small values
    mask = np.abs(x) < 1e-4
    result = np.where(mask,
                     x * 0.5 * (1.0 + x * 0.797884560802865),  # Linear approximation
                     0.5 * x * (1.0 + np.tanh(cdf)))  # Full computation
    
    return result.astype(x.dtype)

def gelu(
    x_name: Union[str, np.ndarray],
    driver = None,
    chip_id: int = 0,
    sm_id: int = 0
) -> Union[str, np.ndarray]:
    """
    Optimized GELU activation function with support for both driver-based and NumPy computation.
    
    Args:
        x_name: Input tensor name (str) or numpy array
        driver: Computation driver instance (optional)
        chip_id: Chip identifier for multi-GPU systems
        sm_id: Streaming multiprocessor identifier
    
    Returns:
        str: Name of output tensor in driver memory, or
        np.ndarray: Computed GELU values when using NumPy
    
    Raises:
        ValueError: If input validation fails
    """
    # Input validation
    is_valid, error_msg = validate_input(x_name, driver)
    if not is_valid:
        raise ValueError(error_msg)

    # NumPy fallback with optimized implementation
    if driver is None:
        return gelu_numpy(x_name)

    # Driver-based computation with optimized memory management
    manager = TensorResourceManager(driver, f"gelu_{chip_id}_{sm_id}")
    
    try:
        # Constants (precomputed for efficiency)
        SQRT_2_DIV_PI = np.float32(np.sqrt(2 / np.pi))
        COEFF = np.float32(0.044715)

        # Check if driver supports fused operations
        has_fused_ops = hasattr(driver, 'fused_gelu')
        if has_fused_ops:
            return driver.fused_gelu(x_name)

        # Optimized computation path using minimal temporary tensors
        with manager.temp_tensor(driver.power(x_name, 3), "cube") as cube_name:
            # Fused multiply-add: x + 0.044715x³
            with manager.temp_tensor(
                driver.fused_multiply_add(x_name, cube_name, COEFF),
                "inner"
            ) as inner_name:
                # Compute tanh(sqrt(2/π) * (x + 0.044715x³))
                with manager.temp_tensor(
                    driver.tanh(driver.mul_scalar(inner_name, SQRT_2_DIV_PI)),
                    "tanh"
                ) as tanh_name:
                    # Final computation: 0.5x(1 + tanh(...))
                    with manager.temp_tensor(
                        driver.fused_multiply_add(x_name, tanh_name, 1.0, 0.5),
                        "output"
                    ) as output_name:
                        # Create final tensor with optimal memory layout
                        result_name = f"gelu_result_{chip_id}_{sm_id}"
                        driver.create_tensor(result_name, driver.get_tensor_data(output_name))
                        return result_name

    finally:
        # Ensure all temporary tensors are cleaned up
        manager.cleanup()
