"""
Advanced Graphics Pipeline for Virtual GPU
Integrated with Helium VGPU Pipeline
Stages: Rasterization, Clipping, Blending, Depth/Stencil Testing
"""
import numpy as np
from helium.core.pipeline import VGPUPipeline
from helium.core.memory import MemoryPool

class VGPURasterizer:
    def __init__(self, vgpu_pipeline, memory_pool):
        self.pipeline = vgpu_pipeline
        self.memory_pool = memory_pool
        self.stream_id = self.pipeline.create_stream()
        
    def rasterize(self, vertices, indices, viewport, width, height):
        # Allocate vertex and index buffers in VGPU memory
        vertex_addr = self.memory_pool.allocate_tensor(
            shape=(len(vertices), 3),
            dtype=np.float32,
            stream_id=self.stream_id
        )
        index_addr = self.memory_pool.allocate_tensor(
            shape=(len(indices),),
            dtype=np.int32,
            stream_id=self.stream_id
        )
        
        # Use VGPU for triangle setup and rasterization
        fragments = self.pipeline.execute_kernel(
            "rasterize_triangles",
            inputs={
                "vertices": vertex_addr,
                "indices": index_addr,
                "viewport": viewport,
                "dimensions": (width, height)
            },
            stream_id=self.stream_id
        )
        fragments = []
        for y in range(min_y, max_y+1):
            for x in range(min_x, max_x+1):
                bc = self.barycentric(tri, (x, y))
                if all(b >= 0 for b in bc):
                    z = sum(bc[i]*tri[i][2] for i in range(3))
                    fragments.append((x, y, z, bc, tri_idx))
        return fragments
    def barycentric(self, tri, p):
        # Compute barycentric coordinates for point p in triangle tri
        a, b, c = tri
        denom = ((b[1] - c[1])*(a[0] - c[0]) + (c[0] - b[0])*(a[1] - c[1]))
        if denom == 0:
            return (-1, -1, -1)
        w0 = ((b[1] - c[1])*(p[0] - c[0]) + (c[0] - b[0])*(p[1] - c[1])) / denom
        w1 = ((c[1] - a[1])*(p[0] - c[0]) + (a[0] - c[0])*(p[1] - c[1])) / denom
        w2 = 1 - w0 - w1
        return (w0, w1, w2)
    def in_viewport(self, v, viewport):
        x, y, _ = v
        x0, y0, x1, y1 = viewport
        return x0 <= x <= x1 and y0 <= y <= y1

class Clipper:
    def clip(self, triangles, clip_rect):
        # For demo: trivial reject if all verts outside
        clipped = []
        for tri in triangles:
            if any(self.in_clip(v, clip_rect) for v in tri):
                clipped.append(tri)
        return clipped
    def in_clip(self, v, clip_rect):
        x, y, _ = v
        x0, y0, x1, y1 = clip_rect
        return x0 <= x <= x1 and y0 <= y <= y1

class Blender:
    def blend(self, src_color, dst_color, mode='alpha', alpha=1.0):
        if mode == 'alpha':
            return alpha * src_color + (1 - alpha) * dst_color
        elif mode == 'add':
            return np.clip(src_color + dst_color, 0, 1)
        elif mode == 'multiply':
            return src_color * dst_color
        else:
            return src_color

class DepthStencil:
    def __init__(self, width, height):
        self.depth = np.full((height, width), np.inf, dtype=np.float32)
        self.stencil = np.zeros((height, width), dtype=np.uint8)
    def test(self, x, y, z, stencil_ref=1):
        # Depth test: less
        if z < self.depth[y, x]:
            self.depth[y, x] = z
            self.stencil[y, x] = stencil_ref
            return True
        return False
    def clear(self, depth_val=np.inf, stencil_val=0):
        self.depth.fill(depth_val)
        self.stencil.fill(stencil_val)

from .texture_buffer_manager import Texture, Buffer, Framebuffer

class GraphicsPipeline:
    def __init__(self, width, height, num_framebuffers=1, channels=4):
        self.rasterizer = Rasterizer()
        self.clipper = Clipper()
        self.blender = Blender()
        self.depth_stencil = DepthStencil(width, height)
        self.framebuffer = Framebuffer(width, height, num_targets=num_framebuffers, channels=channels)
        self.textures = {}
        self.buffers = {}
        self.active_texture = None
        self.active_vertex_buffer = None
        self.active_index_buffer = None
    def upload_texture(self, name, img):
        h, w, c = img.shape
        tex = Texture(w, h, c, img.dtype)
        tex.upload(img)
        self.textures[name] = tex
    def bind_texture(self, name):
        self.active_texture = self.textures[name]
    def sample_texture(self, u, v):
        if self.active_texture is not None:
            return self.active_texture.sample(u, v)
        return None
    def upload_vertex_buffer(self, name, data):
        self.buffers[name] = Buffer(data)
    def bind_vertex_buffer(self, name):
        self.active_vertex_buffer = self.buffers[name]
    def upload_index_buffer(self, name, data):
        self.buffers[name] = Buffer(data)
    def bind_index_buffer(self, name):
        self.active_index_buffer = self.buffers[name]
    def bind_framebuffer(self, idx=0):
        self.framebuffer.bind(idx)
    def clear_framebuffer(self, color=0):
        self.framebuffer.clear(color)
    def run(self, viewport, clip_rect, blend_mode='alpha', alpha=1.0, shader_program=None):
        # Use bound vertex/index buffers
        vertices = self.active_vertex_buffer.get()
        indices = self.active_index_buffer.get()
        width, height = self.framebuffer.width, self.framebuffer.height
        # 1. Rasterization
        fragments = self.rasterizer.rasterize(vertices, indices, viewport, width, height)
        # 2. Clipping (fragments outside clip_rect)
        x0, y0, x1, y1 = clip_rect
        fragments = [f for f in fragments if x0 <= f[0] <= x1 and y0 <= f[1] <= y1]
        # 3. Depth/Stencil Test and Fragment Processing
        for frag in fragments:
            x, y, z, bary, tri_idx = frag
            if not (0 <= x < width and 0 <= y < height):
                continue
            if not self.depth_stencil.test(x, y, z):
                continue
            # 4. Fragment shading (with texture sampling if needed)
            if shader_program is not None:
                color = shader_program.run_fragment(bary, tri_idx, vertices, self.active_texture)
            else:
                color = np.ones(4)  # Default white
            # 5. Blending
            prev = self.framebuffer.read(x, y)
            out = self.blender.blend(color, prev, mode=blend_mode, alpha=alpha)
            self.framebuffer.write(x, y, out)
        return self.framebuffer.targets
