"""
Database manager for graphics pipeline state persistence
"""
import duckdb
import json
import time
import logging
import os
from typing import Dict, List, Optional
from pathlib import Path

# Create data directory if it doesn't exist
os.makedirs("data", exist_ok=True)

class PipelineStateDB:
    DB_FILE = "data/pipeline_state.db"

    def __init__(self, db_path: str = None, max_retries: int = 3):
        self.db_path = db_path or self.DB_FILE
        self.max_retries = max_retries
        self._connect_with_retries()

    def _connect_with_retries(self):
        """Establish database connection with retry logic"""
        for attempt in range(self.max_retries):
            try:
                self.conn = self._init_db_connection()
                self._init_tables()
                return
            except Exception as e:
                if attempt == self.max_retries - 1:
                    raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
                time.sleep(1)

    def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
        """Initialize database connection"""
        return duckdb.connect(self.db_path)
        
    def _init_tables(self):
        """Initialize database tables"""
        # Pipeline state table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS pipeline_states (
                hash VARCHAR PRIMARY KEY,
                shader_stages JSON,
                vertex_attributes JSON,
                shader_resources JSON, 
                viewport JSON,
                scissor JSON,
                rasterization JSON,
                depth JSON,
                stencil JSON,
                blend JSON,
                color_mask JSON,
                primitive_type VARCHAR,
                patch_control_points INTEGER,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Resource bindings table 
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS resource_bindings (
                pipeline_hash VARCHAR,
                binding_point INTEGER,
                resource_type VARCHAR,
                resource_data JSON,
                FOREIGN KEY (pipeline_hash) REFERENCES pipeline_states(hash),
                PRIMARY KEY (pipeline_hash, binding_point)
            )
        """)
        
        # Cache statistics
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS cache_stats (
                pipeline_hash VARCHAR PRIMARY KEY,
                hit_count INTEGER DEFAULT 0,
                last_used TIMESTAMP,
                FOREIGN KEY (pipeline_hash) REFERENCES pipeline_states(hash)
            )
        """)
        
    def ensure_connection(self):
        """Ensure database connection is active and valid"""
        try:
            self.conn.execute("SELECT 1")
        except:
            logging.warning("Database connection lost, attempting to reconnect...")
            self._connect_with_retries()

    def store_pipeline(self, hash: str, state_dict: Dict):
        """Store pipeline state in database"""
        self.ensure_connection()
        # Convert complex objects to JSON
        state = {
            'hash': hash,
            'shader_stages': json.dumps(state_dict['shaders']),
            'vertex_attributes': json.dumps(state_dict['vertex_attributes']),
            'shader_resources': json.dumps(state_dict['shader_resources']),
            'viewport': json.dumps(state_dict['viewport']),
            'scissor': json.dumps(state_dict['scissor']),
            'rasterization': json.dumps(state_dict['rasterization']),
            'depth': json.dumps(state_dict['depth']),
            'stencil': json.dumps(state_dict['stencil']),
            'blend': json.dumps(state_dict['blend']),
            'color_mask': json.dumps(state_dict['color_mask']),
            'primitive_type': state_dict['primitive_type'],
            'patch_control_points': state_dict['patch_control_points']
        }
        
        # Insert/update pipeline state
        self.conn.execute("""
            INSERT INTO pipeline_states 
            (hash, shader_stages, vertex_attributes, shader_resources,
             viewport, scissor, rasterization, depth, stencil, blend,
             color_mask, primitive_type, patch_control_points)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            ON CONFLICT(hash) DO UPDATE SET
            shader_stages=excluded.shader_stages,
            vertex_attributes=excluded.vertex_attributes,
            shader_resources=excluded.shader_resources,
            viewport=excluded.viewport,
            scissor=excluded.scissor,
            rasterization=excluded.rasterization,
            depth=excluded.depth,
            stencil=excluded.stencil,
            blend=excluded.blend,
            color_mask=excluded.color_mask,
            primitive_type=excluded.primitive_type,
            patch_control_points=excluded.patch_control_points
        """, [state[k] for k in state.keys()])
        
    def get_pipeline(self, hash: str) -> Optional[Dict]:
        """Retrieve pipeline state from database"""
        result = self.conn.execute("""
            SELECT * FROM pipeline_states WHERE hash = ?
        """, [hash]).fetchone()
        
        if not result:
            return None
            
        # Update cache statistics
        self.conn.execute("""
            INSERT INTO cache_stats (pipeline_hash, hit_count, last_used)
            VALUES (?, 1, CURRENT_TIMESTAMP)
            ON CONFLICT(pipeline_hash) DO UPDATE SET
            hit_count = cache_stats.hit_count + 1,
            last_used = CURRENT_TIMESTAMP
        """, [hash])
        
        # Convert JSON back to Python objects
        state = dict(zip(result.keys(), result))
        for k in ['shader_stages', 'vertex_attributes', 'shader_resources',
                 'viewport', 'scissor', 'rasterization', 'depth',
                 'stencil', 'blend', 'color_mask']:
            if state[k]:
                state[k] = json.loads(state[k])
        
        return state
        
    def prune_cache(self, max_size: int = 1000):
        """Remove least recently used pipeline states"""
        self.conn.execute("""
            WITH old_states AS (
                SELECT pipeline_hash
                FROM cache_stats
                ORDER BY last_used ASC
                LIMIT (SELECT COUNT(*) - ? FROM pipeline_states)
            )
            DELETE FROM pipeline_states 
            WHERE hash IN (SELECT pipeline_hash FROM old_states)
        """, [max_size])
        
    def get_cache_stats(self) -> List[Dict]:
        """Get cache usage statistics"""
        return self.conn.execute("""
            SELECT ps.hash, cs.hit_count, cs.last_used,
                   ps.created_at
            FROM pipeline_states ps
            LEFT JOIN cache_stats cs ON ps.hash = cs.pipeline_hash
            ORDER BY cs.hit_count DESC
        """).fetchall()
