avtc commited on
Commit
83685eb
·
verified ·
1 Parent(s): 80c6cfe

Upload folder using huggingface_hub

Browse files
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. __init__.py +26 -0
  3. added_tokens.json +56 -0
  4. chat_template.jinja +159 -0
  5. config.json +147 -0
  6. configuration_minimax_m2.py +131 -0
  7. generation_config.json +5 -0
  8. merges.txt +0 -0
  9. model-00001-of-00033.safetensors +3 -0
  10. model-00002-of-00033.safetensors +3 -0
  11. model-00003-of-00033.safetensors +3 -0
  12. model-00004-of-00033.safetensors +3 -0
  13. model-00005-of-00033.safetensors +3 -0
  14. model-00006-of-00033.safetensors +3 -0
  15. model-00007-of-00033.safetensors +3 -0
  16. model-00008-of-00033.safetensors +3 -0
  17. model-00009-of-00033.safetensors +3 -0
  18. model-00010-of-00033.safetensors +3 -0
  19. model-00011-of-00033.safetensors +3 -0
  20. model-00012-of-00033.safetensors +3 -0
  21. model-00013-of-00033.safetensors +3 -0
  22. model-00014-of-00033.safetensors +3 -0
  23. model-00015-of-00033.safetensors +3 -0
  24. model-00016-of-00033.safetensors +3 -0
  25. model-00017-of-00033.safetensors +3 -0
  26. model-00018-of-00033.safetensors +3 -0
  27. model-00019-of-00033.safetensors +3 -0
  28. model-00020-of-00033.safetensors +3 -0
  29. model-00021-of-00033.safetensors +3 -0
  30. model-00022-of-00033.safetensors +3 -0
  31. model-00023-of-00033.safetensors +3 -0
  32. model-00024-of-00033.safetensors +3 -0
  33. model-00025-of-00033.safetensors +3 -0
  34. model-00026-of-00033.safetensors +3 -0
  35. model-00027-of-00033.safetensors +3 -0
  36. model-00028-of-00033.safetensors +3 -0
  37. model-00029-of-00033.safetensors +3 -0
  38. model-00030-of-00033.safetensors +3 -0
  39. model-00031-of-00033.safetensors +3 -0
  40. model-00032-of-00033.safetensors +3 -0
  41. model-00033-of-00033.safetensors +3 -0
  42. model.safetensors.index.json +3 -0
  43. modeling_minimax_m2.py +843 -0
  44. quant_log.csv +0 -0
  45. quantize_config.json +27 -0
  46. special_tokens_map.json +76 -0
  47. test_minimax_m2_hf.py +178 -0
  48. tokenizer.json +3 -0
  49. tokenizer_config.json +498 -0
  50. vocab.json +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2
