"""
Unified Pipeline Controller
Coordinates all pipeline components and ensures proper data flow
"""
from typing import Any, Dict, Union, Optional
import array
import json
from pathlib import Path

from .tensor_processing import TensorProcessingPipeline
from ..core.pipeline import Pipeline as CorePipeline
from ...virtual_gpu_driver.src.driver_api import VirtualGPUDriver
from ...virtual_gpu_driver.src.memory.memory_manager import MemoryManager
from ..pipeline.tensor_visualizer import TensorVisualizer

class UnifiedPipelineController:
    """
    Coordinates all pipeline components:
    1. User Input Processing
    2. Memory Management
    3. Model Execution
    4. Output Processing
    5. Visualization
    """
    def __init__(self):
        # Initialize core components
        self.driver = VirtualGPUDriver()
        self.memory_manager = MemoryManager()
        self.core_pipeline = CorePipeline()
        
        # Initialize processing pipelines
        self.tensor_pipeline = TensorProcessingPipeline()
        self.visualizer = TensorVisualizer(self.tensor_pipeline.graphics)
        
        # Set up memory pools
        self._initialize_memory_pools()
        
    def _initialize_memory_pools(self):
        """Initialize specialized memory pools for different data types"""
        # Input pool for user data
        self.driver.memory_manager.create_memory_pool(
            "input_pool",
            size_bytes=1 * 1024 * 1024 * 1024,  # 1GB
            stream_buffering=True
        )
        
        # Model pool for weights and activations
        self.driver.memory_manager.create_memory_pool(
            "model_pool",
            size_bytes=4 * 1024 * 1024 * 1024,  # 4GB
            stream_buffering=False
        )
        
        # Output pool for results
        self.driver.memory_manager.create_memory_pool(
            "output_pool",
            size_bytes=1 * 1024 * 1024 * 1024,  # 1GB
            stream_buffering=True
        )
        
    def process_input(self, input_data: Any, input_type: str = None) -> int:
        """
        Process user input through the pipeline
        Returns: Memory address of processed tensor
        """
        from ..core.probability import HeliumTokenizer
        
        # Auto-detect input type if not specified
        if input_type is None:
            input_type = self._detect_input_type(input_data)
            
        # Get appropriate pool for input type
        pool_id = "input_pool"
        
        # Initialize tokenizer for text input
        if input_type == 'text':
            self.tokenizer = HeliumTokenizer()
            if hasattr(self, 'vocab_path'):
                self.tokenizer.load_vocabulary(self.vocab_path)
        
        # Process through tensor pipeline with proper memory allocation
        tensor_addr = self.tensor_pipeline.process_user_input({
            'data': input_data,
            'type': input_type,
            'pool': pool_id
        })
        
        return tensor_addr
        
    def execute_model(self, input_addr: int, 
                       temperature: float = 1.0,
                       top_k: Optional[int] = None,
                       top_p: Optional[float] = None) -> int:
        """
        Execute model operations on input tensor
        Returns: Memory address of output tensor
        """
        from ..core.probability import ProbabilityCalculator
        
        # Create execution stream
        stream = self.driver.create_stream()
        
        # Get input tensor metadata
        input_meta = self.memory_manager.get_tensor_info(input_addr)
        
        # Initialize probability calculator
        self.prob_calc = ProbabilityCalculator()
        
        # Prepare model configuration
        model_config = {
            'batch_size': 1,
            'stream_id': stream,
            'use_cache': True,
            'temperature': temperature,
            'top_k': top_k,
            'top_p': top_p
        }
        
        # Execute through core pipeline
        output_addr = self.core_pipeline.run_inference(
            input_addr,
            model_config
        )
        
        # Synchronize stream
        self.driver.stream_synchronize(stream)
        
        return output_addr
        
    def process_output(self, output_addr: int, 
                      output_type: str,
                      visualization_type: str = None) -> Any:
        """
        Process and visualize model output
        Returns: Processed output in requested format
        """
        # Read output tensor
        output_data = self.memory_manager.read_tensor(output_addr)
        output_meta = self.memory_manager.get_tensor_info(output_addr)
        
        # Process logits for text output
        if output_type == 'text':
            # Convert logits to probabilities
            probs = self.prob_calc.compute_probabilities(
                output_data,
                temperature=self.last_config.get('temperature', 1.0)
            )
            
            # Sample or take argmax
            if self.last_config.get('top_k') or self.last_config.get('top_p'):
                token_id = self.prob_calc.sample_from_probs(
                    probs,
                    temperature=self.last_config['temperature'],
                    top_k=self.last_config.get('top_k'),
                    top_p=self.last_config.get('top_p')
                )
            else:
                token_id = self.prob_calc.argmax(probs)
                
            # Decode token to text
            if hasattr(self, 'tokenizer'):
                return self.tokenizer.decode([token_id])
        
        # Process output based on type
        processed_output = self.tensor_pipeline.postprocess_output({
            'data': output_data,
            'meta': output_meta,
            'type': output_type
        })
        
        # Generate visualization if requested
        if visualization_type:
            viz_output = self.visualizer.generate_visualization(
                processed_output,
                visualization_type
            )
            return {
                'data': processed_output,
                'visualization': viz_output
            }
            
        return processed_output
        
    def _detect_input_type(self, input_data: Any) -> str:
        """Detect input data type"""
        if isinstance(input_data, str):
            if any(input_data.lower().endswith(ext) for ext in ['.jpg', '.png', '.bmp']):
                return 'image'
            elif any(input_data.lower().endswith(ext) for ext in ['.wav', '.mp3']):
                return 'audio'
            elif any(input_data.lower().endswith(ext) for ext in ['.mp4', '.avi']):
                return 'video'
            return 'text'
        elif isinstance(input_data, (list, tuple, array.array)):
            return 'tensor'
        elif hasattr(input_data, 'mode'):  # PIL Image
            return 'image'
        return 'unknown'
        
    def get_pipeline_status(self) -> Dict[str, Any]:
        """Get current status of all pipeline components"""
        return {
            'memory_pools': {
                name: pool.get_status()
                for name, pool in self.driver.memory_manager.memory_pools.items()
            },
            'streams': {
                stream_id: stream.get_status()
                for stream_id, stream in self.driver.stream_manager.items()
            },
            'tensor_cache': len(self.memory_manager.tensor_cache),
            'active_visualizations': self.visualizer.active_count
        }
