"""
Database-backed tensor caching system for distributed tensor processing.
Provides persistent storage and efficient retrieval of tensor chunks across cores.
"""

import sqlite3
import numpy as np
import json
import time
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
import hashlib

logger = logging.getLogger(__name__)

class TensorDBCache:
    """Manages tensor chunk caching using SQLite databases"""
    
    def __init__(self, 
                 cache_dir: str = "data",
                 tensor_db: str = "tensors.db",
                 types_db: str = "tensor_types.db",
                 metrics_db: str = "tensor_metrics.db"):
        """Initialize the tensor caching system
        
        Args:
            cache_dir: Directory to store DB files
            tensor_db: Database for tensor chunk data
            types_db: Database for tensor metadata and types
            metrics_db: Database for performance metrics
        """
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        
        self.tensor_db_path = self.cache_dir / tensor_db
        self.types_db_path = self.cache_dir / types_db
        self.metrics_db_path = self.cache_dir / metrics_db
        
        self.initialize_db()
        
    def initialize_db(self):
        """Initialize all required database tables"""
        # Initialize tensor chunks database
        with sqlite3.connect(self.tensor_db_path) as conn:
            cursor = conn.cursor()
            
            # Store actual tensor chunks with compression
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS tensor_chunks (
                    chunk_id TEXT PRIMARY KEY,
                    tensor_id TEXT NOT NULL,
                    core_id INTEGER NOT NULL,
                    data BLOB NOT NULL,
                    compressed BOOLEAN DEFAULT TRUE,
                    checksum TEXT NOT NULL,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    last_accessed TIMESTAMP,
                    access_count INTEGER DEFAULT 0
                )
            ''')
            
            # Create indices
            cursor.execute('CREATE INDEX IF NOT EXISTS idx_tensor_id ON tensor_chunks(tensor_id)')
            cursor.execute('CREATE INDEX IF NOT EXISTS idx_core_id ON tensor_chunks(core_id)')
            
        # Initialize tensor types and metadata database
        with sqlite3.connect(self.types_db_path) as conn:
            cursor = conn.cursor()
            
            # Store tensor metadata
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS tensor_metadata (
                    tensor_id TEXT PRIMARY KEY,
                    name TEXT,
                    shape TEXT NOT NULL,  -- JSON string of shape
                    dtype TEXT NOT NULL,
                    total_chunks INTEGER NOT NULL,
                    chunk_distribution TEXT NOT NULL,  -- JSON mapping of chunk_id -> core_id
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    last_modified TIMESTAMP,
                    is_weight BOOLEAN DEFAULT FALSE
                )
            ''')
            
            # Store core assignments and status
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS core_assignments (
                    core_id INTEGER PRIMARY KEY,
                    active_chunks INTEGER DEFAULT 0,
                    total_memory REAL,
                    used_memory REAL,
                    last_updated TIMESTAMP
                )
            ''')
            
        # Initialize performance metrics database
        with sqlite3.connect(self.metrics_db_path) as conn:
            cursor = conn.cursor()
            
            # Store access patterns and performance metrics
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS chunk_metrics (
                    chunk_id TEXT NOT NULL,
                    core_id INTEGER NOT NULL,
                    operation TEXT NOT NULL,
                    processing_time REAL,
                    memory_used REAL,
                    cache_hit BOOLEAN,
                    timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    PRIMARY KEY (chunk_id, core_id, timestamp)
                )
            ''')
            
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS core_metrics (
                    core_id INTEGER NOT NULL,
                    metric_name TEXT NOT NULL,
                    metric_value REAL NOT NULL,
                    timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    PRIMARY KEY (core_id, metric_name, timestamp)
                )
            ''')
    
    def store_tensor_chunk(self, chunk_id: str, tensor_id: str, core_id: int, 
                          data: np.ndarray, compress: bool = True) -> bool:
        """Store a tensor chunk in the database
        
        Args:
            chunk_id: Unique identifier for this chunk
            tensor_id: ID of the parent tensor
            core_id: ID of the core this chunk is assigned to
            data: Numpy array containing the chunk data
            compress: Whether to compress the data before storing
        
        Returns:
            bool: True if storage was successful
        """
        try:
            # Calculate checksum
            checksum = hashlib.sha256(data.tobytes()).hexdigest()
            
            # Compress if requested
            if compress:
                data_bytes = self._compress_array(data)
            else:
                data_bytes = data.tobytes()
            
            with sqlite3.connect(self.tensor_db_path) as conn:
                cursor = conn.cursor()
                
                cursor.execute('''
                    INSERT OR REPLACE INTO tensor_chunks 
                    (chunk_id, tensor_id, core_id, data, compressed, checksum, last_accessed)
                    VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
                ''', (chunk_id, tensor_id, core_id, data_bytes, compress, checksum))
                
                # Update access metrics
                cursor.execute('''
                    UPDATE tensor_chunks 
                    SET access_count = access_count + 1
                    WHERE chunk_id = ?
                ''', (chunk_id,))
                
            return True
            
        except Exception as e:
            logger.error(f"Failed to store tensor chunk {chunk_id}: {str(e)}")
            return False
    
    def load_tensor_chunk(self, chunk_id: str) -> Optional[np.ndarray]:
        """Load a tensor chunk from the database
        
        Args:
            chunk_id: ID of the chunk to load
        
        Returns:
            np.ndarray or None: The tensor chunk data if found
        """
        try:
            with sqlite3.connect(self.tensor_db_path) as conn:
                cursor = conn.cursor()
                
                cursor.execute('''
                    SELECT data, compressed, checksum 
                    FROM tensor_chunks 
                    WHERE chunk_id = ?
                ''', (chunk_id,))
                
                result = cursor.fetchone()
                if not result:
                    return None
                    
                data_bytes, is_compressed, stored_checksum = result
                
                # Update access time and count
                cursor.execute('''
                    UPDATE tensor_chunks 
                    SET last_accessed = CURRENT_TIMESTAMP,
                        access_count = access_count + 1
                    WHERE chunk_id = ?
                ''', (chunk_id,))
                
                # Decompress if needed
                if is_compressed:
                    data = self._decompress_array(data_bytes)
                else:
                    data = np.frombuffer(data_bytes)
                
                # Verify checksum
                current_checksum = hashlib.sha256(data.tobytes()).hexdigest()
                if current_checksum != stored_checksum:
                    logger.error(f"Checksum mismatch for chunk {chunk_id}")
                    return None
                
                return data
                
        except Exception as e:
            logger.error(f"Failed to load tensor chunk {chunk_id}: {str(e)}")
            return None
    
    def get_core_chunks(self, core_id: int) -> List[str]:
        """Get all chunk IDs assigned to a specific core
        
        Args:
            core_id: ID of the core
            
        Returns:
            List[str]: List of chunk IDs assigned to this core
        """
        with sqlite3.connect(self.tensor_db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT chunk_id 
                FROM tensor_chunks 
                WHERE core_id = ?
            ''', (core_id,))
            
            return [row[0] for row in cursor.fetchall()]
    
    def update_core_metrics(self, core_id: int, metrics: Dict[str, float]):
        """Update performance metrics for a core
        
        Args:
            core_id: ID of the core
            metrics: Dictionary of metric_name -> value
        """
        with sqlite3.connect(self.metrics_db_path) as conn:
            cursor = conn.cursor()
            
            for metric_name, value in metrics.items():
                cursor.execute('''
                    INSERT INTO core_metrics (core_id, metric_name, metric_value)
                    VALUES (?, ?, ?)
                ''', (core_id, metric_name, value))
    
    def _compress_array(self, arr: np.ndarray) -> bytes:
        """Compress a numpy array for storage"""
        import zlib
        return zlib.compress(arr.tobytes())
    
    def _decompress_array(self, data: bytes) -> np.ndarray:
        """Decompress stored array data"""
        import zlib
        return np.frombuffer(zlib.decompress(data))
    
    def get_tensor_distribution(self, tensor_id: str) -> Dict[str, int]:
        """Get the chunk distribution for a tensor
        
        Args:
            tensor_id: ID of the tensor
            
        Returns:
            Dict[str, int]: Mapping of chunk_id to core_id
        """
        with sqlite3.connect(self.types_db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT chunk_distribution
                FROM tensor_metadata
                WHERE tensor_id = ?
            ''', (tensor_id,))
            
            result = cursor.fetchone()
            if result:
                return json.loads(result[0])
            return {}
    
    def cleanup_old_chunks(self, max_age_days: int = 7):
        """Remove old unused chunks to free up space
        
        Args:
            max_age_days: Maximum age of unused chunks to keep
        """
        with sqlite3.connect(self.tensor_db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                DELETE FROM tensor_chunks
                WHERE last_accessed < datetime('now', ?) 
                AND is_weight = FALSE
            ''', (f'-{max_age_days} days',))
            
            if cursor.rowcount > 0:
                logger.info(f"Cleaned up {cursor.rowcount} old tensor chunks")