from typing import Optional, Union, Tuple
import numpy as np
from dataclasses import dataclass
import warnings

@dataclass
class ProjectionConfig:
    """Configuration for final projection layer"""
    hidden_dim: int
    vocab_size: int
    use_bias: bool = True
    use_fp16: bool = False
    use_weight_tying: bool = False
    dropout_rate: float = 0.0
    initializer_range: float = 0.02

class FinalProjection:
    """
    Optimized final projection layer for transformer models with support for:
    - Weight tying with input embeddings
    - Mixed precision (FP16/FP32)
    - Memory-efficient computation
    - Optimized matrix multiplication
    - Bias fusion
    """
    
    def __init__(
        self,
        config: ProjectionConfig,
        embedding_weights: Optional[np.ndarray] = None,
        driver = None
    ):
        """
        Initialize the final projection layer.
        
        Args:
            config: Projection configuration
            embedding_weights: Optional weights for weight tying with input embeddings
            driver: Optional hardware driver for optimized computation
        """
        self.config = config
        self.driver = driver
        
        if config.use_weight_tying and embedding_weights is not None:
            # Tie weights with input embeddings
            self.weight = embedding_weights.T  # Transpose for projection
        else:
            # Initialize new weights
            self.weight = self._initialize_weights()
            
        if config.use_bias:
            self.bias = np.zeros(config.vocab_size, dtype=self._get_dtype())
        else:
            self.bias = None
            
        # Cache for optimizations
        self._setup_cache()

    def _get_dtype(self) -> np.dtype:
        """Get the appropriate dtype based on configuration"""
        return np.float16 if self.config.use_fp16 else np.float32

    def _initialize_weights(self) -> np.ndarray:
        """Initialize projection weights"""
        return np.random.normal(
            0.0,
            self.config.initializer_range,
            (self.config.hidden_dim, self.config.vocab_size)
        ).astype(self._get_dtype())

    def _setup_cache(self):
        """Setup computation cache for optimizations"""
        self._cached_shapes = {}
        if self.driver and hasattr(self.driver, 'prepare_projection'):
            self._prepared_weight = self.driver.prepare_projection(self.weight)
            if self.bias is not None:
                self._prepared_bias = self.driver.prepare_bias(self.bias)
        else:
            self._prepared_weight = None
            self._prepared_bias = None

    def _apply_dropout(
        self,
        x: np.ndarray,
        training: bool = False
    ) -> np.ndarray:
        """Apply dropout if configured"""
        if training and self.config.dropout_rate > 0:
            mask = np.random.binomial(
                1,
                1.0 - self.config.dropout_rate,
                x.shape
            ).astype(self._get_dtype()) / (1.0 - self.config.dropout_rate)
            return x * mask
        return x

    def _validate_input(self, x: np.ndarray):
        """Validate input tensor shape and type"""
        if x.ndim != 3:
            raise ValueError(
                f"Expected 3D input tensor (batch, seq_len, hidden_dim), got shape {x.shape}"
            )
        if x.shape[-1] != self.config.hidden_dim:
            raise ValueError(
                f"Input hidden dimension {x.shape[-1]} doesn't match "
                f"configured hidden_dim {self.config.hidden_dim}"
            )

    def _optimize_computation(
        self,
        x: np.ndarray,
        batch_size: int,
        seq_len: int
    ) -> np.ndarray:
        """Optimize computation based on input shape and hardware"""
        shape_key = (batch_size, seq_len)
        
        # Use cached computation plan if available
        if shape_key in self._cached_shapes:
            return self._cached_shapes[shape_key](x)
            
        if self.driver and hasattr(self.driver, 'optimized_projection'):
            # Use hardware-specific optimizations
            compute_plan = self.driver.optimized_projection(
                batch_size,
                seq_len,
                self._prepared_weight,
                self._prepared_bias
            )
            self._cached_shapes[shape_key] = compute_plan
            return compute_plan(x)
            
        return None  # Fall back to standard computation

    def forward(
        self,
        x: np.ndarray,
        training: bool = False
    ) -> np.ndarray:
        """
        Forward pass of the final projection layer.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, hidden_dim)
            training: Whether in training mode (enables dropout)
            
        Returns:
            logits: Output logits of shape (batch_size, seq_len, vocab_size)
        """
        self._validate_input(x)
        
        # Cast input to appropriate dtype
        x = x.astype(self._get_dtype())
        
        # Apply dropout during training
        x = self._apply_dropout(x, training)
        
        batch_size, seq_len, _ = x.shape
        
        # Try optimized computation path
        optimized_result = self._optimize_computation(x, batch_size, seq_len)
        if optimized_result is not None:
            return optimized_result
            
        # Standard computation path
        if self.driver and self._prepared_weight is not None:
            # Use prepared weights if available
            logits = self.driver.matmul(x, self._prepared_weight)
            if self._prepared_bias is not None:
                logits = self.driver.add_bias(logits, self._prepared_bias)
        else:
            # Fallback to NumPy computation
            # Reshape for efficient matrix multiplication
            x_2d = x.reshape(-1, self.config.hidden_dim)
            logits = np.matmul(x_2d, self.weight)
            
            if self.bias is not None:
                logits += self.bias
                
            # Reshape back to 3D
            logits = logits.reshape(batch_size, seq_len, self.config.vocab_size)
            
        return logits

def final_linear_projection(
    x: np.ndarray,
    W: np.ndarray,
    b: Optional[np.ndarray] = None,
    driver = None
) -> np.ndarray:
    """
    Legacy function for backward compatibility.
    
    Args:
        x: Input tensor (batch, seq_len, hidden_dim)
        W: Weight matrix (hidden_dim, vocab_size)
        b: Optional bias vector (vocab_size,)
        driver: Optional hardware driver
        
    Returns:
        logits: Output logits (batch, seq_len, vocab_size)
    """
    warnings.warn(
        "final_linear_projection is deprecated, use FinalProjection class instead",
        DeprecationWarning
    )
    
    config = ProjectionConfig(
        hidden_dim=W.shape[0],
        vocab_size=W.shape[1],
        use_bias=b is not None
    )
    
    projection = FinalProjection(config, driver=driver)
    projection.weight = W
    if b is not None:
        projection.bias = b
        
    return projection.forward(x)
