from typing import Optional, List, Dict, Union, Tuple
import numpy as np
from dataclasses import dataclass
from enum import Enum
import warnings
from .block import TransformerBlock
from .core.db_manager import HeliumDBManager
import json
import hashlib
from contextlib import contextmanager
import time

class ExecutionStrategy(Enum):
    """Execution strategies for transformer stack"""
    SEQUENTIAL = "sequential"  # Process blocks one by one
    PIPELINED = "pipelined"   # Pipeline blocks across multiple devices
    PARALLEL = "parallel"     # Process blocks in parallel where possible

@dataclass
class StackConfig:
    """Configuration for transformer stack"""
    num_layers: int
    hidden_dim: int
    num_heads: int
    intermediate_size: int
    max_sequence_length: int
    dropout_rate: float = 0.1
    layer_norm_epsilon: float = 1e-5
    use_cache: bool = True
    use_checkpointing: bool = False
    execution_strategy: ExecutionStrategy = ExecutionStrategy.SEQUENTIAL
    dtype: np.dtype = np.float32
    gradient_checkpointing_steps: int = 2
    max_batch_size: Optional[int] = None

class TransformerStackCache:
    """Cache manager for transformer stack computations"""
    def __init__(self, config: StackConfig):
        self.config = config
        self.db = HeliumDBManager.get_instance()
        
    def _compute_cache_key(
        self,
        layer_idx: int,
        input_shape: Tuple,
        block_config: Dict
    ) -> str:
        """Compute cache key for layer outputs"""
        cache_data = {
            'layer_idx': layer_idx,
            'input_shape': input_shape,
            'block_config': block_config,
            'dtype': str(self.config.dtype)
        }
        return hashlib.sha256(json.dumps(cache_data).encode()).hexdigest()
    
    def get(self, key: str) -> Optional[np.ndarray]:
        """Get cached computation result"""
        return self.db.get_activation(key)
    
    def set(self, key: str, value: np.ndarray, metadata: Dict):
        """Cache computation result"""
        self.db.set_activation(key, value, metadata)

class ResourceManager:
    """Manages hardware resources and scheduling"""
    def __init__(self, driver=None):
        self.driver = driver
        self.available_devices = self._get_available_devices()
        self.device_queues = {device: [] for device in self.available_devices}
        
    def _get_available_devices(self) -> List[str]:
        """Get list of available compute devices"""
        if self.driver and hasattr(self.driver, 'list_devices'):
            return self.driver.list_devices()
        return ['cpu']
    
    @contextmanager
    def acquire_device(self, preferred_device: Optional[str] = None):
        """Acquire a compute device"""
        device = self._select_device(preferred_device)
        try:
            yield device
        finally:
            self._release_device(device)
    
    def _select_device(self, preferred_device: Optional[str] = None) -> str:
        """Select best available device"""
        if preferred_device and preferred_device in self.available_devices:
            return preferred_device
        
        # Select device with shortest queue
        return min(
            self.device_queues.items(),
            key=lambda x: len(x[1])
        )[0]
    
    def _release_device(self, device: str):
        """Release device back to pool"""
        if device in self.device_queues:
            self.device_queues[device].pop(0) if self.device_queues[device] else None

