import numpy as np
import uuid
import time
import json
import logging
from typing import Dict, Optional, Tuple
from datetime import datetime
from pathlib import Path
from .storage_manager import StorageManager, CellState

class MultiLevelCell:
    """Multi-level NAND flash memory cell with quantum state tracking"""
    
    def __init__(self, channel_length: float, drift_velocity: float, levels: int, 
                 block_id: int = 0, page_id: int = 0):
        self.cell_id = str(uuid.uuid4())
        self.block_id = block_id
        self.page_id = page_id
        self.channel_length = channel_length
        self.drift_velocity = drift_velocity
        self.levels = levels
        
        # Static parameters
        self.max_wear_count = 100000  # P/E cycles
        self.voltage_thresholds = np.linspace(0, 3.3, levels)  # V
        
        # Initialize storage
        self.storage = StorageManager()
        self._init_cell_state()

    def program(self, value):
        target_level = max(0, min(self.levels-1, value))
        
        # Calculate quantum tunneling probability
        tunneling_factor = 0.8 + (0.2 * (self.wear_count / self.max_wear_count))
        success_prob = 1.0 - (self.wear_count / self.max_wear_count)
        
        # Temperature compensation
        temp_factor = 1.0 + (0.1 * (self.temperature - 298.15) / 50)
        prog_time = self.channel_length / (self.drift_velocity * tunneling_factor * temp_factor)
        
        # Save updated state
        self._save_state()
        
        if np.random.random() < success_prob:
            self.value = target_level
            self.trapped_electrons = target_level * 100  # Approx electrons per level
            self.electron_state = np.zeros(self.levels)
            self.electron_state[target_level] = 1.0  # Set quantum state
            self.wear_count += 1
            self.retention_loss = 0.0
            self._sync_to_remote()  # Sync to remote storage
            return prog_time, True
        
        self.error_history.append(('program_fail', self.wear_count))
        return prog_time, False

    def _init_db_connection(self):
        """Initialize database connection with HuggingFace configuration"""
        try:
            # Connect directly to HuggingFace URL
            # First create an in-memory connection to configure settings
            temp_conn = duckdb.connect(":memory:")
            
            # Configure HuggingFace access - must be done before connecting to URL
            temp_conn.execute("INSTALL httpfs;")
            temp_conn.execute("LOAD httpfs;")
            temp_conn.execute("SET s3_endpoint='hf.co';")
            temp_conn.execute("SET s3_use_ssl=true;")
            temp_conn.execute("SET s3_url_style='path';")
            
            # Now create the real connection with the configured settings
            self.conn = duckdb.connect(self.STORAGE_URL, config={'http_keep_alive': 'true'})
            self.conn.execute("INSTALL httpfs;")
            self.conn.execute("LOAD httpfs;")
            self.conn.execute("SET s3_endpoint='hf.co';")
            self.conn.execute("SET s3_use_ssl=true;")
            self.conn.execute("SET s3_url_style='path';")
            
            # Close temporary connection
            temp_conn.close()
            
        except Exception as e:
            logging.error(f"Failed to initialize database connection: {str(e)}")
            raise

    def _setup_database(self):
        """Initialize database tables for cell state tracking"""
        try:
            # Cell state table with quantum tracking
            self.conn.execute("""
                CREATE TABLE IF NOT EXISTS cell_states (
                    cell_id VARCHAR PRIMARY KEY,
                    block_id INTEGER,
                    page_id INTEGER,
                    value INTEGER DEFAULT 0,
                    wear_count INTEGER DEFAULT 0,
                    trapped_electrons INTEGER DEFAULT 0,
                    electron_state VARCHAR,  -- JSON array of quantum state
                    temperature FLOAT DEFAULT 298.15,  -- Room temp in Kelvin
                    voltage_drift FLOAT DEFAULT 0.0,
                    retention_loss FLOAT DEFAULT 0.0,
                    error_rate FLOAT DEFAULT 0.0,
                    last_program TIMESTAMP,
                    last_read TIMESTAMP,
                    quantum_coherence FLOAT DEFAULT 1.0,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                    retention_loss FLOAT DEFAULT 0.0,
                    temperature FLOAT DEFAULT 298.15,
                    electron_state JSON,
                    error_history JSON,
                    last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """)
            
            # Cell operations history
            self.conn.execute("""
                CREATE TABLE IF NOT EXISTS cell_operations (
                    operation_id INTEGER PRIMARY KEY,
                    cell_id VARCHAR,
                    operation_type VARCHAR,
                    old_value INTEGER,
                    new_value INTEGER,
                    success BOOLEAN,
                    error_type VARCHAR,
                    operation_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    FOREIGN KEY (cell_id) REFERENCES cell_states(cell_id)
                )
            """)
            
        except Exception as e:
            logging.error(f"Failed to setup database: {str(e)}")
            raise

    def _init_cell_state(self):
        """Initialize cell state in database"""
        try:
            # Check if cell state exists
            state = self.conn.execute("""
                SELECT value, trapped_electrons, wear_count,
                       electron_state, temperature, retention_loss,
                       voltage_drift, quantum_coherence
                FROM cell_states
                WHERE cell_id = ?
            """, [self.cell_id]).fetchone()
            
            if state:
                # Load existing state
                (value, trapped_e, wear_count, e_state, temp,
                 retention, v_drift, coherence) = state
                
                self.value = value
                self.trapped_electrons = trapped_e
                self.wear_count = wear_count
                self.electron_state = np.array(json.loads(e_state))
                self.temperature = temp
                self.retention_loss = retention
                
                # Adjust voltage thresholds for drift
                base_thresholds = np.linspace(0, 3.3, self.levels)
                self.voltage_thresholds = base_thresholds + v_drift
                
            else:
                # Initialize new cell state
                self.value = 0
                self.trapped_electrons = 0
                self.wear_count = 0
                self.retention_loss = 0.0
                self.temperature = 298.15  # Room temperature
                self.electron_state = np.zeros(self.levels)
                self.voltage_thresholds = np.linspace(0, 3.3, self.levels)
                
                # Save initial state to database
                self.conn.execute("""
                    INSERT INTO cell_states (
                        cell_id, block_id, page_id, value,
                        trapped_electrons, wear_count, electron_state,
                        temperature, voltage_drift, retention_loss,
                        quantum_coherence, created_at, updated_at
                    ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
                """, [
                    self.cell_id,
                    self.block_id,
                    self.page_id,
                    self.value,
                    self.trapped_electrons,
                    self.wear_count,
                    json.dumps(self.electron_state.tolist()),
                    self.temperature,
                    0.0,  # No initial voltage drift
                    self.retention_loss,
                    1.0   # Initial perfect quantum coherence
                ])
        except Exception as e:
            logging.error(f"Failed to initialize cell state: {str(e)}")
            raise

    def _get_cell_state(self):
        """Get current cell state from database"""
        try:
            state = self.conn.execute("""
                SELECT value, trapped_electrons, wear_count,
                       retention_loss, temperature, electron_state,
                       error_history
                FROM cell_states
                WHERE cell_id = ?
            """, [self.cell_id]).fetchone()
            
            if not state:
                raise RuntimeError(f"Cell state not found for {self.cell_id}")
            
            return {
                'value': state[0],
                'trapped_electrons': state[1],
                'wear_count': state[2],
                'retention_loss': state[3],
                'temperature': state[4],
                'electron_state': np.array(json.loads(state[5])),
                'error_history': json.loads(state[6])
            }
            
        except Exception as e:
            logging.error(f"Failed to get cell state: {str(e)}")
            raise

    def _update_cell_state(self, **kwargs):
        """Update cell state in database"""
        try:
            updates = []
            values = []
            for key, value in kwargs.items():
                if isinstance(value, np.ndarray):
                    value = json.dumps(value.tolist())
                elif isinstance(value, list):
                    value = json.dumps(value)
                updates.append(f"{key} = ?")
                values.append(value)
            
            if updates:
                query = f"""
                    UPDATE cell_states 
                    SET {', '.join(updates)},
                        last_updated = CURRENT_TIMESTAMP
                    WHERE cell_id = ?
                """
                values.append(self.cell_id)
                self.conn.execute(query, values)
                
        except Exception as e:
            logging.error(f"Failed to update cell state: {str(e)}")
            raise

    def erase(self) -> float:
        """Erase cell contents with database tracking"""
        try:
            # Start transaction
            self.conn.execute("BEGIN TRANSACTION")
            
            # Get current state
            state = self._get_cell_state()
            old_value = state['value']
            
            # Perform erase
            state['trapped_electrons'] = 0
            state['value'] = 0
            state['wear_count'] += 1
            state['retention_loss'] = 0.0
            
            # Update state
            self._update_cell_state(**state)
            
            # Log operation
            self.conn.execute("""
                INSERT INTO cell_operations (
                    cell_id, operation_type, old_value,
                    new_value, success
                ) VALUES (?, 'erase', ?, ?, true)
            """, [self.cell_id, old_value, 0])
            
            # Calculate erase time
            erase_time = self.channel_length / self.drift_velocity
            
            # Commit transaction
            self.conn.execute("COMMIT")
            return erase_time
            
        except Exception as e:
            self.conn.execute("ROLLBACK")
            logging.error(f"Cell erase failed: {str(e)}")
            raise

    def read(self) -> Tuple[int, float]:
        """Read cell value with error checking and quantum effects"""
        try:
            start_time = time.time()
            
            # Get current state from database
            state = self.conn.execute("""
                SELECT value, electron_state, retention_loss, temperature,
                       voltage_drift, wear_count, quantum_coherence
                FROM cell_states
                WHERE cell_id = ?
            """, [self.cell_id]).fetchone()
            
            if not state:
                return 0, 0.0
                
            # Unpack state
            stored_value, e_state, retention, temp, v_drift, wear, coherence = state
            electron_state = np.array(json.loads(e_state))
            
            # Apply quantum decoherence effects
            decoherence = np.random.normal(0, 0.01 * (1 + retention))
            base_voltage = self.voltage_thresholds[stored_value]
            measured_voltage = base_voltage + decoherence + v_drift
            
            # Account for temperature effects
            temp_drift = 0.001 * (temp - 298.15)
            measured_voltage += temp_drift
            
            # Calculate read confidence
            level_diffs = [abs(v - measured_voltage) for v in self.voltage_thresholds]
            measured_level = level_diffs.index(min(level_diffs))
            confidence = 1.0 - (min(level_diffs) / self.voltage_thresholds[1])
            
            # Update quantum state and retention
            new_retention = retention
            new_value = stored_value
            error_occurred = False
            
            if stored_value > 0:
                new_retention += 0.001 * (1 + wear / self.max_wear_count)
                if new_retention > 0.5:
                    new_value = max(0, stored_value - 1)
                    electron_state[stored_value] = 0
                    electron_state[new_value] = 1
                    new_retention = 0.0
                    error_occurred = True
            
            # Update state in database
            self.conn.execute("""
                UPDATE cell_states
                SET value = ?,
                    electron_state = ?,
                    retention_loss = ?,
                    quantum_coherence = ?,
                    last_read = CURRENT_TIMESTAMP,
                    updated_at = CURRENT_TIMESTAMP,
                    error_rate = CASE
                        WHEN ? THEN (error_rate * read_count + 1) / (read_count + 1)
                        ELSE error_rate
                    END
                WHERE cell_id = ?
            """, [
                new_value,
                json.dumps(electron_state.tolist()),
                new_retention,
                float(np.max(electron_state)),
                error_occurred,
                self.cell_id
            ])
            
            # Log read operation
            duration = int((time.time() - start_time) * 1e9)  # Convert to nanoseconds
            self.conn.execute("""
                INSERT INTO cell_operations (
                    cell_id, operation_type, old_value,
                    new_value, success, error_type, duration_ns
                ) VALUES (?, 'read', ?, ?, ?, ?, ?)
            """, [
                self.cell_id,
                stored_value,
                measured_level,
                not error_occurred,
                'retention_error' if error_occurred else None,
                duration
            ])
            
            return measured_level, confidence
            
        except Exception as e:
            logging.error(f"Cell read failed: {str(e)}")
            return 0, 0.0
            
        except Exception as e:
            logging.error(f"Cell read failed: {str(e)}")
            return 0, 0.0
        
    def get_health_status(self):
        """Return detailed cell health metrics"""
        try:
            result = self.conn.execute("""
                SELECT wear_count, retention_loss, temperature,
                       voltage_drift, quantum_coherence, error_rate,
                       (
                           SELECT COUNT(*) 
                           FROM cell_operations 
                           WHERE cell_id = ? AND success = false
                       ) as error_count
                FROM cell_states
                WHERE cell_id = ?
            """, [self.cell_id, self.cell_id]).fetchone()
            
            if not result:
                return {}
                
            wear_count, retention_loss, temp, v_drift, coherence, error_rate, error_count = result
            
            return {
                'wear_percentage': (wear_count / self.max_wear_count) * 100,
                'retention_loss': retention_loss,
                'voltage_stability': 1.0 - abs(v_drift/3.3),
                'estimated_life_remaining': max(0, 1 - (wear_count / self.max_wear_count)),
                'quantum_coherence': coherence,
                'temperature': temp,
                'error_rate': error_rate if error_rate else error_count / max(1, wear_count)
            }
            
        except Exception as e:
            logging.error(f"Failed to get cell health status: {str(e)}")
            return {}
    
    def set_temperature(self, temp_kelvin):
        """Update cell temperature and adjust parameters"""
        try:
            self.temperature = temp_kelvin
            
            # Calculate new voltage thresholds
            temp_factor = 1.0 - 0.002 * (self.temperature - 298.15)
            self.voltage_thresholds = np.linspace(0, 3.3 * temp_factor, self.levels)
            
            # Update temperature in database
            self.conn.execute("""
                UPDATE cell_states
                SET temperature = ?,
                    voltage_drift = ?,
                    updated_at = CURRENT_TIMESTAMP
                WHERE cell_id = ?
            """, [
                temp_kelvin,
                float(3.3 * (1 - temp_factor)),
                self.cell_id
            ])
            
        except Exception as e:
            logging.error(f"Failed to update temperature: {str(e)}")
            raise
            
    def _sync_to_remote(self):
        """Synchronize cell state to remote storage"""
        try:
            self.conn.execute("""
                INSERT OR REPLACE INTO cell_states (
                    cell_id, block_id, page_id, value,
                    trapped_electrons, wear_count, electron_state,
                    temperature, voltage_drift, retention_loss,
                    quantum_coherence, last_program, updated_at
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
            """, [
                self.cell_id,
                self.block_id,
                self.page_id,
                self.value,
                self.trapped_electrons,
                self.wear_count,
                json.dumps(self.electron_state.tolist()),
                self.temperature,
                float(self.voltage_thresholds[self.value] - 3.3/2),
                self.retention_loss,
                float(np.max(self.electron_state))
            ])
            
        except Exception as e:
            logging.error(f"Failed to sync cell state: {str(e)}")
            raise

    def _load_from_remote(self):
        """Load cell state from remote storage"""
        try:
            result = self.conn.execute("""
                SELECT value, trapped_electrons, wear_count,
                       electron_state, temperature, retention_loss,
                       voltage_drift, quantum_coherence
                FROM cell_states
                WHERE cell_id = ?
            """, [self.cell_id]).fetchone()
            
            if result:
                self.value = result[0]
                self.trapped_electrons = result[1]
                self.wear_count = result[2]
                self.electron_state = np.array(json.loads(result[3]))
                self.temperature = result[4]
                self.retention_loss = result[5]
                voltage_drift = result[6]
                # Adjust voltage thresholds based on drift
                self.voltage_thresholds = np.linspace(
                    0 + voltage_drift,
                    3.3 + voltage_drift,
                    self.levels
                )
                
            else:
                # Initialize new cell
                self.value = 0
                self.trapped_electrons = 0
                self.wear_count = 0
                self.electron_state = np.zeros(self.levels)
                self.temperature = 298.15
                self.retention_loss = 0.0
                self._sync_to_remote()
                
        except Exception as e:
            logging.error(f"Failed to load cell state: {str(e)}")
            raise
