Skip to content

Commit 29a8be9

Browse files
committed
Pinning runtime: out-of-pool resident allocations with budget split
EXPERIMENTAL. Lands the runtime half of pinning. Commit 9b will add the public partitioner kwarg, the public ``weight_offload_budget_mb`` runtime spec, and the ``--cuda_runtime_spec`` runner flag that consume what this commit ships. A pinned FQN now: * Has a probe op in the schedule like any other FQN (no change — pass behavior unchanged). * Is allocated once at Session::create via out-of-pool ``cudaMalloc`` + a synchronous ``cudaMemcpyAsync`` from the pinned host mirror + ``cudaStreamSynchronize``. Lives for the Session lifetime in ``pinned_``; freed in the dtor between dummies cleanup and host-mirror free. * At serve(): the pinned fast path bypasses the pool, event waits, and streaming stats — but STILL calls ``opportunistic_prefetch(probe_id)`` so a pinned→streaming transition doesn't lose overlap. * At prefetch lookup: short-circuited (pinned FQNs are already resident; no work needed). Budget accounting splits cleanly: * ``total_budget_bytes_`` — what the user configured (or the default, which is now ``floor + pinned_bytes`` when no spec is provided so a no-spec default never starves pinning). * ``pinned_bytes_total_`` — sum of payload.pin_fqns logical nbytes, computed from VALIDATED metadata (before any GPU work). * ``streaming_budget_bytes_ = total - pinned`` — the cap the miss-path and prefetch-path eviction loops compare against. The pool's release threshold (soft) also uses this. Floor check at init becomes ``streaming_budget >= payload.floor_bytes`` (the pass-computed floor already excludes pinned). Below-floor budgets hard-fail with a descriptive message naming pinned bytes, streaming floor, and required total — "Weight offloading needs at least X bytes... pinned: Y bytes, streaming pool floor: Z bytes, required total: X" — so the user knows exactly what to set. Three-layer dedupe on payload.pin_fqns to prevent double-accounting / overwrite: 1. Pass-side: ``_apply_weight_offload`` dedupes ``pin_fqns`` first-seen-order before payload serialization. 2. Runtime parse: cuda_backend.cpp hard-fails at parse if ``payload.pin_fqns`` contains duplicates (corrupted / hand-rolled artifact protection). 3. Allocation: Session::create's ``pinned_.emplace(fqn, dev)`` asserts not-already-inserted as a third-layer guard. Stream/release-threshold contract: the pool's release threshold is set to ``streaming_budget_bytes_`` (not total) so that requested live offload bytes are capped at ``pinned + streaming = total``. The threshold is SOFT — driver- reserved / pool cache memory may briefly exceed that — but the accounting invariant ``peak_live_bytes <= streaming_budget`` stays self-consistent. Refactoring + cleanup: * Extracted ``Session::wrap_borrowed_tensor`` (used in 4 sites now: hit, miss, pinned, and the prefetched-then- consumed hit path). * Renamed ``budget_bytes_`` to ``total_budget_bytes_`` for clarity; existing accessor renamed to ``total_budget_bytes()`` with new ``pinned_bytes_total()`` and ``streaming_budget_bytes()`` siblings. * Stats log extended with ``pinned_bytes=Pn streaming_budget=Sb`` fields. ``_STATS_RE`` in test_weight_offload_pool.py extends to capture both. Tests: * Existing 6 pool tests + 9 transport tests + 4 catalog tests still pass; their assertions for "pin_fqns hard-fails" flip to "pin_fqns now succeeds". * NEW ``test_pinning_default_budget_covers_pinned`` (pool): no explicit budget + non-empty pin_fqns succeeds; stats show ``streaming + pinned == total`` and ``streaming >= floor``. Validates the v3 default-budget-with-pins fix. * NEW ``test_pinning_pinned_fqn_resident_no_streaming_h2d`` (pool): pin w1, run, assert ``bytes_h2d == 16384`` (w2 only), ``pinned_bytes == 16384``. Confirms pinned allocations bypass the streaming pool entirely. * NEW ``test_pinning_pinned_then_streaming_still_prefetches`` (pool): asserts ``prefetch_attempted >= 1`` even when one of the two probes hits the pinned fast path — proves the pinned fast path still calls opportunistic_prefetch. * ``test_runtime_rejects_nonempty_pin_fqns`` → renamed ``test_runtime_accepts_nonempty_pin_fqns`` and updated to assert the success summary is emitted. * ``test_hard_fails_when_pin_fqns_set`` (catalog) removed; coverage moved to the transport + pool side. * ``test_pinning_below_floor_with_pinned_hard_fails`` is deferred to 9b — needs a way to inject a sub-required budget (the ``--cuda_runtime_spec`` runner flag or a C++ Module harness, both landing in 9b). Banner updates: * session.h: "POOL+LRU+DUMMIES+PREFETCH WIRED" entry under "Resolved" gains a "Pinning (commit 9a)" subsection. The deferred-items list drops the pinning bullet and adds the 9b public-knob bullet. * weight_offload.h: "Resolved in commit 9a" block added spelling out the pinning contract + the streaming-only release-threshold rationale. * weight_offload_pass.py: docstring flipped from "pinning is hard-failed" to the new commit-9a behavior.
1 parent e35ef0c commit 29a8be9

