import numpy as np
import ctypes
from typing import Dict, List, Set, Tuple, Optional, Any, Union, TYPE_CHECKING
from collections import defaultdict

if TYPE_CHECKING:
    from .autograd import Tensor

class ComputationalNode:
    def __init__(self, op_name: str, inputs: List['Tensor'], output: 'Tensor', driver):
        self.op_name = op_name
        self.node_id = f"node_{id(self)}"
        self.driver = driver
        
        # Store input and output tensor names in driver
        input_names = [inp.name for inp in inputs]
        self.driver.create_tensor(f"{self.node_id}_inputs", np.array(input_names))
        self.driver.create_tensor(f"{self.node_id}_output", np.array(output.name))
        
        self.gradient_fn = None
        self.saved_tensor_names = {}  # Store tensor names instead of data
        
    def save_for_backward(self, **tensors):
        """Save tensor data needed for backward pass in driver storage"""
        for key, tensor in tensors.items():
            tensor_name = f"{self.node_id}_saved_{key}"
            if isinstance(tensor, np.ndarray):
                self.driver.create_tensor(tensor_name, tensor)
            else:
                # For non-tensor data like shapes, axes etc.
                self.driver.create_tensor(tensor_name, np.array(tensor))
            self.saved_tensor_names[key] = tensor_name
            
    def get_saved_tensor(self, key):
        """Retrieve saved tensor from driver storage"""
        return self.driver.get_tensor(self.saved_tensor_names[key])
        
    def get_inputs(self):
        """Get input tensor names from driver storage"""
        return self.driver.get_tensor(f"{self.node_id}_inputs")
        
    def get_output(self):
        """Get output tensor name from driver storage"""
        return self.driver.get_tensor(f"{self.node_id}_output").item()

class ComputeGraph:
    def __init__(self, driver=None):
        self.driver = driver
        self.graph_id = f"graph_{id(self)}"
        # Store node IDs in driver
        self.driver.create_tensor(f"{self.graph_id}_nodes", np.array([]))
        # Store op mappings in driver
        self.driver.create_tensor(f"{self.graph_id}_grad_fns", np.array([]))
        self.requires_grad = set()  # Small enough to keep in Python
        self.is_training = True
        
    def register_operation(self, op_name: str, forward_fn: Any, backward_fn: Any):
        """Register a new operation with its forward and backward functions"""
        fn_name = f"{self.graph_id}_fn_{op_name}"
        self.driver.create_tensor(fn_name, np.array([id(forward_fn), id(backward_fn)]))
        # Update op list
        ops = list(self.driver.get_tensor(f"{self.graph_id}_grad_fns"))
        ops.append(fn_name)
        self.driver.create_tensor(f"{self.graph_id}_grad_fns", np.array(ops))
    
    def track_operation(self, op_name: str, inputs: List['Tensor'], output: 'Tensor') -> None:
        """Record an operation in the computational graph"""
        if not self.is_training:
            return
            
        if any(inp.requires_grad for inp in inputs):
            node = ComputationalNode(op_name, inputs, output, self.driver)
            fn_name = f"{self.graph_id}_fn_{op_name}"
            fn_ids = self.driver.get_tensor(fn_name)
            node.gradient_fn = ctypes.cast(int(fn_ids[1]), ctypes.py_object).value
            
            # Add node to graph
            nodes = list(self.driver.get_tensor(f"{self.graph_id}_nodes"))
            nodes.append(node.node_id)
            self.driver.create_tensor(f"{self.graph_id}_nodes", np.array(nodes))
            
    def backward(self, loss_tensor: 'Tensor', retain_graph: bool = False):
        """Execute backward pass through the computational graph"""
        nodes = self.driver.get_tensor(f"{self.graph_id}_nodes")
        if len(nodes) == 0:
            return
            
        # Initialize gradients in driver storage
        grad_id = f"{self.graph_id}_grads"
        self.driver.create_tensor(f"{grad_id}_{loss_tensor.name}", np.ones_like(loss_tensor.data()))
        
        # Topological sort using driver storage
        visited_id = f"{self.graph_id}_visited"
        self.driver.create_tensor(visited_id, np.array([]))
        topo_id = f"{self.graph_id}_topo"
        self.driver.create_tensor(topo_id, np.array([]))
        
        def build_topo(node):
            visited = set(self.driver.get_tensor(visited_id))
            if node.node_id in visited:
                return
                
            visited = list(visited)
            visited.append(node.node_id)
            self.driver.create_tensor(visited_id, np.array(visited))
            
            for input_name in node.get_inputs():
                input_tensor = self.get_tensor_by_name(input_name)
                if input_tensor.requires_grad:
                    for n_id in nodes:
                        n = self.get_node_by_id(n_id)
                        if n.get_output() == input_name:
                            build_topo(n)
                            
            topo_order = list(self.driver.get_tensor(topo_id))
            topo_order.append(node.node_id)
            self.driver.create_tensor(topo_id, np.array(topo_order))
            
        # Build topological ordering
        for node_id in reversed(nodes):
            node = self.get_node_by_id(node_id)
            if node.get_output() == loss_tensor.name:
                build_topo(node)
                
        # Execute backward passes in topological order
        topo_order = self.driver.get_tensor(topo_id)
        for node_id in reversed(topo_order):
            node = self.get_node_by_id(node_id)
            grad_output = self.driver.get_tensor(f"{grad_id}_{node.get_output()}")
            grad_inputs = node.gradient_fn(grad_output, **{k: node.get_saved_tensor(k) 
                                                         for k in node.saved_tensor_names})
            
            if not isinstance(grad_inputs, tuple):
                grad_inputs = (grad_inputs,)
                
            for grad_input, input_name in zip(grad_inputs, node.get_inputs()):
                input_tensor = self.get_tensor_by_name(input_name)
                if input_tensor.requires_grad:
                    grad_key = f"{grad_id}_{input_name}"
                    if self.driver.tensor_exists(grad_key):
                        existing_grad = self.driver.get_tensor(grad_key)
                        self.driver.create_tensor(grad_key, existing_grad + grad_input)
                    else:
                        self.driver.create_tensor(grad_key, grad_input)
                    
        # Update gradients in tensors
        for node_id in nodes:
            node = self.get_node_by_id(node_id)
            for input_name in node.get_inputs():
                input_tensor = self.get_tensor_by_name(input_name)
                if input_tensor.requires_grad:
                    grad_key = f"{grad_id}_{input_name}"
                    if self.driver.tensor_exists(grad_key):
                        input_tensor.set_grad(self.driver.get_tensor(grad_key))
                        
        if not retain_graph:
            # Clear graph
            self.driver.create_tensor(f"{self.graph_id}_nodes", np.array([]))
            
    def clear(self):
        """Clear the computational graph"""
        self.nodes.clear()
        
    def no_grad(self):
        """Context manager to disable gradient computation"""
        return NoGrad(self)
        
class NoGrad:
    def __init__(self, graph: ComputeGraph):
        self.graph = graph
        self.prev_state = None
        
    def __enter__(self):
        self.prev_state = self.graph.is_training
        self.graph.is_training = False
        
    def __exit__(self, *args):
        self.graph.is_training = self.prev_state

# Global compute graph instance
GLOBAL_GRAPH = ComputeGraph()

def get_compute_graph() -> ComputeGraph:
    return GLOBAL_GRAPH
