"""
Enhanced NAND Block implementation with remote state tracking
"""

import time
import json
import logging
import duckdb
import os
from typing import List, Dict, Optional
from dataclasses import dataclass
from huggingface_hub import HfApi, HfFileSystem
from nand_page import Page
from config import get_hf_token

# Initialize token from .env



@dataclass
class BlockMetrics:
    """Block-level performance and health metrics"""
    erase_count: int
    program_errors: int
    read_errors: int
    retention_errors: int
    last_erase_time: float
    temperature: float
    voltage_drift: float
    bit_error_rate: float

class Block:
    """NAND Block with database-backed state tracking"""
    
    DB_URL = "hf://datasets/Fred808/helium/storage.json"
    
    def __init__(self, block_id: int, num_pages: int, num_cells_per_page: int,
                channel_length: float, drift_velocity: float, levels: int):
        self.block_id = block_id
        self.num_pages = num_pages
        self.num_cells = num_cells_per_page
        self.channel_length = channel_length
        self.drift_velocity = drift_velocity
        self.levels = levels
        
        # Initialize database connection
        self._init_db_connection()
        self._setup_database()
        
        # Initialize pages and metrics
        self._init_block()
        
    def _init_db_connection(self):
        """Initialize database connection with HuggingFace configuration"""
        # Connect directly to HuggingFace URL
        # First create an in-memory connection to configure settings
        temp_conn = duckdb.connect(":memory:")
        
        # Configure HuggingFace access - must be done before connecting to URL
        temp_conn.execute("INSTALL httpfs;")
        temp_conn.execute("LOAD httpfs;")
        temp_conn.execute("SET s3_endpoint='hf.co';")
        temp_conn.execute("SET s3_use_ssl=true;")
        temp_conn.execute("SET s3_url_style='path';")
        
        # Ensure using remote URL
        if not self.DB_URL.startswith('hf://'):
            self.DB_URL = f"hf://datasets/Fred808/helium/{os.path.basename(self.DB_URL)}"
        
        # Now create the real connection with the configured settings
        self.conn = duckdb.connect(self.DB_URL, config={'http_keep_alive': 'true'})
        self.conn.execute("INSTALL httpfs;")
        self.conn.execute("LOAD httpfs;")
        self.conn.execute("SET s3_endpoint='hf.co';")
        self.conn.execute("SET s3_use_ssl=true;")
        self.conn.execute("SET s3_url_style='path';")
        
        # Close temporary connection
        temp_conn.close()
        
    def _setup_database(self):
        """Initialize database tables"""
        # Block state table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS block_states (
                block_id INTEGER PRIMARY KEY,
                page_states JSON,
                erase_count INTEGER DEFAULT 0,
                program_errors INTEGER DEFAULT 0,
                read_errors INTEGER DEFAULT 0,
                retention_errors INTEGER DEFAULT 0,
                last_erase_time TIMESTAMP,
                temperature FLOAT DEFAULT 25.0,
                voltage_drift FLOAT DEFAULT 0.0,
                bit_error_rate FLOAT DEFAULT 0.0,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Block operations history
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS block_operations (
                operation_id INTEGER PRIMARY KEY,
                block_id INTEGER,
                operation_type VARCHAR,
                page_number INTEGER,
                data_size INTEGER,
                success BOOLEAN,
                error_type VARCHAR,
                operation_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                duration_ms INTEGER,
                FOREIGN KEY (block_id) REFERENCES block_states(block_id)
            )
        """)
        
    def _init_block(self):
        """Initialize block state in database"""
        try:
            # Create pages
            self.pages = [
                Page(self.num_cells, self.channel_length,
                     self.drift_velocity, self.levels)
                for _ in range(self.num_pages)
            ]
            
            # Initialize block state in database
            self.conn.execute("""
                INSERT OR IGNORE INTO block_states (
                    block_id, page_states, last_erase_time
                ) VALUES (?, ?, CURRENT_TIMESTAMP)
            """, [
                self.block_id,
                json.dumps([{
                    'erased': True,
                    'program_count': 0,
                    'read_count': 0,
                    'last_program_time': None,
                    'last_read_time': None
                } for _ in range(self.num_pages)])
            ])
            
        except Exception as e:
            logging.error(f"Failed to initialize block {self.block_id}: {str(e)}")
            raise
            
    def erase(self) -> bool:
        """Erase block with wear tracking and error detection"""
        try:
            start_time = time.time()
            
            # Start transaction
            self.conn.execute("BEGIN TRANSACTION")
            
            # Get current block state
            state = self.conn.execute("""
                SELECT erase_count, temperature, voltage_drift 
                FROM block_states 
                WHERE block_id = ?
            """, [self.block_id]).fetchone()
            
            if not state:
                raise RuntimeError(f"Block {self.block_id} state not found")
                
            erase_count, temperature, voltage_drift = state
            
            # Calculate erase stress based on wear
            erase_stress = min(1.0, erase_count / 100000.0)  # Assume 100K P/E cycles life
            error_probability = erase_stress * (1 + temperature/100) * (1 + abs(voltage_drift))
            
            # Perform erase operation
            success = True
            error_type = None
            
            if error_probability > 0.9:  # High risk of failure
                success = False
                error_type = "wear_out"
            else:
                # Erase pages
                for page in self.pages:
                    page.erase()
            
            # Update block state
            self.conn.execute("""
                UPDATE block_states 
                SET erase_count = erase_count + 1,
                    page_states = ?,
                    last_erase_time = CURRENT_TIMESTAMP,
                    bit_error_rate = ?,
                    updated_at = CURRENT_TIMESTAMP
                WHERE block_id = ?
            """, [
                json.dumps([{
                    'erased': True,
                    'program_count': 0,
                    'read_count': 0,
                    'last_program_time': None,
                    'last_read_time': None
                } for _ in range(self.num_pages)]),
                error_probability,
                self.block_id
            ])
            
            # Log operation
            duration = int((time.time() - start_time) * 1000)
            self.conn.execute("""
                INSERT INTO block_operations (
                    block_id, operation_type, success,
                    error_type, duration_ms
                ) VALUES (?, 'erase', ?, ?, ?)
            """, [self.block_id, success, error_type, duration])
            
            # Commit transaction
            self.conn.execute("COMMIT")
            
            return success
            
        except Exception as e:
            self.conn.execute("ROLLBACK")
            logging.error(f"Block {self.block_id} erase failed: {str(e)}")
            return False
            
    def program_page(self, page_num: int, data: bytes) -> bool:
        """Program a page with error tracking"""
        try:
            start_time = time.time()
            
            # Start transaction
            self.conn.execute("BEGIN TRANSACTION")
            
            # Get page state
            block_state = self.conn.execute("""
                SELECT page_states, erase_count, temperature
                FROM block_states
                WHERE block_id = ?
            """, [self.block_id]).fetchone()
            
            if not block_state:
                raise RuntimeError(f"Block {self.block_id} state not found")
                
            page_states = json.loads(block_state[0])
            erase_count = block_state[1]
            temperature = block_state[2]
            
            # Verify page is erased
            if not page_states[page_num]['erased']:
                raise RuntimeError(f"Page {page_num} must be erased before programming")
            
            # Calculate program stress
            program_stress = min(1.0, erase_count / 100000.0)
            error_probability = program_stress * (1 + temperature/100)
            
            # Perform program operation
            success = True
            error_type = None
            
            if error_probability > 0.8:  # High risk of failure
                success = False
                error_type = "program_failure"
            else:
                success = self.pages[page_num].program(data)
                if not success:
                    error_type = "program_error"
            
            # Update page state
            page_states[page_num].update({
                'erased': False,
                'program_count': page_states[page_num]['program_count'] + 1,
                'last_program_time': time.time()
            })
            
            # Update block state
            self.conn.execute("""
                UPDATE block_states
                SET page_states = ?,
                    program_errors = CASE WHEN ? THEN program_errors + 1 ELSE program_errors END,
                    bit_error_rate = ?,
                    updated_at = CURRENT_TIMESTAMP
                WHERE block_id = ?
            """, [
                json.dumps(page_states),
                not success,
                error_probability,
                self.block_id
            ])
            
            # Log operation
            duration = int((time.time() - start_time) * 1000)
            self.conn.execute("""
                INSERT INTO block_operations (
                    block_id, operation_type, page_number,
                    data_size, success, error_type, duration_ms
                ) VALUES (?, 'program', ?, ?, ?, ?, ?)
            """, [
                self.block_id, page_num, len(data),
                success, error_type, duration
            ])
            
            # Commit transaction
            self.conn.execute("COMMIT")
            return success
            
        except Exception as e:
            self.conn.execute("ROLLBACK")
            logging.error(f"Block {self.block_id} page {page_num} program failed: {str(e)}")
            return False
            
    def read_page(self, page_num: int) -> Optional[bytes]:
        """Read a page with error detection and correction"""
        try:
            start_time = time.time()
            
            # Start transaction
            self.conn.execute("BEGIN TRANSACTION")
            
            # Get page state
            block_state = self.conn.execute("""
                SELECT page_states, erase_count, temperature, voltage_drift
                FROM block_states
                WHERE block_id = ?
            """, [self.block_id]).fetchone()
            
            if not block_state:
                raise RuntimeError(f"Block {self.block_id} state not found")
                
            page_states = json.loads(block_state[0])
            erase_count = block_state[1]
            temperature = block_state[2]
            voltage_drift = block_state[3]
            
            # Calculate read stress factors
            retention_time = time.time() - page_states[page_num]['last_program_time'] if page_states[page_num]['last_program_time'] else 0
            retention_factor = min(1.0, retention_time / (365 * 24 * 3600))  # 1 year retention
            read_stress = (erase_count / 100000.0) * (1 + temperature/100) * (1 + abs(voltage_drift))
            
            # Perform read operation
            success = True
            error_type = None
            data = None
            
            if read_stress > 0.9:  # High risk of uncorrectable errors
                success = False
                error_type = "read_failure"
            else:
                try:
                    data = self.pages[page_num].read()
                    if not data:
                        success = False
                        error_type = "read_error"
                except Exception as e:
                    success = False
                    error_type = str(e)
            
            # Update page state
            page_states[page_num]['read_count'] += 1
            page_states[page_num]['last_read_time'] = time.time()
            
            # Update block state
            self.conn.execute("""
                UPDATE block_states
                SET page_states = ?,
                    read_errors = CASE WHEN ? THEN read_errors + 1 ELSE read_errors END,
                    retention_errors = CASE WHEN ? THEN retention_errors + 1 ELSE retention_errors END,
                    bit_error_rate = ?,
                    updated_at = CURRENT_TIMESTAMP
                WHERE block_id = ?
            """, [
                json.dumps(page_states),
                not success,
                retention_factor > 0.8,
                read_stress,
                self.block_id
            ])
            
            # Log operation
            duration = int((time.time() - start_time) * 1000)
            self.conn.execute("""
                INSERT INTO block_operations (
                    block_id, operation_type, page_number,
                    data_size, success, error_type, duration_ms
                ) VALUES (?, 'read', ?, ?, ?, ?, ?)
            """, [
                self.block_id, page_num,
                len(data) if data else 0,
                success, error_type, duration
            ])
            
            # Commit transaction
            self.conn.execute("COMMIT")
            return data if success else None
            
        except Exception as e:
            self.conn.execute("ROLLBACK")
            logging.error(f"Block {self.block_id} page {page_num} read failed: {str(e)}")
            return None
            
    def get_health_metrics(self) -> BlockMetrics:
        """Get block health and performance metrics"""
        try:
            metrics = self.conn.execute("""
                SELECT 
                    erase_count,
                    program_errors,
                    read_errors,
                    retention_errors,
                    last_erase_time,
                    temperature,
                    voltage_drift,
                    bit_error_rate
                FROM block_states
                WHERE block_id = ?
            """, [self.block_id]).fetchone()
            
            if not metrics:
                raise RuntimeError(f"Block {self.block_id} metrics not found")
                
            return BlockMetrics(
                erase_count=metrics[0],
                program_errors=metrics[1],
                read_errors=metrics[2],
                retention_errors=metrics[3],
                last_erase_time=metrics[4],
                temperature=metrics[5],
                voltage_drift=metrics[6],
                bit_error_rate=metrics[7]
            )
            
        except Exception as e:
            logging.error(f"Failed to get block {self.block_id} metrics: {str(e)}")
            return BlockMetrics(0, 0, 0, 0, 0.0, 25.0, 0.0, 0.0)
            
    def predict_remaining_life(self) -> float:
        """Predict remaining life percentage based on wear and errors"""
        try:
            metrics = self.get_health_metrics()
            
            # Calculate wear factor (0-1)
            wear_factor = min(1.0, metrics.erase_count / 100000.0)
            
            # Calculate error factor (0-1)
            total_errors = (
                metrics.program_errors +
                metrics.read_errors +
                metrics.retention_errors
            )
            error_factor = min(1.0, total_errors / 10000.0)
            
            # Calculate environmental stress (0-1)
            temp_stress = max(0, (metrics.temperature - 25) / 100)
            voltage_stress = abs(metrics.voltage_drift)
            env_factor = min(1.0, (temp_stress + voltage_stress) / 2)
            
            # Combined health score (0-100%)
            health_score = (
                (1 - wear_factor) * 0.4 +
                (1 - error_factor) * 0.4 +
                (1 - env_factor) * 0.2
            ) * 100
            
            return max(0.0, health_score)
            
        except Exception as e:
            logging.error(f"Failed to predict block {self.block_id} life: {str(e)}")
            return 0.0
