"""
Massively parallel multicore system simulation.
Each core supports 700K hardware threads for a total of 35B threads across the system.
"""

import threading
from typing import List, Dict, Any
import time
from core import ThreadedCore
from http_storage import LocalStorage
import logging
import uuid

class MultiCoreSystem:
    def __init__(self, num_cores=50000, threads_per_core=700000, threads_per_block=32, bits=2, num_registers=2):
        # Initialize storage
        self.storage = LocalStorage()
        
        # Initialize cores with storage for state persistence
        self.cores = []
        for core_id in range(num_cores):
            core = ThreadedCore(
                num_threads=threads_per_core,
                threads_per_block=threads_per_block,
                bits=bits,
                num_registers=num_registers
            )
            self.cores.append(core)
            
            # Store core state in database
            self.storage.conn.execute("""
                INSERT INTO tensor_core_states (
                    core_id, array_id, current_op, register_state, shared_memory_state,
                    metadata, status, is_active
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
            """, [
                str(core_id),
                str(uuid.uuid4()),
                None,
                '{}',
                '{}',
                '{"threads": threads_per_core, "blocks": threads_per_block}',
                'initialized',
                True
            ])
            
        self.num_cores = num_cores
        self.threads_per_core = threads_per_core
        self.total_threads = num_cores * threads_per_core
        self.scheduler_lock = threading.Lock()
        
    def _execute_core(self, core_id: int, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Execute operations on a single core's threads"""
        # Update core state in database
        op_id = str(uuid.uuid4())
        self.storage.conn.execute("""
            INSERT INTO tensor_ops (
                op_id, core_id, operation_type, input_tensors, output_tensors,
                metadata, status
            ) VALUES (?, ?, ?, ?, ?, ?, ?)
        """, [
            op_id,
            str(core_id),
            'parallel_execution',
            '{"inputs": len(inputs)}',
            '{}',
            '{}',
            'started'
        ])
        
        # Execute operations
        results = self.cores[core_id].execute_parallel(inputs)
        
        # Update operation status
        self.storage.conn.execute("""
            UPDATE tensor_ops 
            SET status = ?, output_tensors = ?, completed_at = CURRENT_TIMESTAMP
            WHERE op_id = ?
        """, ['completed', '{"results": len(results)}', op_id])
        
        return results
        
    def execute_all(self, inputs: List[Dict[str, Any]], sync_blocks: bool = True):
        """
        Execute operations across all cores and their threads in parallel.
        inputs: List of operation inputs (one per thread)
        sync_blocks: Whether to synchronize thread blocks after execution
        Returns: List of results from all threads
        """
        if len(inputs) > self.total_threads:
            inputs = inputs[:self.total_threads]
            
        # Split inputs across cores
        inputs_per_core = self.threads_per_core
        core_inputs = [
            inputs[i:i + inputs_per_core]
            for i in range(0, len(inputs), inputs_per_core)
        ]
        
        # Create and start a thread for each core
        threads = []
        results = []
        
        for core_id, core_input in enumerate(core_inputs):
            thread = threading.Thread(
                target=lambda: results.extend(self._execute_core(core_id, core_input))
            )
            threads.append(thread)
            thread.start()
            
        # Wait for all core threads to complete
        for thread in threads:
            thread.join()
            
        # Optionally synchronize thread blocks across all cores
        if sync_blocks:
            for core in self.cores:
                core.barrier_all_threads()
                
        return results
        
    def execute_same_all(self, operation: Dict[str, Any], sync_blocks: bool = True):
        """
        Execute the same operation across all cores and their threads.
        operation: Single operation to replicate across all threads
        sync_blocks: Whether to synchronize thread blocks after execution
        """
        inputs = [operation] * self.total_threads
        return self.execute_all(inputs, sync_blocks)

if __name__ == "__main__":
    print("\n--- Massively Parallel MultiCore System Simulation ---")
    
    # Create system with 50K cores, 700K threads per core
    system = MultiCoreSystem(
        num_cores=50000,
        threads_per_core=700000,
        threads_per_block=32,
        bits=2,
        num_registers=2
    )
    
    print(f"Total cores: {system.num_cores:,}")
    print(f"Threads per core: {system.threads_per_core:,}")
    print(f"Total threads: {system.total_threads:,}")
    
    # Example: Execute same ADD operation across all threads
    print("\nExecuting same operation across all threads...")
    operation = {
        'a': [0.7, 0.0],
        'b': [0.7, 0.7],
        'cin': 0.0,
        'opcode': 0b10,
        'reg_sel': 0
    }
    
    start_time = time.time()
    results = system.execute_same_all(operation)
    end_time = time.time()
    
    print(f"Executed {len(results):,} thread operations")
    print(f"First thread result: {results[0]}")
    print(f"Execution time: {end_time - start_time:.4f} seconds")
    print(f"Operations per second: {len(results) / (end_time - start_time):,.0f}")
