import struct
from .shader_compiler import ShaderCompiler
from .rasterizer import Rasterizer

class GraphicsAPI:
    def __init__(self, driver):
        self.driver = driver
        self._verify_driver_compatibility()
        
        # Resource tracking
        self.buffers = {}
        self.textures = {}
        self.framebuffers = {}
        self.programs = {}
        
        # Current state
        self.current_program = None
        self.current_framebuffer = None
        self.current_texture_units = {}
        
        # Initialize subsystems
        self.shader_compiler = ShaderCompiler()
        self.rasterizer = Rasterizer(driver)
        
        print("Graphics API initialized.")
        
    def _verify_driver_compatibility(self):
        """Verify driver has required graphics capabilities"""
        required_methods = [
            'create_buffer',
            'create_texture',
            'create_framebuffer',
            'draw_triangles'
        ]
        
        missing_methods = [method for method in required_methods 
                         if not hasattr(self.driver, method)]
                         
        if missing_methods:
            raise RuntimeError(
                f"Driver missing required graphics methods: {missing_methods}"
                "\nEnsure driver has modern graphics API support."
            )
            
        # Verify required enums exist
        required_enums = [
            'TextureFormat',
            'FilterMode',
            'WrapMode',
            'BufferType',
            'MSAASamples'
        ]
        
        missing_enums = [enum for enum in required_enums
                        if not hasattr(self.driver, enum)]
                        
        if missing_enums:
            raise RuntimeError(
                f"Driver missing required graphics enums: {missing_enums}"
                "\nEnsure driver has modern graphics type support."
            )

    def create_buffer(self, data, buffer_type="vertex", dynamic=False, map_write=False):
        """
        Create a new buffer with data
        
        Args:
            data: numpy array or bytes of buffer data
            buffer_type: "vertex", "index", "uniform", "storage", or "indirect"
            dynamic: Whether buffer will be frequently updated
            map_write: Whether buffer should be mappable for CPU writes
        """
        if not self.driver.initialized:
            raise RuntimeError("Driver not initialized.")
            
        # Convert buffer type string to enum
        type_map = {
            "vertex": self.driver.BufferType.VERTEX,
            "index": self.driver.BufferType.INDEX,
            "uniform": self.driver.BufferType.UNIFORM,
            "storage": self.driver.BufferType.STORAGE,
            "indirect": self.driver.BufferType.INDIRECT
        }
        
        if buffer_type not in type_map:
            raise ValueError(f"Invalid buffer type: {buffer_type}")
            
        # Create buffer through driver
        buffer_id = self.driver.create_buffer(
            data=data,
            buffer_type=type_map[buffer_type],
            dynamic=dynamic,
            map_write=map_write
        )
        
        # Store additional metadata
        self.buffers[buffer_id] = {
            "type": buffer_type,
            "size": len(data),
            "dynamic": dynamic,
            "map_write": map_write
        }
        
        print(f"Created {buffer_type} buffer with ID {buffer_id}, size {len(data)} bytes.")
        return buffer_id

    def delete_buffer(self, buffer_id):
        if buffer_id in self.buffers:
            buffer_info = self.buffers[buffer_id]
            self.driver.free_memory(buffer_info["virtual_address"])
            del self.buffers[buffer_id]
            print(f"Deleted buffer with ID {buffer_id}.")
        else:
            print(f"Warning: Attempted to delete non-existent buffer with ID {buffer_id}.")

    def buffer_data(self, buffer_id, data):
        if buffer_id not in self.buffers:
            raise ValueError(f"Buffer with ID {buffer_id} not found.")
        
        buffer_info = self.buffers[buffer_id]
        if len(data) > buffer_info["size"]:
            raise ValueError(f"Data size ({len(data)}) exceeds buffer size ({buffer_info['size']}) for buffer ID {buffer_id}.")

        self.driver.write_memory(buffer_info["virtual_address"], data)
        print(f"Loaded {len(data)} bytes into buffer ID {buffer_id}.")

    def draw_arrays(self, mode, first, count, vertex_buffer,
                   primitive_restart_index=None, instances=1):
        """
        Draw primitives from vertex buffer
        
        Args:
            mode: "points", "lines", "triangles", etc
            first: First vertex to draw
            count: Number of vertices
            vertex_buffer: Vertex buffer ID
            primitive_restart_index: Index value that restarts primitive
            instances: Number of instances to draw
        """
        if not self.driver.initialized:
            raise RuntimeError("Driver not initialized.")
        if not self.current_program:
            raise RuntimeError("No shader program in use.")
        if not self.current_framebuffer:
            raise RuntimeError("No framebuffer bound.")
            
        print(f"Drawing {count} vertices in {mode} mode")
        
        # Draw using driver's optimized path
        self.driver.draw_triangles(
            vertex_buffer_id=vertex_buffer,
            index_buffer_id=None,
            framebuffer_id=self.current_framebuffer,
            shader_program=self.current_program,
            num_vertices=count,
            start_vertex=first
        )
        
        # Add command to command buffer for replay/debug
        self.driver.add_command("draw_arrays", 
            mode=mode,
            first=first,
            count=count,
            vertex_buffer=vertex_buffer,
            primitive_restart_index=primitive_restart_index,
            instances=instances
        )

    def draw_indexed(self, mode, count, vertex_buffer, index_buffer,
                     index_offset=0, base_vertex=0, instances=1,
                     primitive_restart_index=None):
        """
        Draw indexed primitives
        
        Args:
            mode: "points", "lines", "triangles", etc
            count: Number of indices
            vertex_buffer: Vertex buffer ID
            index_buffer: Index buffer ID
            index_offset: Starting offset in index buffer
            base_vertex: Value added to each index
            instances: Number of instances to draw
            primitive_restart_index: Index value that restarts primitive
        """
        if not self.driver.initialized:
            raise RuntimeError("Driver not initialized.")
        if not self.current_program:
            raise RuntimeError("No shader program in use.")
        if not self.current_framebuffer:
            raise RuntimeError("No framebuffer bound.")
            
        if index_buffer not in self.buffers or self.buffers[index_buffer]["type"] != "index":
            raise ValueError(f"Invalid index buffer ID: {index_buffer}")
            
        print(f"Drawing {count} indices in {mode} mode")
        
        # Draw using driver's optimized path
        self.driver.draw_triangles(
            vertex_buffer_id=vertex_buffer,
            index_buffer_id=index_buffer,
            framebuffer_id=self.current_framebuffer,
            shader_program=self.current_program,
            num_vertices=count,
            start_vertex=index_offset
        )
        
        # Add command to command buffer
        self.driver.add_command("draw_indexed",
            mode=mode,
            count=count,
            vertex_buffer=vertex_buffer,
            index_buffer=index_buffer,
            index_offset=index_offset,
            base_vertex=base_vertex,
            instances=instances,
            primitive_restart_index=primitive_restart_index
        )

    def compile_shader(self, shader_source, shader_type="vertex"):
        return self.shader_compiler.compile_shader(shader_source, shader_type)

    def link_program(self, vertex_shader, fragment_shader):
        return self.shader_compiler.link_program(vertex_shader, fragment_shader)

    def use_program(self, program):
        if not self.shader_compiler.validate_program(program):
            raise ValueError("Invalid shader program.")
        self.current_program = program
        print(f"Using shader program: {program['id']}.")

    def create_texture(self, width, height, format="rgba8", 
                    filter="bilinear", wrap="repeat",
                    generate_mipmaps=True, aniso_level=1):
        """
        Create a new texture
        
        Args:
            width, height: Texture dimensions
            format: "r8", "rgba8", "rgba16f", "rgba32f", etc.
            filter: "nearest", "bilinear", "trilinear", "anisotropic"
            wrap: "repeat", "clamp", "mirror"
            generate_mipmaps: Whether to generate mipmaps
            aniso_level: Anisotropic filtering level (1-16)
        """
        # Convert string parameters to enums
        format_map = {
            "r8": self.driver.TextureFormat.R8,
            "rg8": self.driver.TextureFormat.RG8,
            "rgba8": self.driver.TextureFormat.RGBA8,
            "r16f": self.driver.TextureFormat.R16F,
            "rgba16f": self.driver.TextureFormat.RGBA16F,
            "r32f": self.driver.TextureFormat.R32F,
            "rgba32f": self.driver.TextureFormat.RGBA32F,
            "bc1": self.driver.TextureFormat.BC1,
            "bc3": self.driver.TextureFormat.BC3
        }
        
        filter_map = {
            "nearest": self.driver.FilterMode.NEAREST,
            "bilinear": self.driver.FilterMode.BILINEAR,
            "trilinear": self.driver.FilterMode.TRILINEAR,
            "anisotropic": self.driver.FilterMode.ANISOTROPIC
        }
        
        wrap_map = {
            "repeat": self.driver.WrapMode.REPEAT,
            "clamp": self.driver.WrapMode.CLAMP,
            "mirror": self.driver.WrapMode.MIRROR
        }
        
        if format not in format_map:
            raise ValueError(f"Invalid texture format: {format}")
        if filter not in filter_map:
            raise ValueError(f"Invalid filter mode: {filter}")
        if wrap not in wrap_map:
            raise ValueError(f"Invalid wrap mode: {wrap}")
            
        texture_id = self.driver.create_texture(
            width=width,
            height=height,
            format=format_map[format],
            filter_mode=filter_map[filter],
            wrap_mode=wrap_map[wrap],
            generate_mipmaps=generate_mipmaps,
            aniso_level=aniso_level
        )
        
        print(f"Created texture with ID {texture_id}, format {format}, size {width}x{height}")
        return texture_id
        
    def create_framebuffer(self, width, height, color_formats=None,
                          depth_format=None, stencil_format=None,
                          samples=1):
        """
        Create a new framebuffer
        
        Args:
            width, height: Framebuffer dimensions
            color_formats: List of formats for color attachments
            depth_format: Format for depth attachment (None to disable)
            stencil_format: Format for stencil attachment (None to disable)
            samples: MSAA sample count (1, 2, 4, or 8)
        """
        # Default color format if none specified
        if color_formats is None:
            color_formats = ["rgba8"]
            
        # Convert format strings to enums
        format_map = {
            "r8": self.driver.TextureFormat.R8,
            "rgba8": self.driver.TextureFormat.RGBA8,
            "rgba16f": self.driver.TextureFormat.RGBA16F,
            "r32f": self.driver.TextureFormat.R32F
        }
        
        samples_map = {
            1: self.driver.MSAASamples.MSAA_1X,
            2: self.driver.MSAASamples.MSAA_2X,
            4: self.driver.MSAASamples.MSAA_4X,
            8: self.driver.MSAASamples.MSAA_8X
        }
        
        # Convert formats to enums
        color_format_enums = []
        for fmt in color_formats:
            if fmt not in format_map:
                raise ValueError(f"Invalid color format: {fmt}")
            color_format_enums.append(format_map[fmt])
            
        depth_format_enum = None
        if depth_format:
            if depth_format not in format_map:
                raise ValueError(f"Invalid depth format: {depth_format}")
            depth_format_enum = format_map[depth_format]
            
        stencil_format_enum = None
        if stencil_format:
            if stencil_format not in format_map:
                raise ValueError(f"Invalid stencil format: {stencil_format}")
            stencil_format_enum = format_map[stencil_format]
            
        if samples not in samples_map:
            raise ValueError(f"Invalid sample count: {samples}")
            
        # Create framebuffer through driver
        fb_id = self.driver.create_framebuffer(
            width=width,
            height=height,
            color_formats=color_format_enums,
            depth_format=depth_format_enum,
            stencil_format=stencil_format_enum,
            samples=samples_map[samples]
        )
        
        # Store framebuffer info
        self.framebuffers[fb_id] = {
            "width": width,
            "height": height,
            "color_formats": color_formats,
            "depth_format": depth_format,
            "stencil_format": stencil_format,
            "samples": samples
        }
        
        print(f"Created framebuffer with ID {fb_id}, size {width}x{height}, {samples}x MSAA")
        return fb_id

    def bind_framebuffer(self, fb_id):
        """Bind framebuffer for rendering"""
        if fb_id not in self.framebuffers:
            raise ValueError(f"Invalid framebuffer ID: {fb_id}")
            
        self.current_framebuffer = fb_id
        fb_info = self.framebuffers[fb_id]
        print(f"Binding framebuffer {fb_id} ({fb_info['width']}x{fb_info['height']})")
        
    def bind_texture(self, texture_id, unit=0):
        """Bind texture to a texture unit"""
        if texture_id not in self.textures:
            raise ValueError(f"Invalid texture ID: {texture_id}")
            
        self.current_texture_units[unit] = texture_id
        print(f"Binding texture {texture_id} to unit {unit}")
        
    def set_viewport(self, x, y, width, height):
        """Set viewport for rendering"""
        if not self.current_framebuffer:
            raise RuntimeError("No framebuffer bound")
            
        fb_info = self.framebuffers[self.current_framebuffer]
        if x + width > fb_info['width'] or y + height > fb_info['height']:
            raise ValueError("Viewport exceeds framebuffer dimensions")
            
        self.current_viewport = {
            'x': x, 'y': y,
            'width': width, 'height': height
        }
        
    def clear(self, color=None, depth=None, stencil=None):
        """Clear current framebuffer attachments"""
        if not self.current_framebuffer:
            raise RuntimeError("No framebuffer bound")
            
        fb_id = self.current_framebuffer
        fb = self.framebuffers[fb_id]
        
        # Get framebuffer from driver
        driver_fb = self.driver.framebuffer_manager[fb_id]
        
        if color is not None:
            r, g, b, a = color
            driver_fb.set_clear_values(color=[r, g, b, a])
            
        if depth is not None:
            driver_fb.set_clear_values(depth=depth)
            
        if stencil is not None:
            driver_fb.set_clear_values(stencil=stencil)
            
        # Perform clear
        driver_fb.clear(
            color=color is not None,
            depth=depth is not None,
            stencil=stencil is not None
        )
        
        print(f"Cleared framebuffer {fb_id}")
        
    def set_blend_state(self, enable=True, 
                       src_factor="src_alpha",
                       dst_factor="one_minus_src_alpha",
                       blend_op="add"):
        """Set blending state"""
        if not self.current_framebuffer:
            raise RuntimeError("No framebuffer bound")
            
        # Convert string parameters to driver enums
        factor_map = {
            "zero": self.driver.BlendMode.ZERO,
            "one": self.driver.BlendMode.ONE,
            "src_color": self.driver.BlendMode.SRC_COLOR,
            "one_minus_src_color": self.driver.BlendMode.ONE_MINUS_SRC_COLOR,
            "dst_color": self.driver.BlendMode.DST_COLOR,
            "one_minus_dst_color": self.driver.BlendMode.ONE_MINUS_DST_COLOR,
            "src_alpha": self.driver.BlendMode.SRC_ALPHA,
            "one_minus_src_alpha": self.driver.BlendMode.ONE_MINUS_SRC_ALPHA,
            "dst_alpha": self.driver.BlendMode.DST_ALPHA,
            "one_minus_dst_alpha": self.driver.BlendMode.ONE_MINUS_DST_ALPHA
        }
        
        op_map = {
            "add": self.driver.BlendOp.ADD,
            "subtract": self.driver.BlendOp.SUBTRACT,
            "reverse_subtract": self.driver.BlendOp.REVERSE_SUBTRACT,
            "min": self.driver.BlendOp.MIN,
            "max": self.driver.BlendOp.MAX
        }
        
        if src_factor not in factor_map:
            raise ValueError(f"Invalid source blend factor: {src_factor}")
        if dst_factor not in factor_map:
            raise ValueError(f"Invalid destination blend factor: {dst_factor}")
        if blend_op not in op_map:
            raise ValueError(f"Invalid blend operation: {blend_op}")
            
        fb_id = self.current_framebuffer
        driver_fb = self.driver.framebuffer_manager[fb_id]
        
        # Store blend state
        self.blend_state = {
            'enable': enable,
            'src_factor': factor_map[src_factor],
            'dst_factor': factor_map[dst_factor],
            'blend_op': op_map[blend_op]
        }
        
        print(f"Set blend state: {enable}, {src_factor}, {dst_factor}, {blend_op}")


