"""
Virtual GPU interface with modality support.
Provides a high-level interface for multi-modal tensor operations.
"""

from typing import Optional, Dict, Any, List, Union, Tuple
import numpy as np
from .modality import ModalityType, ModalityConfig
from .modality_aware_tensor_core import ModalityAwareTensorCore

class ModalityAwareVirtualGPU:
    """High-level interface for modality-aware virtual GPU operations"""
    
    def __init__(
        self,
        num_tensor_cores: int = 1,
        bits: int = 2,
        memory_size: Optional[int] = None,
        bandwidth_tbps: float = 10000
    ):
        # Initialize tensor cores
        self.tensor_cores = [
            ModalityAwareTensorCore(
                bits=bits,
                memory_size=memory_size,
                bandwidth_tbps=bandwidth_tbps
            )
            for _ in range(num_tensor_cores)
        ]
        
        # Track modality assignments to cores
        self.modality_core_map: Dict[ModalityType, int] = {}
        self.core_modality_map: Dict[int, ModalityType] = {}
        
        # Initialize modality-specific memory pools
        self.modality_memory: Dict[ModalityType, Dict[str, np.ndarray]] = {
            modality: {} for modality in ModalityType
        }
        
    def assign_modality_to_core(
        self,
        modality: ModalityType,
        core_idx: int
    ) -> None:
        """Assign a modality to a specific tensor core"""
        if core_idx >= len(self.tensor_cores):
            raise ValueError(f"Invalid core index: {core_idx}")
            
        # Update mappings
        self.modality_core_map[modality] = core_idx
        self.core_modality_map[core_idx] = modality
        
        # Configure the core
        self.tensor_cores[core_idx].set_modality(modality)
        
    def get_core_for_modality(
        self,
        modality: ModalityType
    ) -> Optional[ModalityAwareTensorCore]:
        """Get the tensor core assigned to a modality"""
        core_idx = self.modality_core_map.get(modality)
        if core_idx is not None:
            return self.tensor_cores[core_idx]
        return None
        
    def matmul(
        self,
        x: np.ndarray,
        y: np.ndarray,
        x_modality: ModalityType,
        y_modality: Optional[ModalityType] = None
    ) -> np.ndarray:
        """Matrix multiplication with modality handling"""
        y_mod = y_modality or x_modality
        core = self.get_core_for_modality(x_modality)
        
        if core is None:
            raise ValueError(f"No core assigned for modality: {x_modality}")
            
        return core.matmul(x, y, x_modality, y_mod)
        
    def conv(
        self,
        x: np.ndarray,
        weight: np.ndarray,
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Union[int, Tuple[int, ...]] = 0,
        modality: ModalityType = None
    ) -> np.ndarray:
        """Convolution operation"""
        core = self.get_core_for_modality(modality)
        if core is None:
            raise ValueError(f"No core assigned for modality: {modality}")
            
        return core.conv(x, weight, stride, padding, modality)
        
    def attention(
        self,
        q: np.ndarray,
        k: np.ndarray,
        v: np.ndarray,
        mask: Optional[np.ndarray] = None,
        modality: ModalityType = None
    ) -> np.ndarray:
        """Attention operation"""
        core = self.get_core_for_modality(modality)
        if core is None:
            raise ValueError(f"No core assigned for modality: {modality}")
            
        return core.attention(q, k, v, mask, modality)
        
    def pool(
        self,
        x: np.ndarray,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Optional[Union[int, Tuple[int, ...]]] = None,
        mode: str = 'max',
        modality: ModalityType = None
    ) -> np.ndarray:
        """Pooling operation"""
        core = self.get_core_for_modality(modality)
        if core is None:
            raise ValueError(f"No core assigned for modality: {modality}")
            
        return core.pool(x, kernel_size, stride, mode, modality)
        
    def normalize(
        self,
        x: np.ndarray,
        modality: ModalityType = None,
        eps: float = 1e-5
    ) -> np.ndarray:
        """Normalization operation"""
        core = self.get_core_for_modality(modality)
        if core is None:
            raise ValueError(f"No core assigned for modality: {modality}")
            
        return core.normalize(x, modality, eps)
        
    def fuse_modalities(
        self,
        x: np.ndarray,
        y: np.ndarray,
        x_modality: ModalityType,
        y_modality: ModalityType
    ) -> np.ndarray:
        """Fuse tensors from different modalities"""
        x_core = self.get_core_for_modality(x_modality)
        if x_core is None:
            raise ValueError(f"No core assigned for modality: {x_modality}")
            
        return x_core.fuse_modalities(x, y, x_modality, y_modality)
        
    def unfuse_modalities(
        self,
        z: np.ndarray,
        x_modality: ModalityType,
        y_modality: ModalityType
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Separate fused tensor back into modalities"""
        x_core = self.get_core_for_modality(x_modality)
        if x_core is None:
            raise ValueError(f"No core assigned for modality: {x_modality}")
            
        return x_core.unfuse_modalities(z, x_modality, y_modality)
        
    def store_tensor(
        self,
        data: np.ndarray,
        modality: ModalityType,
        key: str
    ) -> None:
        """Store tensor in modality-specific memory"""
        core = self.get_core_for_modality(modality)
        if core is None:
            raise ValueError(f"No core assigned for modality: {modality}")
            
        core.store_modality_tensor(data, modality, key)
        
    def load_tensor(
        self,
        modality: ModalityType,
        key: str
    ) -> Optional[np.ndarray]:
        """Load tensor from modality-specific memory"""
        core = self.get_core_for_modality(modality)
        if core is None:
            raise ValueError(f"No core assigned for modality: {modality}")
            
        return core.load_modality_tensor(modality, key)
        
    def clear_memory(
        self,
        modality: Optional[ModalityType] = None
    ) -> None:
        """Clear modality-specific memory"""
        if modality:
            core = self.get_core_for_modality(modality)
            if core is not None:
                core.clear_modality_memory(modality)
        else:
            for core in self.tensor_cores:
                core.clear_modality_memory()
