import numpy as np
from typing import Tuple, Optional
from .compute_graph import get_compute_graph
from .autograd import Tensor

def register_ops():
    """Register all operations with the computational graph"""
    graph = get_compute_graph()
    
    # Addition
    def add_forward(a: Tensor, b: Tensor) -> Tensor:
        driver = a.driver
        out_name = f"add_{a.name}_{b.name}"
        data = driver.add(a.name, b.name, out_name)
        return Tensor(data, requires_grad=a.requires_grad or b.requires_grad, driver=driver, name=out_name)
        
    def add_backward(grad_output: np.ndarray, a: Tensor, b: Tensor) -> Tuple[np.ndarray, np.ndarray]:
        return grad_output, grad_output
        
    graph.register_operation("add", add_forward, add_backward)
    
    # Multiplication
    def mul_forward(a: Tensor, b: Tensor) -> Tensor:
        driver = a.driver
        out_name = f"mul_{a.name}_{b.name}"
        data = driver.mul(a.name, b.name, out_name)
        out = Tensor(data, requires_grad=a.requires_grad or b.requires_grad, driver=driver, name=out_name)
        if a.requires_grad or b.requires_grad:
            node = graph.track_operation("mul", [a, b], out)
            node.save_for_backward(a_data=a.data(), b_data=b.data())
        return out
        
    def mul_backward(grad_output: np.ndarray, a_data: np.ndarray, b_data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        grad_a = grad_output * b_data
        grad_b = grad_output * a_data
        return grad_a, grad_b
        
    graph.register_operation("mul", mul_forward, mul_backward)
    
    # Matrix multiplication
    def matmul_forward(a: Tensor, b: Tensor) -> Tensor:
        driver = a.driver
        out_name = f"matmul_{a.name}_{b.name}"
        data = driver.matmul(a.name, b.name, out_name)
        out = Tensor(data, requires_grad=a.requires_grad or b.requires_grad, driver=driver, name=out_name)
        if a.requires_grad or b.requires_grad:
            node = graph.track_operation("matmul", [a, b], out)
            node.save_for_backward(a_data=a.data(), b_data=b.data())
        return out
        
    def matmul_backward(grad_output: np.ndarray, a_data: np.ndarray, b_data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        grad_a = np.matmul(grad_output, b_data.T)
        grad_b = np.matmul(a_data.T, grad_output)
        return grad_a, grad_b
        
    graph.register_operation("matmul", matmul_forward, matmul_backward)
    
    # ReLU
    def relu_forward(x: Tensor) -> Tensor:
        driver = x.driver
        out_name = f"relu_{x.name}"
        data = driver.relu(x.name, out_name)
        out = Tensor(data, requires_grad=x.requires_grad, driver=driver, name=out_name)
        if x.requires_grad:
            node = graph.track_operation("relu", [x], out)
            node.save_for_backward(x_data=x.data())
        return out
        
    def relu_backward(grad_output: np.ndarray, x_data: np.ndarray) -> np.ndarray:
        return grad_output * (x_data > 0)
        
    graph.register_operation("relu", relu_forward, relu_backward)
    
    # Mean
    def mean_forward(x: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor:
        driver = x.driver
        out_name = f"mean_{x.name}"
        data = driver.mean(x.name, out_name, axis=axis, keepdims=keepdims)
        out = Tensor(data, requires_grad=x.requires_grad, driver=driver, name=out_name)
        if x.requires_grad:
            node = graph.track_operation("mean", [x], out)
            node.save_for_backward(x_shape=x.shape, axis=axis, keepdims=keepdims)
        return out
        
    def mean_backward(grad_output: np.ndarray, x_shape: Tuple[int, ...], 
                     axis: Optional[int], keepdims: bool) -> np.ndarray:
        output_shape = x_shape
        if axis is not None:
            if not keepdims:
                grad_output = np.expand_dims(grad_output, axis)
            dims = x_shape[axis] if isinstance(axis, int) else np.prod([x_shape[i] for i in axis])
            grad_input = np.broadcast_to(grad_output, output_shape) / dims
        else:
            grad_input = np.broadcast_to(grad_output / np.prod(x_shape), output_shape)
        return grad_input
        
    graph.register_operation("mean", mean_forward, mean_backward)

# Function interfaces
def add(a: Tensor, b: Tensor) -> Tensor:
    graph = get_compute_graph()
    return graph.grad_fns["add"][0](a, b)
    
def mul(a: Tensor, b: Tensor) -> Tensor:
    graph = get_compute_graph()
    return graph.grad_fns["mul"][0](a, b)
    
def matmul(a: Tensor, b: Tensor) -> Tensor:
    graph = get_compute_graph()
    return graph.grad_fns["matmul"][0](a, b)
    
def relu(x: Tensor) -> Tensor:
    graph = get_compute_graph()
    return graph.grad_fns["relu"][0](x)
    
def mean(x: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor:
    graph = get_compute_graph()
    return graph.grad_fns["mean"][0](x, axis, keepdims)
