from typing import Dict, List, Optional, Tuple, Union
import numpy as np

from .broadcast import ModalityType

class AttentionState:
    """State tracking for attention computations"""
    def __init__(self, driver, name: str):
        self.driver = driver
        self.name = name
        self.stored_tensors: Dict[str, str] = {}

def split_heads(
    x: Union[str, "HeliumTensor"],
    num_heads: int,
    driver,
    modality: Optional[ModalityType] = None
) -> Union[str, "HeliumTensor"]:
    """Split hidden dim into multiple heads"""
    if isinstance(x, str):
        x = driver.get_tensor(x)
        
    batch_size, seq_len, hidden_dim = x.shape
    head_dim = hidden_dim // num_heads
    
    # Reshape and transpose
    x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
    x = driver.transpose(x, (0, 2, 1, 3))
    
    # Apply modality-specific scaling
    if modality:
        scale = 1.0
        if modality == ModalityType.IMAGE:
            scale = np.sqrt(head_dim / 64)
        elif modality == ModalityType.AUDIO:
            scale = np.sqrt(head_dim / 32)
            
        if scale != 1.0:
            x = driver.mul_scalar(x, scale)
            
    return x

def apply_rotary_embedding(
    x: Union[str, "HeliumTensor"],
    seq_len: int,
    head_dim: int,
    driver
) -> Union[str, "HeliumTensor"]:
    """Apply rotary positional embeddings"""
    if isinstance(x, str):
        x = driver.get_tensor(x)
        
    # Generate position indices
    pos = np.arange(seq_len)
    
    # Generate frequencies
    freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
    angles = pos[:, None] * freqs[None, :]
    
    # Generate rotation matrix elements
    cos = np.cos(angles).reshape(seq_len, -1)
    sin = np.sin(angles).reshape(seq_len, -1)
    
    # Move to device
    cos = driver.to_gpu(cos)
    sin = driver.to_gpu(sin)
    
    # Apply rotations
    x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
    x = driver.add(x, x_rot)
    
    return x

def fuse_cross_modal_attention(
    q: Union[str, "HeliumTensor"],
    k: Union[str, "HeliumTensor"],
    v: Union[str, "HeliumTensor"],
    q_modality: ModalityType,
    kv_modality: ModalityType,
    fusion_type: str,
    driver,
    state: AttentionState
) -> Tuple[Union[str, "HeliumTensor"], Union[str, "HeliumTensor"], Union[str, "HeliumTensor"]]:
    """Fuse cross-modal attention patterns"""
    if fusion_type == "additive":
        # Simple additive fusion
        q = driver.add(q, k)
        k = q
    elif fusion_type == "multiplicative":
        # Element-wise multiplication
        q = driver.mul(q, k)
        k = q
    elif fusion_type == "gated":
        # Gated fusion with learned parameters
        gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
        q = driver.add(
            driver.mul(gate, q),
            driver.mul(driver.sub(1.0, gate), k)
        )
        k = q
        
    return q, k, v
