import numpy as np
from .layer_norm import layer_norm
from .gelu import gelu
from .multihead_attention import multihead_attention

def transformer_block(x, weights, num_heads, mask=None, driver=None, chip_id=0, sm_id=0, scheduler=None):
    """
    x: (batch, seq_len, hidden_dim)
    weights: dict with keys for all block weights
    driver: VirtualGPUDriver instance
    chip_id, sm_id: hardware location
    scheduler: function to select (chip_id, sm_id) for each op (best practice: round-robin or load-balance)
    """
    # Scheduler setup
    if scheduler is None:
        # Default: round-robin scheduler over available chips/SMs
        def scheduler(op_name, op_idx=[0], chips=driver.hardware_config.get('num_chips', 1), sms=driver.hardware_config.get('num_sms_per_chip', 1)):
            idx = op_idx[0]
            chip = idx % chips
            sm = (idx // chips) % sms
            op_idx[0] += 1
            return chip, sm

    # LayerNorm 1 (GPU, scheduled)
    chip_id, sm_id = scheduler('layernorm1')
    x_norm1 = driver.layernorm(x, weights['ln1.weight'], weights['ln1.bias'], chip_id=chip_id, sm_id=sm_id)
    # Multi-head attention (GPU, scheduled)
    chip_id, sm_id = scheduler('multihead_attention')
    attn_out, _ = multihead_attention(
        x_norm1,
        weights['attn.q_proj.weight'],
        weights['attn.k_proj.weight'],
        weights['attn.v_proj.weight'],
        weights['attn.out_proj.weight'],
        num_heads,
        mask,
        driver=driver,
        chip_id=chip_id,
        sm_id=sm_id,
        scheduler=scheduler
    )
    # Residual 1
    x2 = x + attn_out
    # LayerNorm 2 (GPU, scheduled)
    chip_id, sm_id = scheduler('layernorm2')
    x_norm2 = driver.layernorm(x2, weights['ln2.weight'], weights['ln2.bias'], chip_id=chip_id, sm_id=sm_id)
    # Feedforward (GPU, scheduled)
    chip_id, sm_id = scheduler('ff1')
    ff1 = driver.matmul(x_norm2, weights['ff1.weight'], chip_id=chip_id, sm_id=sm_id) + weights['ff1.bias']
    chip_id, sm_id = scheduler('gelu')
    ff1 = driver.gelu(ff1, chip_id=chip_id, sm_id=sm_id)
    chip_id, sm_id = scheduler('ff2')
    ff2 = driver.matmul(ff1, weights['ff2.weight'], chip_id=chip_id, sm_id=sm_id) + weights['ff2.bias']
    # Residual 2
    output = x2 + ff2
    return output
