Skip to content

Commit 9c5519f

Browse files
solderzzcericjlake
andauthored
feat(fast): preadIntoOffset for stacked-buffer MoE consumers + PAPPS try_take (#10)
Cherry-pick of ericjlake/mlx-swift@761381f with overflow-safe bounds check fix. Adds mlx_fast_pread_into_offset — writes one expert slab into a stacked destination buffer at a byte offset. Enables the stacked-buffer MoE fast path in mlx-swift-lm's SwitchGLU (PR ml-explore#35). Copilot review fixes applied: - Overflow-safe size_t bounds check - Swift precondition(dstOffset >= 0) guard Co-authored-by: Eric Lake <ericjlake@gmail.com>
1 parent 6b27940 commit 9c5519f

4 files changed

Lines changed: 117 additions & 0 deletions

File tree

Source/Cmlx/include/mlx/c/fast.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,18 @@ int mlx_fast_pread_into(
234234
const char* tensor_name,
235235
uint32_t expert_index);
236236

237+
// Like mlx_fast_pread_into, but writes the expert's bytes into the dst buffer
238+
// starting at byte offset `dst_offset`. Reads exactly `bytes_per_expert` bytes
239+
// (NOT the whole dst array). Use this to populate one slot of a stacked
240+
// `[N_slots, ..., ...]` buffer, where `dst_offset = slot * bytes_per_expert`.
241+
// Bounds check: dst_offset + bytes_per_expert <= dst.nbytes.
242+
int mlx_fast_pread_into_offset(
243+
mlx_array dst,
244+
const char* safetensors_path,
245+
const char* tensor_name,
246+
uint32_t expert_index,
247+
size_t dst_offset);
248+
237249
// mlx_fast_submit_prefetch (PAPPS Background Worker)
238250
int mlx_fast_submit_prefetch(
239251
const char* safetensors_path,

Source/Cmlx/mlx-c/mlx/c/fast.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,3 +1120,71 @@ extern "C" int mlx_fast_pread_into(
11201120
}
11211121
return 0;
11221122
}
1123+
1124+
// mlx_fast_pread_into_offset — variant that writes ONE expert into a slot of
1125+
// a stacked destination buffer. Used by SwitchGLU's stacked-buffer fast path
1126+
// (TEND_MOE_STACKED=1) to avoid `MLX.concatenated` cost when fusing per-expert
1127+
// matmuls into a single gatherQuantizedMM dispatch.
1128+
//
1129+
// dst_offset is bytes (not elements). Reads exactly `bytes_per_expert` bytes
1130+
// from the safetensors file at the requested expert index, into
1131+
// `dst.data() + dst_offset`. Bounds check (overflow-safe):
1132+
// dst_offset <= dst.nbytes() && bytes_per_expert <= dst.nbytes() - dst_offset.
1133+
//
1134+
// PAPPS fast path: if a background worker already preloaded this expert
1135+
// (cache_id = path|tname_<file_offset>, which is independent of dst_offset),
1136+
// take it via try_take() and memcpy into the slot, skipping the synchronous
1137+
// pread. Caller is expected to issue mlx_fast_submit_prefetch ahead of time
1138+
// (e.g. at last-token routing) to populate the PAPPS cache.
1139+
extern "C" int mlx_fast_pread_into_offset(
1140+
mlx_array dst,
1141+
const char* safetensors_path,
1142+
const char* tensor_name,
1143+
uint32_t expert_index,
1144+
size_t dst_offset) {
1145+
try {
1146+
std::string path(safetensors_path);
1147+
std::string tname(tensor_name);
1148+
std::string key = path + "|" + tname;
1149+
1150+
STPReadEntry entry = get_safetensors_entry(path, tname, key);
1151+
1152+
auto& arr = mlx_array_get_(dst);
1153+
void* base = const_cast<void*>(static_cast<const void*>(arr.data<uint8_t>()));
1154+
if (!base) throw std::runtime_error("[pread_into_offset] dst has no data pointer — call eval() first");
1155+
size_t total_nbytes = arr.nbytes();
1156+
size_t bpe = entry.bytes_per_expert;
1157+
if (dst_offset > total_nbytes || bpe > total_nbytes - dst_offset) {
1158+
throw std::runtime_error(
1159+
"[pread_into_offset] dst_offset (" + std::to_string(dst_offset) +
1160+
") + bytes_per_expert (" + std::to_string(bpe) +
1161+
") > dst.nbytes (" + std::to_string(total_nbytes) + ")");
1162+
}
1163+
void* slot_buf = static_cast<uint8_t*>(base) + dst_offset;
1164+
off_t file_offset = static_cast<off_t>(entry.data_start + (size_t)expert_index * bpe);
1165+
1166+
// PAPPS fast path: try to absorb a previously-submitted prefetch.
1167+
// cache_id is keyed on (path,tname,file_offset) — same as full-buffer
1168+
// variant — so a single submit_prefetch call serves both consumers.
1169+
std::string cache_id = key + "_" + std::to_string(file_offset);
1170+
bool hit = false;
1171+
{
1172+
std::lock_guard<std::mutex> lock(global_papps_mutex);
1173+
if (global_papps_queue) {
1174+
hit = global_papps_queue->try_take(cache_id, slot_buf, bpe);
1175+
}
1176+
}
1177+
if (hit) {
1178+
return 0; // memcpy from PAPPS cache complete; no syscall
1179+
}
1180+
1181+
// Cache miss — synchronous pread into the slot.
1182+
ssize_t result = pread(entry.fd, slot_buf, bpe, file_offset);
1183+
if (result < 0 || (size_t)result != bpe)
1184+
throw std::runtime_error("[pread_into_offset] pread failed: got " + std::to_string(result) + " of " + std::to_string(bpe));
1185+
} catch (std::exception& e) {
1186+
mlx_error(e.what());
1187+
return 1;
1188+
}
1189+
return 0;
1190+
}

Source/Cmlx/mlx-c/mlx/c/fast.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,18 @@ int mlx_fast_pread_into(
249249
const char* tensor_name,
250250
uint32_t expert_index);
251251

252+
// Like mlx_fast_pread_into, but writes the expert's bytes into the dst buffer
253+
// starting at byte offset `dst_offset`. Reads exactly `bytes_per_expert` bytes
254+
// (NOT the whole dst array). Use this to populate one slot of a stacked
255+
// `[N_slots, ..., ...]` buffer, where `dst_offset = slot * bytes_per_expert`.
256+
// Bounds check: dst_offset + bytes_per_expert <= dst.nbytes.
257+
int mlx_fast_pread_into_offset(
258+
mlx_array dst,
259+
const char* safetensors_path,
260+
const char* tensor_name,
261+
uint32_t expert_index,
262+
size_t dst_offset);
263+
252264
/**@}*/
253265

254266
// ── SSD Flash-Stream metrics snapshot ────────────────────────────────────────

Source/MLX/MLXFast.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,31 @@ public enum MLXFast {
376376
}
377377
}
378378

