import numpy as np
from helium.jit import JITCompiler
from helium.broadcast import BroadcastModule

def test_jit_compilation():
    """Test JIT compilation with various operations"""
    
    # Initialize dummy driver with JIT operation support
    class DummyDriver:
        def __init__(self):
            self.tensors = {}
            self.counter = 0
            
        def create_tensor(self, name, data):
            self.tensors[name] = np.array(data)
            return name
            
        def get_tensor(self, name):
            return self.tensors[name]
            
        def matmul(self, a, b, output=None):
            result = np.matmul(self.get_tensor(a), self.get_tensor(b))
            if output:
                self.tensors[output] = result
                return output
            name = f"matmul_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
            
        def add(self, a, b, output=None):
            result = self.get_tensor(a) + self.get_tensor(b)
            if output:
                self.tensors[output] = result
                return output
            name = f"add_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
            
        def relu(self, x, output=None):
            result = np.maximum(0, self.get_tensor(x))
            if output:
                self.tensors[output] = result
                return output
            name = f"relu_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
            
        def fused_matmul_add(self, a, b, bias, output=None):
            """Optimized matmul + add operation"""
            result = np.matmul(self.get_tensor(a), self.get_tensor(b))
            result += self.get_tensor(bias)
            if output:
                self.tensors[output] = result
                return output
            name = f"fused_matmul_add_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
            
        def fused_add_relu(self, a, b, output=None):
            """Optimized add + relu operation"""
            result = self.get_tensor(a) + self.get_tensor(b)
            result = np.maximum(0, result)
            if output:
                self.tensors[output] = result
                return output
            name = f"fused_add_relu_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
    
    # Create driver instance
    driver = DummyDriver()
    
    # Define a function to be JIT compiled
    def linear_relu(x_name: str, weight_name: str, bias_name: str) -> str:
        """Function implementing linear layer with ReLU"""
        # Standard implementation
        matmul_result = driver.matmul(x_name, weight_name)
        bias_result = driver.add(matmul_result, bias_name)
        return driver.relu(bias_result)
    
    # Initialize JIT compiler
    compiler = JITCompiler(driver)
    
    # Create example inputs
    batch_size = 32
    in_features = 64
    out_features = 128
    
    x = np.random.randn(batch_size, in_features)
    weight = np.random.randn(in_features, out_features)
    bias = np.random.randn(out_features)
    
    # Store tensors in driver
    x_name = driver.create_tensor("x", x)
    weight_name = driver.create_tensor("weight", weight)
    bias_name = driver.create_tensor("bias", bias)
    
    # Compile the function
    example_inputs = {
        "x_name": x_name,
        "weight_name": weight_name,
        "bias_name": bias_name
    }
    
    compiled_fn = compiler.compile(linear_relu, example_inputs)
    
    print("Running standard vs JIT compiled versions...")
    
    # Run standard version
    import time
    start_time = time.time()
    standard_result = linear_relu(x_name, weight_name, bias_name)
    standard_time = time.time() - start_time
    
    # Run JIT compiled version
    start_time = time.time()
    jit_result = compiled_fn(
        x_name=x_name,
        weight_name=weight_name,
        bias_name=bias_name
    )["output"]
    jit_time = time.time() - start_time
    
    # Compare results
    standard_output = driver.get_tensor(standard_result)
    jit_output = driver.get_tensor(jit_result)
    
    print("\nResults:")
    print(f"Standard execution time: {standard_time:.6f} seconds")
    print(f"JIT execution time: {jit_time:.6f} seconds")
    print(f"Speedup: {standard_time/jit_time:.2f}x")
    print(f"Max difference in outputs: {np.max(np.abs(standard_output - jit_output))}")
    
    # Show optimizations
    print("\nOptimizations applied:")
    print("1. Operation fusion:")
    print("   - Matmul + Add -> fused_matmul_add")
    print("   - Add + ReLU -> fused_add_relu")
    print("2. Memory reuse:")
    print("   - Intermediate tensors reuse memory slots")
    print("3. Operation reordering:")
    print("   - Independent operations can run in parallel")
    
    # Test with different input sizes
    print("\nTesting with different input sizes...")
    
    sizes = [(16, 32, 64), (64, 128, 256), (128, 256, 512)]
    
    for batch, in_dim, out_dim in sizes:
        print(f"\nInput size: batch={batch}, in_features={in_dim}, out_features={out_dim}")
        
        # Create new inputs
        x = np.random.randn(batch, in_dim)
        weight = np.random.randn(in_dim, out_dim)
        bias = np.random.randn(out_dim)
        
        x_name = driver.create_tensor(f"x_{batch}", x)
        weight_name = driver.create_tensor(f"weight_{batch}", weight)
        bias_name = driver.create_tensor(f"bias_{batch}", bias)
        
        # Time standard execution
        start_time = time.time()
        standard_result = linear_relu(x_name, weight_name, bias_name)
        standard_time = time.time() - start_time
        
        # Time JIT execution
        example_inputs = {
            "x_name": x_name,
            "weight_name": weight_name,
            "bias_name": bias_name
        }
        compiled_fn = compiler.compile(linear_relu, example_inputs)
        
        start_time = time.time()
        jit_result = compiled_fn(
            x_name=x_name,
            weight_name=weight_name,
            bias_name=bias_name
        )["output"]
        jit_time = time.time() - start_time
        
        print(f"Standard time: {standard_time:.6f} seconds")
        print(f"JIT time: {jit_time:.6f} seconds")
        print(f"Speedup: {standard_time/jit_time:.2f}x")

if __name__ == "__main__":
    test_jit_compilation()