class TransformerStack:
    """
    Optimized transformer stack implementation with support for:
    - Multiple execution strategies
    - Hardware acceleration
    - Gradient checkpointing
    - Mixed precision
    - Memory optimization
    """
    
    def __init__(
        self,
        config: StackConfig,
        weights_list: List[Dict],
        driver = None
    ):
        """
        Initialize transformer stack
        
        Args:
            config: Stack configuration
            weights_list: List of block weights
            driver: Optional hardware driver
        """
        self.config = config
        self.weights_list = weights_list
        self.driver = driver
        
        self._validate_config()
        self._setup_components()
        
    def _validate_config(self):
        """Validate configuration parameters"""
        if len(self.weights_list) != self.config.num_layers:
            raise ValueError(
                f"Expected {self.config.num_layers} weight dicts, got {len(self.weights_list)}"
            )
            
        if self.config.num_heads <= 0:
            raise ValueError(f"Invalid number of heads: {self.config.num_heads}")
            
        if self.config.hidden_dim % self.config.num_heads != 0:
            raise ValueError(
                f"Hidden dimension {self.config.hidden_dim} must be divisible "
                f"by number of heads {self.config.num_heads}"
            )
    
    def _setup_components(self):
        """Setup stack components"""
        # Initialize blocks
        self.blocks = [
            TransformerBlock(
                hidden_size=self.config.hidden_dim,
                num_heads=self.config.num_heads,
                intermediate_size=self.config.intermediate_size,
                weights=weights,
                dropout_rate=self.config.dropout_rate,
                layer_norm_epsilon=self.config.layer_norm_epsilon,
                dtype=self.config.dtype,
                driver=self.driver
            )
            for weights in self.weights_list
        ]
        
        # Initialize cache
        self.cache = TransformerStackCache(self.config)
        
        # Initialize resource manager
        self.resource_manager = ResourceManager(self.driver)
        
    def _execute_sequential(
        self,
        x: np.ndarray,
        mask: Optional[np.ndarray] = None,
        use_cache: bool = True
    ) -> np.ndarray:
        """Execute blocks sequentially"""
        current_state = x
        
        for i, block in enumerate(self.blocks):
            if use_cache:
                cache_key = self.cache._compute_cache_key(
                    i, current_state.shape, block.get_config()
                )
                cached_result = self.cache.get(cache_key)
                if cached_result is not None:
                    current_state = cached_result
                    continue
            
            with self.resource_manager.acquire_device() as device:
                current_state = block(
                    current_state,
                    mask=mask,
                    device=device
                )
                
                if use_cache:
                    self.cache.set(
                        cache_key,
                        current_state,
                        {'layer_idx': i, 'shape': current_state.shape}
                    )
                    
        return current_state
    
    def _execute_pipelined(
        self,
        x: np.ndarray,
        mask: Optional[np.ndarray] = None
    ) -> np.ndarray:
        """Execute blocks in a pipelined fashion"""
        batch_size = x.shape[0]
        num_chunks = min(
            batch_size,
            len(self.resource_manager.available_devices)
        )
        chunk_size = batch_size // num_chunks
        
        # Split input into chunks
        chunks = np.array_split(x, num_chunks)
        results = []
        
        # Process chunks in pipeline
        for i, chunk in enumerate(chunks):
            current_state = chunk
            for j, block in enumerate(self.blocks):
                with self.resource_manager.acquire_device() as device:
                    current_state = block(
                        current_state,
                        mask=mask[i*chunk_size:(i+1)*chunk_size] if mask is not None else None,
                        device=device
                    )
            results.append(current_state)
            
        # Concatenate results
        return np.concatenate(results, axis=0)
    
    def _execute_parallel(
        self,
        x: np.ndarray,
        mask: Optional[np.ndarray] = None
    ) -> np.ndarray:
        """Execute blocks in parallel where possible"""
        if not self.driver or not hasattr(self.driver, 'parallel_execute'):
            warnings.warn("Parallel execution not supported, falling back to sequential")
            return self._execute_sequential(x, mask)
            
        return self.driver.parallel_execute(
            self.blocks,
            x,
            mask,
            self.config.num_layers
        )
    
    def forward(
        self,
        x: np.ndarray,
        mask: Optional[np.ndarray] = None,
        use_cache: bool = True
    ) -> np.ndarray:
        """
        Forward pass through transformer stack
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, hidden_dim)
            mask: Optional attention mask
            use_cache: Whether to use computation caching
            
        Returns:
            Output tensor of shape (batch_size, seq_len, hidden_dim)
        """
        # Input validation
        if x.ndim != 3:
            raise ValueError(f"Expected 3D input tensor, got shape {x.shape}")
            
        if x.shape[2] != self.config.hidden_dim:
            raise ValueError(
                f"Expected hidden dimension {self.config.hidden_dim}, got {x.shape[2]}"
            )
            
        if (
            self.config.max_sequence_length and
            x.shape[1] > self.config.max_sequence_length
        ):
            raise ValueError(
                f"Input sequence length {x.shape[1]} exceeds maximum "
                f"allowed length {self.config.max_sequence_length}"
            )
            
        # Choose execution strategy
        if self.config.execution_strategy == ExecutionStrategy.PIPELINED:
            return self._execute_pipelined(x, mask)
        elif self.config.execution_strategy == ExecutionStrategy.PARALLEL:
            return self._execute_parallel(x, mask)
        else:
            return self._execute_sequential(x, mask, use_cache)
    
    def __call__(
        self,
        x: np.ndarray,
        mask: Optional[np.ndarray] = None,
        use_cache: bool = True
    ) -> np.ndarray:
        """Callable interface"""
        return self.forward(x, mask, use_cache)

# Legacy function for backward compatibility
def transformer_stack(
    x: np.ndarray,
    weights_list: List[Dict],
    num_heads: int,
    mask: Optional[np.ndarray] = None,
    driver = None,
    scheduler = None
) -> np.ndarray:
    """Legacy transformer stack interface"""
    warnings.warn(
        "transformer_stack function is deprecated, use TransformerStack class instead",
        DeprecationWarning
    )
    
    config = StackConfig(
        num_layers=len(weights_list),
        hidden_dim=x.shape[2],
        num_heads=num_heads,
        intermediate_size=4 * x.shape[2],  # Standard size
        max_sequence_length=x.shape[1]
    )
    
    stack = TransformerStack(config, weights_list, driver)
    return stack.forward(x, mask)
