Kyle Howells commited on
Commit
3ff2e58
·
1 Parent(s): c4137e6

Replace demucs-mlx conversion scripts with direct PyTorch exporter

Browse files

Removes dependency on demucs-mlx re-implementation. The new
export_from_pytorch.py converts all 8 models directly from the
original PyTorch demucs package to safetensors + JSON config.

convert_demucs_mlx_checkpoint.py DELETED
@@ -1,121 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Export demucs-mlx pickle checkpoint to flat safetensors + JSON metadata.
4
-
5
- This is a preparation step for native Swift/MLX loading.
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import argparse
11
- import json
12
- import os
13
- import pickle
14
- from pathlib import Path
15
- from typing import Any
16
- from fractions import Fraction
17
-
18
- import mlx.core as mx
19
-
20
-
21
- def flatten_tree(node: Any, prefix: str = "") -> dict[str, mx.array]:
22
- out: dict[str, mx.array] = {}
23
-
24
- if isinstance(node, dict):
25
- for k, v in node.items():
26
- key = f"{prefix}.{k}" if prefix else str(k)
27
- out.update(flatten_tree(v, key))
28
- return out
29
-
30
- if isinstance(node, (list, tuple)):
31
- for idx, v in enumerate(node):
32
- key = f"{prefix}.{idx}" if prefix else str(idx)
33
- out.update(flatten_tree(v, key))
34
- return out
35
-
36
- # MLX array leaf
37
- if isinstance(node, mx.array):
38
- out[prefix] = node
39
- return out
40
-
41
- # Non-array leaf in state tree: ignore.
42
- return out
43
-
44
-
45
- def to_builtin(obj: Any) -> Any:
46
- if isinstance(obj, dict):
47
- return {str(k): to_builtin(v) for k, v in obj.items()}
48
- if isinstance(obj, (list, tuple)):
49
- return [to_builtin(x) for x in obj]
50
- if isinstance(obj, Fraction):
51
- return f"{obj.numerator}/{obj.denominator}"
52
- return obj
53
-
54
-
55
- def main() -> None:
56
- ap = argparse.ArgumentParser()
57
- ap.add_argument(
58
- "--checkpoint",
59
- default=os.path.expanduser("~/.cache/demucs-mlx/htdemucs_mlx.pkl"),
60
- help="Path to demucs-mlx pickle checkpoint",
61
- )
62
- ap.add_argument(
63
- "--out-dir",
64
- default="./Models/htdemucs",
65
- help="Output directory",
66
- )
67
- ap.add_argument(
68
- "--name",
69
- default="htdemucs",
70
- help="Output model basename",
71
- )
72
- args = ap.parse_args()
73
-
74
- ck_path = Path(args.checkpoint).expanduser().resolve()
75
- out_dir = Path(args.out_dir).resolve()
76
- out_dir.mkdir(parents=True, exist_ok=True)
77
-
78
- with ck_path.open("rb") as f:
79
- checkpoint = pickle.load(f)
80
-
81
- if "state" not in checkpoint:
82
- raise ValueError(f"No 'state' key in checkpoint: {ck_path}")
83
-
84
- flat = flatten_tree(checkpoint["state"])
85
- if not flat:
86
- raise ValueError("No MLX arrays found while flattening state tree")
87
-
88
- safetensors_path = out_dir / f"{args.name}.safetensors"
89
- config_path = out_dir / f"{args.name}_config.json"
90
-
91
- mx.save_safetensors(str(safetensors_path), flat)
92
-
93
- metadata = {
94
- "model_name": checkpoint.get("model_name"),
95
- "model_class": checkpoint.get("model_class"),
96
- "sub_model_class": checkpoint.get("sub_model_class"),
97
- "num_models": checkpoint.get("num_models"),
98
- "weights": checkpoint.get("weights"),
99
- "args": to_builtin(checkpoint.get("args", [])),
100
- "kwargs": to_builtin(checkpoint.get("kwargs", {})),
101
- "mlx_version": checkpoint.get("mlx_version"),
102
- "tensor_count": len(flat),
103
- "tensors": {
104
- k: {
105
- "shape": list(v.shape),
106
- "dtype": str(v.dtype),
107
- }
108
- for k, v in flat.items()
109
- },
110
- }
111
-
112
- with config_path.open("w") as f:
113
- json.dump(metadata, f, indent=2)
114
-
115
- print(f"wrote {safetensors_path}")
116
- print(f"wrote {config_path}")
117
- print(f"tensors: {len(flat)}")
118
-
119
-
120
- if __name__ == "__main__":
121
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
export_all_models.py DELETED
@@ -1,206 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Export all available demucs-mlx model checkpoints to safetensors + JSON.
4
-
5
- Usage:
6
- python scripts/export_all_models.py [--cache-dir ~/.cache/demucs-mlx] [--out-dir ./Models]
7
-
8
- This script finds all *_mlx.pkl checkpoints in the demucs-mlx cache directory
9
- and exports each one as:
10
- <out-dir>/<model_name>/<model_name>.safetensors
11
- <out-dir>/<model_name>/<model_name>_config.json
12
-
13
- If you haven't converted models yet, run demucs-mlx first to generate the
14
- pickle checkpoints:
15
- python -m demucs_mlx --model htdemucs -n test.mp3
16
- """
17
-
18
- from __future__ import annotations
19
-
20
- import argparse
21
- import json
22
- import os
23
- import pickle
24
- import sys
25
- from pathlib import Path
26
- from typing import Any
27
- from fractions import Fraction
28
-
29
- import mlx.core as mx
30
-
31
-
32
- # Known model names in demucs-mlx
33
- ALL_MODELS = [
34
- "htdemucs",
35
- "htdemucs_ft",
36
- "htdemucs_6s",
37
- "hdemucs_mmi",
38
- "mdx",
39
- "mdx_extra",
40
- "mdx_q",
41
- "mdx_extra_q",
42
- ]
43
-
44
-
45
- def flatten_tree(node: Any, prefix: str = "") -> dict[str, mx.array]:
46
- out: dict[str, mx.array] = {}
47
- if isinstance(node, dict):
48
- for k, v in node.items():
49
- key = f"{prefix}.{k}" if prefix else str(k)
50
- out.update(flatten_tree(v, key))
51
- return out
52
- if isinstance(node, (list, tuple)):
53
- for idx, v in enumerate(node):
54
- key = f"{prefix}.{idx}" if prefix else str(idx)
55
- out.update(flatten_tree(v, key))
56
- return out
57
- if isinstance(node, mx.array):
58
- out[prefix] = node
59
- return out
60
- return out
61
-
62
-
63
- def to_builtin(obj: Any) -> Any:
64
- if isinstance(obj, dict):
65
- return {str(k): to_builtin(v) for k, v in obj.items()}
66
- if isinstance(obj, (list, tuple)):
67
- return [to_builtin(x) for x in obj]
68
- if isinstance(obj, Fraction):
69
- return f"{obj.numerator}/{obj.denominator}"
70
- return obj
71
-
72
-
73
- def export_checkpoint(ck_path: Path, out_dir: Path, model_name: str) -> bool:
74
- """Export a single checkpoint. Returns True on success."""
75
- if not ck_path.exists():
76
- return False
77
-
78
- print(f"\n--- Exporting {model_name} from {ck_path} ---")
79
-
80
- with ck_path.open("rb") as f:
81
- checkpoint = pickle.load(f)
82
-
83
- if "state" not in checkpoint:
84
- print(f" WARNING: No 'state' key in checkpoint, skipping")
85
- return False
86
-
87
- flat = flatten_tree(checkpoint["state"])
88
- if not flat:
89
- print(f" WARNING: No MLX arrays found, skipping")
90
- return False
91
-
92
- model_dir = out_dir / model_name
93
- model_dir.mkdir(parents=True, exist_ok=True)
94
-
95
- safetensors_path = model_dir / f"{model_name}.safetensors"
96
- config_path = model_dir / f"{model_name}_config.json"
97
-
98
- mx.save_safetensors(str(safetensors_path), flat)
99
-
100
- metadata = {
101
- "model_name": checkpoint.get("model_name", model_name),
102
- "model_class": checkpoint.get("model_class"),
103
- "sub_model_class": checkpoint.get("sub_model_class"),
104
- "num_models": checkpoint.get("num_models"),
105
- "weights": checkpoint.get("weights"),
106
- "args": to_builtin(checkpoint.get("args", [])),
107
- "kwargs": to_builtin(checkpoint.get("kwargs", {})),
108
- "mlx_version": checkpoint.get("mlx_version"),
109
- "tensor_count": len(flat),
110
- }
111
-
112
- # For heterogeneous bags, include per-model class and kwargs
113
- per_model_class = checkpoint.get("per_model_class")
114
- per_model_kwargs = checkpoint.get("per_model_kwargs")
115
-
116
- if per_model_class:
117
- # Map PyTorch class names to MLX class names
118
- class_map = {
119
- 'Demucs': 'DemucsMLX',
120
- 'HDemucs': 'HDemucsMLX',
121
- 'HTDemucs': 'HTDemucsMLX',
122
- }
123
- metadata["sub_model_classes"] = [class_map.get(c, c) for c in per_model_class]
124
-
125
- if per_model_kwargs:
126
- # Build model_configs array with per-model class + kwargs
127
- model_configs = []
128
- for i, kw in enumerate(per_model_kwargs):
129
- mc = "HTDemucsMLX"
130
- if per_model_class and i < len(per_model_class):
131
- mc = class_map.get(per_model_class[i], per_model_class[i])
132
- model_configs.append({
133
- "model_class": mc,
134
- "kwargs": to_builtin(kw),
135
- })
136
- metadata["model_configs"] = model_configs
137
-
138
- # Remove None values for cleaner JSON
139
- metadata = {k: v for k, v in metadata.items() if v is not None}
140
-
141
- with config_path.open("w") as f:
142
- json.dump(metadata, f, indent=2)
143
-
144
- print(f" wrote {safetensors_path} ({len(flat)} tensors)")
145
- print(f" wrote {config_path}")
146
- mc = metadata.get("model_class", "?")
147
- smc = metadata.get("sub_model_class", "")
148
- nm = metadata.get("num_models", 1)
149
- print(f" class={mc}, sub_class={smc}, num_models={nm}")
150
- return True
151
-
152
-
153
- def main() -> None:
154
- ap = argparse.ArgumentParser(description="Export all demucs-mlx checkpoints to safetensors")
155
- ap.add_argument(
156
- "--cache-dir",
157
- default=os.path.expanduser("~/.cache/demucs-mlx"),
158
- help="demucs-mlx cache directory containing *_mlx.pkl files",
159
- )
160
- ap.add_argument(
161
- "--out-dir",
162
- default="./Models",
163
- help="Output root directory (model files go into <out-dir>/<model_name>/)",
164
- )
165
- ap.add_argument(
166
- "--models",
167
- nargs="*",
168
- default=None,
169
- help="Specific model names to export (default: all found)",
170
- )
171
- args = ap.parse_args()
172
-
173
- cache_dir = Path(args.cache_dir).expanduser().resolve()
174
- out_dir = Path(args.out_dir).resolve()
175
-
176
- if not cache_dir.exists():
177
- print(f"Cache directory not found: {cache_dir}")
178
- print("Run demucs-mlx first to download and convert models.")
179
- sys.exit(1)
180
-
181
- models_to_export = args.models or ALL_MODELS
182
-
183
- exported = 0
184
- skipped = 0
185
-
186
- for model_name in models_to_export:
187
- ck_path = cache_dir / f"{model_name}_mlx.pkl"
188
- if export_checkpoint(ck_path, out_dir, model_name):
189
- exported += 1
190
- else:
191
- skipped += 1
192
-
193
- # Also check for any *_mlx.pkl files not in our known list
194
- if args.models is None:
195
- for pkl_file in sorted(cache_dir.glob("*_mlx.pkl")):
196
- name = pkl_file.stem.replace("_mlx", "")
197
- if name not in ALL_MODELS:
198
- print(f"\nFound additional checkpoint: {pkl_file.name}")
199
- if export_checkpoint(pkl_file, out_dir, name):
200
- exported += 1
201
-
202
- print(f"\n=== Done: {exported} exported, {skipped} skipped ===")
203
-
204
-
205
- if __name__ == "__main__":
206
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
export_from_pytorch.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export Demucs PyTorch models directly to safetensors + JSON config for Swift MLX.
4
+
5
+ Converts all 8 pretrained models directly from the original PyTorch demucs package.
6
+ No dependency on demucs-mlx or any other re-implementation.
7
+
8
+ Usage:
9
+ # Export all models
10
+ python scripts/export_from_pytorch.py --out-dir ~/.cache/demucs-mlx-swift-models
11
+
12
+ # Export specific models
13
+ python scripts/export_from_pytorch.py --models htdemucs htdemucs_ft --out-dir ./Models
14
+
15
+ Requirements:
16
+ pip install demucs safetensors numpy
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import inspect
22
+ import json
23
+ import re
24
+ import sys
25
+ from fractions import Fraction
26
+ from pathlib import Path
27
+
28
+ import numpy as np
29
+ import torch
30
+
31
+ ALL_MODELS = [
32
+ "htdemucs",
33
+ "htdemucs_ft",
34
+ "htdemucs_6s",
35
+ "hdemucs_mmi",
36
+ "mdx",
37
+ "mdx_extra",
38
+ "mdx_q",
39
+ "mdx_extra_q",
40
+ ]
41
+
42
+ # Map PyTorch class names to MLX class names used by Swift loader
43
+ CLASS_MAP = {
44
+ "Demucs": "DemucsMLX",
45
+ "HDemucs": "HDemucsMLX",
46
+ "HTDemucs": "HTDemucsMLX",
47
+ }
48
+
49
+ # Conv-like layer names that get .conv. wrapper in MLX
50
+ CONV_LAYER_NAMES = {
51
+ "conv", "conv_tr", "rewrite",
52
+ "channel_upsampler", "channel_downsampler",
53
+ "channel_upsampler_t", "channel_downsampler_t",
54
+ }
55
+
56
+ # DConv attention sub-module names (LocalState)
57
+ DCONV_ATTN_NAMES = {"content", "key", "query", "proj", "query_decay", "query_freqs"}
58
+
59
+
60
+ def to_json_serializable(obj):
61
+ """Convert Python objects to JSON-serializable types."""
62
+ if isinstance(obj, Fraction):
63
+ return f"{obj.numerator}/{obj.denominator}"
64
+ if isinstance(obj, torch.Tensor):
65
+ return obj.item() if obj.numel() == 1 else obj.tolist()
66
+ if isinstance(obj, np.ndarray):
67
+ return obj.tolist()
68
+ if isinstance(obj, (list, tuple)):
69
+ return [to_json_serializable(x) for x in obj]
70
+ if isinstance(obj, dict):
71
+ return {str(k): to_json_serializable(v) for k, v in obj.items()}
72
+ return obj
73
+
74
+
75
+ def transpose_conv_weights(key: str, value: np.ndarray, is_conv_transpose: bool = False) -> np.ndarray:
76
+ """Transpose PyTorch conv weights to MLX layout.
77
+
78
+ Conv1d: (out, in, k) → MLX: (out, k, in) transpose (0,2,1)
79
+ Conv2d: (out, in, h, w) → MLX: (out, h, w, in) transpose (0,2,3,1)
80
+ ConvTranspose1d: (in, out, k) → MLX: (out, k, in) transpose (1,2,0)
81
+ ConvTranspose2d: (in, out, h, w) → MLX: (out, h, w, in) transpose (1,2,3,0)
82
+ """
83
+ if not key.endswith(".weight"):
84
+ return value
85
+
86
+ if len(value.shape) == 3:
87
+ return np.transpose(value, (1, 2, 0) if is_conv_transpose else (0, 2, 1))
88
+ if len(value.shape) == 4:
89
+ return np.transpose(value, (1, 2, 3, 0) if is_conv_transpose else (0, 2, 3, 1))
90
+ return value
91
+
92
+
93
+ def remap_key(
94
+ key: str,
95
+ value: np.ndarray,
96
+ model_type: str = "HTDemucs",
97
+ dconv_conv_slots: set | None = None,
98
+ seq_conv_slots: set | None = None,
99
+ ) -> list[tuple[str, np.ndarray]]:
100
+ """Remap a PyTorch state dict key to MLX key convention.
101
+
102
+ Returns a list of (key, value) pairs (multiple for attention in_proj splits).
103
+ Duplicate target keys (e.g. LSTM bias_ih + bias_hh) are merged by the caller.
104
+
105
+ Args:
106
+ key: PyTorch state dict key
107
+ value: numpy array (already transposed for conv weights)
108
+ model_type: PyTorch class name ("Demucs", "HDemucs", "HTDemucs")
109
+ dconv_conv_slots: set of (block_prefix, slot_str) for DConv slots with 3D weights
110
+ seq_conv_slots: set of (enc_dec, layer, slot) for Demucs v1/v2 Sequential Conv slots
111
+ """
112
+ dconv_conv_slots = dconv_conv_slots or set()
113
+ seq_conv_slots = seq_conv_slots or set()
114
+
115
+ # =========================================================================
116
+ # Step 1: Demucs v1/v2 Sequential insertion
117
+ # encoder.{i}.{j}.rest → encoder.{i}.layers.{j}.rest
118
+ # decoder.{i}.{j}.rest → decoder.{i}.layers.{j}.rest
119
+ # =========================================================================
120
+ if model_type == "Demucs":
121
+ m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)(\..*)?$", key)
122
+ if m:
123
+ enc_dec, layer, slot, rest = m.groups()
124
+ rest = rest or ""
125
+ key = f"{enc_dec}.{layer}.layers.{slot}{rest}"
126
+
127
+ # =========================================================================
128
+ # Step 1.5: Demucs v1/v2 Sequential Conv/Norm slot wrapping
129
+ # encoder.{i}.layers.{j}.weight → encoder.{i}.layers.{j}.conv.weight (if Conv slot)
130
+ # =========================================================================
131
+ if model_type == "Demucs":
132
+ m = re.match(r"(encoder|decoder)\.(\d+)\.layers\.(\d+)\.(weight|bias)$", key)
133
+ if m:
134
+ enc_dec, layer, slot, param = m.groups()
135
+ if (enc_dec, layer, slot) in seq_conv_slots:
136
+ return [(f"{enc_dec}.{layer}.layers.{slot}.conv.{param}", value)]
137
+ else:
138
+ return [(f"{enc_dec}.{layer}.layers.{slot}.{param}", value)]
139
+
140
+ # =========================================================================
141
+ # Step 2: DConv internal slot handling
142
+ # Matches: *.layers.{block_idx}.{slot_idx}.{rest}
143
+ # Both HDemucs (.dconv.layers.) and Demucs v1/v2 (.layers.{N}.layers.) end
144
+ # with this pattern after Step 1.
145
+ # =========================================================================
146
+ m = re.match(r"(.+\.layers\.\d+)\.(\d+)\.(.+)$", key)
147
+ if m:
148
+ block_prefix = m.group(1)
149
+ slot = m.group(2)
150
+ rest = m.group(3)
151
+
152
+ # --- 2a. Simple weight/bias/scale ---
153
+ if rest in ("weight", "bias", "scale"):
154
+ if rest == "weight" and len(value.shape) >= 2:
155
+ # 3D weight = Conv1d → add .conv.
156
+ return [(f"{block_prefix}.layers.{slot}.conv.{rest}", value)]
157
+ elif rest == "weight":
158
+ # 1D weight = GroupNorm → no wrapper
159
+ return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
160
+ elif rest == "bias":
161
+ if (block_prefix, slot) in dconv_conv_slots:
162
+ return [(f"{block_prefix}.layers.{slot}.conv.{rest}", value)]
163
+ else:
164
+ return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
165
+ else: # scale
166
+ return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
167
+
168
+ # --- 2b. LSTM weights/biases ---
169
+ m_lstm = re.match(r"lstm\.(weight|bias)_(ih|hh)_l(\d+)(_reverse)?$", rest)
170
+ if m_lstm:
171
+ wb, ih_hh, layer_idx, reverse = m_lstm.groups()
172
+ direction = "backward_lstms" if reverse else "forward_lstms"
173
+ if wb == "weight":
174
+ param = "Wx" if ih_hh == "ih" else "Wh"
175
+ return [(f"{block_prefix}.layers.{slot}.{direction}.{layer_idx}.{param}", value)]
176
+ else: # bias — both bias_ih and bias_hh map to same key; caller merges
177
+ return [(f"{block_prefix}.layers.{slot}.{direction}.{layer_idx}.bias", value)]
178
+
179
+ # --- 2c. LSTM linear ---
180
+ m_linear = re.match(r"linear\.(weight|bias)$", rest)
181
+ if m_linear:
182
+ param = m_linear.group(1)
183
+ return [(f"{block_prefix}.layers.{slot}.linear.{param}", value)]
184
+
185
+ # --- 2d. Attention sub-modules (LocalState) ---
186
+ m_attn = re.match(r"(content|key|query|proj|query_decay|query_freqs)\.(weight|bias)$", rest)
187
+ if m_attn:
188
+ attn_name, param = m_attn.groups()
189
+ # These are all Conv1d modules → add .conv. wrapper
190
+ return [(f"{block_prefix}.layers.{slot}.{attn_name}.conv.{param}", value)]
191
+
192
+ # --- 2e. Fallback for unknown compound keys ---
193
+ return [(f"{block_prefix}.layers.{slot}.{rest}", value)]
194
+
195
+ # =========================================================================
196
+ # Step 3: MultiheadAttention in_proj split (HTDemucs transformer)
197
+ # =========================================================================
198
+ m = re.match(r"(.+)\.(self_attn|cross_attn)\.in_proj_(weight|bias)$", key)
199
+ if m:
200
+ prefix, attn_type, param = m.group(1), m.group(2), m.group(3)
201
+ mlx_attn = "attn" if attn_type == "self_attn" else "cross_attn"
202
+ dim = value.shape[0] // 3
203
+ q, k_val, v = value[:dim], value[dim : 2 * dim], value[2 * dim :]
204
+ return [
205
+ (f"{prefix}.{mlx_attn}.query_proj.{param}", q),
206
+ (f"{prefix}.{mlx_attn}.key_proj.{param}", k_val),
207
+ (f"{prefix}.{mlx_attn}.value_proj.{param}", v),
208
+ ]
209
+
210
+ # self_attn.out_proj → attn.out_proj
211
+ m = re.match(r"(.+)\.self_attn\.out_proj\.(weight|bias)$", key)
212
+ if m:
213
+ prefix, param = m.group(1), m.group(2)
214
+ return [(f"{prefix}.attn.out_proj.{param}", value)]
215
+
216
+ # =========================================================================
217
+ # Step 4: norm_out wrapping → norm_out.gn
218
+ # =========================================================================
219
+ m = re.match(r"(.+)\.norm_out\.(weight|bias)$", key)
220
+ if m:
221
+ prefix, param = m.group(1), m.group(2)
222
+ return [(f"{prefix}.norm_out.gn.{param}", value)]
223
+
224
+ # =========================================================================
225
+ # Step 5: Bottleneck LSTM (Demucs v1/v2 and HDemucs)
226
+ # lstm.lstm.weight_ih_l0 → lstm.forward_lstms.0.Wx
227
+ # =========================================================================
228
+ m = re.match(r"(.+)\.lstm\.(weight|bias)_(ih|hh)_l(\d+)(_reverse)?$", key)
229
+ if m:
230
+ prefix = m.group(1)
231
+ wb = m.group(2)
232
+ ih_hh = m.group(3)
233
+ layer_idx = m.group(4)
234
+ reverse = m.group(5)
235
+ direction = "backward_lstms" if reverse else "forward_lstms"
236
+ if wb == "weight":
237
+ param = "Wx" if ih_hh == "ih" else "Wh"
238
+ return [(f"{prefix}.{direction}.{layer_idx}.{param}", value)]
239
+ else: # bias — merge handled by caller
240
+ return [(f"{prefix}.{direction}.{layer_idx}.bias", value)]
241
+
242
+ # =========================================================================
243
+ # Step 6: Conv/ConvTranspose/Rewrite named layers → add .conv. wrapper
244
+ # =========================================================================
245
+ parts = key.rsplit(".", 1)
246
+ if len(parts) == 2:
247
+ path, param = parts
248
+ path_parts = path.split(".")
249
+ last_name = path_parts[-1]
250
+ if last_name in CONV_LAYER_NAMES and param in ("weight", "bias"):
251
+ return [(f"{path}.conv.{param}", value)]
252
+
253
+ # =========================================================================
254
+ # Default: no change
255
+ # =========================================================================
256
+ return [(key, value)]
257
+
258
+
259
+ def convert_sub_model(model, prefix: str) -> dict[str, np.ndarray]:
260
+ """Convert a single sub-model's state dict to MLX-compatible numpy arrays."""
261
+ cls_name = type(model).__name__
262
+
263
+ # --- Pre-scan: identify ConvTranspose modules by type ---
264
+ conv_tr_paths = set()
265
+ for name, module in model.named_modules():
266
+ if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)):
267
+ conv_tr_paths.add(name)
268
+
269
+ # --- Collect state dict as numpy ---
270
+ state_items = []
271
+ for key, tensor in model.state_dict().items():
272
+ arr = tensor.detach().cpu().float().numpy()
273
+ state_items.append((key, arr))
274
+
275
+ # --- Pre-scan: identify DConv Conv slots (3D weights) ---
276
+ # Pattern: *.layers.{block}.{slot}.weight where value is 3D
277
+ # For Demucs v1/v2, apply Sequential insertion first so lookups match remap_key
278
+ dconv_conv_slots: set[tuple[str, str]] = set()
279
+ for key, arr in state_items:
280
+ scan_key = key
281
+ if cls_name == "Demucs":
282
+ m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)(\..*)?$", scan_key)
283
+ if m:
284
+ enc_dec, layer, slot, rest = m.groups()
285
+ rest = rest or ""
286
+ scan_key = f"{enc_dec}.{layer}.layers.{slot}{rest}"
287
+ m = re.match(r"(.+\.layers\.\d+)\.(\d+)\.weight$", scan_key)
288
+ if m and len(arr.shape) >= 2:
289
+ dconv_conv_slots.add((m.group(1), m.group(2)))
290
+
291
+ # --- Pre-scan: Demucs v1/v2 Sequential Conv slots ---
292
+ seq_conv_slots: set[tuple[str, str, str]] = set()
293
+ if cls_name == "Demucs":
294
+ for key, arr in state_items:
295
+ m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)\.weight$", key)
296
+ if m and len(arr.shape) >= 2:
297
+ seq_conv_slots.add((m.group(1), m.group(2), m.group(3)))
298
+
299
+ # --- Convert ---
300
+ weights: dict[str, np.ndarray] = {}
301
+ for key, arr in state_items:
302
+ # Determine if this belongs to a ConvTranspose module
303
+ is_conv_tr = any(key.startswith(p + ".") for p in conv_tr_paths)
304
+
305
+ # Transpose conv weights
306
+ arr = transpose_conv_weights(key, arr, is_conv_transpose=is_conv_tr)
307
+
308
+ # Remap key
309
+ remapped = remap_key(key, arr, cls_name, dconv_conv_slots, seq_conv_slots)
310
+ for new_key, new_val in remapped:
311
+ full_key = f"{prefix}{new_key}"
312
+ if full_key in weights:
313
+ # LSTM bias merge: bias_ih + bias_hh → bias (additive)
314
+ weights[full_key] = weights[full_key] + new_val
315
+ else:
316
+ weights[full_key] = new_val
317
+
318
+ return weights
319
+
320
+
321
+ def extract_kwargs(model) -> dict:
322
+ """Extract constructor kwargs from a model using _init_args_kwargs or inspection."""
323
+ if hasattr(model, "_init_args_kwargs"):
324
+ _, kwargs = model._init_args_kwargs
325
+ return {k: to_json_serializable(v) for k, v in kwargs.items()
326
+ if isinstance(v, (int, float, str, bool, list, tuple, type(None), Fraction))}
327
+
328
+ # Fallback: inspect __init__ signature and read matching attributes
329
+ sig = inspect.signature(type(model).__init__)
330
+ kwargs = {}
331
+ for name in sig.parameters:
332
+ if name == "self":
333
+ continue
334
+ if hasattr(model, name):
335
+ val = getattr(model, name)
336
+ kwargs[name] = to_json_serializable(val)
337
+ return kwargs
338
+
339
+
340
+ def export_model(model_name: str, out_dir: Path) -> bool:
341
+ """Export a single model (or bag) to safetensors + config JSON."""
342
+ from demucs.pretrained import get_model
343
+ from demucs.apply import BagOfModels
344
+
345
+ print(f"\n--- Exporting {model_name} ---")
346
+ try:
347
+ model = get_model(model_name)
348
+ except Exception as e:
349
+ print(f" Failed to load model: {e}")
350
+ return False
351
+
352
+ is_bag = isinstance(model, BagOfModels)
353
+
354
+ if is_bag:
355
+ sub_models = list(model.models)
356
+ num_models = len(sub_models)
357
+ bag_weights = model.weights.tolist() if hasattr(model.weights, "tolist") else list(model.weights)
358
+ else:
359
+ sub_models = [model]
360
+ num_models = 1
361
+ bag_weights = None
362
+
363
+ print(f" {'Bag of ' + str(num_models) + ' models' if is_bag else 'Single model'}")
364
+
365
+ # Collect all weights and metadata
366
+ all_weights: dict[str, np.ndarray] = {}
367
+ model_classes: list[str] = []
368
+ model_configs: list[dict] = []
369
+
370
+ for i, sub in enumerate(sub_models):
371
+ cls_name = type(sub).__name__
372
+ mlx_cls = CLASS_MAP.get(cls_name, cls_name)
373
+ model_classes.append(mlx_cls)
374
+ print(f" Model {i}: {cls_name} → {mlx_cls}")
375
+
376
+ prefix = f"model_{i}." if is_bag else ""
377
+ sub_weights = convert_sub_model(sub, prefix)
378
+ all_weights.update(sub_weights)
379
+
380
+ kwargs = extract_kwargs(sub)
381
+ model_configs.append({
382
+ "model_class": mlx_cls,
383
+ "kwargs": kwargs,
384
+ })
385
+
386
+ # Build config JSON
387
+ config: dict = {
388
+ "model_name": model_name,
389
+ "tensor_count": len(all_weights),
390
+ }
391
+
392
+ if is_bag:
393
+ config["model_class"] = "BagOfModelsMLX"
394
+ config["num_models"] = num_models
395
+ config["weights"] = bag_weights
396
+ config["sub_model_classes"] = model_classes
397
+
398
+ # If all sub-models are the same class, set sub_model_class for compat
399
+ unique = set(model_classes)
400
+ if len(unique) == 1:
401
+ config["sub_model_class"] = unique.pop()
402
+
403
+ config["model_configs"] = model_configs
404
+
405
+ # Also put kwargs at top level for single-model bags (common case)
406
+ if num_models == 1:
407
+ config["kwargs"] = model_configs[0]["kwargs"]
408
+ else:
409
+ config["model_class"] = model_classes[0]
410
+ config["kwargs"] = model_configs[0]["kwargs"]
411
+
412
+ # Save files
413
+ model_dir = out_dir / model_name
414
+ model_dir.mkdir(parents=True, exist_ok=True)
415
+
416
+ safetensors_path = model_dir / f"{model_name}.safetensors"
417
+ config_path = model_dir / f"{model_name}_config.json"
418
+
419
+ # Save safetensors (prefer safetensors library, fallback to mlx)
420
+ try:
421
+ from safetensors.numpy import save_file
422
+ save_file(all_weights, str(safetensors_path))
423
+ except ImportError:
424
+ import mlx.core as mx
425
+ mlx_weights = {k: mx.array(v) for k, v in all_weights.items()}
426
+ mx.save_safetensors(str(safetensors_path), mlx_weights)
427
+
428
+ with config_path.open("w") as f:
429
+ json.dump(config, f, indent=2, default=str)
430
+
431
+ size_mb = safetensors_path.stat().st_size / (1024 * 1024)
432
+ print(f" Wrote {safetensors_path} ({len(all_weights)} tensors, {size_mb:.0f} MB)")
433
+ print(f" Wrote {config_path}")
434
+ return True
435
+
436
+
437
+ def main():
438
+ ap = argparse.ArgumentParser(
439
+ description="Export Demucs PyTorch models to safetensors for Swift MLX"
440
+ )
441
+ ap.add_argument(
442
+ "--models",
443
+ nargs="*",
444
+ default=None,
445
+ help=f"Models to export (default: all). Choices: {', '.join(ALL_MODELS)}",
446
+ )
447
+ ap.add_argument(
448
+ "--out-dir",
449
+ default="./Models",
450
+ help="Output root directory (files go into <out-dir>/<model_name>/)",
451
+ )
452
+ args = ap.parse_args()
453
+
454
+ models = args.models or ALL_MODELS
455
+ out_dir = Path(args.out_dir).resolve()
456
+
457
+ exported = 0
458
+ failed = 0
459
+
460
+ for name in models:
461
+ if export_model(name, out_dir):
462
+ exported += 1
463
+ else:
464
+ failed += 1
465
+
466
+ print(f"\n=== Done: {exported} exported, {failed} failed ===")
467
+ if failed:
468
+ sys.exit(1)
469
+
470
+
471
+ if __name__ == "__main__":
472
+ main()
export_mdx.py DELETED
@@ -1,343 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Export mdx/mdx_extra models (heterogeneous bags of Demucs + HDemucs) to safetensors.
4
-
5
- These models contain a mix of Demucs (v1/v2) and HDemucs (v3) sub-models in a
6
- single BagOfModels. The Python MLX converter has a bug that prevents it from
7
- handling these models, so we do a direct PyTorch → safetensors conversion.
8
-
9
- Usage:
10
- python scripts/export_mdx.py --model mdx --out-dir .scratch/models
11
- """
12
- from __future__ import annotations
13
-
14
- import argparse
15
- import json
16
- import os
17
- import sys
18
- from pathlib import Path
19
-
20
- import torch
21
- import numpy as np
22
-
23
-
24
- def flatten_state_dict(state_dict: dict, prefix: str = "") -> dict:
25
- """Flatten a nested state dict into dot-separated keys with numpy arrays."""
26
- flat = {}
27
- for key, value in state_dict.items():
28
- full_key = f"{prefix}{key}" if prefix else key
29
- if isinstance(value, torch.Tensor):
30
- flat[full_key] = value.detach().cpu().numpy()
31
- elif isinstance(value, dict):
32
- flat.update(flatten_state_dict(value, f"{full_key}."))
33
- return flat
34
-
35
-
36
- def convert_torch_to_mlx_keys(state_dict: dict, model_type: str) -> dict:
37
- """Convert PyTorch state dict keys/shapes to MLX-compatible format.
38
-
39
- Key differences:
40
- - Conv1d weight: (out, in, k) → (out, k, in)
41
- - Conv2d weight: (out, in, h, w) → (out, h, w, in)
42
- - ConvTranspose1d weight: (in, out, k) → (out, k, in)
43
- - ConvTranspose2d weight: (in, out, h, w) → (out, h, w, in)
44
- - nn.Sequential indices stay as-is but may need remapping for DConv
45
- """
46
- converted = {}
47
-
48
- for key, value in state_dict.items():
49
- # Skip non-tensor items
50
- if not isinstance(value, np.ndarray):
51
- continue
52
-
53
- new_key = key
54
- new_value = value
55
-
56
- # Transpose conv weights
57
- if key.endswith('.weight') and len(value.shape) == 3:
58
- # 1D conv: (out, in, k) → (out, k, in)
59
- new_value = np.transpose(value, (0, 2, 1))
60
- elif key.endswith('.weight') and len(value.shape) == 4:
61
- # 2D conv: (out, in, h, w) → (out, h, w, in)
62
- new_value = np.transpose(value, (0, 2, 3, 1))
63
-
64
- # Handle ConvTranspose weight naming
65
- # ConvTranspose1d: (in, out, k) → (out, k, in)
66
- # These are already handled above since they also end in .weight with 3 dims
67
-
68
- converted[new_key] = new_value
69
-
70
- return converted
71
-
72
-
73
- def remap_demucs_keys(state_dict: dict) -> dict:
74
- """Remap Demucs v1/v2 PyTorch keys to MLX key structure.
75
-
76
- PyTorch Demucs uses nn.ModuleList of nn.Sequential:
77
- - encoder[i] = Sequential(Conv1d, GroupNorm, ..., DConv, ...)
78
-
79
- In the PyTorch state dict, keys look like:
80
- - encoder.{i}.{j}.weight (for simple layers)
81
- - encoder.{i}.{j}.layers.{k}.{l}.weight (for DConv)
82
-
83
- MLX uses explicit named sub-modules, so we need to wrap in Conv1dNCL etc.
84
- The MLX structure wraps Conv1d in Conv1dNCL which has .conv sub-module.
85
- """
86
- remapped = {}
87
-
88
- # Map of which sequential indices are Conv1d/ConvTranspose1d
89
- # and need wrapping in Conv1dNCL/ConvTranspose1dNCL
90
- for key, value in state_dict.items():
91
- parts = key.split('.')
92
-
93
- # Handle encoder layers
94
- if len(parts) >= 3 and parts[0] == 'encoder':
95
- enc_idx = parts[1]
96
- layer_idx = int(parts[2])
97
- rest = '.'.join(parts[3:])
98
-
99
- # Sequential structure for encoder:
100
- # 0: Conv1d → Conv1dNCL wrapper (add .conv. prefix)
101
- # 1: GroupNorm or Identity
102
- # 2: Identity (GELU placeholder)
103
- # 3+: DConv (if present), then rewrite Conv1d, GroupNorm, Identity
104
- if layer_idx == 0 and (rest.startswith('weight') or rest.startswith('bias')):
105
- # Conv1d → wrap in Conv1dNCL
106
- new_key = f"encoder.{enc_idx}.layers.{layer_idx}.conv.{rest}"
107
- elif rest.startswith('layers.'):
108
- # DConv internal structure - remap sequential to named
109
- new_key = remap_dconv_key(f"encoder.{enc_idx}.layers.{layer_idx}", rest, value)
110
- if new_key:
111
- remapped[new_key] = value
112
- continue
113
- else:
114
- # Fallback: keep original structure
115
- new_key = f"encoder.{enc_idx}.layers.{layer_idx}.{rest}"
116
- else:
117
- new_key = f"encoder.{enc_idx}.layers.{layer_idx}.{rest}"
118
-
119
- remapped[new_key] = value
120
- continue
121
-
122
- # Handle decoder layers (similar structure but reversed)
123
- if len(parts) >= 3 and parts[0] == 'decoder':
124
- dec_idx = parts[1]
125
- layer_idx = int(parts[2])
126
- rest = '.'.join(parts[3:])
127
-
128
- # For decoder, rewrite comes first, then DConv, then ConvTranspose
129
- # Need to check what the sequential order is
130
- new_key = f"decoder.{dec_idx}.layers.{layer_idx}.{rest}"
131
-
132
- # Conv layers need wrapping
133
- if (rest.startswith('weight') or rest.startswith('bias')) and len(value.shape) >= 2:
134
- # Check if it's a conv by shape
135
- if len(value.shape) == 3:
136
- new_key = f"decoder.{dec_idx}.layers.{layer_idx}.conv.{rest}"
137
- # else it's a GroupNorm - keep as is
138
-
139
- remapped[new_key] = value
140
- continue
141
-
142
- # Handle LSTM
143
- if parts[0] == 'lstm':
144
- remapped[key] = value
145
- continue
146
-
147
- remapped[key] = value
148
-
149
- return remapped
150
-
151
-
152
- def remap_dconv_key(prefix: str, rest: str, value: np.ndarray) -> str | None:
153
- """Remap DConv internal key structure.
154
-
155
- PyTorch DConv uses nn.Sequential for each block:
156
- - layers[0][0] = Conv1d (depthwise)
157
- - layers[0][1] = GroupNorm
158
- - layers[0][2] = Identity
159
- - layers[0][3] = Conv1d (pointwise)
160
- - layers[0][4] = GroupNorm
161
- - layers[0][5] = Identity
162
- - layers[0][6] = LayerScale
163
-
164
- MLX DConvBlock uses:
165
- - layers[0] = DConvSlot(.conv) → has .conv.weight/.conv.bias
166
- - layers[1] = DConvSlot(.normGELU) → has .weight/.bias
167
- - layers[2] = DConvSlot(.identity) → no params
168
- - layers[3] = DConvSlot(.conv) → has .conv.weight/.conv.bias
169
- - layers[4] = DConvSlot(.normGLU) → has .weight/.bias
170
- - layers[5] = DConvSlot(.identity) → no params
171
- - layers[6] = DConvSlot(.scale) → has .scale
172
- """
173
- # rest looks like: layers.{block_idx}.{seq_idx}.weight
174
- parts = rest.split('.')
175
- if len(parts) < 4:
176
- return None
177
-
178
- block_idx = parts[1]
179
- seq_idx = int(parts[2])
180
- param_rest = '.'.join(parts[3:])
181
-
182
- # Map sequential index to DConvSlot index
183
- # PyTorch seq: 0=Conv, 1=GroupNorm, 2=Identity, 3=Conv1x1, 4=GroupNorm, 5=Identity, 6=Scale
184
- # MLX slots: 0=conv, 1=normGELU, 2=identity, 3=conv, 4=normGLU, 5=identity, 6=scale
185
-
186
- if seq_idx in (0, 3):
187
- # Conv layers - wrap in DConvSlot .conv
188
- new_key = f"{prefix}.layers.{block_idx}.layers.{seq_idx}.conv.{param_rest}"
189
- elif seq_idx in (1, 4):
190
- # GroupNorm - direct weight/bias
191
- new_key = f"{prefix}.layers.{block_idx}.layers.{seq_idx}.{param_rest}"
192
- elif seq_idx == 6:
193
- # LayerScale - has .scale parameter
194
- if param_rest == 'scale':
195
- new_key = f"{prefix}.layers.{block_idx}.layers.{seq_idx}.{param_rest}"
196
- else:
197
- return None
198
- else:
199
- return None
200
-
201
- return new_key
202
-
203
-
204
- def export_model(model_name: str, out_dir: Path) -> bool:
205
- """Export a model to safetensors + config JSON."""
206
- from demucs.pretrained import get_model
207
-
208
- print(f"\n--- Exporting {model_name} ---")
209
- try:
210
- bag = get_model(model_name)
211
- except Exception as e:
212
- print(f" Failed to load model: {e}")
213
- return False
214
-
215
- from demucs.apply import BagOfModels
216
-
217
- if not isinstance(bag, BagOfModels):
218
- print(f" Expected BagOfModels, got {type(bag).__name__}")
219
- return False
220
-
221
- num_models = len(bag.models)
222
- print(f" Bag of {num_models} models")
223
-
224
- # Collect all weights with model_X prefix
225
- all_weights = {}
226
- model_classes = []
227
- model_kwargs_list = []
228
-
229
- for i, sub_model in enumerate(bag.models):
230
- cls_name = type(sub_model).__name__
231
- print(f" Model {i}: {cls_name}")
232
- model_classes.append(cls_name)
233
-
234
- # Get state dict
235
- sd = sub_model.state_dict()
236
- flat = {}
237
- for key, tensor in sd.items():
238
- arr = tensor.detach().cpu().numpy()
239
- # Transpose conv weights
240
- if key.endswith('.weight'):
241
- if len(arr.shape) == 3:
242
- arr = np.transpose(arr, (0, 2, 1))
243
- elif len(arr.shape) == 4:
244
- arr = np.transpose(arr, (0, 2, 3, 1))
245
- flat[f"model_{i}.{key}"] = arr
246
-
247
- all_weights.update(flat)
248
-
249
- # Extract kwargs
250
- import inspect
251
- init_sig = inspect.signature(type(sub_model).__init__)
252
- kwargs = {}
253
- for param_name in init_sig.parameters:
254
- if param_name == 'self':
255
- continue
256
- if hasattr(sub_model, param_name):
257
- val = getattr(sub_model, param_name)
258
- if isinstance(val, torch.Tensor):
259
- val = val.item()
260
- elif isinstance(val, (list, tuple)):
261
- val = list(val)
262
- kwargs[param_name] = val
263
- model_kwargs_list.append(kwargs)
264
-
265
- # Save safetensors
266
- model_dir = out_dir / model_name
267
- model_dir.mkdir(parents=True, exist_ok=True)
268
-
269
- safetensors_path = model_dir / f"{model_name}.safetensors"
270
- config_path = model_dir / f"{model_name}_config.json"
271
-
272
- # Convert numpy arrays to mlx arrays and save
273
- try:
274
- import mlx.core as mx
275
- mlx_weights = {k: mx.array(v) for k, v in all_weights.items()}
276
- mx.save_safetensors(str(safetensors_path), mlx_weights)
277
- except ImportError:
278
- # Fallback: use safetensors library directly
279
- from safetensors.numpy import save_file
280
- save_file(all_weights, str(safetensors_path))
281
-
282
- # Build config
283
- # Map PyTorch class names to MLX class names
284
- class_map = {
285
- 'Demucs': 'DemucsMLX',
286
- 'HDemucs': 'HDemucsMLX',
287
- 'HTDemucs': 'HTDemucsMLX',
288
- }
289
-
290
- # Get weights
291
- weights = None
292
- if bag.weights is not None:
293
- weights = bag.weights.tolist() if hasattr(bag.weights, 'tolist') else list(bag.weights)
294
-
295
- config = {
296
- "model_name": model_name,
297
- "model_class": "BagOfModelsMLX",
298
- "num_models": num_models,
299
- "weights": weights,
300
- "sub_model_classes": [class_map.get(c, c) for c in model_classes],
301
- "model_configs": [],
302
- "tensor_count": len(all_weights),
303
- }
304
-
305
- # If all models are the same class, also set sub_model_class for compatibility
306
- unique_classes = set(config["sub_model_classes"])
307
- if len(unique_classes) == 1:
308
- config["sub_model_class"] = unique_classes.pop()
309
-
310
- # Add per-model configs
311
- for i, (cls, kwargs) in enumerate(zip(model_classes, model_kwargs_list)):
312
- model_config = {
313
- "model_class": class_map.get(cls, cls),
314
- "kwargs": {},
315
- }
316
- # Convert kwargs to JSON-serializable
317
- for k, v in kwargs.items():
318
- if isinstance(v, (int, float, str, bool, list)):
319
- model_config["kwargs"][k] = v
320
- elif v is None:
321
- model_config["kwargs"][k] = None
322
- config["model_configs"].append(model_config)
323
-
324
- with config_path.open("w") as f:
325
- json.dump(config, f, indent=2, default=str)
326
-
327
- print(f" Wrote {safetensors_path} ({len(all_weights)} tensors)")
328
- print(f" Wrote {config_path}")
329
- return True
330
-
331
-
332
- def main():
333
- ap = argparse.ArgumentParser(description="Export mdx/mdx_extra models")
334
- ap.add_argument("--model", default="mdx", help="Model name")
335
- ap.add_argument("--out-dir", default=".scratch/models", help="Output directory")
336
- args = ap.parse_args()
337
-
338
- out_dir = Path(args.out_dir).resolve()
339
- export_model(args.model, out_dir)
340
-
341
-
342
- if __name__ == "__main__":
343
- main()