"""Base HAL (Hardware Abstraction Layer) for virtual GPU components using local storage."""

import os
from typing import Optional, Dict, Any
import duckdb
import json
import logging
from pathlib import Path
import threading
import time
from enum import Enum

class DeviceType(Enum):
    GPU = "gpu_device"
    MEMORY = "memory_device"
    TENSOR_CORE = "tensor_core"
    WARP = "warp_unit"

class VirtualGPUStorageBase:
    """Base class for virtual GPU component storage"""
    
    def __init__(self, component_name: str, device_type: DeviceType):
        """Initialize storage for a virtual GPU component"""
        self.component_name = component_name
        self.device_type = device_type
        
        # Setup storage directory
        self.storage_dir = os.path.join(Path.home(), "database", "virtual_gpu")
        os.makedirs(self.storage_dir, exist_ok=True)
        
        # Component-specific database file
        self.db_path = os.path.join(self.storage_dir, f"{component_name}.db")
        
        # Initialize connection
        self._connect_with_retries()
        
    def _connect_with_retries(self, max_retries=3):
        """Establish database connection with retry logic"""
        for attempt in range(max_retries):
            try:
                self.conn = duckdb.connect(self.db_path)
                self.lock = threading.Lock()
                
                # Enable JSON extension
                self.conn.execute("""
                    INSTALL json;
                    LOAD json;
                """)
                
                # Initialize schema
                self._init_database()
                return
            except Exception as e:
                if attempt == max_retries - 1:
                    raise RuntimeError(f"Failed to initialize database after {max_retries} attempts: {str(e)}")
                time.sleep(1)
    
    def _init_database(self):
        """Initialize the database schema"""
        # Create device state table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS device_state (
                device_id VARCHAR PRIMARY KEY,
                type VARCHAR NOT NULL,
                state JSON,
                metadata JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Create operations table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS operations (
                op_id VARCHAR PRIMARY KEY,
                operation_type VARCHAR,
                params JSON,
                state JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                completed_at TIMESTAMP,
                status VARCHAR
            )
        """)
        
        # Create tensor storage
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS tensors (
                tensor_id VARCHAR PRIMARY KEY,
                data BLOB,
                metadata JSON,
                device_id VARCHAR,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP,
                state VARCHAR
            )
        """)
        
        # Create memory blocks table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS memory_blocks (
                block_id VARCHAR PRIMARY KEY,
                device_id VARCHAR,
                start_address BIGINT,
                size BIGINT,
                data BLOB,
                metadata JSON,
                allocated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                last_accessed TIMESTAMP,
                is_free BOOLEAN DEFAULT TRUE
            )
        """)
        
        self.conn.commit()
    
    def update_device_state(self, device_id: str, state: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None):
        """Update the state of a device"""
        self.conn.execute("""
            INSERT INTO device_state (device_id, type, state, metadata, updated_at)
            VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
            ON CONFLICT (device_id) DO UPDATE SET
                state = excluded.state,
                metadata = excluded.metadata,
                updated_at = CURRENT_TIMESTAMP
        """, [device_id, self.device_type.value, json.dumps(state), json.dumps(metadata or {})])
        
    def get_device_state(self, device_id: str) -> Optional[Dict[str, Any]]:
        """Get the current state of a device"""
        result = self.conn.execute("""
            SELECT state, metadata
            FROM device_state
            WHERE device_id = ? AND type = ?
        """, [device_id, self.device_type.value]).fetchone()
        
        if result:
            state, metadata = result
            return {
                'state': json.loads(state),
                'metadata': json.loads(metadata)
            }
        return None
    
    def store_tensor(self, tensor_id: str, data: bytes, metadata: Dict[str, Any], device_id: str):
        """Store a tensor in the database"""
        self.conn.execute("""
            INSERT INTO tensors (tensor_id, data, metadata, device_id, last_accessed, state)
            VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, 'active')
            ON CONFLICT (tensor_id) DO UPDATE SET
                data = excluded.data,
                metadata = excluded.metadata,
                last_accessed = CURRENT_TIMESTAMP
        """, [tensor_id, data, json.dumps(metadata), device_id])
        
    def load_tensor(self, tensor_id: str) -> Optional[Dict[str, Any]]:
        """Load a tensor from the database"""
        result = self.conn.execute("""
            SELECT data, metadata, device_id
            FROM tensors
            WHERE tensor_id = ?
        """, [tensor_id]).fetchone()
        
        if result:
            data, metadata, device_id = result
            return {
                'data': data,
                'metadata': json.loads(metadata),
                'device_id': device_id
            }
        return None
    
    def allocate_memory(self, size: int, device_id: str, metadata: Optional[Dict[str, Any]] = None) -> Optional[str]:
        """Allocate a block of memory"""
        block_id = f"block_{time.time_ns()}"
        try:
            self.conn.execute("""
                INSERT INTO memory_blocks (
                    block_id, device_id, size, metadata, is_free
                ) VALUES (?, ?, ?, ?, FALSE)
            """, [block_id, device_id, size, json.dumps(metadata or {})])
            return block_id
        except Exception as e:
            logging.error(f"Failed to allocate memory block: {e}")
            return None
    
    def free_memory(self, block_id: str):
        """Free an allocated memory block"""
        self.conn.execute("""
            UPDATE memory_blocks
            SET is_free = TRUE, data = NULL
            WHERE block_id = ?
        """, [block_id])
    
    def close(self):
        """Close the database connection"""
        if hasattr(self, 'conn'):
            self.conn.close()
