Skip to content

Commit e35ef0c

Browse files
committed
Add depth-1 opportunistic prefetch to the offload Session
EXPERIMENTAL. After every successful Session::serve(probe_id), opportunistic_prefetch() issues an asynchronous H2D for schedule_[(probe_id + 1) % size()]. With wraparound: the last probe of a forward pass prefetches probe_id 0 of the next pass — free warmup for autoregressive decode, one wasted prefetch for one-shot inference. Per the v3 RFC microbenchmarks (Phase 4) + design header choice 4, depth-1 is the depth current measurements support in both regimes: when compute hides PCIe, one-ahead saturates the overlap budget; when PCIe dominates, the copy stream is already serializing back-to-back and deeper queueing doesn't change throughput. Re-measure if hardware/workload shifts meaningfully. Pieces: * SessionStats grows prefetch_attempted / prefetch_succeeded counters. ``attempted`` bumps BEFORE the H2D is issued; ``succeeded`` bumps AFTER cudaMemcpyAsync is queued on copy_stream_. ``attempted - succeeded`` = swallowed errors. Stats log line extends with both fields; the ``_STATS_RE`` regex in test_weight_offload_pool.py captures them. * Session::opportunistic_prefetch() is the new private member. Skips immediately if the target is already live (same FQN case, including 1-FQN-schedule wraparound). Defensive "never evict current_fqn" guard catches the narrow case where pick_lru would target the single immediately-just- served FQN — only protects that case, NOT the general below-floor scenario (a fused kernel with probes for A and B sharing one launch could still have A evicted by a prefetch after B if the floor invariant were violated). The floor hard-fail at init remains the real general safety contract. * Stream-ordering invariant extended: every cudaFreeAsync(e.dev_ptr, compute_stream_) is now preceded by cudaStreamWaitEvent(compute_stream_, e.ready_event, 0). This was implicit pre-commit-8 because every live entry's ready_event had been waited on by the prior serve's hit path — commit 8 introduces prefetched entries whose ready_event is NOT waited on until the NEXT serve consumes them, so the wait must be made explicit. Applied to three sites: prefetch eviction (new), miss-path eviction (retrofit), and ~Session()'s live-cleanup loop (retrofit). * Post-eviction event-batch failure path falls back to cudaStreamSynchronize(compute_stream_). When the batch's cudaEventCreate / cudaEventRecord / cudaStreamWaitEvent(copy_stream_, evict_done) fails AFTER live_/bytes_in_flight_ have been mutated to reflect the evictions, returning Error::Internal alone would leave the Session in a state where a subsequent cudaMallocFromPoolAsync on copy_stream_ races the pending cudaFreeAsyncs on compute_stream_. The sync guarantees the frees physically complete before return. Cheap insurance for a rare error path; applied to both miss-path and prefetch-path eviction batches. Banner flips from "POOL+LRU+DUMMIES WIRED" to "POOL+LRU+DUMMIES+PREFETCH WIRED" in both session.h and weight_offload.h. The "Depth-1 prefetch" entry moves from "NOT YET WIRED" to the resolved list. Tests: * Existing 5 pool tests still pass. * NEW ``test_prefetch_converts_second_probe_to_pool_hit``: on _TwoWeightModel (2 distinct probed FQNs) under a budget that comfortably fits 2+ weights, asserts pool_misses == 1 (just the first cold weight) and prefetch_succeeded >= 1 — proving the second probe hit because the prior serve prefetched it. Test name and docstring make explicit that "pool hit" doesn't mean "no stall": the hit path still does cudaStreamWaitEvent on the ready_event, so the consuming kernel can stall briefly if the prefetch H2D hasn't finished. A true no-stall assertion needs wall-clock measurement (separate workstream).
1 parent 5697cd2 commit e35ef0c

4 files changed

Lines changed: 381 additions & 15 deletions

File tree

backends/cuda/runtime/weight_offload/session.cpp

