"""
Hardware-accelerated multi-modal transformer decoder implementation for Helium virtual GPU
"""
from typing import Optional, Union, Dict, Any, TYPE_CHECKING, List, Tuple
from dataclasses import dataclass
import numpy as np
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout
from virtual_gpu_driver.src.stream import Stream as ComputeStream
from virtual_gpu_driver.src.stream import StreamManager as KernelSchedule
from .main import get_device, get_default_device
from .layer_norm import HeliumLayerNorm
from .gelu import HeliumGELU
from .multihead_attention import HeliumMultiHeadAttention
from .core.db_manager import HeliumDBManager
from .broadcast import ModalityType, TensorMetadata

@dataclass
class DecoderConfig:
    """Configuration for multi-modal decoder"""
    output_modalities: List[ModalityType]
    hidden_dim: int
    num_layers: int
    num_heads: int
    intermediate_size: int
    max_seq_len: Dict[ModalityType, int]
    vocab_size: Optional[int] = None  # For text generation
    image_size: Optional[Tuple[int, int]] = None  # For image generation
    audio_params: Optional[Dict[str, Any]] = None  # For audio generation
    use_cache: bool = True
    dtype: str = "float16"
    
    def validate(self):
        """Validate configuration"""
        for modality in self.output_modalities:
            if modality == ModalityType.TEXT and not self.vocab_size:
                raise ValueError("vocab_size required for text generation")
            elif modality == ModalityType.IMAGE and not self.image_size:
                raise ValueError("image_size required for image generation")
            elif modality == ModalityType.AUDIO and not self.audio_params:
                raise ValueError("audio_params required for audio generation")

if TYPE_CHECKING:
    from .main import HeliumTensor

class ModalityProjection:
    """Projects hidden states to modality-specific outputs"""
    def __init__(
        self,
        config: DecoderConfig,
        modality: ModalityType,
        driver=None
    ):
        self.config = config
        self.modality = modality
        self.driver = driver
        
        if modality == ModalityType.TEXT:
            self.proj = self._create_linear(
                config.hidden_dim,
                config.vocab_size
            )
        elif modality == ModalityType.IMAGE:
            h, w = config.image_size
            self.proj = self._create_linear(
                config.hidden_dim,
                h * w * 3  # RGB channels
            )
        elif modality == ModalityType.AUDIO:
            self.proj = self._create_linear(
                config.hidden_dim,
                config.audio_params["num_samples"]
            )
            
    def _create_linear(self, in_features: int, out_features: int) -> Dict[str, Any]:
        """Create projection layer"""
        weight_desc = TensorDescriptor(
            shape=(out_features, in_features),
            dtype=DType.FLOAT16,
            device=Device.VGPU,
            layout=Layout.ROW_MAJOR
        )
        bias_desc = TensorDescriptor(
            shape=(out_features,),
            dtype=DType.FLOAT16,
            device=Device.VGPU,
            layout=Layout.ROW_MAJOR
        )
        return {
            'weight': self.driver.allocate_tensor(weight_desc),
            'bias': self.driver.allocate_tensor(bias_desc)
        }
        
    def forward(
        self,
        hidden_states: Union[str, "HeliumTensor"]
    ) -> Union[str, "HeliumTensor"]:
        """Project to modality-specific output space"""
        out = self.driver.matmul(hidden_states, self.proj['weight'])
        out = self.driver.add(out, self.proj['bias'])
        
        if self.modality == ModalityType.IMAGE:
            # Reshape to image format (B, H, W, C)
            h, w = self.config.image_size
            out = self.driver.reshape(out, (-1, h, w, 3))
        elif self.modality == ModalityType.AUDIO:
            # Apply audio-specific processing
            if self.config.audio_params.get("normalize", True):
                out = self.driver.tanh(out)
                
        return out

