from typing import Optional, Dict, List, Union, Tuple
import numpy as np
from enum import Enum
from dataclasses import dataclass
from .embedding import embedding_lookup, add_positional_encoding
from .positional_encoding import sinusoidal_positional_encoding
from .stack import transformer_stack
from .layer_norm import layer_norm
from .core.db_manager import HeliumDBManager
from .broadcast import ModalityType, TensorMetadata

class EncoderType(Enum):
    """Supported encoder architectures"""
    TEXT = "text"
    VISION = "vision"
    AUDIO = "audio"
    MULTIMODAL = "multimodal"
    
@dataclass
class ModalityConfig:
    """Configuration for specific modalities"""
    modality_type: ModalityType
    input_channels: int = 1
    patch_size: Union[int, Tuple[int, ...]] = 16
    sampling_rate: Optional[int] = None
    frame_rate: Optional[int] = None
    max_seq_len: int = 1024
    use_positional: bool = True
    use_patch_embed: bool = False
    
@dataclass
class EncoderConfig:
    """Configuration for TransformerEncoder"""
    encoder_type: EncoderType
    hidden_dim: int
    num_layers: int
    num_heads: int
    modality_configs: Dict[ModalityType, ModalityConfig]
    vocab_size: Optional[int] = None  # Only needed for text
    dropout_rate: float = 0.1
    layer_norm_epsilon: float = 1e-5
    initializer_range: float = 0.02
    use_cache: bool = True
    use_fp16: bool = False
    fusion_type: str = "concatenate"  # concatenate, add, or learnable
    
    def get_total_sequence_length(self) -> int:
        """Get total sequence length across all modalities"""
        return sum(config.max_seq_len for config in self.modality_configs.values())

class EncoderCache:
    """Cache for storing key/value states during inference"""
    def __init__(self):
        self.layer_states: List[Tuple[np.ndarray, np.ndarray]] = []
        self.position_offset: int = 0
    
    def update(self, layer_idx: int, key: np.ndarray, value: np.ndarray):
        if layer_idx >= len(self.layer_states):
            self.layer_states.append((key, value))
        else:
            prev_k, prev_v = self.layer_states[layer_idx]
            self.layer_states[layer_idx] = (
                np.concatenate([prev_k, key], axis=1),
                np.concatenate([prev_v, value], axis=1)
            )

class ModalityEncoder:
    """Base class for modality-specific encoders"""
    def __init__(
        self,
        config: ModalityConfig,
        hidden_dim: int,
        driver=None
    ):
        self.config = config
        self.hidden_dim = hidden_dim
        self.driver = driver
        
    def encode(self, x: np.ndarray) -> Tuple[np.ndarray, TensorMetadata]:
        """Convert input to embeddings with metadata"""
        raise NotImplementedError
        
