import numpy as np
from enum import Enum, auto
from typing import List, Tuple, Optional, Dict, Union
from dataclasses import dataclass

class ModalityType(Enum):
    """Supported modality types for tensors"""
    TEXT = "text"
    IMAGE = "image"
    VIDEO = "video"
    AUDIO = "audio"
    VISION = "vision"
    LATENT = "latent"
    EMBEDDING = "embedding"
    ATTENTION = "attention"
    
@dataclass
class TensorMetadata:
    """Metadata for multi-modal tensors"""
    modality: ModalityType
    shape: Tuple[int, ...]
    dtype: np.dtype
    channels: int = 1
    sampling_rate: Optional[int] = None  # For audio
    frame_rate: Optional[int] = None     # For video
    sequence_length: Optional[int] = None # For text/time series
    spatial_dims: Optional[Tuple[int, ...]] = None  # For image/video

class BroadcastState:
    def __init__(self, driver, prefix: str):
        self.driver = driver
        self.prefix = prefix
        self.counter = 0
        self.metadata_cache: Dict[str, TensorMetadata] = {}
        
    def get_temp_tensor(
        self,
        data,
        name_suffix: str = "",
        metadata: Optional[TensorMetadata] = None
    ) -> str:
        """Store temporary computation results in driver memory with metadata"""
        name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
        self.counter += 1
        self.driver.create_tensor(name, data)
        
        if metadata:
            self.metadata_cache[name] = metadata
            if hasattr(self.driver, 'set_tensor_metadata'):
                self.driver.set_tensor_metadata(name, metadata)
                
        return name
        
    def free_temp_tensor(self, name: str):
        """Clean up temporary tensors"""
        if self.driver.tensor_exists(name):
            self.driver.delete_tensor(name)

def validate_modality_compatibility(
    modalities: List[ModalityType],
    shapes: List[Tuple[int, ...]],
    metadata_list: List[TensorMetadata]
) -> bool:
    """
    Validate if tensors with given modalities can be broadcast together
    """
    # Basic modality compatibility rules
    text_modalities = {ModalityType.TEXT, ModalityType.EMBEDDING}
    spatial_modalities = {ModalityType.IMAGE, ModalityType.VISION, ModalityType.VIDEO}
    temporal_modalities = {ModalityType.AUDIO, ModalityType.VIDEO}
    
    unique_modalities = set(modalities)
    
    # Check if mixing text and spatial modalities
    if unique_modalities & text_modalities and unique_modalities & spatial_modalities:
        # Ensure there's an attention or embedding bridge
        if ModalityType.ATTENTION not in unique_modalities and \
           ModalityType.EMBEDDING not in unique_modalities:
            return False
            
    # Check sampling rate compatibility for temporal modalities
    if len(unique_modalities & temporal_modalities) > 1:
        rates = [m.sampling_rate for m in metadata_list 
                if m.modality in temporal_modalities]
        if not all(r == rates[0] for r in rates):
            return False
            
    return True

def compute_broadcast_shapes_with_modality(
    *shapes: Tuple[int, ...],
    metadata_list: Optional[List[TensorMetadata]] = None
) -> Tuple[Tuple[int, ...], Optional[TensorMetadata]]:
    """
    Compute broadcast shapes with modality awareness
    Returns (broadcast_shape, broadcast_metadata) or (None, None) if not compatible
    """
    if metadata_list and len(shapes) != len(metadata_list):
        raise ValueError("Number of shapes must match number of metadata entries")
        
    if metadata_list:
        modalities = [m.modality for m in metadata_list]
        if not validate_modality_compatibility(modalities, list(shapes), metadata_list):
            return None, None
            
    # Compute basic shape broadcasting
    result = []
    for dims in zip(*[reversed(s) for s in shapes]):
        dim = max(dims)
        if all(d == 1 or d == dim for d in dims):
            result.append(dim)
        else:
            return None, None
            
    broadcast_shape = tuple(reversed(result))
    
    # Compute broadcast metadata if provided
    if metadata_list:
        # Take the highest resolution/quality metadata
        broadcast_metadata = TensorMetadata(
            modality=metadata_list[0].modality,  # Primary modality
            shape=broadcast_shape,
            dtype=metadata_list[0].dtype,
            channels=max(m.channels for m in metadata_list),
            sampling_rate=max((m.sampling_rate or 0) for m in metadata_list),
            frame_rate=max((m.frame_rate or 0) for m in metadata_list),
            sequence_length=max((m.sequence_length or 0) for m in metadata_list),
            spatial_dims=max((m.spatial_dims or (0,)) for m in metadata_list)
        )
        return broadcast_shape, broadcast_metadata
        
    return broadcast_shape, None

def compute_broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
    """Legacy compatibility wrapper"""
    shape, _ = compute_broadcast_shapes_with_modality(*shapes)
    return shape

