import struct
import numpy as np
from enum import Enum, auto
from typing import List, Dict, Tuple, Optional

class Rasterizer:
    def __init__(self, driver):
        self.driver = driver
        print("Rasterizer initialized.")

    def _compute_edge_function(self, x0, y0, x1, y1, px, py):
        """
        Compute edge function for point (px,py) against edge (x0,y0)->(x1,y1)
        Positive value means point is on left side of edge
        """
        return (px - x0) * (y1 - y0) - (py - y0) * (x1 - x0)
        
    def _is_top_left_edge(self, x0, y0, x1, y1):
        """Check if edge is top or left edge for tie-breaking rules"""
        return (y0 == y1 and x0 < x1) or y0 < y1
        
    def _compute_perspective_w(self, barycentric, w0, w1, w2):
        """Compute perspective-correct interpolation weight"""
        return 1.0 / (barycentric[0]/w0 + barycentric[1]/w1 + barycentric[2]/w2)
        
    def _interpolate_perspective(self, barycentric, attr0, attr1, attr2, w0, w1, w2):
        """Perspective-correct attribute interpolation"""
        w = self._compute_perspective_w(barycentric, w0, w1, w2)
        return w * (
            attr0 * barycentric[0]/w0 +
            attr1 * barycentric[1]/w1 +
            attr2 * barycentric[2]/w2
        )
        
    def rasterize_triangle(self, v0, v1, v2, framebuffer_width, framebuffer_height, 
                          msaa_samples=1, conservative=False):
        """
        Rasterize triangle with modern features:
        - Edge function rasterization
        - Perspective-correct interpolation
        - MSAA support
        - Conservative rasterization option
        
        Args:
            v0, v1, v2: vertices with (x,y,z,w) coordinates and attributes
            msaa_samples: number of MSAA samples (1, 2, 4, or 8)
            conservative: use conservative rasterization
        """
        fragments = []
        
        # Extract positions and compute W for perspective correction
        pos0, pos1, pos2 = v0['position'], v1['position'], v2['position']
        w0, w1, w2 = pos0[3], pos1[3], pos2[3]
        
        # Convert to screen space
        screen0 = [(pos0[0]/w0 + 1)*0.5*framebuffer_width, 
                  (pos0[1]/w0 + 1)*0.5*framebuffer_height]
        screen1 = [(pos1[0]/w1 + 1)*0.5*framebuffer_width,
                  (pos1[1]/w1 + 1)*0.5*framebuffer_height]
        screen2 = [(pos2[0]/w2 + 1)*0.5*framebuffer_width,
                  (pos2[1]/w2 + 1)*0.5*framebuffer_height]
        
        # Compute bounding box
        min_x = max(0, int(min(screen0[0], screen1[0], screen2[0])))
        max_x = min(framebuffer_width - 1, int(max(screen0[0], screen1[0], screen2[0])))
        min_y = max(0, int(min(screen0[1], screen1[1], screen2[1])))
        max_y = min(framebuffer_height - 1, int(max(screen0[1], screen1[1], screen2[1])))
        
        # For conservative rasterization, expand bounding box
        if conservative:
            min_x -= 1
            min_y -= 1
            max_x += 1
            max_y += 1
            
        # Compute edge functions for triangle edges
        def edge01(px, py): return self._compute_edge_function(
            screen0[0], screen0[1], screen1[0], screen1[1], px, py)
        def edge12(px, py): return self._compute_edge_function(
            screen1[0], screen1[1], screen2[0], screen2[1], px, py)
        def edge20(px, py): return self._compute_edge_function(
            screen2[0], screen2[1], screen0[0], screen0[1], px, py)
            
        # Determine fill rules for edges
        is_top_left01 = self._is_top_left_edge(screen0[0], screen0[1], 
                                              screen1[0], screen1[1])
        is_top_left12 = self._is_top_left_edge(screen1[0], screen1[1],
                                              screen2[0], screen2[1])
        is_top_left20 = self._is_top_left_edge(screen2[0], screen2[1],
                                              screen0[0], screen0[1])
                                              
        # Area of the triangle for barycentric coordinates
        area = edge01(screen2[0], screen2[1])
        if area <= 0:  # Skip back-facing triangles
            return []
            
        # MSAA sample positions (for 2x2 grid)
        if msaa_samples == 4:
            sample_positions = [
                (-0.375, -0.375), (0.375, -0.375),
                (-0.375, 0.375), (0.375, 0.375)
            ]
        else:
            sample_positions = [(0.0, 0.0)]
            
        # Rasterize
        for y in range(min_y, max_y + 1):
            for x in range(min_x, max_x + 1):
                covered_samples = 0
                sample_fragments = []
                
                # Test each sample position
                for sample_x, sample_y in sample_positions:
                    px, py = x + sample_x, y + sample_y
                    
                    # Compute edge values
                    e01 = edge01(px, py)
                    e12 = edge12(px, py)
                    e20 = edge20(px, py)
                    
                    # Apply fill rules
                    inside = (
                        (e01 > 0 or (e01 == 0 and is_top_left01)) and
                        (e12 > 0 or (e12 == 0 and is_top_left12)) and
                        (e20 > 0 or (e20 == 0 and is_top_left20))
                    )
                    
                    if inside or (conservative and (e01 >= 0 and e12 >= 0 and e20 >= 0)):
                        covered_samples += 1
                        
                        # Compute barycentric coordinates
                        b0 = e12 / area
                        b1 = e20 / area
                        b2 = e01 / area
                        
                        # Interpolate Z perspectively
                        z = self._interpolate_perspective(
                            (b0, b1, b2),
                            pos0[2], pos1[2], pos2[2],
                            w0, w1, w2
                        )
                        
                        # Interpolate attributes
                        attributes = {}
                        for attr in v0['attributes'].keys():
                            attributes[attr] = self._interpolate_perspective(
                                (b0, b1, b2),
                                v0['attributes'][attr],
                                v1['attributes'][attr],
                                v2['attributes'][attr],
                                w0, w1, w2
                            )
                            
                        sample_fragments.append({
                            "x": x,
                            "y": y,
                            "sample_x": sample_x,
                            "sample_y": sample_y,
                            "depth": z,
                            "attributes": attributes,
                            "barycentric": (b0, b1, b2)
                        })
                
                if covered_samples > 0:
                    fragment = {
                        "x": x,
                        "y": y,
                        "samples": sample_fragments,
                        "coverage": covered_samples / len(sample_positions)
                    }
                    fragments.append(fragment)
        
        return fragments

    class HiZBuffer:
        """Hierarchical Z-buffer for early depth testing"""
        def __init__(self, width, height):
            self.width = width
            self.height = height
            self.levels = []
            
            # Build mip chain
            current_w, current_h = width, height
            while current_w > 0 and current_h > 0:
                self.levels.append(np.full((current_h, current_w), 1.0))
                current_w //= 2
                current_h //= 2
                
        def update_region(self, x, y, z):
            """Update Hi-Z pyramid after depth write"""
            level = 0
            while level < len(self.levels):
                level_x, level_y = x >> level, y >> level
                if level_x >= self.levels[level].shape[1] or level_y >= self.levels[level].shape[0]:
                    break
                    
                # Update min depth
                self.levels[level][level_y, level_x] = min(
                    self.levels[level][level_y, level_x], z)
                level += 1
                
        def test_region(self, min_x, min_y, max_x, max_y, z):
            """Test if region could be visible (not occluded)"""
            # Find appropriate mip level
            width = max_x - min_x + 1
            height = max_y - min_y + 1
            level = max(0, int(np.log2(max(width, height))))
            
            if level >= len(self.levels):
                return True
                
            # Scale coordinates to mip level
            level_min_x = min_x >> level
            level_min_y = min_y >> level
            level_max_x = max_x >> level
            level_max_y = max_y >> level
            
            # Get min depth in region
            min_depth = np.inf
            for ly in range(level_min_y, level_max_y + 1):
                for lx in range(level_min_x, level_max_x + 1):
                    if ly < self.levels[level].shape[0] and lx < self.levels[level].shape[1]:
                        min_depth = min(min_depth, self.levels[level][ly, lx])
                        
            return z <= min_depth
            
    def process_fragments(self, fragments, fragment_shader_program, chip_id=0, 
                         early_z=True, hierarchical_z=True):
        """
        Process fragments using the fragment shader with early-Z and Hi-Z optimizations
        
        Args:
            fragments: List of fragments to process
            fragment_shader_program: Shader program to execute
            chip_id: GPU chip to use
            early_z: Enable early-Z optimization
            hierarchical_z: Enable hierarchical Z-buffer
        """
        processed_fragments = []
        
        # Initialize Hi-Z buffer if needed
        hiz = None
        if hierarchical_z:
            fb_width = max(f["x"] for f in fragments) + 1
            fb_height = max(f["y"] for f in fragments) + 1
            hiz = self.HiZBuffer(fb_width, fb_height)
        
        # Group fragments into tiles for better cache coherency
        TILE_SIZE = 32
        tiles = {}
        for fragment in fragments:
            tile_x = fragment["x"] // TILE_SIZE
            tile_y = fragment["y"] // TILE_SIZE
            if (tile_x, tile_y) not in tiles:
                tiles[(tile_x, tile_y)] = []
            tiles[(tile_x, tile_y)].append(fragment)
            
        # Process tiles
        for (tile_x, tile_y), tile_fragments in tiles.items():
            # Sort fragments by depth for early-Z efficiency
            if early_z:
                tile_fragments.sort(key=lambda f: f["samples"][0]["depth"])
                
            # Hi-Z test for entire tile
            tile_min_x = tile_x * TILE_SIZE
            tile_min_y = tile_y * TILE_SIZE
            tile_max_x = min(tile_min_x + TILE_SIZE - 1, fb_width - 1)
            tile_max_y = min(tile_min_y + TILE_SIZE - 1, fb_height - 1)
            
            if hierarchical_z:
                min_depth = min(s["depth"] for f in tile_fragments for s in f["samples"])
                if not hiz.test_region(tile_min_x, tile_min_y, tile_max_x, tile_max_y, min_depth):
                    continue
                    
            # Process fragments in tile
            for fragment in tile_fragments:
                # Early-Z test (per sample)
                if early_z:
                    depth_test_passed = False
                    for sample in fragment["samples"]:
                        if self._depth_test(sample["depth"], fragment["x"], fragment["y"]):
                            depth_test_passed = True
                            break
                    if not depth_test_passed:
                        continue
                        
                # Execute fragment shader
                processed_samples = []
                for sample in fragment["samples"]:
                    color = self._execute_fragment_shader(sample, fragment_shader_program, chip_id)
                    processed_sample = {
                        "sample_x": sample["sample_x"],
                        "sample_y": sample["sample_y"],
                        "depth": sample["depth"],
                        "color": color
                    }
                    processed_samples.append(processed_sample)
                    
                    # Update Hi-Z buffer
                    if hierarchical_z:
                        hiz.update_region(fragment["x"], fragment["y"], sample["depth"])
                        
                processed_fragment = {
                    "x": fragment["x"],
                    "y": fragment["y"],
                    "samples": processed_samples,
                    "coverage": fragment["coverage"]
                }
                processed_fragments.append(processed_fragment)
                
        return processed_fragments

    def _execute_fragment_shader(self, fragment, fragment_shader_program, chip_id):
        """
        Simulate execution of a fragment shader for a single fragment.
        """
        # In a real implementation, this would dispatch the shader instructions
        # to an available SM and execute them using the SM's cores.
        
        # For simulation, just return a dummy color based on fragment position
        r = (fragment["x"] % 256) / 255.0
        g = (fragment["y"] % 256) / 255.0
        b = fragment["depth"]
        a = 1.0
        
        return (r, g, b, a)

    def _depth_test(self, fragment_depth: float, x: int, y: int, 
                    depth_func=lambda a,b: a < b) -> bool:
        """
        Test fragment depth against depth buffer
        
        Args:
            fragment_depth: Fragment's depth value
            x, y: Fragment coordinates
            depth_func: Depth comparison function
            
        Returns:
            bool: True if fragment passes depth test
        """
        depth_buffer_index = y * self.framebuffer_width + x
        current_depth = self.depth_buffer[depth_buffer_index]
        return depth_func(fragment_depth, current_depth)
        
    def depth_test(self, fragments: List[Dict], depth_buffer_bytes: bytes,
                  framebuffer_width: int, 
                  depth_func: str = 'LESS',
                  depth_write: bool = True,
                  stencil_enabled: bool = False) -> Tuple[List[Dict], bytes]:
        """
        Perform depth and optional stencil testing on fragments
        
        Args:
            fragments: List of fragments to test
            depth_buffer_bytes: Current depth buffer
            framebuffer_width: Width of framebuffer
            depth_func: Depth comparison function ('LESS', 'LEQUAL', etc)
            depth_write: Whether to write passing fragments to depth buffer
            stencil_enabled: Whether to perform stencil testing
            
        Returns:
            Tuple of (passed fragments, modified depth buffer)
        """
        self.framebuffer_width = framebuffer_width
        
        # Set up depth comparison function
        depth_funcs = {
            'NEVER': lambda a,b: False,
            'LESS': lambda a,b: a < b,
            'EQUAL': lambda a,b: abs(a - b) < 1e-6,
            'LEQUAL': lambda a,b: a <= b,
            'GREATER': lambda a,b: a > b,
            'NOTEQUAL': lambda a,b: abs(a - b) >= 1e-6,
            'GEQUAL': lambda a,b: a >= b,
            'ALWAYS': lambda a,b: True
        }
        depth_compare = depth_funcs[depth_func]
        
        # Unpack depth buffer
        self.depth_buffer = []
        if depth_buffer_bytes:
            for i in range(0, len(depth_buffer_bytes), 4):
                depth = struct.unpack("f", bytes(bytearray(depth_buffer_bytes[i:i+4])))[0]
                self.depth_buffer.append(depth)
        else:
            self.depth_buffer = [1.0] * (framebuffer_width * framebuffer_width)
            
        passed_fragments = []
        for fragment in fragments:
            x, y = fragment["x"], fragment["y"]
            passed_samples = []
            
            for sample in fragment["samples"]:
                if self._depth_test(sample["depth"], x, y, depth_compare):
                    passed_samples.append(sample)
                    
                    # Write depth if enabled
                    if depth_write:
                        depth_idx = y * framebuffer_width + x
                        self.depth_buffer[depth_idx] = sample["depth"]
                        
            if passed_samples:
                fragment = fragment.copy()
                fragment["samples"] = passed_samples
                fragment["coverage"] = len(passed_samples) / len(fragment["samples"])
                passed_fragments.append(fragment)
                
        # Pack modified depth buffer
        modified_depth_buffer = b''.join(
            [struct.pack("f", d) for d in self.depth_buffer])
            
        return passed_fragments, modified_depth_buffer

    class BlendMode(Enum):
        """Blend modes for color blending"""
        ZERO = auto()
        ONE = auto()
        SRC_COLOR = auto()
        ONE_MINUS_SRC_COLOR = auto()
        DST_COLOR = auto()
        ONE_MINUS_DST_COLOR = auto()
        SRC_ALPHA = auto()
        ONE_MINUS_SRC_ALPHA = auto()
        DST_ALPHA = auto()
        ONE_MINUS_DST_ALPHA = auto()
        
    class BlendOp(Enum):
        """Blend operations"""
        ADD = auto()
        SUBTRACT = auto()
        REVERSE_SUBTRACT = auto()
        MIN = auto()
        MAX = auto()
        
    def _blend_factor(self, mode: BlendMode, src_color, dst_color) -> np.ndarray:
        """Calculate blend factor based on mode"""
        if mode == self.BlendMode.ZERO:
            return np.zeros(4)
        elif mode == self.BlendMode.ONE:
            return np.ones(4)
        elif mode == self.BlendMode.SRC_COLOR:
            return src_color
        elif mode == self.BlendMode.ONE_MINUS_SRC_COLOR:
            return 1.0 - src_color
        elif mode == self.BlendMode.DST_COLOR:
            return dst_color
        elif mode == self.BlendMode.ONE_MINUS_DST_COLOR:
            return 1.0 - dst_color
        elif mode == self.BlendMode.SRC_ALPHA:
            return np.full(4, src_color[3])
        elif mode == self.BlendMode.ONE_MINUS_SRC_ALPHA:
            return np.full(4, 1.0 - src_color[3])
        elif mode == self.BlendMode.DST_ALPHA:
            return np.full(4, dst_color[3])
        elif mode == self.BlendMode.ONE_MINUS_DST_ALPHA:
            return np.full(4, 1.0 - dst_color[3])
            
    def _blend_operation(self, op: BlendOp, src: np.ndarray, dst: np.ndarray) -> np.ndarray:
        """Apply blend operation"""
        if op == self.BlendOp.ADD:
            return src + dst
        elif op == self.BlendOp.SUBTRACT:
            return src - dst
        elif op == self.BlendOp.REVERSE_SUBTRACT:
            return dst - src
        elif op == self.BlendOp.MIN:
            return np.minimum(src, dst)
        elif op == self.BlendOp.MAX:
            return np.maximum(src, dst)
            
    def write_to_framebuffer(self, fragments: List[Dict], color_buffer: bytearray,
                          framebuffer_width: int, 
                          blend_enable: bool = True,
                          src_blend: BlendMode = BlendMode.SRC_ALPHA,
                          dst_blend: BlendMode = BlendMode.ONE_MINUS_SRC_ALPHA,
                          blend_op: BlendOp = BlendOp.ADD) -> bytearray:
        """
        Write fragments to framebuffer with MSAA resolve and blending
        
        Args:
            fragments: List of fragments to write
            color_buffer: Current framebuffer contents
            framebuffer_width: Width of framebuffer
            blend_enable: Whether to enable blending
            src_blend: Source blend factor
            dst_blend: Destination blend factor
            blend_op: Blend operation
            
        Returns:
            Modified color buffer
        """
        for fragment in fragments:
            x, y = fragment["x"], fragment["y"]
            buffer_index = (y * framebuffer_width + x) * 4
            
            # Read current framebuffer color
            dst_color = np.array([
                color_buffer[buffer_index] / 255.0,
                color_buffer[buffer_index + 1] / 255.0,
                color_buffer[buffer_index + 2] / 255.0,
                color_buffer[buffer_index + 3] / 255.0
            ])
            
            # Resolve MSAA samples
            if len(fragment["samples"]) > 1:
                # Weight colors by coverage
                src_color = np.zeros(4)
                total_weight = 0.0
                
                for sample in fragment["samples"]:
                    weight = fragment["coverage"] / len(fragment["samples"])
                    src_color += np.array(sample["color"]) * weight
                    total_weight += weight
                    
                if total_weight > 0:
                    src_color /= total_weight
            else:
                src_color = np.array(fragment["samples"][0]["color"])
                
            # Apply blending if enabled
            if blend_enable:
                src_factor = self._blend_factor(src_blend, src_color, dst_color)
                dst_factor = self._blend_factor(dst_blend, src_color, dst_color)
                
                final_color = self._blend_operation(
                    blend_op,
                    src_color * src_factor,
                    dst_color * dst_factor
                )
            else:
                final_color = src_color
                
            # Clamp and convert to 8-bit
            final_color = np.clip(final_color, 0.0, 1.0)
            color_buffer[buffer_index] = int(final_color[0] * 255)
            color_buffer[buffer_index + 1] = int(final_color[1] * 255)
            color_buffer[buffer_index + 2] = int(final_color[2] * 255)
            color_buffer[buffer_index + 3] = int(final_color[3] * 255)
            
        return color_buffer


