import os
from typing import Dict, Any, Optional, Union
import time
import json
import sqlite3
import hashlib
import logging
import numpy as np
from pathlib import Path
from electron_speed import max_switch_freq, drift_velocity, transit_time

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

class LocalStorage:
    """Local storage implementation for hardware state persistence"""
    def __init__(self, db_path: str = "db/hardware.db"):
        self.db_path = db_path
        os.makedirs(os.path.dirname(db_path), exist_ok=True)
        self.conn = sqlite3.connect(db_path)
        self._init_db()
        
    def _init_db(self):
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS hardware_state (
                id TEXT PRIMARY KEY,
                type TEXT NOT NULL,
                state TEXT NOT NULL,
                updated_at REAL NOT NULL
            )
        """)
        self.conn.commit()
        
    def store_state(self, hardware_id: str, hardware_type: str, state: Dict[str, Any]):
        self.conn.execute("""
            INSERT OR REPLACE INTO hardware_state (id, type, state, updated_at)
            VALUES (?, ?, ?, ?)
        """, (hardware_id, hardware_type, json.dumps(state), time.time()))
        self.conn.commit()
        
    def load_state(self, hardware_id: str) -> Optional[Dict[str, Any]]:
        result = self.conn.execute("""
            SELECT state FROM hardware_state WHERE id = ?
        """, (hardware_id,)).fetchone()
        if result:
            return json.loads(result[0])
        return None

class ElectronBuffer:
    """
    Pure electron-speed buffer implementation.
    Uses direct electron operations with zero storage allocation.
    """
    
    def __init__(self):
        self.switch_freq = max_switch_freq  # ~9.80e14 Hz
        self.drift_speed = drift_velocity   # m/s
        self.gate_delay = transit_time      # seconds
        self.cycle_time = 1.0 / self.switch_freq
        
    def compute_timing(self, data_size: int) -> float:
        """Calculate electron transit time for data size"""
        return data_size * self.gate_delay
        
class DirectStorage:
    """
    Direct electron-speed storage implementation with disk-based persistence.
    Uses memory-mapped files and SQLite for zero-RAM operation.
    """

    def _init_singleton(self):
        """Initialize the singleton instance with disk storage"""
        # Create data directories if they don't exist
        self.data_dir = Path('data')
        self.chunk_dir = self.data_dir / 'chunks'
        self.data_dir.mkdir(exist_ok=True)
        self.chunk_dir.mkdir(exist_ok=True)
        
        # Setup SQLite database with WAL mode for better concurrent access
        self.db_path = self.data_dir / 'storage.db'
        self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
        self.conn.execute('PRAGMA journal_mode=WAL')  # Write-Ahead Logging
        self.conn.execute('PRAGMA synchronous=NORMAL')  # Better performance
        self.conn.execute('PRAGMA mmap_size=2147483648')  # 2GB memory map
        self.cursor = self.conn.cursor()
        
        # Create tables for metadata and chunk mapping
        self.cursor.execute('''
            CREATE TABLE IF NOT EXISTS chunks (
                chunk_id TEXT PRIMARY KEY,
                file_path TEXT NOT NULL,
                size INTEGER NOT NULL,
                created_at INTEGER NOT NULL,
                last_accessed INTEGER NOT NULL
            )
        ''')
        
        self.cursor.execute('''
            CREATE TABLE IF NOT EXISTS metadata (
                key TEXT PRIMARY KEY,
                chunk_ids TEXT NOT NULL,
                total_size INTEGER NOT NULL,
                created_at INTEGER NOT NULL,
                last_modified INTEGER NOT NULL
            )
        ''')
        self.conn.commit()
        
        # Initialize memory mapping for active chunks
        self.active_mappings = {}
        
    def _create_chunk_file(self, chunk_id: str, size: int) -> Path:
        """Create a new chunk file with specified size"""
        chunk_path = self.chunk_dir / f"{chunk_id}.dat"
        with open(chunk_path, 'wb') as f:
            f.seek(size - 1)
            f.write(b'\0')
        return chunk_path
        
    def _get_chunk_mapping(self, chunk_id: str, size: int) -> np.ndarray:
        """Get memory-mapped array for chunk, creating if needed"""
        import mmap
        
        if chunk_id not in self.active_mappings:
            # Get chunk file path from database
            self.cursor.execute('SELECT file_path FROM chunks WHERE chunk_id = ?', (chunk_id,))
            result = self.cursor.fetchone()
            
            if not result:
                # Create new chunk file
                chunk_path = self._create_chunk_file(chunk_id, size)
                # Register in database
                self.cursor.execute('''
                    INSERT INTO chunks (chunk_id, file_path, size, created_at, last_accessed)
                    VALUES (?, ?, ?, ?, ?)
                ''', (chunk_id, str(chunk_path), size, int(time.time()), int(time.time())))
                self.conn.commit()
            else:
                chunk_path = Path(result[0])
            
            # Create memory mapping
            fd = os.open(str(chunk_path), os.O_RDWR)
            mapping = mmap.mmap(fd, 0, access=mmap.ACCESS_WRITE)
            self.active_mappings[chunk_id] = mapping
            
        # Update last accessed time
        self.cursor.execute('''
            UPDATE chunks SET last_accessed = ? WHERE chunk_id = ?
        ''', (int(time.time()), chunk_id))
        self.conn.commit()
        
        return np.frombuffer(self.active_mappings[chunk_id], dtype=np.uint8)
        
    def store_data(self, key: str, data: Union[bytes, np.ndarray]) -> None:
        """Store data using disk-based chunks"""
        # Convert numpy array to bytes if needed
        if isinstance(data, np.ndarray):
            data = data.tobytes()
            
        # Generate chunk ID from content hash
        chunk_id = hashlib.sha256(data).hexdigest()
        size = len(data)
        
        # Store data in chunk file
        chunk_array = self._get_chunk_mapping(chunk_id, size)
        chunk_array[:size] = np.frombuffer(data, dtype=np.uint8)
        
        # Update metadata
        self.cursor.execute('''
            INSERT OR REPLACE INTO metadata (key, chunk_ids, total_size, created_at, last_modified)
            VALUES (?, ?, ?, ?, ?)
        ''', (key, json.dumps([chunk_id]), size, int(time.time()), int(time.time())))
        self.conn.commit()
        
    def load_data(self, key: str) -> Optional[bytes]:
        """Load data from disk-based chunks"""
        # Get metadata
        self.cursor.execute('SELECT chunk_ids, total_size FROM metadata WHERE key = ?', (key,))
        result = self.cursor.fetchone()
        
        if not result:
            return None
            
        chunk_ids = json.loads(result[0])
        total_size = result[1]
        
        # Load and combine chunks
        data = bytearray()
        for chunk_id in chunk_ids:
            chunk_array = self._get_chunk_mapping(chunk_id, total_size)
            data.extend(chunk_array.tobytes())
            
        return bytes(data)
        
    def cleanup_old_chunks(self, max_age_hours: int = 24) -> None:
        """Remove chunks that haven't been accessed in specified hours"""
        cutoff_time = int(time.time()) - (max_age_hours * 3600)
        
        # Find old chunks
        self.cursor.execute('SELECT chunk_id, file_path FROM chunks WHERE last_accessed < ?', (cutoff_time,))
        old_chunks = self.cursor.fetchall()
        
        for chunk_id, file_path in old_chunks:
            # Remove from active mappings if present
            if chunk_id in self.active_mappings:
                self.active_mappings[chunk_id].close()
                del self.active_mappings[chunk_id]
                
            # Delete file
            try:
                Path(file_path).unlink()
            except Exception as e:
                logging.warning(f"Failed to delete chunk file {file_path}: {e}")
                
            # Remove from database
            self.cursor.execute('DELETE FROM chunks WHERE chunk_id = ?', (chunk_id,))
            
        self.conn.commit()
        
    def optimize_storage(self) -> None:
        """Optimize database and chunk storage"""
        # Vacuum database to reclaim space
        self.cursor.execute('VACUUM')
        
        # Clear unused memory mappings
        current_time = int(time.time())
        for chunk_id in list(self.active_mappings.keys()):
            self.cursor.execute('SELECT last_accessed FROM chunks WHERE chunk_id = ?', (chunk_id,))
            result = self.cursor.fetchone()
            
            if not result or (current_time - result[0]) > 3600:  # 1 hour
                self.active_mappings[chunk_id].close()
                del self.active_mappings[chunk_id]
        
    def close(self) -> None:
        """Clean up resources"""
        # Close all memory mappings
        for mapping in self.active_mappings.values():
            mapping.close()
        self.active_mappings.clear()
        
        # Close database connection
        if hasattr(self, 'conn') and self.conn:
            self.conn.close()
        self.storage_id = hashlib.md5(str(self.db_path).encode()).hexdigest()[:8]
        
        # Initialize electron-speed buffer
        self.buffer = ElectronBuffer()
        
        # Direct state tracking
        self.sm_states = {}  # Streaming multiprocessor states
        self.vram_states = {}  # VRAM states
        
        # Performance tracking
        self.op_count = 0
        self.last_op_time = time.time()
        self.cycle_time = 1.0 / self.buffer.switch_freq
        
        # Check if tables exist first
        existing_tables = self.cursor.execute(
            "SELECT name FROM sqlite_master WHERE type='table'"
        ).fetchall()
        existing_tables = {table[0] for table in existing_tables}
        
        # Add existing tables to our table tracking
        for table in existing_tables:
            self.tables[table] = self._get_table_schema(table)
        
        # Only create tables if they don't exist
        if 'sm_states' not in existing_tables:
            self.cursor.execute("""
            CREATE TABLE sm_states (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                sm_id TEXT,
                chip_id TEXT,
                state_json TEXT,
                sm_key TEXT,
                state_data TEXT,
                timestamp INTEGER,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )""")
            self.cursor.execute("CREATE INDEX idx_sm_states_sm_id ON sm_states(sm_id)")
            # Add to our table tracking
            self.tables['sm_states'] = self._get_table_schema('sm_states')
            
        if 'vram_states' not in existing_tables:
            self.cursor.execute("""
            CREATE TABLE vram_states (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                vram_id TEXT UNIQUE,
                state_data TEXT,
                timestamp INTEGER,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )""")
            self.cursor.execute("CREATE INDEX idx_vram_states_vram_id ON vram_states(vram_id)")
            # Add to our table tracking
            self.tables['vram_states'] = self._get_table_schema('vram_states')
            
        self.conn.commit()
        
    def _get_table_schema(self, table_name: str) -> Dict[str, str]:
        """Get the schema of an existing table"""
        try:
            schema = {}
            for row in self.cursor.execute(f"PRAGMA table_info({table_name})"):
                col_name = row[1]  # Column name is the second field
                col_type = row[2]  # Type is the third field
                schema[col_name] = col_type
            return schema
        except sqlite3.Error as e:
            print(f"Error getting schema for table {table_name}: {e}")
            return {}
            
    def store_tensor(self, tensor_id: str, data: Any, shape: tuple) -> bool:
        """Store tensor at electron speed"""
        start_time = time.time()
        
        try:
            # Direct electron-speed update
            self.tensor_buffer[tensor_id] = {
                'data': data,
                'shape': shape,
                'timestamp': start_time
            }
            
            # Validate electron timing
            elapsed = time.time() - start_time
            if elapsed > self.tensor_cycle:
                print(f"Warning: Tensor operation exceeded electron cycle time")
            
            return True
        except Exception as e:
            print(f"Error in electron-speed tensor operation: {e}")
            return False
        
        # Initialize database tables
        self._init_db_tables()
        
        self._connected = True
        self.initialized = True
        
    def _init_db_tables(self):
        """Initialize database tables"""
        with self.lock:
            pass  # Skip DB initialization since we're not using storage
            
    def init_electron_buffers(self):
        """Initialize electron-speed buffers"""
        # Initialize direct buffers
        self.tensor_buffer = {}  # Direct tensor storage
        self.sm_buffer = {}      # SM state buffer
        self.vram_buffer = {}    # VRAM state buffer
        
        # Set electron timing parameters
        self.tensor_cycle = 1.0 / max_switch_freq
        self.buffer_latency = drift_velocity * transit_time
            
            
    # Delegate all storage operations to HuggingFace storage manager
    @property
    def resource_monitor(self):
        """Get resource monitor stats"""
        with self.lock:
            self.cursor.execute("SELECT COUNT(*) FROM tensors")
            tensor_count = self.cursor.fetchone()[0]
            self.cursor.execute("SELECT COUNT(*) FROM models")
            model_count = self.cursor.fetchone()[0]
            
            return {
                'vram_used': 0,  # Local storage doesn't track VRAM
                'active_tensors': tensor_count,
                'loaded_models': model_count,
                'last_updated': time.time()
            }
    
    @property
    def stats(self):
        """Get storage stats"""
        with self.lock:
            self.cursor.execute("SELECT COUNT(*), SUM(LENGTH(data)) FROM tensors")
            tensor_stats = self.cursor.fetchone()
            self.cursor.execute("SELECT COUNT(*), SUM(LENGTH(data)) FROM models")
            model_stats = self.cursor.fetchone()
            
            total_size = (tensor_stats[1] or 0) + (model_stats[1] or 0)
            return {
                'total_size': total_size,
                'available_size': float('inf'),  # Limited by disk space
                'model_count': model_stats[0],
                'tensor_count': tensor_stats[0]
            }
    
    @property
    def model_registry(self):
        """Get model registry"""
        with self.lock:
            self.cursor.execute("SELECT name, config FROM models")
            return {row[0]: json.loads(row[1]) if row[1] else {} 
                   for row in self.cursor.fetchall()}
    
    @property
    def tensor_registry(self):
        """Get tensor registry"""
        with self.lock:
            self.cursor.execute("SELECT id, shape, dtype FROM tensors")
            return {row[0]: {'shape': json.loads(row[1]), 'dtype': row[2]} 
                   for row in self.cursor.fetchall()}
        
    def is_connected(self) -> bool:
        """Check if storage is connected"""
        return self._connected and not self._closing
        
    def close(self):
        """Close storage connection"""
        with self.lock:
            self._closing = True
            self._connected = False
            if hasattr(self, 'conn'):
                self.conn.close()

    def is_model_loaded(self, model_id: str) -> bool:
        """Check if a model is loaded in storage"""
        with self.lock:
            self.cursor.execute("SELECT 1 FROM models WHERE name = ?", (model_id,))
            return bool(self.cursor.fetchone())
        
    def wait_for_connection(self, timeout: float = 30.0) -> bool:
        """Wait for storage connection to be ready"""
        return self.storage_manager.wait_for_connection(timeout)
        
    def __init__(self):
        """This will actually just return the singleton instance. 
        The actual initialization happens in __new__ and _init_singleton"""
        pass

    def _check_storage_ready(self) -> bool:
        """Check if storage is ready for use"""
        return self.storage_manager._check_storage_ready()

    def _check_storage(self) -> Dict[str, Any]:
        """Check storage status and usage"""
        try:
            stats = self.storage_manager.get_stats()
            return {"status": "ok", "monitor": stats.get('resource_monitor', {})}
        except Exception as e:
            logging.error(f"Error checking storage: {e}")
            return {"status": "error", "message": str(e)}

    def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
        """Store tensor data in local SQLite database"""
        try:
            with self.lock:
                self.cursor.execute("""
                    INSERT OR REPLACE INTO tensors (id, data, shape, dtype) 
                    VALUES (?, ?, ?, ?)
                """, (
                    tensor_id,
                    data.tobytes(),
                    json.dumps(data.shape),
                    str(data.dtype)
                ))
                self.conn.commit()
                return True
        except Exception as e:
            logging.error(f"Error storing tensor {tensor_id}: {e}")
            return False

    def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
        """Load tensor data from local SQLite database"""
        try:
            with self.lock:
                self.cursor.execute("""
                    SELECT data, shape, dtype FROM tensors WHERE id = ?
                """, (tensor_id,))
                row = self.cursor.fetchone()
                if row:
                    data_bytes, shape_str, dtype_str = row
                    shape = json.loads(shape_str)
                    return np.frombuffer(data_bytes, dtype=dtype_str).reshape(shape)
                return None
        except Exception as e:
            logging.error(f"Error loading tensor {tensor_id}: {e}")
            return None

    def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
        """Store component state in local SQLite database"""
        try:
            with self.lock:
                self.cursor.execute("""
                    INSERT OR REPLACE INTO states (component, state_id, data)
                    VALUES (?, ?, ?)
                """, (
                    component,
                    state_id,
                    json.dumps(state_data)
                ))
                self.conn.commit()
                return True
        except Exception as e:
            logging.error(f"Error storing state {component}/{state_id}: {e}")
            return False

    def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
        """Load component state from local SQLite database"""
        try:
            with self.lock:
                self.cursor.execute("""
                    SELECT data FROM states WHERE component = ? AND state_id = ?
                """, (component, state_id))
                row = self.cursor.fetchone()
                if row:
                    return json.loads(row[0])
                return None
        except Exception as e:
            logging.error(f"Error loading state {component}/{state_id}: {e}")
            return None
    
    def load_model(self, model_name: str, model_data: Optional[Union[bytes, Dict]] = None, model_config: Optional[Dict] = None) -> bool:
        """Store/load model in local SQLite database"""
        try:
            with self.lock:
                if model_data is not None:
                    # Store model
                    data_bytes = model_data if isinstance(model_data, bytes) else json.dumps(model_data).encode()
                    self.cursor.execute("""
                        INSERT OR REPLACE INTO models (name, data, config)
                        VALUES (?, ?, ?)
                    """, (
                        model_name,
                        data_bytes,
                        json.dumps(model_config) if model_config else None
                    ))
                    self.conn.commit()
                    return True
                else:
                    # Load model
                    self.cursor.execute("SELECT data FROM models WHERE name = ?", (model_name,))
                    row = self.cursor.fetchone()
                    if row:
                        return row[0]
                    return None
        except Exception as e:
            logging.error(f"Error with model {model_name}: {e}")
            return False

    def ping(self) -> bool:
        """Check if storage is accessible"""
        try:
            with self.lock:
                self.cursor.execute("SELECT 1")
                return bool(self.cursor.fetchone())
        except Exception:
            return False
            
    def __del__(self):
        """Ensure connection is closed on deletion"""
        self.close()

# Compatibility aliases for existing code