8 files changed

Lines changed: 599 additions & 160 deletions

File tree

backends/cuda/passes/weight_offload_pass.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@
3131
the pinned host mirror + serves probes. No public partitioner
3232
kwarg yet; the only opt-in callers are this stack's own tests.
3333
34-
Pinning (``pin_fqns``) is hard-failed at init for now and lands
35-
with the public partitioner kwarg. Multi-device offload is hard-
36-
failed at init (commit 7 only supports device 0).
34+
Pinning (``pin_fqns``) is supported as of commit 9a: the runtime
35+
allocates each pinned weight once via out-of-pool ``cudaMalloc``
36+
+ a synchronous H2D, then serves it through a resident fast path.
37+
The pass deduplicates pin_fqns before serialization so the
38+
runtime never sees duplicates. Multi-device offload is still
39+
hard-failed at init (commit 7 only supports device 0).
3740
3841
Schedule / cursor order -- RESOLVED:
3942
The probe op carries an explicit ``probe_id: int`` argument assigned
@@ -763,7 +766,18 @@ def _apply_weight_offload(
763766
stack's own tests. See the EXPERIMENTAL banner at the top of
764767
this module for the current wiring state.
765768
"""
766-
pin_fqns = list(pin_fqns or [])
769+
# Canonicalize pin_fqns: dedupe while preserving first-seen
770+
# order so the payload is stable. The commit-9a runtime
771+
# hard-fails on duplicates at parse time; deduping here keeps
772+
# harmless caller mistakes from reaching that hard-fail.
773+
raw_pins = list(pin_fqns or [])
774+
pin_fqns = []
775+
_seen_pin = set()
776+
for fqn in raw_pins:
777+
if fqn in _seen_pin:
778+
continue
779+
_seen_pin.add(fqn)
780+
pin_fqns.append(fqn)
767781
graph = exported_program.graph_module.graph
768782

769783
# Re-entering the pass would wrap each probe's input (still a

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -439,18 +439,27 @@ class ET_EXPERIMENTAL CudaBackend final
439439
// fetching — no point allocating GPU state for a config we'd
440440
// throw away.
441441
//
442-
// Pinning is deferred to commit 9. Session::create also
443-
// hard-fails on non-empty pin_fqns, but pulling the check
444-
// forward here saves the entire AOTI container + catalog walk
445-
// + dummy install before we discover the failure.
446-
if (!offload_payload.pin_fqns.empty()) {
447-
std::fprintf(
448-
stderr,
449-
"[ET_WEIGHT_OFFLOAD][ERROR] method '%s' specifies pin_fqns "
450-
"but pinning is deferred to a future commit; drop pin_fqns "
451-
"or pass empty\n",
452-
method_name.c_str());
453-
return Error::InvalidArgument;
442+
// Deduplicate pin_fqns (commit 9a). The pass + partitioner
443+
// are supposed to emit a canonical list, but a corrupted or
444+
// hand-rolled payload could repeat. Hard-fail at parse time
445+
// so accounting / allocation downstream can rely on a
446+
// 1:1 fqn↔allocation mapping. Session::create's
447+
// pinned_.emplace() is the second-layer guard.
448+
{
449+
std::unordered_set<std::string> pin_seen;
450+
pin_seen.reserve(offload_payload.pin_fqns.size());
451+
for (const auto& fqn : offload_payload.pin_fqns) {
452+
if (!pin_seen.insert(fqn).second) {
453+
std::fprintf(
454+
stderr,
455+
"[ET_WEIGHT_OFFLOAD][ERROR] method '%s' has duplicate "
456+
"FQN '%s' in payload.pin_fqns; the partitioner should "
457+
"have deduplicated before serializing\n",
458+
method_name.c_str(),
459+
fqn.c_str());
460+
return Error::InvalidArgument;
461+
}
462+
}
454463
}
455464

456465
// Single-device constraint for commit 7. The CUDA backend's
@@ -1124,14 +1133,39 @@ class ET_EXPERIMENTAL CudaBackend final
11241133
session_catalog.emplace(fqn, std::move(info));
11251134
}
11261135

1136+
// Compute pinned_bytes_total from payload metadata BEFORE
1137+
// resolving the default budget. With non-empty pin_fqns, the
1138+
// no-spec default must be `floor + pinned` (not just floor),
1139+
// since the floor formula already excludes pinned weights —
1140+
// a `floor`-only default would leave no room for the pinned
1141+
// allocations and trip the streaming-vs-floor check.
1142+
uint64_t pinned_bytes_total = 0;
1143+
for (const auto& fqn : offload_payload.pin_fqns) {
1144+
// Coverage check above already verified the set; defensive.
1145+
auto m_it = fqn_to_meta.find(fqn);
1146+
if (m_it == fqn_to_meta.end()) {
1147+
std::fprintf(
1148+
stderr,
1149+
"[ET_WEIGHT_OFFLOAD][ERROR] pin_fqn '%s' missing from "
1150+
"payload metadata (should have been caught upstream)\n",
1151+
fqn.c_str());
1152+
delete handle;
1153+
return Error::Internal;
1154+
}
1155+
pinned_bytes_total += m_it->second->nbytes;
1156+
}
1157+
11271158
// Resolve the per-load budget from the runtime-spec channel.
1128-
// Default = payload.floor_bytes. Override via the private
1159+
// Default = payload.floor_bytes + pinned_bytes_total when no
1160+
// spec is provided. Override via the private
11291161
// ``_weight_offload_internal_budget_bytes`` runtime spec.
1130-
uint64_t resolved_budget = offload_payload.floor_bytes;
1162+
// (9b adds the public ``weight_offload_budget_mb`` spec.)
1163+
uint64_t resolved_budget = 0;
1164+
bool budget_explicitly_provided = false;
11311165
auto budget_res = context.get_runtime_spec<const char*>(
11321166
"_weight_offload_internal_budget_bytes");
11331167
if (budget_res.error() == Error::NotFound) {
1134-
// Spec absent — keep the default.
1168+
// Spec absent — fall through to default.
11351169
} else if (budget_res.error() == Error::InvalidArgument) {
11361170
std::fprintf(
11371171
stderr,
@@ -1179,6 +1213,14 @@ class ET_EXPERIMENTAL CudaBackend final
11791213
return Error::InvalidArgument;
11801214
}
11811215
resolved_budget = static_cast<uint64_t>(parsed);
1216+
budget_explicitly_provided = true;
1217+
}
1218+
1219+
// Default: cover the streaming floor PLUS pinned bytes. With
1220+
// no pins, this matches the pre-9a behavior of
1221+
// ``resolved_budget = floor_bytes``.
1222+
if (!budget_explicitly_provided) {
1223+
resolved_budget = offload_payload.floor_bytes + pinned_bytes_total;
11821224
}
11831225

11841226
auto session_res = weight_offload::Session::create(

0 commit comments

Comments
 (0)