"""
Tensor Processing Pipeline for VGPU
Handles user input conversion, model operations, and visualization
"""
from ..core.pipeline import VGPUPipeline
from ...virtual_gpu_driver.src.driver_api import VirtualGPUDriver
from ...virtual_gpu_driver.src.graphics.graphics_pipeline import VGPURasterizer
import array
import json
from typing import Any, Dict, List, Tuple

class TensorProcessingPipeline:
    def __init__(self):
        self.driver = VirtualGPUDriver()
        self.vgpu_pipeline = VGPUPipeline()
        self.graphics = VGPURasterizer(self.vgpu_pipeline, self.driver.memory_manager)
        self.main_stream = self.driver.create_stream()
        
    def process_user_input(self, input_data: Dict[str, Any]) -> int:
        """
        Convert user input to tensor format
        Returns: Memory address of input tensor
        """
        # Determine input shape based on data type
        if isinstance(input_data, str):
            # Text input - convert to token array
            input_array = array.array('l', [ord(c) for c in input_data])
            shape = (1, len(input_data))
            dtype = 'int32'
        elif isinstance(input_data, (list, tuple)):
            # Numeric array input
            input_array = array.array('f', input_data)
            shape = (1, len(input_data))
            dtype = 'float32'
        else:
            raise ValueError("Unsupported input type")
            
        # Allocate and initialize input tensor
        input_addr = self.driver.memory_manager.allocate_tensor(shape, dtype)
        self.driver.memory_manager.write_tensor(input_addr, input_array)
        return input_addr
        
    def apply_model_operations(self, input_addr: int, weights: List[str]) -> int:
        """
        Apply model weights and operations to input tensor
        Returns: Memory address of output tensor
        """
        current_addr = input_addr
        
        for weight_file in weights:
            # Load weight tensor
            with open(weight_file, 'rb') as f:
                weight_data = array.array('f')
                weight_data.frombytes(f.read())
                weight_meta = json.loads(f.readline().decode())
                
            # Allocate weight tensor
            weight_addr = self.driver.memory_manager.allocate_tensor(
                weight_meta['shape'], 
                'float32'
            )
            self.driver.memory_manager.write_tensor(weight_addr, weight_data)
            
            # Allocate output tensor for this layer
            output_shape = self._calculate_output_shape(
                self.driver.memory_manager.get_tensor_shape(current_addr),
                weight_meta['shape']
            )
            output_addr = self.driver.memory_manager.allocate_tensor(
                output_shape, 
                'float32'
            )
            
            # Execute tensor operation
            self.driver.execute_tensor_op(
                'matmul',
                [current_addr, weight_addr],
                output_addr,
                self.main_stream
            )
            
            # Update current tensor for next operation
            current_addr = output_addr
            
        return current_addr
        
    def generate_visualization(self, output_addr: int, 
                             chart_type: str = 'line') -> None:
        """
        Generate visualization of output tensor data
        """
        # Read output tensor data
        output_data = self.driver.memory_manager.read_tensor(output_addr)
        output_shape = self.driver.memory_manager.get_tensor_shape(output_addr)
        
        if chart_type == 'line':
            # Setup line chart visualization
            vertices = []
            indices = []
            
            # Convert tensor data to line segments
            for i in range(len(output_data) - 1):
                vertices.extend([
                    i, output_data[i], 0,  # Current point
                    i + 1, output_data[i + 1], 0  # Next point
                ])
                indices.extend([i * 2, i * 2 + 1])  # Connect points
                
            # Create visualization using graphics pipeline
            self.graphics.draw_lines(
                array.array('f', vertices),
                array.array('l', indices),
                line_width=2.0,
                color=(0.0, 0.4, 0.8, 1.0)  # Blue lines
            )
            
        elif chart_type == 'heatmap':
            # For 2D tensor data
            if len(output_shape) != 2:
                raise ValueError("Heatmap requires 2D tensor data")
                
            # Convert to color values and draw using graphics pipeline
            self.graphics.draw_heatmap(
                output_data,
                output_shape,
                colormap='viridis'
            )
            
    def _calculate_output_shape(self, input_shape: Tuple[int, ...],
                              weight_shape: Tuple[int, ...]) -> Tuple[int, ...]:
        """Calculate output tensor shape for matrix multiplication"""
        if len(input_shape) != 2 or len(weight_shape) != 2:
            raise ValueError("Only 2D tensor operations supported")
            
        if input_shape[1] != weight_shape[0]:
            raise ValueError("Incompatible tensor shapes for multiplication")
            
        return (input_shape[0], weight_shape[1])
