"""
Hardware-accelerated softmax implementation for Helium virtual GPU
"""
from typing import Optional, Union, Tuple, TYPE_CHECKING
import helium as he
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout, Tensor
import numpy as np

def softmax(x: Tensor, dim: int = -1) -> Tensor:
    """
    Applies the softmax function along a dimension
    
    Args:
        x: Input tensor
        dim: Dimension along which to apply softmax
        
    Returns:
        Softmax output tensor
    """
    # For numerical stability, subtract the maximum value
    # before applying exp
    x_max = x.max(dim=dim, keepdim=True)
    exp_x = (x - x_max).exp()
    return exp_x / exp_x.sum(dim=dim, keepdim=True)

class HeliumSoftmax:
    """
    Optimized softmax implementation for virtual GPU
    Handles multi-head attention patterns efficiently
    """
    def __init__(self, device_id: Optional[str] = None):
        # Get virtual GPU driver
        self.driver = he.get_device(device_id) if device_id else he.get_default_device()
        self.device_id = device_id
        
        # Track allocated tensors for cleanup
        self._temp_tensors = {}
        self._counter = 0
        
    def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
        """Allocate temporary tensor in device memory"""
        tensor_id = f"softmax_temp_{self._counter}"
        self._counter += 1
        
        descriptor = TensorDescriptor(
            shape=shape,
            dtype=getattr(DType, dtype.upper()),
            device=Device.VGPU,
            layout=Layout.ROW_MAJOR
        )
        
        self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
        return tensor_id
        
    def _free_temp_tensor(self, tensor_id: str):
        """Release temporary tensor memory"""
        if tensor_id in self._temp_tensors:
            self.driver.free_tensor(self._temp_tensors[tensor_id])
            del self._temp_tensors[tensor_id]
            
    def __del__(self):
        """Cleanup all temporary tensors"""
        for tensor_id in list(self._temp_tensors.keys()):
            self._free_temp_tensor(tensor_id)

    def forward(
        self,
        input_tensor: Union[str, "HeliumTensor"],
        dim: int = -1,
        memory_efficient: bool = True,
        stream_id: Optional[int] = None
    ) -> Union[str, "HeliumTensor"]:
        """
        Compute softmax along specified dimension
        
        Args:
            input_tensor: Input tensor or tensor name in driver
            dim: Dimension to compute softmax over (-1 for last dim)
            memory_efficient: Use memory-efficient algorithm
            stream_id: Optional stream for async execution
            
        Returns:
            Softmax output tensor or tensor name
        """
        # Get input tensor info
        if isinstance(input_tensor, str):
            tensor_name = input_tensor
            tensor_info = self.driver.get_tensor_info(tensor_name)
            shape = tensor_info.shape
            dtype = tensor_info.dtype
        else:
            tensor_name = input_tensor.name
            shape = input_tensor.shape
            dtype = input_tensor.dtype.name
            
        # Handle negative dim
        if dim < 0:
            dim = len(shape) + dim
            
        # Memory-efficient implementation (streaming)
        if memory_efficient:
            # Compute max along dim
            max_shape = list(shape)
            max_shape[dim] = 1
            max_tensor = self._get_temp_tensor(tuple(max_shape), dtype)
            
            self.driver.reduce_max(
                input=tensor_name,
                output=max_tensor,
                dim=dim,
                stream_id=stream_id
            )
            
            # Subtract max (for numerical stability)
            shifted = self._get_temp_tensor(shape, dtype)
            self.driver.broadcast_sub(
                input=tensor_name,
                other=max_tensor,
                output=shifted,
                dim=dim,
                stream_id=stream_id
            )
            
            # Compute exp
            exp_tensor = shifted  # Reuse memory
            self.driver.exp(
                input=shifted,
                output=exp_tensor,
                stream_id=stream_id
            )
            
            # Compute sum
            sum_tensor = max_tensor  # Reuse memory
            self.driver.reduce_sum(
                input=exp_tensor,
                output=sum_tensor,
                dim=dim,
                stream_id=stream_id
            )
            
            # Final division
            output = self._get_temp_tensor(shape, dtype)
            self.driver.broadcast_div(
                input=exp_tensor,
                other=sum_tensor,
                output=output,
                dim=dim,
                stream_id=stream_id
            )
            
            # Cleanup
            self._free_temp_tensor(shifted)
            self._free_temp_tensor(sum_tensor)
            
            return output
            
        else:
            # Direct implementation (uses more memory but fewer kernels)
            output = self._get_temp_tensor(shape, dtype)
            
            self.driver.softmax(
                input=tensor_name,
                output=output,
                dim=dim,
                stream_id=stream_id
            )
            
            return output
            
def softmax(
    x: Union[str, "HeliumTensor"],
    dim: int = -1,
    device_id: Optional[str] = None,
    memory_efficient: bool = True,
    stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
    """
    Functional interface for softmax operation
    
    Args:
        x: Input tensor
        dim: Dimension to compute softmax over
        device_id: Virtual GPU device ID
        memory_efficient: Use memory-efficient algorithm
        stream_id: Optional stream for async execution
        
    Returns:
        Softmax output tensor
    """
    module = HeliumSoftmax(device_id)
    return module.forward(x, dim, memory_efficient, stream_id)
