import numpy as np
from typing import Optional, Tuple, Dict, Union, List, Any
from dataclasses import dataclass
from enum import Enum
from .softmax import softmax
from .broadcast import ModalityType, TensorMetadata
from .tensor import HeliumTensor
from .attention_utils import AttentionState
from .utils import split_heads, apply_rotary_embedding, fuse_cross_modal_attention

class AttentionType(Enum):
    """Types of attention patterns"""
    SELF = "self"
    CROSS = "cross"
    LOCAL = "local"
    SPARSE = "sparse"
    GLOBAL = "global"

@dataclass
class AttentionConfig:
    """Configuration for multi-modal attention"""
    attention_type: AttentionType 
    num_heads: int
    hidden_dim: int
    cross_modality_fusion: str = "additive"
    use_rotary: bool = False

class HeliumMultiHeadAttention:
    """
    Multi-modal attention implementation with support for:
    - Cross-modal attention
    - Modality-specific patterns
    - Local/sparse attention
    - Rotary embeddings
    - Fusion mechanisms
    """
    def __init__(self, config: AttentionConfig, device_id: Optional[str] = None):
        self.config = config
        self.device_id = device_id
        self.head_dim = config.hidden_dim // config.num_heads
        
        # Initialize modality-specific projections
        self.projections = self._create_projections()
        
        # Initialize output projection
        self.output_projection = self._create_projection(scale=1.0)
        
        # Cache for attention patterns
        self.pattern_cache: Dict[str, np.ndarray] = {}

    def _create_projections(self) -> Dict[str, Dict[str, Any]]:
        """Create projection matrices for Q,K,V for each modality"""
        projections = {}
        
        for modality in ModalityType:
            # Get modality-specific scaling
            scale = 1.0
            if modality == ModalityType.IMAGE:
                scale = np.sqrt(self.head_dim / 64)
            elif modality == ModalityType.AUDIO:
                scale = np.sqrt(self.head_dim / 32)
                
            # Create projections
            q_proj = self._create_projection(scale=scale)
            k_proj = self._create_projection(scale=scale)
            v_proj = self._create_projection(scale=scale)
            
            projections[modality] = {
                'query': q_proj,
                'key': k_proj,
                'value': v_proj
            }
            
        return projections
        
    def _create_projection(self, scale: float = 1.0) -> Dict[str, Union[np.ndarray, HeliumTensor]]:
        """Create a single projection matrix"""
        std = scale * np.sqrt(2.0 / (2.0 * self.config.hidden_dim))
        weight = np.random.normal(0, std, (self.config.hidden_dim, self.config.hidden_dim))
        bias = np.zeros(self.config.hidden_dim)
        
        if self.device_id:
            # Move to device if specified
            weight = HeliumTensor(weight, device=self.device_id)
            bias = HeliumTensor(bias, device=self.device_id)
            
        return {'weight': weight, 'bias': bias}
        
    def forward(
        self,
        hidden_states: Union[str, HeliumTensor],
        attention_mask: Optional[Union[str, HeliumTensor]] = None,
        modality: Optional[ModalityType] = None,
        cross_states: Optional[Union[str, HeliumTensor]] = None,
        cross_modality: Optional[ModalityType] = None,
        metadata: Optional[TensorMetadata] = None
    ) -> Tuple[Union[str, HeliumTensor], Dict[str, Any]]:
        """
        Multi-modal attention forward pass
        """
        # Initialize computation state 
        state = AttentionState(hidden_states.device if hasattr(hidden_states, 'device') else None, "mm_attn")
        
        # Get projection matrices
        mod = modality or ModalityType.TEXT
        projections = self.projections[mod]
        
        # Project inputs
        q = driver.matmul(hidden_states, projections['query']['weight'])
        k = q if cross_states is None else driver.matmul(cross_states, projections['key']['weight'])
        v = k
        
        # Split heads with modality awareness
        q = split_heads(q, self.config.num_heads, hidden_states.device, modality)
        k = split_heads(k, self.config.num_heads, hidden_states.device, cross_modality or modality) 
        v = split_heads(v, self.config.num_heads, hidden_states.device, cross_modality or modality)
        
        # Apply rotary embeddings if configured
        if self.config.use_rotary:
            seq_len = hidden_states.shape[1]
            q = apply_rotary_embedding(q, seq_len, self.head_dim, hidden_states.device)
            k = apply_rotary_embedding(k, seq_len, self.head_dim, hidden_states.device)
            
        # Handle cross-modal attention
        if cross_states is not None and cross_modality != modality:
            q, k, v = fuse_cross_modal_attention(
                q, k, v,
                modality,
                cross_modality,
                self.config.cross_modality_fusion,
                hidden_states.device,
                state
            )
            
        # Get attention mask
        if attention_mask is None and self.config.attention_type != AttentionType.GLOBAL:
            attention_mask = self._get_attention_mask(
                modality or ModalityType.TEXT,
                cross_modality or modality or ModalityType.TEXT,
                q.shape[2],
                k.shape[2]
            )
            
        # Compute attention
        scale = np.sqrt(self.head_dim)
        if modality == ModalityType.IMAGE:
            scale *= 2.0
            
        attn_output, _ = scaled_dot_product_attention(
            q, k, v,
            mask=attention_mask,
            scale=scale,
            driver=hidden_states.device
        )
        
        # Combine heads
        attn_output = driver.reshape(attn_output, (
            attn_output.shape[0],
            attn_output.shape[2],
            self.config.hidden_dim
        ))
        
        # Project output
        output = driver.matmul(attn_output, self.output_projection['weight'])
        
        # Add metadata
        if metadata:
            metadata.modality = modality
            metadata.operation = "attention"
            metadata.shape = output.shape
            
        return output, {'attention_weights': attn_output}
            
    def _get_attention_mask(
        self,
        q_modality: ModalityType,
        k_modality: ModalityType,
        q_length: int,
        k_length: int
    ) -> Optional[Union[str, HeliumTensor]]:
        """Get or create attention mask for given modalities"""
        key = (q_modality, k_modality, q_length, k_length)
        if key in self.pattern_cache:
            return self.pattern_cache[key]
            
        # Create attention mask based on attention type
        mask = None
        if self.config.attention_type == AttentionType.LOCAL:
            # Local attention with sliding window
            window = self.config.window_size or q_length // 8
            indices = np.arange(q_length)
            mask = np.abs(indices[:, None] - indices) > window
            
        elif self.config.attention_type == AttentionType.SPARSE:
            # Sparse attention with strided pattern
            stride = self.config.sparsity_factor or 8
            indices = np.arange(q_length)
            mask = (indices[:, None] - indices) % stride != 0
            
        # Add modal-specific patterns
        if mask is not None and self.config.modality_specific:
            if q_modality == ModalityType.IMAGE:
                # Add local 2D structure for images
                h = w = int(np.sqrt(q_length))
                if h * w == q_length:  # Perfect square
                    i, j = np.meshgrid(np.arange(h), np.arange(w))
                    dist = (i[:, None] - i) ** 2 + (j[:, None] - j) ** 2
                    mask = np.logical_and(mask, dist.reshape(q_length, q_length) > 4)
                    
            elif q_modality == ModalityType.AUDIO:
                # Add frequency-based patterns for audio
                freqs = np.fft.fftfreq(q_length)
                mask = np.logical_and(mask, 
                    np.abs(freqs[:, None] - freqs) > 0.25)
                
        if mask is not None and self.device_id:
            mask = HeliumTensor(mask, device=self.device_id)
            
            self.pattern_cache[key] = mask
            return mask
            
