"""
Database-backed memory implementation for virtual CPU.
Uses SQLite to provide persistent storage for CPU memory.
"""

import sqlite3
from logic_gates import VDD, VSS, VTH
import os

class DatabaseMemory:
    """CPU Memory implementation using SQLite database"""
    
    def __init__(self, size=256, db_path='cpu_memory.db'):
        """Initialize memory with given size and database path"""
        self.size = size
        self.db_path = db_path
        self.initialize_db()
        
    def initialize_db(self):
        """Create and initialize the memory database"""
        # Create new database connection
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Create memory cells table if it doesn't exist
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS memory_cells (
                    address INTEGER PRIMARY KEY,
                    value INTEGER NOT NULL DEFAULT 0,
                    last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            ''')
            
            # Create voltage states table for detailed simulation
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS voltage_states (
                    address INTEGER PRIMARY KEY,
                    voltage REAL NOT NULL DEFAULT 0.0,
                    last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            ''')
            
            # Initialize memory cells if empty
            cursor.execute('SELECT COUNT(*) FROM memory_cells')
            count = cursor.fetchone()[0]
            
            if count == 0:
                # Initialize all memory cells to 0
                cursor.executemany(
                    'INSERT INTO memory_cells (address, value) VALUES (?, 0)',
                    [(i,) for i in range(self.size)]
                )
                
                # Initialize voltage states
                cursor.executemany(
                    'INSERT INTO voltage_states (address, voltage) VALUES (?, ?)',
                    [(i, VSS) for i in range(self.size)]
                )
            
            conn.commit()

    def read(self, address):
        """Read from memory address"""
        if not (0 <= address < self.size):
            return 0
            
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Read digital value
            cursor.execute('SELECT value FROM memory_cells WHERE address = ?', (address,))
            result = cursor.fetchone()
            
            if result is None:
                return 0
                
            return result[0]
            
    def write(self, address, data, clk):
        """Write to memory address"""
        if not (0 <= address < self.size):
            return
            
        # Only write on clock high
        if clk <= VTH:
            return
            
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Update digital value
            cursor.execute('''
                UPDATE memory_cells 
                SET value = ?, last_updated = CURRENT_TIMESTAMP
                WHERE address = ?
            ''', (data, address))
            
            # Update voltage state
            voltage = VDD if data != 0 else VSS
            cursor.execute('''
                UPDATE voltage_states
                SET voltage = ?, last_updated = CURRENT_TIMESTAMP
                WHERE address = ?
            ''', (voltage, address))
            
            conn.commit()
            
    def clear(self):
        """Clear all memory cells"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('UPDATE memory_cells SET value = 0')
            cursor.execute('UPDATE voltage_states SET voltage = ?', (VSS,))
            
            conn.commit()
            
    def dump(self, start=0, count=16):
        """Dump memory contents for debugging"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT m.address, m.value, v.voltage, m.last_updated
                FROM memory_cells m
                JOIN voltage_states v ON m.address = v.address
                WHERE m.address >= ? AND m.address < ?
                ORDER BY m.address
            ''', (start, start + count))
            
            results = cursor.fetchall()
            
            print("\nMemory Dump:")
            print("Address  | Value | Voltage | Last Updated")
            print("-" * 50)
            
            for addr, val, volt, timestamp in results:
                print(f"{addr:08X} | {val:4d} | {volt:7.3f} | {timestamp}")
                
    def get_state(self):
        """Get full memory state for save/restore"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('SELECT address, value FROM memory_cells ORDER BY address')
            return dict(cursor.fetchall())
            
    def set_state(self, state_dict):
        """Restore memory state"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            for address, value in state_dict.items():
                voltage = VDD if value != 0 else VSS
                
                cursor.execute('''
                    UPDATE memory_cells 
                    SET value = ?, last_updated = CURRENT_TIMESTAMP
                    WHERE address = ?
                ''', (value, address))
                
                cursor.execute('''
                    UPDATE voltage_states
                    SET voltage = ?, last_updated = CURRENT_TIMESTAMP
                    WHERE address = ?
                ''', (voltage, address))
                
            conn.commit()

# Example usage
if __name__ == "__main__":
    # Create memory instance
    mem = DatabaseMemory(size=256)
    
    # Write some test values
    mem.write(0, 42, VDD)
    mem.write(1, 255, VDD)
    mem.write(2, 128, VDD)
    
    # Read values back
    print(f"Address 0: {mem.read(0)}")
    print(f"Address 1: {mem.read(1)}")
    print(f"Address 2: {mem.read(2)}")
    
    # Dump memory contents
    mem.dump(0, 4)