from typing import Optional, Dict, Set
import time
import json
from .ftl import FTL, BlockMetadata
from .nand_memory import DynamicNANDMemory

class DynamicFTL(FTL):
    """Dynamic Flash Translation Layer with unlimited scaling capabilities"""
    
    def __init__(self, nand_memory: DynamicNANDMemory, dynamic_mapping: bool = True, unlimited_blocks: bool = True):
        super().__init__(total_blocks=nand_memory.get_total_blocks(), pages_per_block=nand_memory.pages_per_block)
        self.nand_memory = nand_memory
        self.dynamic_mapping = dynamic_mapping
        self.unlimited_blocks = unlimited_blocks
        
        # Dynamic scaling thresholds
        self.scale_threshold = 0.85  # Scale when 85% full
        self.scale_factor = 2.0      # Double size when scaling
        
        # Initialize dynamic state
        self.block_allocations = {}  # Track block allocations
        self.block_temperatures = {} # Track block heat
        self.mapping_table_size = 0  # Track mapping table size
        
    def get_free_block(self) -> Optional[int]:
        """Get a free block with dynamic scaling if needed"""
        block = super().get_free_block()
        
        if block is None and self.unlimited_blocks:
            # Calculate current utilization
            total_blocks = self.nand_memory.get_total_blocks()
            used_blocks = len(set(self.block_allocations.values()))
            utilization = used_blocks / total_blocks
            
            if utilization >= self.scale_threshold:
                # Scale up NAND memory
                new_size = int(total_blocks * self.scale_factor)
                self.nand_memory.scale_to(new_size)
                
                # Update our block count
                self.total_blocks = self.nand_memory.get_total_blocks()
                
                # Try getting a free block again
                block = super().get_free_block()
                
        return block
        
    def map(self, lba: int, phys: int, is_hot: bool = False):
        """Map logical to physical address with dynamic table scaling"""
        if self.dynamic_mapping:
            # Check if we need to scale mapping table
            if lba >= self.mapping_table_size:
                self.scale_mapping_table(lba + 1)
                
        # Update block temperature tracking
        block_id = phys // self.pages_per_block
        self.block_temperatures[block_id] = time.time()
        self.block_allocations[lba] = block_id
        
        # Perform mapping in parent
        super().map(lba, phys, is_hot)
        
    def scale_mapping_table(self, new_size: int):
        """Scale the mapping table to accommodate more entries"""
        try:
            # Start transaction
            self.conn.execute("BEGIN TRANSACTION")
            
            # Create temporary table with new size
            self.conn.execute("""
                CREATE TEMPORARY TABLE new_mapping AS
                SELECT * FROM address_mapping
            """)
            
            # Drop old table and recreate with more capacity
            self.conn.execute("DROP TABLE address_mapping")
            self.conn.execute("""
                CREATE TABLE address_mapping (
                    lba INTEGER PRIMARY KEY,
                    phys_addr INTEGER UNIQUE,
                    last_update TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """)
            
            # Copy data back
            self.conn.execute("""
                INSERT INTO address_mapping
                SELECT * FROM new_mapping
            """)
            
            # Drop temporary table
            self.conn.execute("DROP TABLE new_mapping")
            
            # Update size tracking
            self.mapping_table_size = new_size
            
            # Commit transaction
            self.conn.execute("COMMIT")
            
        except Exception as e:
            self.conn.execute("ROLLBACK")
            raise RuntimeError(f"Failed to scale mapping table: {str(e)}")
            
    def get_block_temperature(self, block_id: int) -> float:
        """Get block temperature based on access patterns"""
        last_access = self.block_temperatures.get(block_id, 0)
        current_time = time.time()
        age = current_time - last_access
        
        # Temperature decays exponentially with age
        temperature = max(0.0, 1.0 - (age / 3600.0))  # Decay over 1 hour
        return temperature
        
    def garbage_collect(self):
        """Enhanced garbage collection with temperature awareness"""
        # Get blocks sorted by invalid page count and temperature
        candidates = []
        for block_id in range(self.total_blocks):
            meta = self.block_metadata.get(block_id)
            if meta and len(meta.invalid_pages) > 0:
                temperature = self.get_block_temperature(block_id)
                score = (len(meta.invalid_pages) / self.pages_per_block) * (1 + temperature)
                candidates.append((score, block_id))
                
        # Sort by score (higher score = better GC candidate)
        candidates.sort(reverse=True)
        
        # Process top candidates
        for _, block_id in candidates[:5]:
            self._collect_block(block_id)
            
    def _collect_block(self, block_id: int):
        """Process a single block for garbage collection"""
        meta = self.block_metadata[block_id]
        
        # Move valid pages to new locations
        for page_id in meta.valid_pages.copy():
            phys_addr = block_id * self.pages_per_block + page_id
            lba = self.phys_to_lba.get(phys_addr)
            
            if lba is not None:
                # Allocate new location based on temperature
                temperature = self.get_block_temperature(block_id)
                new_block = self.get_free_block()
                
                if new_block is not None:
                    new_page = self._get_free_page_in_block(new_block)
                    if new_page is not None:
                        new_phys = new_block * self.pages_per_block + new_page
                        self.map(lba, new_phys, temperature > 0.5)
                        
        # Erase the block
        meta.erase_count += 1
        meta.valid_pages.clear()
        meta.invalid_pages.clear()
        meta.last_write_time = time.time()
        self.free_blocks.add(block_id)
