File size: 4,520 Bytes
9da72c7 81a7d5f 9da72c7 81a7d5f 9da72c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# coding=utf-8
# Copyright 2025 NVIDIA Corporation. All rights reserved.
""" Nemotron-Flash model configuration"""
import math
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class NemotronFlashConfig(PretrainedConfig):
model_type = "nemotron_flash"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=65536,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
calc_logits_for_entire_prompt=False,
output_router_logits=False,
router_aux_loss_coef=0.001,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
sliding_window=None,
max_position_embeddings=262144,
orig_max_position_embeddings=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_experts=16,
use_mamba_kernels=True,
mamba_d_state=16,
mamba_d_conv=4,
mamba_expand=2,
mamba_dt_rank="auto",
mamba_conv_bias=True,
mamba_proj_bias=False,
mamba_inner_layernorms=True,
hybrid_decoder_layer='mamba',
global_attn_idx=None,
attn_implementation_new='flash_attention_2',
mamba2_headdim=64,
rope_type=None,
layer_types=None,
ffn_expand_ratio=None,
d_conv=4,
**kwargs,
):
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.orig_max_position_embeddings = orig_max_position_embeddings
self.attention_dropout = attention_dropout
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.calc_logits_for_entire_prompt = calc_logits_for_entire_prompt
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.use_mamba_kernels = use_mamba_kernels
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_expand = mamba_expand
self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
self.mamba_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.mamba_inner_layernorms = mamba_inner_layernorms
self.kq_norm = kwargs.pop("kq_norm", None)
self.rope = kwargs.pop("rope", False)
self.rope_theta = kwargs.pop("rope_theta", 10000.0)
self.num_memory_tokens = kwargs.pop("num_memory_tokens", 0)
self.attn_hidden_size = kwargs.pop("attn_hidden_size", -1)
self.kq_head_dim = kwargs.pop("kq_head_dim", -1)
self.v_head_dim = kwargs.pop("v_head_dim", -1)
self.new_seq_length = 2048
self.hybrid_decoder_layer = hybrid_decoder_layer
self.global_attn_idx = global_attn_idx
self.attn_implementation_new = attn_implementation_new
self.mamba2_headdim = mamba2_headdim
self.rope_type = rope_type
self.layer_types = layer_types
self.ffn_expand_ratio = ffn_expand_ratio
self.d_conv = d_conv
self.mlp_hidden_act = kwargs.pop("mlp_hidden_act", "silu")
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
|