"""
Improved parallel array distributor with electron-speed processing and chunk management.
Uses direct hardware simulation at electron speed without Python threading limitations.
"""
import numpy as np
from typing import List, Tuple, Optional, Dict, Any
import asyncio
import time
import logging
from tensor_chunk_manager import ChunkManager
from local_storage_manager import LocalStorageManager

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('parallel_processing.log'),
        logging.StreamHandler()
    ]
)

class ParallelArrayDistributor:
    def __init__(self, num_sms=8, cores_per_sm=30, persistent=False):
        """Initialize the array distributor with electron-speed acceleration"""
        self.num_sms = num_sms
        self.cores_per_sm = cores_per_sm
        self.persistent = persistent
        
        # Initialize core speeds and frequencies
        self.base_speed = 8.92e85  # Base speed in Hz
        self.base_drift = 1.96e7   # Base electron drift velocity
        self.process_node = 14e-9  # 14nm process node
        
        # Initialize core speeds and electron drifts
        self.core_speeds = {}      # Maps core_id to operating frequency
        self.electron_drifts = {}  # Maps core_id to electron drift velocity
        
        for i in range(num_sms):
            # Each core gets progressively faster
            self.core_speeds[i] = self.base_speed + (i * self.base_speed)
            # Each core gets faster electron drift
            self.electron_drifts[i] = self.base_drift * (i + 1)
            
        # Initialize chunk management
        self.chunk_manager = ChunkManager(num_cores=num_sms, chunks_per_core=cores_per_sm, persistent=persistent)
        
        # Initialize per-core processing queues
        self.core_queues = [asyncio.Queue() for _ in range(num_sms)]
        self.total_ops = 0
        self.core_timings = {i: 0.0 for i in range(num_sms)}
        
        # Performance tracking
        self.start_time = time.time()
        self.processed_chunks = 0
        self.processed_chunks = 0
        self.start_time = time.time()
        
        # Store persistent assignments if enabled
        self.persistent_assignments = {} if persistent else None
        self.initialized = False
        
    async def process_tensor(self, tensor: np.ndarray, operation: str = "matmul") -> np.ndarray:
        """Process tensor in parallel across all cores at electron speed"""
        self.start_time = time.time()
        logging.info(f"Starting tensor processing: shape={tensor.shape}, size={tensor.nbytes/1e9:.2f}GB")
        
        # Initialize chunk manager if not already done
        if self.chunk_manager is None:
            self.chunk_manager = ChunkManager(
                num_cores=self.num_sms,
                chunks_per_core=self.cores_per_sm,
                persistent=self.persistent
            )
            
        # Split tensor into chunks
        chunk_ids = await self.chunk_manager.chunk_tensor(tensor)
        logging.info(f"Split tensor into {len(chunk_ids)} chunks")
        
        # Clear existing queues
        for q in self.core_queues:
            while not q.empty():
                try:
                    q.get_nowait()
                except asyncio.QueueEmpty:
                    break
        
        # Distribute chunks to appropriate core queues
        for chunk_id in chunk_ids:
            chunk_meta = self.chunk_manager.chunk_metadata[chunk_id]
            core_id = chunk_meta.core_id
            await self.core_queues[core_id].put((chunk_id, operation))
        logging.info("Distributed all chunks to core queues")
        
        # Process chunks in parallel at electron speed using core-specific queues
        processing_tasks = [
            self._process_core_with_queue(core_id)
            for core_id in range(self.num_sms)
        ]
        
        # Wait for all cores to complete
        results = await asyncio.gather(*processing_tasks)
        logging.info("All cores completed processing")
        
        # Combine results
        final_result = self._combine_results(results, tensor.shape)
        self._log_performance()
        
        return final_result
        
    async def _process_core_with_queue(self, core_id: int) -> List[np.ndarray]:
        """Process chunks from this core's dedicated queue"""
        results = []
        core_start_time = time.time()
        chunks_processed = 0
        core_queue = self.core_queues[core_id]
        
        while not core_queue.empty():
            # Get item from this core's queue
            chunk_id, operation = await core_queue.get()
            
            chunk_meta = self.chunk_manager.chunk_metadata[chunk_id]
            
            # Process chunk at electron speed
            result = await self.chunk_manager.process_chunk(chunk_id, operation)
            results.append(result)
            
            # Track performance
            self.processed_chunks += 1
            self.total_ops += chunk_meta.quantum_ops
            chunks_processed += 1
            
            # Log chunk completion
            logging.info(f"Core {core_id} completed chunk {chunk_id}, "
                       f"time: {time.time() - chunk_meta.processing_start:.9f}s, "
                       f"chunks processed: {chunks_processed}")
            
            core_queue.task_done()
                
        logging.info(f"Core {core_id} completed all chunks in {time.time() - core_start_time:.9f}s")
        return results
        
    def get_processing_status(self, array_id: int) -> Dict[str, Any]:
        """Get status of parallel processing for a specific array"""
        try:
            # Query storage for chunks and results related to this array
            chunks = self.storage.query_tensors(
                metadata_filter={'array_id': array_id}
            )
            
            total_chunks = len(chunks) if chunks else 0
            processed_chunks = len([c for c in chunks if c.get('metadata', {}).get('processed', False)]) if chunks else 0
            
            return {
                'array_id': array_id,
                'total_chunks': total_chunks,
                'processed_chunks': processed_chunks,
                'completion_percentage': (processed_chunks / total_chunks * 100) if total_chunks > 0 else 0,
                'timestamp': time.time()
            }
        except Exception as e:
            return {
                'array_id': array_id,
                'error': str(e),
                'timestamp': time.time()
            }

    def _combine_results(self, results: List[List[np.ndarray]], original_shape: Tuple[int, ...]) -> np.ndarray:
        """Combine processed chunks back into final tensor preserving dimensions"""
        combine_start = time.time()
        
        # Initialize output array with correct shape
        final_result = np.zeros(original_shape, dtype=np.float32)
        
        # Track reconstruction
        chunks_combined = 0
        current_idx = 0
        total_chunks = sum(len(core_results) for core_results in results)
        
        # Reconstruct tensor from chunks while preserving dimensions
        for core_id, core_results in enumerate(results):
            for chunk in core_results:                
                # Calculate position based on chunk size
                if len(chunk.shape) != len(original_shape):
                    raise ValueError(f"Chunk shape {chunk.shape} incompatible with original shape {original_shape}")
                
                chunk_size = chunk.shape[0]
                end_idx = current_idx + chunk_size
                if end_idx > original_shape[0]:
                    end_idx = original_shape[0]
                    chunk = chunk[:end_idx-current_idx]
                
                # Place chunk in correct position, preserving other dimensions
                final_result[current_idx:end_idx] = chunk
                current_idx = end_idx
                chunks_combined += 1
                
                if chunks_combined % 100 == 0:
                    logging.info(f"Combined {chunks_combined}/{total_chunks} chunks")
        
        logging.info(f"Result combination completed in {time.time() - combine_start:.6f}s")
        return final_result
        
    def _create_persistent_assignments(self):
        """Create and store persistent core assignments"""
        if not self.persistent:
            return
            
        logger.info("Creating persistent core assignments...")
        cores_available = list(range(self.num_sms))
        self.persistent_assignments = {
            'core_mapping': {i: core for i, core in enumerate(cores_available)},
            'chunk_sizes': self._calculate_optimal_chunk_sizes(),
            'timestamp': time.time()
        }
        logger.info("Persistent assignments created and stored")
        
    def _assign_to_persistent_cores(self, chunks: List[np.ndarray]) -> List[int]:
        """Assign chunks using persistent core assignments"""
        chunk_ids = []
        core_mapping = self.persistent_assignments['core_mapping']
        chunk_sizes = self.persistent_assignments['chunk_sizes']
        
        for i, chunk in enumerate(chunks):
            core_id = core_mapping[i % len(core_mapping)]
            chunk_id = self.chunk_manager.add_chunk(
                chunk,
                core_id=core_id,
                chunk_size=chunk_sizes.get(core_id, None)
            )
            chunk_ids.append(chunk_id)
            
        return chunk_ids
        
    def _log_performance(self):
        """Log detailed performance metrics"""
        duration = time.time() - self.start_time
        ops_per_second = self.total_ops / duration
        
        logging.info(f"\nPerformance Summary:")
        logging.info(f"Total duration: {duration:.9f} seconds")
        logging.info(f"Total operations: {self.total_ops:,}")
        logging.info(f"Operations per second: {ops_per_second:.2e}")
        logging.info(f"Processed chunks: {self.processed_chunks}")
        logging.info(f"Average time per chunk: {duration/self.processed_chunks:.12f} seconds")
        
        # Log core utilization
        for core_id in range(self.num_sms):
            core_chunks = len([meta for meta in self.chunk_manager.chunk_metadata.values() 
                             if meta.core_id == core_id])
            core_ops = sum(meta.quantum_ops for meta in self.chunk_manager.chunk_metadata.values() 
                          if meta.core_id == core_id)
            
            logging.info(f"\nCore {core_id} stats:")
            logging.info(f"  Chunks processed: {core_chunks}")
            logging.info(f"  Operations: {core_ops:,}")
            logging.info(f"  Ops/second: {core_ops/duration:.2e}")
