"""
Enhanced 60-second SHA-256 mining test with true parallel execution and hardware-level thread management
"""
import os
import hashlib
import time
import logging
import threading
import queue
from datetime import datetime
from collections import deque
import struct
import numpy as np
from dataclasses import dataclass

from virtual_cpu import (
    CPU, WORD_SIZE, OP_ADD, OP_XOR, OP_SHR, OP_SHL, OP_AND, OP_OR,
    OP_LOAD, OP_STORE, Memory, MEMORY_SIZE, THREADS_PER_CORE, OP_ROTATE, 
    OP_LOAD_IMMED
)

@dataclass
class InstructionInfo:
    """Represents a decoded CPU instruction.
    
    Fields:
    - op_type: Operation type (e.g., ADD, XOR, LOAD)
    - dest_reg: Destination register number (0-31)
    - src_reg1: First source register number (0-31)
    - src_reg2: Second source register number (0-31)
    - immediate: Immediate value for operations that use constants
    """
    op_type: int
    dest_reg: int
    src_reg1: int
    src_reg2: int
    immediate: int

# SHA-256 constants
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('mining_stats.log'),
        logging.StreamHandler()
    ]
)
class ThreadExecutionUnit:
    """Hardware-level thread execution unit with instruction interpretation.
    
    This class simulates a hardware thread with its own:
    - Register file (32 registers)
    - Instruction pointer
    - 5-stage pipeline (Fetch, Decode, Execute, Memory, Writeback)
    - L1 instruction cache
    - Performance metrics (cache hits/misses, IPC, etc.)
    
    The execution unit processes instructions from program memory and updates
    shared memory according to the instruction semantics.
    """
    def __init__(self, thread_id, core_id, shared_memory):
        self.thread_id = thread_id
        self.core_id = core_id
        self.registers = [0] * 32
        self.instruction_pointer = 0
        self.pipeline = deque(maxlen=5)  # 5-stage pipeline
        self.local_cache = {}  # L1 cache for this thread
        self.shared_memory = shared_memory
        self.running = False
        self.clock_cycles = 0
        self.cache_hits = 0
        self.cache_misses = 0
        self.shared_cache_hits = 0  # Hits on entries added by other threads
        self.instructions_executed = 0
        self.last_result = None
        self.result_queue = queue.Queue()
        self.hash_queue = queue.Queue()  # Queue for tracking hashes
        self.cache_owners = {}  # Track which thread added each cache entry
        self.total_hashes = 0  # Total hashes computed
        self.blocks_found = 0  # Valid blocks found
        
        # Warm up the instruction cache
        self.warmup_cache()
        
    def warmup_cache(self):
        """Pre-load instructions into cache"""
        for addr in range(512):  # Program memory size
            if addr not in self.local_cache:
                self.local_cache[addr] = self.shared_memory.read(addr)
                self.cache_misses += 1  # Count initial loads as misses
                self.cache_owners[addr] = self.thread_id
            else:
                self.shared_cache_hits += 1  # Entry was loaded by another thread

    def fetch_instruction(self):
        """Fetch next instruction from memory"""
        addr = self.instruction_pointer % 512  # Program memory wraps around
        if addr in self.local_cache:
            if self.cache_owners[addr] == self.thread_id:
                self.cache_hits += 1
            else:
                self.shared_cache_hits += 1
            instr = self.local_cache[addr]
        else:
            self.cache_misses += 1
            instr = self.shared_memory.read(addr)
            self.local_cache[addr] = instr
            self.cache_owners[addr] = self.thread_id
        self.instruction_pointer += 1
        return instr
        
    def decode_instruction(self, raw_instruction):
        """Decode instruction into components"""
        if isinstance(raw_instruction, int):
            instruction = raw_instruction
        else:
            instruction = int.from_bytes(raw_instruction, 'little')
            
        # Decode fields
        op_type = (instruction >> 26) & 0x3F
        dest_reg = (instruction >> 21) & 0x1F
        src_reg1 = (instruction >> 16) & 0x1F
        src_reg2 = (instruction >> 11) & 0x1F
        immediate = instruction & 0x7FF
        
        # Create instruction structure
        return InstructionInfo(
            op_type=op_type,
            dest_reg=dest_reg,
            src_reg1=src_reg1, 
            src_reg2=src_reg2,
            immediate=immediate
        )
        
    def execute_instruction(self, instruction):
        """Execute decoded instruction"""
        if not isinstance(instruction, InstructionInfo):
            return None, 0
            
        # Load operands    
        operand1 = self.registers[instruction.src_reg1]
        operand2 = self.registers[instruction.src_reg2]
        
        # Execute operation
        result = 0
        if instruction.op_type == OP_ADD:
            result = (operand1 + operand2) & 0xFFFFFFFF
        elif instruction.op_type == OP_XOR:
            result = operand1 ^ operand2
        elif instruction.op_type == OP_AND:
            result = operand1 & operand2
        elif instruction.op_type == OP_OR:
            result = operand1 | operand2
        elif instruction.op_type == OP_SHL:
            result = (operand1 << instruction.immediate) & 0xFFFFFFFF
        elif instruction.op_type == OP_SHR:
            result = operand1 >> instruction.immediate
        elif instruction.op_type == OP_ROTATE:
            amount = instruction.immediate & 0x1F
            result = ((operand1 >> amount) | (operand1 << (32 - amount))) & 0xFFFFFFFF
        elif instruction.op_type == OP_LOAD:
            addr = operand1 + instruction.immediate
            result = self.shared_memory.read(addr)
        elif instruction.op_type == OP_LOAD_IMMED:
            result = instruction.immediate
        elif instruction.op_type == OP_STORE:
            addr = operand1 + instruction.immediate
            self.shared_memory.write(addr, operand2)
            
        return instruction.dest_reg, result
        
    def pipeline_cycle(self):
        """Execute one pipeline cycle"""
        if not self.running:
            return
            
        # Pipeline stages: Fetch -> Decode -> Execute -> Memory -> Writeback
        
        try:
            # Writeback stage
            if len(self.pipeline) >= 5:
                writeback = self.pipeline.pop()
                if writeback is not None:
                    dest_reg, result = writeback
                    if dest_reg is not None:
                        self.registers[dest_reg] = result
                    
            # Memory stage
            if len(self.pipeline) >= 4:
                memory = self.pipeline[3]
                # Memory operations already handled in execute stage
                self.pipeline[3] = memory
                
            # Execute stage
            if len(self.pipeline) >= 3:
                execute = self.pipeline[2]
                if isinstance(execute, InstructionInfo):
                    result = self.execute_instruction(execute)
                    self.pipeline[2] = result
                
            # Decode stage
            if len(self.pipeline) >= 2:
                decode = self.pipeline[1]
                if isinstance(decode, (bytes, int)):
                    decoded = self.decode_instruction(decode)
                    self.pipeline[1] = decoded
                
            # Fetch stage
            if self.instruction_pointer < 512:  # Program memory limit
                instruction = self.fetch_instruction()
                self.pipeline.appendleft(instruction)
                
            self.clock_cycles += 1
            self.instructions_executed += 1
            
            # Perform real SHA-256 mining
            if self.clock_cycles % 64 == 0:  # Process block periodically
                # Get current block header from memory
                header = bytearray(80)  # Standard block header size
                for i in range(0, 80, 4):
                    value = self.shared_memory.read(i)
                    header[i:i+4] = value.to_bytes(4, 'little')
                
                # Update nonce (last 4 bytes)
                nonce = self.registers[0] + self.thread_id  # Use register 0 as nonce base
                header[76:80] = nonce.to_bytes(4, 'little')
                
                # Perform double SHA-256
                hash1 = hashlib.sha256(header).digest()
                final_hash = hashlib.sha256(hash1).digest()
                
                # Convert to integer for comparison (little-endian)
                result = int.from_bytes(final_hash, 'little')
                self.total_hashes += 1
                
                # Put hash count update
                self.hash_queue.put(1)
                
                # Check against target difficulty
                target = 0x00000000FFFF0000000000000000000000000000000000000000000000000000
                if result < target:
                    self.blocks_found += 1
                    # Put tuple with block info
                    self.hash_queue.put((result, self.thread_id, self.core_id))
                    logging.debug(f"Thread {self.thread_id} found block with hash: {final_hash.hex()}")
                    
        except Exception as e:
            logging.error(f"Pipeline error in thread {self.thread_id}: {str(e)}")
            self.running = False

