from typing import Optional, Union, Callable, Dict, Type, Any
import numpy as np
from enum import Enum
from dataclasses import dataclass
import warnings
from functools import lru_cache
import math
import os
import duckdb
import json
from pathlib import Path
from dotenv import load_dotenv
import hashlib
import pickle
from datetime import datetime

# Load environment variables
load_dotenv()

# Get database URL from environment
DB_URL = os.getenv('HELIUM_DB_URL', 'hf://datasets/Fred808/helium/storage.json')
import numpy as np
from enum import Enum
from dataclasses import dataclass
import warnings
from functools import lru_cache
import math
import sqlite3
import pickle
import hashlib
import threading
import time
import os
from pathlib import Path

# Initialize HuggingFace token from environment
HF_TOKEN = os.getenv("HF_TOKEN")


class ActivationType(Enum):
    """Supported activation function types"""
    RELU = "relu"
    GELU = "gelu"
    TANH = "tanh"
    SIGMOID = "sigmoid"
    SWISH = "swish"
    MISH = "mish"
    RELU6 = "relu6"
    ELU = "elu"
    SELU = "selu"
    LEAKY_RELU = "leaky_relu"

@dataclass
class ActivationConfig:
    """Configuration for activation functions"""
    type: ActivationType
    dtype: np.dtype = np.float32
    inplace: bool = False
    alpha: float = 0.01  # For LeakyReLU, ELU
    approximate: bool = True  # For GELU
    cache_size: int = 1024  # For lookup table optimization

class DuckDBCache:
    """Database-backed cache for activation function computations"""
    def __init__(self, size: int = 1024):
        self.size = size
        self._connect_db()
        self._init_tables()
        
    def _connect_db(self):
        """Connect to DuckDB database with HuggingFace configuration"""
        # First create an in-memory connection to configure settings
        temp_conn = duckdb.connect(":memory:")
        
        # Configure HuggingFace access - must be done before connecting to URL
        temp_conn.execute("INSTALL httpfs;")
        temp_conn.execute("LOAD httpfs;")
        temp_conn.execute("SET s3_endpoint='hf.co';")
        temp_conn.execute("SET s3_use_ssl=true;")
        temp_conn.execute("SET s3_url_style='path';")
        
        # Now create the real connection with the configured settings
        self.conn = duckdb.connect(DB_URL, config={'http_keep_alive': 'true'})
        self.conn.execute("INSTALL httpfs;")
        self.conn.execute("LOAD httpfs;")
        self.conn.execute("SET s3_endpoint='hf.co';")
        self.conn.execute("SET s3_use_ssl=true;")
        self.conn.execute("SET s3_url_style='path';")
        
        # Close temporary connection
        temp_conn.close()
        
    def _init_tables(self):
        """Initialize database tables"""
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS activation_cache (
                key VARCHAR PRIMARY KEY,
                value BLOB,
                metadata JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP
            )
        """)
        
        # Create index for faster lookups
        self.conn.execute("""
            CREATE INDEX IF NOT EXISTS idx_activation_cache_key 
            ON activation_cache(key)
        """)
        
    def _compute_key(self, data: np.ndarray, activation_type: str) -> str:
        """Compute cache key based on input data and activation type"""
        # Hash the input data and metadata
        hasher = hashlib.sha256()
        hasher.update(data.tobytes())
        hasher.update(activation_type.encode())
        return hasher.hexdigest()
        
    def clear(self):
        """Clear all cached computations"""
        self.conn.execute("DELETE FROM activation_cache")
        
    def get(self, data: np.ndarray, activation_type: str) -> Optional[np.ndarray]:
        """Get cached computation result"""
        key = self._compute_key(data, activation_type)
        
        result = self.conn.execute("""
            SELECT value, metadata FROM activation_cache 
            WHERE key = ?
        """, [key]).fetchone()
        
        if result:
            value_blob, metadata = result
            
            # Update last accessed timestamp
            self.conn.execute("""
                UPDATE activation_cache 
                SET last_accessed = CURRENT_TIMESTAMP 
                WHERE key = ?
            """, [key])
            
            # Deserialize the value
            try:
                value = pickle.loads(value_blob)
                return value
            except Exception as e:
                warnings.warn(f"Failed to deserialize cached value: {e}")
                return None
                
        return None
        
    def set(self, data: np.ndarray, activation_type: str, value: np.ndarray):
        """Cache computation result"""
        key = self._compute_key(data, activation_type)
        
        # Prepare metadata
        metadata = {
            'shape': data.shape,
            'dtype': str(data.dtype),
            'activation_type': activation_type,
            'timestamp': datetime.now().isoformat()
        }
        
        # Serialize the value
        try:
            value_blob = pickle.dumps(value)
        except Exception as e:
            warnings.warn(f"Failed to serialize value for caching: {e}")
            return
            
        # Check cache size and remove old entries if needed
        self.conn.execute("""
            DELETE FROM activation_cache 
            WHERE key IN (
                SELECT key FROM activation_cache 
                ORDER BY last_accessed ASC 
                LIMIT MAX(0, (SELECT COUNT(*) - ?) FROM activation_cache)
            )
        """, [self.size - 1])
        
        # Insert new value
        self.conn.execute("""
            INSERT OR REPLACE INTO activation_cache (key, value, metadata)
            VALUES (?, ?, ?)
        """, [key, value_blob, json.dumps(metadata)])
        
    def cleanup_old_entries(self, max_age_days: int = 30):
        """Remove entries older than specified days"""
        self.conn.execute("""
            DELETE FROM activation_cache 
            WHERE last_accessed < DATEADD(day, ?, CURRENT_TIMESTAMP)
        """, [-max_age_days])
        
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics"""
        stats = self.conn.execute("""
            SELECT 
                COUNT(*) as total_entries,
                SUM(LENGTH(value)) as total_size_bytes,
                MIN(created_at) as oldest_entry,
                MAX(last_accessed) as last_accessed
            FROM activation_cache
        """).fetchone()
        
        return {
            'total_entries': stats[0],
            'total_size_mb': stats[1] / (1024 * 1024) if stats[1] else 0,
            'oldest_entry': stats[2],
            'last_accessed': stats[3]
        }
        
    def __del__(self):
        """Close database connection on cleanup"""
        if hasattr(self, 'conn'):
            self.conn.close()

