Skip to content

Commit cf654d4

Browse files
Brooooooklynclaude
andauthored
fix(convert): unblock 250GB+ MoE conversions (CPU stream + drain on write) (mlx-node#63)
## Summary `mlx convert` on a 250 GB Qwen3.5 122B-A10B checkpoint (256 experts × 48 layers) failed two different ways: 1. **macOS Metal watchdog kill** (~5 s) when materializing a 1.6 GB sliced view of a fused \`experts.gate_up_proj\` backed by a cold mmap'd HF shard — surfaces as \`kIOGPUCommandBufferCallbackErrorTimeout\` mid-shard. 2. **Silent OOM-kill at shard 35/49** with MLX allocator at 162 GB active memory — each materialized contiguous backing buffer stayed live in the in-memory \`HashMap<String, MxArray>\` for the entire sharded save, blowing through 128 GB RAM. Both fixes are required to convert this checkpoint at all. ### Fix 1: CPU device + stream for convert Conversion does only slice / reshape / dtype-cast — no real math — so the CPU is semantically correct and immune to the Metal watchdog. A new RAII guard (\`ConvertDefaultStreamGuard\` / \`ConvertGgufDefaultStreamGuard\`) flips both \`set_default_device(CPU)\` AND \`set_default_stream(cpu_default)\` at the start of \`convert_model\` / \`convert_gguf_to_safetensors\` and restores the previous values on drop. Setting the stream alone is NOT enough — MLX dispatches stream-less ops via \`default_stream(default_device())\`, so the device pin is load-bearing. New FFI shims \`mlx_default_device\` / \`mlx_set_default_device\` are added to \`mlx-sys\`. ### Fix 2: Drain the tensor map as each tensor is written \`save_safetensors_single\` / \`_sharded\` / \`save_safetensors\` now take \`&mut HashMap<String, MxArray>\` and call \`.remove(name)\` after each tensor's bytes hit disk. This releases the MLX backing buffer immediately and keeps MLX active memory bounded at ~4.6 GB peak instead of growing unbounded. All callers updated: \`convert.rs\`, \`training_state.rs\`, \`gguf.rs\`, \`foreign_weights.rs\`, \`qwen3/qwen3_5/qwen3_5_moe/model.rs\`. ### Production logs \`info!\` level now exposes: - convert begin/end with structured fields (\`input_dir\`, \`output_dir\`, \`model_type\`, \`quantize\`, \`total_seconds\`, \`num_tensors\`, \`num_parameters\`) - per-shard timing, MB, avg MB/s, MLX \`active_mb\` / \`peak_mb\` / \`cache_mb\` - any single-tensor materialization ≥ 2 s (watchdog / cold-mmap signal) \`debug!\` level keeps the full per-tensor trace for deep debugging via \`MLX_NODE_LOG=\"mlx_core::utils::safetensors=debug\"\`. ## Verification (Qwen3.5 122B-A10B, 250 GB → bf16 MLX) | | Before | After | |---|---|---| | Result | Died at shard 3/49 (Metal watchdog), then shard 35/49 (OOM-kill) | ✓ 49/49 in 11:40 | | MLX peak memory | 162 GB | **4.6 GB** | | MLX active (steady-state) | growing unbounded | 0 MB | | Avg throughput | n/a (crash) | 334 MB/s sustained | \`cargo clippy --all-targets -- -D warnings\` and \`cargo fmt --check\` both clean. ## Test plan - [x] Qwen3.5 122B-A10B full bf16 conversion completes end-to-end (49 shards + index) - [x] \`cargo clippy --all-targets -- -D warnings\` - [x] \`cargo fmt --check\` - [ ] Spot-check that small / already-working conversions (Qwen3 0.6B, smaller MoE) still work — same code path now uses CPU stream, expected to be a no-op or trivially faster - [ ] Spot-check that GGUF→SafeTensors path is unaffected 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- CURSOR_SUMMARY --> --- > [!NOTE] > **Medium Risk** > Changes process-wide MLX default device/stream during convert (documented inference overlap risk) and mutates save APIs site-wide; behavior is intentional for CLI convert but embedders must serialize inference. > > **Overview** > Large HuggingFace / GGUF → MLX conversions are made reliable on huge MoE checkpoints by **routing convert work on CPU** and **releasing MLX memory as each tensor is written**. > > A new **`CpuConvertGuard`** temporarily sets MLX’s default **device and stream** to CPU for `convert_model` and `convert_gguf_to_safetensors`, then restores them on drop—avoiding Metal watchdog timeouts when materializing multi‑GB mmap-backed expert slices. A process-wide **`convert_mutex`** serializes conversions so global MLX defaults aren’t raced. **`mlx_default_device` / `mlx_set_default_device`** are added in `mlx-sys` to support this. > > **SafeTensors writers** now take `&mut HashMap<String, MxArray>` and **`.remove` each tensor after it’s serialized**, so backing buffers don’t accumulate through 49‑shard saves (fixes silent OOM on ~250 GB models). Call sites in convert, GGUF, foreign weights, Qwen saves, and optimizer state were updated; GGUF/foreign paths **snapshot tensor names before save** because the map may be drained. > > **Structured logging** was added for convert start/end, sharded save duration, per-shard throughput, MLX active/peak/cache MB, and slow (≥2 s) tensor materializations. > > <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit 4ba9c3e. Bugbot is set up for automated code reviews on this repo. Configure [here](https://www.cursor.com/dashboard/bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 95d46db commit cf654d4

10 files changed

Lines changed: 310 additions & 24 deletions

File tree

crates/mlx-core/src/convert.rs

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,81 @@ use crate::models::paddleocr_vl::persistence::load_paddleocr_vl_weights;
2020
use crate::models::qianfan_ocr::persistence::load_qianfan_ocr_weights;
2121
use crate::utils::safetensors::load_safetensors_lazy;
2222

23+
/// RAII guard that pins the MLX default device + stream to CPU for one
24+
/// conversion call, then restores the previous values on drop.
25+
///
26+
/// Used by the conversion path to temporarily route every MLX op through
27+
/// CPU for the duration of one `convert_model` /
28+
/// `convert_gguf_to_safetensors` call. Both the default *device* and the
29+
/// default *stream* must be switched: MLX dispatches stream-less ops via
30+
/// `default_stream(default_device())`, so flipping the stream alone is
31+
/// not enough — the device must be CPU too. On drop, the previous
32+
/// device and stream are restored so subsequent inference / training
33+
/// calls keep using the GPU. See the call sites for the rationale.
34+
///
35+
/// MUST be acquired while holding `CONVERT_MUTEX`'s lock — otherwise two
36+
/// overlapping conversions can race on the process-wide MLX defaults and
37+
/// restore each other's `saved_*` fields incorrectly (e.g. both observe
38+
/// the already-flipped CPU device as "original", then both restore to
39+
/// CPU, leaving the process pinned to CPU for the next inference call).
40+
///
41+
/// **Concurrent-inference limitation (intentional):** `convert_mutex`
42+
/// only serializes convert-vs-convert. It does NOT block inference /
43+
/// training entrypoints. If a Node process runs `convert_model` while
44+
/// also serving inference, those inference ops resolve their stream via
45+
/// `default_stream(default_device())` and will be silently routed to
46+
/// CPU until the conversion finishes — typically minutes to hours on
47+
/// large MoE checkpoints, with severe latency degradation. The
48+
/// architecturally correct fix is to plumb explicit `Stream` arguments
49+
/// through every convert-used MLX FFI op so the global default is never
50+
/// touched; that's a substantial refactor outside the scope of this
51+
/// change. For the supported usage today (the `mlx convert` CLI exits
52+
/// after conversion; no other entrypoint in this codebase invokes
53+
/// convert), this is a non-issue. Callers who embed convert inside a
54+
/// long-lived multi-tenant Node process should serialize their own
55+
/// inference against convert externally.
56+
pub(crate) struct CpuConvertGuard {
57+
saved_device: i32,
58+
saved_stream: mlx_sys::mlx_stream,
59+
}
60+
61+
impl CpuConvertGuard {
62+
/// Enter the CPU device + stream. The caller is responsible for holding
63+
/// `CONVERT_MUTEX` for the lifetime of the returned guard.
64+
pub(crate) fn enter_cpu() -> Self {
65+
let saved_device = unsafe { mlx_sys::mlx_default_device() };
66+
let saved_stream = unsafe { mlx_sys::mlx_default_stream(saved_device) };
67+
unsafe { mlx_sys::mlx_set_default_device(0) };
68+
let cpu_stream = unsafe { mlx_sys::mlx_default_stream(0) };
69+
unsafe { mlx_sys::mlx_set_default_stream(cpu_stream) };
70+
Self {
71+
saved_device,
72+
saved_stream,
73+
}
74+
}
75+
}
76+
77+
impl Drop for CpuConvertGuard {
78+
fn drop(&mut self) {
79+
unsafe { mlx_sys::mlx_set_default_stream(self.saved_stream) };
80+
unsafe { mlx_sys::mlx_set_default_device(self.saved_device) };
81+
}
82+
}
83+
84+
/// Process-wide async mutex serializing all conversion calls.
85+
///
86+
/// `convert_model` and `convert_gguf_to_safetensors` mutate MLX's
87+
/// process-wide default device + default stream via `CpuConvertGuard`,
88+
/// which is unsafe under concurrency: two overlapping conversions (or a
89+
/// convert during inference that depends on the GPU default) can race on
90+
/// the global state. Both NAPI entrypoints `.await` this mutex before
91+
/// constructing a `CpuConvertGuard`, so only one conversion runs at a
92+
/// time across the entire Node process.
93+
pub(crate) fn convert_mutex() -> &'static tokio::sync::Mutex<()> {
94+
static CONVERT_MUTEX: std::sync::OnceLock<tokio::sync::Mutex<()>> = std::sync::OnceLock::new();
95+
CONVERT_MUTEX.get_or_init(|| tokio::sync::Mutex::new(()))
96+
}
97+
2398
/// Structure for parsing model.safetensors.index.json
2499
#[derive(Debug, Deserialize)]
25100
struct ShardedModelIndex {
@@ -115,6 +190,39 @@ pub struct ConversionResult {
115190
/// ```
116191
#[napi]
117192
pub async fn convert_model(options: ConversionOptions) -> Result<ConversionResult> {
193+
let _convert_start = std::time::Instant::now();
194+
info!(
195+
target = "mlx_core::convert",
196+
input_dir = %options.input_dir,
197+
output_dir = %options.output_dir,
198+
dtype = ?options.dtype,
199+
model_type = ?options.model_type,
200+
quantize = options.quantize.unwrap_or(false),
201+
quant_mode = ?options.quant_mode,
202+
quant_recipe = ?options.quant_recipe,
203+
"convert_model start"
204+
);
205+
let result = convert_model_inner(options).await;
206+
match &result {
207+
Ok(r) => info!(
208+
target = "mlx_core::convert",
209+
total_seconds = _convert_start.elapsed().as_secs_f64(),
210+
num_tensors = r.num_tensors,
211+
num_parameters = r.num_parameters,
212+
output_path = %r.output_path,
213+
"convert_model finished"
214+
),
215+
Err(e) => tracing::error!(
216+
target = "mlx_core::convert",
217+
total_seconds = _convert_start.elapsed().as_secs_f64(),
218+
error = %e,
219+
"convert_model failed"
220+
),
221+
}
222+
result
223+
}
224+
225+
async fn convert_model_inner(options: ConversionOptions) -> Result<ConversionResult> {
118226
let input_dir = PathBuf::from(&options.input_dir);
119227
let output_dir = PathBuf::from(&options.output_dir);
120228
let target_dtype = options.dtype.unwrap_or_else(|| "float32".to_string());
@@ -249,6 +357,23 @@ pub async fn convert_model(options: ConversionOptions) -> Result<ConversionResul
249357
)));
250358
}
251359

