import numpy as np
from typing import Dict, List, Set, Tuple, Optional, Any, Union, Callable
import hashlib
from dataclasses import dataclass
import ast
import inspect

@dataclass
class JITOp:
    """Represents a single operation in the JIT compilation"""
    op_type: str
    inputs: List[str]
    outputs: List[str]
    attributes: Dict[str, Any]
    
@dataclass
class JITTrace:
    """Complete trace of operations for JIT compilation"""
    ops: List[JITOp]
    input_shapes: Dict[str, Tuple[int, ...]]
    output_shapes: Dict[str, Tuple[int, ...]]
    
class JITFunction:
    """Compiled function with optimized execution path"""
    def __init__(self, trace: JITTrace, driver):
        self.trace = trace
        self.driver = driver
        self.optimized_ops = self._optimize_trace()
        
    def _optimize_trace(self) -> List[JITOp]:
        """Apply optimizations to the trace"""
        ops = self.trace.ops.copy()
        
        # Optimization 1: Operator fusion
        ops = self._fuse_operations(ops)
        
        # Optimization 2: Memory reuse
        ops = self._optimize_memory(ops)
        
        # Optimization 3: Operation reordering
        ops = self._reorder_operations(ops)
        
        return ops
        
    def _fuse_operations(self, ops: List[JITOp]) -> List[JITOp]:
        """Fuse compatible consecutive operations"""
        fused_ops = []
        i = 0
        while i < len(ops):
            if i + 1 < len(ops):
                # Check for fusion opportunities
                current_op = ops[i]
                next_op = ops[i + 1]
                
                # Example fusion: add + relu -> fused_add_relu
                if (current_op.op_type == "add" and 
                    next_op.op_type == "relu" and
                    next_op.inputs[0] == current_op.outputs[0]):
                    fused_ops.append(JITOp(
                        op_type="fused_add_relu",
                        inputs=current_op.inputs,
                        outputs=next_op.outputs,
                        attributes={**current_op.attributes, **next_op.attributes}
                    ))
                    i += 2
                    continue
                    
                # Example fusion: matmul + add -> fused_matmul_add (bias)
                if (current_op.op_type == "matmul" and
                    next_op.op_type == "add" and
                    next_op.inputs[0] == current_op.outputs[0]):
                    fused_ops.append(JITOp(
                        op_type="fused_matmul_add",
                        inputs=current_op.inputs + [next_op.inputs[1]],
                        outputs=next_op.outputs,
                        attributes={**current_op.attributes, **next_op.attributes}
                    ))
                    i += 2
                    continue
            
            fused_ops.append(ops[i])
            i += 1
            
        return fused_ops
        
    def _optimize_memory(self, ops: List[JITOp]) -> List[JITOp]:
        """Optimize memory usage by reusing tensor storage"""
        # Track tensor lifetimes
        last_use: Dict[str, int] = {}
        for i, op in enumerate(ops):
            for tensor in op.inputs:
                last_use[tensor] = i
            for tensor in op.outputs:
                last_use[tensor] = i
                
        # Map tensors to reusable memory locations
        memory_slots: Dict[str, str] = {}
        free_slots: Set[str] = set()
        
        for i, op in enumerate(ops):
            # Free memory slots that won't be used anymore
            for tensor, last_idx in last_use.items():
                if last_idx == i - 1:  # Tensor was last used in previous op
                    if tensor in memory_slots:
                        free_slots.add(memory_slots[tensor])
                        del memory_slots[tensor]
                        
            # Allocate memory slots for outputs
            for output in op.outputs:
                if free_slots:
                    slot = free_slots.pop()
                    memory_slots[output] = slot
                else:
                    slot = f"mem_slot_{len(memory_slots)}"
                    memory_slots[output] = slot
                    
            # Update operation with memory slot information
            op.attributes["memory_slots"] = {
                tensor: memory_slots.get(tensor, tensor)
                for tensor in op.inputs + op.outputs
            }
            
        return ops
        
    def _reorder_operations(self, ops: List[JITOp]) -> List[JITOp]:
        """Reorder operations for better parallelism"""
        # Build dependency graph
        dependencies: Dict[str, Set[str]] = {}
        for op in ops:
            for output in op.outputs:
                dependencies[output] = set(op.inputs)
                
        # Find independent operation groups
        independent_groups: List[List[JITOp]] = []
        current_group: List[JITOp] = []
        available_tensors: Set[str] = set()
        
        for op in ops:
            # Check if op can be executed with available tensors
            if all(inp in available_tensors for inp in op.inputs):
                current_group.append(op)
                available_tensors.update(op.outputs)
            else:
                if current_group:
                    independent_groups.append(current_group)
                current_group = [op]
                available_tensors = set(op.outputs)
                
        if current_group:
            independent_groups.append(current_group)
            
        # Flatten groups back to operation list
        return [op for group in independent_groups for op in group]
        
    def __call__(self, **inputs) -> Dict[str, str]:
        """Execute the optimized operation sequence"""
        # Validate input shapes
        for name, tensor_name in inputs.items():
            shape = self.driver.get_tensor(tensor_name).shape
            if shape != self.trace.input_shapes[name]:
                raise ValueError(f"Input '{name}' shape mismatch: "
                              f"expected {self.trace.input_shapes[name]}, "
                              f"got {shape}")
                              
        # Execute optimized operations
        tensor_map: Dict[str, str] = dict(inputs)
        
        for op in self.optimized_ops:
            if op.op_type == "fused_add_relu":
                # Execute fused operation
                result = self.driver.fused_add_relu(
                    tensor_map[op.inputs[0]],
                    tensor_map[op.inputs[1]],
                    op.attributes.get("memory_slots", {}).get(op.outputs[0])
                )
                tensor_map[op.outputs[0]] = result
                
            elif op.op_type == "fused_matmul_add":
                # Execute fused matrix multiply with bias
                result = self.driver.fused_matmul_add(
                    tensor_map[op.inputs[0]],
                    tensor_map[op.inputs[1]],
                    tensor_map[op.inputs[2]],  # bias
                    op.attributes.get("memory_slots", {}).get(op.outputs[0])
                )
                tensor_map[op.outputs[0]] = result
                
            else:
                # Execute standard operation
                result = getattr(self.driver, op.op_type)(
                    *[tensor_map[inp] for inp in op.inputs],
                    op.attributes.get("memory_slots", {}).get(op.outputs[0])
                )
                tensor_map[op.outputs[0]] = result
                
        return {name: tensor_map[name] for name in self.trace.output_shapes.keys()}

