"""
Local storage implementation for NAND cell states and memory operations.
Provides persistent storage using SQLite through LocalStorage.
"""

import numpy as np
import json
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
import threading
import time
from pathlib import Path
from http_storage import LocalStorage



@dataclass
class CellState:
    cell_id: str
    block_id: int
    page_id: int
    value: int
    trapped_electrons: int
    wear_count: int
    retention_loss: float
    temperature: float
    voltage_level: float
    quantum_state: List[float]
    timestamp: float

class RemoteStorageManager:
    """Manages remote storage operations using DuckDB and HuggingFace datasets"""
    
    _instance = None
    _lock = threading.Lock()

    def __new__(cls, dataset_path: str = "hf://datasets/Fred808/helium/storage.json"):  # Standard DB URL
        with cls._lock:
            if cls._instance is None:
                cls._instance = super().__new__(cls)
                cls._instance._init_storage(dataset_path)
            return cls._instance

    def _init_storage(self, dataset_path: str):
        """Initialize local storage connection"""
        # Create database directory if it doesn't exist
        db_dir = "db/vram"
        os.makedirs(db_dir, exist_ok=True)
        
        # Set up local database path
        self.db_path = os.path.join(db_dir, "vram.db")
        
        try:
            # Connect to local database
            self.con = duckdb.connect(self.db_path)
            
            # Initialize extensions
            self.con.execute("INSTALL httpfs;")
            self.con.execute("LOAD httpfs;")
            
            logging.info(f"Connected to local database: {self.db_path}")
            
        except Exception as e:
            logging.warning(f"Failed to connect to local database: {e}")
            logging.info("Using in-memory database instead")
            self.con = duckdb.connect(":memory:")
        
        # Configure HuggingFace authentication
        self.con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
        self.con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
        
        # Create tables for cell states
        self.con.execute("""
            CREATE TABLE IF NOT EXISTS cell_states (
                cell_id VARCHAR,
                block_id INTEGER,
                page_id INTEGER,
                value INTEGER,
                trapped_electrons INTEGER,
                wear_count INTEGER,
                retention_loss DOUBLE,
                temperature DOUBLE,
                voltage_level DOUBLE,
                quantum_state JSON,
                timestamp DOUBLE,
                PRIMARY KEY (cell_id)
            )
        """)
        
        # Create tables for PCIe interface
        self.con.execute("""
            CREATE TABLE IF NOT EXISTS pcie_interface_state (
                id VARCHAR PRIMARY KEY,
                version VARCHAR,
                lanes INTEGER,
                max_gbps DOUBLE,
                active_lanes INTEGER,
                lane_groups JSON,
                lane_errors JSON,
                qos_profiles JSON,
                bandwidth_allocations JSON,
                timestamp TIMESTAMP
            )
        """)
        
        self.con.execute("""
            CREATE TABLE IF NOT EXISTS pcie_transfers (
                id VARCHAR PRIMARY KEY,
                timestamp TIMESTAMP,
                size_bytes BIGINT,
                direction VARCHAR,
                qos_profile_id INTEGER,
                transfer_time DOUBLE,
                lanes_active INTEGER,
                bandwidth_achieved DOUBLE
            )
        """)
        
        self.con.execute("""
            CREATE TABLE IF NOT EXISTS dma_operations (
                id VARCHAR PRIMARY KEY,
                timestamp TIMESTAMP,
                source_addr BIGINT,
                dest_addr BIGINT,
                size_bytes BIGINT,
                priority INTEGER,
                completion_time DOUBLE,
                status VARCHAR
            )
        """)
        
        self.con.execute("""
            CREATE TABLE IF NOT EXISTS qos_metrics (
                timestamp TIMESTAMP,
                profile_id INTEGER,
                bandwidth_allocated DOUBLE,
                bandwidth_used DOUBLE,
                latency_measured DOUBLE,
                latency_target DOUBLE
            )
        """)
        
        # Create indexes
        self.con.execute("CREATE INDEX IF NOT EXISTS idx_block ON cell_states(block_id)")
        self.con.execute("CREATE INDEX IF NOT EXISTS idx_page ON cell_states(page_id)")
        self.con.execute("CREATE INDEX IF NOT EXISTS idx_pcie_timestamp ON pcie_transfers(timestamp)")
        self.con.execute("CREATE INDEX IF NOT EXISTS idx_dma_timestamp ON dma_operations(timestamp)")
        self.con.execute("CREATE INDEX IF NOT EXISTS idx_qos_profile ON qos_metrics(profile_id)")

    def ensure_connection(self):
        """Ensure connection is active and reconnect if needed"""
        try:
            self.con.execute("SELECT 1")
        except:
            # First create an in-memory connection to configure settings
            temp_con = duckdb.connect(":memory:")
            
            # Configure HuggingFace access - must be done before connecting to URL
            temp_con.execute("INSTALL httpfs;")
            temp_con.execute("LOAD httpfs;")
            temp_con.execute("SET s3_endpoint='hf.co';")
            temp_con.execute("SET s3_use_ssl=true;")
            temp_con.execute("SET s3_url_style='path';")
            
            # Now create the real connection with the configured settings
            self.con = duckdb.connect(self.dataset_path, config={'http_keep_alive': 'true'})
            self.con.execute("INSTALL httpfs;")
            self.con.execute("LOAD httpfs;")
            self.con.execute("SET s3_endpoint='hf.co';")
            self.con.execute("SET s3_use_ssl=true;")
            self.con.execute("SET s3_url_style='path';")
            
            # Close temporary connection
            temp_con.close()
            self.con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
            self.con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")

    def store_cell_state(self, state: CellState):
        """Store cell state in remote DB"""
        self.ensure_connection()
        with self._lock:
            self.con.execute("""
                INSERT OR REPLACE INTO cell_states 
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                state.cell_id,
                state.block_id,
                state.page_id,
                state.value,
                state.trapped_electrons,
                state.wear_count,
                state.retention_loss,
                state.temperature,
                state.voltage_level,
                json.dumps(state.quantum_state),
                state.timestamp
            ))

    def get_cell_state(self, cell_id: str) -> Optional[CellState]:
        """Retrieve cell state from remote DB"""
        self.ensure_connection()
        with self._lock:
            result = self.con.execute("""
                SELECT * FROM cell_states 
                WHERE cell_id = ?
            """, [cell_id]).fetchone()
            
            if result:
                return CellState(
                    cell_id=result[0],
                    block_id=result[1],
                    page_id=result[2],
                    value=result[3],
                    trapped_electrons=result[4],
                    wear_count=result[5],
                    retention_loss=result[6],
                    temperature=result[7],
                    voltage_level=result[8],
                    quantum_state=json.loads(result[9]),
                    timestamp=result[10]
                )
            return None

    def get_block_states(self, block_id: int) -> List[CellState]:
        """Retrieve all cell states for a given block"""
        self.ensure_connection()
        with self._lock:
            results = self.con.execute("""
                SELECT * FROM cell_states 
                WHERE block_id = ?
                ORDER BY page_id, cell_id
            """, [block_id]).fetchall()
            
            return [
                CellState(
                    cell_id=r[0],
                    block_id=r[1],
                    page_id=r[2],
                    value=r[3],
                    trapped_electrons=r[4],
                    wear_count=r[5],
                    retention_loss=r[6],
                    temperature=r[7],
                    voltage_level=r[8],
                    quantum_state=json.loads(r[9]),
                    timestamp=r[10]
                )
                for r in results
            ]

    def update_cell_value(self, cell_id: str, value: int, quantum_state: List[float]):
        """Update cell value and quantum state"""
        self.ensure_connection()
        with self._lock:
            self.con.execute("""
                UPDATE cell_states 
                SET value = ?,
                    quantum_state = ?,
                    timestamp = ?
                WHERE cell_id = ?
            """, (value, json.dumps(quantum_state), time.time(), cell_id))

    def get_block_wear_stats(self, block_id: int) -> Dict[str, float]:
        """Get wear statistics for a block"""
        self.ensure_connection()
        with self._lock:
            result = self.con.execute("""
                SELECT 
                    AVG(wear_count) as avg_wear,
                    MAX(wear_count) as max_wear,
                    AVG(retention_loss) as avg_retention_loss,
                    COUNT(*) as cell_count
                FROM cell_states 
                WHERE block_id = ?
            """, [block_id]).fetchone()
            
            return {
                'avg_wear': result[0],
                'max_wear': result[1],
                'avg_retention_loss': result[2],
                'cell_count': result[3]
            }

    def cleanup_old_states(self, max_age_hours: float = 24.0):
        """Clean up old states from all tables"""
        self.ensure_connection()
        with self._lock:
            cutoff_time = time.time() - (max_age_hours * 3600)
            tables = ['cell_states', 'pcie_transfers', 'dma_operations', 'qos_metrics']
            for table in tables:
                self.con.execute(f"""
                    DELETE FROM {table}
                    WHERE timestamp < ?
                """, [cutoff_time])

    def store_interface_state(self, state: Dict):
        """Store PCIe interface state"""
        self.ensure_connection()
        with self._lock:
            self.con.execute("""
                INSERT OR REPLACE INTO pcie_interface_state 
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                state.get('id', str(time.time())),
                state['version'],
                state['lanes'],
                state['max_gbps'],
                state['active_lanes'],
                json.dumps(state['lane_groups']),
                json.dumps(state['lane_errors']),
                json.dumps(state['qos_profiles']),
                json.dumps(state['bandwidth_allocations']),
                state['timestamp']
            ))

    def store_transfer(self, transfer: Dict):
        """Store PCIe transfer details"""
        self.ensure_connection()
        with self._lock:
            self.con.execute("""
                INSERT INTO pcie_transfers 
                VALUES (?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                str(time.time()),
                transfer['timestamp'],
                transfer['size_bytes'],
                transfer['direction'],
                transfer['qos_profile_id'],
                transfer['transfer_time'],
                transfer['lanes_active'],
                transfer['bandwidth_achieved']
            ))

    def store_dma_operation(self, dma: Dict):
        """Store DMA operation details"""
        self.ensure_connection()
        with self._lock:
            self.con.execute("""
                INSERT INTO dma_operations 
                VALUES (?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                str(time.time()),
                dma['timestamp'],
                dma['source_addr'],
                dma['dest_addr'],
                dma['size_bytes'],
                dma['priority'],
                dma['completion_time'],
                dma['status']
            ))

    def store_qos_metrics(self, metrics: Dict):
        """Store QoS metrics"""
        self.ensure_connection()
        with self._lock:
            self.con.execute("""
                INSERT INTO qos_metrics 
                VALUES (?, ?, ?, ?, ?, ?)
            """, (
                metrics['timestamp'],
                metrics['profile_id'],
                metrics['bandwidth_allocated'],
                metrics['bandwidth_used'],
                metrics['latency_measured'],
                metrics['latency_target']
            ))

    def get_transfer_stats(self, time_window: float = 3600) -> Dict:
        """Get transfer statistics for the last time_window seconds"""
        self.ensure_connection()
        with self._lock:
            cutoff = time.time() - time_window
            result = self.con.execute("""
                SELECT 
                    COUNT(*) as total_transfers,
                    SUM(size_bytes) as total_bytes,
                    AVG(transfer_time) as avg_transfer_time,
                    AVG(bandwidth_achieved) as avg_bandwidth
                FROM pcie_transfers 
                WHERE timestamp > ?
            """, [cutoff]).fetchone()
            
            return {
                'total_transfers': result[0],
                'total_bytes': result[1],
                'avg_transfer_time': result[2],
                'avg_bandwidth': result[3]
            }

    def get_dma_stats(self, time_window: float = 3600) -> Dict:
        """Get DMA operation statistics for the last time_window seconds"""
        self.ensure_connection()
        with self._lock:
            cutoff = time.time() - time_window
            result = self.con.execute("""
                SELECT 
                    COUNT(*) as total_operations,
                    SUM(size_bytes) as total_bytes,
                    AVG(completion_time) as avg_completion_time,
                    COUNT(CASE WHEN status = 'completed' THEN 1 END) as successful_ops
                FROM dma_operations 
                WHERE timestamp > ?
            """, [cutoff]).fetchone()
            
            return {
                'total_operations': result[0],
                'total_bytes': result[1],
                'avg_completion_time': result[2],
                'successful_ops': result[3]
            }

    def get_qos_compliance(self, time_window: float = 3600) -> Dict:
        """Get QoS compliance metrics for the last time_window seconds"""
        self.ensure_connection()
        with self._lock:
            cutoff = time.time() - time_window
            result = self.con.execute("""
                SELECT 
                    profile_id,
                    AVG(bandwidth_used / bandwidth_allocated) as bandwidth_utilization,
                    AVG(CASE WHEN latency_measured <= latency_target THEN 1 ELSE 0 END) as sla_compliance
                FROM qos_metrics 
                WHERE timestamp > ?
                GROUP BY profile_id
            """, [cutoff]).fetchall()
            
            return {
                row[0]: {
                    'bandwidth_utilization': row[1],
                    'sla_compliance': row[2]
                }
                for row in result
            }
