# --- Example usage with optimizer and console logs ---
if __name__ == "__main__":
    class DummyDriver:
        # Minimal driver for demonstration (replace with your real driver)
        def __init__(self):
            self.tensors = {}
        def tensor_exists(self, name):
            return name in self.tensors
        def create_tensor(self, name, data):
            self.tensors[name] = np.array(data)
        def get_tensor(self, name):
            return self.tensors[name]
        def set_tensor(self, name, data):
            self.tensors[name] = np.array(data)
        def add(self, a, b, out):
            self.tensors[out] = self.tensors[a] + self.tensors[b]
            return self.tensors[out]
        def mul(self, a, b, out=None):
            if isinstance(a, str): a = self.tensors[a]
            if isinstance(b, str): b = self.tensors[b]
            res = a * b
            if out: self.tensors[out] = res
            return res
        def sub(self, a, b, out=None):
            if isinstance(a, str): a = self.tensors[a]
            if isinstance(b, str): b = self.tensors[b]
            res = a - b
            if out: self.tensors[out] = res
            return res
        def div(self, a, b, out=None):
            if isinstance(a, str): a = self.tensors[a]
            if isinstance(b, str): b = self.tensors[b]
            res = a / b
            if out: self.tensors[out] = res
            return res
        def matmul(self, a, b, out=None, transpose_a=False, transpose_b=False):
            if isinstance(a, str): a = self.tensors[a]
            if isinstance(b, str): b = self.tensors[b]
            if transpose_a: a = a.T
            if transpose_b: b = b.T
            res = a @ b
            if out: self.tensors[out] = res
            return res
        def relu(self, a, out):
            if isinstance(a, str): a = self.tensors[a]
            res = np.maximum(0, a)
            self.tensors[out] = res
            return res
        def relu_grad(self, a):
            if isinstance(a, str): a = self.tensors[a]
            return (a > 0).astype(a.dtype)
        def sigmoid(self, a, out):
            if isinstance(a, str): a = self.tensors[a]
            res = 1 / (1 + np.exp(-a))
            self.tensors[out] = res
            return res
        def tanh(self, a, out):
            if isinstance(a, str): a = self.tensors[a]
            res = np.tanh(a)
            self.tensors[out] = res
            return res
        def sum(self, a, out, axis=None, keepdims=False):
            if isinstance(a, str): a = self.tensors[a]
            res = np.sum(a, axis=axis, keepdims=keepdims)
            self.tensors[out] = res
            return res
        def mean(self, a, out, axis=None, keepdims=False):
            if isinstance(a, str): a = self.tensors[a]
            res = np.mean(a, axis=axis, keepdims=keepdims)
            self.tensors[out] = res
            return res
        def broadcast_to(self, a, shape):
            if isinstance(a, str): a = self.tensors[a]
            return np.broadcast_to(a, shape)
        def softmax(self, a, out, axis=-1):
            if isinstance(a, str): a = self.tensors[a]
            e = np.exp(a - np.max(a, axis=axis, keepdims=True))
            res = e / np.sum(e, axis=axis, keepdims=True)
            self.tensors[out] = res
            return res

    # Dummy optimizer (SGD)
    class DummySGD:
        def __init__(self, param, driver, lr=0.1):
            self.param = param
            self.driver = driver
            self.lr = lr
        def step(self, grad_name):
            w = self.driver.get_tensor(self.param.name)
            g = self.driver.get_tensor(grad_name)
            self.driver.set_tensor(self.param.name, w - self.lr * g)

    # Example: simple linear regression y = wx + b
    driver = DummyDriver()
    x = Tensor(np.array([[1.0],[2.0],[3.0]]), requires_grad=False, driver=driver, name="x")
    y_true = Tensor(np.array([[2.0],[4.0],[6.0]]), requires_grad=False, driver=driver, name="y_true")
    w = Tensor(np.array([[0.5]]), requires_grad=True, driver=driver, name="w")
    b = Tensor(np.array([[0.0]]), requires_grad=True, driver=driver, name="b")
    optimizer = DummySGD(w, driver, lr=0.1)

    for epoch in range(3):
        # Forward: y_pred = x @ w + b
        y_pred = add(matmul(x, w), b)
        # Loss: mean squared error
        diff = sub(y_pred, y_true)
        loss = mean(mul(diff, diff))
        print(f"Epoch {epoch} | Loss: {loss.data().item():.4f}")
        # Backward
        w.zero_grad(); b.zero_grad()
        loss.backward()
        # Optimizer step
        optimizer.step(w.grad_name)
        # Print weights and grads
        print(f"  w: {w.data().item():.4f} | grad: {w.grad().item():.4f}")
        print(f"  b: {b.data().item():.4f} | grad: {b.grad().item():.4f}")
# --- More autograd ops: sigmoid, tanh, softmax, sum, mean, mul, sub, div ---
class SigmoidCtx:
    def __init__(self, inp, out):
        self.inp = inp
        self.out = out
    def backward(self, grad_out):
        sig = self.out.data()
        grad_inp = self.inp.driver.mul(grad_out, self.inp.driver.mul(sig, self.inp.driver.sub(1, sig)))
        if self.inp.requires_grad:
            self.inp.backward(grad_inp)

