"""
Multi-modal generation controller for Helium
Handles coordinated generation of text, images, audio, and video
"""
from typing import Dict, Optional, List, Union, Any
import numpy as np
from dataclasses import dataclass
from .decoder import DecoderConfig, HeliumDecoderBlock
from .broadcast import ModalityType, TensorMetadata
from .core.db_manager import HeliumDBManager

@dataclass
class GenerationConfig:
    """Configuration for multi-modal generation"""
    max_new_tokens: Dict[ModalityType, int]
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 0.95
    repetition_penalty: float = 1.0
    do_sample: bool = True
    num_beams: int = 1
    length_penalty: float = 1.0
    early_stopping: bool = False

class HeliumMultiModalGenerator:
    """
    Coordinated multi-modal generation system
    
    Capabilities:
    - Text generation
    - Image generation
    - Audio generation
    - Video generation
    - Cross-modal generation
    - Conditional generation
    """
    def __init__(
        self,
        decoder_config: DecoderConfig,
        device_id: Optional[str] = None
    ):
        self.config = decoder_config
        self.device_id = device_id
        self.db = HeliumDBManager.get_instance()
        
        # Initialize decoder blocks
        self.decoder_blocks = [
            HeliumDecoderBlock(decoder_config, device_id)
            for _ in range(decoder_config.num_layers)
        ]
        
        # Track generation state
        self.current_sequences: Dict[ModalityType, List[np.ndarray]] = {}
        self.metadata_cache: Dict[ModalityType, TensorMetadata] = {}
        
    def prepare_initial_inputs(
        self,
        prompts: Dict[ModalityType, np.ndarray],
        encoder_outputs: Optional[Dict[ModalityType, np.ndarray]] = None
    ) -> Dict[str, Any]:
        """Prepare inputs for generation"""
        initial_states = {}
        initial_metadata = {}
        
        for modality, prompt in prompts.items():
            # Create initial hidden states
            if modality == ModalityType.TEXT:
                # Text prompts are token IDs
                initial_states[modality] = prompt
            elif modality == ModalityType.IMAGE:
                # Image prompts are pixel values
                initial_states[modality] = prompt.reshape(1, -1, self.config.hidden_dim)
            elif modality == ModalityType.AUDIO:
                # Audio prompts are waveform values
                initial_states[modality] = prompt.reshape(1, -1, self.config.hidden_dim)
                
            # Create metadata
            initial_metadata[modality] = TensorMetadata(
                modality=modality,
                shape=initial_states[modality].shape,
                dtype=initial_states[modality].dtype
            )
            
        return {
            "states": initial_states,
            "metadata": initial_metadata,
            "encoder_outputs": encoder_outputs
        }
        
    def generate(
        self,
        prompts: Dict[ModalityType, np.ndarray],
        target_modalities: List[ModalityType],
        encoder_outputs: Optional[Dict[ModalityType, np.ndarray]] = None,
        generation_config: Optional[GenerationConfig] = None
    ) -> Dict[ModalityType, np.ndarray]:
        """
        Generate multi-modal outputs
        
        Args:
            prompts: Input prompts per modality
            target_modalities: Modalities to generate
            encoder_outputs: Optional encoder outputs for conditioning
            generation_config: Generation parameters
            
        Returns:
            Generated outputs per modality
        """
        if generation_config is None:
            generation_config = GenerationConfig(
                max_new_tokens={m: 100 for m in target_modalities}
            )
            
        # Prepare inputs
        inputs = self.prepare_initial_inputs(prompts, encoder_outputs)
        
        # Initialize output buffers
        outputs = {}
        for modality in target_modalities:
            if modality not in self.config.output_modalities:
                raise ValueError(f"Model not configured for {modality} generation")
            self.current_sequences[modality] = []
            
        # Generate step by step
        for step in range(max(generation_config.max_new_tokens.values())):
            # Determine which modalities need generation this step
            active_modalities = [
                m for m in target_modalities
                if len(self.current_sequences[m]) < generation_config.max_new_tokens[m]
            ]
            
            if not active_modalities:
                break
                
            # Run decoder blocks
            hidden_states = inputs["states"]
            for block in self.decoder_blocks:
                for modality in active_modalities:
                    # Generate for each active modality
                    output = block.forward(
                        hidden_states=hidden_states[modality],
                        target_modality=modality,
                        encoder_hidden_states=inputs["encoder_outputs"].get(modality) if inputs["encoder_outputs"] else None,
                        metadata=inputs["metadata"][modality]
                    )
                    
                    # Apply sampling/filtering
                    if modality == ModalityType.TEXT:
                        output = self._sample_text_output(
                            output,
                            temperature=generation_config.temperature,
                            top_k=generation_config.top_k,
                            top_p=generation_config.top_p
                        )
                    elif modality == ModalityType.IMAGE:
                        output = self._filter_image_output(output)
                    elif modality == ModalityType.AUDIO:
                        output = self._filter_audio_output(output)
                        
                    # Store generated output
                    self.current_sequences[modality].append(output)
                    
                    # Update hidden states for next step
                    hidden_states[modality] = output
                    
        # Post-process and combine outputs
        for modality in target_modalities:
            if modality == ModalityType.TEXT:
                outputs[modality] = self._finalize_text_sequence(
                    self.current_sequences[modality]
                )
            elif modality == ModalityType.IMAGE:
                outputs[modality] = self._finalize_image_sequence(
                    self.current_sequences[modality]
                )
            elif modality == ModalityType.AUDIO:
                outputs[modality] = self._finalize_audio_sequence(
                    self.current_sequences[modality]
                )
                
        return outputs
        
    def _sample_text_output(
        self,
        logits: np.ndarray,
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 0.95
    ) -> np.ndarray:
        """Sample next tokens from logits"""
        # Apply temperature
        if temperature != 1.0:
            logits = logits / temperature
            
        # Apply top-k filtering
        if top_k > 0:
            indices_to_remove = logits < np.partition(
                logits, -top_k, axis=-1)[..., -top_k:][..., None]
            logits[indices_to_remove] = -float('inf')
            
        # Apply top-p filtering
        if top_p < 1.0:
            sorted_logits = np.sort(logits, axis=-1)
            cumulative_probs = np.cumsum(np.exp(sorted_logits), axis=-1)
            cumulative_probs = cumulative_probs / cumulative_probs[..., -1:]
            sorted_indices_to_remove = cumulative_probs < (1 - top_p)
            indices_to_remove = sorted_indices_to_remove.scatter(-1, np.argsort(logits, axis=-1))
            logits[indices_to_remove] = -float('inf')
            
        # Convert to probabilities
        probs = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
        
        # Sample from distribution
        return np.random.choice(
            logits.shape[-1],
            size=logits.shape[:-1],
            p=probs.reshape(-1, probs.shape[-1])
        )
        
    def _filter_image_output(self, pixels: np.ndarray) -> np.ndarray:
        """Post-process generated image pixels"""
        # Clip to valid range
        pixels = np.clip(pixels, 0, 1)
        return pixels
        
    def _filter_audio_output(self, waveform: np.ndarray) -> np.ndarray:
        """Post-process generated audio waveform"""
        # Normalize
        waveform = waveform / np.max(np.abs(waveform))
        return waveform
        
    def _finalize_text_sequence(self, sequence: List[np.ndarray]) -> np.ndarray:
        """Combine generated text tokens"""
        return np.concatenate(sequence, axis=1)
        
    def _finalize_image_sequence(self, sequence: List[np.ndarray]) -> np.ndarray:
        """Combine generated image frames"""
        return np.stack(sequence, axis=0)
        
    def _finalize_audio_sequence(self, sequence: List[np.ndarray]) -> np.ndarray:
        """Combine generated audio segments"""
        return np.concatenate(sequence, axis=1)
