import numpy as np
from helium.broadcast_ops import BroadcastTensor
from helium.broadcast import BroadcastModule, BroadcastBackward

def test_broadcasting():
    """Test broadcasting operations with various shapes"""
    # Initialize dummy driver for testing
    class DummyDriver:
        def __init__(self):
            self.tensors = {}
            self.counter = 0
            
        def create_tensor(self, name, data):
            self.tensors[name] = np.array(data)
            return name
            
        def get_tensor(self, name):
            return self.tensors[name]
            
        def delete_tensor(self, name):
            if name in self.tensors:
                del self.tensors[name]
                
        def tensor_exists(self, name):
            return name in self.tensors
            
        def add(self, a, b):
            return self.get_tensor(a) + self.get_tensor(b)
            
        def mul(self, a, b):
            return self.get_tensor(a) * self.get_tensor(b)
            
        def div(self, a, b):
            return self.get_tensor(a) / self.get_tensor(b)
            
        def sum(self, x, axis=None, keepdims=False):
            return np.sum(self.get_tensor(x), axis=axis, keepdims=keepdims)
            
        def ones_like(self, x):
            return np.ones_like(self.get_tensor(x))
            
        def broadcast_to(self, x, shape):
            return np.broadcast_to(self.get_tensor(x), shape)
            
        def reshape(self, x, shape):
            return self.get_tensor(x).reshape(shape)
            
        def add_(self, a, b):
            self.tensors[a] += self.get_tensor(b)
            
    # Create driver instance
    driver = DummyDriver()
    
    # Test Case 1: Basic Broadcasting
    print("Test Case 1: Basic Broadcasting")
    a = np.array([[1, 2, 3],
                  [4, 5, 6]])  # shape: (2, 3)
    b = np.array([10, 20, 30])  # shape: (3,)
    
    tensor_a = BroadcastTensor("a", driver, requires_grad=True)
    tensor_b = BroadcastTensor("b", driver, requires_grad=True)
    
    driver.create_tensor("a", a)
    driver.create_tensor("b", b)
    
    # Test addition
    c = tensor_a + tensor_b
    result = driver.get_tensor(c.name)
    print(f"Shape a: {a.shape}")
    print(f"Shape b: {b.shape}")
    print(f"Result shape: {result.shape}")
    print("Result:")
    print(result)
    print()
    
    # Test Case 2: Complex Broadcasting
    print("Test Case 2: Complex Broadcasting")
    x = np.random.randn(2, 1, 4)  # shape: (2, 1, 4)
    y = np.random.randn(3, 1)    # shape: (3, 1)
    
    tensor_x = BroadcastTensor("x", driver, requires_grad=True)
    tensor_y = BroadcastTensor("y", driver, requires_grad=True)
    
    driver.create_tensor("x", x)
    driver.create_tensor("y", y)
    
    # Test multiplication
    z = tensor_x * tensor_y
    result = driver.get_tensor(z.name)
    print(f"Shape x: {x.shape}")
    print(f"Shape y: {y.shape}")
    print(f"Result shape: {result.shape}")
    print("Result shape should be (2, 3, 4)")
    print()
    
    # Test Case 3: Gradient Broadcasting
    print("Test Case 3: Gradient Broadcasting")
    m = np.random.randn(2, 1)    # shape: (2, 1)
    n = np.random.randn(3)       # shape: (3,)
    
    tensor_m = BroadcastTensor("m", driver, requires_grad=True)
    tensor_n = BroadcastTensor("n", driver, requires_grad=True)
    
    driver.create_tensor("m", m)
    driver.create_tensor("n", n)
    
    # Forward pass
    p = tensor_m + tensor_n  # shape will be (2, 3)
    q = p * tensor_m        # involves more broadcasting
    
    # Backward pass
    q.backward()
    
    # Check gradient shapes
    m_grad = driver.get_tensor(tensor_m.grad_name)
    n_grad = driver.get_tensor(tensor_n.grad_name)
    
    print(f"Original m shape: {m.shape}")
    print(f"m gradient shape: {m_grad.shape}")
    print(f"Original n shape: {n.shape}")
    print(f"n gradient shape: {n_grad.shape}")
    print()
    
    # Test Case 4: Division with Broadcasting
    print("Test Case 4: Division with Broadcasting")
    u = np.random.randn(4, 1, 3)  # shape: (4, 1, 3)
    v = np.random.randn(1, 2, 1)  # shape: (1, 2, 1)
    
    tensor_u = BroadcastTensor("u", driver, requires_grad=True)
    tensor_v = BroadcastTensor("v", driver, requires_grad=True)
    
    driver.create_tensor("u", u)
    driver.create_tensor("v", v)
    
    # Test division
    w = tensor_u / tensor_v
    result = driver.get_tensor(w.name)
    print(f"Shape u: {u.shape}")
    print(f"Shape v: {v.shape}")
    print(f"Result shape: {result.shape}")
    print("Result shape should be (4, 2, 3)")
    
    # Test backward pass
    w.backward()
    
    u_grad = driver.get_tensor(tensor_u.grad_name)
    v_grad = driver.get_tensor(tensor_v.grad_name)
    
    print(f"u gradient shape: {u_grad.shape}")
    print(f"v gradient shape: {v_grad.shape}")

if __name__ == "__main__":
    test_broadcasting()