360+
// Serialize all conversions process-wide before touching MLX's default
361+
// device + stream — see `convert_mutex` and `CpuConvertGuard` docs for
362+
// the race this avoids.
363+
let _convert_lock = convert_mutex().lock().await;
364+
365+
// Route every MLX op in this conversion through the CPU device + stream.
366+
//
367+
// The conversion path is slice / reshape / dtype-cast only — no real math.
368+
// On GPU, materializing a 1.6 GB sliced view of a fused expert tensor backed
369+
// by a 250 GB mmap'd source can stall a Metal command buffer past the macOS
370+
// GPU watchdog (~5 s), surfacing as
371+
// `kIOGPUCommandBufferCallbackErrorTimeout` mid-shard for large MoE models
372+
// (e.g. Qwen3.5 122B-A10B with 256 experts × 48 layers). CPU has direct
373+
// access to the mmap'd pages and is immune to the watchdog. `_stream_guard`
374+
// restores the prior default device + stream when convert_model returns.
375+
let _stream_guard = CpuConvertGuard::enter_cpu();
376+
252377
// Check for required files
253378
let config_path = input_dir.join("config.json");
254379
if !config_path.exists() {
@@ -660,9 +785,20 @@ pub async fn convert_model(options: ConversionOptions) -> Result<ConversionResul
660785
tensor_names.sort();
661786

662787
// Save converted model — sharded output with index file (mlx-lm/mlx-vlm compatible)
663-
info!("Saving converted model to: {}", output_dir.display());
788+
info!(
789+
target = "mlx_core::convert",
790+
output_dir = %output_dir.display(),
791+
num_tensors = converted_tensors.len(),
792+
"starting sharded save"
793+
);
664794

665-
crate::utils::safetensors::save_safetensors_sharded(&output_dir, &converted_tensors)?;
795+
let save_start = std::time::Instant::now();
796+
crate::utils::safetensors::save_safetensors_sharded(&output_dir, &mut converted_tensors)?;
797+
info!(
798+
target = "mlx_core::convert",
799+
save_seconds = save_start.elapsed().as_secs_f64(),
800+
"sharded save complete"
801+
);
666802

667803
// Write config.json — clean and sort keys to match mlx-lm/mlx-vlm save_config
668804
let output_config_path = output_dir.join("config.json");

crates/mlx-core/src/models/qwen3/model.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4479,7 +4479,7 @@ impl Qwen3Inner {
44794479
}
44804480
}
44814481