def sigmoid(inp):
    driver = inp.driver
    out_name = f"sigmoid_{inp.name}"
    data = driver.sigmoid(inp.name, out_name)
    out = Tensor(data, requires_grad=inp.requires_grad, driver=driver, name=out_name)
    out.set_ctx(SigmoidCtx(inp, out))
    return out

class TanhCtx:
    def __init__(self, inp, out):
        self.inp = inp
        self.out = out
    def backward(self, grad_out):
        tanh_out = self.out.data()
        grad_inp = self.inp.driver.mul(grad_out, self.inp.driver.sub(1, self.inp.driver.mul(tanh_out, tanh_out)))
        if self.inp.requires_grad:
            self.inp.backward(grad_inp)

def tanh(inp):
    driver = inp.driver
    out_name = f"tanh_{inp.name}"
    data = driver.tanh(inp.name, out_name)
    out = Tensor(data, requires_grad=inp.requires_grad, driver=driver, name=out_name)
    out.set_ctx(TanhCtx(inp, out))
    return out

class MulCtx:
    def __init__(self, a, b, out):
        self.a = a
        self.b = b
        self.out = out
    def backward(self, grad_out):
        if self.a.requires_grad:
            grad_a = self.a.driver.mul(grad_out, self.b.name)
            self.a.backward(grad_a)
        if self.b.requires_grad:
            grad_b = self.b.driver.mul(grad_out, self.a.name)
            self.b.backward(grad_b)

def mul(a, b):
    assert a.driver == b.driver
    driver = a.driver
    out_name = f"mul_{a.name}_{b.name}"
    data = driver.mul(a.name, b.name, out_name)
    out = Tensor(data, requires_grad=a.requires_grad or b.requires_grad, driver=driver, name=out_name)
    out.set_ctx(MulCtx(a, b, out))
    return out

class SubCtx:
    def __init__(self, a, b, out):
        self.a = a
        self.b = b
        self.out = out
    def backward(self, grad_out):
        if self.a.requires_grad:
            self.a.backward(grad_out)
        if self.b.requires_grad:
            self.b.backward(-grad_out)

def sub(a, b):
    assert a.driver == b.driver
    driver = a.driver
    out_name = f"sub_{a.name}_{b.name}"
    data = driver.sub(a.name, b.name, out_name)
    out = Tensor(data, requires_grad=a.requires_grad or b.requires_grad, driver=driver, name=out_name)
    out.set_ctx(SubCtx(a, b, out))
    return out

class DivCtx:
    def __init__(self, a, b, out):
        self.a = a
        self.b = b
        self.out = out
    def backward(self, grad_out):
        if self.a.requires_grad:
            grad_a = self.a.driver.div(grad_out, self.b.name)
            self.a.backward(grad_a)
        if self.b.requires_grad:
            grad_b = self.b.driver.mul(grad_out, self.a.name)
            grad_b = self.b.driver.div(grad_b, self.b.driver.mul(self.b.name, self.b.name))
            self.b.backward(-grad_b)

def div(a, b):
    assert a.driver == b.driver
    driver = a.driver
    out_name = f"div_{a.name}_{b.name}"
    data = driver.div(a.name, b.name, out_name)
    out = Tensor(data, requires_grad=a.requires_grad or b.requires_grad, driver=driver, name=out_name)
    out.set_ctx(DivCtx(a, b, out))
    return out

class SumCtx:
    def __init__(self, inp, out, axis=None, keepdims=False):
        self.inp = inp
        self.out = out
        self.axis = axis
        self.keepdims = keepdims
    def backward(self, grad_out):
        grad_shape = self.inp.data().shape
        grad = self.inp.driver.broadcast_to(grad_out, grad_shape)
        if self.inp.requires_grad:
            self.inp.backward(grad)

def sum(inp, axis=None, keepdims=False):
    driver = inp.driver
    out_name = f"sum_{inp.name}"
    data = driver.sum(inp.name, out_name, axis=axis, keepdims=keepdims)
    out = Tensor(data, requires_grad=inp.requires_grad, driver=driver, name=out_name)
    out.set_ctx(SumCtx(inp, out, axis, keepdims))
    return out

class MeanCtx:
    def __init__(self, inp, out, axis=None, keepdims=False):
        self.inp = inp
        self.out = out
        self.axis = axis
        self.keepdims = keepdims
    def backward(self, grad_out):
        grad_shape = self.inp.data().shape
        grad = self.inp.driver.broadcast_to(grad_out, grad_shape)
        grad = self.inp.driver.div(grad, np.prod(grad_shape) if self.axis is None else grad_out.shape[self.axis])
        if self.inp.requires_grad:
            self.inp.backward(grad)

def mean(inp, axis=None, keepdims=False):
    driver = inp.driver
    out_name = f"mean_{inp.name}"
    data = driver.mean(inp.name, out_name, axis=axis, keepdims=keepdims)
    out = Tensor(data, requires_grad=inp.requires_grad, driver=driver, name=out_name)
    out.set_ctx(MeanCtx(inp, out, axis, keepdims))
    return out

