"""
Modern graphics pipeline state manager for Virtual GPU
"""
from typing import Dict, List, Optional, Union
import numpy as np

from .pipeline_db import PipelineStateDB
from .graphics_types import (
    ShaderType, PrimitiveType, GraphicsPipelineState,
    Viewport, Scissor, RasterizationState, DepthState,
    StencilState, BlendState, ColorMask, VertexAttribute,
    ShaderResource
)

class GraphicsPipelineManager:
    def __init__(self, driver, db_path: str = None):
        self.driver = driver
        self.current_state = GraphicsPipelineState()
        self.db = PipelineStateDB(db_path or ":memory:")  # Use persistent or in-memory DB
        
    def create_pipeline(self, state: GraphicsPipelineState) -> str:
        """Create a new pipeline with given state"""
        # Hash pipeline state for caching
        state_hash = self._hash_pipeline_state(state)
        
        # Convert state to dict for DB storage
        state_dict = {
            "shaders": {k.value: v for k,v in state.shader_stages.items()},
            "vertex_attributes": [attr._asdict() for attr in state.vertex_attributes],
            "shader_resources": [res._asdict() for res in state.shader_resources],
            "viewport": state.viewport._asdict() if state.viewport else None,
            "scissor": state.scissor._asdict() if state.scissor else None,
            "rasterization": state.rasterization._asdict(),
            "depth": state.depth._asdict(),
            "stencil": state.stencil._asdict(),
            "blend": state.blend._asdict(),
            "color_mask": state.color_mask._asdict(),
            "primitive_type": state.primitive_type.value,
            "patch_control_points": state.patch_control_points
        }
        
        # Store in database
        self.db.store_pipeline(state_hash, state_dict)
        return state_hash
        
    def bind_pipeline(self, pipeline_hash: str):
        """Bind pipeline for rendering"""
        state_dict = self.db.get_pipeline(pipeline_hash)
        if not state_dict:
            raise ValueError(f"Invalid pipeline hash: {pipeline_hash}")
            
        # Reconstruct GraphicsPipelineState from DB data
        state = GraphicsPipelineState()
        state.shader_stages = {ShaderType(k): v for k,v in state_dict['shaders'].items()}
        state.vertex_attributes = [VertexAttribute(**attr) for attr in state_dict['vertex_attributes']]
        state.shader_resources = [ShaderResource(**res) for res in state_dict['shader_resources']]
        state.viewport = Viewport(**state_dict['viewport']) if state_dict['viewport'] else None
        state.scissor = Scissor(**state_dict['scissor']) if state_dict['scissor'] else None
        state.rasterization = RasterizationState(**state_dict['rasterization'])
        state.depth = DepthState(**state_dict['depth'])
        state.stencil = StencilState(**state_dict['stencil'])
        state.blend = BlendState(**state_dict['blend'])
        state.color_mask = ColorMask(**state_dict['color_mask'])
        state.primitive_type = PrimitiveType(state_dict['primitive_type'])
        state.patch_control_points = state_dict['patch_control_points']
        
        self.current_state = state
        
    def set_viewport(self, viewport: Viewport):
        """Set viewport state"""
        self.current_state.viewport = viewport
        
    def set_scissor(self, scissor: Scissor):
        """Set scissor state"""
        self.current_state.scissor = scissor
        
    def set_vertex_attributes(self, attributes: List[VertexAttribute]):
        """Set vertex input attributes"""
        self.current_state.vertex_attributes = attributes
        
    def set_shader_resources(self, resources: List[ShaderResource]):
        """Set shader resource bindings"""
        self.current_state.shader_resources = resources
        
    def set_rasterization_state(self, state: RasterizationState):
        """Set rasterization state"""
        self.current_state.rasterization = state
        
    def set_depth_state(self, state: DepthState):
        """Set depth state"""
        self.current_state.depth = state
        
    def set_stencil_state(self, state: StencilState):
        """Set stencil state"""
        self.current_state.stencil = state
        
    def set_blend_state(self, state: BlendState):
        """Set blend state"""
        self.current_state.blend = state
        
    def set_color_mask(self, mask: ColorMask):
        """Set color write mask"""
        self.current_state.color_mask = mask
        
    def _hash_pipeline_state(self, state: GraphicsPipelineState) -> str:
        """Create unique hash for pipeline state"""
        import hashlib
        import json
        
        # Convert state to JSON-serializable dict
        state_dict = {
            "shaders": {k.value: v for k,v in state.shader_stages.items()},
            "vertex_attributes": [attr._asdict() for attr in state.vertex_attributes],
            "shader_resources": [res._asdict() for res in state.shader_resources],
            "viewport": state.viewport._asdict() if state.viewport else None,
            "scissor": state.scissor._asdict() if state.scissor else None,
            "rasterization": state.rasterization._asdict(),
            "depth": state.depth._asdict(),
            "stencil": state.stencil._asdict(),
            "blend": state.blend._asdict(),
            "color_mask": state.color_mask._asdict(),
            "primitive_type": state.primitive_type.value,
            "patch_control_points": state.patch_control_points
        }
        
        # Create hash
        state_str = json.dumps(state_dict, sort_keys=True)
        return hashlib.sha256(state_str.encode()).hexdigest()
        
    def get_current_state(self) -> GraphicsPipelineState:
        """Get copy of current pipeline state"""
        return self.current_state
        
    def validate_state(self) -> List[str]:
        """Validate current pipeline state"""
        errors = []
        
        # Check required shaders
        if ShaderType.VERTEX not in self.current_state.shader_stages:
            errors.append("Missing vertex shader")
            
        # Check vertex attributes match shader inputs
        # TODO: Add validation against shader reflection data
        
        # Check primitive type compatibility
        if (self.current_state.primitive_type == PrimitiveType.PATCHES and
            ShaderType.TESSELLATION_CONTROL not in self.current_state.shader_stages):
            errors.append("Tessellation shaders required for patch primitives")
            
        return errors
