from typing import Optional, List, Tuple
from .broadcast import BroadcastModule, BroadcastBackward

class BroadcastOps:
    """Operations with automatic broadcasting support"""
    
    def __init__(self, driver):
        self.driver = driver
        self.broadcast_module = BroadcastModule(driver)
        self.backward_module = BroadcastBackward(driver)
        
    def binary_op(self, op_name: str, a_name: str, b_name: str,
                 save_shapes: bool = True) -> Tuple[str, Optional[Tuple[str, str]]]:
        """
        Perform binary operation with automatic broadcasting.
        Returns:
            - Result tensor name
            - Tuple of (broadcasted_a, broadcasted_b) if save_shapes=True
        """
        # Get original shapes for backward pass
        if save_shapes:
            a_shape = self.driver.get_tensor(a_name).shape
            b_shape = self.driver.get_tensor(b_name).shape
            
        # Broadcast tensors
        broadcasted = self.broadcast_module.binary_op_broadcast(a_name, b_name)
        
        # Perform operation
        if op_name == "add":
            result = self.driver.add(broadcasted[0], broadcasted[1])
        elif op_name == "mul":
            result = self.driver.mul(broadcasted[0], broadcasted[1])
        elif op_name == "div":
            result = self.driver.div(broadcasted[0], broadcasted[1])
        else:
            raise ValueError(f"Unsupported operation: {op_name}")
            
        result_name = f"{op_name}_result_{id(result)}"
        self.driver.create_tensor(result_name, result)
        
        if save_shapes:
            return result_name, broadcasted
        return result_name, None
        
    def backward_binary_op(self, op_name: str, grad_output_name: str,
                         original_shapes: Tuple[Tuple[int, ...], Tuple[int, ...]],
                         broadcasted: Tuple[str, str]) -> Tuple[Optional[str], Optional[str]]:
        """
        Compute gradients for binary operation with broadcasting.
        Returns gradients for both inputs (may be None if not required).
        """
        grad_a = None
        grad_b = None
        
        if op_name == "add":
            # For addition, just reduce the gradient back to original shapes
            grad_a = self.backward_module.reduce_gradient(grad_output_name, original_shapes[0])
            grad_b = self.backward_module.reduce_gradient(grad_output_name, original_shapes[1])
            
        elif op_name == "mul":
            # For multiplication, multiply by the other tensor then reduce
            grad_a = self.backward_module.reduce_gradient(
                self.driver.mul(grad_output_name, broadcasted[1]),
                original_shapes[0]
            )
            grad_b = self.backward_module.reduce_gradient(
                self.driver.mul(grad_output_name, broadcasted[0]),
                original_shapes[1]
            )
            
        elif op_name == "div":
            # For division, more complex gradients
            b_squared = self.driver.mul(broadcasted[1], broadcasted[1])
            grad_a = self.backward_module.reduce_gradient(
                self.driver.div(grad_output_name, broadcasted[1]),
                original_shapes[0]
            )
            grad_b = self.backward_module.reduce_gradient(
                self.driver.mul(
                    grad_output_name,
                    self.driver.div(
                        self.driver.mul(broadcasted[0], -1.0),
                        b_squared
                    )
                ),
                original_shapes[1]
            )
            
        return grad_a, grad_b

class BroadcastTensor:
    """Tensor wrapper with broadcasting support"""
    
    def __init__(self, name: str, driver, requires_grad: bool = False):
        self.name = name
        self.driver = driver
        self.requires_grad = requires_grad
        self.grad_name = None if not requires_grad else f"{name}_grad"
        self.broadcast_ops = BroadcastOps(driver)
        self._ctx = None
        
    @property
    def shape(self) -> Tuple[int, ...]:
        return self.driver.get_tensor(self.name).shape
        
    def _create_ctx(self, op_name: str, other: 'BroadcastTensor',
                   result_name: str, broadcasted: Tuple[str, str]):
        """Create context for backward pass"""
        if self.requires_grad or (other and other.requires_grad):
            self._ctx = {
                'op': op_name,
                'shapes': (self.shape, other.shape if other else None),
                'broadcasted': broadcasted,
                'self_name': self.name,
                'other_name': other.name if other else None,
                'result_name': result_name
            }
            
    def backward(self, grad_name: Optional[str] = None):
        """Execute backward pass with broadcasting support"""
        if not self.requires_grad or not self._ctx:
            return
            
        if grad_name is None:
            # Create ones gradient
            grad_name = f"{self.name}_ones_grad"
            self.driver.create_tensor(
                grad_name,
                self.driver.ones_like(self.name)
            )
            
        op = self._ctx['op']
        shapes = self._ctx['shapes']
        broadcasted = self._ctx['broadcasted']
        
        grad_self, grad_other = self.broadcast_ops.backward_binary_op(
            op, grad_name, shapes, broadcasted
        )
        
        # Accumulate gradients
        if grad_self is not None and self.requires_grad:
            if self.grad_name is None:
                self.grad_name = grad_self
            else:
                self.driver.add_(self.grad_name, grad_self)
                
        # Propagate to other tensor if needed
        other = self._ctx['other_name']
        if other and grad_other is not None:
            other.backward(grad_other)
            
    def __add__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
        result_name, broadcasted = self.broadcast_ops.binary_op(
            "add", self.name, other.name
        )
        result = BroadcastTensor(
            result_name,
            self.driver,
            requires_grad=self.requires_grad or other.requires_grad
        )
        result._create_ctx("add", other, result_name, broadcasted)
        return result
        
    def __mul__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
        result_name, broadcasted = self.broadcast_ops.binary_op(
            "mul", self.name, other.name
        )
        result = BroadcastTensor(
            result_name,
            self.driver,
            requires_grad=self.requires_grad or other.requires_grad
        )
        result._create_ctx("mul", other, result_name, broadcasted)
        return result
        
    def __truediv__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
        result_name, broadcasted = self.broadcast_ops.binary_op(
            "div", self.name, other.name
        )
        result = BroadcastTensor(
            result_name,
            self.driver,
            requires_grad=self.requires_grad or other.requires_grad
        )
        result._create_ctx("div", other, result_name, broadcasted)
        return result

# Example usage:
"""
# Initialize
driver = YourDriver()

# Create tensors with broadcasting support
a = BroadcastTensor("tensor_a", driver, requires_grad=True)  # shape: (2, 1, 4)
b = BroadcastTensor("tensor_b", driver, requires_grad=True)  # shape: (3, 1)

# Operations with automatic broadcasting
c = a + b  # shape: (2, 3, 4)
d = c * a  # broadcasting happens automatically

# Backward pass with proper gradient broadcasting
d.backward()
"""