# --- Softmax (forward only, backward via cross-entropy) ---
def softmax(inp, axis=-1):
    driver = inp.driver
    out_name = f"softmax_{inp.name}"
    data = driver.softmax(inp.name, out_name, axis=axis)
    out = Tensor(data, requires_grad=inp.requires_grad, driver=driver, name=out_name)
    # Backward handled via cross-entropy loss
    return out
import numpy as np

from .compute_graph import get_compute_graph, NoGrad

class Tensor:
    """
    Enhanced autograd tensor with dynamic computational graph support.
    All data/grad storage and ops routed through driver (e.g., SQLiteMemoryManager).
    """
    def __init__(self, data, requires_grad=False, driver=None, name=None):
        self.driver = driver
        self.name = name or f"tensor_{id(self)}"
        if not driver.tensor_exists(self.name):
            driver.create_tensor(self.name, data)
        self.requires_grad = requires_grad
        self.grad_name = self.name + "_grad" if requires_grad else None
        if requires_grad and not driver.tensor_exists(self.grad_name):
            driver.create_tensor(self.grad_name, np.zeros_like(data))
        self._ctx = None  # For autograd graph
        self.graph = get_compute_graph()
        
    @property
    def shape(self):
        return self.data().shape
        
    def data(self):
        return self.driver.get_tensor(self.name)

    def grad(self):
        if self.requires_grad:
            return self.driver.get_tensor(self.grad_name)
        return None

    def zero_grad(self):
        if self.requires_grad:
            self.driver.set_tensor(self.grad_name, np.zeros_like(self.data()))

    def set_ctx(self, ctx):
        self._ctx = ctx
        
    def set_grad(self, grad):
        """Set gradient for this tensor"""
        if self.requires_grad:
            self.driver.set_tensor(self.grad_name, grad)
            
    @staticmethod
    def no_grad():
        """Context manager to disable gradient tracking"""
        return get_compute_graph().no_grad()
        
    def detach(self):
        """Create a new tensor detached from the computation graph"""
        return Tensor(self.data(), requires_grad=False, driver=self.driver)

    def backward(self, grad=None, retain_graph: bool = False):
        """Execute backward pass through the computational graph"""
        if not self.requires_grad:
            return
            
        if grad is None:
            grad = np.ones_like(self.data())
            
        self.set_grad(grad)
        self.graph.backward(self, retain_graph)


# --- Example operation with autograd ---
class AddCtx:
    def __init__(self, a, b, out):
        self.a = a
        self.b = b
        self.out = out
    def backward(self, grad_out):
        if self.a.requires_grad:
            self.a.backward(grad_out)
        if self.b.requires_grad:
            self.b.backward(grad_out)

def add(a, b):
    assert a.driver == b.driver
    driver = a.driver
    out_name = f"add_{a.name}_{b.name}"
    data = driver.add(a.name, b.name, out_name)
    out = Tensor(data, requires_grad=a.requires_grad or b.requires_grad, driver=driver, name=out_name)
    out.set_ctx(AddCtx(a, b, out))
    return out

# --- MatMul operation with autograd ---
class MatMulCtx:
    def __init__(self, a, b, out):
        self.a = a
        self.b = b
        self.out = out
    def backward(self, grad_out):
        # dL/da = grad_out @ b.T, dL/db = a.T @ grad_out
        if self.a.requires_grad:
            grad_a = self.a.driver.matmul(grad_out, self.b.name, transpose_b=True)
            self.a.backward(grad_a)
        if self.b.requires_grad:
            grad_b = self.b.driver.matmul(self.a.name, grad_out, transpose_a=True)
            self.b.backward(grad_b)

def matmul(a, b):
    assert a.driver == b.driver
    driver = a.driver
    out_name = f"matmul_{a.name}_{b.name}"
    data = driver.matmul(a.name, b.name, out_name)
    out = Tensor(data, requires_grad=a.requires_grad or b.requires_grad, driver=driver, name=out_name)
    out.set_ctx(MatMulCtx(a, b, out))
    return out

# --- ReLU operation with autograd ---
class ReLUCtx:
    def __init__(self, inp, out):
        self.inp = inp
        self.out = out
    def backward(self, grad_out):
        mask = self.inp.driver.relu_grad(self.inp.name)
        grad_inp = self.inp.driver.mul(grad_out, mask)
        if self.inp.requires_grad:
            self.inp.backward(grad_inp)

def relu(inp):
    driver = inp.driver
    out_name = f"relu_{inp.name}"
    data = driver.relu(inp.name, out_name)
    out = Tensor(data, requires_grad=inp.requires_grad, driver=driver, name=out_name)
    out.set_ctx(ReLUCtx(inp, out))
    return out

# --- Integration with Optimizers ---
def step_optimizer(optimizer, params):
    """
    Calls optimizer.step for each parameter, using gradients from autograd.
    params: list of Tensor objects
    optimizer: SGD, Adam, etc. (from loss_and_optim.py)
    """
    for p in params:
        if p.requires_grad:
            optimizer.step(p.grad_name)
