
from typing import Dict, List, Optional, Tuple, Union
import numpy as np


def embedding_lookup(
    input_ids: np.ndarray,
    embedding_weights: np.ndarray,
    driver=None
) -> np.ndarray:
    """
    Look up embeddings for input tokens.
    
    Args:
        input_ids: Input token indices of shape (batch_size, sequence_length)
        embedding_weights: Embedding weight matrix of shape (vocab_size, hidden_dim)
        driver: Optional hardware driver for optimized lookup
        
    Returns:
        Embedded tokens of shape (batch_size, sequence_length, hidden_dim)
    """
    if driver and hasattr(driver, 'embedding_lookup'):
        return driver.embedding_lookup(input_ids, embedding_weights)
        
    # Fallback to numpy implementation
    batch_size, seq_length = input_ids.shape
    hidden_dim = embedding_weights.shape[1]
    
    # Reshape input_ids for broadcasting
    input_ids_reshaped = input_ids.reshape(-1)
    
    # Lookup embeddings
    embeddings = embedding_weights[input_ids_reshaped]
    
    # Reshape back to (batch_size, sequence_length, hidden_dim)
    return embeddings.reshape(batch_size, seq_length, hidden_dim)


def add_positional_encoding(
    embeddings: np.ndarray,
    max_position: int,
    hidden_dim: int,
    dtype: np.dtype = np.float32,
    driver=None
) -> np.ndarray:
    """
    Add positional encodings to input embeddings.
    
    Args:
        embeddings: Input embeddings of shape (batch_size, sequence_length, hidden_dim)
        max_position: Maximum sequence length
        hidden_dim: Hidden dimension size
        dtype: Data type for positional encodings
        driver: Optional hardware driver for optimized computation
        
    Returns:
        Embeddings with positional encoding added
    """
    if driver and hasattr(driver, 'add_positional_encoding'):
        return driver.add_positional_encoding(
            embeddings,
            max_position,
            hidden_dim,
            dtype
        )
        
    # Fallback to numpy implementation
    batch_size, seq_length, _ = embeddings.shape
    
    # Create position indices
    position = np.arange(seq_length)[:, np.newaxis]
    div_term = np.exp(
        np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
    )
    
    # Calculate positional encodings
    pos_encoding = np.zeros((seq_length, hidden_dim), dtype=dtype)
    pos_encoding[:, 0::2] = np.sin(position * div_term)
    pos_encoding[:, 1::2] = np.cos(position * div_term)
    
    # Add batch dimension and add to embeddings
    pos_encoding = pos_encoding[np.newaxis, :, :]
    return embeddings + pos_encoding[:, :seq_length, :]


class EmbeddingState:
    def __init__(self, driver, prefix: str):
        self.driver = driver
        self.prefix = prefix
        self.counter = 0
        
    def get_temp_tensor(self, data, name_suffix: str = "") -> str:
        """Store temporary computation results in driver memory"""
        name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
        self.counter += 1
        self.driver.create_tensor(name, data)
        return name
        
    def free_temp_tensor(self, name: str):
        """Clean up temporary tensors"""
        if self.driver.tensor_exists(name):
            self.driver.delete_tensor(name)

class Embedding:
    """
    GPU/DB-backed Embedding layer for NLP/graph models.
    All weights/tensors are stored and accessed via the driver (e.g., SQLiteMemoryManager), not Python RAM.
    """
    def __init__(self, vocab_size: int, embedding_dim: int, driver, prefix: str = "embed", init_std: float = 0.02):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.driver = driver
        self.prefix = prefix
        
        # Create unique names for persistent tensors
        self.weight_name = f"{prefix}_weight"
        self.grad_name = f"{prefix}_grad"
        
        # Initialize embedding matrix in driver memory if not present
        if not driver.tensor_exists(self.weight_name):
            weights = driver.random_normal(
                (vocab_size, embedding_dim),
                mean=0.0,
                std=init_std
            )
            driver.create_tensor(self.weight_name, weights)
            # Initialize gradient tensor
            driver.create_tensor(
                self.grad_name,
                np.zeros((vocab_size, embedding_dim))
            )
            
    def forward(
        self,
        indices_name: str,
        training: bool = True
    ) -> str:
        """
        All operations in driver memory
        indices_name: name of tensor containing indices in driver
        Returns: name of output tensor in driver
        """
        state = EmbeddingState(self.driver, f"{self.prefix}_fwd")
        
        # Get shape info from driver
        indices = self.driver.get_tensor(indices_name)
        original_shape = indices.shape
        
        # Flatten indices in driver memory
        flat_name = state.get_temp_tensor(
            indices.reshape(-1),
            "flat"
        )
        
        # Gather embeddings in driver memory
        gathered_name = state.get_temp_tensor(
            self.driver.gather(self.weight_name, flat_name),
            "gathered"
        )
        state.free_temp_tensor(flat_name)
        
        # Reshape to original dimensions + embedding_dim
        output_shape = original_shape + (self.embedding_dim,)
        output_name = state.get_temp_tensor(
            self.driver.reshape(gathered_name, output_shape),
            "output"
        )
        state.free_temp_tensor(gathered_name)
        
        if training:
            # Store intermediate results needed for backward
            self.save_for_backward(indices_name, original_shape)
            
        return output_name
        
    def save_for_backward(self, indices_name: str, shape: Tuple[int, ...]):
        """Save tensors needed for backward pass in driver memory"""
        self.driver.create_tensor(
            f"{self.prefix}_cache_indices",
            self.driver.get_tensor(indices_name)
        )
        self.driver.create_tensor(
            f"{self.prefix}_cache_shape",
            np.array(shape)
        )
        
    def backward(self, grad_output_name: str) -> None:
        """
        Compute gradients in driver memory
        grad_output_name: name of gradient tensor in driver
        """
        state = EmbeddingState(self.driver, f"{self.prefix}_bwd")
        
        # Get cached values from driver
        indices = self.driver.get_tensor(f"{self.prefix}_cache_indices")
        orig_shape = tuple(self.driver.get_tensor(f"{self.prefix}_cache_shape"))
        
        # Reshape gradient to match gathered shape
        reshaped_grad_name = state.get_temp_tensor(
            self.driver.reshape(grad_output_name, (-1, self.embedding_dim)),
            "reshaped_grad"
        )
        
        # Use scatter_add to accumulate gradients for each index
        self.driver.scatter_add(
            self.grad_name,  # Accumulate into gradient tensor
            indices.reshape(-1),  # Flattened indices
            reshaped_grad_name  # Reshaped gradients
        )
        
        state.free_temp_tensor(reshaped_grad_name)
        
        # Cleanup cached tensors
        self.driver.delete_tensor(f"{self.prefix}_cache_indices")
        self.driver.delete_tensor(f"{self.prefix}_cache_shape")
        
    def parameters(self) -> Dict[str, str]:
        """Return names of parameter tensors in driver"""
        return {
            "weight": self.weight_name,
            "grad": self.grad_name
        }
        
    def zero_grad(self) -> None:
        """Reset gradients to zero in driver memory"""
        self.driver.fill(self.grad_name, 0.0)
