"""
Enhanced NAND Flash Page with database-backed state tracking and ECC
"""

import time
import json
import logging
import numpy as np
import duckdb
from typing import List, Optional, Tuple, Dict
from datetime import datetime
from huggingface_hub import HfApi, HfFileSystem
from nand_cell import MultiLevelCell
from config import get_hf_token

# Initialize token from .env



class Page:
    """NAND Flash page with error correction and remote state tracking"""
    
    STORAGE_URL = "hf://datasets/Fred808/helium/storage.json"  # Standard DB URL
    
    def __init__(self, num_cells: int, channel_length: float, 
                drift_velocity: float, levels: int,
                block_id: int = 0, page_id: int = 0):
        self.page_id = page_id
        self.block_id = block_id
        self.num_cells = num_cells
        
        # Initialize database connection
        self._init_db_connection()
        self._setup_database()
        
        # Initialize cells with proper IDs
        self.cells = [
            MultiLevelCell(
                channel_length=channel_length,
                drift_velocity=drift_velocity,
                levels=levels,
                block_id=block_id,
                page_id=page_id
            ) for _ in range(num_cells)
        ]
        
        # Initialize page state
        self._init_page_state()
        
    def _init_db_connection(self):
        """Initialize database connection with HuggingFace configuration"""
        try:
            # Connect directly to HuggingFace URL
            # First create an in-memory connection to configure settings
            temp_conn = duckdb.connect(":memory:")
            
            # Configure HuggingFace access - must be done before connecting to URL
            temp_conn.execute("INSTALL httpfs;")
            temp_conn.execute("LOAD httpfs;")
            temp_conn.execute("SET s3_endpoint='hf.co';")
            temp_conn.execute("SET s3_use_ssl=true;")
            temp_conn.execute("SET s3_url_style='path';")
            
            # Now create the real connection with the configured settings
            self.conn = duckdb.connect(self.STORAGE_URL, config={'http_keep_alive': 'true'})
            self.conn.execute("INSTALL httpfs;")
            self.conn.execute("LOAD httpfs;")
            self.conn.execute("SET s3_endpoint='hf.co';")
            self.conn.execute("SET s3_use_ssl=true;")
            self.conn.execute("SET s3_url_style='path';")
            
            # Close temporary connection
            temp_conn.close() 
            self.conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
            self.conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
            
        except Exception as e:
            logging.error(f"Failed to initialize database connection: {str(e)}")
            raise
            
    def _setup_database(self):
        """Initialize database tables for page state tracking"""
        try:
            # Page state table
            self.conn.execute("""
                CREATE TABLE IF NOT EXISTS page_states (
                    page_id INTEGER,
                    block_id INTEGER,
                    ecc_value INTEGER DEFAULT 0,
                    error_count INTEGER DEFAULT 0,
                    last_program TIMESTAMP,
                    last_erase TIMESTAMP,
                    program_count INTEGER DEFAULT 0,
                    erase_count INTEGER DEFAULT 0,
                    valid BOOLEAN DEFAULT true,
                    bit_error_rate FLOAT DEFAULT 0.0,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    PRIMARY KEY (block_id, page_id)
                )
            """)
            
            # Page operation history
            self.conn.execute("""
                CREATE TABLE IF NOT EXISTS page_operations (
                    operation_id INTEGER PRIMARY KEY,
                    block_id INTEGER,
                    page_id INTEGER,
                    operation_type VARCHAR,
                    data_size INTEGER,
                    success BOOLEAN,
                    error_type VARCHAR,
                    ecc_corrections INTEGER,
                    duration_ns INTEGER,
                    operation_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    FOREIGN KEY (block_id, page_id) REFERENCES page_states(block_id, page_id)
                )
            """)
            
        except Exception as e:
            logging.error(f"Failed to setup database: {str(e)}")
            raise
            
    def _init_page_state(self):
        """Initialize or load page state"""
        try:
            self.conn.execute("""
                INSERT OR IGNORE INTO page_states (
                    block_id, page_id
                ) VALUES (?, ?)
            """, [self.block_id, self.page_id])
            
        except Exception as e:
            logging.error(f"Failed to initialize page state: {str(e)}")
            raise

    def program(self, data: List[int]) -> bool:
        """Program page with ECC and error tracking"""
        try:
            start_time = time.time()
            success = True
            error_type = None
            corrections = 0
            
            # Generate ECC
            ecc_value = self._calculate_ecc(data)
            
            # Program cells
            for i, value in enumerate(data):
                prog_success = self.cells[i].program(value)
                if not prog_success:
                    success = False
                    error_type = 'program_fail'
            
            # Update page state
            self.conn.execute("""
                UPDATE page_states
                SET ecc_value = ?,
                    program_count = program_count + 1,
                    last_program = CURRENT_TIMESTAMP,
                    error_count = CASE WHEN ? THEN error_count + 1 ELSE error_count END,
                    updated_at = CURRENT_TIMESTAMP
                WHERE block_id = ? AND page_id = ?
            """, [
                ecc_value,
                not success,
                self.block_id,
                self.page_id
            ])
            
            # Log operation
            duration = int((time.time() - start_time) * 1e9)  # Convert to nanoseconds
            self.conn.execute("""
                INSERT INTO page_operations (
                    block_id, page_id, operation_type,
                    data_size, success, error_type,
                    ecc_corrections, duration_ns
                ) VALUES (?, ?, 'program', ?, ?, ?, ?, ?)
            """, [
                self.block_id,
                self.page_id,
                len(data),
                success,
                error_type,
                corrections,
                duration
            ])
            
            return success
            
        except Exception as e:
            logging.error(f"Page program failed: {str(e)}")
            return False

    def erase(self) -> bool:
        """Erase page with error tracking"""
        try:
            start_time = time.time()
            success = True
            error_type = None
            
            # Erase all cells
            for cell in self.cells:
                if not cell.erase():
                    success = False
                    error_type = 'erase_fail'
            
            # Update page state
            self.conn.execute("""
                UPDATE page_states
                SET ecc_value = 0,
                    erase_count = erase_count + 1,
                    last_erase = CURRENT_TIMESTAMP,
                    error_count = CASE WHEN ? THEN error_count + 1 ELSE error_count END,
                    updated_at = CURRENT_TIMESTAMP
                WHERE block_id = ? AND page_id = ?
            """, [
                not success,
                self.block_id,
                self.page_id
            ])
            
            # Log operation
            duration = int((time.time() - start_time) * 1e9)
            self.conn.execute("""
                INSERT INTO page_operations (
                    block_id, page_id, operation_type,
                    success, error_type, duration_ns
                ) VALUES (?, ?, 'erase', ?, ?, ?)
            """, [
                self.block_id,
                self.page_id,
                success,
                error_type,
                duration
            ])
            
            return success
            
        except Exception as e:
            logging.error(f"Page erase failed: {str(e)}")
            return False

    def read(self) -> Tuple[Optional[List[int]], int]:
        """Read page with ECC checking and error correction"""
        try:
            start_time = time.time()
            success = True
            error_type = None
            corrections = 0
            
            # Read stored ECC
            stored_ecc = self.conn.execute("""
                SELECT ecc_value
                FROM page_states
                WHERE block_id = ? AND page_id = ?
            """, [self.block_id, self.page_id]).fetchone()[0]
            
            # Read cell data
            data = []
            for cell in self.cells:
                value, confidence = cell.read()
                data.append(value)
                
                # Track potential bit errors based on confidence
                if confidence < 0.9:  # Less than 90% confidence
                    corrections += 1
            
            # Verify ECC
            calculated_ecc = self._calculate_ecc(data)
            if calculated_ecc != stored_ecc:
                success = False
                error_type = 'ecc_mismatch'
            
            # Update error statistics
            self.conn.execute("""
                UPDATE page_states
                SET error_count = error_count + ?,
                    bit_error_rate = (
                        bit_error_rate * program_count + ?
                    ) / (program_count + 1),
                    updated_at = CURRENT_TIMESTAMP
                WHERE block_id = ? AND page_id = ?
            """, [
                corrections,
                corrections / len(self.cells),
                self.block_id,
                self.page_id
            ])
            
            # Log operation
            duration = int((time.time() - start_time) * 1e9)
            self.conn.execute("""
                INSERT INTO page_operations (
                    block_id, page_id, operation_type,
                    data_size, success, error_type,
                    ecc_corrections, duration_ns
                ) VALUES (?, ?, 'read', ?, ?, ?, ?, ?)
            """, [
                self.block_id,
                self.page_id,
                len(data),
                success,
                error_type,
                corrections,
                duration
            ])
            
            return data if success else None, stored_ecc
            
        except Exception as e:
            logging.error(f"Page read failed: {str(e)}")
            return None, 0

    def _calculate_ecc(self, data: List[int]) -> int:
        """Calculate ECC value using Hamming code"""
        try:
            # Simple parity-based ECC for demonstration
            # In production, use proper Hamming or BCH codes
            ecc = 0
            for i, value in enumerate(data):
                ecc ^= (value & 0xFF) << (i % 8)
            return ecc
            
        except Exception as e:
            logging.error(f"ECC calculation failed: {str(e)}")
            return 0
            
    def get_health_status(self) -> Dict[str, any]:
        """Get page health and performance metrics"""
        try:
            result = self.conn.execute("""
                SELECT program_count, erase_count, error_count,
                       bit_error_rate, 
                       EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - created_at)) as age_seconds,
                       (
                           SELECT COUNT(*) 
                           FROM page_operations 
                           WHERE block_id = ? AND page_id = ? 
                           AND success = false
                       ) as total_errors
                FROM page_states
                WHERE block_id = ? AND page_id = ?
            """, [
                self.block_id, self.page_id,
                self.block_id, self.page_id
            ]).fetchone()
            
            if not result:
                return {}
                
            prog_count, erase_count, errors, ber, age, total_errors = result
            
            # Calculate derived metrics
            hours_since_creation = age / 3600
            errors_per_hour = total_errors / max(1, hours_since_creation)
            
            return {
                "program_count": prog_count,
                "erase_count": erase_count,
                "error_count": errors,
                "bit_error_rate": round(ber, 6),
                "errors_per_hour": round(errors_per_hour, 2),
                "health_score": max(0, 100 - (ber * 100) - (erase_count / 100)),
                "estimated_life_remaining": max(0, 100 - (erase_count / 100))
            }
            
        except Exception as e:
            logging.error(f"Failed to get page health status: {str(e)}")
            return {}