Lines changed: 270 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,18 @@ Session::~Session() {
497497

498498
for (auto& [fqn, e] : live_) {
499499
if (e.dev_ptr != nullptr && compute_stream_ != nullptr) {
500+
// Stream-order the free behind the entry's H2D. Required
501+
// since commit 8 introduced prefetched entries whose
502+
// ready_event has not been waited on by any prior serve()
503+
// — without this, cudaFreeAsync on compute_stream_ would
504+
// queue the free before the in-flight cudaMemcpyAsync on
505+
// copy_stream_ drains, freeing memory still being written
506+
// to. cudaStreamWaitEvent is a queued cross-stream
507+
// dependency (not a host sync), and a no-op if the event
508+
// has already completed.
509+
if (e.ready_event != nullptr) {
510+
cudaStreamWaitEvent(compute_stream_, e.ready_event, 0);
511+
}
500512
cudaFreeAsync(e.dev_ptr, compute_stream_);
501513
}
502514
if (e.ready_event != nullptr) {
@@ -538,15 +550,17 @@ Session::~Session() {
538550
stderr,
539551
"[ET_WEIGHT_OFFLOAD_STATS] method=%s hits=%llu misses=%llu "
540552
"evictions=%llu bytes_h2d=%llu peak_live_bytes=%llu budget=%llu "
541-
"floor=%llu\n",
553+
"floor=%llu prefetch_attempted=%llu prefetch_succeeded=%llu\n",
542554
method_name_.c_str(),
543555
static_cast<unsigned long long>(stats_.pool_hits),
544556
static_cast<unsigned long long>(stats_.pool_misses),
545557
static_cast<unsigned long long>(stats_.evictions),
546558
static_cast<unsigned long long>(stats_.bytes_h2d_copied),
547559
static_cast<unsigned long long>(peak_live_bytes_),
548560
static_cast<unsigned long long>(budget_bytes_),
549-
static_cast<unsigned long long>(floor_bytes_));
561+
static_cast<unsigned long long>(floor_bytes_),
562+
static_cast<unsigned long long>(stats_.prefetch_attempted),
563+
static_cast<unsigned long long>(stats_.prefetch_succeeded));
550564
}
551565
}
552566

@@ -660,6 +674,10 @@ ::executorch::runtime::Error Session::serve(
660674
return err == Error::Ok ? Error::Internal : err;
661675
}
662676
*output = wrapped;
677+
// Best-effort depth-1 prefetch. Errors are logged inside the
678+
// helper and never propagated — the current probe is already
679+
// populated in *output.
680+
(void)opportunistic_prefetch(probe_id);
663681
return Error::Ok;
664682
}
665683

