"""
Database manager for Helium components using DuckDB
"""
from typing import Optional, Dict, Any, Union
import os
import duckdb
import json
import pickle
import numpy as np
from pathlib import Path
from datetime import datetime
import hashlib
from dotenv import load_dotenv
import warnings

# Initialize HuggingFace token from environment
HF_TOKEN = os.getenv("HF_TOKEN")


# Load environment variables
load_dotenv()

class HeliumDBManager:
    """Centralized database manager for Helium components"""
    
    _instance = None
    
    @classmethod
    def get_instance(cls):
        """Singleton pattern to ensure one database connection"""
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance
    
    def __init__(self):
        """Initialize database connection and tables"""
        self.db_url = os.getenv('HELIUM_DB_URL', 'hf://datasets/Fred808/helium/storage.json')
        self._connect_db()
        self._init_tables()
        
    def _connect_db(self):
        """Connect to DuckDB database"""
        # First create an in-memory connection to configure settings
        temp_conn = duckdb.connect(":memory:")
        
        # Configure HuggingFace access - must be done before connecting to URL
        temp_conn.execute("INSTALL httpfs;")
        temp_conn.execute("LOAD httpfs;")
        temp_conn.execute("SET s3_endpoint='hf.co';")
        temp_conn.execute("SET s3_use_ssl=true;")
        temp_conn.execute("SET s3_url_style='path';")
        
        # Now create the real connection with the configured settings
        self.conn = duckdb.connect(self.db_url, config={'http_keep_alive': 'true'})
        self.conn.execute("INSTALL httpfs;")
        self.conn.execute("LOAD httpfs;")
        self.conn.execute("SET s3_endpoint='hf.co';")
        self.conn.execute("SET s3_use_ssl=true;")
        self.conn.execute("SET s3_url_style='path';")
        
        # Close temporary connection
        temp_conn.close()
        
    def _init_tables(self):
        """Initialize all required tables"""
        # Activation cache table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS activation_cache (
                key VARCHAR PRIMARY KEY,
                value BLOB,
                metadata JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP
            )
        """)
        
        # Layer normalization cache table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS layer_norm_cache (
                key VARCHAR PRIMARY KEY,
                mean BLOB,
                var BLOB,
                normalized BLOB,
                metadata JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP
            )
        """)
        
        # Encoder state cache table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS encoder_cache (
                key VARCHAR PRIMARY KEY,
                key_states BLOB,
                value_states BLOB,
                metadata JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP
            )
        """)
        
        # Decoder state cache table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS decoder_cache (
                key VARCHAR PRIMARY KEY,
                self_attn_states BLOB,
                cross_attn_states BLOB,
                metadata JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP
            )
        """)
        
        # Create indices for faster lookups
        for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
            self.conn.execute(f"""
                CREATE INDEX IF NOT EXISTS idx_{table}_key 
                ON {table}(key)
            """)
    
    def _compute_key(self, data: Union[np.ndarray, bytes], component_type: str, extra_data: str = "") -> str:
        """Compute cache key based on input data and component type"""
        hasher = hashlib.sha256()
        if isinstance(data, np.ndarray):
            hasher.update(data.tobytes())
        else:
            hasher.update(data)
        hasher.update(component_type.encode())
        if extra_data:
            hasher.update(extra_data.encode())
        return hasher.hexdigest()
    
    def get_activation(self, key: str) -> Optional[np.ndarray]:
        """Get cached activation result"""
        result = self.conn.execute("""
            SELECT value, metadata FROM activation_cache 
            WHERE key = ?
        """, [key]).fetchone()
        
        if result:
            self._update_access_time('activation_cache', key)
            return pickle.loads(result[0])
        return None
    
    def set_activation(self, key: str, value: np.ndarray, metadata: Dict[str, Any]):
        """Cache activation result"""
        self.conn.execute("""
            INSERT OR REPLACE INTO activation_cache (key, value, metadata)
            VALUES (?, ?, ?)
        """, [key, pickle.dumps(value), json.dumps(metadata)])
    
    def get_layer_norm(self, key: str) -> Optional[Dict[str, np.ndarray]]:
        """Get cached layer normalization result"""
        result = self.conn.execute("""
            SELECT mean, var, normalized, metadata 
            FROM layer_norm_cache 
            WHERE key = ?
        """, [key]).fetchone()
        
        if result:
            self._update_access_time('layer_norm_cache', key)
            return {
                'mean': pickle.loads(result[0]),
                'var': pickle.loads(result[1]),
                'normalized': pickle.loads(result[2])
            }
        return None
    
    def set_layer_norm(self, key: str, mean: np.ndarray, var: np.ndarray, 
                      normalized: np.ndarray, metadata: Dict[str, Any]):
        """Cache layer normalization result"""
        self.conn.execute("""
            INSERT OR REPLACE INTO layer_norm_cache 
            (key, mean, var, normalized, metadata)
            VALUES (?, ?, ?, ?, ?)
        """, [
            key,
            pickle.dumps(mean),
            pickle.dumps(var),
            pickle.dumps(normalized),
            json.dumps(metadata)
        ])
    
    def get_encoder_state(self, key: str) -> Optional[Dict[str, np.ndarray]]:
        """Get cached encoder state"""
        result = self.conn.execute("""
            SELECT key_states, value_states, metadata 
            FROM encoder_cache 
            WHERE key = ?
        """, [key]).fetchone()
        
        if result:
            self._update_access_time('encoder_cache', key)
            return {
                'key_states': pickle.loads(result[0]),
                'value_states': pickle.loads(result[1])
            }
        return None
    
    def set_encoder_state(self, key: str, key_states: np.ndarray, 
                         value_states: np.ndarray, metadata: Dict[str, Any]):
        """Cache encoder state"""
        self.conn.execute("""
            INSERT OR REPLACE INTO encoder_cache 
            (key, key_states, value_states, metadata)
            VALUES (?, ?, ?, ?)
        """, [
            key,
            pickle.dumps(key_states),
            pickle.dumps(value_states),
            json.dumps(metadata)
        ])
    
    def get_decoder_state(self, key: str) -> Optional[Dict[str, np.ndarray]]:
        """Get cached decoder state"""
        result = self.conn.execute("""
            SELECT self_attn_states, cross_attn_states, metadata 
            FROM decoder_cache 
            WHERE key = ?
        """, [key]).fetchone()
        
        if result:
            self._update_access_time('decoder_cache', key)
            return {
                'self_attn_states': pickle.loads(result[0]),
                'cross_attn_states': pickle.loads(result[1])
            }
        return None
    
    def set_decoder_state(self, key: str, self_attn_states: np.ndarray,
                         cross_attn_states: np.ndarray, metadata: Dict[str, Any]):
        """Cache decoder state"""
        self.conn.execute("""
            INSERT OR REPLACE INTO decoder_cache 
            (key, self_attn_states, cross_attn_states, metadata)
            VALUES (?, ?, ?, ?)
        """, [
            key,
            pickle.dumps(self_attn_states),
            pickle.dumps(cross_attn_states),
            json.dumps(metadata)
        ])
    
    def _update_access_time(self, table: str, key: str):
        """Update last accessed timestamp"""
        self.conn.execute(f"""
            UPDATE {table}
            SET last_accessed = CURRENT_TIMESTAMP 
            WHERE key = ?
        """, [key])
    
    def cleanup_old_entries(self, max_age_days: int = 30):
        """Remove entries older than specified days from all tables"""
        for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
            self.conn.execute(f"""
                DELETE FROM {table}
                WHERE last_accessed < DATEADD(day, ?, CURRENT_TIMESTAMP)
            """, [-max_age_days])
    
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics for all tables"""
        stats = {}
        for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
            table_stats = self.conn.execute(f"""
                SELECT 
                    COUNT(*) as total_entries,
                    SUM(LENGTH(value)) as total_size_bytes,
                    MIN(created_at) as oldest_entry,
                    MAX(last_accessed) as last_accessed
                FROM {table}
            """).fetchone()
            
            stats[table] = {
                'total_entries': table_stats[0],
                'total_size_mb': table_stats[1] / (1024 * 1024) if table_stats[1] else 0,
                'oldest_entry': table_stats[2],
                'last_accessed': table_stats[3]
            }
        return stats
    
    def __del__(self):
        """Close database connection on cleanup"""
        if hasattr(self, 'conn'):
            self.conn.close()