4482-
let params_clone: HashMap<String, MxArray> =
4482+
let mut params_clone: HashMap<String, MxArray> =
44834483
params.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
44844484

44854485
// Build weights.mlx metadata (shape + dtype only; full data is in safetensors).
@@ -4528,7 +4528,11 @@ impl Qwen3Inner {
45284528
"format": "mlx-node",
45294529
"version": "1.0"
45304530
}));
4531-
crate::utils::safetensors::save_safetensors(&safetensors_path, &params_clone, metadata)?;
4531+
crate::utils::safetensors::save_safetensors(
4532+
&safetensors_path,
4533+
&mut params_clone,
4534+
metadata,
4535+
)?;
45324536
info!("Saved weights.safetensors");
45334537

45344538
let weights_str = serde_json::to_string_pretty(&weights_json)?;

crates/mlx-core/src/models/qwen3_5/model.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,7 +1353,7 @@ impl Qwen35Inner {
13531353
}
13541354
}
13551355

1356-
let params_clone: HashMap<String, MxArray> =
1356+
let mut params_clone: HashMap<String, MxArray> =
13571357
params.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
13581358

13591359
// Weights metadata
@@ -1401,7 +1401,11 @@ impl Qwen35Inner {
14011401
"format": "mlx-node",
14021402
"version": "1.0"
14031403
}));
1404-
crate::utils::safetensors::save_safetensors(&safetensors_path, &params_clone, metadata)?;
1404+
crate::utils::safetensors::save_safetensors(
1405+
&safetensors_path,
1406+
&mut params_clone,
1407+
metadata,
1408+
)?;
14051409
info!("Saved weights.safetensors");
14061410

