"""
Hardware-accelerated pooling operations for Helium virtual GPU
"""
from typing import Optional, Union, Tuple, List, TYPE_CHECKING
from virtual_gpu_driver.src.ai.tensor_types import (
    TensorDescriptor, DType, Device, Layout,
    PoolingDescriptor, PoolingMode
)
from .main import get_device, get_default_device

if TYPE_CHECKING:
    from .main import HeliumTensor

class HeliumPooling2D:
    """Base class for 2D pooling operations on virtual GPU"""
    
    def __init__(
        self,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Optional[Union[int, Tuple[int, int]]] = None,
        padding: Union[int, Tuple[int, int]] = 0,
        device_id: Optional[str] = None
    ):
        # Get virtual GPU driver
        self.driver = get_device(device_id) if device_id else get_default_device()
        self.device_id = device_id
        
        # Parse kernel size
        if isinstance(kernel_size, int):
            self.kernel_height = self.kernel_width = kernel_size
        else:
            self.kernel_height, self.kernel_width = kernel_size
            
        # Parse stride
        if stride is None:
            self.stride_height = self.kernel_height
            self.stride_width = self.kernel_width
        elif isinstance(stride, int):
            self.stride_height = self.stride_width = stride
        else:
            self.stride_height, self.stride_width = stride
            
        # Parse padding
        if isinstance(padding, int):
            self.padding_height = self.padding_width = padding
        else:
            self.padding_height, self.padding_width = padding
            
        # Track allocated tensors
        self._temp_tensors = {}
        self._counter = 0
        
    def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
        """Allocate temporary tensor in device memory"""
        tensor_id = f"pool_temp_{self._counter}"
        self._counter += 1
        
        descriptor = TensorDescriptor(
            shape=shape,
            dtype=getattr(DType, dtype.upper()),
            device=Device.VGPU,
            layout=Layout.NHWC  # Use NHWC for better performance on GPU
        )
        
        self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
        return tensor_id
        
    def _free_temp_tensor(self, tensor_id: str):
        """Release temporary tensor memory"""
        if tensor_id in self._temp_tensors:
            self.driver.free_tensor(self._temp_tensors[tensor_id])
            del self._temp_tensors[tensor_id]
            
    def __del__(self):
        """Cleanup temporary tensors"""
        for tensor_id in list(self._temp_tensors.keys()):
            self._free_temp_tensor(tensor_id)
            
    def _create_pooling_descriptor(self, mode: PoolingMode) -> PoolingDescriptor:
        """Create pooling descriptor for hardware"""
        return PoolingDescriptor(
            mode=mode,
            kernel_height=self.kernel_height,
            kernel_width=self.kernel_width,
            stride_height=self.stride_height,
            stride_width=self.stride_width,
            padding_height=self.padding_height,
            padding_width=self.padding_width
        )

class MaxPool2D(HeliumPooling2D):
    """2D max pooling layer"""
    
    def forward(
        self,
        input_tensor: Union[str, "HeliumTensor"],
        stream_id: Optional[int] = None
    ) -> Union[str, "HeliumTensor"]:
        """
        Compute 2D max pooling
        
        Args:
            input_tensor: Input tensor (NCHW format)
            stream_id: Optional stream for async execution
        """
        # Get input tensor info
        if isinstance(input_tensor, str):
            tensor_name = input_tensor
            tensor_info = self.driver.get_tensor_info(tensor_name)
            input_shape = tensor_info.shape
            dtype = tensor_info.dtype
        else:
            tensor_name = input_tensor.name
            input_shape = input_tensor.shape
            dtype = input_tensor.dtype.name
            
        # Calculate output shape
        batch_size, channels, in_height, in_width = input_shape
        out_height = (in_height + 2*self.padding_height - self.kernel_height) // self.stride_height + 1
        out_width = (in_width + 2*self.padding_width - self.kernel_width) // self.stride_width + 1
        output_shape = (batch_size, channels, out_height, out_width)
        
        # Allocate output tensor
        output = self._get_temp_tensor(output_shape, dtype)
        
        # Create pooling descriptor
        pool_desc = self._create_pooling_descriptor(PoolingMode.MAX)
        
        # Execute pooling on device
        self.driver.pooling_forward(
            pooling_desc=pool_desc,
            input=tensor_name,
            output=output,
            stream_id=stream_id
        )
        
        return output

