from typing import Dict, List, Optional, Union, NamedTuple
from dataclasses import dataclass

from .graphics_types_db import GraphicsTypeDB
from .state_enums import (
    CullMode, FrontFace, CompareOp, StencilOp,
    BlendFactor, BlendOp
)
"""Type definitions for graphics subsystem"""

from enum import Enum, auto
from typing import Dict, List, Optional, Union, NamedTuple, Tuple
from dataclasses import dataclass

from .graphics_types_db import GraphicsTypeDB

# Initialize the database
db = GraphicsTypeDB()

class ShaderType(Enum):
    """Shader types backed by database"""
    VERTEX = "vertex"
    FRAGMENT = "fragment" 
    GEOMETRY = "geometry"
    COMPUTE = "compute"
    TESSELLATION_CONTROL = "tess_control"
    TESSELLATION_EVALUATION = "tess_eval"

    def __init__(self, value):
        super().__init__()
        self._db_value = db.get_shader_type(value)

class PrimitiveType(Enum):
    """Primitive types backed by database"""
    POINTS = 1
    LINES = 2  
    LINE_STRIP = 3
    TRIANGLES = 4
    TRIANGLE_STRIP = 5
    TRIANGLE_FAN = 6
    PATCHES = 7

    def __init__(self, value):
        super().__init__()
        self._db_value = db.get_primitive_type(value)

class FormatType(Enum):
    """Format types backed by database"""
    FLOAT32 = 1
    VEC2 = 2
    VEC3 = 3
    VEC4 = 4
    INT32 = 5
    UINT32 = 6
    RGBA8_UNORM = 7

    def __init__(self, value):
        super().__init__()
        self._db_value = db.get_format_type(value)

@dataclass
class Viewport:
    x: float = 0
    y: float = 0
    width: float = 0
    height: float = 0
    min_depth: float = 0.0
    max_depth: float = 1.0

@dataclass
class Scissor:
    x: int = 0
    y: int = 0
    width: int = 0
    height: int = 0

@dataclass
class RasterizationState:
    cull_mode: CullMode = CullMode.NONE
    front_face: FrontFace = FrontFace.CCW
    polygon_mode: str = "FILL"  # FILL, LINE, POINT
    depth_bias: float = 0.0
    depth_bias_clamp: float = 0.0
    depth_bias_slope_factor: float = 0.0

@dataclass
class DepthState:
    test_enable: bool = False
    write_enable: bool = False
    compare_op: CompareOp = CompareOp.LESS

@dataclass
class StencilState:
    test_enable: bool = False
    write_mask: int = 0xFF
    compare_mask: int = 0xFF
    reference: int = 0
    compare_op: CompareOp = CompareOp.ALWAYS
    pass_op: StencilOp = StencilOp.KEEP
    fail_op: StencilOp = StencilOp.KEEP
    depth_fail_op: StencilOp = StencilOp.KEEP

@dataclass
class BlendState:
    enable: bool = False
    src_color_factor: BlendFactor = BlendFactor.ONE
    dst_color_factor: BlendFactor = BlendFactor.ZERO
    color_op: BlendOp = BlendOp.ADD
    src_alpha_factor: BlendFactor = BlendFactor.ONE
    dst_alpha_factor: BlendFactor = BlendFactor.ZERO
    alpha_op: BlendOp = BlendOp.ADD

@dataclass
class ColorMask:
    red: bool = True
    green: bool = True
    blue: bool = True
    alpha: bool = True

@dataclass
class VertexAttribute:
    location: int
    format: FormatType
    offset: int = 0
    binding: int = 0
    stride: int = 0

@dataclass
class ShaderResource:
    binding: int
    type: str  # UNIFORM_BUFFER, STORAGE_BUFFER, SAMPLED_IMAGE, STORAGE_IMAGE
    stages: List[ShaderType]
    format: Optional[FormatType] = None

