@@ -1120,3 +1120,71 @@ extern "C" int mlx_fast_pread_into(
11201120 }
11211121 return 0 ;
11221122}
1123+
1124+ // mlx_fast_pread_into_offset — variant that writes ONE expert into a slot of
1125+ // a stacked destination buffer. Used by SwitchGLU's stacked-buffer fast path
1126+ // (TEND_MOE_STACKED=1) to avoid `MLX.concatenated` cost when fusing per-expert
1127+ // matmuls into a single gatherQuantizedMM dispatch.
1128+ //
1129+ // dst_offset is bytes (not elements). Reads exactly `bytes_per_expert` bytes
1130+ // from the safetensors file at the requested expert index, into
1131+ // `dst.data() + dst_offset`. Bounds check (overflow-safe):
1132+ // dst_offset <= dst.nbytes() && bytes_per_expert <= dst.nbytes() - dst_offset.
1133+ //
1134+ // PAPPS fast path: if a background worker already preloaded this expert
1135+ // (cache_id = path|tname_<file_offset>, which is independent of dst_offset),
1136+ // take it via try_take() and memcpy into the slot, skipping the synchronous
1137+ // pread. Caller is expected to issue mlx_fast_submit_prefetch ahead of time
1138+ // (e.g. at last-token routing) to populate the PAPPS cache.
1139+ extern " C" int mlx_fast_pread_into_offset (
1140+ mlx_array dst,
1141+ const char * safetensors_path,
1142+ const char * tensor_name,
1143+ uint32_t expert_index,
1144+ size_t dst_offset) {
1145+ try {
1146+ std::string path (safetensors_path);
1147+ std::string tname (tensor_name);
1148+ std::string key = path + " |" + tname;
1149+
1150+ STPReadEntry entry = get_safetensors_entry (path, tname, key);
1151+
1152+ auto & arr = mlx_array_get_ (dst);
1153+ void * base = const_cast <void *>(static_cast <const void *>(arr.data <uint8_t >()));
1154+ if (!base) throw std::runtime_error (" [pread_into_offset] dst has no data pointer — call eval() first" );
1155+ size_t total_nbytes = arr.nbytes ();
1156+ size_t bpe = entry.bytes_per_expert ;
1157+ if (dst_offset > total_nbytes || bpe > total_nbytes - dst_offset) {
1158+ throw std::runtime_error (
1159+ " [pread_into_offset] dst_offset (" + std::to_string (dst_offset) +
1160+ " ) + bytes_per_expert (" + std::to_string (bpe) +
1161+ " ) > dst.nbytes (" + std::to_string (total_nbytes) + " )" );
1162+ }
1163+ void * slot_buf = static_cast <uint8_t *>(base) + dst_offset;
1164+ off_t file_offset = static_cast <off_t >(entry.data_start + (size_t )expert_index * bpe);
1165+
1166+ // PAPPS fast path: try to absorb a previously-submitted prefetch.
1167+ // cache_id is keyed on (path,tname,file_offset) — same as full-buffer
1168+ // variant — so a single submit_prefetch call serves both consumers.
1169+ std::string cache_id = key + " _" + std::to_string (file_offset);
1170+ bool hit = false ;
1171+ {
1172+ std::lock_guard<std::mutex> lock (global_papps_mutex);
1173+ if (global_papps_queue) {
1174+ hit = global_papps_queue->try_take (cache_id, slot_buf, bpe);
1175+ }
1176+ }
1177+ if (hit) {
1178+ return 0 ; // memcpy from PAPPS cache complete; no syscall
1179+ }
1180+
1181+ // Cache miss — synchronous pread into the slot.
1182+ ssize_t result = pread (entry.fd , slot_buf, bpe, file_offset);
1183+ if (result < 0 || (size_t )result != bpe)
1184+ throw std::runtime_error (" [pread_into_offset] pread failed: got " + std::to_string (result) + " of " + std::to_string (bpe));
1185+ } catch (std::exception& e) {
1186+ mlx_error (e.what ());
1187+ return 1 ;
1188+ }
1189+ return 0 ;
1190+ }
0 commit comments