"""
Tensor Visualization Module
Provides various chart types for tensor data visualization
"""
from ...virtual_gpu_driver.src.graphics.graphics_pipeline import VGPURasterizer
import array
from typing import Tuple, List, Union, Optional

class TensorVisualizer:
    def __init__(self, rasterizer: VGPURasterizer):
        self.rasterizer = rasterizer
        self.width = 800
        self.height = 600
        self.margin = 50
        
    def plot_line(self, data: array.array,
                  color: Tuple[float, float, float, float] = (0.0, 0.4, 0.8, 1.0),
                  title: Optional[str] = None) -> None:
        """Draw line chart from 1D tensor data"""
        # Normalize data to fit display area
        min_val = min(data)
        max_val = max(data)
        scale_y = (self.height - 2 * self.margin) / (max_val - min_val)
        scale_x = (self.width - 2 * self.margin) / (len(data) - 1)
        
        # Generate vertices and indices
        vertices = array.array('f')
        indices = array.array('l')
        
        for i, val in enumerate(data):
            x = self.margin + i * scale_x
            y = self.margin + (val - min_val) * scale_y
            vertices.extend([x, y, 0.0])
            
            if i < len(data) - 1:
                indices.extend([i, i + 1])
                
        # Draw axes
        self._draw_axes(min_val, max_val, len(data))
        
        # Draw data lines
        self.rasterizer.draw_lines(vertices, indices, line_width=2.0, color=color)
        
        # Draw title if provided
        if title:
            self._draw_text(title, self.width // 2, self.margin // 2)
            
    def plot_heatmap(self, data: array.array,
                    shape: Tuple[int, int],
                    colormap: str = 'viridis') -> None:
        """Draw heatmap from 2D tensor data"""
        # Scale data to fit display
        cell_width = (self.width - 2 * self.margin) / shape[1]
        cell_height = (self.height - 2 * self.margin) / shape[0]
        
        # Create vertex grid
        vertices = array.array('f')
        colors = array.array('f')
        indices = array.array('l')
        
        for i in range(shape[0]):
            for j in range(shape[1]):
                x = self.margin + j * cell_width
                y = self.margin + i * cell_height
                
                # Get color from colormap
                value = data[i * shape[1] + j]
                color = self._get_colormap_color(value, colormap)
                
                # Add vertex
                vertex_idx = len(vertices) // 3
                vertices.extend([x, y, 0.0])
                colors.extend(color)
                
                # Add indices for cell quad
                if i < shape[0] - 1 and j < shape[1] - 1:
                    indices.extend([
                        vertex_idx,
                        vertex_idx + 1,
                        vertex_idx + shape[1],
                        vertex_idx + shape[1] + 1
                    ])
                    
        # Draw heatmap cells
        self.rasterizer.draw_quads(vertices, indices, colors)
        
    def plot_bar(self, data: array.array,
                 color: Tuple[float, float, float, float] = (0.3, 0.6, 0.9, 1.0),
                 title: Optional[str] = None) -> None:
        """Draw bar chart from 1D tensor data"""
        # Calculate bar dimensions
        bar_width = (self.width - 2 * self.margin) / len(data) * 0.8
        gap = bar_width * 0.25
        
        # Normalize data
        max_val = max(data)
        scale_y = (self.height - 2 * self.margin) / max_val
        
        # Generate vertices for bars
        vertices = array.array('f')
        indices = array.array('l')
        
        for i, val in enumerate(data):
            x = self.margin + i * (bar_width + gap)
            y = self.margin
            height = val * scale_y
            
            # Add bar vertices
            vertex_idx = len(vertices) // 3
            vertices.extend([
                x, y, 0.0,  # Bottom left
                x + bar_width, y, 0.0,  # Bottom right
                x, y + height, 0.0,  # Top left
                x + bar_width, y + height, 0.0  # Top right
            ])
            
            # Add indices for bar quad
            indices.extend([
                vertex_idx, vertex_idx + 1, vertex_idx + 2,
                vertex_idx + 1, vertex_idx + 2, vertex_idx + 3
            ])
            
        # Draw axes
        self._draw_axes(0, max_val, len(data))
        
        # Draw bars
        self.rasterizer.draw_quads(vertices, indices, color=color)
        
        # Draw title if provided
        if title:
            self._draw_text(title, self.width // 2, self.margin // 2)
            
    def _draw_axes(self, min_val: float, max_val: float, num_points: int) -> None:
        """Draw coordinate axes with labels"""
        # Generate axis lines
        axis_vertices = array.array('f', [
            self.margin, self.margin, 0.0,  # Origin
            self.margin, self.height - self.margin, 0.0,  # Y-axis
            self.margin, self.margin, 0.0,  # Origin
            self.width - self.margin, self.margin, 0.0  # X-axis
        ])
        
        axis_indices = array.array('l', [0, 1, 2, 3])
        
        # Draw axes
        self.rasterizer.draw_lines(
            axis_vertices,
            axis_indices,
            line_width=1.0,
            color=(0.3, 0.3, 0.3, 1.0)
        )
        
    def _draw_text(self, text: str, x: float, y: float,
                   color: Tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0)) -> None:
        """Draw text at specified position"""
        self.rasterizer.draw_text(text, (x, y), color)
        
    def _get_colormap_color(self, value: float, colormap: str) -> List[float]:
        """Get color from specified colormap"""
        # Simple implementation of viridis-like colormap
        if colormap == 'viridis':
            # Convert value to color gradient
            r = 0.3 + 0.7 * value
            g = 0.4 + 0.6 * (1.0 - abs(value - 0.5))
            b = 0.7 + 0.3 * (1.0 - value)
            return [r, g, b, 1.0]
        else:
            raise ValueError(f"Unsupported colormap: {colormap}")