class ParallelMiningController:
    """Manages parallel mining operations across multiple thread execution units.
    
    This class coordinates multiple ThreadExecutionUnits to perform SHA-256 mining:
    - Manages thread creation and scheduling
    - Handles shared memory and L1 cache allocation
    - Tracks mining progress and statistics
    - Controls mining duration and thread synchronization
    - Collects and aggregates mining results
    
    Each core gets a shared L1 cache that all its threads can access.
    """
    def __init__(self, cpu):
        self.cpu = cpu
        self.thread_units = []
        self.shared_memory = cpu.memory
        self.running = False
        self.total_hashes = 0
        self.total_blocks = 0
        self.start_time = None
        self.hash_queue = queue.Queue()
        
        # Create shared L1 cache per core
        self.core_caches = {}
        self.core_stats = {}  # Track stats per core
        for core_id in range(len(cpu.cores)):
            self.core_caches[core_id] = {}
            self.core_stats[core_id] = {
                "hashes": 0,
                "blocks": 0
            }
            
        # Create thread execution units
        for core_id in range(len(cpu.cores)):
            for thread_id in range(THREADS_PER_CORE):
                thread_unit = ThreadExecutionUnit(
                    thread_id=thread_id,
                    core_id=core_id,
                    shared_memory=self.shared_memory
                )
                thread_unit.local_cache = self.core_caches[core_id]  # Share L1 cache
                self.thread_units.append(thread_unit)
                
    def initialize_mining(self):
        """Initialize mining program in memory"""
        # Real Bitcoin block header fields
        version = 2
        prev_block = bytes.fromhex('00000000000000000007d331598b77e6f84fa0bc30139318a3b7252bede95501')  # Recent block
        merkle_root = bytes.fromhex('4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b')
        timestamp = int(time.time())
        bits = 0x1d00ffff  # Current Bitcoin mining difficulty target
        nonce = 0
        
        # Build 80-byte block header in memory
        block_header = struct.pack('<L32s32sLL', 
                                 version, prev_block, merkle_root, 
                                 timestamp, bits)
        block_header += nonce.to_bytes(4, 'little')  # Add initial nonce
        
        # Store header in shared memory
        for i in range(0, len(block_header), 4):
            value = int.from_bytes(block_header[i:i+4], 'little')
            self.shared_memory.write(i, value)
            
        logging.info("Mining configuration:")
        logging.info(f"Version: {version}")
        logging.info(f"Previous block: {prev_block.hex()}")
        logging.info(f"Merkle root: {merkle_root.hex()}")
        logging.info(f"Time: {timestamp}")
        logging.info(f"Target bits: {bits:08x}")
        logging.info(f"Initial nonce: {nonce}\n")
            
    def start_mining(self, duration_seconds=60):
        """Start parallel mining operation"""
        self.running = True
        self.start_time = time.time()
        self.total_hashes = 0

        # Start all thread execution units
        for thread_unit in self.thread_units:
            thread_unit.running = True
            thread_unit.instruction_pointer = 0  # Reset to start of program
            
        mining_threads = []
        for thread_unit in self.thread_units:
            thread = threading.Thread(target=self._run_thread_unit, 
                                   args=(thread_unit, duration_seconds))
            thread.start()
            mining_threads.append(thread)
            
        # Monitor and collect results
        monitor_thread = threading.Thread(target=self._monitor_progress)
        monitor_thread.start()
        
        # Wait for completion
        for thread in mining_threads:
            thread.join()
            
        self.running = False
        monitor_thread.join()
        
        # Calculate final results
        duration = time.time() - self.start_time
        hash_rate = self.total_hashes / duration
        
        return self.total_hashes, hash_rate
        
    def _run_thread_unit(self, thread_unit, duration_seconds):
        """Run a single thread execution unit"""
        end_time = time.time() + duration_seconds
        
        while time.time() < end_time and self.running:
            # Execute pipeline cycle
            thread_unit.pipeline_cycle()
            
            # Check for completed hash
            if thread_unit.result_queue.qsize() > 0:
                hash_result = thread_unit.result_queue.get()
                self.hash_queue.put(hash_result)
                
    def _monitor_progress(self):
        """Monitor and log mining progress"""
        last_log_time = self.start_time
        last_total_hashes = 0
        
        while self.running:
            current_time = time.time()
            if current_time - last_log_time >= 1.0:  # Log every second
                # Process completed hashes and blocks
                while not self.hash_queue.empty():
                    result = self.hash_queue.get()
                    
                    # Update total hashes counter for regular hash completions
                    if result == 1:
                        self.total_hashes += 1
                    # Process found blocks
                    elif isinstance(result, tuple):
                        hash_val, thread_id, core_id = result
                        
                        # Update core stats
                        self.core_stats[core_id]["blocks"] += 1
                        
                        # Update global stats
                        self.total_blocks += 1
                        block_hash = f"{hash_val:064x}"
                        logging.info(f"\n!!! Block Found !!!")
                        logging.info(f"Thread {thread_id} on Core {core_id}")
                        logging.info(f"Block hash: {block_hash}")
                        logging.info(f"Hash starts with {block_hash.count('0')} zeros\n")
                
                # Calculate per-core stats
                for core_id in self.core_stats:
                    core_threads = [t for t in self.thread_units if t.core_id == core_id]
                    core_hashes = sum(t.total_hashes for t in core_threads)
                    self.core_stats[core_id]["hashes"] = core_hashes
                
                # Calculate total hashes across all threads
                self.total_hashes = sum(t.total_hashes for t in self.thread_units)
                
                # Calculate current hash rate
                interval = current_time - last_log_time
                interval_hashes = self.total_hashes - last_total_hashes
                current_rate = interval_hashes / interval
                
                # Log progress
                logging.info(f"\nMining Progress at {current_time - self.start_time:.1f}s:")
                logging.info(f"Current hashrate: {format_hashrate(current_rate)}")
                logging.info(f"Total hashes: {self.total_hashes:,}")
                logging.info(f"Valid blocks found: {self.total_blocks}")
                
                # Log per-core stats
                for core_id in sorted(self.core_stats.keys()):
                    stats = self.core_stats[core_id]
                    logging.info(f"\nCore {core_id}:")
                    logging.info(f"- Total hashes: {stats['hashes']:,}")
                    logging.info(f"- Blocks found: {stats['blocks']}")
                
                last_log_time = current_time
                last_total_hashes = self.total_hashes
                
            time.sleep(0.1)  # Prevent busy waiting