@dataclass
class GraphicsPipelineState:
    shader_stages: Dict[ShaderType, bytes] = None
    vertex_attributes: List[VertexAttribute] = None
    shader_resources: List[ShaderResource] = None
    viewport: Optional[Viewport] = None
    scissor: Optional[Scissor] = None
    rasterization: RasterizationState = None
    depth: DepthState = None
    stencil: StencilState = None
    blend: BlendState = None
    color_mask: ColorMask = None
    primitive_type: PrimitiveType = None
    patch_control_points: int = 0

    def __post_init__(self):
        self.shader_stages = self.shader_stages or {}
        self.vertex_attributes = self.vertex_attributes or []
        self.shader_resources = self.shader_resources or []
        self.rasterization = self.rasterization or RasterizationState()
        self.depth = self.depth or DepthState()
        self.stencil = self.stencil or StencilState()
        self.blend = self.blend or BlendState()
        self.color_mask = self.color_mask or ColorMask()
        self.primitive_type = self.primitive_type or PrimitiveType.TRIANGLES
    TRIANGLE_STRIP = auto()
    TRIANGLE_FAN = auto()
    PATCHES = auto()

class CullMode(Enum):
    NONE = auto()
    FRONT = auto()
    BACK = auto()
    FRONT_AND_BACK = auto()

class FrontFace(Enum):
    CW = auto()
    CCW = auto()

class CompareOp(Enum):
    NEVER = auto()
    LESS = auto()
    EQUAL = auto()
    LESS_EQUAL = auto()
    GREATER = auto()
    NOT_EQUAL = auto()
    GREATER_EQUAL = auto()
    ALWAYS = auto()

class StencilOp(Enum):
    KEEP = auto()
    ZERO = auto()
    REPLACE = auto()
    INCREMENT_CLAMP = auto()
    DECREMENT_CLAMP = auto()
    INVERT = auto()
    INCREMENT_WRAP = auto()
    DECREMENT_WRAP = auto()

class Viewport(NamedTuple):
    x: int
    y: int
    width: int
    height: int
    min_depth: float = 0.0
    max_depth: float = 1.0

class Scissor(NamedTuple):
    x: int
    y: int
    width: int
    height: int

class ColorMask(NamedTuple):
    r: bool = True
    g: bool = True
    b: bool = True
    a: bool = True

class StencilState(NamedTuple):
    fail_op: StencilOp = StencilOp.KEEP
    pass_op: StencilOp = StencilOp.KEEP
    depth_fail_op: StencilOp = StencilOp.KEEP
    compare_op: CompareOp = CompareOp.ALWAYS
    compare_mask: int = 0xFF
    write_mask: int = 0xFF
    reference: int = 0

class DepthState(NamedTuple):
    test_enable: bool = True
    write_enable: bool = True
    compare_op: CompareOp = CompareOp.LESS
    bounds_test: bool = False
    min_bounds: float = 0.0
    max_bounds: float = 1.0

class RasterizationState(NamedTuple):
    cull_mode: CullMode = CullMode.BACK
    front_face: FrontFace = FrontFace.CCW
    depth_bias_enable: bool = False
    depth_bias_constant: float = 0.0
    depth_bias_slope: float = 0.0
    depth_bias_clamp: float = 0.0
    depth_clamp_enable: bool = False
    rasterizer_discard_enable: bool = False

class BlendState(NamedTuple):
    enable: bool = False
    src_color_factor: str = "src_alpha"
    dst_color_factor: str = "one_minus_src_alpha"
    color_op: str = "add"
    src_alpha_factor: str = "one"
    dst_alpha_factor: str = "zero"
    alpha_op: str = "add"
    blend_constants: tuple = (0.0, 0.0, 0.0, 0.0)

class ShaderResource(NamedTuple):
    type: str  # "uniform_buffer", "sampler", "storage_buffer", etc.
    binding: int
    set: int = 0
    stages: List[ShaderType] = None  # None means all stages

class VertexAttribute(NamedTuple):
    location: int
    format: str  # "float", "vec2", "vec3", "vec4", etc.
    offset: int = 0
    stride: int = 0
    input_rate: str = "vertex"  # or "instance"

class GraphicsPipelineState:
    """Full graphics pipeline state"""
    def __init__(self):
        self.shader_stages: Dict[ShaderType, str] = {}
        self.vertex_attributes: List[VertexAttribute] = []
        self.shader_resources: List[ShaderResource] = []
        self.viewport: Optional[Viewport] = None
        self.scissor: Optional[Scissor] = None
        self.rasterization: RasterizationState = RasterizationState()
        self.depth: DepthState = DepthState()
        self.stencil: StencilState = StencilState()
        self.blend: BlendState = BlendState()
        self.color_mask: ColorMask = ColorMask()
        self.primitive_type: PrimitiveType = PrimitiveType.TRIANGLES
        self.patch_control_points: int = 3  # For tessellation
