from typing import Optional, Union, Tuple, List
import numpy as np
from dataclasses import dataclass
import warnings
from .core.db_manager import HeliumDBManager
from enum import Enum
import hashlib
import json

class PaddingMode(Enum):
    """Supported padding modes"""
    ZEROS = "zeros"
    REFLECT = "reflect"
    REPLICATE = "replicate"
    CIRCULAR = "circular"

@dataclass
class Conv2DConfig:
    """Configuration for Conv2D layer"""
    in_channels: int
    out_channels: int
    kernel_size: Union[int, Tuple[int, int]]
    stride: Union[int, Tuple[int, int]] = 1
    padding: Union[int, Tuple[int, int]] = 0
    dilation: Union[int, Tuple[int, int]] = 1
    groups: int = 1
    padding_mode: PaddingMode = PaddingMode.ZEROS
    use_bias: bool = True
    dtype: np.dtype = np.float32
    use_winograd: bool = True  # Use Winograd algorithm for 3x3 convolutions
    use_im2col: bool = True    # Use im2col optimization for other sizes
    cache_size: int = 1024

class Conv2D:
    """
    Optimized 2D Convolution implementation with support for:
    - Hardware acceleration
    - Database-backed caching
    - Winograd convolution
    - Im2col optimization
    - Mixed precision
    - Memory optimization
    """
    
    def __init__(
        self,
        config: Conv2DConfig,
        weight: Optional[np.ndarray] = None,
        bias: Optional[np.ndarray] = None,
        driver = None
    ):
        """Initialize Conv2D layer"""
        self.config = config
        self.driver = driver
        self.db = HeliumDBManager.get_instance()
        
        # Convert scalar parameters to tuples
        self.kernel_size = self._to_tuple(config.kernel_size)
        self.stride = self._to_tuple(config.stride)
        self.padding = self._to_tuple(config.padding)
        self.dilation = self._to_tuple(config.dilation)
        
        # Initialize or validate weights
        if weight is not None:
            self._validate_weight(weight)
            self.weight = weight.astype(config.dtype)
        else:
            self.weight = self._initialize_weight()
            
        if config.use_bias:
            if bias is not None:
                self._validate_bias(bias)
                self.bias = bias.astype(config.dtype)
            else:
                self.bias = np.zeros(config.out_channels, dtype=config.dtype)
        else:
            self.bias = None
            
        # Initialize optimizations
        self._setup_optimizations()

    def _to_tuple(self, value: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
        """Convert integer to 2-tuple"""
        if isinstance(value, tuple):
            return value
        return (value, value)

    def _validate_weight(self, weight: np.ndarray):
        """Validate weight tensor dimensions"""
        expected_shape = (
            self.config.out_channels,
            self.config.in_channels // self.config.groups,
            self.kernel_size[0],
            self.kernel_size[1]
        )
        if weight.shape != expected_shape:
            raise ValueError(f"Weight shape {weight.shape} doesn't match expected {expected_shape}")

    def _validate_bias(self, bias: np.ndarray):
        """Validate bias tensor dimensions"""
        if bias.shape != (self.config.out_channels,):
            raise ValueError(
                f"Bias shape {bias.shape} doesn't match output channels {self.config.out_channels}"
            )

    def _initialize_weight(self) -> np.ndarray:
        """Initialize weight tensor"""
        fan_in = self.config.in_channels * self.kernel_size[0] * self.kernel_size[1]
        fan_out = self.config.out_channels * self.kernel_size[0] * self.kernel_size[1]
        std = np.sqrt(2.0 / (fan_in + fan_out))
        return np.random.normal(
            0.0,
            std,
            (
                self.config.out_channels,
                self.config.in_channels // self.config.groups,
                self.kernel_size[0],
                self.kernel_size[1]
            )
        ).astype(self.config.dtype)

    def _setup_optimizations(self):
        """Setup optimization strategies"""
        # Check if we can use Winograd (3x3 kernels with stride 1)
        self.use_winograd = (
            self.config.use_winograd and
            self.kernel_size == (3, 3) and
            self.stride == (1, 1) and
            self.dilation == (1, 1)
        )
        
        # Setup im2col for other cases
        self.use_im2col = (
            self.config.use_im2col and
            not self.use_winograd
        )
        
        # Initialize transform matrices for Winograd if needed
        if self.use_winograd:
            self._setup_winograd()

    def _setup_winograd(self):
        """Initialize Winograd transform matrices"""
        # F(2,3) Winograd transforms
        self.G = np.array([
            [1, 0, 0],
            [0.5, 0.5, 0.5],
            [0.5, -0.5, 0.5],
            [0, 0, 1]
        ], dtype=self.config.dtype)
        
        self.B = np.array([
            [1, 0, -1, 0],
            [0, 1, 1, 0],
            [0, -1, 1, 0],
            [0, 1, 0, -1]
        ], dtype=self.config.dtype)
        
        self.A = np.array([
            [1, 0, 0, 0],
            [1, 1, 1, 1],
            [1, -1, 1, -1],
            [0, 0, 0, 1]
        ], dtype=self.config.dtype)

    def _compute_cache_key(self, input_shape: Tuple) -> str:
        """Compute cache key for input shape"""
        config_str = json.dumps({
            'input_shape': input_shape,
            'kernel_size': self.kernel_size,
            'stride': self.stride,
            'padding': self.padding,
            'dilation': self.dilation,
            'groups': self.config.groups
        })
        return hashlib.sha256(config_str.encode()).hexdigest()

    def _pad_input(self, x: np.ndarray) -> np.ndarray:
        """Apply padding according to configuration"""
        if self.config.padding_mode == PaddingMode.ZEROS:
            return np.pad(
                x,
                ((0,0), (0,0), (self.padding[0],self.padding[0]), (self.padding[1],self.padding[1])),
                mode='constant'
            )
        elif self.config.padding_mode == PaddingMode.REFLECT:
            return np.pad(
                x,
                ((0,0), (0,0), (self.padding[0],self.padding[0]), (self.padding[1],self.padding[1])),
                mode='reflect'
            )
        elif self.config.padding_mode == PaddingMode.REPLICATE:
            return np.pad(
                x,
                ((0,0), (0,0), (self.padding[0],self.padding[0]), (self.padding[1],self.padding[1])),
                mode='edge'
            )
        else:  # CIRCULAR
            return np.pad(
                x,
                ((0,0), (0,0), (self.padding[0],self.padding[0]), (self.padding[1],self.padding[1])),
                mode='wrap'
            )

    def _im2col(self, x: np.ndarray) -> np.ndarray:
        """Convert input tensor to column matrix for efficient convolution"""
        batch_size, channels, height, width = x.shape
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.stride
        dilation_h, dilation_w = self.dilation
        
        # Calculate output dimensions
        out_h = (height - dilation_h * (kernel_h - 1) - 1) // stride_h + 1
        out_w = (width - dilation_w * (kernel_w - 1) - 1) // stride_w + 1
        
        # Create column matrix
        col = np.zeros(
            (batch_size, channels, kernel_h, kernel_w, out_h, out_w),
            dtype=x.dtype
        )
        
        # Fill column matrix
        for y in range(kernel_h):
            y_max = y * dilation_h + 1
            for x in range(kernel_w):
                x_max = x * dilation_w + 1
                col[:, :, y, x, :, :] = x[
                    :, :,
                    y * dilation_h:height - y_max + 1:stride_h,
                    x * dilation_w:width - x_max + 1:stride_w
                ]
        
        return col.transpose(0, 4, 5, 1, 2, 3).reshape(
            batch_size * out_h * out_w, -1
        )

    def _winograd_conv(self, x: np.ndarray) -> np.ndarray:
        """Perform convolution using Winograd algorithm"""
        batch_size, channels, height, width = x.shape
        
        # Transform filter (only needed once, could be precomputed)
        U = np.zeros(
            (self.config.out_channels, channels, 4, 4),
            dtype=self.config.dtype
        )
        for i in range(self.config.out_channels):
            for j in range(channels):
                U[i,j] = self.G @ self.weight[i,j] @ self.G.T
        
        # Process input tiles
        output = np.zeros(
            (
                batch_size,
                self.config.out_channels,
                height - 2,
                width - 2
            ),
            dtype=self.config.dtype
        )
        
        for b in range(batch_size):
            for i in range(0, height - 2, 2):
                for j in range(0, width - 2, 2):
                    # Extract 4x4 input tile
                    tile = x[b, :, i:i+4, j:j+4]
                    
                    # Transform input
                    V = np.zeros(
                        (channels, 4, 4),
                        dtype=self.config.dtype
                    )
                    for c in range(channels):
                        V[c] = self.B.T @ tile[c] @ self.B
                    
                    # Batched matrix multiplication
                    M = np.zeros(
                        (self.config.out_channels, 4, 4),
                        dtype=self.config.dtype
                    )
                    for k in range(self.config.out_channels):
                        for c in range(channels):
                            M[k] += U[k,c] * V[c]
                    
                    # Inverse transform
                    for k in range(self.config.out_channels):
                        output[b,k,i:i+2,j:j+2] = self.A.T @ M[k] @ self.A
        
        return output

    def forward(
        self,
        x: np.ndarray,
        use_cache: bool = True
    ) -> np.ndarray:
        """
        Forward pass of Conv2D layer
        
        Args:
            x: Input tensor of shape (batch, channels, height, width)
            use_cache: Whether to use database caching
            
        Returns:
            Output tensor of shape (batch, out_channels, out_height, out_width)
        """
        # Input validation
        if x.ndim != 4:
            raise ValueError(f"Expected 4D input tensor, got shape {x.shape}")
            
        if x.shape[1] != self.config.in_channels:
            raise ValueError(
                f"Expected {self.config.in_channels} input channels, got {x.shape[1]}"
            )
        
        # Try to get from cache
        if use_cache:
            cache_key = self._compute_cache_key(x.shape)
            cached_result = self.db.get_activation(cache_key)
            if cached_result is not None:
                return cached_result
        
        # Pad input
        x_padded = self._pad_input(x)
        
        # Choose computation method
        if self.use_winograd:
            output = self._winograd_conv(x_padded)
        elif self.driver and hasattr(self.driver, 'conv2d'):
            # Use hardware acceleration
            output = self.driver.conv2d(
                x_padded,
                self.weight,
                self.bias,
                self.stride,
                self.dilation,
                self.config.groups
            )
        else:
            # Use im2col optimization
            if self.use_im2col:
                col = self._im2col(x_padded)
                weight_matrix = self.weight.reshape(self.config.out_channels, -1)
                output = np.matmul(col, weight_matrix.T)
                if self.bias is not None:
                    output += self.bias
                
                # Reshape output
                batch_size = x.shape[0]
                out_h = (x.shape[2] + 2*self.padding[0] - self.dilation[0]*(self.kernel_size[0]-1) - 1)//self.stride[0] + 1
                out_w = (x.shape[3] + 2*self.padding[1] - self.dilation[1]*(self.kernel_size[1]-1) - 1)//self.stride[1] + 1
                output = output.reshape(batch_size, out_h, out_w, self.config.out_channels)
                output = output.transpose(0, 3, 1, 2)
            else:
                # Fallback to basic implementation
                batch_size = x.shape[0]
                out_h = (x.shape[2] + 2*self.padding[0] - self.dilation[0]*(self.kernel_size[0]-1) - 1)//self.stride[0] + 1
                out_w = (x.shape[3] + 2*self.padding[1] - self.dilation[1]*(self.kernel_size[1]-1) - 1)//self.stride[1] + 1
                output = np.zeros(
                    (batch_size, self.config.out_channels, out_h, out_w),
                    dtype=x.dtype
                )
                
                for b in range(batch_size):
                    for oc in range(self.config.out_channels):
                        for i in range(out_h):
                            for j in range(out_w):
                                i_start = i * self.stride[0]
                                j_start = j * self.stride[1]
                                i_end = i_start + self.kernel_size[0]
                                j_end = j_start + self.kernel_size[1]
                                
                                receptive_field = x_padded[
                                    b,
                                    :,
                                    i_start:i_end:self.dilation[0],
                                    j_start:j_end:self.dilation[1]
                                ]
                                
                                output[b, oc, i, j] = np.sum(
                                    receptive_field * self.weight[oc]
                                )
                                
                                if self.bias is not None:
                                    output[b, oc, i, j] += self.bias[oc]
        
        # Cache result
        if use_cache:
            metadata = {
                'shape': x.shape,
                'dtype': str(x.dtype),
                'algorithm': 'winograd' if self.use_winograd else 'im2col' if self.use_im2col else 'basic'
            }
            self.db.set_activation(cache_key, output, metadata)
        
        return output

# Legacy function for backward compatibility
def conv2d(
    input: np.ndarray,
    weight: np.ndarray,
    bias: Optional[np.ndarray] = None,
    stride: Union[int, Tuple[int, int]] = 1,
    padding: Union[int, Tuple[int, int]] = 0,
    dilation: Union[int, Tuple[int, int]] = 1,
    groups: int = 1,
    driver = None,
    chip_id: int = 0,
    sm_id: int = 0,
    scheduler = None
) -> np.ndarray:
    """Legacy conv2d interface"""
    warnings.warn(
        "conv2d function is deprecated, use Conv2D class instead",
        DeprecationWarning
    )
    
    config = Conv2DConfig(
        in_channels=input.shape[1],
        out_channels=weight.shape[0],
        kernel_size=weight.shape[2:],
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=groups,
        use_bias=bias is not None,
        dtype=input.dtype
    )
    
    layer = Conv2D(config, weight, bias, driver)
    return layer.forward(input)