class AvgPool2D(HeliumPooling2D):
    """2D average pooling layer"""
    
    def forward(
        self,
        input_tensor: Union[str, "HeliumTensor"],
        stream_id: Optional[int] = None
    ) -> Union[str, "HeliumTensor"]:
        """
        Compute 2D average pooling
        
        Args:
            input_tensor: Input tensor (NCHW format)
            stream_id: Optional stream for async execution
        """
        # Get input tensor info
        if isinstance(input_tensor, str):
            tensor_name = input_tensor
            tensor_info = self.driver.get_tensor_info(tensor_name)
            input_shape = tensor_info.shape
            dtype = tensor_info.dtype
        else:
            tensor_name = input_tensor.name
            input_shape = input_tensor.shape
            dtype = input_tensor.dtype.name
            
        # Calculate output shape
        batch_size, channels, in_height, in_width = input_shape
        out_height = (in_height + 2*self.padding_height - self.kernel_height) // self.stride_height + 1
        out_width = (in_width + 2*self.padding_width - self.kernel_width) // self.stride_width + 1
        output_shape = (batch_size, channels, out_height, out_width)
        
        # Allocate output tensor
        output = self._get_temp_tensor(output_shape, dtype)
        
        # Create pooling descriptor
        pool_desc = self._create_pooling_descriptor(PoolingMode.AVERAGE)
        
        # Execute pooling on device
        self.driver.pooling_forward(
            pooling_desc=pool_desc,
            input=tensor_name,
            output=output,
            stream_id=stream_id
        )
        
        return output

class GlobalAvgPool2D:
    """Global average pooling layer"""
    
    def __init__(self, device_id: Optional[str] = None):
        self.driver = get_device(device_id) if device_id else get_default_device()
        self.device_id = device_id
        self._temp_tensors = {}
        self._counter = 0
        
    def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
        tensor_id = f"global_pool_temp_{self._counter}"
        self._counter += 1
        
        descriptor = TensorDescriptor(
            shape=shape,
            dtype=getattr(DType, dtype.upper()),
            device=Device.VGPU,
            layout=Layout.NHWC
        )
        
        self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
        return tensor_id
        
    def forward(
        self,
        input_tensor: Union[str, "HeliumTensor"],
        stream_id: Optional[int] = None
    ) -> Union[str, "HeliumTensor"]:
        """
        Compute global average pooling
        
        Args:
            input_tensor: Input tensor (NCHW format)
            stream_id: Optional stream for async execution
        """
        # Get input info
        if isinstance(input_tensor, str):
            tensor_name = input_tensor
            tensor_info = self.driver.get_tensor_info(tensor_name)
            input_shape = tensor_info.shape
            dtype = tensor_info.dtype
        else:
            tensor_name = input_tensor.name
            input_shape = input_tensor.shape
            dtype = input_tensor.dtype.name
            
        batch_size, channels = input_shape[0], input_shape[1]
        output_shape = (batch_size, channels)
        
        # Allocate output
        output = self._get_temp_tensor(output_shape, dtype)
        
        # Execute global pooling
        self.driver.reduce_mean(
            input=tensor_name,
            output=output,
            dims=(2, 3),  # Height and width dimensions
            stream_id=stream_id
        )
        
        return output

# Functional interface
def max_pool2d(
    x: Union[str, "HeliumTensor"],
    kernel_size: Union[int, Tuple[int, int]],
    stride: Optional[Union[int, Tuple[int, int]]] = None,
    padding: Union[int, Tuple[int, int]] = 0,
    device_id: Optional[str] = None,
    stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
    """Functional interface for 2D max pooling"""
    module = MaxPool2D(kernel_size, stride, padding, device_id)
    return module.forward(x, stream_id)

def avg_pool2d(
    x: Union[str, "HeliumTensor"],
    kernel_size: Union[int, Tuple[int, int]],
    stride: Optional[Union[int, Tuple[int, int]]] = None,
    padding: Union[int, Tuple[int, int]] = 0,
    device_id: Optional[str] = None,
    stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
    """Functional interface for 2D average pooling"""
    module = AvgPool2D(kernel_size, stride, padding, device_id)
    return module.forward(x, stream_id)

def global_avg_pool2d(
    x: Union[str, "HeliumTensor"],
    device_id: Optional[str] = None,
    stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
    """Functional interface for global average pooling"""
    module = GlobalAvgPool2D(device_id)
    return module.forward(x, stream_id)