14071411
let weights_str = serde_json::to_string_pretty(&weights_json)?;

crates/mlx-core/src/models/qwen3_5_moe/model.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5303,7 +5303,7 @@ impl Qwen35MoeInner {
53035303
}
53045304
}
53055305

5306-
let params_clone: HashMap<String, MxArray> =
5306+
let mut params_clone: HashMap<String, MxArray> =
53075307
params.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
53085308

53095309
// Weights metadata (reference sidecar)
@@ -5351,7 +5351,11 @@ impl Qwen35MoeInner {
53515351
"format": "mlx-node",
53525352
"version": "1.0"
53535353
}));
5354-
crate::utils::safetensors::save_safetensors(&safetensors_path, &params_clone, metadata)?;
5354+
crate::utils::safetensors::save_safetensors(
5355+
&safetensors_path,
5356+
&mut params_clone,
5357+
metadata,
5358+
)?;
53555359
info!("Saved weights.safetensors");
53565360

53575361
let weights_str = serde_json::to_string_pretty(&weights_json)?;

crates/mlx-core/src/training_state.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ impl ModelThreadTrainingState {
124124
"step": step.to_string(),
125125
"format": "adamw_optimizer_state",
126126
});
127-
crate::utils::safetensors::save_safetensors(path, &tensors, Some(metadata))
127+
crate::utils::safetensors::save_safetensors(path, &mut tensors, Some(metadata))
128128
}
129129