class JITCompiler:
    """JIT compiler for computational graphs"""
    
    def __init__(self, driver):
        self.driver = driver
        self.compiled_cache: Dict[str, JITFunction] = {}
        
    def _trace_function(self, func: Callable, example_inputs: Dict[str, str]) -> JITTrace:
        """Trace a function to create an operation graph"""
        # Get function source
        source = inspect.getsource(func)
        tree = ast.parse(source)
        
        # Extract operations from AST
        ops: List[JITOp] = []
        input_shapes = {
            name: self.driver.get_tensor(tensor_name).shape
            for name, tensor_name in example_inputs.items()
        }
        output_shapes = {}
        
        # Simple AST traversal for demonstration
        for node in ast.walk(tree):
            if isinstance(node, ast.Call):
                if isinstance(node.func, ast.Attribute):
                    op_type = node.func.attr
                    inputs = []
                    for arg in node.args:
                        if isinstance(arg, ast.Name):
                            inputs.append(arg.id)
                    outputs = [node.targets[0].id] if hasattr(node, 'targets') else ['temp']
                    ops.append(JITOp(op_type, inputs, outputs, {}))
                    
        return JITTrace(ops, input_shapes, output_shapes)
        
    def _get_cache_key(self, func: Callable, example_inputs: Dict[str, str]) -> str:
        """Generate cache key for compiled function"""
        func_str = inspect.getsource(func)
        shapes_str = str(sorted([
            (name, self.driver.get_tensor(tensor_name).shape)
            for name, tensor_name in example_inputs.items()
        ]))
        key_str = f"{func_str}:{shapes_str}"
        return hashlib.md5(key_str.encode()).hexdigest()
        
    def compile(self, func: Callable, example_inputs: Dict[str, str]) -> JITFunction:
        """Compile a function for optimized execution"""
        cache_key = self._get_cache_key(func, example_inputs)
        
        if cache_key in self.compiled_cache:
            return self.compiled_cache[cache_key]
            
        trace = self._trace_function(func, example_inputs)
        compiled_fn = JITFunction(trace, self.driver)
        self.compiled_cache[cache_key] = compiled_fn
        
        return compiled_fn
