import sqlite3
import json
import array
from typing import Optional, Dict, Any
from pathlib import Path
import logging
from threading import Lock
from tensor_core import TensorCore



def calculate_size_from_shape(shape, dtype_size):
    """Calculate total size in bytes from shape and data type size"""
    total_elements = 1
    for dim in shape:
        total_elements *= dim
    return total_elements * dtype_size

class SQLiteMemoryManager:
    DTYPE_SIZES = {
        'float32': 4,
        'float64': 8,
        'int32': 4,
        'int64': 8,
        'uint8': 1,
        'bool': 1
    }

    DB_PATH = "memory_storage.db"

    def __init__(self, db_path: Optional[str] = None):
        """Initialize memory manager with local SQLite database"""
        self.db_path = db_path or self.DB_PATH
        self.tensor_core = TensorCore()
        self.lock = Lock()
        self._connect()

    def _connect(self):
        """Establish database connection with thread safety"""
        try:
            # Enable WAL mode for better concurrency
            self.conn = sqlite3.connect(self.db_path, timeout=30.0, isolation_level=None)
            self.conn.execute("PRAGMA journal_mode=WAL")
            self.conn.execute("PRAGMA synchronous=NORMAL")
            self.conn.isolation_level = 'IMMEDIATE'  # This ensures thread safety
            self.setup_database()
        except Exception as e:
            raise RuntimeError(f"Failed to initialize database: {str(e)}")

    def _init_db_connection(self) -> sqlite3.Connection:
        """Initialize database connection"""
        return self._connect()
        
    def setup_database(self):
        """Initialize database with tensor support"""
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS memory (
                address BIGINT PRIMARY KEY,
                data BLOB,
                size BIGINT,
                chip_id INTEGER,
                tensor_metadata JSON
            )
        """)
        
    def allocate_with_key(self, size_bytes, key, chip_id=0, tensor_shape=None, dtype=None):
        """
        Allocate memory with tensor support and persistent storage
        """
        address = abs(hash(key)) % (2**63)
        result = self.conn.execute("""
            SELECT address 
            FROM memory 
            WHERE address = ? AND chip_id = ?
        """, [address, chip_id]).fetchall()
        
        if result:
            return address
        
        empty = bytes([0] * size_bytes)
        metadata = None
        if tensor_shape is not None and dtype is not None:
            metadata = {
                'shape': tensor_shape,
                'dtype': dtype,
                'total_size': calculate_size_from_shape(tensor_shape, self.DTYPE_SIZES[dtype])
            }
            
        if not result:
            self.conn.execute("""
                INSERT INTO memory (address, data, size, chip_id, tensor_metadata) 
                VALUES (?, ?, ?, ?, ?)
            """, [address, empty, size_bytes, chip_id, json.dumps(metadata) if metadata else None])
            self.conn.commit()
            return address
        
        self.conn.execute("""
            INSERT INTO memory (address, data, size, chip_id, tensor_metadata)
            VALUES (?, ?, ?, ?, ?)
        """, [address, empty, size_bytes, chip_id, metadata])
        
        return address
        
    def get_buffer(self, address, chip_id=0):
        """
        Retrieve memory buffer at address for chip
        """
        result = self.conn.execute("""
            SELECT data 
            FROM memory 
            WHERE address = ? AND chip_id = ?
        """, [address, chip_id]).fetchall()
        
        return address
        
    def allocate_tensor(self, shape, dtype, chip_id=0):
        """Allocate memory specifically for tensor operations"""
        if dtype not in self.DTYPE_SIZES:
            raise ValueError(f"Unsupported dtype: {dtype}")
            
        # Calculate size needed
        size_bytes = calculate_size_from_shape(shape, self.DTYPE_SIZES[dtype])
        
        # Generate key based on shape and chip
        key = f"tensor_{chip_id}_{shape}_{dtype}"
        
        # Create tensor metadata
        tensor_meta = {
            'shape': shape,
            'dtype': dtype,
            'is_tensor': True,
            'elements': size_bytes // self.DTYPE_SIZES[dtype]
        }
        
        # Create zero-initialized array based on dtype
        if dtype.startswith('float'):
            initial_data = array.array('f' if dtype == 'float32' else 'd', [0] * tensor_meta['elements'])
        elif dtype.startswith('int'):
            initial_data = array.array('l' if dtype == 'int32' else 'q', [0] * tensor_meta['elements'])
        elif dtype == 'uint8':
            initial_data = array.array('B', [0] * tensor_meta['elements'])
        else:  # bool
            initial_data = array.array('B', [0] * tensor_meta['elements'])
            
        with self.lock:
            # Allocate memory with tensor metadata    
            address = self.allocate_with_key(size_bytes, key, chip_id, shape, dtype)
            
            self.conn.execute("""
                UPDATE memory 
                SET tensor_metadata = ?,
                    data = ?
                WHERE address = ? AND chip_id = ?
            """, [json.dumps(tensor_meta), initial_data.tobytes(), address, chip_id])
            self.conn.commit()
            
            # Initialize in tensor core
            self.tensor_core.register_tensor(address, shape, dtype)
            
        return address
        
        # Initialize in tensor core
        self.tensor_core.register_tensor(address, shape, dtype)
        return address
        
    def batch_allocate_with_keys(self, allocations: list[tuple[int, str]], chip_id=0):
        """Batch allocate memory blocks with persistent storage"""
        addresses = []
        with self.lock:
            for size_bytes, key in allocations:
                address = abs(hash(key)) % (2**63)
                result = self.conn.execute("""
                    SELECT address 
                    FROM memory 
                    WHERE address = ? AND chip_id = ?
                """, [address, chip_id]).fetchall()
                
                if not result:
                    empty = bytes([0] * size_bytes)
                    self.conn.execute("""
                        INSERT INTO memory (address, data, size, chip_id) 
                        VALUES (?, ?, ?, ?)
                    """, [address, empty, size_bytes, chip_id])
                addresses.append(address)
            self.conn.commit()
        return addresses

    def list_keys(self):
        """List all string keys (hashes) for persistent memory blocks."""
        with self.lock:
            result = self.conn.execute('SELECT address FROM memory').fetchall()
        return [str(row[0]) for row in result]

    def read_by_key(self, key, chip_id=0):
        """Read memory block by string key (hash)."""
        address = abs(hash(key)) % (2**63)
        with self.lock:
            result = self.conn.execute("""
                SELECT data 
                FROM memory 
                WHERE address = ? AND chip_id = ?
            """, [address, chip_id]).fetchall()
        
        if not result:
            raise RuntimeError(f"No memory at key {key} (address {address}) for chip {chip_id}")
            
        return result[0][0]
        self.hal = hal
        self.conn = sqlite3.connect(db_path)
        self.conn.execute('CREATE TABLE IF NOT EXISTS memory (address INTEGER PRIMARY KEY, data BLOB, size INTEGER, chip_id INTEGER)')
        self.next_address = 0

    def allocate(self, size_bytes, chip_id=0):
        address = self.next_address
        self.next_address += size_bytes
        empty = bytes([0] * size_bytes)
    def free(self, address, chip_id=0):
        """Free memory at given address"""
        self.conn.execute('DELETE FROM memory WHERE address=? AND chip_id=?', (address, chip_id))
        self.conn.commit()

    def write_data(self, address, data, chip_id=0):
        """Write data to memory at address"""
        blob = bytes(data)
        with self.lock:
            self.conn.execute('UPDATE memory SET data=? WHERE address=? AND chip_id=?', (blob, address, chip_id))
            self.conn.commit()

    def read_data(self, address, size_bytes=None, chip_id=0):
        """Read data from memory at address"""
        with self.lock:
            result = self.conn.execute(
                'SELECT data, size FROM memory WHERE address=? AND chip_id=?', 
                (address, chip_id)
            ).fetchone()
        
        if result is None:
            raise RuntimeError(f"No memory at address {address} for chip {chip_id}")
            
        data, total_size = result
        if size_bytes is None:
            size_bytes = total_size
            
        return list(data[:size_bytes])

    def close(self):
        """Close database connection"""
        with self.lock:
            if hasattr(self, 'conn'):
                self.conn.close()