class VisionEncoder(ModalityEncoder):
    """Vision-specific encoder with patching"""
    def encode(self, x: np.ndarray) -> Tuple[np.ndarray, TensorMetadata]:
        # Apply patch embedding
        if self.config.use_patch_embed:
            B, C, H, W = x.shape
            P = self.config.patch_size
            x = x.reshape(B, C, H//P, P, W//P, P).transpose(0,2,4,1,3,5)
            x = x.reshape(B, (H//P)*(W//P), C*P*P)
            
        # Project to hidden dimension
        if hasattr(self.driver, 'linear'):
            x = self.driver.linear(x, self.hidden_dim)
        else:
            x = np.random.randn(*x.shape[:-1], self.hidden_dim)
            
        metadata = TensorMetadata(
            modality=ModalityType.VISION,
            shape=x.shape,
            dtype=x.dtype,
            channels=self.config.input_channels,
            spatial_dims=(H, W) if 'H' in locals() else None
        )
        return x, metadata
        
class AudioEncoder(ModalityEncoder):
    """Audio-specific encoder"""
    def encode(self, x: np.ndarray) -> Tuple[np.ndarray, TensorMetadata]:
        # Apply time-frequency transform if needed
        if hasattr(self.driver, 'stft'):
            x = self.driver.stft(x)
            
        metadata = TensorMetadata(
            modality=ModalityType.AUDIO,
            shape=x.shape,
            dtype=x.dtype,
            channels=self.config.input_channels,
            sampling_rate=self.config.sampling_rate
        )
        return x, metadata
        
class TextEncoder(ModalityEncoder):
    """Text-specific encoder"""
    def __init__(self, config: ModalityConfig, hidden_dim: int, 
                 vocab_size: int, embedding_weights: np.ndarray,
                 driver=None):
        super().__init__(config, hidden_dim, driver)
        self.vocab_size = vocab_size
        self.embedding_weights = embedding_weights
        
    def encode(self, x: np.ndarray) -> Tuple[np.ndarray, TensorMetadata]:
        x = embedding_lookup(x, self.embedding_weights, driver=self.driver)
        metadata = TensorMetadata(
            modality=ModalityType.TEXT,
            shape=x.shape,
            dtype=x.dtype,
            sequence_length=x.shape[1]
        )
        return x, metadata

class TransformerEncoder:
    """
    Multi-modal Transformer Encoder implementation with support for:
    - Multiple input modalities (text, vision, audio)
    - Cross-modal attention
    - Modality-specific processing
    - Inference caching
    - Mixed precision (FP16/FP32)
    - Parallel processing
    - Memory optimization
    """
    def __init__(
        self,
        config: EncoderConfig,
        embedding_weights: Optional[np.ndarray] = None,
        block_weights_list: List[Dict] = None,
        driver=None,
        scheduler=None
    ):
        """
        Initialize the multi-modal transformer encoder.
        
        Args:
            config: Encoder configuration with modality settings
            embedding_weights: Optional word embedding matrix for text
            block_weights_list: List of weight dictionaries for transformer blocks
            driver: Optional hardware driver for optimized computation
            scheduler: Optional scheduler for parallel processing
        """
        self.validate_inputs(config, embedding_weights, block_weights_list)
        
        self.config = config
        self.driver = driver
        self.scheduler = scheduler
        
        # Initialize modality-specific encoders
        self.encoders = {}
        for modality, modal_config in config.modality_configs.items():
            if modality == ModalityType.TEXT:
                if embedding_weights is None:
                    raise ValueError("embedding_weights required for text modality")
                self.encoders[modality] = TextEncoder(
                    modal_config,
                    config.hidden_dim,
                    config.vocab_size,
                    self._prepare_weights(embedding_weights),
                    driver
                )
            elif modality == ModalityType.VISION:
                self.encoders[modality] = VisionEncoder(
                    modal_config,
                    config.hidden_dim,
                    driver
                )
            elif modality == ModalityType.AUDIO:
                self.encoders[modality] = AudioEncoder(
                    modal_config,
                    config.hidden_dim,
                    driver
                )
                
        # Initialize transformer blocks
        self.block_weights_list = [
            self._prepare_weights(weights) for weights in (block_weights_list or [])
        ]
        
        # Initialize cached computations and fusion layer
        self._init_cached_computations()
        self._init_fusion_layer()

    def _init_cached_computations(self):
        """Initialize cached components for faster inference"""
        # Create positional encodings for each modality
        self.pos_encodings = {}
        dtype = np.float16 if self.config.use_fp16 else np.float32
        
        for modality, modal_config in self.config.modality_configs.items():
            if modal_config.use_positional:
                self.pos_encodings[modality] = sinusoidal_positional_encoding(
                    modal_config.max_seq_len,
                    self.config.hidden_dim,
                    dtype=dtype
                )
                
        # Precompute attention bias if supported
        if self.driver and hasattr(self.driver, 'precompute_attention_bias'):
            total_seq_len = self.config.get_total_sequence_length()
            self.cached_attention_bias = self.driver.precompute_attention_bias(
                total_seq_len
            )
        else:
            self.cached_attention_bias = None
            
    def _init_fusion_layer(self):
        """Initialize multi-modal fusion layer"""
        if self.config.fusion_type == "learnable":
            num_modalities = len(self.config.modality_configs)
            if self.driver and hasattr(self.driver, 'create_parameter'):
                self.fusion_weights = self.driver.create_parameter(
                    (num_modalities, 1, 1),
                    dtype=np.float16 if self.config.use_fp16 else np.float32
                )
            else:
                self.fusion_weights = np.ones((num_modalities, 1, 1)) / num_modalities
        else:
            self.fusion_weights = None

    def _prepare_weights(self, weights: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
        """Convert weights to appropriate precision"""
        if self.config.use_fp16:
            if isinstance(weights, np.ndarray):
                return weights.astype(np.float16)
            return {k: v.astype(np.float16) for k, v in weights.items()}
        return weights
        
    def _fuse_modalities(
        self,
        encoded_states: Dict[ModalityType, np.ndarray],
        encoded_metadata: Dict[ModalityType, TensorMetadata]
    ) -> Tuple[np.ndarray, TensorMetadata]:
        """
        Fuse multiple modalities into a single representation
        
        Supports three fusion types:
        1. concatenate: Concatenate along sequence dimension
        2. add: Element-wise addition (requires same shape)
        3. learnable: Weighted sum using learned weights
        """
        modalities = list(encoded_states.keys())
        
        if len(modalities) == 1:
            return encoded_states[modalities[0]], encoded_metadata[modalities[0]]
            
        if self.config.fusion_type == "concatenate":
            # Concatenate along sequence dimension
            fused = np.concatenate(
                [encoded_states[m] for m in modalities],
                axis=1
            )
            
        elif self.config.fusion_type == "add":
            # Verify shapes match
            shapes = [encoded_states[m].shape for m in modalities]
            if not all(s == shapes[0] for s in shapes):
                raise ValueError(
                    f"All modalities must have same shape for addition fusion. Got {shapes}"
                )
            fused = sum(encoded_states[m] for m in modalities)
            
        elif self.config.fusion_type == "learnable":
            # Apply learned weights
            weighted = [
                encoded_states[m] * self.fusion_weights[i]
                for i, m in enumerate(modalities)
            ]
            fused = sum(weighted)
            
        else:
            raise ValueError(f"Unknown fusion type: {self.config.fusion_type}")
            
        # Create metadata for fused representation
        fused_metadata = TensorMetadata(
            modality=ModalityType.LATENT,
            shape=fused.shape,
            dtype=fused.dtype,
            channels=sum(m.channels for m in encoded_metadata.values()),
            sequence_length=fused.shape[1]
        )
        
        return fused, fused_metadata

    @staticmethod
    def validate_inputs(
        config: EncoderConfig,
        embedding_weights: np.ndarray,
        block_weights_list: List[Dict]
    ):
        """Validate input parameters and weights"""
        if embedding_weights.shape != (config.vocab_size, config.hidden_dim):
            raise ValueError(
                f"Embedding weights shape {embedding_weights.shape} doesn't match "
                f"config (vocab_size={config.vocab_size}, hidden_dim={config.hidden_dim})"
            )
        
        if len(block_weights_list) != config.num_layers:
            raise ValueError(
                f"Expected {config.num_layers} transformer blocks, got {len(block_weights_list)}"
            )

    def create_attention_mask(
        self,
        input_shape: Tuple[int, int],
        past_length: int = 0
    ) -> np.ndarray:
        """Create causal attention mask for autoregressive inference"""
        batch_size, seq_length = input_shape
        mask = np.ones((batch_size, 1, seq_length, seq_length + past_length))
        
        # Create causal mask for autoregressive generation
        if past_length > 0:
            mask[:, :, :, :past_length] = 1.0
        
        return mask

    def forward(
        self,
        inputs: Dict[ModalityType, np.ndarray],
        attention_mask: Optional[np.ndarray] = None,
        past_cache: Optional[EncoderCache] = None,
        return_cache: bool = False
    ) -> Union[np.ndarray, Tuple[np.ndarray, EncoderCache]]:
        """
        Forward pass of the multi-modal encoder
        
        Args:
            inputs: Dictionary mapping modality types to input arrays
            attention_mask: Optional attention mask
            past_cache: Optional cached key/value states
            return_cache: Whether to return updated cache
            
        Returns:
            Encoded representations, optionally with cache
        """
        # Encode each modality
        encoded_states = {}
        encoded_metadata = {}
        max_seq_len = 0
        
        for modality, x in inputs.items():
            if modality not in self.encoders:
                raise ValueError(f"No encoder configured for modality {modality}")
                
            # Encode input
            states, metadata = self.encoders[modality].encode(x)
            encoded_states[modality] = states
            encoded_metadata[modality] = metadata
            max_seq_len = max(max_seq_len, states.shape[1])
            
        # Pad sequences to same length
        for modality in encoded_states:
            states = encoded_states[modality]
            if states.shape[1] < max_seq_len:
                pad_len = max_seq_len - states.shape[1]
                encoded_states[modality] = np.pad(
                    states,
                    ((0, 0), (0, pad_len), (0, 0)),
                    mode='constant'
                )
                
        # Add positional encodings
        for modality, states in encoded_states.items():
            if modality in self.pos_encodings:
                pos_enc = self.pos_encodings[modality][:states.shape[1]]
                encoded_states[modality] = states + pos_enc
                
        # Create attention mask if not provided
        if attention_mask is None:
            attention_mask = self.create_attention_mask(
                (encoded_states[list(encoded_states.keys())[0]].shape[0], max_seq_len),
                past_length=past_cache.position_offset if past_cache else 0
            )
        """
        Forward pass through the transformer encoder.
        
        Args:
            input_ids: Input token IDs of shape (batch_size, seq_len)
            attention_mask: Optional attention mask
            past_cache: Optional past key/value cache for inference
            return_cache: Whether to return updated cache
            
        Returns:
            output: Encoded representations
            cache: Updated cache if return_cache is True
        """
        batch_size, seq_length = input_ids.shape
        
        if seq_length > self.config.max_seq_len:
            raise ValueError(
                f"Input sequence length {seq_length} exceeds maximum "
                f"sequence length {self.config.max_seq_len}"
            )

        # Fuse modalities
        hidden_states, fused_metadata = self._fuse_modalities(
            encoded_states,
            encoded_metadata
        )

        # Initialize cache for current forward pass
        current_cache = EncoderCache() if self.config.use_cache else None
        
        if current_cache:
            current_cache.modality_metadata = fused_metadata
        
        # Process through transformer stack with modality-aware attention
        hidden_states = transformer_stack(
            hidden_states,
            self.block_weights_list,
            self.config.num_heads,
            attention_mask=attention_mask,
            past_cache=past_cache,
            current_cache=current_cache,
            driver=self.driver,
            scheduler=self.scheduler,
            metadata=fused_metadata
        )

        if return_cache:
            return hidden_states, current_cache
        return hidden_states

    def generate(
        self,
        input_ids: np.ndarray,
        max_length: int,
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 0.95
    ) -> np.ndarray:
        """
        Generate sequences autoregressively.
        
        Args:
            input_ids: Initial input tokens
            max_length: Maximum sequence length to generate
            temperature: Sampling temperature
            top_k: Number of top tokens to sample from
            top_p: Cumulative probability threshold for nucleus sampling
            
        Returns:
            generated_ids: Generated token sequences
        """
        batch_size = input_ids.shape[0]
        generated_ids = [list(seq) for seq in input_ids]
        cache = EncoderCache()

        for _ in range(max_length - input_ids.shape[1]):
            # Forward pass with caching
            outputs, cache = self.forward(
                input_ids,
                past_cache=cache,
                return_cache=True
            )
            
            # Get next token logits
            next_token_logits = outputs[:, -1, :]
            
            # Apply temperature
            next_token_logits = next_token_logits / temperature
            
            # Apply top-k filtering
            if top_k > 0:
                indices_to_remove = next_token_logits < np.partition(
                    next_token_logits, -top_k, axis=-1
                )[:, -top_k:].min(axis=-1, keepdims=True)
                next_token_logits[indices_to_remove] = -float('inf')
            
            # Apply top-p (nucleus) filtering
            if top_p < 1.0:
                sorted_logits = np.sort(next_token_logits, axis=-1)[:, ::-1]
                cumsum_probs = np.cumsum(np.exp(sorted_logits), axis=-1)
                mask = cumsum_probs > top_p
                mask[:, 1:] = mask[:, :-1].copy()
                mask[:, 0] = 0
                indices_to_remove = next_token_logits < np.min(
                    sorted_logits[mask],
                    axis=-1,
                    keepdims=True
                )
                next_token_logits[indices_to_remove] = -float('inf')
            
            # Sample next tokens
            probs = np.exp(next_token_logits)
            probs = probs / np.sum(probs, axis=-1, keepdims=True)
            next_tokens = np.array([
                np.random.choice(self.config.vocab_size, p=p)
                for p in probs
            ])
            
            # Update generated sequences
            for i in range(batch_size):
                generated_ids[i].append(next_tokens[i])
            
            # Update input_ids for next iteration
            input_ids = next_tokens[:, np.newaxis]
            
        return np.array(generated_ids)