class BroadcastModule:
    def __init__(self, driver):
        self.driver = driver
        self.metadata_cache: Dict[str, TensorMetadata] = {}
        
    def _align_spatial_dims(
        self,
        tensor_name: str,
        target_metadata: TensorMetadata,
        state: BroadcastState
    ) -> str:
        """Align spatial dimensions for image/video tensors"""
        if not hasattr(self.driver, 'resize'):
            return tensor_name
            
        tensor = self.driver.get_tensor(tensor_name)
        metadata = self._get_tensor_metadata(tensor_name)
        
        if not metadata or not target_metadata.spatial_dims:
            return tensor_name
            
        if metadata.spatial_dims != target_metadata.spatial_dims:
            resized_name = state.get_temp_tensor(
                self.driver.resize(
                    tensor_name,
                    target_metadata.spatial_dims,
                    mode='bilinear'
                ),
                "resized",
                target_metadata
            )
            return resized_name
            
        return tensor_name
        
    def _align_sampling_rate(
        self,
        tensor_name: str,
        target_metadata: TensorMetadata,
        state: BroadcastState
    ) -> str:
        """Align sampling rates for audio tensors"""
        if not hasattr(self.driver, 'resample'):
            return tensor_name
            
        tensor = self.driver.get_tensor(tensor_name)
        metadata = self._get_tensor_metadata(tensor_name)
        
        if not metadata or not target_metadata.sampling_rate:
            return tensor_name
            
        if metadata.sampling_rate != target_metadata.sampling_rate:
            resampled_name = state.get_temp_tensor(
                self.driver.resample(
                    tensor_name,
                    metadata.sampling_rate,
                    target_metadata.sampling_rate
                ),
                "resampled",
                target_metadata
            )
            return resampled_name
            
        return tensor_name
        
    def _align_sequence_length(
        self,
        tensor_name: str,
        target_metadata: TensorMetadata,
        state: BroadcastState
    ) -> str:
        """Align sequence lengths for text/embedding tensors"""
        tensor = self.driver.get_tensor(tensor_name)
        metadata = self._get_tensor_metadata(tensor_name)
        
        if not metadata or not target_metadata.sequence_length:
            return tensor_name
            
        current_length = metadata.sequence_length
        target_length = target_metadata.sequence_length
        
        if current_length != target_length:
            if current_length > target_length:
                # Truncate
                sliced_name = state.get_temp_tensor(
                    self.driver.slice(
                        tensor_name,
                        (0, target_length)
                    ),
                    "truncated",
                    target_metadata
                )
                return sliced_name
            else:
                # Pad
                padded_name = state.get_temp_tensor(
                    self.driver.pad(
                        tensor_name,
                        ((0, target_length - current_length),),
                        mode='constant'
                    ),
                    "padded",
                    target_metadata
                )
                return padded_name
                
        return tensor_name
        
    def _expand_dims(self, tensor_name: str, target_dims: int, state: BroadcastState) -> str:
        """Add leading dimensions of size 1 to match target dimensionality"""
        tensor = self.driver.get_tensor(tensor_name)
        current_dims = len(tensor.shape)
        if current_dims < target_dims:
            new_shape = (1,) * (target_dims - current_dims) + tensor.shape
            metadata = self._get_tensor_metadata(tensor_name)
            expanded_name = state.get_temp_tensor(
                self.driver.reshape(tensor_name, new_shape),
                "expanded",
                metadata
            )
            return expanded_name
        return tensor_name
        
    def _broadcast_to(self, tensor_name: str, target_shape: Tuple[int, ...], 
                     state: BroadcastState) -> str:
        """Broadcast tensor to target shape in driver memory"""
        tensor = self.driver.get_tensor(tensor_name)
        current_shape = tensor.shape
        
        if current_shape == target_shape:
            return tensor_name
            
        # Check if broadcasting is possible
        for c, t in zip(reversed(current_shape), reversed(target_shape)):
            if c != 1 and c != t:
                raise ValueError(f"Shape {current_shape} cannot be broadcast to {target_shape}")
                
        broadcast_name = state.get_temp_tensor(
            self.driver.broadcast_to(tensor_name, target_shape),
            "broadcast"
        )
        return broadcast_name
        
    def _get_tensor_metadata(self, tensor_name: str) -> Optional[TensorMetadata]:
        """Get metadata for a tensor from driver or cache"""
        if tensor_name in self.metadata_cache:
            return self.metadata_cache[tensor_name]
            
        if hasattr(self.driver, 'get_tensor_metadata'):
            return self.driver.get_tensor_metadata(tensor_name)
            
        return None
        
    def broadcast_tensors(self, *tensor_names: str) -> List[str]:
        """
        Broadcast tensors to a common shape in driver memory.
        Handles multi-modal tensors with metadata preservation.
        Returns list of broadcasted tensor names.
        """
        state = BroadcastState(self.driver, "broadcast")
        
        # Get shapes and metadata from driver memory
        shapes = []
        metadata_list = []
        
        for name in tensor_names:
            tensor = self.driver.get_tensor(name)
            shapes.append(tensor.shape)
            metadata = self._get_tensor_metadata(name)
            metadata_list.append(metadata if metadata else TensorMetadata(
                modality=ModalityType.LATENT,
                shape=tensor.shape,
                dtype=tensor.dtype
            ))
            
        # Compute target shape with modality awareness
        target_shape, target_metadata = compute_broadcast_shapes_with_modality(
            *shapes,
            metadata_list=metadata_list
        )
        
        if target_shape is None:
            raise ValueError(
                f"Tensors with shapes {shapes} and modalities "
                f"{[m.modality for m in metadata_list]} cannot be broadcast together"
            )
            
        # Handle modality-specific transforms before broadcasting
        transformed_names = []
        for name, metadata in zip(tensor_names, metadata_list):
            # Apply modality-specific preprocessing
            if metadata.modality in {ModalityType.IMAGE, ModalityType.VIDEO}:
                # Ensure spatial dimensions are properly aligned
                name = self._align_spatial_dims(name, target_metadata, state)
            elif metadata.modality in {ModalityType.AUDIO}:
                # Resample if needed
                name = self._align_sampling_rate(name, target_metadata, state)
            elif metadata.modality in {ModalityType.TEXT, ModalityType.EMBEDDING}:
                # Pad/truncate sequences if needed
                name = self._align_sequence_length(name, target_metadata, state)
            transformed_names.append(name)
            
        # Expand dimensions to match target dimensionality
        target_dims = len(target_shape)
        expanded_names = [
            self._expand_dims(name, target_dims, state)
            for name in transformed_names
        ]
        
        # Broadcast each tensor to target shape
        result_names = []
        for expanded_name, orig_metadata in zip(expanded_names, metadata_list):
            broadcast_name = self._broadcast_to(
                expanded_name, 
                target_shape,
                state
            )
            # Update metadata for broadcasted tensor
            if target_metadata:
                state.metadata_cache[broadcast_name] = target_metadata
            result_names.append(broadcast_name)
        
        # Clean up expanded tensors if they were created
        for exp_name, orig_name in zip(expanded_names, tensor_names):
            if exp_name != orig_name:
                state.free_temp_tensor(exp_name)
                
        return result_names
        
    def binary_op_broadcast(self, a_name: str, b_name: str, 
                          op_name: str = "add") -> Tuple[str, str]:
        """
        Broadcast two tensors for a binary operation.
        Returns tuple of broadcasted tensor names.
        """
        return tuple(self.broadcast_tensors(a_name, b_name))
        
    def unary_op_broadcast(self, tensor_name: str, target_shape: Tuple[int, ...]) -> str:
        """
        Broadcast a tensor to a target shape for a unary operation.
        Returns broadcasted tensor name.
        """
        state = BroadcastState(self.driver, f"unary_{tensor_name}")
        tensor = self.driver.get_tensor(tensor_name)
        
        # First expand dims if needed
        expanded_name = self._expand_dims(tensor_name, len(target_shape), state)
        
        # Then broadcast to target shape
        result_name = self._broadcast_to(expanded_name, target_shape, state)
        
        # Clean up if expansion was needed
        if expanded_name != tensor_name:
            state.free_temp_tensor(expanded_name)
            
        return result_name
        
