"""
Enhanced Shader System for Virtual GPU
Integrates with SQLite for shader management and state tracking
"""

import json
import time
from typing import Dict, List, Optional, Union
from enum import Enum
import logging
import hashlib
import sqlite3
import threading
import os
from pathlib import Path

from .shader_compiler_local import ShaderType, ShaderError

class ShaderProgram:
    """Represents a shader program that can contain multiple shader stages"""
    
    def __init__(self):
        self.shaders = []
        self.uniforms = {}
        self.attributes = {}
        self.is_linked = False
        self.link_error = None
        
    def attach_shader(self, shader):
        """Attach a shader to this program"""
        if shader not in self.shaders:
            self.shaders.append(shader)
            self.is_linked = False
            
    def detach_shader(self, shader):
        """Detach a shader from this program"""
        if shader in self.shaders:
            self.shaders.remove(shader)
            self.is_linked = False
            
    def link(self):
        """Link the shader program"""
        try:
            # Verify we have necessary shader stages
            shader_types = {shader.type for shader in self.shaders}
            if ShaderType.VERTEX not in shader_types:
                raise ValueError("Shader program must have a vertex shader")
                
            # TODO: Additional validation and linking logic
            self.is_linked = True
            self.link_error = None
            
        except Exception as e:
            self.link_error = str(e)
            self.is_linked = False
            raise

