ehartford commited on
Commit
9d5a2a6
·
verified ·
1 Parent(s): fcdf7bd

Upload quant.py

Browse files
Files changed (1) hide show
  1. quant.py +171 -0
quant.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert a local BF16 model into Marlin-supported quant formats via llm-compressor."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import gc
7
+ import os
8
+ import sys
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from datasets import load_dataset
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ # Allow running against the local llm-compressor checkout without installing.
16
+ LLM_COMPRESSOR_SRC = "/home/quixi/marlin-cdna/llm-compressor/src"
17
+ if os.path.isdir(LLM_COMPRESSOR_SRC):
18
+ sys.path.insert(0, LLM_COMPRESSOR_SRC)
19
+
20
+ from llmcompressor import oneshot # noqa: E402
21
+ from llmcompressor.modifiers.awq import AWQModifier # noqa: E402
22
+ from llmcompressor.modifiers.quantization import ( # noqa: E402
23
+ GPTQModifier,
24
+ QuantizationModifier,
25
+ )
26
+
27
+ MODEL_PATH = "/home/quixi/models/Llama-3.2-1B"
28
+ OUTPUT_ROOT = "/home/quixi/models"
29
+
30
+ CALIB_DATASET_ID = "HuggingFaceH4/ultrachat_200k"
31
+ CALIB_DATASET_SPLIT = "train_sft"
32
+ NUM_CALIBRATION_SAMPLES = 128
33
+ MAX_SEQUENCE_LENGTH = 512
34
+
35
+
36
+ def _load_tokenized_dataset(tokenizer):
37
+ ds = load_dataset(
38
+ CALIB_DATASET_ID,
39
+ split=f"{CALIB_DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
40
+ ).shuffle(seed=42)
41
+
42
+ def preprocess(example):
43
+ return {
44
+ "text": tokenizer.apply_chat_template(
45
+ example["messages"],
46
+ tokenize=False,
47
+ )
48
+ }
49
+
50
+ ds = ds.map(preprocess)
51
+
52
+ def tokenize(sample):
53
+ return tokenizer(
54
+ sample["text"],
55
+ padding=False,
56
+ max_length=MAX_SEQUENCE_LENGTH,
57
+ truncation=True,
58
+ add_special_tokens=False,
59
+ )
60
+
61
+ return ds.map(tokenize, remove_columns=ds.column_names)
62
+
63
+
64
+ def _load_model_and_tokenizer():
65
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, dtype="auto")
66
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
67
+ if torch.cuda.is_available():
68
+ model.to("cuda")
69
+ return model, tokenizer
70
+
71
+
72
+ def _cleanup(model, tokenizer):
73
+ del model
74
+ del tokenizer
75
+ gc.collect()
76
+ if torch.cuda.is_available():
77
+ torch.cuda.empty_cache()
78
+
79
+
80
+ def _run_recipe(
81
+ name: str,
82
+ recipe,
83
+ *,
84
+ save_compressed: bool,
85
+ use_calibration: bool,
86
+ ) -> Optional[str]:
87
+ print(f"\n=== Quantizing {name} ===")
88
+ model, tokenizer = _load_model_and_tokenizer()
89
+
90
+ oneshot_kwargs = {"model": model, "recipe": recipe}
91
+ if use_calibration:
92
+ ds = _load_tokenized_dataset(tokenizer)
93
+ oneshot_kwargs.update(
94
+ dataset=ds,
95
+ max_seq_length=MAX_SEQUENCE_LENGTH,
96
+ num_calibration_samples=NUM_CALIBRATION_SAMPLES,
97
+ )
98
+
99
+ oneshot(**oneshot_kwargs)
100
+
101
+ base_name = os.path.basename(MODEL_PATH.rstrip("/"))
102
+ save_dir = os.path.join(OUTPUT_ROOT, f"{base_name}-{name}")
103
+ os.makedirs(save_dir, exist_ok=True)
104
+
105
+ if save_compressed:
106
+ model.save_pretrained(save_dir, save_compressed=True)
107
+ else:
108
+ model.save_pretrained(save_dir)
109
+ tokenizer.save_pretrained(save_dir)
110
+
111
+ _cleanup(model, tokenizer)
112
+ return save_dir
113
+
114
+
115
+ def main():
116
+ # GPTQ W4A16 (INT4 weight-only).
117
+ _run_recipe(
118
+ "W4A16-GPTQ",
119
+ GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
120
+ save_compressed=True,
121
+ use_calibration=True,
122
+ )
123
+
124
+ # AWQ W4A16 (INT4 weight-only).
125
+ _run_recipe(
126
+ "W4A16-AWQ",
127
+ AWQModifier(
128
+ targets=["Linear"],
129
+ scheme="W4A16_ASYM",
130
+ ignore=["lm_head"],
131
+ duo_scaling="both",
132
+ ),
133
+ save_compressed=True,
134
+ use_calibration=True,
135
+ )
136
+
137
+ # GPTQ W8A16 (INT8 weight-only).
138
+ _run_recipe(
139
+ "W8A16-GPTQ",
140
+ GPTQModifier(targets="Linear", scheme="W8A16", ignore=["lm_head"]),
141
+ save_compressed=True,
142
+ use_calibration=True,
143
+ )
144
+
145
+ # FP8 dynamic (W8A8-FP8).
146
+ _run_recipe(
147
+ "FP8-Dynamic",
148
+ QuantizationModifier(targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]),
149
+ save_compressed=False,
150
+ use_calibration=False,
151
+ )
152
+
153
+ # NVFP4A16 (FP4 weights + FP16 activations).
154
+ _run_recipe(
155
+ "NVFP4A16",
156
+ QuantizationModifier(targets="Linear", scheme="NVFP4A16", ignore=["lm_head"]),
157
+ save_compressed=True,
158
+ use_calibration=False,
159
+ )
160
+
161
+ # MXFP4 (FP4 weights).
162
+ _run_recipe(
163
+ "MXFP4",
164
+ QuantizationModifier(targets="Linear", scheme="MXFP4", ignore=["lm_head"]),
165
+ save_compressed=True,
166
+ use_calibration=False,
167
+ )
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()