class BroadcastBackward:
    def __init__(self, driver):
        self.driver = driver
        
    def reduce_gradient(self, grad_name: str, original_shape: Tuple[int, ...]) -> str:
        """
        Reduce gradient back to original tensor shape after broadcasting.
        All operations done in driver memory.
        """
        state = BroadcastState(self.driver, f"reduce_{grad_name}")
        grad = self.driver.get_tensor(grad_name)
        grad_shape = grad.shape
        
        # Nothing to reduce if shapes match
        if grad_shape == original_shape:
            return grad_name
            
        # Calculate dimensions to sum over
        reduce_dims = []
        grad_dims = len(grad_shape)
        orig_dims = len(original_shape)
        
        # Handle leading dimensions
        if grad_dims > orig_dims:
            reduce_dims.extend(range(grad_dims - orig_dims))
            
        # Handle size-1 dimensions in original shape
        for i, (orig, grad) in enumerate(zip(reversed(original_shape),
                                           reversed(grad_shape[-orig_dims:]))):
            if orig == 1 and grad != 1:
                reduce_dims.append(grad_dims - orig_dims + i)
                
        # Sum over required dimensions
        if reduce_dims:
            reduced_name = state.get_temp_tensor(
                self.driver.sum(grad_name, axis=tuple(reduce_dims), keepdims=True),
                "reduced"
            )
            
            # Reshape to original shape
            result_name = state.get_temp_tensor(
                self.driver.reshape(reduced_name, original_shape),
                "reshaped"
            )
            state.free_temp_tensor(reduced_name)
            return result_name
            
        return grad_name
        
# Example usage:
"""
# Initialize
driver = YourDriver()
broadcast_module = BroadcastModule(driver)
backward_module = BroadcastBackward(driver)

# Forward pass with broadcasting
a_name = "tensor_a"  # shape: (2, 1, 4)
b_name = "tensor_b"  # shape: (3, 1)
c_name, d_name = broadcast_module.binary_op_broadcast(a_name, b_name)
# c_name and d_name now have shape (2, 3, 4)

# Backward pass
grad_name = "output_grad"  # shape: (2, 3, 4)
grad_a = backward_module.reduce_gradient(grad_name, (2, 1, 4))
grad_b = backward_module.reduce_gradient(grad_name, (3, 1))
"""
