"""
Physics-inspired digital core model for virtual GPU v2.
Contains ThreadedCore class for massive parallel computation.
"""

from logic_gates import ControlUnit, ALU2Bit, RegisterFile2x2, SimpleMMU
import threading
from typing import List, Dict, Any
import numpy as np
from queue import Queue
import time

class ThreadState:
    """Represents the state of a single thread within a core"""
    def __init__(self, thread_id: int, num_registers: int = 2, bits: int = 2):
        self.thread_id = thread_id
        self.regfile = RegisterFile2x2()  # Each thread gets its own registers
        self.active = True
        self.barrier_count = 0
        self.result_queue = Queue()
        
class ThreadBlock:
    """Manages a group of threads that can be synchronized together"""
    def __init__(self, block_id: int, num_threads: int = 32):
        self.block_id = block_id
        self.threads: List[ThreadState] = []
        self.barrier = threading.Barrier(num_threads)
        self.shared_memory = {}
        
    def synchronize(self):
        """Synchronize all threads in the block"""
        self.barrier.wait()

class ThreadedCore:
    """
    Simulates a massively parallel core with:
    - 700K hardware threads
    - Shared control unit
    - Thread-local register files
    - Shared ALU with time-multiplexing
    - Thread synchronization capabilities
    """
    def __init__(self, num_threads: int = 700000, threads_per_block: int = 32, bits: int = 2, num_registers: int = 2):
        self.control = ControlUnit()
        self.alu = ALU2Bit()  # Shared ALU
        self.mmu = SimpleMMU(num_registers=num_registers, bits=bits)
        self.clk = 0.7  # High voltage for clock
        self.bits = bits
        self.num_registers = num_registers  # Store num_registers as instance variable
        
        # Thread management
        self.num_threads = num_threads
        self.threads_per_block = threads_per_block
        self.num_blocks = (num_threads + threads_per_block - 1) // threads_per_block
        
        # Initialize thread blocks and states
        self.blocks: List[ThreadBlock] = []
        self.thread_states: Dict[int, ThreadState] = {}
        self._initialize_threads()
        
        # Thread scheduling
        self.scheduler_lock = threading.Lock()
        self.active_threads = set(range(num_threads))
        self.thread_pool = []  # Will hold thread objects

    def _initialize_threads(self):
        """Initialize thread blocks and states"""
        for block_id in range(self.num_blocks):
            block = ThreadBlock(block_id, self.threads_per_block)
            threads_in_block = min(
                self.threads_per_block,
                self.num_threads - block_id * self.threads_per_block
            )
            
            for i in range(threads_in_block):
                thread_id = block_id * self.threads_per_block + i
                thread_state = ThreadState(thread_id, num_registers=self.num_registers, bits=self.bits)
                block.threads.append(thread_state)
                self.thread_states[thread_id] = thread_state
                
            self.blocks.append(block)
    
    def _execute_thread(self, thread_id: int, a, b, cin, opcode, reg_sel):
        """Execute operation for a single thread"""
        thread_state = self.thread_states[thread_id]
        if not thread_state.active:
            return None
            
        # Get block for this thread
        block_id = thread_id // self.threads_per_block
        block = self.blocks[block_id]
        
        # Acquire scheduler lock for ALU access
        with self.scheduler_lock:
            # Set control signals
            self.control.set_opcode(opcode)
            ctrl = self.control.get_control_signals()
            
            # ALU operation (shared resource)
            (r0, r1), cout = self.alu.operate(a[0], a[1], b[0], b[1], cin, ctrl['alu_op'])
            
            # Write to thread-local register file
            thread_state.regfile.write(r0, r1, self.clk, reg_sel)
            
        # Store result in thread's queue
        result = {
            'thread_id': thread_id,
            'alu_result': (r0, r1),
            'carry_out': cout,
            'regfile_out': thread_state.regfile.read(reg_sel),
            'control': ctrl
        }
        thread_state.result_queue.put(result)
        return result
        
    def execute_parallel(self, inputs: List[Dict[str, Any]]):
        """
        Execute operations across all threads in parallel
        inputs: List of operation inputs for each thread
        """
        threads = []
        results = []
        
        # Create and start threads
        for thread_id, inp in enumerate(inputs):
            if thread_id >= self.num_threads:
                break
                
            thread = threading.Thread(
                target=self._execute_thread,
                args=(thread_id, inp['a'], inp['b'], inp['cin'], inp['opcode'], inp['reg_sel'])
            )
            threads.append(thread)
            thread.start()
            
        # Wait for all threads to complete
        for thread in threads:
            thread.join()
            
        # Collect results
        for thread_id in range(min(len(inputs), self.num_threads)):
            if thread_id in self.thread_states:
                try:
                    result = self.thread_states[thread_id].result_queue.get_nowait()
                    results.append(result)
                except Exception:
                    pass
                    
        return results
        
    def synchronize_block(self, block_id: int):
        """Synchronize all threads in a block"""
        if 0 <= block_id < len(self.blocks):
            self.blocks[block_id].synchronize()
            
    def barrier_all_threads(self):
        """Global barrier synchronization across all threads"""
        for block in self.blocks:
            block.synchronize()

if __name__ == "__main__":
    print("\n--- Threaded Core Simulation ---")
    core = ThreadedCore(num_threads=700000, threads_per_block=32)
    
    # Example: Execute same operation across many threads
    inputs = [
        {'a': [0.7, 0.0], 'b': [0.7, 0.7], 'cin': 0.0, 'opcode': 0b10, 'reg_sel': 0}
        for _ in range(1000)  # Test with 1000 threads
    ]
    
    start_time = time.time()
    results = core.execute_parallel(inputs)
    end_time = time.time()
    
    print(f"Executed {len(results)} thread operations")
    print(f"First thread result: {results[0]}")
    print(f"Execution time: {end_time - start_time:.4f} seconds")