class ShaderSystem:
    """Manages shader programs and their state using local SQLite storage"""
    
    def __init__(self, db_path: str = "db/graphics/shader_system.db"):
        """Initialize shader system database"""
        self.db_path = db_path
        os.makedirs(os.path.dirname(db_path), exist_ok=True)
        self.lock = threading.Lock()
        self._connect()

    def _connect(self):
        """Establish database connection"""
        try:
            self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
            self._setup_database()
        except Exception as e:
            raise RuntimeError(f"Failed to initialize database: {str(e)}")

    def _setup_database(self):
        """Initialize database tables"""
        with self.lock:
            # Programs table
            self.conn.execute("""
                CREATE TABLE IF NOT EXISTS programs (
                    program_id TEXT PRIMARY KEY,
                    is_linked INTEGER NOT NULL,
                    link_error TEXT,
                    metadata TEXT,
                    creation_time REAL,
                    last_used REAL
                )
            """)
            
            # Shaders table
            self.conn.execute("""
                CREATE TABLE IF NOT EXISTS shaders (
                    shader_id TEXT PRIMARY KEY,
                    program_id TEXT,
                    shader_type TEXT NOT NULL,
                    source_code TEXT NOT NULL,
                    compiled_code BLOB,
                    metadata TEXT,
                    FOREIGN KEY (program_id) REFERENCES programs(program_id)
                )
            """)
            
            # Uniforms table
            self.conn.execute("""
                CREATE TABLE IF NOT EXISTS uniforms (
                    uniform_id TEXT PRIMARY KEY,
                    program_id TEXT,
                    name TEXT NOT NULL,
                    type TEXT NOT NULL,
                    size INTEGER,
                    location INTEGER,
                    value BLOB,
                    metadata TEXT,
                    FOREIGN KEY (program_id) REFERENCES programs(program_id)
                )
            """)
            
            # Attributes table
            self.conn.execute("""
                CREATE TABLE IF NOT EXISTS attributes (
                    attribute_id TEXT PRIMARY KEY,
                    program_id TEXT,
                    name TEXT NOT NULL,
                    type TEXT NOT NULL,
                    size INTEGER,
                    location INTEGER,
                    metadata TEXT,
                    FOREIGN KEY (program_id) REFERENCES programs(program_id)
                )
            """)
            
            self.conn.commit()

    def create_program(self, metadata: Optional[Dict] = None) -> str:
        """Create a new shader program"""
        program_id = hashlib.sha256(str(time.time()).encode()).hexdigest()
        
        with self.lock:
            self.conn.execute("""
                INSERT INTO programs (program_id, is_linked, metadata, creation_time)
                VALUES (?, 0, ?, strftime('%s','now'))
            """, (program_id, json.dumps(metadata or {})))
            self.conn.commit()
            
        return program_id

    def attach_shader(self, program_id: str, shader_type: ShaderType, 
                     source_code: str, metadata: Optional[Dict] = None) -> str:
        """Attach a shader to a program"""
        shader_id = hashlib.sha256(source_code.encode()).hexdigest()
        
        with self.lock:
            self.conn.execute("""
                INSERT INTO shaders (shader_id, program_id, shader_type, source_code, metadata)
                VALUES (?, ?, ?, ?, ?)
            """, (shader_id, program_id, shader_type.value, source_code, 
                 json.dumps(metadata or {})))
            
            # Update program state
            self.conn.execute("""
                UPDATE programs 
                SET is_linked = 0, 
                    link_error = NULL,
                    last_used = strftime('%s','now')
                WHERE program_id = ?
            """, (program_id,))
            
            self.conn.commit()
            
        return shader_id

    def set_uniform(self, program_id: str, name: str, type_: str, 
                   size: int, location: int, value: bytes = None, 
                   metadata: Optional[Dict] = None):
        """Set uniform variable for a program"""
        uniform_id = f"{program_id}_{name}"
        
        with self.lock:
            self.conn.execute("""
                INSERT OR REPLACE INTO uniforms 
                (uniform_id, program_id, name, type, size, location, value, metadata)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?)
            """, (uniform_id, program_id, name, type_, size, location, 
                 value, json.dumps(metadata or {})))
            self.conn.commit()

    def set_attribute(self, program_id: str, name: str, type_: str, 
                     size: int, location: int, metadata: Optional[Dict] = None):
        """Set attribute variable for a program"""
        attribute_id = f"{program_id}_{name}"
        
        with self.lock:
            self.conn.execute("""
                INSERT OR REPLACE INTO attributes
                (attribute_id, program_id, name, type, size, location, metadata)
                VALUES (?, ?, ?, ?, ?, ?, ?)
            """, (attribute_id, program_id, name, type_, size, location,
                 json.dumps(metadata or {})))
            self.conn.commit()

    def get_program(self, program_id: str) -> Optional[Dict]:
        """Get program details"""
        with self.lock:
            program = self.conn.execute("""
                SELECT is_linked, link_error, metadata, creation_time, last_used
                FROM programs
                WHERE program_id = ?
            """, (program_id,)).fetchone()
            
            if not program:
                return None
                
            shaders = self.conn.execute("""
                SELECT shader_id, shader_type, source_code, compiled_code, metadata
                FROM shaders
                WHERE program_id = ?
            """, (program_id,)).fetchall()
            
            uniforms = self.conn.execute("""
                SELECT name, type, size, location, value, metadata
                FROM uniforms
                WHERE program_id = ?
            """, (program_id,)).fetchall()
            
            attributes = self.conn.execute("""
                SELECT name, type, size, location, metadata
                FROM attributes
                WHERE program_id = ?
            """, (program_id,)).fetchall()
            
        return {
            'program_id': program_id,
            'is_linked': bool(program[0]),
            'link_error': program[1],
            'metadata': json.loads(program[2]) if program[2] else {},
            'creation_time': program[3],
            'last_used': program[4],
            'shaders': [{
                'shader_id': s[0],
                'type': s[1],
                'source': s[2],
                'compiled': s[3],
                'metadata': json.loads(s[4]) if s[4] else {}
            } for s in shaders],
            'uniforms': [{
                'name': u[0],
                'type': u[1],
                'size': u[2],
                'location': u[3],
                'value': u[4],
                'metadata': json.loads(u[5]) if u[5] else {}
            } for u in uniforms],
            'attributes': [{
                'name': a[0],
                'type': a[1],
                'size': a[2],
                'location': a[3],
                'metadata': json.loads(a[4]) if a[4] else {}
            } for a in attributes]
        }

    def link_program(self, program_id: str, error: Optional[str] = None):
        """Update program link status"""
        with self.lock:
            self.conn.execute("""
                UPDATE programs
                SET is_linked = ?,
                    link_error = ?,
                    last_used = strftime('%s','now')
                WHERE program_id = ?
            """, (0 if error else 1, error, program_id))
            self.conn.commit()

    def set_shader_compiled(self, shader_id: str, compiled_code: bytes):
        """Update shader compiled code"""
        with self.lock:
            self.conn.execute("""
                UPDATE shaders
                SET compiled_code = ?
                WHERE shader_id = ?
            """, (compiled_code, shader_id))
            self.conn.commit()

    def delete_program(self, program_id: str):
        """Delete a shader program and all related data"""
        with self.lock:
            self.conn.execute("DELETE FROM attributes WHERE program_id = ?", (program_id,))
            self.conn.execute("DELETE FROM uniforms WHERE program_id = ?", (program_id,))
            self.conn.execute("DELETE FROM shaders WHERE program_id = ?", (program_id,))
            self.conn.execute("DELETE FROM programs WHERE program_id = ?", (program_id,))
            self.conn.commit()

    def list_programs(self) -> List[str]:
        """List all program IDs"""
        with self.lock:
            results = self.conn.execute(
                "SELECT program_id FROM programs"
            ).fetchall()
        return [r[0] for r in results]

    def close(self):
        """Close database connection"""
        if hasattr(self, 'conn'):
            self.conn.close()

    def __del__(self):
        self.close()
