
"""
NAND Flash SSD Simulation (Modular)
-----------------------------------
This file documents the SSD architecture and usage for the modular simulation.

Components:
- nand_cell.py: MultiLevelCell (single cell physics/logic)
- nand_page.py: Page (group of cells, ECC)
- nand_block.py: Block (group of pages)
- nand_plane.py: Plane (group of blocks)
- dram_cache.py: DRAMCache, Buffer (cache, buffer, metadata)
- ftl.py: FTL (Flash Translation Layer, mapping table)
- ssd_controller.py: SSDController (manages all above, FTL, cache, buffer)
- main.py: Demo/entry point

Usage:
------
Import and use the SSDController and other components in your own scripts, or run main.py for a demo.

Example:
    from ssd_controller import SSDController
    ssd = SSDController(...)
    ssd.program(lba, data)
    ssd.read(lba)

See main.py for a full demonstration of SSD features, including DRAM cache, buffer, FTL, wear leveling, garbage collection, and retention simulation.
"""

from typing import Optional, Dict, List
import numpy as np

class NANDMemory:
    def __init__(self, num_planes=16, blocks_per_plane=512, pages_per_block=512, bytes_per_page=32768):
        """Initialize NAND memory with fixed configuration"""
        self.num_planes = num_planes
        self.blocks_per_plane = blocks_per_plane
        self.pages_per_block = pages_per_block
        self.bytes_per_page = bytes_per_page
        
        # Initialize storage structures
        self.storage: Dict[tuple, np.ndarray] = {}  # (plane, block, page) -> data
        self.erase_counts: Dict[tuple, int] = {}    # (plane, block) -> count
        self.valid_pages: Dict[tuple, set] = {}     # (plane, block) -> {valid pages}

    def read_page(self, plane: int, block: int, page: int) -> Optional[np.ndarray]:
        """Read a page from NAND memory"""
        key = (plane, block, page)
        return self.storage.get(key)

    def write_page(self, plane: int, block: int, page: int, data: np.ndarray) -> bool:
        """Write a page to NAND memory"""
        if not self._is_page_free(plane, block, page):
            return False
            
        key = (plane, block, page)
        block_key = (plane, block)
        
        self.storage[key] = data
        if block_key not in self.valid_pages:
            self.valid_pages[block_key] = set()
        self.valid_pages[block_key].add(page)
        return True

    def erase_block(self, plane: int, block: int) -> bool:
        """Erase a block in NAND memory"""
        block_key = (plane, block)
        
        # Update erase count
        self.erase_counts[block_key] = self.erase_counts.get(block_key, 0) + 1
        
        # Remove all pages in block
        if block_key in self.valid_pages:
            for page in self.valid_pages[block_key]:
                key = (plane, block, page)
                if key in self.storage:
                    del self.storage[key]
            del self.valid_pages[block_key]
        
        return True

    def _is_page_free(self, plane: int, block: int, page: int) -> bool:
        """Check if a page is free (not written)"""
        key = (plane, block, page)
        block_key = (plane, block)
        return key not in self.storage or page not in self.valid_pages.get(block_key, set())

class DynamicNANDMemory(NANDMemory):
    def __init__(self, storage_manager, initial_planes=16, initial_blocks=512, initial_pages=512, 
                 bytes_per_page=32768, auto_scale=True, max_scale_factor=float('inf')):
        """Initialize dynamic NAND memory with scaling capabilities
        
        Args:
            storage_manager: Remote storage manager instance for persistence
            initial_planes (int): Initial number of planes
            initial_blocks (int): Initial blocks per plane
            initial_pages (int): Initial pages per block
            bytes_per_page (int): Bytes per page
            auto_scale (bool): Whether to automatically scale when needed
            max_scale_factor (float): Maximum scaling factor (inf for unlimited)
        """
        super().__init__(num_planes=initial_planes, blocks_per_plane=initial_blocks,
                        pages_per_block=initial_pages, bytes_per_page=bytes_per_page)
        
        self.storage_manager = storage_manager
        self.auto_scale = auto_scale
        self.max_scale_factor = max_scale_factor
        
        # Scaling parameters
        self.initial_planes = initial_planes
        self.initial_blocks = initial_blocks
        self.initial_pages = initial_pages
        self.current_scale = 1.0
        self.scale_threshold = 0.8  # Scale up when 80% full
        
        # Load existing state if available
        self._load_state()

    def write_page(self, plane: int, block: int, page: int, data: np.ndarray) -> bool:
        """Write a page with dynamic scaling"""
        # Check if we need more space
        if self.auto_scale and self._usage_ratio() > self.scale_threshold:
            self._scale_up()
            
        # Validate coordinates after potential scaling
        if not self._validate_coordinates(plane, block, page):
            return False
            
        success = super().write_page(plane, block, page, data)
        if success:
            # Persist to remote storage
            self._persist_page(plane, block, page, data)
        return success

    def _usage_ratio(self) -> float:
        """Calculate current usage ratio"""
        total_pages = self.num_planes * self.blocks_per_plane * self.pages_per_block
        used_pages = sum(len(pages) for pages in self.valid_pages.values())
        return used_pages / total_pages

    def _scale_up(self):
        """Scale up NAND memory if allowed"""
        if self.current_scale < self.max_scale_factor:
            scale_factor = min(2.0, self.max_scale_factor / self.current_scale)
            
            # Scale dimensions
            self.num_planes = int(self.initial_planes * scale_factor)
            self.blocks_per_plane = int(self.initial_blocks * scale_factor)
            self.pages_per_block = int(self.initial_pages * scale_factor)
            
            self.current_scale *= scale_factor
            
            # Persist new configuration
            self._persist_config()

    def _validate_coordinates(self, plane: int, block: int, page: int) -> bool:
        """Validate memory coordinates"""
        return (0 <= plane < self.num_planes and
                0 <= block < self.blocks_per_plane and
                0 <= page < self.pages_per_block)

    def _persist_page(self, plane: int, block: int, page: int, data: np.ndarray):
        """Persist page data to remote storage"""
        key = f"nand_page_{plane}_{block}_{page}"
        self.storage_manager.store(key, data.tobytes())

    def _persist_config(self):
        """Persist current configuration to remote storage"""
        config = {
            'num_planes': self.num_planes,
            'blocks_per_plane': self.blocks_per_plane,
            'pages_per_block': self.pages_per_block,
            'current_scale': self.current_scale,
            'erase_counts': self.erase_counts,
        }
        self.storage_manager.store('nand_config', config)

    def _load_state(self):
        """Load existing state from remote storage"""
        config = self.storage_manager.retrieve('nand_config')
        if config:
            self.num_planes = config['num_planes']
            self.blocks_per_plane = config['blocks_per_plane']
            self.pages_per_block = config['pages_per_block']
            self.current_scale = config['current_scale']
            self.erase_counts = config['erase_counts']
            
            # Load all pages
            for key in self.storage_manager.list_keys('nand_page_*'):
                _, plane, block, page = key.split('_')
                data = np.frombuffer(self.storage_manager.retrieve(key), dtype=np.uint8)
                super().write_page(int(plane), int(block), int(page), data)
