"""
Integrates TensorCore_v2 with ParallelArrayDistributor for optimal performance.
"""

import numpy as np
import asyncio
import logging
from typing import Dict, Any
from tensor_core_v2 import TensorCore
from parallel_array_distributor import ParallelArrayDistributor

class TensorCoreDistributor:
    """Manages integration between TensorCore_v2 and ParallelArrayDistributor"""
    
    def __init__(self, num_sms: int = 16, cores_per_sm: int = 4000):
        self.tensor_core = TensorCore(core_id=0)  # Initialize with core_id 0
        self.distributor = ParallelArrayDistributor(num_sms, cores_per_sm)
        self.tensor_cache = {}  # Cache for frequently used tensors
        self.cache_hits = 0
        self.cache_misses = 0
        
    async def process_tensor(self, tensor: np.ndarray, operation: str = "matmul") -> np.ndarray:
        """Process tensor using the distributor with caching"""
        cache_key = (tensor.tobytes(), operation)
        
        # Check cache
        if cache_key in self.tensor_cache:
            self.cache_hits += 1
            return self.tensor_cache[cache_key]
            
        # Process tensor
        self.cache_misses += 1
        result = await self.distributor.process_tensor(tensor, operation)
        
        # Cache result
        self.tensor_cache[cache_key] = result
        
        # Manage cache size
        if len(self.tensor_cache) > 1000:  # Maximum cache entries
            # Remove oldest entries
            old_keys = list(self.tensor_cache.keys())[:-500]
            for k in old_keys:
                del self.tensor_cache[k]
                
        return result
        
    async def process_model_weights(self, weights: Dict[str, np.ndarray], operation: str = "matmul") -> Dict[str, np.ndarray]:
        """Process large model weights in parallel"""
        processed_weights = {}
        total_size = sum(w.nbytes for w in weights.values())
        
        logging.info(f"Processing model weights: {len(weights)} tensors, "
                    f"total size: {total_size/1e9:.2f}GB")
        
        # Process each weight tensor in parallel
        for name, tensor in weights.items():
            processed_weights[name] = await self.distributor.process_tensor(tensor, operation)
            
            # Log progress
            logging.info(f"Processed weight tensor: {name}, "
                        f"shape: {tensor.shape}, "
                        f"size: {tensor.nbytes/1e9:.2f}GB")
            
        return processed_weights
        
    async def run_inference(self, input_tensor: np.ndarray) -> np.ndarray:
        """Run model inference using distributed processing"""
        # Process input through tensor core
        return await self.distributor.process_tensor(input_tensor, "inference")
        
    def get_performance_metrics(self) -> Dict[str, Any]:
        """Get combined performance metrics from both systems"""
        metrics = {
            'distributor_stats': {
                'processed_chunks': self.distributor.processed_chunks,
                'total_ops': self.distributor.total_ops,
                'processing_time': time.time() - self.distributor.start_time
            },
            'tensor_core_stats': {
                'electron_cycles': self.tensor_core.electron_cycles,
                'quantum_ops': self.tensor_core.quantum_ops
            }
        }
        
        return metrics