"""
Memory management using DuckDB with LocalStorage for persistence.
"""

import duckdb
import json
import numpy as np
from typing import Dict, Any, List, Optional, Tuple
import time
import logging
import os
from pathlib import Path
from tensor_core import TensorCore
from http_storage import LocalStorage  # Import LocalStorage for persistence

def calculate_size_from_shape(shape: tuple, dtype_size: int) -> int:
    """Calculate total size in bytes from shape and data type size"""
    total_elements = 1
    for dim in shape:
        total_elements *= dim
    return total_elements * dtype_size

class DuckDBMemoryManager:
    DTYPE_SIZES = {
        'float32': 4,
        'float64': 8,
        'int32': 4,
        'int64': 8,
        'uint8': 1,
        'bool': 1
    }
    
    def __init__(self, max_retries: int = 3):
        """Initialize DuckDB memory manager with local storage in user's home directory
        
        Args:
            max_retries: Number of connection retry attempts
        """
        # Setup database directory in user's home folder
        home_dir = os.path.expanduser("~")
        self.db_dir = os.path.join(home_dir, "database")
        os.makedirs(self.db_dir, exist_ok=True)
        
        # Database file paths
        self.db_file = os.path.join(self.db_dir, "memory.db")
        self.json_file = os.path.join(self.db_dir, "memory_store.json")
        
        self.tensor_core = TensorCore()
        self.next_address = 0
        self.max_retries = max_retries
        
        self.tensor_core = TensorCore()
        self.next_address = 0
        
        # Initialize connection with retries
        self._connect_with_retries()
                
    def _connect_with_retries(self):
        """Establish connection to local database"""
        for attempt in range(self.max_retries):
            try:
                # Connect to local database file
                self.conn = duckdb.connect(database=self.db_file, read_only=False)
                
                # Initialize database schema
                self._setup_database()
                
                logging.info(f"Connected to local database: {self.db_file}")
                return
            except Exception as e:
                if attempt == self.max_retries - 1:
                    raise RuntimeError(f"Failed to initialize storage after {self.max_retries} attempts: {str(e)}")
                logging.warning(f"Storage initialization attempt {attempt + 1} failed: {str(e)}")
                time.sleep(1)
                
    def ensure_connection(self):
        """Ensure database connection is active and valid"""
        try:
            self.conn.execute("SELECT 1")
        except:
            logging.warning("Database connection lost, attempting to reconnect...")
            self._connect_with_retries()
            
    def _setup_database(self):
        """Initialize database schema and load initial data"""
        # Create local memory table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS memory (
                address BIGINT PRIMARY KEY,
                data BLOB,
                size BIGINT,
                chip_id INTEGER,
                tensor_metadata JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Create indexes
        self.conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_chip ON memory(chip_id)")
        self.conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_size ON memory(size)")
        
    def allocate_with_key(self, size_bytes: int, key: str, chip_id: int = 0, 
                         tensor_shape: Optional[Tuple] = None, dtype: Optional[str] = None) -> int:
        """Allocate memory with persistence"""
        self.ensure_connection()
        address = abs(hash(key)) % (2**63)
        
        # Check if already exists
        result = self.conn.execute("""
            SELECT address FROM memory 
            WHERE address = ? AND chip_id = ?
        """, [address, chip_id]).fetchone()
        
        if result is not None:
            return address
            
        # Create new allocation
        empty = np.zeros(size_bytes, dtype=np.uint8).tobytes()
        metadata = None
        if tensor_shape and dtype:
            metadata = json.dumps({
                'shape': tensor_shape,
                'dtype': dtype,
                'is_tensor': True,
                'elements': size_bytes // self.DTYPE_SIZES.get(dtype, 1)
            })
            
        self.conn.execute("""
            INSERT INTO memory (address, data, size, chip_id, tensor_metadata)
            VALUES (?, ?, ?, ?, ?)
        """, [address, empty, size_bytes, chip_id, metadata])
        
        return address
        
    def allocate_tensor(self, shape: Tuple, dtype: str, chip_id: int = 0) -> int:
        """Allocate memory for tensor operations"""
        self.ensure_connection()
        if dtype not in self.DTYPE_SIZES:
            raise ValueError(f"Unsupported dtype: {dtype}")
            
        size_bytes = calculate_size_from_shape(shape, self.DTYPE_SIZES[dtype])
        address = self.next_address
        self.next_address += size_bytes
        
        # Create zero-initialized tensor
        if dtype.startswith('float'):
            data = np.zeros(shape, dtype=np.float32 if dtype == 'float32' else np.float64)
        elif dtype.startswith('int'):
            data = np.zeros(shape, dtype=np.int32 if dtype == 'int32' else np.int64)
        elif dtype == 'uint8':
            data = np.zeros(shape, dtype=np.uint8)
        else:  # bool
            data = np.zeros(shape, dtype=bool)
            
        # Store tensor
        tensor_meta = {
            'shape': shape,
            'dtype': dtype,
            'is_tensor': True,
            'elements': size_bytes // self.DTYPE_SIZES[dtype]
        }
        
        self.conn.execute("""
            INSERT INTO memory (address, data, size, chip_id, tensor_metadata)
            VALUES (?, ?, ?, ?, ?)
        """, [address, data.tobytes(), size_bytes, chip_id, json.dumps(tensor_meta)])
        
        # Register with tensor core
        self.tensor_core.register_tensor(address, shape, dtype)
        return address
        
    def batch_allocate_with_keys(self, allocations: List[Tuple[int, str]], chip_id: int = 0) -> List[int]:
        """Batch allocate memory blocks"""
        self.ensure_connection()
        addresses = []
        
        # Prepare batch data
        batch_data = []
        for size_bytes, key in allocations:
            address = abs(hash(key)) % (2**63)
            empty = np.zeros(size_bytes, dtype=np.uint8).tobytes()
            batch_data.append((address, empty, size_bytes, chip_id))
            addresses.append(address)
            
        # Execute batch insert
        self.conn.executemany("""
            INSERT INTO memory (address, data, size, chip_id)
            VALUES (?, ?, ?, ?)
            ON CONFLICT (address) DO NOTHING
        """, batch_data)
        
        return addresses
        
    def list_keys(self) -> List[str]:
        """List all memory block addresses"""
        self.ensure_connection()
        return [str(row[0]) for row in self.conn.execute(
            "SELECT address FROM memory ORDER BY created_at"
        ).fetchall()]
        
    def read_by_key(self, key: str, chip_id: int = 0) -> bytes:
        """Read memory block by key"""
        self.ensure_connection()
        address = abs(hash(key)) % (2**63)
        result = self.conn.execute("""
            UPDATE memory 
            SET last_accessed = CURRENT_TIMESTAMP
            WHERE address = ? AND chip_id = ?
            RETURNING data
        """, [address, chip_id]).fetchone()
        
        if result is None:
            raise RuntimeError(f"No memory at key {key} (address {address}) for chip {chip_id}")
        return result[0]
        
    def write_data(self, address: int, data: bytes, chip_id: int = 0):
        """Write data to memory and sync to JSON file"""
        self.ensure_connection()
        
        # Update memory table
        self.conn.execute("""
            INSERT INTO memory (
                address, data, chip_id, last_accessed
            ) VALUES (?, ?, ?, CURRENT_TIMESTAMP)
            ON CONFLICT (address) DO UPDATE 
            SET data = EXCLUDED.data,
                chip_id = EXCLUDED.chip_id,
                last_accessed = CURRENT_TIMESTAMP
        """, [address, data, chip_id])
        
        # Export to JSON file for persistence
        try:
            self.conn.execute(f"""
                COPY (
                    SELECT * FROM memory
                ) TO '{self.json_file}' (FORMAT JSON)
            """)
        except Exception as e:
            logging.error(f"Failed to sync to JSON file: {e}")
            raise
        
    def read_data(self, address: int, size_bytes: int, chip_id: int = 0) -> bytes:
        """Read data from memory, loading from JSON file if needed"""
        self.ensure_connection()
        
        # First try to read from memory table
        result = self.conn.execute("""
            SELECT data FROM memory 
            WHERE address = ? AND chip_id = ?
        """, [address, chip_id]).fetchone()
        
        # If not found, try to load from JSON file
        if result is None and os.path.exists(self.json_file):
            try:
                # Load data from JSON file
                self.conn.execute(f"""
                    INSERT INTO memory 
                    SELECT * FROM read_json_auto('{self.json_file}')
                    WHERE address = ? AND chip_id = ?
                """, [address, chip_id])
                
                # Try reading again
                result = self.conn.execute("""
                    SELECT data FROM memory 
                    WHERE address = ? AND chip_id = ?
                """, [address, chip_id]).fetchone()
            except Exception as e:
                logging.error(f"Failed to read from JSON file: {e}")
                raise
        
        if result is None:
            raise RuntimeError(f"No memory at address {address} for chip {chip_id}")
        return result[0][:size_bytes]
        
    def free(self, address: int, chip_id: int = 0):
        """Free memory block"""
        self.ensure_connection()
        self.conn.execute("DELETE FROM memory WHERE address = ? AND chip_id = ?", 
                         [address, chip_id])
        
    def get_stats(self) -> Dict[str, Any]:
        """Get memory manager statistics"""
        self.ensure_connection()
        stats = self.conn.execute("""
            SELECT 
                COUNT(*) as block_count,
                SUM(size) as total_size,
                COUNT(CASE WHEN tensor_metadata IS NOT NULL THEN 1 END) as tensor_count,
                AVG(EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - last_accessed))) as avg_idle_time
            FROM memory
        """).fetchone()
        
        return {
            'block_count': stats[0],
            'total_size_bytes': stats[1],
            'tensor_count': stats[2],
            'average_idle_time_seconds': stats[3]
        }
