import hashlib
import time
import threading
import struct
import logging
import numpy as np
from datetime import datetime

# Create stop event for thread synchronization
stop_event = threading.Event()
stats = {}  # Global stats dictionary
threads = []  # Global threads list
batch_size = 800000  # Global batch size

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s',
    handlers=[
        logging.StreamHandler()  # Show output in terminal
    ]
)

# Log mining stats every 10 seconds
class StatsLogger:
    def __init__(self):
        self.start_time = time.time()
        self.total_hashes = 0
        self.last_log_time = self.start_time
        self.last_hash_count = 0
        self.batch_threshold = 100000  # Accumulate stats until this threshold
        self.pending_hashes = 0
        
    def log_hash(self, count=1):
        self.pending_hashes += count
        
        # Only update total and check time when we hit batch threshold
        if self.pending_hashes >= self.batch_threshold:
            self.total_hashes += self.pending_hashes
            self.pending_hashes = 0
            
            current_time = time.time()
            if current_time - self.last_log_time >= 10:
                elapsed = current_time - self.last_log_time
                hashes_since_last = self.total_hashes - self.last_hash_count
                rate = hashes_since_last / elapsed
                logging.info(f"Mining Stats - Total Hashes: {self.total_hashes:,} | Current Rate: {rate:,.0f} H/s")
                self.last_log_time = current_time
                self.last_hash_count = self.total_hashes

stats_logger = StatsLogger()

class BlockHeader:
    def __init__(self, version, prev_block_hash, merkle_root, timestamp, bits, nonce):
        self.version = version
        self.prev_block_hash = prev_block_hash
        self.merkle_root = merkle_root
        self.timestamp = timestamp
        self.bits = bits
        self.nonce = nonce

    def serialize(self):
        # Proper Bitcoin block header serialization (little-endian)
        return (
            struct.pack("<I", self.version) +  # 4 bytes version
            bytes.fromhex(self.prev_block_hash)[::-1] +  # 32 bytes prev_block_hash (reversed)
            bytes.fromhex(self.merkle_root)[::-1] +  # 32 bytes merkle_root (reversed)
            struct.pack("<I", self.timestamp) +  # 4 bytes timestamp
            struct.pack("<I", self.bits) +  # 4 bytes bits
            struct.pack("<I", self.nonce)  # 4 bytes nonce
        )

def calculate_target(bits):
    # Convert compressed target (bits) to full 256-bit target
    size = bits >> 24
    word = bits & 0x007fffff
    if size <= 3:
        word >>= 8 * (3 - size)
        target = word
    else:
        target = word
        target <<= 8 * (size - 3)
    return target

def bits_to_difficulty(bits):
    # Convert bits to human-readable difficulty
    current_target = calculate_target(bits)
    if current_target <= 0:
        return 0
    difficulty = (0x00ffff * 2**208) / current_target
    return difficulty

def create_mock_network_data():
    # Mock network data similar to real Bitcoin network
    return {
        'prev_block_hash': '000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f',
        'version': 2,
        'bits': 0x1f07ffff,  # Test difficulty (much easier than real network)
        'height': 1,
        'coinbase_message': b'Test Mining @ ' + datetime.now().strftime('%Y-%m-%d %H:%M:%S').encode()
    }

def create_coinbase_tx(height, coinbase_message):
    # Create a proper coinbase transaction
    script = (
        bytes([0x03]) +  # Push next 3 bytes
        height.to_bytes(3, 'little') +  # Block height (BIP34)
        coinbase_message  # Arbitrary data
    )
    tx = (
        struct.pack("<I", 1) +  # Version
        bytes([1]) +  # Input count
        bytes.fromhex('0' * 64) +  # Previous transaction hash
        struct.pack("<I", 0xFFFFFFFF) +  # Previous output index
        bytes([len(script)]) + script +  # Script length and script
        struct.pack("<I", 0xFFFFFFFF) +  # Sequence
        bytes([1]) +  # Output count
        (50 * 100000000).to_bytes(8, 'little') +  # 50 BTC reward
        bytes([25]) + bytes.fromhex('76a914') + bytes(20) + bytes.fromhex('88ac') +  # P2PKH script
        struct.pack("<I", 0)  # Lock time
    )
    return hashlib.sha256(hashlib.sha256(tx).digest()).hexdigest()

def create_merkle_root(coinbase_tx_hash):
    # In a real implementation, this would combine all transaction hashes
    # For this mock version, we just use the coinbase transaction
    return coinbase_tx_hash