@@ -687,6 +705,23 @@ ::executorch::runtime::Error Session::serve(
687705
return Error::Internal;
688706
}
689707
auto& v = victim_it->second;
708+
// Stream-order the free behind the entry's H2D — required for
709+
// prefetched entries whose ready_event hasn't been waited on
710+
// yet (no prior serve has consumed them). Hit-path entries
711+
// already had their ready_event waited on, so this is a no-op
712+
// for them. See "Existing miss-path eviction + ~Session() need
713+
// the same wait" in the commit-8 subplan.
714+
cudaError_t wait_err =
715+
cudaStreamWaitEvent(compute_stream_, v.ready_event, 0);
716+
if (wait_err != cudaSuccess) {
717+
std::fprintf(
718+
stderr,
719+
"[ET_WEIGHT_OFFLOAD][ERROR] cudaStreamWaitEvent on victim "
720+
"'%s' before eviction failed: %s\n",
721+
victim_it->first.c_str(),
722+
cudaGetErrorString(wait_err));
723+
return Error::Internal;
724+
}
690725
// Check cudaFreeAsync — failure here means the device pointer is
691726
// NOT being freed, so we must NOT decrement bytes_in_flight_ or
692727
// erase from live_, otherwise our accounting diverges from the
@@ -715,38 +750,52 @@ ::executorch::runtime::Error Session::serve(
715750
// One event per eviction batch — all cudaFreeAsyncs are on
716751
// compute_stream_ and stream-ordered relative to each other,
717752
// so a single event after the batch covers them all.
753+
//
754+
// Any failure here is delicate: live_/bytes_in_flight_ have
755+
// already been mutated to reflect the evictions, but if we
756+
// don't establish the copy_stream_ ↔ compute_stream_ ordering
757+
// a subsequent cudaMallocFromPoolAsync on copy_stream_ may
758+
// race the still-pending cudaFreeAsyncs. Fallback on every
759+
// failure: cudaStreamSynchronize(compute_stream_) so the
760+
// frees are GUARANTEED physically complete before we return —
761+
// the Session state stays consistent at the cost of a brief
762+
// host block. Cheap insurance for a rare error path.
718763
cudaEvent_t evict_done = nullptr;
719764
cudaError_t ev_err =
720765
cudaEventCreateWithFlags(&evict_done, cudaEventDisableTiming);
721766
if (ev_err != cudaSuccess) {
722767
std::fprintf(
723768
stderr,
724769
"[ET_WEIGHT_OFFLOAD][ERROR] cudaEventCreate for eviction batch "
725-
"failed: %s\n",
770+
"failed: %s; falling back to cudaStreamSynchronize(compute) to "
771+
"guarantee the frees physically complete before return\n",
726772
cudaGetErrorString(ev_err));
773+
(void)cudaStreamSynchronize(compute_stream_);
727774
return Error::Internal;
728775
}
729776
cudaError_t rec_err = cudaEventRecord(evict_done, compute_stream_);
730777
if (rec_err != cudaSuccess) {
731778
std::fprintf(
732779
stderr,
733780
"[ET_WEIGHT_OFFLOAD][ERROR] cudaEventRecord for eviction batch "
734-
"failed: %s; the subsequent cudaMallocFromPoolAsync would be "
735-
"unordered against the cudaFreeAsync, risking reuse before free\n",
781+
"failed: %s; falling back to cudaStreamSynchronize(compute) to "
782+
"guarantee the frees physically complete before return\n",
736783
cudaGetErrorString(rec_err));
737784
cudaEventDestroy(evict_done);
785+
(void)cudaStreamSynchronize(compute_stream_);
738786
return Error::Internal;
739787
}
740788
cudaError_t wait_err = cudaStreamWaitEvent(copy_stream_, evict_done, 0);
741789
if (wait_err != cudaSuccess) {
742790
std::fprintf(
743791
stderr,
744792
"[ET_WEIGHT_OFFLOAD][ERROR] cudaStreamWaitEvent for eviction "
745-
"batch on copy_stream failed: %s; the subsequent "
746-
"cudaMallocFromPoolAsync would be unordered against the "
747-
"cudaFreeAsync\n",
793+
"batch on copy_stream failed: %s; falling back to "
794+
"cudaStreamSynchronize(compute) to guarantee the frees "
795+
"physically complete before return\n",
748796
cudaGetErrorString(wait_err));
749797
cudaEventDestroy(evict_done);
798+
(void)cudaStreamSynchronize(compute_stream_);
750799
return Error::Internal;
751800
}
752801
cudaEventDestroy(evict_done);
@@ -858,6 +907,219 @@ ::executorch::runtime::Error Session::serve(
858907
stats_.bytes_h2d_copied += need;
859908

860909
*output = wrapped;
910+
// Best-effort depth-1 prefetch. Errors are logged inside the
911+
// helper and never propagated — the current probe is already
912+
// populated in *output.
913+
(void)opportunistic_prefetch(probe_id);
914+
return Error::Ok;
915+
}
916+
917+
::executorch::runtime::Error Session::opportunistic_prefetch(
918+
int64_t current_probe_id) {
919+
using ::executorch::runtime::Error;
920+
921+
if (schedule_.empty()) {
922+
return Error::Ok;
923+
}
924+
const int64_t next_id =
925+
(current_probe_id + 1) % static_cast<int64_t>(schedule_.size());
926+
const std::string& fqn = schedule_[static_cast<size_t>(next_id)];
927+
928+
// Step 1: already-live → no work needed.
929+
if (live_.find(fqn) != live_.end()) {
930+
return Error::Ok;
931+
}
932+
933+
// Step 1b: defensive guard — never evict the FQN the c-shim is
934+
// about to hand back to AOTI for kernel launch. The floor formula
935+
// (budget >= bytes(current) + bytes(next)) should make this
936+
// unreachable today; the guard catches the single-immediately-
937+
// just-served-FQN case if a future commit ever allows budgets
938+
// below the floor. Narrow protection — it does NOT cover
939+
// multi-probe-before-one-launch (fused kernels with probes A, B
940+
// before one launch could still have A evicted by a prefetch
941+
// after B if the floor invariant were violated). The floor
942+
// hard-fail at init is the real general contract.
943+
const std::string& current_fqn =
944+
schedule_[static_cast<size_t>(current_probe_id)];
945+
946+
auto host_it = host_entries_.find(fqn);
947+
if (host_it == host_entries_.end()) {
948+
std::fprintf(
949+
stderr,
950+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch skipped: no host mirror "
951+
"for FQN '%s'\n",
952+
fqn.c_str());
953+
return Error::Internal;
954+
}
955+
const HostEntry& host = host_it->second;
956+
const uint64_t need = host.nbytes;
957+
958+
// From here on a real prefetch is attempted. Count the attempt
959+
// regardless of whether it succeeds — `attempted - succeeded`
960+
// = swallowed errors.
961+
stats_.prefetch_attempted++;
962+
963+
// Step 2: eviction (same logic as the miss path, plus the
964+
// current_fqn guard and the stream-order wait before each free).
965+
bool evicted = false;
966+
while (bytes_in_flight_ + need > budget_bytes_) {
967+
auto victim_it = pick_lru();
968+
if (victim_it == live_.end()) {
969+
std::fprintf(
970+
stderr,
971+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch skipped: no evictable "
972+
"entries for FQN '%s' (%llu bytes, budget %llu)\n",
973+
fqn.c_str(),
974+
static_cast<unsigned long long>(need),
975+
static_cast<unsigned long long>(budget_bytes_));
976+
return Error::Internal;
977+
}
978+
if (victim_it->first == current_fqn) {
979+
// Floor formula should prevent this; if it ever fires, skip
980+
// the prefetch instead of risking corruption of the
981+
// just-served entry.
982+
std::fprintf(
983+
stderr,
984+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch skipped: LRU victim "
985+
"for FQN '%s' would be the just-served '%s'; floor "
986+
"formula may be violated\n",
987+
fqn.c_str(),
988+
current_fqn.c_str());
989+
return Error::Internal;
990+
}
991+
auto& v = victim_it->second;
992+
cudaError_t wait_err =
993+
cudaStreamWaitEvent(compute_stream_, v.ready_event, 0);
994+
if (wait_err != cudaSuccess) {
995+
std::fprintf(
996+
stderr,
997+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch eviction "
998+
"cudaStreamWaitEvent on victim '%s' failed: %s; prefetch "
999+
"skipped\n",
1000+
victim_it->first.c_str(),
1001+
cudaGetErrorString(wait_err));
1002+
return Error::Internal;
1003+
}
1004+
cudaError_t free_err = cudaFreeAsync(v.dev_ptr, compute_stream_);
1005+
if (free_err != cudaSuccess) {
1006+
std::fprintf(
1007+
stderr,
1008+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch eviction cudaFreeAsync "
1009+
"for victim '%s' failed: %s; prefetch skipped\n",
1010+
victim_it->first.c_str(),
1011+
cudaGetErrorString(free_err));
1012+
return Error::Internal;
1013+
}
1014+
cudaEventDestroy(v.ready_event);
1015+
bytes_in_flight_ -= v.nbytes;
1016+
stats_.evictions++;
1017+
live_.erase(victim_it);
1018+
evicted = true;
1019+
}
1020+
if (evicted) {
1021+
// Mirror the miss-path event ordering: one event per eviction
1022+
// batch, made copy_stream_ wait on it before allocating.
1023+
//
1024+
// Failure here is the same hazard as the miss-path equivalent:
1025+
// live_/bytes_in_flight_ have been mutated to reflect the
1026+
// evictions, but without the copy_stream_ ↔ compute_stream_
1027+
// ordering, a subsequent cudaMallocFromPoolAsync would race
1028+
// the pending cudaFreeAsyncs. Even though serve() ignores our
1029+
// return value (this is best-effort), Session state must stay
1030+
// consistent for the NEXT serve(). Fall back to
1031+
// cudaStreamSynchronize(compute_stream_) so the frees are
1032+
// guaranteed done before we return.
1033+
cudaEvent_t evict_done = nullptr;
1034+
cudaError_t ev_err =
1035+
cudaEventCreateWithFlags(&evict_done, cudaEventDisableTiming);
1036+
if (ev_err != cudaSuccess) {
1037+
std::fprintf(
1038+
stderr,
1039+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch cudaEventCreate for "
1040+
"eviction batch failed: %s; syncing compute_stream and "
1041+
"skipping prefetch\n",
1042+
cudaGetErrorString(ev_err));
1043+
(void)cudaStreamSynchronize(compute_stream_);
1044+
return Error::Internal;
1045+
}
1046+
if (cudaEventRecord(evict_done, compute_stream_) != cudaSuccess ||
1047+
cudaStreamWaitEvent(copy_stream_, evict_done, 0) != cudaSuccess) {
1048+
cudaEventDestroy(evict_done);
1049+
std::fprintf(
1050+
stderr,
1051+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch event-record/wait for "
1052+
"eviction batch failed; syncing compute_stream and skipping "
1053+
"prefetch\n");
1054+
(void)cudaStreamSynchronize(compute_stream_);
1055+
return Error::Internal;
1056+
}
1057+
cudaEventDestroy(evict_done);
1058+
}
1059+
1060+
// Step 3: allocate + copy on copy_stream_. SAME as miss path
1061+
// except we do NOT cudaStreamWaitEvent(compute_, ready) — the
1062+
// next serve() that consumes this entry as a hit does that.
1063+
void* dev = nullptr;
1064+
cudaError_t malloc_err =
1065+
cudaMallocFromPoolAsync(&dev, need, pool_, copy_stream_);
1066+
if (malloc_err != cudaSuccess) {
1067+
std::fprintf(
1068+
stderr,
1069+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch cudaMallocFromPoolAsync "
1070+
"for FQN '%s' (%llu bytes) failed: %s; prefetch skipped\n",
1071+
fqn.c_str(),
1072+
static_cast<unsigned long long>(need),
1073+
cudaGetErrorString(malloc_err));
1074+
return Error::Internal;
1075+
}
1076+
auto free_on_error = [&]() {
1077+
if (dev != nullptr) {
1078+
cudaFreeAsync(dev, copy_stream_);
1079+
dev = nullptr;
1080+
}
1081+
};
1082+
1083+
if (cudaMemcpyAsync(
1084+
dev, host.host_ptr, need, cudaMemcpyHostToDevice, copy_stream_) !=
1085+
cudaSuccess) {
1086+
std::fprintf(
1087+
stderr,
1088+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch cudaMemcpyAsync for FQN "
1089+
"'%s' failed; prefetch skipped\n",
1090+
fqn.c_str());
1091+
free_on_error();
1092+
return Error::Internal;
1093+
}
1094+
cudaEvent_t ready = nullptr;
1095+
if (cudaEventCreateWithFlags(&ready, cudaEventDisableTiming) != cudaSuccess) {
1096+
std::fprintf(
1097+
stderr,
1098+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch cudaEventCreate for FQN "
1099+
"'%s' ready event failed; prefetch skipped\n",
1100+
fqn.c_str());
1101+
free_on_error();
1102+
return Error::Internal;
1103+
}
1104+
if (cudaEventRecord(ready, copy_stream_) != cudaSuccess) {
1105+
cudaEventDestroy(ready);
1106+
free_on_error();
1107+
std::fprintf(
1108+
stderr,
1109+
"[ET_WEIGHT_OFFLOAD][WARN] prefetch cudaEventRecord for FQN "
1110+
"'%s' failed; prefetch skipped\n",
1111+
fqn.c_str());
1112+
return Error::Internal;
1113+
}
1114+
1115+
bytes_in_flight_ += need;
1116+
peak_live_bytes_ = std::max(peak_live_bytes_, bytes_in_flight_);
1117+
// Treat the prefetched entry as "newest" — see commit-8 subplan
1118+
// for the option-(b) rationale (the FQN is about to be served
1119+
// next, so its expected next-use is sooner than the just-served).
1120+
live_.emplace(fqn, LiveAllocation{dev, need, ready, next_seq_++});
1121+
stats_.bytes_h2d_copied += need;
1122+
stats_.prefetch_succeeded++;
8611123
return Error::Ok;
8621124
}
8631125

0 commit comments

Comments
 (0)