import numpy as np
from .embedding import embedding_lookup, add_positional_encoding
from .positional_encoding import sinusoidal_positional_encoding
from .decoder import transformer_decoder_block

class TransformerDecoder:
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, max_seq_len, embedding_weights, block_weights_list, driver=None, scheduler=None):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.max_seq_len = max_seq_len
        self.embedding_weights = embedding_weights  # (vocab_size, hidden_dim)
        self.block_weights_list = block_weights_list  # list of dicts, one per block
        self.pos_encoding = sinusoidal_positional_encoding(max_seq_len, hidden_dim)
        self.driver = driver
        self.scheduler = scheduler

    def forward(self, input_ids, enc_out, self_mask=None, enc_dec_mask=None):
        # input_ids: (batch, tgt_seq_len)
        x = embedding_lookup(input_ids, self.embedding_weights)
        x = add_positional_encoding(x, self.pos_encoding[:x.shape[1]])
        for block_weights in self.block_weights_list:
            x = transformer_decoder_block(x, enc_out, block_weights, self.num_heads, self_mask, enc_dec_mask, driver=self.driver, scheduler=self.scheduler)
        return x