def mine_block(thread_id, start_nonce, step, network_data, stop_event, result, stats):
    version = network_data['version']
    prev_block_hash = network_data['prev_block_hash']
    bits = network_data['bits']
    target = calculate_target(bits)
    
    # Create block template
    coinbase_tx = create_coinbase_tx(network_data['height'], network_data['coinbase_message'])
    merkle_root = create_merkle_root(coinbase_tx)
    
    # Initialize mining stats
    nonce = start_nonce
    stats[thread_id] = 0  # Actual hash count
    block_count = 0
    batch_size = 800000  # Increased batch size for vectorized operations
    
    # Pre-calculate static part of the header
    static_header_part = (
        version.to_bytes(4, 'little') +
        bytes.fromhex(prev_block_hash) +
        bytes.fromhex(merkle_root)
    )
    
    while not stop_event.is_set():
        # Get current timestamp once per batch
        timestamp = int(time.time())
        timestamp_bits_bytes = timestamp.to_bytes(4, 'little') + bits.to_bytes(4, 'little')
        
        # Create array of sequential nonces for the batch
        nonce_array = np.arange(nonce, nonce + batch_size, dtype=np.uint32)
        
        # Pre-allocate batch headers
        headers = []
        for n in nonce_array:
            header_bytes = static_header_part + timestamp_bits_bytes + n.tobytes()
            headers.append(header_bytes)
        
        # Vectorized first SHA256
        first_hash = np.array([hashlib.sha256(h).digest() for h in headers])
        # Vectorized second SHA256
        final_hash = np.array([hashlib.sha256(h).digest() for h in first_hash])
        
        # Update stats for the batch
        if thread_id not in stats:
            stats[thread_id] = 0
        stats[thread_id] += batch_size
        
        # Convert to integers and compare with target
        hash_ints = np.array([int.from_bytes(h, 'little') for h in final_hash])
        valid_blocks = hash_ints < target
        
        if np.any(valid_blocks):
            # Found valid block(s)
            found_indices = np.where(valid_blocks)[0]
            for idx in found_indices:
                block_count += 1
                found_nonce = nonce + idx
                found_hash = final_hash[idx].hex()
                logging.info(f"\nBLOCK FOUND by Thread {thread_id}!")
                logging.info(f"Block hash: {found_hash}")
                logging.info(f"Nonce: {found_nonce}")
                logging.info(f"Target: {hex(target)}")
        
        nonce += batch_size
        
        # Ensure thread stats are initialized
        if thread_id not in stats:
            stats[thread_id] = 0
            
        # Store block info in results
        if block_count > 0:
            if 'blocks' not in result:
                result['blocks'] = []
            for idx in found_indices:
                found_nonce = nonce + idx
                found_hash = final_hash[idx].hex()
                block_info = {
                    'hash': found_hash,
                    'thread': thread_id,
                    'timestamp': timestamp,
                    'target': target,
                    'nonce': found_nonce
                }
                result['blocks'].append(block_info)
                
        # Update nonce for next batch
        nonce += batch_size
        stats_logger.log_hash(batch_size)  # Log the batch of hashes
            

def run_miner():
    import multiprocessing
    start_time = time.time()
    runtime = 160  # Set runtime to 60 seconds
    num_threads = multiprocessing.cpu_count()
    logging.info(f"Starting mining with {num_threads} threads for {runtime} seconds...")
    
    global result
    result = {'found': False}
    
    # Get mock network data
    network_data = create_mock_network_data()
    difficulty = bits_to_difficulty(network_data['bits'])
    logging.info(f"Current difficulty: {difficulty:.2f}")
    
    stop_event = threading.Event()
    result = {'found': False}
    stats = {i: 0 for i in range(num_threads)}  # Will store actual hash counts
    threads = []
    
    # Distribute nonce space across threads
    for i in range(num_threads):
        t = threading.Thread(
            target=mine_block,
            args=(i, i, num_threads, network_data, stop_event, result, stats)
        )
        threads.append(t)
        t.start()
    
    # Monitor mining progress
    while not stop_event.is_set():
        time.sleep(1)
        current_time = time.time()
        elapsed_time = current_time - start_time
        
        if elapsed_time >= runtime:
            logging.info(f"\nReached runtime limit of {runtime} seconds")
            stop_event.set()
            break
            
        if 'blocks' in result and len(result['blocks']) > 0:
            logging.info(f"\nTotal blocks found: {len(result['blocks'])}")
            # Don't stop on block found, continue until runtime
            # stop_event.set()
            # break
    
    # Clean up
    stop_event.set()
    for t in threads:
        t.join()

                
        
    # Calculate final statistics
    end_time = time.time()
    total_time = end_time - start_time
    final_total_hashes = sum(stats.values())
    avg_hashrate = final_total_hashes / total_time
    
    if 'blocks' in result:
        logging.info(f"\nTotal valid blocks found: {len(result['blocks'])}")
        for i, block in enumerate(result['blocks'], 1):
            logging.info(f"\nBlock {i}:")
            logging.info(f"Thread: {block['thread']}")
            logging.info(f"Block hash: {block['hash']}")
            logging.info(f"Nonce: {block['nonce']}")
            logging.info(f"Target: {hex(block['target'])}")
            hash_int = int(block['hash'], 16)
            logging.info(f"Hash int: {hex(hash_int)}")
            logging.info(f"Valid: {hash_int < block['target']}")
            logging.info(f"Timestamp: {datetime.fromtimestamp(block['timestamp'])}")
    else:
        logging.info("\nNo blocks found during this session.")

    # Log final statistics
    logging.info("\nMining Session Summary")
    logging.info("---------------------")
    logging.info(f"Total time: {total_time:.2f} seconds")
    logging.info(f"Total hashes performed: {final_total_hashes:,d}")
    logging.info(f"True average hashrate: {avg_hashrate:,.0f} H/s")
    logging.info("\nPer Thread Breakdown:")
    for thread_id, hash_count in stats.items():
        thread_hashrate = hash_count / total_time
        logging.info(f"Thread {thread_id}: {hash_count:,d} hashes ({thread_hashrate:,.0f} H/s)")

if __name__ == "__main__":
    try:
        logging.info("Starting mining session...")
        run_miner()
    except KeyboardInterrupt:
        logging.info("\nMiner stopped by user.")