"""
Local storage implementation for NAND cell states and memory operations.
Provides persistent storage using SQLite through LocalStorage.
"""

import numpy as np
import json
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
import threading
import time
from pathlib import Path
import sqlite3
import os

@dataclass
class CellState:
    cell_id: str
    block_id: int
    page_id: int
    value: int
    trapped_electrons: int
    wear_count: int
    retention_loss: float
    temperature: float
    voltage_level: float
    quantum_state: List[float]
    timestamp: float

class LocalStorageManager:
    """Manages local storage operations using SQLite"""
    
    _instance = None
    _lock = threading.Lock()

    def __new__(cls, db_path: str = "db/vram/storage.db"):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super().__new__(cls)
                cls._instance._init_storage(db_path)
            return cls._instance

    def _init_storage(self, db_path: str):
        """Initialize local storage connection"""
        self.db_path = db_path
        os.makedirs(os.path.dirname(db_path), exist_ok=True)
        
        self.conn = sqlite3.connect(db_path, check_same_thread=False)
        self._setup_tables()
        
    def _setup_tables(self):
        """Initialize database tables"""
        # Cell states table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS cell_states (
                cell_id TEXT PRIMARY KEY,
                block_id INTEGER,
                page_id INTEGER,
                value INTEGER,
                trapped_electrons INTEGER,
                wear_count INTEGER,
                retention_loss REAL,
                temperature REAL,
                voltage_level REAL,
                quantum_state TEXT,
                timestamp REAL
            )
        """)
        
        # Memory blocks table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS memory_blocks (
                block_id INTEGER PRIMARY KEY,
                allocation_time REAL,
                last_access REAL,
                wear_level INTEGER,
                is_available INTEGER,
                metadata TEXT
            )
        """)
        
        # Memory pages table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS memory_pages (
                page_id INTEGER,
                block_id INTEGER,
                data BLOB,
                last_write REAL,
                error_count INTEGER,
                PRIMARY KEY (page_id, block_id),
                FOREIGN KEY (block_id) REFERENCES memory_blocks(block_id)
            )
        """)
        
        self.conn.commit()

    def store_cell_state(self, state: CellState):
        """Store cell state in database"""
        with self._lock:
            self.conn.execute("""
                INSERT OR REPLACE INTO cell_states 
                (cell_id, block_id, page_id, value, trapped_electrons, 
                wear_count, retention_loss, temperature, voltage_level,
                quantum_state, timestamp)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                state.cell_id,
                state.block_id,
                state.page_id,
                state.value,
                state.trapped_electrons,
                state.wear_count,
                state.retention_loss,
                state.temperature,
                state.voltage_level,
                json.dumps(state.quantum_state),
                state.timestamp
            ))
            self.conn.commit()

    def get_cell_state(self, cell_id: str) -> Optional[CellState]:
        """Retrieve cell state from database"""
        with self._lock:
            result = self.conn.execute(
                "SELECT * FROM cell_states WHERE cell_id = ?",
                (cell_id,)
            ).fetchone()
            
        if result is None:
            return None
            
        return CellState(
            cell_id=result[0],
            block_id=result[1],
            page_id=result[2],
            value=result[3],
            trapped_electrons=result[4],
            wear_count=result[5],
            retention_loss=result[6],
            temperature=result[7],
            voltage_level=result[8],
            quantum_state=json.loads(result[9]),
            timestamp=result[10]
        )

    def allocate_block(self, metadata: Dict[str, Any] = None) -> int:
        """Allocate a new memory block"""
        with self._lock:
            result = self.conn.execute("""
                SELECT block_id 
                FROM memory_blocks 
                WHERE is_available = 1 
                ORDER BY wear_level ASC 
                LIMIT 1
            """).fetchone()
            
            if result:
                block_id = result[0]
                self.conn.execute("""
                    UPDATE memory_blocks 
                    SET is_available = 0,
                        allocation_time = ?,
                        last_access = ?,
                        metadata = ?
                    WHERE block_id = ?
                """, (time.time(), time.time(), json.dumps(metadata or {}), block_id))
            else:
                block_id = self.conn.execute("""
                    INSERT INTO memory_blocks 
                    (allocation_time, last_access, wear_level, is_available, metadata)
                    VALUES (?, ?, 0, 0, ?)
                    RETURNING block_id
                """, (time.time(), time.time(), json.dumps(metadata or {}))).fetchone()[0]
                
            self.conn.commit()
            return block_id

    def free_block(self, block_id: int):
        """Mark a memory block as available"""
        with self._lock:
            self.conn.execute("""
                UPDATE memory_blocks 
                SET is_available = 1 
                WHERE block_id = ?
            """, (block_id,))
            self.conn.commit()

    def write_page(self, block_id: int, page_id: int, data: bytes):
        """Write data to a memory page"""
        with self._lock:
            self.conn.execute("""
                INSERT OR REPLACE INTO memory_pages 
                (page_id, block_id, data, last_write, error_count)
                VALUES (?, ?, ?, ?, 0)
            """, (page_id, block_id, data, time.time()))
            
            self.conn.execute("""
                UPDATE memory_blocks 
                SET last_access = ?,
                    wear_level = wear_level + 1
                WHERE block_id = ?
            """, (time.time(), block_id))
            
            self.conn.commit()

    def read_page(self, block_id: int, page_id: int) -> Optional[bytes]:
        """Read data from a memory page"""
        with self._lock:
            result = self.conn.execute("""
                SELECT data 
                FROM memory_pages 
                WHERE block_id = ? AND page_id = ?
            """, (block_id, page_id)).fetchone()
            
            if result is None:
                return None
                
            self.conn.execute("""
                UPDATE memory_blocks 
                SET last_access = ? 
                WHERE block_id = ?
            """, (time.time(), block_id))
            
            self.conn.commit()
            return result[0]

    def get_block_info(self, block_id: int) -> Optional[Dict[str, Any]]:
        """Get information about a memory block"""
        with self._lock:
            result = self.conn.execute("""
                SELECT allocation_time, last_access, wear_level, is_available, metadata
                FROM memory_blocks
                WHERE block_id = ?
            """, (block_id,)).fetchone()
            
        if result is None:
            return None
            
        return {
            "allocation_time": result[0],
            "last_access": result[1],
            "wear_level": result[2],
            "is_available": bool(result[3]),
            "metadata": json.loads(result[4])
        }

    def close(self):
        """Close database connection"""
        if hasattr(self, 'conn'):
            self.conn.close()

    def __del__(self):
        self.close()