379+
/// Like `preadInto`, but writes the expert's bytes into the destination at
380+
/// byte-offset `dstOffset`. Reads exactly `bytes_per_expert` bytes (the
381+
/// safetensors entry's per-expert slab size), NOT the whole dst.
382+
///
383+
/// Use when you have a stacked `[N_slots, ..., ...]` MLXArray and want to
384+
/// populate slot `k` via `dstOffset = k * bytesPerExpert`. Lets a single
385+
/// `gatherQuantizedMM` call replace a per-expert loop, eliminating both
386+
/// the per-expert kernel-launch overhead and the `MLX.concatenated` Metal
387+
/// copy that would otherwise be needed to fuse N independent buffers.
388+
@discardableResult
389+
public static func preadIntoOffset(
390+
_ dst: MLXArray,
391+
safetensorsPath: String,
392+
tensorName: String,
393+
expertIndex: UInt32,
394+
dstOffset: Int
395+
) -> Int32 {
396+
precondition(dstOffset >= 0, "dstOffset must be non-negative")
397+
return safetensorsPath.withCString { pathPtr in
398+
tensorName.withCString { namePtr in
399+
mlx_fast_pread_into_offset(dst.ctx, pathPtr, namePtr, expertIndex, dstOffset)
400+
}
401+
}
402+
}
403+
379404
/// Submits an asynchronous background prefetch for a specific expert's weights.
380405
/// The fetch is handled by a persistent C++ background thread and placed in a unified memory arena.
381406
public static func pappsPrefetch(

0 commit comments

Comments
 (0)