def format_hashrate(hashrate):
    """Format hash rate with appropriate unit"""
    if hashrate >= 1e12:
        return f"{hashrate/1e12:.2f} TH/s"
    elif hashrate >= 1e9:
        return f"{hashrate/1e9:.2f} GH/s"
    elif hashrate >= 1e6:
        return f"{hashrate/1e6:.2f} MH/s"
    else:
        return f"{hashrate/1e3:.2f} KH/s"

def run_mining_test(duration=60):
    """Run intensive SHA-256 mining test with hardware-level thread execution"""
    logging.info(f"Starting {duration}-second intensive SHA-256 mining test")
    
    # Initialize virtual CPU and controller
    cpu = CPU(num_cores=16)  # Using 16 cores for maximum parallelism
    mining_controller = ParallelMiningController(cpu)
    
    # Initialize mining program and memory
    mining_controller.initialize_mining()
    
    try:
        # Run mining test
        total_hashes, avg_hashrate = mining_controller.start_mining(duration_seconds=duration)
        
        # Log overall results
        rate_str = format_hashrate(avg_hashrate)
        logging.info("\n=== Mining Test Results ===")
        logging.info(f"Test duration: {duration:.2f} seconds")
        logging.info(f"Total hashes completed: {total_hashes:,}")
        logging.info(f"Average hash rate: {rate_str}")
        logging.info(f"Active cores: {len(cpu.cores)}")
        logging.info(f"Threads per core: {THREADS_PER_CORE}")
        logging.info(f"Total threads: {len(cpu.cores) * THREADS_PER_CORE}")
        
        # Log thread-level statistics
        logging.info("\n=== Thread Statistics ===")
        for thread_unit in mining_controller.thread_units:
            logging.info(f"\nThread {thread_unit.thread_id} on Core {thread_unit.core_id}:")
            logging.info(f"Instructions executed: {thread_unit.instructions_executed:,}")
            logging.info(f"Clock cycles: {thread_unit.clock_cycles:,}")
            logging.info(f"Cache hits: {thread_unit.cache_hits:,}")
            logging.info(f"Cache misses: {thread_unit.cache_misses:,}")
            hit_rate = thread_unit.cache_hits / (thread_unit.cache_hits + thread_unit.cache_misses) * 100
            logging.info(f"Cache hit rate: {hit_rate:.1f}%")
            ipc = thread_unit.instructions_executed / thread_unit.clock_cycles
            logging.info(f"Instructions per cycle (IPC): {ipc:.2f}")
        
        # Save detailed stats
        with open('mining_stats_detailed.log', 'w') as f:
            f.write(f"=== Detailed Mining Statistics ===\n")
            f.write(f"Test timestamp: {datetime.now()}\n")
            f.write(f"Hardware Configuration:\n")
            f.write(f"- Physical cores: {len(cpu.cores)}\n")
            f.write(f"- Threads per core: {THREADS_PER_CORE}\n")
            f.write(f"- Total threads: {len(cpu.cores) * THREADS_PER_CORE}\n")
            f.write(f"- Memory size: {MEMORY_SIZE:,} bytes\n\n")
            
            f.write(f"Performance Results:\n")
            f.write(f"- Total hashes: {total_hashes:,}\n")
            f.write(f"- Average hash rate: {rate_str}\n")
            f.write(f"- Total instructions: {sum(t.instructions_executed for t in mining_controller.thread_units):,}\n")
            f.write(f"- Total clock cycles: {sum(t.clock_cycles for t in mining_controller.thread_units):,}\n\n")
            
            f.write(f"Memory System Performance:\n")
            total_hits = sum(t.cache_hits for t in mining_controller.thread_units)
            total_misses = sum(t.cache_misses for t in mining_controller.thread_units)
            overall_hit_rate = total_hits / (total_hits + total_misses) * 100
            f.write(f"- Total cache hits: {total_hits:,}\n")
            f.write(f"- Total cache misses: {total_misses:,}\n")
            f.write(f"- Overall cache hit rate: {overall_hit_rate:.1f}%\n")
        
    except KeyboardInterrupt:
        logging.info("\nMining test interrupted by user")

if __name__ == '__main__':
    # Set high optimization level for virtual CPU
    os.environ['VIRTUAL_CPU_OPT_LEVEL'] = '3'
    
    try:
        run_mining_test(duration=60)
    except Exception as e:
        logging.error(f"Test failed: {str(e)}")
        raise