130130
/// Restore AdamW moment tensors + step from a SafeTensors file.
@@ -389,7 +389,7 @@ mod tests {
389389
let arr = MxArray::from_float32(&[*val], &[1]).unwrap();
390390
tensor_map.insert(key.to_string(), arr);
391391
}
392-
crate::utils::safetensors::save_safetensors(path, &tensor_map, metadata).unwrap();
392+
crate::utils::safetensors::save_safetensors(path, &mut tensor_map, metadata).unwrap();
393393
}
394394

395395
// =========================================================================

crates/mlx-core/src/utils/foreign_weights.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ pub fn convert_foreign_weights(
6969
))
7070
})?;
7171

72-
let (tensors, config_json) = match options.model_type.as_str() {
72+
let (mut tensors, config_json) = match options.model_type.as_str() {
7373
"pp-lcnet-ori" => convert_pp_lcnet_ori(&input_path, verbose)?,
7474
"uvdoc" => convert_uvdoc(&input_path, verbose)?,
7575
other => {
@@ -81,9 +81,9 @@ pub fn convert_foreign_weights(
8181

8282
// Save SafeTensors
8383
let weights_path = output_dir.join("model.safetensors");
84-
save_safetensors(&weights_path, &tensors, None)?;
85-
84+
// Capture names BEFORE save (save drains the map for memory reasons).
8685
let mut tensor_names: Vec<String> = tensors.keys().cloned().collect();
86+
save_safetensors(&weights_path, &mut tensors, None)?;
8787
tensor_names.sort();
8888
let num_tensors = tensor_names.len() as i32;
8989

crates/mlx-core/src/utils/gguf.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,15 @@ pub async fn convert_gguf_to_safetensors(
12791279
)));
12801280
}
12811281

1282+
// Serialize all conversions process-wide before touching MLX's default
1283+
// device + stream. Then route every MLX op through CPU for the duration
1284+
// of this call. See `crate::convert::convert_mutex` and
1285+
// `crate::convert::CpuConvertGuard` for the full rationale — same
1286+
// reasoning applies here for GGUF→SafeTensors conversion of huge MoE
1287+
// checkpoints.
1288+
let _convert_lock = crate::convert::convert_mutex().lock().await;
1289+
let _stream_guard = crate::convert::CpuConvertGuard::enter_cpu();
1290+
12821291
// Parse GGUF header and metadata
12831292
info!("Parsing GGUF file: {}", input_path.display());
12841293
let gguf = parse_gguf(&input_path)?;
@@ -1569,10 +1578,16 @@ pub async fn convert_gguf_to_safetensors(
15691578
.unwrap_or("model.safetensors");
15701579
let safetensors_path = output_dir.join(safetensors_filename);
15711580
info!("Saving to {}", safetensors_path.display());
1581+
// Capture tensor names BEFORE `save_safetensors` — it drains `weights`
1582+
// as it streams each tensor to disk so MLX-allocated backing buffers
1583+
// can be released immediately on large MoE checkpoints. Reading
1584+
// `weights.keys()` after the save would return an empty list and the
1585+
// GgufConversionResult would report num_tensors = 0 to JS callers.
1586+
let tensor_names: Vec<String> = weights.keys().cloned().collect();
15721587
// Add "format: mlx" metadata so loaders (e.g., mlx-vlm) know weights are
15731588
// already in MLX layout and skip sanitize (which would double-apply +1.0 to norms).
15741589
let st_metadata = serde_json::json!({ "format": "mlx" });
1575-
save_safetensors(&safetensors_path, &weights, Some(st_metadata))?;
1590+
save_safetensors(&safetensors_path, &mut weights, Some(st_metadata))?;
15761591

15771592
// Only write config.json and tokenizer files for the primary model file.
15781593
// Secondary files (e.g., vision.safetensors for mmproj) should not overwrite
@@ -1652,8 +1667,6 @@ pub async fn convert_gguf_to_safetensors(
16521667
.collect::<Vec<_>>()
16531668
.join(", ");
16541669

1655-
let tensor_names: Vec<String> = weights.keys().cloned().collect();
1656-
16571670
Ok(GgufConversionResult {
16581671
num_tensors: tensor_names.len() as i32,
16591672
num_parameters,

0 commit comments

Comments
 (0)