def create_attention_mask(
            q_modality: ModalityType,
            k_modality: ModalityType,
            q_length: int,
            k_length: int,
            attention_type: AttentionType,
            window_size: Optional[int] = None
) -> np.ndarray:
        mask = np.ones((q_length, k_length), dtype=np.float32)
    
        if attention_type == AttentionType.LOCAL and window_size:
            # Create local attention pattern
            for i in range(q_length):
                start = max(0, i - window_size)
                end = min(k_length, i + window_size + 1)
                mask[i, :start] = 0
                mask[i, end:] = 0
                
        elif attention_type == AttentionType.SPARSE:
            # Create sparse attention pattern
            stride = max(1, k_length // 8)  # Example: attend to every 8th position
            mask[:, ::stride] = 1
            mask[:, :] = 0
            
        # Modality-specific masking
        if q_modality != k_modality:
            if q_modality == ModalityType.TEXT and k_modality == ModalityType.IMAGE:
                # Text can attend to full image
                pass
            elif q_modality == ModalityType.IMAGE and k_modality == ModalityType.TEXT:
                # Image attends to text sparsely
                mask[:, ::2] = 1  # Example: attend to every other text token
                mask[:, 1::2] = 0
                
        return mask

def split_heads(
    x_name: str,
    num_heads: int,
    driver,
    state: AttentionState,
    modality: Optional[ModalityType] = None
) -> str:
    """
    Split the last dimension into (num_heads, head_dim) with modality-specific processing
    All operations in driver memory
    Returns: name of resulting tensor in driver
    """
    x = driver.get_tensor(x_name)
    batch, seq_len, hidden_dim = x.shape
    head_dim = hidden_dim // num_heads
    
    # Apply modality-specific head scaling
    if modality:
        scale = 1.0
        if modality == ModalityType.IMAGE:
            # Scale image heads differently
            scale = np.sqrt(head_dim / 64)  # Example scaling
        elif modality == ModalityType.TEXT:
            scale = 1.0
        
        x = x * scale
    
    # Reshape and transpose in driver memory
    reshaped_name = state.get_temp_tensor(
        x.reshape(batch, seq_len, num_heads, head_dim),
        "reshaped"
    )
    
    # Add modality info to metadata if supported
    if hasattr(driver, 'set_tensor_metadata') and modality:
        driver.set_tensor_metadata(
            reshaped_name,
            TensorMetadata(
                modality=modality,
                shape=x.shape,
                dtype=x.dtype
            )
        )
    
    transposed_name = state.get_temp_tensor(
        driver.transpose(reshaped_name, (0, 2, 1, 3)),
        "transposed"
    )
    
    state.free_temp_tensor(reshaped_name)
    return transposed_name

def apply_rotary_embedding(
    x_name: str,
    seq_len: int,
    head_dim: int,
    driver,
    state: AttentionState,
    base: int = 10000
) -> str:
    """Apply rotary positional embeddings"""
    x = driver.get_tensor(x_name)
    batch_size, num_heads = x.shape[:2]
    
    # Create position indices
    position = np.arange(seq_len)
    # Create dimension indices
    dim = np.arange(head_dim // 2) * 2
    
    # Compute frequencies
    freq = 1.0 / (base ** (dim / head_dim))
    freq = np.einsum('i,j->ij', position, freq)
    
    # Compute rotations
    cos = np.cos(freq)[None, None, :, :]
    sin = np.sin(freq)[None, None, :, :]
    
    # Reshape x for rotation
    x_reshaped = x.reshape(batch_size, num_heads, seq_len, head_dim // 2, 2)
    
    # Apply rotation
    x_rot = np.concatenate([
        x_reshaped[..., 0] * cos - x_reshaped[..., 1] * sin,
        x_reshaped[..., 0] * sin + x_reshaped[..., 1] * cos
    ], axis=-1)
    
    rotated_name = state.get_temp_tensor(x_rot, "rotary")
    return rotated_name

def fuse_cross_modal_attention(
    q_name: str,
    k_name: str,
    v_name: str,
    q_modality: ModalityType,
    k_modality: ModalityType,
    fusion_type: str,
    driver,
    state: AttentionState
) -> Tuple[str, str, str]:
    """
    Fuse attention across different modalities
    
    Args:
        q_name: Query tensor name
        k_name: Key tensor name
        v_name: Value tensor name
        q_modality: Query modality
        k_modality: Key modality
        fusion_type: Type of fusion (additive, multiplicative, gated)
    """
    q = driver.get_tensor(q_name)
    k = driver.get_tensor(k_name)
    v = driver.get_tensor(v_name)
    
    if fusion_type == "additive":
        # Add modality-specific learnable bias
        bias_shape = (1, q.shape[1], 1, q.shape[-1])
        q_bias = np.zeros(bias_shape)
        k_bias = np.zeros(bias_shape)
        
        q_fused_name = state.get_temp_tensor(q + q_bias, "q_fused")
        k_fused_name = state.get_temp_tensor(k + k_bias, "k_fused")
        v_fused_name = v_name
        
    elif fusion_type == "multiplicative":
        # Apply modality-specific scaling
        q_scale = np.sqrt(q.shape[-1]) if q_modality == ModalityType.TEXT else 1.0
        k_scale = np.sqrt(k.shape[-1]) if k_modality == ModalityType.TEXT else 1.0
        
        q_fused_name = state.get_temp_tensor(q * q_scale, "q_fused")
        k_fused_name = state.get_temp_tensor(k * k_scale, "k_fused")
        v_fused_name = v_name
        
    elif fusion_type == "gated":
        # Learn modality-specific gating
        gate_shape = (1, q.shape[1], 1, 1)
        q_gate = np.ones(gate_shape)  # Initialize to 1
        k_gate = np.ones(gate_shape)
        
        q_fused_name = state.get_temp_tensor(q * q_gate, "q_fused")
        k_fused_name = state.get_temp_tensor(k * k_gate, "k_fused")
        v_fused_name = v_name
        
    return q_fused_name, k_fused_name, v_fused_name

def combine_heads(
    x_name: str,
    driver,
    state: AttentionState,
    modality: Optional[ModalityType] = None
) -> str:
    """
    Combine heads with modality-specific processing
    All operations in driver memory
    Returns: name of resulting tensor in driver
    """
    x = driver.get_tensor(x_name)
    batch, num_heads, seq_len, head_dim = x.shape
    
    # Transpose and reshape in driver memory
    transposed_name = state.get_temp_tensor(
        driver.transpose(x_name, (0, 2, 1, 3)),
        "transposed_back"
    )
    reshaped_name = state.get_temp_tensor(
        driver.reshape(transposed_name, (batch, seq_len, num_heads * head_dim)),
        "reshaped_back"
    )
    
    state.free_temp_tensor(transposed_name)
    return reshaped_name

    def __init__(
        self,
        config: AttentionConfig,
        device_id: Optional[str] = None,
        driver = None
    ):
        self.config = config
        self.driver = driver
        self.head_dim = config.hidden_dim // config.num_heads
        
        # Initialize modality-specific projections
        self.projections = self._create_projections()
        
        # Cache for attention patterns
        self.pattern_cache: Dict[str, np.ndarray] = {}
        
    def _create_projections(self) -> Dict[str, Dict[str, Any]]:
        """Create projection matrices for Q,K,V"""
        projections = {}
        
        for modality in ModalityType:
            # Get modality-specific scaling
            scale = 1.0
            if modality == ModalityType.IMAGE:
                scale = np.sqrt(self.head_dim / 64)
            elif modality == ModalityType.AUDIO:
                scale = np.sqrt(self.head_dim / 32)
                
            # Create projections
            q_proj = self._create_projection(scale=scale)
            k_proj = self._create_projection(scale=scale)
            v_proj = self._create_projection(scale=scale)
            
            projections[modality] = {
                'query': q_proj,
                'key': k_proj,
                'value': v_proj
            }
            
        return projections
        
    def _create_projection(self, scale: float = 1.0) -> Dict[str, np.ndarray]:
        """Create a single projection matrix"""
        std = scale * np.sqrt(2.0 / (2.0 * self.config.hidden_dim))
        weight = np.random.normal(0, std, (self.config.hidden_dim, self.config.hidden_dim))
        bias = np.zeros(self.config.hidden_dim)
        
        if hasattr(self.driver, 'to_gpu'):
            weight = self.driver.to_gpu(weight)
            bias = self.driver.to_gpu(bias)
            
        return {'weight': weight, 'bias': bias}
        
    def forward(
        self,
        hidden_states: Union[str, "HeliumTensor"],
        attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
        modality: Optional[ModalityType] = None,
        cross_states: Optional[Union[str, "HeliumTensor"]] = None,
        cross_modality: Optional[ModalityType] = None,
        metadata: Optional[TensorMetadata] = None
    ) -> Tuple[Union[str, "HeliumTensor"], Dict[str, Any]]:
        """
        Multi-modal attention forward pass
        """
        # Initialize computation state
        state = AttentionState(self.driver, f"mm_attn")
        
        # Get input tensors from names/references
        if isinstance(hidden_states, str):
            query = self.driver.get_tensor(hidden_states)
        else:
            query = hidden_states
            
        # Project query
        q_proj = self.projections[modality or ModalityType.TEXT]['query']
        key = query if cross_states is None else cross_states
        value = key
        
        # Project and split heads
        q = self.driver.matmul(query, q_proj['weight'])
        k = self.driver.matmul(key, q_proj['weight'])
        v = self.driver.matmul(value, q_proj['weight'])
        
        # Split heads with modality awareness
        q = split_heads(q, self.config.num_heads, self.driver, modality)
        k = split_heads(k, self.config.num_heads, self.driver, cross_modality or modality)
        v = split_heads(v, self.config.num_heads, self.driver, cross_modality or modality)
        
        # Apply rotary embeddings if configured
        if self.config.use_rotary:
            q = apply_rotary_embedding(q, query.shape[1], self.head_dim, self.driver)
            k = apply_rotary_embedding(k, key.shape[1], self.head_dim, self.driver)
            
        # Handle cross-modal attention
        if cross_states is not None and cross_modality != modality:
            q, k, v = fuse_cross_modal_attention(
                q, k, v,
                modality,
                cross_modality,
                self.config.cross_modality_fusion,
                self.driver,
                state
            )
            
        # Get attention mask
        if attention_mask is None and self.config.attention_type != AttentionType.GLOBAL:
            attention_mask = self._get_attention_mask(
                modality or ModalityType.TEXT,
                cross_modality or modality or ModalityType.TEXT,
                query.shape[1],
                key.shape[1]
            )
            
        # Compute attention with scaling
        scale = np.sqrt(self.head_dim)
        if modality == ModalityType.IMAGE:
            scale *= 2.0  # Stronger scaling for image attention
            
        attn_output = scaled_dot_product_attention(
            q, k, v,
            mask=attention_mask,
            scale=scale,
            driver=self.driver
        )
        
        # Combine heads
        attn_output = driver.reshape(attn_output, (
            attn_output.shape[0],
            attn_output.shape[2],
            self.config.hidden_dim
        ))
        
        # Project back
        output = driver.matmul(attn_output, self.output_projection['weight'])
        
        # Add metadata
        if metadata:
            metadata.modality = modality
            metadata.operation = "attention"
            metadata.shape = output.shape
            
        return output, {'attention_weights': attn_output}

def multihead_attention(
    x_name: str,
    Wq_name: str,
    Wk_name: str,
    Wv_name: str,
    Wo_name: str,
    num_heads: int,
    mask_name: Optional[str] = None,
    driver = None,
    chip_id: int = 0,
    sm_id: int = 0,
    scheduler = None
) -> Tuple[str, str]:
    """
    All tensors referenced by their names in driver storage
    Returns: (output_name, attention_weights_name) in driver
    """
    if driver is None:
        raise ValueError("Driver is required for GPU-backed attention")
        
    state = AttentionState(driver, f"mha_{chip_id}_{sm_id}")
    
    # Compute Q, K, V projections in driver memory
    Q_name = state.get_temp_tensor(
        driver.matmul(x_name, Wq_name, chip_id=chip_id, sm_id=sm_id),
        "Q"
    )
    K_name = state.get_temp_tensor(
        driver.matmul(x_name, Wk_name, chip_id=chip_id, sm_id=sm_id),
        "K"
    )
    V_name = state.get_temp_tensor(
        driver.matmul(x_name, Wv_name, chip_id=chip_id, sm_id=sm_id),
        "V"
    )
    
    # Split heads
    Q_heads_name = split_heads(Q_name, num_heads, driver, state)
    K_heads_name = split_heads(K_name, num_heads, driver, state)
    V_heads_name = split_heads(V_name, num_heads, driver, state)
    
    # Free original projections
    state.free_temp_tensor(Q_name)
    state.free_temp_tensor(K_name)
    state.free_temp_tensor(V_name)
    
    # Compute attention
    attn_output_name, attn_weights_name = scaled_dot_product_attention(
        Q_heads_name, K_heads_name, V_heads_name,
        mask_name=mask_name,
        driver=driver,
        chip_id=chip_id,
        sm_id=sm_id,
        scheduler=scheduler
    )
    
    # Free split heads
    state.free_temp_tensor(Q_heads_name)
    state.free_temp_tensor(K_heads_name)
    state.free_temp_tensor(V_heads_name)
    
    # Combine heads
    combined_name = combine_heads(attn_output_name, driver, state)
    state.free_temp_tensor(attn_output_name)
    
    # Final output projection
    output_name = state.get_temp_tensor(
        driver.matmul(combined_name, Wo_name, chip_id=chip_id, sm_id=sm_id),
        "output"
    )
    state.free_temp_tensor(combined_name)
    
    return output_name, attn_weights_name
