"""
Memory optimization and caching system for mining operations using local storage
"""
from typing import Dict, Any, Optional
import numpy as np
import os
import sqlite3
import threading
import json
from http_storage import LocalStorage

class LocalCacheManager:
    def __init__(self, cache_size_gb: int = 32, cache_policy: str = "lru", 
                 db_path: str = "db/coin_miner/cache.db"):
        self.cache_size = cache_size_gb * 1024 * 1024 * 1024  # Convert to bytes
        self.cache_policy = cache_policy
        self.db_path = db_path
        os.makedirs(os.path.dirname(db_path), exist_ok=True)
        
        # Initialize storage and memory cache
        self.storage = LocalStorage()
        self.cache = {}
        self.access_count = {}
        self.lock = threading.Lock()
        
        # Initialize SQLite database for persistent cache
        self._init_db()
        
    def cache_computation(self, key: str, data: np.ndarray) -> bool:
        """Cache computation result with LRU eviction"""
        data_size = data.nbytes
        
        # Check if we need to evict
        while self._get_cache_size() + data_size > self.cache_size:
            if not self._evict_lru():
                return False
                
        # Store in cache
        self.cache[key] = data
        self.access_count[key] = 0
        return True
        
    def get_cached(self, key: str) -> Optional[np.ndarray]:
        """Get cached computation if available"""
        if key in self.cache:
            self.access_count[key] += 1
            return self.cache[key]
        return None
        
    def _init_db(self):
        """Initialize SQLite database for persistent cache"""
        self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS cache_entries (
                key TEXT PRIMARY KEY,
                data BLOB,
                metadata TEXT,
                access_count INTEGER,
                last_access REAL,
                size_bytes INTEGER
            )
        """)
        self.conn.commit()

    def _get_cache_size(self) -> int:
        """Get current cache size in bytes"""
        with self.lock:
            # Get in-memory cache size
            memory_size = sum(arr.nbytes for arr in self.cache.values())
            
            # Get persistent cache size
            db_size = self.conn.execute(
                "SELECT COALESCE(SUM(size_bytes), 0) FROM cache_entries"
            ).fetchone()[0]
            
            return memory_size + db_size
        
    def _evict_lru(self) -> bool:
        """Evict least recently used item from both memory and disk cache"""
        with self.lock:
            # Try to evict from memory first
            if self.cache:
                lru_key = min(self.access_count.items(), key=lambda x: x[1])[0]
                del self.cache[lru_key]
                del self.access_count[lru_key]
                return True
                
            # If memory cache is empty, evict from disk
            result = self.conn.execute("""
                DELETE FROM cache_entries 
                WHERE key = (
                    SELECT key FROM cache_entries 
                    ORDER BY access_count ASC, last_access ASC 
                    LIMIT 1
                )
                RETURNING key
            """).fetchone()
            
            self.conn.commit()
            return result is not None

    def cache_computation(self, key: str, data: np.ndarray, 
                        metadata: Optional[Dict] = None) -> bool:
        """Cache computation result with LRU eviction"""
        data_size = data.nbytes
        
        with self.lock:
            # Check if we need to evict
            while self._get_cache_size() + data_size > self.cache_size:
                if not self._evict_lru():
                    return False
            
            # Store in memory cache
            self.cache[key] = data
            self.access_count[key] = 0
            
            # Store in persistent cache
            self.conn.execute("""
                INSERT OR REPLACE INTO cache_entries 
                (key, data, metadata, access_count, last_access, size_bytes)
                VALUES (?, ?, ?, 0, strftime('%s','now'), ?)
            """, (key, data.tobytes(), json.dumps(metadata or {}), data_size))
            
            self.conn.commit()
            return True
        
    def get_cached(self, key: str) -> Optional[np.ndarray]:
        """Get cached computation if available"""
        with self.lock:
            # Try memory cache first
            if key in self.cache:
                self.access_count[key] += 1
                return self.cache[key]
            
            # Try persistent cache
            result = self.conn.execute("""
                SELECT data, size_bytes 
                FROM cache_entries 
                WHERE key = ?
            """, (key,)).fetchone()
            
            if result:
                # Update access stats
                self.conn.execute("""
                    UPDATE cache_entries 
                    SET access_count = access_count + 1,
                        last_access = strftime('%s','now')
                    WHERE key = ?
                """, (key,))
                self.conn.commit()
                
                # Convert bytes back to ndarray
                data_bytes, size = result
                data = np.frombuffer(data_bytes)
                
                # Cache in memory for faster access next time
                self.cache[key] = data
                self.access_count[key] = 1
                
                return data
                
            return None

    def close(self):
        """Close database connection"""
        if hasattr(self, 'conn'):
            self.conn.close()

    def __del__(self):
        self.close()

class MemoryPipeline:
    def __init__(self, prefetch_size: int = 4096, 
                 cache_line_size: int = 256,
                 num_cache_lines: int = 1024):
        self.prefetch_size = prefetch_size
        self.cache_line_size = cache_line_size
        self.num_cache_lines = num_cache_lines
        self.cache_lines = {}
        self.prefetch_buffer = []
        self.lock = threading.Lock()
        
    def optimize_memory_access(self, data: np.ndarray) -> np.ndarray:
        """Optimize memory access patterns"""
        with self.lock:
            # Align data to cache lines
            aligned_data = self._align_to_cache_lines(data)
            
            # Prefetch next batch
            self._prefetch_next_batch(aligned_data)
            
            return aligned_data
        
    def _align_to_cache_lines(self, data: np.ndarray) -> np.ndarray:
        """Align data to cache line boundaries"""
        remainder = data.nbytes % self.cache_line_size
        if remainder:
            padding = self.cache_line_size - remainder
            return np.pad(data, (0, padding), mode='constant')
        return data
        
    def _prefetch_next_batch(self, current_data: np.ndarray):
        """Prefetch next batch of data"""
        start_idx = current_data.nbytes
        self.prefetch_buffer = np.empty(self.prefetch_size, dtype=current_data.dtype)
