import numpy as np
from typing import Optional, List, Set
import time
import json
import logging
import duckdb
import os
from huggingface_hub import HfApi, HfFileSystem
from config import get_hf_token

# Initialize token from .env


class CacheEntry:
    """Entry in the page cache"""
    def __init__(self, lba: int, phys: int, access_count: int = 0,
                 last_access: float = None, is_dirty: bool = False):
        self.lba = lba
        self.phys = phys
        self.access_count = access_count
        self.last_access = last_access if last_access is not None else time.time()
        self.is_dirty = is_dirty


class BlockMetadata:
    """Metadata for a flash memory block"""
    def __init__(self, erase_count: int = 0, valid_pages: Set[int] = None,
                 invalid_pages: Set[int] = None, last_write_time: float = None,
                 temperature: float = 0.0):
        self.erase_count = erase_count
        self.valid_pages = valid_pages if valid_pages is not None else set()
        self.invalid_pages = invalid_pages if invalid_pages is not None else set()
        self.last_write_time = last_write_time if last_write_time is not None else time.time()
        self.temperature = temperature  # Temperature-based wear leveling metric


class AdvancedFTL:
    """Advanced Flash Translation Layer with database-backed storage"""
    
    DB_URL = "hf://datasets/Fred808/helium/storage.json"
    
    def __init__(self, total_blocks: int = 1024, pages_per_block: int = 256):
        self.total_blocks = total_blocks
        self.pages_per_block = pages_per_block
        self.buffer_size = 64  # MB
        self.cache_size = 1024  # pages
        
        # Initialize member variables
        self.block_metadata = {}  # Block ID -> BlockMetadata
        self.page_cache = {}      # LBA -> CacheEntry
        self.free_blocks = set(range(total_blocks))  # Initially all blocks are free
        self.hot_blocks = set()   # Blocks containing hot data
        self.cold_blocks = set()  # Blocks containing cold data
        self.cache_hits = 0
        self.cache_misses = 0
        
        # Initialize database connection
        self._init_db_connection()
        self._setup_database()
        
    def _init_db_connection(self):
        """Initialize database connection with HuggingFace configuration"""
        # Connect directly to HuggingFace URL
        # First create an in-memory connection to configure settings
        temp_conn = duckdb.connect(":memory:")
        
        # Configure HuggingFace access - must be done before connecting to URL
        temp_conn.execute("INSTALL httpfs;")
        temp_conn.execute("LOAD httpfs;")
        temp_conn.execute("SET s3_endpoint='hf.co';")
        temp_conn.execute("SET s3_use_ssl=true;")
        temp_conn.execute("SET s3_url_style='path';")
        
        # Ensure using remote URL
        if not self.DB_URL.startswith('hf://'):
            self.DB_URL = f"hf://datasets/Fred808/helium/{os.path.basename(self.DB_URL)}"
        
        # Now create the real connection with the configured settings
        self.conn = duckdb.connect(self.DB_URL, config={'http_keep_alive': 'true'})
        self.conn.execute("INSTALL httpfs;")
        self.conn.execute("LOAD httpfs;")
        self.conn.execute("SET s3_endpoint='hf.co';")
        self.conn.execute("SET s3_use_ssl=true;")
        self.conn.execute("SET s3_url_style='path';")
        
        # Close temporary connection
        temp_conn.close() 
        
    def _setup_database(self):
        """Initialize database tables"""
        # Address mapping table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS address_mapping (
                lba INTEGER PRIMARY KEY,
                phys_addr INTEGER UNIQUE,
                last_update TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Block metadata table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS block_metadata (
                block_id INTEGER PRIMARY KEY,
                erase_count INTEGER DEFAULT 0,
                valid_pages JSON,
                invalid_pages JSON,
                last_write_time TIMESTAMP,
                temperature FLOAT DEFAULT 0.0,
                is_free BOOLEAN DEFAULT true,
                is_hot BOOLEAN DEFAULT false
            )
        """)
        
        # Page cache table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS page_cache (
                lba INTEGER PRIMARY KEY,
                phys_addr INTEGER,
                access_count INTEGER DEFAULT 0,
                last_access TIMESTAMP,
                is_dirty BOOLEAN DEFAULT false
            )
        """)
        
        # Access history table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS access_history (
                lba INTEGER,
                access_time TIMESTAMP,
                PRIMARY KEY (lba, access_time)
            )
        """)
        
        # Initialize block metadata if empty
        self.conn.execute("""
            INSERT INTO block_metadata (block_id, valid_pages, invalid_pages, last_write_time)
            SELECT 
                value,
                '[]',
                '[]',
                CURRENT_TIMESTAMP
            FROM range(?) t
            WHERE NOT EXISTS (
                SELECT 1 FROM block_metadata WHERE block_id = t.value
            )
        """, [self.total_blocks])
        for block in range(self.total_blocks):
            self.block_metadata[block] = BlockMetadata(
                erase_count=0,
                valid_pages=set(),
                invalid_pages=set(),
                last_write_time=time.time(),
                temperature=0.0
            )

    def map(self, lba: int, phys: int, is_hot: bool = False):
        """Map logical to physical address with temperature awareness"""
        try:
            # Start transaction
            self.conn.execute("BEGIN TRANSACTION")
            
            # Get old mapping if exists
            old_phys = self.conn.execute(
                "SELECT phys_addr FROM address_mapping WHERE lba = ?",
                [lba]
            ).fetchone()
            
            if old_phys:
                old_phys = old_phys[0]
                # Invalidate old mapping
                old_block_id = old_phys // self.pages_per_block
                old_page_id = old_phys % self.pages_per_block
                
                # Update old block's page sets
                old_meta = self.conn.execute(
                    "SELECT valid_pages, invalid_pages FROM block_metadata WHERE block_id = ?",
                    [old_block_id]
                ).fetchone()
                
                if old_meta:
                    valid_pages = set(json.loads(old_meta[0]))
                    invalid_pages = set(json.loads(old_meta[1]))
                    valid_pages.remove(old_page_id)
                    invalid_pages.add(old_page_id)
                    
                    self.conn.execute("""
                        UPDATE block_metadata 
                        SET valid_pages = ?,
                            invalid_pages = ?
                        WHERE block_id = ?
                    """, [
                        json.dumps(list(valid_pages)),
                        json.dumps(list(invalid_pages)),
                        old_block_id
                    ])
            
            # Create or update mapping
            self.conn.execute("""
                INSERT OR REPLACE INTO address_mapping (lba, phys_addr, last_update)
                VALUES (?, ?, CURRENT_TIMESTAMP)
            """, [lba, phys])
            
            # Update new block metadata
            block_id = phys // self.pages_per_block
            page_id = phys % self.pages_per_block
            
            meta = self.conn.execute(
                "SELECT valid_pages FROM block_metadata WHERE block_id = ?",
                [block_id]
            ).fetchone()
            
            valid_pages = set(json.loads(meta[0])) if meta else set()
            valid_pages.add(page_id)
            
            self.conn.execute("""
                UPDATE block_metadata
                SET valid_pages = ?,
                    last_write_time = CURRENT_TIMESTAMP,
                    is_hot = ?,
                    is_free = false
                WHERE block_id = ?
            """, [
                json.dumps(list(valid_pages)),
                is_hot,
                block_id
            ])
            
            # Add to access history
            self.conn.execute("""
                INSERT INTO access_history (lba, access_time)
                VALUES (?, CURRENT_TIMESTAMP)
            """, [lba])
            
            # Clean up old history entries
            self.conn.execute("""
                DELETE FROM access_history 
                WHERE lba = ? AND access_time NOT IN (
                    SELECT access_time 
                    FROM access_history 
                    WHERE lba = ?
                    ORDER BY access_time DESC 
                    LIMIT 100
                )
            """, [lba, lba])
            
            # Commit transaction
            self.conn.execute("COMMIT")
            
        except Exception as e:
            self.conn.execute("ROLLBACK")
            raise RuntimeError(f"Failed to map addresses: {str(e)}")

    def get_phys(self, lba: int) -> Optional[int]:
        """Get physical address with cache awareness"""
        try:
            # Check cache first
            cache_entry = self.conn.execute("""
                SELECT phys_addr, access_count 
                FROM page_cache 
                WHERE lba = ?
            """, [lba]).fetchone()
            
            if cache_entry:
                # Update cache statistics
                self.conn.execute("""
                    UPDATE page_cache
                    SET access_count = access_count + 1,
                        last_access = CURRENT_TIMESTAMP
                    WHERE lba = ?
                """, [lba])
                return cache_entry[0]
                
            # Cache miss - check main mapping
            mapping = self.conn.execute("""
                SELECT phys_addr 
                FROM address_mapping 
                WHERE lba = ?
            """, [lba]).fetchone()
            
            if not mapping:
                return None
                
            phys_addr = mapping[0]
            
            # Add to cache if there's space or evict least recently used
            cache_count = self.conn.execute(
                "SELECT COUNT(*) FROM page_cache"
            ).fetchone()[0]
            
            if cache_count >= self.cache_size:
                # Evict least recently used entry
                self.conn.execute("""
                    DELETE FROM page_cache
                    WHERE lba IN (
                        SELECT lba
                        FROM page_cache
                        ORDER BY last_access ASC
                        LIMIT 1
                    )
                """)
                
            # Add to cache
            self.conn.execute("""
                INSERT INTO page_cache (
                    lba, phys_addr, access_count,
                    last_access, is_dirty
                ) VALUES (?, ?, 1, CURRENT_TIMESTAMP, false)
            """, [lba, phys_addr])
            
            return phys_addr
            
        except Exception as e:
            logging.error(f"Failed to get physical address: {str(e)}")
            return None

    def get_free_block(self) -> Optional[int]:
        """Get a free block using wear-leveling"""
        try:
            # Get block with lowest erase count that is free
            result = self.conn.execute("""
                SELECT block_id, erase_count
                FROM block_metadata
                WHERE is_free = true
                ORDER BY erase_count ASC, last_write_time ASC
                LIMIT 1
            """).fetchone()
            
            return result[0] if result else None
            
        except Exception as e:
            logging.error(f"Failed to get free block: {str(e)}")
            return None

    def garbage_collect(self, block_id: int) -> bool:
        """Perform garbage collection on a block"""
        try:
            # Start transaction
            self.conn.execute("BEGIN TRANSACTION")
            
            # Get block metadata
            block_meta = self.conn.execute("""
                SELECT valid_pages, invalid_pages, erase_count
                FROM block_metadata
                WHERE block_id = ?
            """, [block_id]).fetchone()
            
            if not block_meta:
                return False
                
            valid_pages = set(json.loads(block_meta[0]))
            invalid_pages = set(json.loads(block_meta[1]))
            erase_count = block_meta[2]
            
            # Skip if no invalid pages
            if not invalid_pages:
                return False
                
            # Move valid pages to new block
            new_block_id = self.get_free_block()
            if new_block_id is None:
                return False
                
            # Move each valid page
            for page_id in valid_pages:
                old_phys = block_id * self.pages_per_block + page_id
                new_phys = new_block_id * self.pages_per_block + page_id
                
                # Get LBA for this physical address
                lba_result = self.conn.execute("""
                    SELECT lba 
                    FROM address_mapping
                    WHERE phys_addr = ?
                """, [old_phys]).fetchone()
                
                if lba_result:
                    lba = lba_result[0]
                    # Update mapping to new location
                    self.map(lba, new_phys)
            
            # Erase block
            self.conn.execute("""
                UPDATE block_metadata
                SET valid_pages = '[]',
                    invalid_pages = '[]',
                    erase_count = ?,
                    is_free = true,
                    last_write_time = CURRENT_TIMESTAMP
                WHERE block_id = ?
            """, [erase_count + 1, block_id])
            
            # Commit transaction
            self.conn.execute("COMMIT")
            return True
            
        except Exception as e:
            self.conn.execute("ROLLBACK")
            logging.error(f"Garbage collection failed: {str(e)}")
            return False

    def get_stats(self):
        """Get FTL statistics"""
        try:
            stats = {}
            results = self.conn.execute(
                "SELECT stat_key, stat_value FROM ftl_stats"
            ).fetchall()
            
            for key, value in results:
                stats[key] = value
                
            # Add block statistics
            block_stats = self.conn.execute("""
                SELECT 
                    COUNT(*) as total_blocks,
                    SUM(CASE WHEN is_free THEN 1 ELSE 0 END) as free_blocks,
                    AVG(erase_count) as avg_erase_count,
                    MAX(erase_count) as max_erase_count
                FROM block_metadata
            """).fetchone()
            
            if block_stats:
                stats.update({
                    "total_blocks": block_stats[0],
                    "free_blocks": block_stats[1],
                    "avg_erase_count": round(block_stats[2], 2),
                    "max_erase_count": block_stats[3]
                })
                
            return stats
            
        except Exception as e:
            logging.error(f"Failed to get stats: {str(e)}")
            return {}
            return entry.phys
            
        self.cache_misses += 1
        phys = self.lba_to_phys.get(lba)
        if phys is not None:
            # Add to cache if there's room or evict least recently used
            self._cache_page(lba, phys)
        return phys

    def _cache_page(self, lba: int, phys: int):
        """Add page to cache with eviction if needed"""
        if len(self.page_cache) >= self.cache_size:
            # Evict least recently accessed page
            lru_lba = min(self.page_cache.keys(), 
                         key=lambda x: self.page_cache[x].last_access)
            if self.page_cache[lru_lba].is_dirty:
                # Write back dirty page
                self._write_back(lru_lba)
            del self.page_cache[lru_lba]
            
        self.page_cache[lba] = CacheEntry(
            lba=lba,
            phys=phys,
            access_count=1,
            last_access=time.time(),
            is_dirty=False
        )

    def get_free_block(self) -> Optional[int]:
        """Smart block allocation with wear leveling"""
        if not self.free_blocks:
            self.garbage_collection()
            
        if not self.free_blocks:
            return None
            
        # Choose block with lowest erase count
        return min(self.free_blocks, 
                  key=lambda b: self.block_metadata[b].erase_count)

    def garbage_collection(self):
        """Advanced garbage collection with hot/cold separation"""
        # Select victim blocks based on invalid page count and temperature
        victim_blocks = []
        for block_id in range(self.total_blocks):
            meta = self.block_metadata[block_id]
            if len(meta.invalid_pages) > self.pages_per_block * 0.7:
                score = len(meta.invalid_pages) / self.pages_per_block
                if block_id in self.hot_blocks:
                    score *= 1.2  # Prefer hot blocks for GC
                victim_blocks.append((score, block_id))
                
        victim_blocks.sort(reverse=True)
        
        for _, block_id in victim_blocks[:5]:  # Process top 5 victims
            self._collect_block(block_id)

    def _collect_block(self, block_id: int):
        """Process a single block for garbage collection"""
        meta = self.block_metadata[block_id]
        
        # Relocate valid pages
        for page_id in meta.valid_pages:
            phys_addr = block_id * self.pages_per_block + page_id
            lba = self.phys_to_lba[phys_addr]
            
            # Allocate new location based on temperature
            is_hot = block_id in self.hot_blocks
            new_block = self._allocate_block_for_temperature(is_hot)
            if new_block is None:
                return  # No space available
                
            new_page = self._get_free_page_in_block(new_block)
            new_phys = new_block * self.pages_per_block + new_page
            
            # Update mapping
            self.map(lba, new_phys, is_hot)
            
        # Erase block
        meta.erase_count += 1
        meta.valid_pages.clear()
        meta.invalid_pages.clear()
        meta.last_write_time = time.time()
        self.free_blocks.add(block_id)
        
    def _allocate_block_for_temperature(self, is_hot: bool) -> Optional[int]:
        """Allocate block based on data temperature"""
        blocks = self.hot_blocks if is_hot else self.cold_blocks
        free_blocks = self.free_blocks - blocks
        if not free_blocks:
            return None
        block = min(free_blocks, 
                   key=lambda b: self.block_metadata[b].erase_count)
        return block

    def _get_free_page_in_block(self, block_id: int) -> Optional[int]:
        """Find free page in block"""
        meta = self.block_metadata[block_id]
        used_pages = meta.valid_pages | meta.invalid_pages
        all_pages = set(range(self.pages_per_block))
        free_pages = all_pages - used_pages
        return min(free_pages) if free_pages else None
