"""
Local Storage Manager Implementation
Replaces HuggingFace remote storage with local SQLite database
"""

import sqlite3
import json
import threading
import time
import gc  # Add garbage collection
import numpy as np
from typing import Dict, Any, Optional, Union, List
from pathlib import Path
import os

class LocalStorageManager:
    """
    Local storage implementation using SQLite database.
    Replaces remote HuggingFace storage with local file-based storage.
    
    Features:
    - Zero-allocation design
    - Automatic cleanup
    - Memory optimization
    - Connection pooling
    - Database reconnection handling
    """
    
    _thread_local = threading.local()
    _main_lock = threading.Lock() # Used for initial table setup only
    _cleanup_interval = 30  # Seconds between cleanup operations
    _last_cleanup = 0

    def __new__(cls, db_path: str = "data/local_storage.db"):
        # Each thread gets its own instance of LocalStorageManager, and thus its own connection
        if not hasattr(cls._thread_local, "instance") or cls._thread_local.instance is None:
            instance = super().__new__(cls)
            instance.db_path = db_path
            instance.conn = None
            instance.cursor = None
            # Use the main lock only for the very first connection to ensure tables are created once
            with cls._main_lock:
                instance._connect_db(is_initial_setup=True)
            cls._thread_local.instance = instance
        return cls._thread_local.instance

    def _connect_db(self, is_initial_setup=False):
        """Establish database connection with optimized settings"""
        try:
            # Ensure directory exists
            os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
            
            self.conn = sqlite3.connect(self.db_path, timeout=60.0)
            self.conn.execute("PRAGMA journal_mode=WAL")  # Write-Ahead Logging for better concurrency
            self.conn.execute("PRAGMA synchronous=NORMAL")  # Faster writes with reasonable safety
            self.conn.execute("PRAGMA cache_size=-2000000")  # Use 2GB memory for cache
            self.cursor = self.conn.cursor()
            
            if is_initial_setup:
                # Create tables if they don't exist
                self.cursor.execute("""
                    CREATE TABLE IF NOT EXISTS tensors (
                        id TEXT PRIMARY KEY,
                        data BLOB,
                        metadata TEXT,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                    )
                """)
                self.cursor.execute("""
                    CREATE TABLE IF NOT EXISTS tensor_chunks (
                        chunk_id TEXT PRIMARY KEY,
                        tensor_id TEXT,
                        chunk_data BLOB,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        FOREIGN KEY (tensor_id) REFERENCES tensors(id)
                    )
                """)
                self.conn.commit()
        except sqlite3.Error as e:
            print(f"Database connection error: {e}")
            raise

    def tensor_exists(self, tensor_id: str) -> bool:
        """Check if a tensor with given ID exists in storage"""
        try:
            if not self.conn:
                self._connect_db()
            
            self.cursor.execute(
                "SELECT 1 FROM tensors WHERE id = ? LIMIT 1",
                (tensor_id,)
            )
            return bool(self.cursor.fetchone())
        except sqlite3.Error:
            self._connect_db()  # Try to reconnect
            return False

            self.cursor = self.conn.cursor()
            
            # Optimize database connection for zero-allocation operations
            self.cursor.execute("PRAGMA temp_store = MEMORY")  # Use memory for temp operations
            self.cursor.execute("PRAGMA journal_mode = WAL")   # Write-ahead logging for better concurrency
            self.cursor.execute("PRAGMA synchronous = NORMAL") # Balance durability and speed
            self.cursor.execute("PRAGMA cache_size = -2000")   # Use 2MB cache
            self.cursor.execute("PRAGMA mmap_size = 30000000000")  # Memory-mapped I/O
            
            # Enable WAL mode for better concurrent access
            self.cursor.execute("PRAGMA journal_mode=WAL")
            
            # Enable foreign key support
            self.cursor.execute("PRAGMA foreign_keys=ON")

            if is_initial_setup:
                self._setup_tables()
                self.conn.commit()

            

            
        except Exception as e:
            raise RuntimeError(f"Failed to connect to database: {str(e)}")

    def cleanup(self):
        """Perform storage cleanup and optimization"""
        current_time = time.time()
        if current_time - self._last_cleanup < self._cleanup_interval:
            return
            
        try:
            # Optimize database
            self.cursor.execute("PRAGMA optimize")
            self.cursor.execute("PRAGMA incremental_vacuum")
            
            # Clear any unused prepared statements
            self.cursor.execute("PRAGMA shrink_memory")
            
            # Reset WAL file if it\'s too large
            wal_path = self.db_path + "-wal"
            if os.path.exists(wal_path) and os.path.getsize(wal_path) > 1_000_000_000:  # 1GB
                self.cursor.execute("PRAGMA wal_checkpoint(TRUNCATE)")
            
            # Force Python garbage collection
            gc.collect()
            
            self._last_cleanup = current_time
            
        except Exception as e:
            print(f"Cleanup error: {e}")
            
    def __del__(self):
        """Ensure proper cleanup on deletion"""
        if hasattr(self, 'conn') and self.conn: # Corrected escape sequence
            try:
                self.cleanup()
                self.cursor.execute("PRAGMA optimize")
                self.cursor.execute("PRAGMA incremental_vacuum")
                self.conn.close()
            except:
                pass
                
    def __enter__(self):
        """Context manager entry"""
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit with cleanup"""
        self.cleanup()
        if self.conn:
            self.conn.close()
            
    def _setup_tables(self):
        """Initialize required database tables"""
        self.cursor.executescript("""
            -- Main storage table for binary data
            CREATE TABLE IF NOT EXISTS storage_data (
                key TEXT PRIMARY KEY,
                data BLOB,
                metadata TEXT,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            );

            -- Table for tensor data
            CREATE TABLE IF NOT EXISTS tensor_storage (
                key TEXT PRIMARY KEY,
                data BLOB,
                shape TEXT,
                dtype TEXT,
                metadata TEXT,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            );

            -- Table for model states
            CREATE TABLE IF NOT EXISTS model_states (
                model_id TEXT PRIMARY KEY,
                state_data BLOB,
                config TEXT,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            );
            
            -- Cache table for computation results
            CREATE TABLE IF NOT EXISTS computation_cache (
                cache_key TEXT PRIMARY KEY,
                result BLOB,
                metadata TEXT,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP
            );


            
            -- Mining-specific tables
            CREATE TABLE IF NOT EXISTS cpu_config (
                cpu_id INTEGER PRIMARY KEY,
                gate_delay REAL,
                switch_freq REAL,
                drift_speed REAL,
                clock_freq REAL,
                group_type INTEGER,
                core_count INTEGER,
                thread_count INTEGER,
                initialized INTEGER DEFAULT 0,
                last_updated INTEGER DEFAULT 0
            );
            
            CREATE TABLE IF NOT EXISTS core_state (
                core_id INTEGER PRIMARY KEY,
                cpu_id INTEGER,
                current_temp REAL DEFAULT 0,
                power_state INTEGER DEFAULT 0,
                FOREIGN KEY(cpu_id) REFERENCES cpu_config(cpu_id)
            );
            
            CREATE TABLE IF NOT EXISTS thread_state (
                thread_id INTEGER PRIMARY KEY,
                core_id INTEGER,
                cpu_id INTEGER,
                current_op INTEGER DEFAULT 0,
                cycles_left INTEGER DEFAULT 0,
                FOREIGN KEY(core_id) REFERENCES core_state(core_id)
            );
            
            CREATE TABLE IF NOT EXISTS mining_stats (
                cpu_id INTEGER,
                core_id INTEGER,
                thread_id INTEGER,
                hashes INTEGER DEFAULT 0,
                blocks_found INTEGER DEFAULT 0,
                last_hash_time REAL,
                PRIMARY KEY (cpu_id, core_id, thread_id)
            );

            DROP TABLE IF EXISTS cpu_system_config;
            CREATE TABLE IF NOT EXISTS cpu_system_config (
                id INTEGER PRIMARY KEY,
                total_cpus INTEGER,
                cores_per_cpu INTEGER,
                threads_per_core INTEGER,
                batch_size INTEGER
            );
            
            -- Cleanup and trigger for automatic maintenance
            CREATE TRIGGER IF NOT EXISTS cleanup_old_data 
            AFTER INSERT ON storage_data
            BEGIN
                DELETE FROM storage_data 
                WHERE key NOT IN (
                    SELECT key FROM storage_data 
                    ORDER BY updated_at DESC 
                    LIMIT 1000000
                );
            END;
        """)
        self.conn.commit()

    def store_tensor(self, key: str, tensor: np.ndarray, metadata: Optional[Dict] = None):
        """Store tensor data in local database"""
        try:
            # Serialize tensor data
            tensor_bytes = tensor.tobytes()
            shape = json.dumps(tensor.shape)
            dtype = str(tensor.dtype)
            metadata_json = json.dumps(metadata) if metadata else None

            self.cursor.execute("""
                INSERT OR REPLACE INTO tensor_storage 
                (key, data, shape, dtype, metadata, updated_at)
                VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
            """, (key, tensor_bytes, shape, dtype, metadata_json))
            
            self.conn.commit()
            return True
        except Exception as e:
            print(f"Error storing tensor: {str(e)}")
            return False

    def load_tensor(self, key: str) -> Optional[np.ndarray]:
        """Load tensor from local database"""
        try:
            self.cursor.execute("""
                SELECT data, shape, dtype FROM tensor_storage WHERE key = ?
            """, (key,))
            
            row = self.cursor.fetchone()
            if row:
                tensor_bytes, shape, dtype = row
                shape = json.loads(shape)
                return np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
            return None
        except Exception as e:
            print(f"Error loading tensor: {str(e)}")
            return None

    def store_model(self, model_id: str, state_data: Union[bytes, Dict], 
                   model_config: Optional[Dict] = None) -> bool:
        """Store model state in local database"""
        try:
            if isinstance(state_data, dict):
                state_data = json.dumps(state_data).encode()
            
            config_json = json.dumps(model_config) if model_config else None

            self.cursor.execute("""
                INSERT OR REPLACE INTO model_states 
                (model_id, state_data, config, updated_at)
                VALUES (?, ?, ?, CURRENT_TIMESTAMP)
            """, (model_id, state_data, config_json))
            
            self.conn.commit()
            return True
        except Exception as e:
            print(f"Error storing model: {str(e)}")
            return False

    def load_model(self, model_id: str) -> Optional[Union[bytes, Dict]]:
        """Load model state from local database"""
        try:
            self.cursor.execute("""
                SELECT state_data FROM model_states WHERE model_id = ?
            """, (model_id,))
            
            row = self.cursor.fetchone()
            if row:
                state_data = row[0]
                try:
                    # Try to decode as JSON
                    return json.loads(state_data)
                except:
                    # Return raw bytes if not JSON
                    return state_data
            return None
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            return None

    def store_data(self, key: str, data: Union[bytes, Dict], 
                  metadata: Optional[Dict] = None) -> bool:
        """Store generic data in local database"""
        try:
            if isinstance(data, dict):
                data = json.dumps(data).encode()
            
            metadata_json = json.dumps(metadata) if metadata else None

            self.cursor.execute("""
                INSERT OR REPLACE INTO storage_data 
                (key, data, metadata, updated_at)
                VALUES (?, ?, ?, CURRENT_TIMESTAMP)
            """, (key, data, metadata_json))
            
            self.conn.commit()
            return True
        except Exception as e:
            print(f"Error storing data: {str(e)}")
            return False

    def load_data(self, key: str) -> Optional[Union[bytes, Dict]]:
        """Load generic data from local database"""
        try:
            self.cursor.execute("""
                SELECT data FROM storage_data WHERE key = ?
            """, (key,))
            
            row = self.cursor.fetchone()
            if row:
                data = row[0]
                try:
                    # Try to decode as JSON
                    return json.loads(data)
                except:
                    # Return raw bytes if not JSON
                    return data
            return None
        except Exception as e:
            print(f"Error loading data: {str(e)}")
            return None

    def cache_computation(self, key: str, result: Union[bytes, Dict],
                        metadata: Optional[Dict] = None) -> bool:
        """Cache computation result"""
        try:
            if isinstance(result, dict):
                result = json.dumps(result).encode()
            
            metadata_json = json.dumps(metadata) if metadata else None

            self.cursor.execute("""
                INSERT OR REPLACE INTO computation_cache 
                (cache_key, result, metadata, last_accessed)
                VALUES (?, ?, ?, CURRENT_TIMESTAMP)
            """, (key, result, metadata_json))
            
            self.conn.commit()
            return True
        except Exception as e:
            print(f"Error caching computation: {str(e)}")
            return False

    def get_cached_computation(self, key: str) -> Optional[Union[bytes, Dict]]:
        """Get cached computation result"""
        try:
            self.cursor.execute("""
                UPDATE computation_cache 
                SET last_accessed = CURRENT_TIMESTAMP
                WHERE cache_key = ?
                RETURNING result
            """, (key,))
            
            row = self.cursor.fetchone()
            if row:
                result = row[0]
                try:
                    # Try to decode as JSON
                    return json.loads(result)
                except:
                    # Return raw bytes if not JSON
                    return result
            return None
        except Exception as e:
            print(f"Error getting cached computation: {str(e)}")
            return None

    def cleanup_cache(self, older_than_days: int = 7):
        """Clean up old cache entries"""
        try:
            self.cursor.execute("""
                DELETE FROM computation_cache 
                WHERE last_accessed < datetime('now', ?)
            """, (f'-{older_than_days} days',))
            
            self.conn.commit()
        except Exception as e:
            print(f"Error cleaning cache: {str(e)}")

    def close(self):
        """Close database connection"""
        if self.conn:
            self.conn.close()
            self.conn = None
            self.cursor = None

    def __del__(self):
        """Ensure database connection is closed on deletion"""
        self.close()
