import math
import numpy as np
from transformers import AutoTokenizer
from helium.encoder import TransformerEncoder
from helium.decoder_model import TransformerDecoder
from .final_projection import final_linear_projection
from .utils import (
import os

# Initialize HuggingFace token from environment
HF_TOKEN = os.getenv("HF_TOKEN")

    parse_hf_config,
    map_hf_weights_to_blocks,
    map_hf_weights_to_blocks_bert,
    map_hf_weights_to_blocks_t5,
    create_causal_mask,
)


def run_vision_transformer_inference(
    hf_weights, config, input_data, driver=None, scheduler=None
):
    """
    Clean ViT-style vision transformer inference pipeline using the custom GPU driver.
    Args:
        hf_weights: dict of model weights (already loaded)
        config: model config dict
        input_data: np.ndarray, shape (batch, features) (flattened image or patch sequence)
        driver: VirtualGPUDriver instance
        scheduler: function for parallel scheduling
    Returns:
        output: np.ndarray, logits or features
    """
    cfg = parse_hf_config(config)
    num_layers = cfg["num_layers"]
    num_heads = cfg["num_heads"]
    hidden_dim = cfg["hidden_dim"]
    max_seq_len = cfg["max_seq_len"]
    # --- Patch Embedding ---
    batch, features = input_data.shape
    img_size = int(math.sqrt(features // 3))
    patch_size = config.get("patch_size", 16) if isinstance(config, dict) else 16
    num_patches = (img_size // patch_size) * (img_size // patch_size)
    patch_dim = 3 * patch_size * patch_size
    img = input_data.reshape(batch, 3, img_size, img_size)
    patches = []
    for i in range(0, img_size, patch_size):
        for j in range(0, img_size, patch_size):
            patch = img[:, :, i : i + patch_size, j : j + patch_size].reshape(batch, -1)
            patches.append(patch)
    x = np.stack(patches, axis=1)  # (batch, num_patches, patch_dim)
    # Patch embedding: linear projection to hidden_dim (GPU-backed)
    patch_embed_w = hf_weights.get("vit.patch_embed.proj.weight") or hf_weights.get(
        "patch_embed.proj.weight"
    )
    patch_embed_b = hf_weights.get("vit.patch_embed.proj.bias") or hf_weights.get(
        "patch_embed.proj.bias"
    )
    if patch_embed_w is None:
        patch_embed_w = np.random.randn(patch_dim, hidden_dim).astype(np.float32) * 0.02
        patch_embed_b = np.zeros(hidden_dim, dtype=np.float32)
    chip_id, sm_id = scheduler("patch_embed") if scheduler else (0, 0)
    x = driver.matmul(x, patch_embed_w, chip_id=chip_id, sm_id=sm_id)
    if patch_embed_b is not None:
        x = x + patch_embed_b
    # Add class token if present
    class_token = hf_weights.get("vit.cls_token") or hf_weights.get("cls_token")
    if class_token is not None:
        cls_tok = np.tile(class_token, (batch, 1, 1))
        x = np.concatenate([cls_tok, x], axis=1)
    # Add positional encoding
    pos_embed = hf_weights.get("vit.pos_embed") or hf_weights.get("pos_embed")
    if pos_embed is not None:
        x = x + pos_embed[:, : x.shape[1], :]
    # --- Transformer Encoder ---
    block_weights_list = map_hf_weights_to_blocks(
        hf_weights, num_layers, prefix="vit.encoder.layer."
    )
    encoder = TransformerEncoder(
        vocab_size=None,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        num_heads=num_heads,
        max_seq_len=max_seq_len,
        embedding_weights=None,
        block_weights_list=block_weights_list,
        driver=driver,
        scheduler=scheduler,
    )
    x = encoder.forward(x)
    # --- Final projection (classification head) ---
    cls_out = x[:, 0, :] if x.shape[1] > 1 else x[:, 0]
    head_w = hf_weights.get("vit.head.weight") or hf_weights.get("head.weight")
    head_b = hf_weights.get("vit.head.bias") or hf_weights.get("head.bias")
    if head_w is None:
        head_w = (
            np.random.randn(hidden_dim, config.get("num_classes", 1000)).astype(
                np.float32
            )
            * 0.02
        )
        head_b = np.zeros(config.get("num_classes", 1000), dtype=np.float32)
    chip_id, sm_id = scheduler("head") if scheduler else (0, 0)
    logits = driver.matmul(cls_out, head_w, chip_id=chip_id, sm_id=sm_id)
    if head_b is not None:
        logits = logits + head_b
    return logits


def run_gpt2_inference(
    hf_weights,
    config,
    tokenizer_name,
    prompt,
    max_length=20,
    driver=None,
    scheduler=None,
):
    """
    Run GPT-2 (decoder-only, causal LM) inference. Returns decoded text. Uses GPU driver if provided.
    """
    cfg = parse_hf_config(config)
    num_layers = cfg["num_layers"]
    num_heads = cfg["num_heads"]
    hidden_dim = cfg["hidden_dim"]
    vocab_size = cfg["vocab_size"]
    max_seq_len = cfg["max_seq_len"]
    block_weights_list = map_hf_weights_to_blocks(hf_weights, num_layers)
    embedding_weights = hf_weights["transformer.wte.weight"]
    lm_head = hf_weights.get("lm_head.weight", embedding_weights)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    input_ids = tokenizer(prompt, return_tensors="np").input_ids
    decoder = TransformerEncoder(
        vocab_size,
        hidden_dim,
        num_layers,
        num_heads,
        max_seq_len,
        embedding_weights,
        block_weights_list,
        driver=driver,
        scheduler=scheduler,
    )
    generated = input_ids.copy()
    for _ in range(max_length):
        seq_len = generated.shape[1]
        mask = create_causal_mask(seq_len)
        x = decoder.forward(generated)
        logits = final_linear_projection(x, lm_head.T)
        next_token_logits = logits[:, -1, :]
        next_token = np.argmax(next_token_logits, axis=-1)
        generated = np.concatenate([generated, next_token[:, None]], axis=1)
        if next_token[0] == tokenizer.eos_token_id:
            break
    return tokenizer.decode(generated[0])


def run_bert_inference(
    hf_weights, config, tokenizer_name, prompt, task="mlm", driver=None, scheduler=None
):
    """
    Run BERT (encoder-only) inference. For 'mlm', returns decoded text. For 'classification', returns logits.
    """
    cfg = parse_hf_config(config)
    num_layers = cfg["num_layers"]
    num_heads = cfg["num_heads"]
    hidden_dim = cfg["hidden_dim"]
    vocab_size = cfg["vocab_size"]
    max_seq_len = cfg["max_seq_len"]
    block_weights_list = map_hf_weights_to_blocks_bert(hf_weights, num_layers)
    embedding_weights = hf_weights["bert.embeddings.word_embeddings.weight"]
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    input_ids = tokenizer(
        prompt,
        return_tensors="np",
        padding="max_length",
        max_length=max_seq_len,
        truncation=True,
    ).input_ids
    encoder = TransformerEncoder(
        vocab_size,
        hidden_dim,
        num_layers,
        num_heads,
        max_seq_len,
        embedding_weights,
        block_weights_list,
        driver=driver,
        scheduler=scheduler,
    )
    x = encoder.forward(input_ids)
    if task == "mlm":
        lm_head = hf_weights["cls.predictions.decoder.weight"]
        logits = final_linear_projection(x, lm_head.T)
        pred_ids = np.argmax(logits, axis=-1)
        return tokenizer.decode(pred_ids[0])
    elif task == "classification":
        pooled = x[:, 0, :]
        classifier = hf_weights["cls.seq_relationship.weight"]
        bias = hf_weights["cls.seq_relationship.bias"]
        logits = np.matmul(pooled, classifier.T) + bias
        return logits
    else:
        raise ValueError("Unknown BERT task")


def run_t5_inference(
    hf_weights,
    config,
    tokenizer_name,
    prompt,
    max_length=20,
    driver=None,
    scheduler=None,
):
    """
    Run T5 (encoder-decoder, seq2seq) inference. Returns decoded text.
    """
    cfg = parse_hf_config(config)
    num_layers = cfg["num_layers"]
    num_heads = cfg["num_heads"]
    hidden_dim = cfg["hidden_dim"]
    vocab_size = cfg["vocab_size"]
    max_seq_len = cfg["max_seq_len"]
    enc_block_weights = map_hf_weights_to_blocks_t5(
        hf_weights, num_layers, prefix="encoder.block."
    )
    dec_block_weights = map_hf_weights_to_blocks_t5(
        hf_weights, num_layers, prefix="decoder.block."
    )
    embedding = hf_weights["shared.weight"]
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    src_ids = tokenizer(
        prompt,
        return_tensors="np",
        padding="max_length",
        max_length=max_seq_len,
        truncation=True,
    ).input_ids
    encoder = TransformerEncoder(
        vocab_size,
        hidden_dim,
        num_layers,
        num_heads,
        max_seq_len,
        embedding,
        enc_block_weights,
        driver=driver,
        scheduler=scheduler,
    )
    decoder = TransformerDecoder(
        vocab_size,
        hidden_dim,
        num_layers,
        num_heads,
        max_seq_len,
        embedding,
        dec_block_weights,
        driver=driver,
        scheduler=scheduler,
    )
    enc_out = encoder.forward(src_ids)
    # Start with <pad> or <bos> token for generation
    if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
        decoder_input = np.array([[tokenizer.pad_token_id]])
    elif hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
        decoder_input = np.array([[tokenizer.bos_token_id]])
    else:
        decoder_input = np.zeros((1, 1), dtype=np.int64)
    generated = decoder_input.copy()
    for _ in range(max_length):
        seq_len = generated.shape[1]
        self_mask = create_causal_mask(seq_len)
        x = decoder.forward(generated, enc_out, self_mask)
        lm_head = (
            hf_weights["lm_head.weight"]
            if "lm_head.weight" in hf_weights
            else embedding
        )
        logits = final_linear_projection(x, lm_head.T)
        next_token_logits = logits[:, -1, :]
        next_token = np.argmax(next_token_logits, axis=-1)
        generated = np.concatenate([generated, next_token[:, None]], axis=1)
        if next_token[0] == tokenizer.eos_token_id:
            break
    return tokenizer.decode(generated[0])


# Example usage (requires loaded weights and config):
# output = run_gpt2_inference(hf_weights, config, 'gpt2', 'Hello, world!', max_length=20)
# output = run_bert_inference(hf_weights, config, 'bert-base-uncased', 'The quick brown fox', task='mlm')
# output = run_t5_inference(hf_weights, config, 't5-small', 'Translate English to German: The house is wonderful.', max_length=20)