class Activation:
    """
    Optimized activation function implementation with support for:
    - Hardware acceleration
    - Mixed precision
    - Memory optimization
    - Computation caching
    - Fused operations
    """
    
    def __init__(
        self,
        config: ActivationConfig,
        driver = None
    ):
        """
        Initialize activation function.
        
        Args:
            config: Activation configuration
            driver: Optional hardware driver for optimized computation
        """
        self.config = config
        self.driver = driver
        self.cache = DuckDBCache(config.cache_size)
        self._setup_implementation()
        
    def _setup_implementation(self):
        """Setup the appropriate implementation based on configuration"""
        implementations = {
            ActivationType.RELU: self._relu,
            ActivationType.GELU: self._gelu,
            ActivationType.TANH: self._tanh,
            ActivationType.SIGMOID: self._sigmoid,
            ActivationType.SWISH: self._swish,
            ActivationType.MISH: self._mish,
            ActivationType.RELU6: self._relu6,
            ActivationType.ELU: self._elu,
            ActivationType.SELU: self._selu,
            ActivationType.LEAKY_RELU: self._leaky_relu
        }
        self._impl = implementations[self.config.type]
        
    @staticmethod
    @lru_cache(maxsize=128)
    def _calculate_constants(dtype: np.dtype) -> Dict[str, float]:
        """Calculate and cache constants used in activation functions"""
        return {
            'sqrt_2_pi': np.sqrt(2 / np.pi).astype(dtype),
            'alpha_gelu': np.float32(0.044715),
            'selu_alpha': np.float32(1.6732632423543772),
            'selu_scale': np.float32(1.0507009873554805)
        }
        
    def _validate_input(self, x: np.ndarray):
        """Validate input tensor"""
        if not isinstance(x, np.ndarray):
            raise TypeError(f"Expected numpy.ndarray, got {type(x)}")
            
    def _prepare_input(self, x: np.ndarray) -> np.ndarray:
        """Prepare input for computation"""
        if x.dtype != self.config.dtype:
            x = x.astype(self.config.dtype)
        return x if not self.config.inplace else x.copy()
        
    def _try_driver_implementation(
        self,
        x: np.ndarray,
        func_name: str
    ) -> Optional[np.ndarray]:
        """Try to use driver implementation if available"""
        if self.driver and hasattr(self.driver, func_name):
            return getattr(self.driver, func_name)(x)
        return None

    def _relu(self, x: np.ndarray) -> np.ndarray:
        """Optimized ReLU implementation"""
        result = self._try_driver_implementation(x, 'relu')
        if result is not None:
            return result
        return np.maximum(x, 0, out=x if self.config.inplace else None)

    def _gelu(self, x: np.ndarray) -> np.ndarray:
        """Optimized GELU implementation"""
        result = self._try_driver_implementation(x, 'gelu')
        if result is not None:
            return result
            
        constants = self._calculate_constants(x.dtype)
        if self.config.approximate:
            # Fast approximation
            cdf = x + constants['alpha_gelu'] * np.power(x, 3)
            cdf *= constants['sqrt_2_pi']
            return 0.5 * x * (1 + np.tanh(cdf))
        else:
            # Exact computation using error function
            return 0.5 * x * (1 + math.erf(x / np.sqrt(2)))

    def _tanh(self, x: np.ndarray) -> np.ndarray:
        """Optimized tanh implementation"""
        result = self._try_driver_implementation(x, 'tanh')
        if result is not None:
            return result
        return np.tanh(x, out=x if self.config.inplace else None)

    def _sigmoid(self, x: np.ndarray) -> np.ndarray:
        """Optimized sigmoid implementation"""
        result = self._try_driver_implementation(x, 'sigmoid')
        if result is not None:
            return result
        return 1 / (1 + np.exp(-x, out=x if self.config.inplace else None))

    def _swish(self, x: np.ndarray) -> np.ndarray:
        """Optimized Swish implementation (x * sigmoid(x))"""
        result = self._try_driver_implementation(x, 'swish')
        if result is not None:
            return result
        return x * self._sigmoid(x)

    def _mish(self, x: np.ndarray) -> np.ndarray:
        """Optimized Mish implementation (x * tanh(softplus(x)))"""
        result = self._try_driver_implementation(x, 'mish')
        if result is not None:
            return result
        return x * np.tanh(np.log1p(np.exp(x)))

    def _relu6(self, x: np.ndarray) -> np.ndarray:
        """ReLU6 implementation (min(max(0, x), 6))"""
        result = self._try_driver_implementation(x, 'relu6')
        if result is not None:
            return result
        return np.clip(x, 0, 6, out=x if self.config.inplace else None)

    def _elu(self, x: np.ndarray) -> np.ndarray:
        """ELU implementation"""
        result = self._try_driver_implementation(x, 'elu')
        if result is not None:
            return result
        return np.where(x > 0, x, self.config.alpha * (np.exp(x) - 1))

    def _selu(self, x: np.ndarray) -> np.ndarray:
        """SELU implementation"""
        result = self._try_driver_implementation(x, 'selu')
        if result is not None:
            return result
        constants = self._calculate_constants(x.dtype)
        return constants['selu_scale'] * np.where(
            x > 0,
            x,
            constants['selu_alpha'] * (np.exp(x) - 1)
        )

    def _leaky_relu(self, x: np.ndarray) -> np.ndarray:
        """Leaky ReLU implementation"""
        result = self._try_driver_implementation(x, 'leaky_relu')
        if result is not None:
            return result
        return np.where(x > 0, x, self.config.alpha * x)

    def __call__(self, x: np.ndarray) -> np.ndarray:
        """
        Apply activation function to input tensor.
        
        Args:
            x: Input tensor
            
        Returns:
            Output tensor after activation
        """
        self._validate_input(x)
        x = self._prepare_input(x)
        return self._impl(x)

# Legacy function interfaces for backward compatibility

def relu(x: np.ndarray, driver=None) -> np.ndarray:
    """Legacy ReLU interface"""
    config = ActivationConfig(type=ActivationType.RELU, dtype=x.dtype)
    return Activation(config, driver)(x)

def gelu(x: np.ndarray, driver=None, approximate: bool = True) -> np.ndarray:
    """Legacy GELU interface"""
    config = ActivationConfig(type=ActivationType.GELU, dtype=x.dtype, approximate=approximate)
    return Activation(config, driver)(x)

def tanh(x: np.ndarray, driver=None) -> np.ndarray:
    """Legacy tanh interface"""
    config = ActivationConfig(type=ActivationType.TANH, dtype=x.dtype)
    return Activation(config, driver)(x)

def sigmoid(x: np.ndarray, driver=None) -> np.ndarray:
    """Legacy sigmoid interface"""
    config = ActivationConfig(type=ActivationType.SIGMOID, dtype=x.dtype)
    return Activation(config, driver)(x)