+ # SPDX-FileCopyrightText: 2024-2025 [email protected]
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # Contact: [email protected], x.com/qubitium
5
+ #
6
+ # """MiniMax M2 Hugging Face remote code support."""
7
+
8
+ from .configuration_minimax_m2 import MiniMaxM2Config
9
+ from .modeling_minimax_m2 import (
10
+ MiniMaxForCausalLM,
11
+ MiniMaxM2ForCausalLM,
12
+ MiniMaxM2Model,
13
+ MiniMaxM2PreTrainedModel,
14
+ MiniMaxModel,
15
+ MiniMaxPreTrainedModel,
16
+ )
17
+
18
+ __all__ = [
19
+ "MiniMaxM2Config",
20
+ "MiniMaxM2PreTrainedModel",
21
+ "MiniMaxM2Model",
22
+ "MiniMaxM2ForCausalLM",
23
+ "MiniMaxPreTrainedModel",
24
+ "MiniMaxModel",
25
+ "MiniMaxForCausalLM",
26
+ ]
added_tokens.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</minimax:tool_call>": 200053,
3
+ "</think>": 200051,
4
+ "<add_file>": 200036,
5
+ "<code_context>": 200043,
6
+ "<code_interpreter>": 200023,
7
+ "<commit_after>": 200018,
8
+ "<commit_before>": 200016,
9
+ "<commit_message>": 200040,
10
+ "<commit_msg>": 200017,
11
+ "<delete_file>": 200037,
12
+ "<edit_file>": 200039,
13
+ "<empty_output>": 200015,
14
+ "<empty_source_file>": 200041,
15
+ "<file_content>": 200044,
16
+ "<file_sep>": 200049,
17
+ "<filename>": 200006,
18
+ "<filepath>": 200048,
19
+ "<fim_middle>": 200002,
20
+ "<fim_pad>": 200004,
21
+ "<fim_prefix>": 200001,
22
+ "<fim_suffix>": 200003,
23
+ "<function_call>": 200022,
24
+ "<gh_stars>": 200007,
25
+ "<issue_closed>": 200010,
26
+ "<issue_comment>": 200009,
27
+ "<issue_start>": 200008,
28
+ "<jupyter_code>": 200013,
29
+ "<jupyter_error>": 200035,
30
+ "<jupyter_output>": 200014,
31
+ "<jupyter_start>": 200011,
32
+ "<jupyter_text>": 200012,
33
+ "<minimax:tool_call>": 200052,
34
+ "<pr_start>": 200046,
35
+ "<rename_file>": 200038,
36
+ "<repo_struct>": 200042,
37
+ "<reponame>": 200005,
38
+ "<review_comment>": 200047,
39
+ "<source_files>": 200045,
40
+ "<think>": 200050,
41
+ "[e~[": 200020,
42
+ "]!d~[": 200021,
43
+ "]!p~[": 200000,
44
+ "]<]end of image[>[": 200030,
45
+ "]<]end of speech[>[": 200028,
46
+ "]<]end of video[>[": 200032,
47
+ "]<]image[>[": 200025,
48
+ "]<]speech[>[": 200024,
49
+ "]<]start of image[>[": 200029,
50
+ "]<]start of speech[>[": 200027,
51
+ "]<]start of video[>[": 200031,
52
+ "]<]video[>[": 200026,
53
+ "]<]vision pad[>[": 200033,
54
+ "]~!b[": 200034,
55
+ "]~b]": 200019
56
+ }
chat_template.jinja ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# ----------‑‑‑ special token variables ‑‑‑---------- #}
2
+ {%- set toolcall_begin_token = '<minimax:tool_call>' -%}
3
+ {%- set toolcall_end_token = '</minimax:tool_call>' -%}
4
+ {#- Tool Rendering Functions ============================================== -#}
5
+ {%- macro render_tool_namespace(namespace_name, tool_list) -%}
6
+ {%- for tool in tool_list -%}
7
+ <tool>{{ tool.function | tojson(ensure_ascii=False) }}</tool>
8
+ {% endfor -%}
9
+ {%- endmacro -%}
10
+ {%- macro visible_text(content) -%}
11
+ {%- if content is string -%}
12
+ {{ content }}
13
+ {%- elif content is iterable and content is not mapping -%}
14
+ {%- for item in content -%}
15
+ {%- if item is mapping and item.type == 'text' -%}
16
+ {{- item.text }}
17
+ {%- elif item is string -%}
18
+ {{- item }}
19
+ {%- endif -%}
20
+ {%- endfor -%}
21
+ {%- else -%}
22
+ {{- content }}
23
+ {%- endif -%}
24
+ {%- endmacro -%}
25
+ {#- System Message Construction ============================================ -#}
26
+ {%- macro build_system_message(system_message) -%}
27
+ {%- if system_message and system_message.content -%}
28
+ {{- visible_text(system_message.content) }}
29
+ {%- else -%}
30
+ {%- if model_identity is not defined -%}
31
+ {%- set model_identity = "You are a helpful assistant." -%}
32
+ {%- endif -%}
33
+ {{- model_identity }}
34
+ {%- endif -%}
35
+
36
+ {#- Handle current_date -#}
37
+ {%- if system_message and system_message.current_date -%}
38
+ {{- '\n' ~ 'Current date: ' + system_message.current_date }}
39
+ {%- endif -%}
40
+ {#- Handle current_location -#}
41
+ {%- if system_message and system_message.current_location -%}
42
+ {{- '\n' ~ 'Current location: ' + system_message.current_location }}
43
+ {%- endif -%}
44
+ {%- endmacro -%}
45
+ {#- Main Template Logic ================================================= -#}
46
+ {#- Extract system message (only first message if it's system) -#}
47
+ {%- set system_message = none -%}
48
+ {%- set conversation_messages = messages -%}
49
+ {%- if messages and messages[0].role == "system" -%}
50
+ {%- set system_message = messages[0] -%}
51
+ {%- set conversation_messages = messages[1:] -%}
52
+ {%- endif -%}
53
+ {#- Get the last user message turn, for interleved thinking -#}
54
+ {%- set ns = namespace(last_user_index=-1) %}
55
+ {% for m in conversation_messages %}
56
+ {%- if m.role == 'user' %}
57
+ {% set ns.last_user_index = loop.index0 -%}
58
+ {%- endif %}
59
+ {%- endfor %}
60
+ {#- Render system message -#}
61
+ {{- ']~!b[' ~ ']~b]system' ~ '\n' }}
62
+ {{- build_system_message(system_message) }}
63
+ {#- Render tools if available -#}
64
+ {%- if tools -%}
65
+ {{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }}
66
+ {{- '\n' ~ '<tools>' ~ '\n' }}
67
+ {{- render_tool_namespace("functions", tools) }}
68
+ {{- '</tools>' ~ '\n\n' }}
69
+ {{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }}
70
+ {{- '\n' ~ toolcall_begin_token }}
71
+ <invoke name="tool-name-1">
72
+ <parameter name="param-key-1">param-value-1</parameter>
73
+ <parameter name="param-key-2">param-value-2</parameter>
74
+ ...
75
+ </invoke>
76
+ {{- '\n' ~ toolcall_end_token }}
77
+ {%- endif -%}
78
+ {{- '[e~[\n' }}
79
+
80
+ {#- Render messages -#}
81
+ {%- set last_tool_call = namespace(name=none) -%}
82
+ {%- for message in conversation_messages -%}
83
+ {%- if message.role == 'assistant' -%}
84
+ {#- Only render reasoning_content if no user message follows -#}
85
+ {{- ']~b]ai' ~ '\n' }}
86
+
87
+ {%- set reasoning_content = '' %}
88
+ {%- set content = visible_text(message.content) %}
89
+ {%- if message.reasoning_content is string %}
90
+ {%- set reasoning_content = message.reasoning_content %}
91
+ {%- else %}
92
+ {%- if '</think>' in content %}
93
+ {%- set reasoning_content = content.split('</think>')[0].strip('\n').split('<think>')[-1].strip('\n') %}
94
+ {%- set content = content.split('</think>')[-1].strip('\n') %}
95
+ {%- endif %}
96
+ {%- endif %}
97
+ {%- if reasoning_content and loop.index0 > ns.last_user_index -%}
98
+ {{- '<think>' ~ '\n' ~ reasoning_content ~ '\n' ~ '</think>' ~ '\n\n' }}
99
+ {%- endif -%}
100
+ {%- if content -%}
101
+ {{- content }}
102
+ {%- endif -%}
103
+ {%- if message.tool_calls -%}
104
+ {{- '\n' ~ toolcall_begin_token ~ '\n' }}
105
+
106
+ {%- for tool_call in message.tool_calls -%}
107
+ {%- if tool_call.function %}
108
+ {%- set tool_call = tool_call.function %}
109
+ {%- endif %}
110
+ {{- '<invoke name="' + tool_call.name + '">' }}
111
+ {% set _args = tool_call.arguments %}
112
+ {%- for k, v in _args.items() %}
113
+ {{- '<parameter name="' + k + '">' }}
114
+ {{- v | tojson(ensure_ascii=False) if v is not string else v }}
115
+ {{- '</parameter>' }}
116
+ {% endfor %}
117
+ {{- '</invoke>' ~ '\n' }}
118
+ {%- endfor -%}
119
+
120
+ {{- toolcall_end_token}}
121
+ {%- set last_tool_call.name = message.tool_calls[-1].name -%}
122
+ {%- else -%}
123
+ {%- set last_tool_call.name = none -%}
124
+ {%- endif -%}
125
+ {{- '[e~[' ~ '\n' }}
126
+
127
+ {%- elif message.role == 'tool' -%}
128
+ {%- if last_tool_call.name is none -%}
129
+ {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
130
+ {%- endif -%}
131
+ {%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%}
132
+ {{- ']~b]tool' }}
133
+ {%- endif -%}
134
+ {%- if message.content is string -%}
135
+ {{- '\n<response>' }}
136
+ {{- message.content }}
137
+ {{- '</response>' }}
138
+ {%- else -%}
139
+ {%- for tr in message.content -%}
140
+ {{- '\n<response>' }}
141
+ {{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }}
142
+ {{- '\n</response>' }}
143
+ {%- endfor -%}
144
+ {%- endif -%}
145
+ {%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%}
146
+ {{- '[e~[\n' -}}
147
+ {%- endif -%}
148
+
149
+ {%- elif message.role == 'user' -%}
150
+ {{- ']~b]user' ~ '\n' }}
151
+ {{- visible_text(message.content) }}
152
+ {{- '[e~[' ~ '\n' }}
153
+ {%- endif -%}
154
+ {%- endfor -%}
155
+
156
+ {#- Generation prompt -#}
157
+ {%- if add_generation_prompt -%}
158
+ {{- ']~b]ai' ~ '\n' ~ '<think>' ~ '\n' }}
159
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MiniMaxM2ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_type_list": [
7
+ 1,
8
+ 1,
9
+ 1,
10
+ 1,
11
+ 1,
12
+ 1,
13
+ 1,
14
+ 1,
15
+ 1,
16
+ 1,
17
+ 1,
18
+ 1,
19
+ 1,
20
+ 1,
21
+ 1,
22
+ 1,
23
+ 1,
24
+ 1,
25
+ 1,
26
+ 1,
27
+ 1,
28
+ 1,
29
+ 1,
30
+ 1,
31
+ 1,
32
+ 1,
33
+ 1,
34
+ 1,
35
+ 1,
36
+ 1,
37
+ 1,
38
+ 1,
39
+ 1,
40
+ 1,
41
+ 1,
42
+ 1,
43
+ 1,
44
+ 1,
45
+ 1,
46
+ 1,
47
+ 1,
48
+ 1,
49
+ 1,
50
+ 1,
51
+ 1,
52
+ 1,
53
+ 1,
54
+ 1,
55
+ 1,
56
+ 1,
57
+ 1,
58
+ 1,
59
+ 1,
60
+ 1,
61
+ 1,
62
+ 1,
63
+ 1,
64
+ 1,
65
+ 1,
66
+ 1,
67
+ 1,
68
+ 1
69
+ ],
70
+ "attn_window_size": null,
71
+ "auto_map": {
72
+ "AutoConfig": "configuration_minimax_m2.MiniMaxM2Config",
73
+ "AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM"
74
+ },
75
+ "dtype": "bfloat16",
76
+ "head_dim": 128,
77
+ "hidden_act": "silu",
78
+ "hidden_size": 3072,
79
+ "initializer_range": 0.02,
80
+ "intermediate_size": 1536,
81
+ "layernorm_full_attention_beta": 1.0,
82
+ "layernorm_linear_attention_beta": 1.0,
83
+ "layernorm_mlp_beta": 1.0,
84
+ "max_model_len": null,
85
+ "max_position_embeddings": 196608,
86
+ "mlp_intermediate_size": 8192,
87
+ "model_type": "minimax",
88
+ "mtp_transformer_layers": 1,
89
+ "num_attention_heads": 48,
90
+ "num_expert_group": null,
91
+ "num_experts_per_tok": 8,
92
+ "num_hidden_layers": 62,
93
+ "num_key_value_heads": 8,
94
+ "num_local_experts": 256,
95
+ "num_mtp_modules": 3,
96
+ "output_router_logits": false,
97
+ "partial_rotary_factor": 0.5,
98
+ "qk_norm_type": "per_layer",
99
+ "quantization_config": {
100
+ "bits": 4,
101
+ "checkpoint_format": "gptq",
102
+ "desc_act": false,
103
+ "dynamic": {
104
+ "-:.*self_attn": {}
105
+ },
106
+ "group_size": 32,
107
+ "lm_head": false,
108
+ "meta": {
109
+ "act_group_aware": true,
110
+ "damp_auto_increment": 0.01,
111
+ "damp_percent": 0.01,
112
+ "mse": 0.0,
113
+ "quantizer": [
114
+ "gptqmodel:5.0.0-dev0"
115
+ ],
116
+ "static_groups": false,
117
+ "true_sequential": true,
118
+ "uri": "https://github.com/modelcloud/gptqmodel",
119
+ "v2": false,
120
+ "v2_alpha": 0.25
121
+ },
122
+ "pack_dtype": "int32",
123
+ "quant_method": "gptq",
124
+ "sym": true
125
+ },
126
+ "rms_norm_eps": 1e-06,
127
+ "rope_scaling": null,
128
+ "rope_theta": 5000000,
129
+ "rotary_dim": 64,
130
+ "routed_scaling_factor": 1.0,
131
+ "router_aux_loss_coef": 0.001,
132
+ "router_jitter_noise": 0.0,
133
+ "scoring_func": "sigmoid",
134
+ "shared_intermediate_size": 0,
135
+ "shared_moe_mode": "sigmoid",
136
+ "sliding_window": null,
137
+ "swa_rope_theta": -1.0,
138
+ "tie_word_embeddings": false,
139
+ "topk_group": null,
140
+ "transformers_version": "4.57.1",
141
+ "use_cache": true,
142
+ "use_grouped_topk": true,
143
+ "use_mtp": true,
144
+ "use_qk_norm": true,
145
+ "use_routing_bias": true,
146
+ "vocab_size": 200064
147
+ }
configuration_minimax_m2.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2
+ # SPDX-FileCopyrightText: 2024-2025 [email protected]
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # Contact: [email protected], x.com/qubitium
5
+
6
+ """Configuration for the MiniMax M2 architecture."""
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import List, Optional, Union
11
+
12
+ from transformers.configuration_utils import PretrainedConfig
13
+
14
+
15
+ class MiniMaxM2Config(PretrainedConfig):
16
+ model_type = "minimax"
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_size: int = 200_064,
21
+ hidden_size: int = 3_072,
22
+ intermediate_size: int = 1_536,
23
+ mlp_intermediate_size: int = 8_192,
24
+ num_hidden_layers: int = 62,
25
+ num_attention_heads: int = 48,
26
+ num_key_value_heads: int = 8,
27
+ head_dim: Optional[int] = 128,
28
+ num_local_experts: int = 256,
29
+ num_experts_per_tok: int = 8,
30
+ attn_type_list: Optional[List[int]] = None,
31
+ attention_dropout: float = 0.0,
32
+ hidden_act: str = "silu",
33
+ rms_norm_eps: float = 1e-6,
34
+ max_position_embeddings: int = 196_608,
35
+ rope_theta: float = 5_000_000.0,
36
+ rotary_dim: int = 64,
37
+ rope_scaling: Optional[dict] = None,
38
+ use_qk_norm: bool = True,
39
+ qk_norm_type: str = "per_layer",
40
+ use_routing_bias: bool = True,
41
+ scoring_func: str = "sigmoid",
42
+ router_aux_loss_coef: float = 0.001,
43
+ router_jitter_noise: float = 0.0,
44
+ output_router_logits: bool = False,
45
+ use_grouped_topk: bool = True,
46
+ num_expert_group: Optional[int] = None,
47
+ topk_group: Optional[int] = None,
48
+ routed_scaling_factor: float = 1.0,
49
+ layernorm_full_attention_beta: float = 1.0,
50
+ layernorm_linear_attention_beta: float = 1.0,
51
+ layernorm_mlp_beta: float = 1.0,
52
+ shared_intermediate_size: int = 0,
53
+ shared_moe_mode: str = "sigmoid",
54
+ use_mtp: bool = True,
55
+ num_mtp_modules: int = 3,
56
+ mtp_transformer_layers: int = 1,
57
+ attn_window_size: Optional[Union[int, List[int]]] = None,
58
+ swa_rope_theta: float = -1.0,
59
+ sliding_window: Optional[int] = None,
60
+ initializer_range: float = 0.02,
61
+ tie_word_embeddings: bool = False,
62
+ max_model_len: Optional[int] = None,
63
+ bos_token_id: Optional[int] = None,
64
+ eos_token_id: Optional[int] = None,
65
+ pad_token_id: Optional[int] = None,
66
+ use_cache: bool = True,
67
+ **kwargs,
68
+ ) -> None:
69
+ quantization_config = kwargs.pop("quantization_config", None)
70
+ transformers_version = kwargs.pop("transformers_version", None)
71
+
72
+ super().__init__(
73
+ bos_token_id=bos_token_id,
74
+ eos_token_id=eos_token_id,
75
+ tie_word_embeddings=tie_word_embeddings,
76
+ pad_token_id=pad_token_id,
77
+ **kwargs,
78
+ )
79
+
80
+ self.vocab_size = vocab_size
81
+ self.hidden_size = hidden_size
82
+ self.intermediate_size = intermediate_size
83
+ self.mlp_intermediate_size = mlp_intermediate_size
84
+ self.num_hidden_layers = num_hidden_layers
85
+ self.num_attention_heads = num_attention_heads
86
+ self.num_key_value_heads = num_key_value_heads
87
+ self.head_dim = head_dim or hidden_size // num_attention_heads
88
+ self.num_local_experts = num_local_experts
89
+ self.num_experts_per_tok = num_experts_per_tok
90
+ self.attn_type_list = attn_type_list or [1] * num_hidden_layers
91
+ self.attention_dropout = attention_dropout
92
+ self.hidden_act = hidden_act
93
+ self.rms_norm_eps = rms_norm_eps
94
+ self.max_position_embeddings = max_position_embeddings
95
+ self.rope_theta = rope_theta
96
+ self.rotary_dim = rotary_dim
97
+ self.rope_scaling = rope_scaling
98
+ self.use_qk_norm = use_qk_norm
99
+ self.qk_norm_type = qk_norm_type
100
+ self.use_routing_bias = use_routing_bias
101
+ self.scoring_func = scoring_func
102
+ self.router_aux_loss_coef = router_aux_loss_coef
103
+ self.router_jitter_noise = router_jitter_noise
104
+ self.output_router_logits = output_router_logits
105
+ self.use_grouped_topk = use_grouped_topk
106
+ self.num_expert_group = num_expert_group
107
+ self.topk_group = topk_group
108
+ self.routed_scaling_factor = routed_scaling_factor
109
+ self.layernorm_full_attention_beta = layernorm_full_attention_beta
110
+ self.layernorm_linear_attention_beta = layernorm_linear_attention_beta
111
+ self.layernorm_mlp_beta = layernorm_mlp_beta
112
+ self.shared_intermediate_size = shared_intermediate_size
113
+ self.shared_moe_mode = shared_moe_mode
114
+ self.use_mtp = use_mtp
115
+ self.num_mtp_modules = num_mtp_modules
116
+ self.mtp_transformer_layers = mtp_transformer_layers
117
+ self.attn_window_size = attn_window_size
118
+ self.swa_rope_theta = swa_rope_theta
119
+ self.sliding_window = sliding_window
120
+ self.initializer_range = initializer_range
121
+ self.max_model_len = max_model_len
122
+ self.use_cache = use_cache
123
+
124
+ # Convenient accessor used by rotary embedding helper
125
+ self.partial_rotary_factor = float(self.rotary_dim) / float(self.head_dim)
126
+ if quantization_config is not None:
127
+ self.quantization_config = quantization_config
128
+ self.transformers_version = transformers_version
129
+
130
+
131
+ __all__ = ["MiniMaxM2Config"]
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "top_k": 40,
4
+ "transformers_version": "4.57.1"
5
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0d505fe80382b35a99d7859384ff7d7dda1405064814454b3276965792a0f42
3
+ size 4278383871
model-00002-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b069194c6331a3ec757baed2327bec020bba55a270113231a6415d0b5a5ff5f9
3
+ size 4293815342
model-00003-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baf552df9e60ffcf767ff837719feab6e2e626ef7d803ed52de5c1c394e31eeb
3
+ size 4294262965
model-00004-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6efdb4fae1c827e418e057e858cb079daeb4212c00177ddbd3ebdac499b5c569
3
+ size 4294256821
model-00005-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67cbff78bb3b5f1deac16772922914e6c2f4e68119b15740172b192bbb05f5db
3
+ size 4294256821
model-00006-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03787e807b9c0a845791354cd13471089812c608cfe2ab91f53d3ac7e3d74985
3
+ size 4294263087
model-00007-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39eca192d119daf9d11035c033fd4c16d5c47cbae525f7c97ef9c5de687c129a
3
+ size 4294258400
model-00008-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8bf1b24512cf51b7f59f579a177684f7b6bd2ec230fb20f174ad7a7ffe51193
3
+ size 4294263223
model-00009-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f2cf119cc0c3e4083726545d857a637c82b540d4934bda9ec432c83d1d82dcd
3
+ size 4294269367
model-00010-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4827b9c3a7e1dbc5ba2f218911bd1008988b6d657ebbbf6d746944480b4e64f7
3
+ size 4294263221
model-00011-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82bcc01bd984477f24002a0f17a226ba77d5208eef4de959bfe5832826f8b310
3
+ size 4294263221
model-00012-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ac2f442b66f840a2d8f76826c3cd9537eedd29373fc794210f3844640c1f10c
3
+ size 4294269365
model-00013-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40210fa4c5eb5fe1c55d78238702aa796bdabb3b8b365c5d3fd059685da12dc3
3
+ size 4294263221
model-00014-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a9f53fa80b8c7880ff2773fa065ada21494eb66b74495cb6cebbfcf82738f51
3
+ size 4294263221
model-00015-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0569ef9e36468cd087baf26e6ff6be9bd2803cfa6a1c68e6858f8ecffbd7e3b
3
+ size 4294269365
model-00016-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f411f6798f66ed49b1c4802ed633a06564ab6a9ad2a4597b97b61e4adb190cf
3
+ size 4294263221
model-00017-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3f2a12e333cac19054e6f85fa3015c0fd5c765d807f12677790790374e40d19
3
+ size 4294263221
model-00018-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e8054767737b7831e4d6e3ae6127b278f0fc257d65d3a42c355656cab597faf
3
+ size 4294269365
model-00019-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c754a4022a3a819e1f789cb5d82bf7fa9ae0af5d946295cf952dfb771e24ce7b
3
+ size 4294263221
model-00020-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:835197416f0e33e14dbd03e2f3e32b3200a630ee6967f22ab0401c78f14a49cc
3
+ size 4294263379
model-00021-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5125dec669fc82eed7c98a60d5891f4b6af0808e489291afe96bdbb26a1087f8
3
+ size 4294269161
model-00022-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:880bcd6b2b33484e60c109345e585a68f78aa6d3af1f9a24d660537d1c6eb29b
3
+ size 4294263095
model-00023-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:496fdf032750c6e77e7378c64890c35d0585c08d427530e47f39cb99f41c5382
3
+ size 4294263095
model-00024-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd79b9693394711ea6568859ccd2c97a5713e4d1452489d6d18e9228bb0c1fac
3
+ size 4294269239
model-00025-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df4b971ef373c5fca6f3fbd22259e33ac6ce7e555fadebfa6e1d2a73bda08bef
3
+ size 4294263095
model-00026-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:742241f5c9d1f390a852fb8959843e19755a64a4faef8719fe153cfbd325d3fe
3
+ size 4294263095
model-00027-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c53a58a4752575f2a8569111aeb7c571f907e6d72c53e56410a5ecc84c1ef0c6
3
+ size 4294269239
model-00028-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:117ee235261df56dda987e2d542c263cb2167ad619963d801e5f8d4b0a4c164e
3
+ size 4294263095
model-00029-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f47c62161256fe2bd48acf3456f7921cfbe8cabc0fd2b0499784eb434f4f94b3
3
+ size 4294263095
model-00030-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e88360d0aad2edf9c99be71e421993832b58b1d2f614907756763def8d75f91f
3
+ size 4294269361
model-00031-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce9d25d5e30662c2768358fde7312c8c3224beea6cac8fd3c2c1554f2097cf39
3
+ size 4294263223
model-00032-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9c0524572bdb09fa0a0ee1f5fc22a97b6d923de4b0c07ab4f3901a739ea5a04
3
+ size 4294263223
model-00033-of-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc2164af38e9eb6290d44668c5c16977c9888da1132eee5cd5753080a182457e
3
+ size 1023964621
model.safetensors.index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24c77bb30c16a5837aeb31b288170eb7227b8a64cd667e117a2f1cef4aa4ad12
3
+ size 18611843
modeling_minimax_m2.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2
+ # SPDX-FileCopyrightText: 2024-2025 [email protected]
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # Contact: [email protected], x.com/qubitium
5
+
6
+ """PyTorch implementation of the MiniMax M2 architecture for Hugging Face Transformers."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import copy
11
+ import time
12
+ from typing import Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+
18
+ from transformers.activations import ACT2FN
19
+ from transformers.cache_utils import Cache, DynamicCache
20
+ from transformers.generation import GenerationMixin
21
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
22
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging
25
+
26
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, repeat_kv, rotate_half
27
+
28
+ from .configuration_minimax_m2 import MiniMaxM2Config
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ _CONFIG_FOR_DOC = "MiniMaxM2Config"
33
+ _CHECKPOINT_FOR_DOC = "MiniMaxAI/MiniMax-M2"
34
+
35
+
36
+ def load_balancing_loss_func(
37
+ gate_logits: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
38
+ num_experts: int,
39
+ top_k: int,
40
+ attention_mask: Optional[torch.Tensor] = None,
41
+ ) -> torch.Tensor:
42
+ if gate_logits is None:
43
+ return torch.tensor(0.0)
44
+ if isinstance(gate_logits, torch.Tensor):
45
+ logits = gate_logits
46
+ else:
47
+ logits = torch.cat([layer_gate.to(gate_logits[0].device) for layer_gate in gate_logits], dim=0)
48
+
49
+ routing_weights = torch.softmax(logits, dim=-1, dtype=torch.float32)
50
+ _, selected = torch.topk(routing_weights, top_k, dim=-1)
51
+ expert_mask = torch.nn.functional.one_hot(selected, num_experts)
52
+
53
+ if attention_mask is None:
54
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
55
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
56
+ else:
57
+ batch_size, seq_len = attention_mask.shape
58
+ num_layers = logits.shape[0] // (batch_size * seq_len)
59
+
60
+ expanded_mask = (
61
+ attention_mask[None, :, :, None, None]
62
+ .expand(num_layers, batch_size, seq_len, top_k, num_experts)
63
+ .reshape(-1, top_k, num_experts)
64
+ .to(logits.device)
65
+ )
66
+ tokens_per_expert = torch.sum(expert_mask.float() * expanded_mask, dim=0) / torch.sum(expanded_mask, dim=0)
67
+
68
+ router_mask = (
69
+ attention_mask[None, :, :, None]
70
+ .expand(num_layers, batch_size, seq_len, num_experts)
71
+ .reshape(-1, num_experts)
72
+ .to(logits.device)
73
+ )
74
+ router_prob_per_expert = torch.sum(routing_weights * router_mask, dim=0) / torch.sum(router_mask, dim=0)
75
+
76
+ loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
77
+ return loss * num_experts
78
+
79
+
80
+ def apply_rotary_pos_emb_partial(
81
+ q: torch.Tensor,
82
+ k: torch.Tensor,
83
+ cos: torch.Tensor,
84
+ sin: torch.Tensor,
85
+ rotary_dim: int,
86
+ unsqueeze_dim: int = 2,
87
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ cos = cos.unsqueeze(unsqueeze_dim)[..., :rotary_dim]
89
+ sin = sin.unsqueeze(unsqueeze_dim)[..., :rotary_dim]
90
+ q_rot = q[..., :rotary_dim]
91
+ k_rot = k[..., :rotary_dim]
92
+
93
+ q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin)
94
+ k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin)
95
+
96
+ q = torch.cat((q_rot, q[..., rotary_dim:]), dim=-1)
97
+ k = torch.cat((k_rot, k[..., rotary_dim:]), dim=-1)
98
+ return q, k
99
+
100
+
101
+ class MiniMaxM2RMSNorm(nn.Module):
102
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
103
+ super().__init__()
104
+ self.weight = nn.Parameter(torch.ones(hidden_size))
105
+ self.variance_epsilon = eps
106
+
107
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
108
+ input_dtype = hidden_states.dtype
109
+ hidden_states = hidden_states.to(torch.float32)
110
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
111
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
112
+ return (self.weight * hidden_states).to(input_dtype)
113
+
114
+
115
+ class MiniMaxM2MLP(nn.Module):
116
+ def __init__(self, config: MiniMaxM2Config) -> None:
117
+ super().__init__()
118
+ self.hidden_size = config.hidden_size
119
+ self.intermediate_size = config.intermediate_size
120
+
121
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
122
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
123
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
124
+ self.act_fn = ACT2FN[config.hidden_act]
125
+
126
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
127
+ gate = self.act_fn(self.w1(hidden_states))
128
+ up = self.w3(hidden_states)
129
+ gate.mul_(up)
130
+ del up
131
+ return self.w2(gate)
132
+
133
+
134
+ class MiniMaxM2SparseMoeBlock(nn.Module):
135
+ def __init__(self, config: MiniMaxM2Config) -> None:
136
+ super().__init__()
137
+ self.hidden_dim = config.hidden_size
138
+ self.experts = nn.ModuleList([MiniMaxM2MLP(config) for _ in range(config.num_local_experts)])
139
+ self.num_experts = config.num_local_experts
140
+ self.top_k = config.num_experts_per_tok
141
+ self.jitter_noise = config.router_jitter_noise
142
+ self.use_routing_bias = config.use_routing_bias
143
+ self.scoring_func = getattr(config, "scoring_func", "softmax")
144
+ self.use_grouped_topk = getattr(config, "use_grouped_topk", False)
145
+ self.num_expert_group = getattr(config, "num_expert_group", None)
146
+ self.topk_group = getattr(config, "topk_group", None)
147
+ self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
148
+
149
+ if self.use_grouped_topk:
150
+ if self.num_expert_group is None or self.num_expert_group <= 0:
151
+ self.num_expert_group = 1
152
+ if self.topk_group is None or self.topk_group <= 0:
153
+ self.topk_group = min(self.num_expert_group, self.top_k)
154
+ else:
155
+ self.num_expert_group = 1
156
+ self.topk_group = 1
157
+
158
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
159
+ if self.use_routing_bias:
160
+ self.e_score_correction_bias = nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32))
161
+ else:
162
+ self.register_parameter("e_score_correction_bias", None)
163
+
164
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
165
+ batch_size, seq_len, hidden_dim = hidden_states.shape
166
+ if self.training and self.jitter_noise > 0:
167
+ noise = torch.empty_like(hidden_states).uniform_(
168
+ 1.0 - self.jitter_noise,
169
+ 1.0 + self.jitter_noise,
170
+ )
171
+ hidden_states.mul_(noise)
172
+ del noise
173
+
174
+ hidden_states = hidden_states.view(-1, hidden_dim)
175
+ gate_dtype = self.gate.weight.dtype
176
+ router_logits = self.gate(hidden_states.to(gate_dtype)).to(torch.float32)
177
+ if self.e_score_correction_bias is not None:
178
+ # Bias is applied after scoring (see vLLM/SGLang implementations).
179
+ correction_bias = self.e_score_correction_bias.to(router_logits.device, router_logits.dtype)
180
+ else:
181
+ correction_bias = None
182
+
183
+ if self.scoring_func == "sigmoid":
184
+ scores = torch.sigmoid(router_logits)
185
+ elif self.scoring_func == "softmax":
186
+ scores = torch.softmax(router_logits, dim=-1)
187
+ else:
188
+ raise ValueError(f"Unsupported scoring function: {self.scoring_func}")
189
+
190
+ if correction_bias is not None:
191
+ original_scores = scores
192
+ scores.add_(correction_bias)
193
+ else:
194
+ original_scores = scores
195
+ topk_scores: torch.Tensor
196
+ if self.use_grouped_topk and self.num_expert_group > 1:
197
+ experts_per_group = scores.size(-1) // self.num_expert_group
198
+ scores_grouped = scores.view(scores.size(0), self.num_expert_group, experts_per_group)
199
+ if correction_bias is not None:
200
+ topk_in_group = min(2, experts_per_group)
201
+ if topk_in_group > 0:
202
+ group_scores = scores_grouped.topk(topk_in_group, dim=-1)[0].sum(dim=-1)
203
+ else:
204
+ group_scores = torch.zeros_like(scores_grouped[..., 0])
205
+ else:
206
+ group_scores = scores_grouped.max(dim=-1).values
207
+ group_mask = torch.zeros_like(group_scores)
208
+ selected_groups = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True).indices
209
+ group_mask.scatter_(1, selected_groups, 1.0)
210
+ mask = group_mask.unsqueeze(-1).expand(-1, -1, experts_per_group).reshape(scores.size())
211
+ masked_scores = scores.masked_fill(mask == 0, float("-inf"))
212
+ topk_scores, selected_experts = torch.topk(masked_scores, self.top_k, dim=-1, sorted=True)
213
+ else:
214
+ topk_scores, selected_experts = torch.topk(scores, self.top_k, dim=-1, sorted=True)
215
+
216
+ if correction_bias is not None:
217
+ routing_weights = original_scores.gather(1, selected_experts)
218
+ else:
219
+ routing_weights = topk_scores
220
+ del scores, original_scores, topk_scores
221
+
222
+ routing_weights.div_(routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12))
223
+ if self.routed_scaling_factor != 1.0:
224
+ routing_weights.mul_(self.routed_scaling_factor)
225
+ routing_weights = routing_weights.to(hidden_states.dtype)
226
+ selected_experts = selected_experts.to(torch.long)
227
+
228
+ final_hidden_states = torch.zeros_like(hidden_states)
229
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
230
+ del selected_experts
231
+ expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten()
232
+
233
+ # To further reduce memory, process tokens routed to each expert in chunks
234
+ # instead of all at once. A chunk size of 1024 is a reasonable default.
235
+ EXPERT_CHUNK_SIZE = 1024
236
+
237
+ for expert_idx in expert_hit.tolist():
238
+ expert_layer = self.experts[expert_idx]
239
+ idx_full, top_x_full = torch.where(expert_mask[expert_idx].squeeze(0))
240
+
241
+ for i in range(0, top_x_full.size(0), EXPERT_CHUNK_SIZE):
242
+ top_x = top_x_full[i : i + EXPERT_CHUNK_SIZE]
243
+ idx = idx_full[i : i + EXPERT_CHUNK_SIZE]
244
+
245
+ token_states = hidden_states.index_select(0, top_x)
246
+ expert_output = expert_layer(token_states)
247
+
248
+ weights = routing_weights[top_x, idx].unsqueeze(-1)
249
+ expert_output.mul_(weights)
250
+
251
+ final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype))
252
+ del expert_output, token_states, idx, top_x, weights
253
+
254
+ del idx_full, top_x_full
255
+ del hidden_states, routing_weights, expert_mask, expert_hit
256
+ final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
257
+ return final_hidden_states, router_logits
258
+
259
+
260
+ class MiniMaxM2Attention(nn.Module):
261
+ def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None:
262
+ super().__init__()
263
+ self.config = config
264
+ self.layer_idx = layer_idx
265
+
266
+ self.head_dim = config.head_dim
267
+ self.num_heads = config.num_attention_heads
268
+ self.num_key_value_heads = config.num_key_value_heads
269
+ self.num_key_value_groups = self.num_heads // max(1, self.num_key_value_heads)
270
+ self.rotary_dim = config.rotary_dim
271
+ self.scaling = self.head_dim**-0.5
272
+ self.attention_dropout = config.attention_dropout
273
+ self.is_causal = True
274
+
275
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
276
+ max_model_len = getattr(config, "max_model_len", None)
277
+ if max_model_len is not None:
278
+ max_position_embeddings = max(max_position_embeddings, max_model_len)
279
+
280
+ attn_window_size = getattr(config, "attn_window_size", None)
281
+ if isinstance(attn_window_size, list):
282
+ sliding_window = attn_window_size[layer_idx]
283
+ else:
284
+ sliding_window = attn_window_size
285
+ if sliding_window is not None and sliding_window <= 0:
286
+ sliding_window = None
287
+ self.sliding_window = sliding_window
288
+
289
+ swa_rope_theta = getattr(config, "swa_rope_theta", -1.0)
290
+ rope_theta = config.rope_theta
291
+ if self.sliding_window is not None and swa_rope_theta > 0:
292
+ rope_theta = swa_rope_theta
293
+
294
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
295
+ self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
296
+ self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
297
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
298
+
299
+ self.use_qk_norm = config.use_qk_norm
300
+ if self.use_qk_norm:
301
+ self.q_norm = MiniMaxM2RMSNorm(self.num_heads * self.head_dim, eps=config.rms_norm_eps)
302
+ self.k_norm = MiniMaxM2RMSNorm(self.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps)
303
+
304
+ rope_config = copy.deepcopy(config)
305
+ rope_config.hidden_size = config.hidden_size
306
+ rope_config.num_attention_heads = config.num_attention_heads
307
+ rope_config.partial_rotary_factor = float(config.rotary_dim) / float(self.head_dim)
308
+ rope_config.rope_theta = rope_theta
309
+ rope_config.max_position_embeddings = max_position_embeddings
310
+ self.rotary_emb = LlamaRotaryEmbedding(rope_config)
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states: torch.Tensor,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ position_ids: Optional[torch.LongTensor] = None,
317
+ past_key_values: Optional[Cache] = None,
318
+ use_cache: Optional[bool] = False,
319
+ cache_position: Optional[torch.LongTensor] = None,
320
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
321
+ output_attentions: bool = False,
322
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
323
+ bsz, q_len, _ = hidden_states.size()
324
+ device = hidden_states.device
325
+
326
+ # projections
327
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
328
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
330
+ del hidden_states
331
+
332
+ # optional QK normalization
333
+ if self.use_qk_norm:
334
+ q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1)
335
+ k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1)
336
+ q_flat = self.q_norm(q_flat)
337
+ k_flat = self.k_norm(k_flat)
338
+ query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
339
+ key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
340
+
341
+ # rotary embeddings
342
+ if position_embeddings is None:
343
+ cos, sin = self.rotary_emb(value_states, position_ids)
344
+ else:
345
+ cos, sin = position_embeddings
346
+
347
+ query_states, key_states = apply_rotary_pos_emb_partial(
348
+ query_states.transpose(1, 2), key_states.transpose(1, 2), cos, sin, self.rotary_dim
349
+ )
350
+ query_states = query_states.transpose(1, 2)
351
+ key_states = key_states.transpose(1, 2)
352
+
353
+ # handle cache
354
+ if past_key_values is not None:
355
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
356
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
357
+
358
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
359
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
360
+
361
+ query_dtype = query_states.dtype
362
+ key_len = key_states.shape[-2]
363
+
364
+ # precompute sliding-window mask
365
+ window_mask = None
366
+ if self.sliding_window is not None and past_key_values is None:
367
+ q_pos = torch.arange(q_len, device=device).view(1, 1, q_len, 1)
368
+ k_pos = torch.arange(key_len, device=device).view(1, 1, 1, key_len)
369
+ wm = k_pos < (q_pos - self.sliding_window)
370
+ if wm.any():
371
+ window_mask = wm.squeeze(1) # (1, q_len, key_len)
372
+ del q_pos, k_pos, wm
373
+
374
+ attn_output_parts = []
375
+ attn_weights_list = [] if output_attentions else None
376
+
377
+ for h in range(self.num_heads):
378
+ # (bsz, q_len, key_len)
379
+ q = query_states[:, h, :, :]
380
+ k = key_states[:, h, :, :]
381
+ v = value_states[:, h, :, :]
382
+
383
+ # Chunked attention computation to reduce peak memory usage
384
+ out_parts = []
385
+ attn_parts = [] if output_attentions else None
386
+
387
+ # A smaller chunk size reduces memory but may be slightly slower
388
+ chunk_size = 1024
389
+ for i in range(0, q.size(1), chunk_size):
390
+ q_chunk = q[:, i:i + chunk_size, :]
391
+
392
+ # attn_chunk has shape (bsz, chunk_size, key_len)
393
+ attn_chunk = torch.matmul(q_chunk, k.transpose(-2, -1))
394
+ attn_chunk.mul_(self.scaling)
395
+
396
+ # Apply masks to the chunk
397
+ if attention_mask is not None:
398
+ attn_chunk.add_(attention_mask.squeeze(1)[:, i:i + chunk_size, :])
399
+
400
+ if window_mask is not None:
401
+ attn_chunk.masked_fill_(window_mask[:, i:i + chunk_size, :], float("-inf"))
402
+
403
+ attn_chunk = torch.softmax(attn_chunk, dim=-1, dtype=torch.float32).to(query_dtype)
404
+
405
+ if self.training and self.attention_dropout > 0:
406
+ attn_chunk = F.dropout(attn_chunk, p=self.attention_dropout, training=True)
407
+
408
+ if output_attentions:
409
+ attn_parts.append(attn_chunk)
410
+
411
+ # output_chunk has shape (bsz, chunk_size, head_dim)
412
+ out_chunk = torch.matmul(attn_chunk, v)
413
+ out_parts.append(out_chunk)
414
+
415
+ del q_chunk, attn_chunk, out_chunk
416
+
417
+ out = torch.cat(out_parts, dim=1)
418
+ attn_output_parts.append(out)
419
+
420
+ if output_attentions:
421
+ attn = torch.cat(attn_parts, dim=1)
422
+ attn_weights_list.append(attn)
423
+ del attn, attn_parts
424
+
425
+ del q, k, v, out, out_parts
426
+
427
+ attn_output = torch.stack(attn_output_parts, dim=1)
428
+ del attn_output_parts, query_states, key_states, value_states
429
+
430
+ attn_weights = torch.stack(attn_weights_list, dim=1) if output_attentions else None
431
+
432
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
433
+ attn_output = self.o_proj(attn_output)
434
+
435
+ return attn_output, attn_weights
436
+
437
+
438
+ class MiniMaxM2LogitsProcessor(nn.Module):
439
+ def __init__(self, config: MiniMaxM2Config) -> None:
440
+ super().__init__()
441
+ self.scale = getattr(config, "logits_scale", 1.0)
442
+
443
+ def forward(self, lm_head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
444
+ logits = lm_head(hidden_states)
445
+ if self.scale != 1.0:
446
+ logits = logits * self.scale
447
+ return logits
448
+
449
+
450
+ class MiniMaxM2DecoderLayer(nn.Module):
451
+ def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None:
452
+ super().__init__()
453
+ self.hidden_size = config.hidden_size
454
+ self.self_attn = MiniMaxM2Attention(config, layer_idx)
455
+ self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config)
456
+ self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
457
+ self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
458
+
459
+ def forward(
460
+ self,
461
+ hidden_states: torch.Tensor,
462
+ attention_mask: Optional[torch.Tensor] = None,
463
+ position_ids: Optional[torch.LongTensor] = None,
464
+ past_key_values: Optional[Cache] = None,
465
+ use_cache: Optional[bool] = False,
466
+ cache_position: Optional[torch.LongTensor] = None,
467
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
468
+ output_attentions: bool = False,
469
+ residual: Optional[torch.Tensor] = None,
470
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
471
+ residual_input = hidden_states if residual is None else residual
472
+ hidden_states = self.input_layernorm(hidden_states)
473
+
474
+ attn_output, attn_weights = self.self_attn(
475
+ hidden_states=hidden_states,
476
+ attention_mask=attention_mask,
477
+ position_ids=position_ids,
478
+ past_key_values=past_key_values,
479
+ use_cache=use_cache,
480
+ cache_position=cache_position,
481
+ position_embeddings=position_embeddings,
482
+ output_attentions=output_attentions,
483
+ )
484
+ hidden_states = residual_input + attn_output
485
+
486
+ residual_post_attn = hidden_states
487
+ hidden_states = self.post_attention_layernorm(hidden_states)
488
+ moe_output, router_logits = self.block_sparse_moe(hidden_states)
489
+ hidden_states = residual_post_attn + moe_output
490
+
491
+ return hidden_states, hidden_states, router_logits, attn_weights
492
+
493
+
494
+ class MiniMaxM2PreTrainedModel(PreTrainedModel):
495
+ config_class = MiniMaxM2Config
496
+ base_model_prefix = "model"
497
+ supports_gradient_checkpointing = True
498
+ _no_split_modules = ["MiniMaxM2DecoderLayer"]
499
+ _supports_flash_attn = False
500
+ _supports_sdpa = False
501
+ _supports_attention_backend = False
502
+
503
+ def _init_weights(self, module: nn.Module) -> None:
504
+ if isinstance(module, nn.Linear):
505
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
506
+ if module.bias is not None:
507
+ module.bias.data.zero_()
508
+ elif isinstance(module, nn.Embedding):
509
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
510
+ if module.padding_idx is not None:
511
+ module.weight.data[module.padding_idx].zero_()
512
+
513
+ def _remap_qkv_weights(self, state_dict):
514
+ num_q = self.config.num_attention_heads * self.config.head_dim
515
+ num_kv = self.config.num_key_value_heads * self.config.head_dim
516
+
517
+ for layer_idx in range(self.config.num_hidden_layers):
518
+ prefix = f"model.layers.{layer_idx}.self_attn"
519
+ weight_key = f"{prefix}.qkv_proj.weight"
520
+ if weight_key in state_dict:
521
+ qkv_weight = state_dict.pop(weight_key)
522
+ q_weight, k_weight, v_weight = qkv_weight.split([num_q, num_kv, num_kv], dim=0)
523
+ state_dict.setdefault(f"{prefix}.q_proj.weight", q_weight)
524
+ state_dict.setdefault(f"{prefix}.k_proj.weight", k_weight)
525
+ state_dict.setdefault(f"{prefix}.v_proj.weight", v_weight)
526
+
527
+ def load_state_dict(self, state_dict, strict: bool = True):
528
+ if not isinstance(state_dict, dict):
529
+ raise TypeError(f"Expected state_dict to be dict, got {type(state_dict)}")
530
+
531
+ filtered_state_dict = {}
532
+ drop_suffixes = ("weight_scale_inv", "weight_scale", "input_scale", "scales", "amax")
533
+ for key, value in state_dict.items():
534
+ if key.endswith(drop_suffixes) or "fp8" in key:
535
+ continue
536
+ filtered_state_dict[key] = value
537
+
538
+ self._remap_qkv_weights(filtered_state_dict)
539
+
540
+ if logger.isEnabledFor(logging.INFO):
541
+ logger.info(
542
+ "MiniMaxM2: loading %d tensors (filtered from %d original).",
543
+ len(filtered_state_dict),
544
+ len(state_dict),
545
+ )
546
+
547
+ load_start = time.perf_counter()
548
+ result = super().load_state_dict(filtered_state_dict, strict=strict)
549
+ load_elapsed = time.perf_counter() - load_start
550
+ if logger.isEnabledFor(logging.INFO):
551
+ logger.info("MiniMaxM2: state_dict load finished in %.2f seconds.", load_elapsed)
552
+
553
+ return result
554
+
555
+
556
+ class MiniMaxM2Model(MiniMaxM2PreTrainedModel):
557
+ def __init__(self, config: MiniMaxM2Config) -> None:
558
+ super().__init__(config)
559
+ self.padding_idx = config.pad_token_id
560
+ self.vocab_size = config.vocab_size
561
+
562
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
563
+ self.layers = nn.ModuleList(
564
+ [MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
565
+ )
566
+ self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
567
+ self.gradient_checkpointing = False
568
+
569
+ self.post_init()
570
+
571
+ def get_input_embeddings(self) -> nn.Module:
572
+ return self.embed_tokens
573
+
574
+ def set_input_embeddings(self, value: nn.Module) -> None:
575
+ self.embed_tokens = value
576
+
577
+ def forward(
578
+ self,
579
+ input_ids: Optional[torch.LongTensor] = None,
580
+ attention_mask: Optional[torch.Tensor] = None,
581
+ position_ids: Optional[torch.LongTensor] = None,
582
+ past_key_values: Optional[Cache] = None,
583
+ inputs_embeds: Optional[torch.Tensor] = None,
584
+ cache_position: Optional[torch.LongTensor] = None,
585
+ use_cache: Optional[bool] = None,
586
+ output_attentions: bool = False,
587
+ output_hidden_states: bool = False,
588
+ output_router_logits: Optional[bool] = None,
589
+ return_dict: Optional[bool] = None,
590
+ ) -> Union[MoeModelOutputWithPast, Tuple]:
591
+ if (input_ids is None) == (inputs_embeds is None):
592
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
593
+
594
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
595
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
596
+ output_router_logits = (
597
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
598
+ )
599
+
600
+ if inputs_embeds is None:
601
+ inputs_embeds = self.embed_tokens(input_ids)
602
+
603
+ if use_cache and past_key_values is None:
604
+ past_key_values = DynamicCache(config=self.config)
605
+
606
+ if cache_position is None:
607
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
608
+ cache_position = torch.arange(
609
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
610
+ )
611
+
612
+ if position_ids is None:
613
+ position_ids = cache_position.unsqueeze(0)
614
+
615
+ if self.config.sliding_window is not None:
616
+ causal_mask = create_sliding_window_causal_mask(
617
+ config=self.config,
618
+ input_embeds=inputs_embeds,
619
+ attention_mask=attention_mask,
620
+ cache_position=cache_position,
621
+ past_key_values=past_key_values,
622
+ position_ids=position_ids,
623
+ )
624
+ else:
625
+ causal_mask = create_causal_mask(
626
+ config=self.config,
627
+ input_embeds=inputs_embeds,
628
+ attention_mask=attention_mask,
629
+ cache_position=cache_position,
630
+ past_key_values=past_key_values,
631
+ position_ids=position_ids,
632
+ )
633
+
634
+ hidden_states = inputs_embeds
635
+
636
+ all_hidden_states = () if output_hidden_states else None
637
+ all_attentions = () if output_attentions else None
638
+ all_router_logits = () if output_router_logits else None
639
+
640
+ residual = None
641
+ for decoder_layer in self.layers:
642
+ if output_hidden_states:
643
+ all_hidden_states = all_hidden_states + (hidden_states,)
644
+
645
+ layer_outputs = decoder_layer(
646
+ hidden_states,
647
+ attention_mask=causal_mask,
648
+ position_ids=position_ids,
649
+ past_key_values=past_key_values,
650
+ use_cache=use_cache,
651
+ cache_position=cache_position,
652
+ position_embeddings=None,
653
+ output_attentions=output_attentions,
654
+ residual=residual,
655
+ )
656
+
657
+ hidden_states, residual, router_logits, attn_weights = layer_outputs
658
+
659
+ if output_router_logits:
660
+ all_router_logits = all_router_logits + (router_logits,)
661
+ if output_attentions:
662
+ all_attentions = all_attentions + (attn_weights,)
663
+
664
+ hidden_states = self.norm(hidden_states)
665
+
666
+ if output_hidden_states:
667
+ all_hidden_states = all_hidden_states + (hidden_states,)
668
+
669
+ if not return_dict:
670
+ outputs = (hidden_states, past_key_values)
671
+ if output_hidden_states:
672
+ outputs += (all_hidden_states,)
673
+ if output_attentions:
674
+ outputs += (all_attentions,)
675
+ if output_router_logits:
676
+ outputs += (all_router_logits,)
677
+ return outputs
678
+
679
+ return MoeModelOutputWithPast(
680
+ last_hidden_state=hidden_states,
681
+ past_key_values=past_key_values,
682
+ hidden_states=all_hidden_states,
683
+ attentions=all_attentions,
684
+ router_logits=all_router_logits,
685
+ )
686
+
687
+
688
+ class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin):
689
+ def __init__(self, config: MiniMaxM2Config) -> None:
690
+ super().__init__(config)
691
+ self.model = MiniMaxM2Model(config)
692
+ self.vocab_size = config.vocab_size
693
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
694
+ self.router_aux_loss_coef = config.router_aux_loss_coef
695
+ self.num_experts = config.num_local_experts
696
+ self.num_experts_per_tok = config.num_experts_per_tok
697
+ self.logits_processor = MiniMaxM2LogitsProcessor(config)
698
+
699
+ self.post_init()
700
+
701
+ def get_input_embeddings(self) -> nn.Module:
702
+ return self.model.embed_tokens
703
+
704
+ def set_input_embeddings(self, value: nn.Module) -> None:
705
+ self.model.embed_tokens = value
706
+
707
+ def get_output_embeddings(self) -> nn.Module:
708
+ return self.lm_head
709
+
710
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
711
+ self.lm_head = new_embeddings
712
+
713
+ def prepare_inputs_for_generation(
714
+ self,
715
+ input_ids: torch.LongTensor,
716
+ past_key_values: Optional[Cache] = None,
717
+ attention_mask: Optional[torch.Tensor] = None,
718
+ inputs_embeds: Optional[torch.Tensor] = None,
719
+ **kwargs,
720
+ ):
721
+ if past_key_values is not None:
722
+ input_ids = input_ids[:, -1:]
723
+ if attention_mask is not None:
724
+ attention_mask = attention_mask[:, -past_key_values.get_seq_length() - 1 :]
725
+
726
+ return {
727
+ "input_ids": input_ids,
728
+ "attention_mask": attention_mask,
729
+ "past_key_values": past_key_values,
730
+ "inputs_embeds": inputs_embeds,
731
+ }
732
+
733
+ def forward(
734
+ self,
735
+ input_ids: Optional[torch.LongTensor] = None,
736
+ attention_mask: Optional[torch.Tensor] = None,
737
+ position_ids: Optional[torch.LongTensor] = None,
738
+ past_key_values: Optional[Cache] = None,
739
+ inputs_embeds: Optional[torch.Tensor] = None,
740
+ labels: Optional[torch.LongTensor] = None,
741
+ cache_position: Optional[torch.LongTensor] = None,
742
+ use_cache: Optional[bool] = None,
743
+ output_attentions: bool = False,
744
+ output_hidden_states: bool = False,
745
+ output_router_logits: Optional[bool] = None,
746
+ return_dict: Optional[bool] = None,
747
+ logits_to_keep: Union[int, torch.Tensor] = 0,
748
+ ) -> Union[MoeCausalLMOutputWithPast, Tuple]:
749
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
750
+ output_router_logits = (
751
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
752
+ )
753
+
754
+ model_outputs = self.model(
755
+ input_ids=input_ids,
756
+ attention_mask=attention_mask,
757
+ position_ids=position_ids,
758
+ past_key_values=past_key_values,
759
+ inputs_embeds=inputs_embeds,
760
+ cache_position=cache_position,
761
+ use_cache=use_cache,
762
+ output_attentions=output_attentions,
763
+ output_hidden_states=output_hidden_states,
764
+ output_router_logits=output_router_logits,
765
+ return_dict=True,
766
+ )
767
+
768
+ hidden_states = model_outputs.last_hidden_state
769
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None)
770
+ logits = self.logits_processor(self.lm_head, hidden_states[:, slice_indices, :])
771
+
772
+ loss = None
773
+ if labels is not None:
774
+ shift_logits = logits[..., :-1, :].contiguous()
775
+ shift_labels = labels[..., 1:].contiguous()
776
+ loss_fct = nn.CrossEntropyLoss()
777
+ loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
778
+
779
+ aux_loss = None
780
+ if output_router_logits and model_outputs.router_logits is not None:
781
+ aux_loss = load_balancing_loss_func(
782
+ model_outputs.router_logits,
783
+ num_experts=self.num_experts,
784
+ top_k=self.num_experts_per_tok,
785
+ attention_mask=attention_mask,
786
+ )
787
+ if loss is not None:
788
+ loss = loss + self.router_aux_loss_coef * aux_loss.to(loss.device)
789
+
790
+ if not return_dict:
791
+ output = (logits,) + (model_outputs.past_key_values,)
792
+ if output_hidden_states:
793
+ output += (model_outputs.hidden_states,)
794
+ if output_attentions:
795
+ output += (model_outputs.attentions,)
796
+ if output_router_logits:
797
+ output += (model_outputs.router_logits,)
798
+ return ((loss,) + output) if loss is not None else output
799
+
800
+ return MoeCausalLMOutputWithPast(
801
+ loss=loss,
802
+ aux_loss=aux_loss,
803
+ logits=logits,
804
+ past_key_values=model_outputs.past_key_values,
805
+ hidden_states=model_outputs.hidden_states,
806
+ attentions=model_outputs.attentions,
807
+ router_logits=model_outputs.router_logits,
808
+ )
809
+
810
+ # -----------------------------------------------------------------------------
811
+ # Backward compatibility aliases
812
+ # -----------------------------------------------------------------------------
813
+
814
+ MiniMaxRMSNorm = MiniMaxM2RMSNorm
815
+ MiniMaxSparseMoeBlock = MiniMaxM2SparseMoeBlock
816
+ MiniMaxAttention = MiniMaxM2Attention
817
+ MiniMaxDecoderLayer = MiniMaxM2DecoderLayer
818
+ MiniMaxMLP = MiniMaxM2MLP
819
+ MiniMaxPreTrainedModel = MiniMaxM2PreTrainedModel
820
+ MiniMaxModel = MiniMaxM2Model
821
+
822
+
823
+ class MiniMaxForCausalLM(MiniMaxM2ForCausalLM):
824
+ """Alias for compatibility with checkpoints exporting MiniMaxForCausalLM."""
825
+
826
+
827
+ __all__ = [
828
+ "MiniMaxM2RMSNorm",
829
+ "MiniMaxM2SparseMoeBlock",
830
+ "MiniMaxM2Attention",
831
+ "MiniMaxM2DecoderLayer",
832
+ "MiniMaxM2Model",
833
+ "MiniMaxM2ForCausalLM",
834
+ "MiniMaxM2PreTrainedModel",
835
+ "MiniMaxRMSNorm",
836
+ "MiniMaxSparseMoeBlock",
837
+ "MiniMaxAttention",
838
+ "MiniMaxDecoderLayer",
839
+ "MiniMaxPreTrainedModel",
840
+ "MiniMaxModel",
841
+ "MiniMaxMLP",
842
+ "MiniMaxForCausalLM",
843
+ ]
quant_log.csv ADDED
The diff for this file is too large to render. See raw diff
 
quantize_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bits": 4,
3
+ "dynamic": {
4
+ "-:.*self_attn": {}
5
+ },
6
+ "group_size": 32,
7
+ "desc_act": false,
8
+ "sym": true,
9
+ "lm_head": false,
10
+ "quant_method": "gptq",
11
+ "checkpoint_format": "gptq",
12
+ "pack_dtype": "int32",
13
+ "meta": {
14
+ "quantizer": [
15
+ "gptqmodel:5.0.0-dev0"
16
+ ],
17
+ "uri": "https://github.com/modelcloud/gptqmodel",
18
+ "damp_percent": 0.01,
19
+ "damp_auto_increment": 0.01,
20
+ "static_groups": false,
21
+ "true_sequential": true,
22
+ "mse": 0.0,
23
+ "v2": false,
24
+ "v2_alpha": 0.25,
25
+ "act_group_aware": true
26
+ }
27
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<code_interpreter>",
4
+ "<commit_after>",
5
+ "<commit_before>",
6
+ "<commit_msg>",
7
+ "<empty_output>",
8
+ "<filename>",
9
+ "<fim_middle>",
10
+ "<fim_pad>",
11
+ "<fim_prefix>",
12
+ "<fim_suffix>",
13
+ "<function_call>",
14
+ "<gh_stars>",
15
+ "]<]speech[>[",
16
+ "]<]image[>[",
17
+ "]<]video[>[",
18
+ "]<]start of speech[>[",
19
+ "]<]end of speech[>[",
20
+ "]<]start of image[>[",
21
+ "]<]end of image[>[",
22
+ "]<]start of video[>[",
23
+ "]<]end of video[>[",
24
+ "]<]vision pad[>[",
25
+ "]~!b[",
26
+ "<issue_closed>",
27
+ "<issue_comment>",
28
+ "<issue_start>",
29
+ "<jupyter_code>",
30
+ "<jupyter_output>",
31
+ "<jupyter_start>",
32
+ "<jupyter_text>",
33
+ "<reponame>",
34
+ "[e~[",
35
+ "]!d~[",
36
+ "]!p~[",
37
+ "]~b]",
38
+ "<jupyter_error>",
39
+ "<add_file>",
40
+ "<delete_file>",
41
+ "<rename_file>",
42
+ "<edit_file>",
43
+ "<commit_message>",
44
+ "<empty_source_file>",
45
+ "<repo_struct>",
46
+ "<code_context>",
47
+ "<file_content>",
48
+ "<source_files>",
49
+ "<pr_start>",
50
+ "<review_comment>",
51
+ "<filepath>",
52
+ "<file_sep>"
53
+ ],
54
+ "bos_token": {
55
+ "content": "]~!b[",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false
60
+ },
61
+ "eos_token": {
62
+ "content": "[e~[",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false
67
+ },
68
+ "pad_token": "[e~[",
69
+ "unk_token": {
70
+ "content": "]!d~[",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false
75
+ }
76
+ }
test_minimax_m2_hf.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
2
+ # SPDX-FileCopyrightText: 2024-2025 [email protected]
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # Contact: [email protected], x.com/qubitium
5
+
6
+ """
7
+ MiniMax-M2 Hugging Face checkpoint sanity check with streaming output.
8
+
9
+ Usage:
10
+ python test_minimax_m2_hf.py \
11
+ --model-path /monster/data/model/MiniMax-M2-bf16 \
12
+ --question "How many letter A are there in the word Alphabet? Reply with the number only."
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import threading
19
+ from pathlib import Path
20
+
21
+ import torch.nn as nn
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
23
+
24
+ # from gptqmodel.hf_minimax_m2.modeling_minimax_m2 import (
25
+ # MiniMaxAttention,
26
+ # MiniMaxDecoderLayer,
27
+ # MiniMaxForCausalLM,
28
+ # MiniMaxMLP,
29
+ # MiniMaxM2Attention,
30
+ # MiniMaxM2DecoderLayer,
31
+ # MiniMaxM2ForCausalLM,
32
+ # MiniMaxM2MLP,
33
+ # MiniMaxM2RMSNorm,
34
+ # MiniMaxM2SparseMoeBlock,
35
+ # MiniMaxRMSNorm,
36
+ # MiniMaxSparseMoeBlock,
37
+ # )
38
+
39
+
40
+ def parse_args() -> argparse.Namespace:
41
+ parser = argparse.ArgumentParser(description="MiniMax-M2 HF checkpoint smoke test.")
42
+ parser.add_argument(
43
+ "--model-path",
44
+ type=str,
45
+ default="/monster/data/model/MiniMax-M2-bf16",
46
+ help="Path to the MiniMax-M2 Hugging Face checkpoint directory.",
47
+ )
48
+ parser.add_argument(
49
+ "--question",
50
+ type=str,
51
+ default="How many letter A are there in the word Alphabet? Reply with the number only.",
52
+ help="User question to send through the chat template.",
53
+ )
54
+ parser.add_argument(
55
+ "--max-new-tokens",
56
+ type=int,
57
+ default=512,
58
+ help="Maximum number of new tokens to sample from the model.",
59
+ )
60
+ return parser.parse_args()
61
+
62
+
63
+ def build_prompt(tokenizer: AutoTokenizer, question: str) -> str:
64
+ messages = [
65
+ {"role": "system", "content": "You are a helpful assistant."},
66
+ {"role": "user", "content": question},
67
+ ]
68
+ return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
69
+
70
+
71
+ # def assert_module_types(model: MiniMaxM2ForCausalLM) -> None:
72
+ # causal_lm_types = (MiniMaxM2ForCausalLM, MiniMaxForCausalLM)
73
+ # decoder_layer_types = (MiniMaxM2DecoderLayer, MiniMaxDecoderLayer)
74
+ # attention_types = (MiniMaxM2Attention, MiniMaxAttention)
75
+ # moe_block_types = (MiniMaxM2SparseMoeBlock, MiniMaxSparseMoeBlock)
76
+ # norm_types = (MiniMaxM2RMSNorm, MiniMaxRMSNorm)
77
+ # mlp_types = (MiniMaxM2MLP, MiniMaxMLP)
78
+ #
79
+ # assert isinstance(
80
+ # model, causal_lm_types
81
+ # ), f"Expected MiniMaxM2ForCausalLM/MiniMaxForCausalLM, received {type(model).__name__}"
82
+ #
83
+ # decoder = getattr(model, "model", None)
84
+ # assert decoder is not None, "Model is missing the `model` attribute with decoder layers."
85
+ #
86
+ # for layer_idx, layer in enumerate(decoder.layers):
87
+ # assert isinstance(
88
+ # layer, decoder_layer_types
89
+ # ), f"Layer {layer_idx}: expected MiniMax(M2)DecoderLayer, got {type(layer).__name__}"
90
+ # assert isinstance(
91
+ # layer.self_attn, attention_types
92
+ # ), f"Layer {layer_idx}: unexpected self_attn type {type(layer.self_attn).__name__}"
93
+ # assert isinstance(
94
+ # layer.block_sparse_moe, moe_block_types
95
+ # ), f"Layer {layer_idx}: unexpected MoE block type {type(layer.block_sparse_moe).__name__}"
96
+ # assert isinstance(
97
+ # layer.input_layernorm, norm_types
98
+ # ), f"Layer {layer_idx}: unexpected input_layernorm type {type(layer.input_layernorm).__name__}"
99
+ # assert isinstance(
100
+ # layer.post_attention_layernorm, norm_types
101
+ # ), f"Layer {layer_idx}: unexpected post_attention_layernorm type {type(layer.post_attention_layernorm).__name__}"
102
+ #
103
+ # moe_block = layer.block_sparse_moe
104
+ # assert isinstance(
105
+ # moe_block.experts, nn.ModuleList
106
+ # ), f"Layer {layer_idx}: expected experts to be a ModuleList, got {type(moe_block.experts).__name__}"
107
+ # for expert_idx, expert in enumerate(moe_block.experts):
108
+ # assert isinstance(
109
+ # expert, mlp_types
110
+ # ), f"Layer {layer_idx} expert {expert_idx}: expected MiniMax(M2)MLP, got {type(expert).__name__}"
111
+ #
112
+
113
+ def main() -> None:
114
+ args = parse_args()
115
+ model_path = Path(args.model_path).expanduser().resolve()
116
+
117
+ print(f"Loading tokenizer from {model_path}...")
118
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
119
+
120
+ print(f"Loading model from {model_path}...")
121
+ model = AutoModelForCausalLM.from_pretrained(
122
+ model_path,
123
+ dtype="bfloat16",
124
+ device_map="auto",
125
+ trust_remote_code=True,
126
+ )
127
+
128
+ # Uncomment to enforce module type checks.
129
+ # print("Validating module types...")
130
+ # assert_module_types(model)
131
+
132
+ prompt = build_prompt(tokenizer, args.question)
133
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
134
+
135
+ print("Running generation (streaming)...\n")
136
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=False)
137
+ eos_ids = model.generation_config.eos_token_id
138
+ if eos_ids is None:
139
+ eos_ids = []
140
+ elif isinstance(eos_ids, int):
141
+ eos_ids = [eos_ids]
142
+ think_end_id = tokenizer.convert_tokens_to_ids("</think>")
143
+ if think_end_id is not None and think_end_id not in eos_ids:
144
+ eos_ids = eos_ids + [think_end_id]
145
+
146
+ generation_kwargs = dict(
147
+ **inputs,
148
+ max_new_tokens=args.max_new_tokens,
149
+ streamer=streamer,
150
+ eos_token_id=eos_ids if eos_ids else None,
151
+ )
152
+
153
+ generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
154
+ generation_thread.start()
155
+
156
+ completion = []
157
+ first_chunk = True
158
+ seen_end_reasoning = False
159
+ for text in streamer:
160
+ if first_chunk:
161
+ print("<think>", end="", flush=True)
162
+ completion.append("<think>")
163
+ first_chunk = False
164
+ print(text, end="", flush=True)
165
+ completion.append(text)
166
+ if "</think>" in text:
167
+ seen_end_reasoning = True
168
+
169
+ generation_thread.join()
170
+ print("\n\n=== Completed Response ===")
171
+ final_text = "".join(completion).strip()
172
+ print(final_text or "<empty response>")
173
+ if not seen_end_reasoning:
174
+ print("\n[warning] No </think> token detected in streamed output.", flush=True)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7b90ed7f55d905175bc26771d6d7d33b40b46742f073675bc816fedaf482ea1
3
+ size 15522763
tokenizer_config.json ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "200000": {
5
+ "content": "]!p~[",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "200001": {
13
+ "content": "<fim_prefix>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "200002": {
21
+ "content": "<fim_middle>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "200003": {
29
+ "content": "<fim_suffix>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "200004": {
37
+ "content": "<fim_pad>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "200005": {
45
+ "content": "<reponame>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "200006": {
53
+ "content": "<filename>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "200007": {
61
+ "content": "<gh_stars>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "200008": {
69
+ "content": "<issue_start>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "200009": {
77
+ "content": "<issue_comment>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "200010": {
85
+ "content": "<issue_closed>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "200011": {
93
+ "content": "<jupyter_start>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "200012": {
101
+ "content": "<jupyter_text>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "200013": {
109
+ "content": "<jupyter_code>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "200014": {
117
+ "content": "<jupyter_output>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "200015": {
125
+ "content": "<empty_output>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "200016": {
133
+ "content": "<commit_before>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "200017": {
141
+ "content": "<commit_msg>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "200018": {
149
+ "content": "<commit_after>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "200019": {
157
+ "content": "]~b]",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "200020": {
165
+ "content": "[e~[",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "200021": {
173
+ "content": "]!d~[",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "200022": {
181
+ "content": "<function_call>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "200023": {
189
+ "content": "<code_interpreter>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "200024": {
197
+ "content": "]<]speech[>[",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "200025": {
205
+ "content": "]<]image[>[",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "200026": {
213
+ "content": "]<]video[>[",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "200027": {
221
+ "content": "]<]start of speech[>[",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "200028": {
229
+ "content": "]<]end of speech[>[",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "200029": {
237
+ "content": "]<]start of image[>[",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "200030": {
245
+ "content": "]<]end of image[>[",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "200031": {
253
+ "content": "]<]start of video[>[",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "200032": {
261
+ "content": "]<]end of video[>[",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "200033": {
269
+ "content": "]<]vision pad[>[",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "200034": {
277
+ "content": "]~!b[",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "200035": {
285
+ "content": "<jupyter_error>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "200036": {
293
+ "content": "<add_file>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "200037": {
301
+ "content": "<delete_file>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "200038": {
309
+ "content": "<rename_file>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "200039": {
317
+ "content": "<edit_file>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "200040": {
325
+ "content": "<commit_message>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "200041": {
333
+ "content": "<empty_source_file>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "200042": {
341
+ "content": "<repo_struct>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "200043": {
349
+ "content": "<code_context>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "200044": {
357
+ "content": "<file_content>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "200045": {
365
+ "content": "<source_files>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "200046": {
373
+ "content": "<pr_start>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "200047": {
381
+ "content": "<review_comment>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "200048": {
389
+ "content": "<filepath>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "200049": {
397
+ "content": "<file_sep>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "200050": {
405
+ "content": "<think>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": false
411
+ },
412
+ "200051": {
413
+ "content": "</think>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": false
419
+ },
420
+ "200052": {
421
+ "content": "<minimax:tool_call>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": false
427
+ },
428
+ "200053": {
429
+ "content": "</minimax:tool_call>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": false
435
+ }
436
+ },
437
+ "additional_special_tokens": [
438
+ "<code_interpreter>",
439
+ "<commit_after>",
440
+ "<commit_before>",
441
+ "<commit_msg>",
442
+ "<empty_output>",
443
+ "<filename>",
444
+ "<fim_middle>",
445
+ "<fim_pad>",
446
+ "<fim_prefix>",
447
+ "<fim_suffix>",
448
+ "<function_call>",
449
+ "<gh_stars>",
450
+ "]<]speech[>[",
451
+ "]<]image[>[",
452
+ "]<]video[>[",
453
+ "]<]start of speech[>[",
454
+ "]<]end of speech[>[",
455
+ "]<]start of image[>[",
456
+ "]<]end of image[>[",
457
+ "]<]start of video[>[",
458
+ "]<]end of video[>[",
459
+ "]<]vision pad[>[",
460
+ "]~!b[",
461
+ "<issue_closed>",
462
+ "<issue_comment>",
463
+ "<issue_start>",
464
+ "<jupyter_code>",
465
+ "<jupyter_output>",
466
+ "<jupyter_start>",
467
+ "<jupyter_text>",
468
+ "<reponame>",
469
+ "[e~[",
470
+ "]!d~[",
471
+ "]!p~[",
472
+ "]~b]",
473
+ "<jupyter_error>",
474
+ "<add_file>",
475
+ "<delete_file>",
476
+ "<rename_file>",
477
+ "<edit_file>",
478
+ "<commit_message>",
479
+ "<empty_source_file>",
480
+ "<repo_struct>",
481
+ "<code_context>",
482
+ "<file_content>",
483
+ "<source_files>",
484
+ "<pr_start>",
485
+ "<review_comment>",
486
+ "<filepath>",
487
+ "<file_sep>"
488
+ ],
489
+ "bos_token": "]~!b[",
490
+ "clean_up_tokenization_spaces": false,
491
+ "eos_token": "[e~[",
492
+ "extra_special_tokens": {},
493
+ "model_max_length": 40960000,
494
+ "pad_token": "[e~[",
495
+ "tokenizer_class": "GPT2TokenizerFast",
496
+ "unk_token": "]!d~[",
497
+ "_commit_hash": null
498
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff