from typing import Dict, Optional
import numpy as np

class GPULossFunction:
    def __init__(self, driver):
        self.driver = driver
        self.counter = 0
        
    def get_temp_tensor(self, data):
        """Store temporary computation results in driver memory"""
        name = f"temp_loss_{self.counter}"
        self.counter += 1
        self.driver.create_tensor(name, data)
        return name
        
    def free_temp_tensor(self, name):
        """Clean up temporary tensors"""
        if self.driver.tensor_exists(name):
            self.driver.delete_tensor(name)

def cross_entropy_loss(logits, targets, driver=None):
    """
    logits: (batch, num_classes) - raw model outputs
    targets: (batch,) - integer class labels
    Returns: scalar loss (mean over batch)
    All computation and storage in driver memory
    """
    if driver is None:
        # Fallback to numpy for no driver
        logits_max = np.max(logits, axis=1, keepdims=True)
        exp_logits = np.exp(logits - logits_max)
        probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
        log_probs = np.log(probs + 1e-9)
        nll = -log_probs[np.arange(targets.shape[0]), targets]
        return np.mean(nll)
        
    loss_fn = GPULossFunction(driver)
    
    # All computation in driver memory
    logits_max = driver.max(logits, axis=1, keepdims=True)
    shifted = loss_fn.get_temp_tensor(logits - logits_max)
    
    exp = loss_fn.get_temp_tensor(driver.exp(shifted))
    sum_exp = driver.sum(exp, axis=1, keepdims=True)
    probs = loss_fn.get_temp_tensor(driver.div(exp, sum_exp))
    
    log_probs = loss_fn.get_temp_tensor(driver.log(driver.add(probs, 1e-9)))
    gathered = driver.gather(log_probs, targets)
    loss = -driver.mean(gathered)
    
    # Cleanup
    loss_fn.free_temp_tensor(shifted)
    loss_fn.free_temp_tensor(exp)
    loss_fn.free_temp_tensor(probs)
    loss_fn.free_temp_tensor(log_probs)
    
    return loss

def mse_loss(pred, target, driver=None):
    """All computation in driver memory"""
    if driver is None:
        return np.mean((pred - target) ** 2)
        
    diff = driver.sub(pred, target)
    squared = driver.mul(diff, diff)
    return driver.mean(squared)

class OptimizerState:
    """Base class for optimizer state management in driver memory"""
    def __init__(self, param_table: Dict[str, str], driver, prefix: str):
        self.param_table = param_table
        self.driver = driver
        self.prefix = prefix
        self.state: Dict[str, Dict[str, str]] = {}
        
    def create_state(self, param_name: str, state_dict: Dict[str, np.ndarray]):
        """Store optimizer state in driver memory"""
        param_state = {}
        for key, value in state_dict.items():
            state_name = f"{self.prefix}_{param_name}_{key}"
            self.driver.create_tensor(state_name, value)
            param_state[key] = state_name
        self.state[param_name] = param_state
        
    def get_state(self, param_name: str, key: str) -> np.ndarray:
        """Get state from driver memory"""
        return self.driver.get_tensor(self.state[param_name][key])
        
    def update_state(self, param_name: str, key: str, value: np.ndarray):
        """Update state in driver memory"""
        self.driver.set_tensor(self.state[param_name][key], value)

class Adam(OptimizerState):
    def __init__(self, param_table, driver, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        super().__init__(param_table, driver, "adam")
        self.lr = lr
        self.betas = betas
        self.eps = eps
        
        # Initialize state in driver memory
        for param_name in param_table:
            param = driver.get_tensor(param_table[param_name])
            self.create_state(param_name, {
                "step": np.array([0]),
                "exp_avg": np.zeros_like(param),
                "exp_avg_sq": np.zeros_like(param)
            })
            
    def step(self, grad_table):
        for param_name, grad_name in grad_table.items():
            param = self.driver.get_tensor(self.param_table[param_name])
            grad = self.driver.get_tensor(grad_name)
            
            # Get state from driver memory
            step = self.get_state(param_name, "step")
            exp_avg = self.get_state(param_name, "exp_avg")
            exp_avg_sq = self.get_state(param_name, "exp_avg_sq")
            
            # Update in driver memory
            step += 1
            beta1, beta2 = self.betas
            
            exp_avg = self.driver.mul_scalar(exp_avg, beta1)
            exp_avg = self.driver.add(exp_avg, self.driver.mul_scalar(grad, 1 - beta1))
            
            exp_avg_sq = self.driver.mul_scalar(exp_avg_sq, beta2)
            grad_squared = self.driver.mul(grad, grad)
            exp_avg_sq = self.driver.add(exp_avg_sq, self.driver.mul_scalar(grad_squared, 1 - beta2))
            
            denom = self.driver.sqrt(exp_avg_sq)
            denom = self.driver.add_scalar(denom, self.eps)
            
            step_size = self.lr * np.sqrt(1 - beta2 ** step) / (1 - beta1 ** step)
            
            # Update parameters in driver memory
            update = self.driver.div(exp_avg, denom)
            update = self.driver.mul_scalar(update, step_size)
            param = self.driver.sub(param, update)
            
            # Store updated values back in driver memory
            self.driver.set_tensor(self.param_table[param_name], param)
            self.update_state(param_name, "step", step)
            self.update_state(param_name, "exp_avg", exp_avg)
            self.update_state(param_name, "exp_avg_sq", exp_avg_sq)

class SGD(OptimizerState):
    def __init__(self, param_table, driver, lr=1e-3, momentum=0, nesterov=False):
        super().__init__(param_table, driver, "sgd")
        self.lr = lr
        self.momentum = momentum
        self.nesterov = nesterov
        
        if momentum > 0:
            # Initialize momentum buffers in driver memory
            for param_name in param_table:
                param = driver.get_tensor(param_table[param_name])
                self.create_state(param_name, {
                    "momentum_buffer": np.zeros_like(param)
                })
                
    def step(self, grad_table):
        for param_name, grad_name in grad_table.items():
            param = self.driver.get_tensor(self.param_table[param_name])
            grad = self.driver.get_tensor(grad_name)
            
            if self.momentum > 0:
                buf = self.get_state(param_name, "momentum_buffer")
                buf = self.driver.mul_scalar(buf, self.momentum)
                buf = self.driver.add(buf, grad)
                
                if self.nesterov:
                    grad = self.driver.add(grad, self.driver.mul_scalar(buf, self.momentum))
                else:
                    grad = buf
                    
                self.update_state(param_name, "momentum_buffer", buf)
                
            # Update parameters in driver memory
            update = self.driver.mul_scalar(grad, self.lr)
            param = self.driver.sub(param, update)
            self.driver.set_tensor(self.param_table[param_name], param)
