Skip to content

Commit 7e548ae

Browse files
committed
Offload cleanup (round 6): trim error UX, module docstring, design header
Net -214 lines, three small items from the latest review pass. Collapse the 75-line below-floor error message to a one-paragraph form (Session::create's budget check). The two-branch "MB vs bytes" formatter with inline rounding was a UX nicety for a single error path; the new message keeps the substrings tests assert on ("pinned constants", "streaming pool floor", "required total", "set via <spec>=<value>", "Set weight_offload_budget_mb >= <mb>") and drops everything else. Trim weight_offload_pass.py module docstring 68 -> 13 lines. The removed prose was a list of "RESOLVED" design notes (schedule order, probe-op preservation, payload transport) that read like a status report. The code already embodies the decisions; the docstring now states what the pass does and where the runtime lives. Delete weight_offload.h. It was a 109-line README-as-header duplicating session.h's overview after several rounds of trim, with no code and no other includer. Folded a tightened version of the overview into session.h's banner. Also collapsed probe_op.h's 18-line "DISPATCH WIRED, RUNTIME BODY TBD" banner -- the runtime body has shipped; the banner described a state from several commits ago. 61 of 61 offload tests pass; lint clean. Authored with Claude.
1 parent bf4da7a commit 7e548ae

5 files changed

Lines changed: 56 additions & 270 deletions

File tree

backends/cuda/passes/weight_offload_pass.py

Lines changed: 13 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""Weight offloading pass for the CUDA backend.
8-
9-
EXPERIMENTAL. The runtime now installs pre-loaded GPU dummies in
10-
place of AOTI's constants, skips ``update_constants_from_blob``
11-
entirely, sources host bytes directly from the ``_weights_blob``
12-
NamedData entry, and serves probe calls from a bounded pool with
13-
LRU + event-ordered free-on-compute-stream. Steady-state peak GPU
14-
= offload pool only.
15-
16-
What's wired today:
17-
* The probe op + AOTI c-shim dispatch path. The c-shim is the
18-
Session::serve callback for opt-in methods; identity-passthrough
19-
for non-opt-in.
20-
* The ``_apply_weight_offload`` graph rewrite + v2 payload schema
21-
(per-FQN dtype/sizes/strides/storage_offset/nbytes/device, plus
22-
schedule + floor + pin_fqns).
23-
* The public partitioner kwargs ``weight_offload=True`` and
24-
``weight_offload_pin_fqns=[...]`` on ``CudaPartitioner``
25-
translate to a private compile-spec channel
26-
(``_weight_offload_internal_enable``,
27-
``_weight_offload_internal_pin_fqns``) that routes the payload
28-
from ``CudaBackend.preprocess`` ->
29-
``AotiBackend.preprocess``'s
30-
``pre_aoti_transform_and_collect_named_data`` hook ->
31-
``NamedDataStore`` -> ``CudaBackend::init``, where the runtime
32-
parses + cross-checks against AOTI + installs dummies + builds
33-
the pinned host mirror + serves probes. The internal compile
34-
spec keys stay underscore-prefixed and remain accessible from
35-
tests for exact-byte budget control.
36-
37-
Pinning (``pin_fqns``): the runtime allocates each pinned weight
38-
once via out-of-pool ``cudaMalloc`` + a synchronous H2D, then
39-
serves it through a resident fast path. The pass deduplicates
40-
pin_fqns before serialization so the runtime never sees duplicates.
41-
Multi-device offload is hard-failed at init (device 0 only today).
42-
43-
Schedule / cursor order -- RESOLVED:
44-
The probe op carries an explicit ``probe_id: int`` argument assigned
45-
by the pass (contiguous 0..N-1 in graph order). The runtime indexes
46-
the schedule directly by ``probe_id``; no cursor is needed. This
47-
closes the entire class of "graph order drifted from wrapper order"
48-
silent failures the prior design left open.
49-
50-
Probe op preservation -- RESOLVED:
51-
The c-shim ``aoti_torch_cuda_probe`` is wired into the CUDA backend's
52-
``aot_inductor.custom_ops_to_c_shims`` registry. AOTI emits one
53-
direct call per probe node in the FX graph; distinct ``probe_id``
54-
constants make otherwise-identical probe calls syntactically distinct
55-
from inductor's CSE pass. The single-consumer and multi-consumer
56-
dispatch contracts are asserted by
57-
``test_weight_offload_probe_dispatch.py``.
58-
59-
Payload transport -- RESOLVED:
60-
The pass runs from ``AotiBackend.preprocess`` (via the
61-
``pre_aoti_transform_and_collect_named_data`` hook in
62-
``CudaBackend``) and the payload ships through ``NamedDataStore``
63-
alongside ``_so_blob`` / ``_weights_blob``, so all AOTI-side
64-
artifacts live in one mechanism. The wire format is a small
65-
custom little-endian binary (see ``_serialize_payload``) — the
66-
C++ parser in ``runtime/weight_offload/payload.h`` mirrors it.
67-
68-
The runtime half lives in ``backends/cuda/runtime/weight_offload/``.
7+
"""EXPERIMENTAL: CUDA weight-offloading pass.
8+
9+
Rewrites every parameter / buffer consumer to read through a
10+
``probe(w, probe_id)`` op, then serializes a v2 payload (schedule
11+
+ floor + pin_fqns + per-FQN dtype / sizes / strides /
12+
storage_offset / nbytes / device) into ``NamedDataStore``. The
13+
runtime half (``backends/cuda/runtime/weight_offload/``) parses
14+
the payload, installs 1-byte GPU dummies in place of AOTI's
15+
constants, and serves probes from a pinned host mirror through a
16+
bounded GPU pool. Single-device only.
17+
18+
Public surface: ``CudaPartitioner(weight_offload=True,
19+
weight_offload_pin_fqns=[...])``.
6920
"""
7021

7122
import struct

backends/cuda/runtime/weight_offload/probe_op.h

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,10 @@
88

99
#pragma once
1010

11-
// ===========================================================================
12-
// EXPERIMENTAL -- DISPATCH WIRED, RUNTIME BODY TBD
13-
// ===========================================================================
14-
// State: the c-shim is registered in
15-
// ``cuda_backend.py::custom_ops_to_c_shims`` and implemented in
16-
// ``probe_op.cpp`` as an identity passthrough (returns a fresh
17-
// ``SlimTensor`` handle sharing the input's storage). No pass inserts
18-
// probe nodes in production, so this shim has no production caller yet;
19-
// it is exercised by
20-
// ``backends/cuda/tests/test_weight_offload_probe_dispatch.py`` with
21-
// hand-rolled probes to validate that AOTI emits one call per FX
22-
// probe node (including the multi-consumer case inductor CSE could
23-
// otherwise elide).
24-
//
25-
// Future work replaces the identity body with a Session-managed lookup:
26-
// ``probe_id`` selects the FQN from the schedule and ``Session::serve``
27-
// returns a SlimTensor backed by pool-managed GPU bytes. See
28-
// ``backends/cuda/runtime/weight_offload/weight_offload.h``.
29-
// ===========================================================================
11+
// EXPERIMENTAL. The c-shim AOTI dispatches to for
12+
// ``executorch_weight_offload::probe(w, probe_id)``. Body lives in
13+
// probe_op.cpp; routes through ProbeRegistry to Session::serve when
14+
// offload is opted in, identity-passthrough otherwise.
3015

3116
#include <cstdint>
3217

backends/cuda/runtime/weight_offload/session.cpp

Lines changed: 22 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -367,86 +367,34 @@ ::executorch::runtime::Result<std::unique_ptr<Session>> Session::create(
367367

368368
const uint64_t required_total = payload.floor_bytes + pinned_bytes_total;
369369

370-
// 3e. Unified below-budget check covers BOTH "pinned alone
371-
// exceeds total" AND "streaming budget < floor": they're
372-
// algebraically the same condition (total < pinned + floor).
373-
// Both get the same BudgetSpec-aware UX message naming the
374-
// spec the user set and the suggested fix.
370+
// total < pinned + floor covers both "pinned alone exceeds total"
371+
// and "streaming budget < floor". Suggest the public spec name for
372+
// the fix (rounded up to whole MB so the value is loadable).
375373
if (session->total_budget_bytes_ < required_total) {
376-
// Pick the spec name + suggested-value-to-set for the hint.
377-
// Defaults are safe: caller leaves budget_spec.name == nullptr
378-
// when no spec was set, in which case we suggest the PUBLIC
379-
// spec since that's what users land on first.
380-
const char* spec_name = (budget_spec.name != nullptr)
381-
? budget_spec.name
382-
: "weight_offload_budget_mb";
383-
// Round required_total UP to whole MB for the public-spec hint
384-
// (so the suggested value is actually loadable). Use
385-
// division/modulo (NOT ``+ (1MiB - 1)``) so the rounding can't
386-
// overflow for required_total values near UINT64_MAX.
387374
constexpr uint64_t kMiB = 1ull << 20;
388-
const uint64_t required_mb_rounded_up =
375+
const uint64_t required_mb =
389376
(required_total / kMiB) + ((required_total % kMiB) != 0 ? 1 : 0);
390-
391-
if (std::strcmp(spec_name, "weight_offload_budget_mb") == 0) {
392-
// Public spec — show MB-rounded values.
393-
std::fprintf(
394-
stderr,
395-
"[ET_WEIGHT_OFFLOAD][ERROR] Weight offloading needs at "
396-
"least %llu MB of GPU memory for method '%s':\n"
397-
" pinned constants: %llu bytes\n"
398-
" streaming pool floor: %llu bytes\n"
399-
" ---\n"
400-
" required total: %llu bytes (~%llu MB)\n"
401-
"but the configured budget is %llu bytes",
402-
static_cast<unsigned long long>(required_mb_rounded_up),
403-
payload.method_name.c_str(),
404-
static_cast<unsigned long long>(session->pinned_bytes_total_),
405-
static_cast<unsigned long long>(payload.floor_bytes),
406-
static_cast<unsigned long long>(required_total),
407-
static_cast<unsigned long long>(required_mb_rounded_up),
408-
static_cast<unsigned long long>(session->total_budget_bytes_));
409-
if (budget_spec.value > 0 && budget_spec.value_is_mb) {
410-
std::fprintf(
411-
stderr,
412-
" (set via weight_offload_budget_mb=%llu)",
413-
static_cast<unsigned long long>(budget_spec.value));
414-
}
415-
std::fprintf(
416-
stderr,
417-
". Set weight_offload_budget_mb >= %llu to load this "
418-
"method.\n",
419-
static_cast<unsigned long long>(required_mb_rounded_up));
420-
} else {
421-
// Internal spec (test path) — show exact byte values.
422-
std::fprintf(
423-
stderr,
424-
"[ET_WEIGHT_OFFLOAD][ERROR] Weight offloading needs at "
425-
"least %llu bytes of GPU memory for method '%s':\n"
426-
" pinned constants: %llu bytes\n"
427-
" streaming pool floor: %llu bytes\n"
428-
" ---\n"
429-
" required total: %llu bytes\n"
430-
"but the configured budget is %llu bytes",
431-
static_cast<unsigned long long>(required_total),
432-
payload.method_name.c_str(),
433-
static_cast<unsigned long long>(session->pinned_bytes_total_),
434-
static_cast<unsigned long long>(payload.floor_bytes),
435-
static_cast<unsigned long long>(required_total),
436-
static_cast<unsigned long long>(session->total_budget_bytes_));
437-
if (budget_spec.value > 0 && !budget_spec.value_is_mb) {
438-
std::fprintf(
439-
stderr,
440-
" (set via %s=%llu)",
441-
spec_name,
442-
static_cast<unsigned long long>(budget_spec.value));
443-
}
377+
std::fprintf(
378+
stderr,
379+
"[ET_WEIGHT_OFFLOAD][ERROR] method '%s': budget %llu bytes < "
380+
"required total %llu (pinned constants %llu + streaming pool "
381+
"floor %llu)",
382+
payload.method_name.c_str(),
383+
static_cast<unsigned long long>(session->total_budget_bytes_),
384+
static_cast<unsigned long long>(required_total),
385+
static_cast<unsigned long long>(session->pinned_bytes_total_),
386+
static_cast<unsigned long long>(payload.floor_bytes));
387+
if (budget_spec.value > 0) {
444388
std::fprintf(
445389
stderr,
446-
". Set %s >= %llu to load this method.\n",
447-
spec_name,
448-
static_cast<unsigned long long>(required_total));
390+
" [set via %s=%llu]",
391+
budget_spec.name,
392+
static_cast<unsigned long long>(budget_spec.value));
449393
}
394+
std::fprintf(
395+
stderr,
396+
". Set weight_offload_budget_mb >= %llu.\n",
397+
static_cast<unsigned long long>(required_mb));
450398
return Error::InvalidArgument;
451399
}
452400

backends/cuda/runtime/weight_offload/session.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,23 @@
88

99
#pragma once
1010

11-
// Per-handle CUDA weight-offload Session. See weight_offload.h for
12-
// the cross-component design overview. Lifecycle: created in
13-
// CudaBackend::init after the CUDA stream is in place and the local
14-
// DummyInstallation is built; destroyed first in CudaBackend::destroy
15-
// so the ProbeRegistry unregister + stream drain + free ordering is
16-
// explicit. Single-device only.
11+
// EXPERIMENTAL: per-handle CUDA weight-offload Session.
12+
//
13+
// What it owns: a pinned host mirror of the schedule's constants, a
14+
// bounded cudaMemPool sized by a software byte cap on requested live
15+
// offload bytes, LRU eviction with event-ordered cudaFreeAsync on the
16+
// compute stream, depth-1 opportunistic prefetch on the copy stream,
17+
// optional pinned-resident constants (out-of-pool cudaMalloc), and the
18+
// DummyInstallation that AOTI's container uses as its "active
19+
// constants" so probe ops route every constant read here. See
20+
// session.cpp for the per-step ordering. Single-device only (the
21+
// payload parser hard-fails device_index != 0).
22+
//
23+
// Companion components: payload.{h} (on-wire schema parser, the single
24+
// trust boundary), probe_op.{h,cpp} (AOTI c-shim), probe_registry.{h,cpp}
25+
// (process-global dummy_ptr -> Session map). Public entry from the
26+
// pass side: ``CudaPartitioner(weight_offload=True,
27+
// weight_offload_pin_fqns=[...])``.
1728

1829
#include <cstdint>
1930
#include <list>

backends/cuda/runtime/weight_offload/weight_offload.h

Lines changed: 0 additions & 109 deletions
This file was deleted.

0 commit comments

Comments
 (0)