class HeliumDecoderBlock:
    """
    Hardware-accelerated multi-modal transformer decoder block
    
    Implements:
    1. Self-attention with causal mask
    2. Cross-attention with encoder outputs
    3. Feed-forward network
    4. Multi-modal output projections
    All operations run directly on virtual GPU with modality awareness
    """
    def __init__(
        self,
        config: DecoderConfig,
        device_id: Optional[str] = None
    ):
        # Initialize device and stream
        self.driver = get_device(device_id) if device_id else get_default_device()
        self.device_id = device_id
        self.stream = ComputeStream(self.driver)
        
        # Initialize database connection
        self.db = HeliumDBManager.get_instance()
        
        # Store configuration
        self.config = config
        
        # Architecture parameters
        self.hidden_size = config.hidden_dim
        self.num_heads = config.num_heads
        self.head_dim = config.hidden_dim // config.num_heads
        self.intermediate_size = config.intermediate_size
        self.dtype = config.dtype
        
        # Initialize layer components
        self.self_attention = HeliumMultiHeadAttention(
            hidden_size=self.hidden_size,
            num_heads=self.num_heads,
            device_id=device_id,
            dtype=self.dtype
        )
        
        self.cross_attention = HeliumMultiHeadAttention(
            hidden_size=self.hidden_size,
            num_heads=self.num_heads,
            device_id=device_id,
            dtype=self.dtype
        )
        
        # Layer norms
        self.ln1 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
        self.ln2 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
        self.ln3 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
        
        # Feed-forward layers
        self.ff1 = self._create_linear(self.hidden_size, self.intermediate_size)
        self.ff2 = self._create_linear(self.intermediate_size, self.hidden_size)
        self.gelu = HeliumGELU(device_id=device_id)
        
        # Initialize modality-specific output projections
        self.output_projections = {
            modality: ModalityProjection(config, modality, self.driver)
            for modality in config.output_modalities
        }
        
        # Operation scheduling
        self.schedule = KernelSchedule(self.driver)
        
        # Track allocated tensors
        self._temp_tensors = {}
        self._counter = 0
        # Initialize layer components
        self.self_attention = HeliumMultiHeadAttention(
            hidden_size=hidden_size,
            num_heads=num_heads,
            device_id=device_id,
            dtype=dtype
        )
        
        self.cross_attention = HeliumMultiHeadAttention(
            hidden_size=hidden_size,
            num_heads=num_heads,
            device_id=device_id,
            dtype=dtype
        )
        
        # Layer norms
        self.ln1 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
        self.ln2 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
        self.ln3 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
        
        # Feed-forward layers
        self.ff1 = self._create_linear(hidden_size, intermediate_size)
        self.ff2 = self._create_linear(intermediate_size, hidden_size)
        self.gelu = HeliumGELU(device_id=device_id)
        
        # Operation scheduling
        self.schedule = KernelSchedule(self.driver)
        
        # Track allocated tensors
        self._temp_tensors = {}
        self._counter = 0
        
    def _create_linear(self, in_features: int, out_features: int) -> Dict[str, Any]:
        """Create a linear layer's weight tensors"""
        weight_desc = TensorDescriptor(
            shape=(out_features, in_features),
            dtype=getattr(DType, self.dtype.upper()),
            device=Device.VGPU,
            layout=Layout.ROW_MAJOR
        )
        
        bias_desc = TensorDescriptor(
            shape=(out_features,),
            dtype=getattr(DType, self.dtype.upper()),
            device=Device.VGPU,
            layout=Layout.ROW_MAJOR
        )
        
        return {
            'weight': self.driver.allocate_tensor(weight_desc),
            'bias': self.driver.allocate_tensor(bias_desc)
        }
        
    def _get_temp_tensor(self, shape: tuple) -> str:
        """Allocate a temporary tensor"""
        tensor_id = f"decoder_temp_{self._counter}"
        self._counter += 1
        
        desc = TensorDescriptor(
            shape=shape,
            dtype=getattr(DType, self.dtype.upper()),
            device=Device.VGPU,
            layout=Layout.ROW_MAJOR
        )
        
        self._temp_tensors[tensor_id] = self.driver.allocate_tensor(desc)
        return tensor_id
        
    def _free_temp_tensor(self, tensor_id: str):
        """Free a temporary tensor"""
        if tensor_id in self._temp_tensors:
            self.driver.free_tensor(self._temp_tensors[tensor_id])
            del self._temp_tensors[tensor_id]
            
    def __del__(self):
        """Clean up temporary tensors"""
        for tensor_id in list(self._temp_tensors.keys()):
            self._free_temp_tensor(tensor_id)

    def forward(
        self,
        hidden_states: Union[str, "HeliumTensor"],
        target_modality: ModalityType,
        encoder_hidden_states: Optional[Union[str, "HeliumTensor"]] = None,
        attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
        encoder_attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
        metadata: Optional[TensorMetadata] = None
    ) -> Union[str, "HeliumTensor"]:
        """
        Forward pass of decoder block
        
        Args:
            hidden_states: Input tensor (B, S, H)
            encoder_hidden_states: Optional encoder output (B, S_enc, H)
            attention_mask: Optional attention mask for self-attention
            encoder_attention_mask: Optional mask for encoder-decoder attention
            
        Returns:
            Output tensor (B, S, H)
        """
        residual = hidden_states
        
        # Self attention branch
        with self.stream:
            # Layer norm 1
            hidden_states = self.ln1(hidden_states)
            
            # Self attention
            hidden_states = self.self_attention(
                hidden_states,
                attention_mask=attention_mask,
                causal_mask=True  # Always use causal mask in decoder
            )
            
            # Residual connection
            hidden_states = self.driver.add(hidden_states, residual)
            
        # Cross attention branch (if encoder present)
        if encoder_hidden_states is not None:
            residual = hidden_states
            
            with self.stream:
                # Layer norm 2
                hidden_states = self.ln2(hidden_states)
                
                # Cross attention
                hidden_states = self.cross_attention(
                    hidden_states,
                    encoder_hidden_states,
                    attention_mask=encoder_attention_mask
                )
                
                # Residual connection
                hidden_states = self.driver.add(hidden_states, residual)
                
        # Feed-forward branch
        residual = hidden_states
        
        with self.stream:
            # Layer norm 3
            hidden_states = self.ln3(hidden_states)
            
            # Feed-forward
            hidden_states = self.driver.matmul(
                hidden_states,
                self.ff1['weight']
            )
            hidden_states = self.driver.add(hidden_states, self.ff1['bias'])
            hidden_states = self.gelu(hidden_states)
            
            hidden_states = self.driver.matmul(
                hidden_states,
                self.ff2['weight']
            )
            hidden_states = self.driver.add(hidden_states, self.ff2['bias'])
            
            # Final residual
            hidden_states = self.driver.add(hidden_states, residual)
            
        # Project to target modality
        if target_modality not in self.output_projections:
            raise ValueError(f"No projection available for modality {target_modality}")
            
        output = self.output_projections[target_modality].forward(hidden_states)
        
        # Update metadata if provided
        if metadata is not None:
            metadata.modality = target_modality
            if target_modality == ModalityType.IMAGE:
                h, w = self.config.image_size
                metadata.spatial_dims = (h, w)
                metadata.channels = 3
            elif target_modality == ModalityType.AUDIO:
                metadata.sampling_rate = self.config.audio_params.get("sampling_rate")
            elif target_modality == ModalityType.TEXT:
                metadata.sequence_length = output.shape[1]
                
        return output
