Skip to content

Commit 9d2a1f1

Browse files
committed
Public offload knobs: partitioner kwarg, weight_offload_budget_mb, runner CLI flag
EXPERIMENTAL. Lands the user-facing API surface. After this commit, the CUDA weight-offload runtime is reachable from outside the stack's own tests: * ``CudaPartitioner(..., weight_offload=True, weight_offload_pin_fqns=[...])`` named kwargs translate to the existing internal compile specs. * ``weight_offload_budget_mb`` runtime spec (int megabytes) accepted by ``cuda_backend.cpp::init`` alongside the existing private byte spec. * ``executor_runner --cuda_runtime_spec=k1=v1,k2=v2`` CLI flag lets tests + manual repros drive the public spec end-to-end. Part A — public partitioner kwarg: ``CudaPartitioner.__init__`` grows two named kwargs: ``weight_offload: bool`` and ``weight_offload_pin_fqns: Optional[List[str]]``. Translation rules in order: 1. Reject pin-without-enable (ValueError). 2. Strict mixed-channel rejection: when ANY public kwarg is non-default, reject ANY raw ``_weight_offload_internal_*`` compile spec entry — not just same-key conflicts. Raw internal specs stay allowed only when both public kwargs are at defaults (preserves the test stack). 3. Dedupe pin_fqns first-seen-order. The runtime parser also hard-fails on duplicates as defense in depth; deduping at the partitioner keeps harmless caller mistakes from reaching that hard-fail. 4. Append the internal compile specs. The four internal key strings are INLINED in ``cuda_partitioner.py`` (not imported from ``weight_offload_pass.py``) to avoid the ``@custom_op`` registration side-effect at import time that would defeat the lazy-import pattern in ``CudaBackend.pre_aoti_transform_and_collect_named_data``. Drift is bounded by ``test_partitioner_internal_keys_match_pass``. Part B — public ``weight_offload_budget_mb`` runtime spec: ``cuda_backend.cpp::init`` tries the public int-MB spec first, falls through to the existing private byte spec, defaults to ``floor_bytes + pinned_bytes_total`` (with checked addition overflow guard) when neither is set. When both are set the public wins so the test path can't accidentally bypass the public route. New ``BudgetSpec`` struct in ``session.h`` carries the spec name + value + value_is_mb flag from the runtime-spec resolution chain into Session::create. The below-floor UX message now: * Names the spec the user actually set (public name for the public path, internal name for the test path; the default-budget path defaults to hinting the public name). * Echoes the user-supplied value (``set via weight_offload_budget_mb=N`` or ``set via _weight_offload_..._bytes=N``). * For the public path, includes an MB-rounded suggested fix (``Set weight_offload_budget_mb >= N``) using division/modulo rounding so the round-up itself can't overflow at uint64 boundaries. * Has a checked-addition guard on ``required_total = floor + pinned`` to match the default-budget guard. Part C — ``executor_runner --cuda_runtime_spec`` CLI flag: Single comma-separated string parsed in ``executor_runner.cpp`` (gflags doesn't natively support repeated flags; comma-splitting internally is simpler). Key-aware parsing via ``kKnownCudaSpecs`` table: * ``weight_offload_budget_mb`` → int * ``_weight_offload_internal_budget_bytes`` → string Unknown keys hard-fail at parse with "known keys: ..." message. Duplicate keys hard-fail at parse. Builds ``std::vector<BackendOption>``, wires through ``LoadBackendOptionsMap::set_options("CudaBackend", Span)`` to the existing-but-currently-nullptr ``backend_options`` arg of ``Program::load_method``. Flag is intentionally CUDA-scoped (``--cuda_runtime_spec`` not ``--backend_option``) because the route feeds load-time backend options for CudaBackend specifically; other backends can add their own ``--<backend>_runtime_spec`` flag if they want similar test access. Tests (8 new + 1 deferred-from-9a un-deferred): Pool side: * ``test_runtime_accepts_public_budget_mb_via_runner_flag``: ``--cuda_runtime_spec=weight_offload_budget_mb=4`` → success summary's ``budget_bytes == 4 << 20``. * ``test_pinning_below_floor_with_pinned_hard_fails``: previously deferred from 9a. ``_LargePinnedModel`` (~1 MB per weight) + ``weight_offload_budget_mb=1`` lands strictly below ``floor + pinned``; init hard-fails with the new UX message format. * ``test_floor_message_names_public_spec_when_user_set``: asserts the error message names the public spec and includes the suggested ``Set weight_offload_budget_mb >= N`` fix line. Partitioner side: * ``test_partitioner_public_kwargs_round_trip``: kwargs produce the expected internal compile specs. * ``test_partitioner_dedupes_pin_fqns``: ``["w1","w2","w1"]`` → ``["w1","w2"]`` first-seen-order. * ``test_partitioner_rejects_pin_without_enable``: ValueError. * ``test_partitioner_rejects_any_mixed_channel``: covers same-key AND different-key conflicts; also covers the raw-without-public-kwarg-still-allowed path. * ``test_partitioner_internal_keys_match_pass``: asserts the inlined key constants equal the canonical pass-side exports. Catches drift at CI time. Banner / docstring updates: * ``weight_offload.h``: banner flips to "OFFLOAD COMPLETE; PUBLIC KNOBS WIRED. MULTI-DEVICE PENDING". "What's NOT YET WIRED" reduces to multi-device only; new "Resolved in commit 9b" section spells out the three public surfaces. * ``session.h``: drops the "commit 9b public knobs" bullet from the deferred list. * ``weight_offload_pass.py``: docstring updated to describe the public partitioner kwarg as the user-facing entry point; the underscore-prefixed compile specs are still documented as accessible from tests for exact-byte budget control.
1 parent 29a8be9 commit 9d2a1f1

9 files changed

Lines changed: 937 additions & 138 deletions

File tree

backends/cuda/cuda_partitioner.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,133 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import final, List
7+
from typing import final, List, Optional
88

99
from executorch.backends.aoti.aoti_partitioner import AotiPartitioner
1010
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
1111
from executorch.exir._warnings import experimental
1212
from executorch.exir.backend.compile_spec_schema import CompileSpec
1313
from executorch.exir.passes.propagate_device_pass import TARGET_DEVICE_COMPILE_SPEC_KEY
1414

15+
# Inlined copies of the internal compile-spec key strings owned by
16+
# ``backends/cuda/passes/weight_offload_pass.py``. We don't import from
17+
# that module because it registers a custom op at import time
18+
# (``@custom_op`` decorator), which would defeat the lazy-import
19+
# pattern in ``CudaBackend.pre_aoti_transform_and_collect_named_data``.
20+
# The drift hazard is bounded by
21+
# ``test_partitioner_internal_keys_match_pass``, which asserts these
22+
# values match the pass-side constants at CI time.
23+
_WEIGHT_OFFLOAD_ENABLE_SPEC_KEY = "_weight_offload_internal_enable"
24+
_WEIGHT_OFFLOAD_PIN_FQNS_SPEC_KEY = "_weight_offload_internal_pin_fqns"
25+
_WEIGHT_OFFLOAD_INTERNAL_KEY_PREFIX = "_weight_offload_internal_"
26+
27+
28+
def _check_pin_fqns_input(weight_offload_pin_fqns) -> List[str]:
29+
"""Type-validate + content-validate the public
30+
``weight_offload_pin_fqns`` argument. Returns a deduped list of
31+
strings (first-seen order). Raises TypeError / ValueError for
32+
malformed input.
33+
34+
Bare ``str`` is rejected because ``list("w1")`` silently becomes
35+
``["w", "1"]`` — a common caller mistake. Empty and NUL-containing
36+
FQNs are rejected because they can't round-trip through the
37+
NUL-separated internal compile spec.
38+
"""
39+
if weight_offload_pin_fqns is None:
40+
return []
41+
if isinstance(weight_offload_pin_fqns, str):
42+
raise TypeError(
43+
"weight_offload_pin_fqns must be a list of strings, not a "
44+
"bare str (a bare str would be split into characters); "
45+
"pass [...] even for a single FQN"
46+
)
47+
# Strictly require a list. The public contract is List[str] and the
48+
# error messages promise first-seen order; accepting dict keys /
49+
# sets / generators would either drop ordering (set iteration is
50+
# not deterministic across runs) or be one-shot (generators
51+
# consumed by the first scan), neither of which matches what
52+
# callers see in the docstring.
53+
if not isinstance(weight_offload_pin_fqns, list):
54+
raise TypeError(
55+
"weight_offload_pin_fqns must be a list of strings; got "
56+
f"{type(weight_offload_pin_fqns).__name__} (cast to list "
57+
"explicitly if you have another iterable type)"
58+
)
59+
60+
seen: set = set()
61+
out: List[str] = []
62+
for fqn in weight_offload_pin_fqns:
63+
if not isinstance(fqn, str):
64+
raise TypeError(
65+
"weight_offload_pin_fqns must contain only strings; "
66+
f"got element of type {type(fqn).__name__}: {fqn!r}"
67+
)
68+
if fqn == "":
69+
raise ValueError(
70+
"weight_offload_pin_fqns contains an empty string; "
71+
"FQNs must be non-empty"
72+
)
73+
if "\x00" in fqn:
74+
raise ValueError(
75+
f"weight_offload_pin_fqns entry {fqn!r} contains an "
76+
f"embedded NUL byte; the internal compile spec is "
77+
f"NUL-separated and cannot round-trip such values"
78+
)
79+
if fqn in seen:
80+
continue
81+
seen.add(fqn)
82+
out.append(fqn)
83+
return out
84+
85+
86+
def _validate_and_translate_weight_offload_kwargs(
87+
compile_spec: List[CompileSpec],
88+
weight_offload: bool,
89+
weight_offload_pin_fqns: Optional[List[str]],
90+
) -> List[CompileSpec]:
91+
"""Translate the public weight-offload kwargs to internal compile
92+
specs with strict validation. Returns the (possibly augmented)
93+
compile_spec list to pass to the base partitioner."""
94+
pin_fqns_list = _check_pin_fqns_input(weight_offload_pin_fqns)
95+
96+
# Reject pin-without-enable.
97+
if pin_fqns_list and not weight_offload:
98+
raise ValueError(
99+
"weight_offload_pin_fqns is set but weight_offload=False; "
100+
"pinning requires enabling weight offload"
101+
)
102+
103+
# Strict mixed-channel rejection: when ANY public weight-offload
104+
# kwarg is non-default, reject ANY raw `_weight_offload_internal_*`
105+
# compile spec. Raw internal specs stay allowed when both public
106+
# kwargs are at default values (preserves the test stack).
107+
if weight_offload or pin_fqns_list:
108+
offenders = [
109+
spec.key
110+
for spec in compile_spec
111+
if spec.key.startswith(_WEIGHT_OFFLOAD_INTERNAL_KEY_PREFIX)
112+
]
113+
if offenders:
114+
raise ValueError(
115+
f"CudaPartitioner: public weight-offload kwargs conflict "
116+
f"with raw {_WEIGHT_OFFLOAD_INTERNAL_KEY_PREFIX}* entries "
117+
f"in compile_spec ({sorted(set(offenders))!r}); use "
118+
f"exactly one channel - either the public kwargs OR the "
119+
f"raw compile_spec, not both"
120+
)
121+
122+
out = list(compile_spec)
123+
if weight_offload:
124+
out.append(CompileSpec(_WEIGHT_OFFLOAD_ENABLE_SPEC_KEY, b"1"))
125+
if pin_fqns_list:
126+
out.append(
127+
CompileSpec(
128+
_WEIGHT_OFFLOAD_PIN_FQNS_SPEC_KEY,
129+
b"\x00".join(f.encode("utf-8") for f in pin_fqns_list),
130+
)
131+
)
132+
return out
133+
15134

16135
@final
17136
@experimental(
@@ -29,6 +148,9 @@ class CudaPartitioner(AotiPartitioner):
29148
def __init__(
30149
self,
31150
compile_spec: List[CompileSpec],
151+
*,
152+
weight_offload: bool = False,
153+
weight_offload_pin_fqns: Optional[List[str]] = None,
32154
) -> None:
33155
"""
34156
Initialize the CUDA partitioner.
@@ -38,13 +160,35 @@ def __init__(
38160
target CUDA device, include a CompileSpec with key
39161
"target_device" (e.g., value "cuda:1"). If not
40162
provided, defaults to "cuda:0".
163+
weight_offload: When True, opt the method into the CUDA weight-
164+
offload runtime: AOTI's eager constant load is
165+
skipped, the runtime installs pre-load dummies,
166+
and probes serve weights through a bounded GPU
167+
pool from a host mirror. The load-time budget is
168+
controlled by the ``weight_offload_budget_mb``
169+
runtime spec (or the default
170+
``floor + pinned_bytes`` when unset). Default
171+
False.
172+
weight_offload_pin_fqns: Optional list of parameter / buffer
173+
FQNs to keep resident on GPU for the Session
174+
lifetime (no streaming). Requires
175+
``weight_offload=True``. Duplicates are removed
176+
first-seen order.
41177
"""
178+
# Translate the public weight-offload kwargs to internal
179+
# compile specs (with validation). Extracted to a helper to
180+
# keep this constructor under the project's cyclomatic-
181+
# complexity cap.
182+
compile_spec = _validate_and_translate_weight_offload_kwargs(
183+
compile_spec, weight_offload, weight_offload_pin_fqns
184+
)
185+
42186
# Add target_device compile spec for device propagation if not already present
43187
has_target_device = any(
44188
spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY for spec in compile_spec
45189
)
46190
if not has_target_device:
47-
compile_spec = list(compile_spec) + [
191+
compile_spec = compile_spec + [
48192
CompileSpec(
49193
TARGET_DEVICE_COMPILE_SPEC_KEY,
50194
b"cuda:0",

backends/cuda/passes/weight_offload_pass.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@
2020
* The ``_apply_weight_offload`` graph rewrite + v2 payload schema
2121
(per-FQN dtype/sizes/strides/storage_offset/nbytes/device, plus
2222
schedule + floor + pin_fqns).
23-
* A private compile-spec opt-in
23+
* The public partitioner kwargs ``weight_offload=True`` and
24+
``weight_offload_pin_fqns=[...]`` on ``CudaPartitioner``
25+
translate to a private compile-spec channel
2426
(``_weight_offload_internal_enable``,
2527
``_weight_offload_internal_pin_fqns``) that routes the payload
2628
from ``CudaBackend.preprocess`` ->
2729
``AotiBackend.preprocess``'s
2830
``pre_aoti_transform_and_collect_named_data`` hook ->
2931
``NamedDataStore`` -> ``CudaBackend::init``, where the runtime
3032
parses + cross-checks against AOTI + installs dummies + builds
31-
the pinned host mirror + serves probes. No public partitioner
32-
kwarg yet; the only opt-in callers are this stack's own tests.
33+
the pinned host mirror + serves probes. The internal compile
34+
spec keys stay underscore-prefixed and remain accessible from
35+
tests for exact-byte budget control.
3336
3437
Pinning (``pin_fqns``) is supported as of commit 9a: the runtime
3538
allocates each pinned weight once via out-of-pool ``cudaMalloc``
@@ -158,13 +161,21 @@ def _probe_fake(w: torch.Tensor, probe_id: int) -> torch.Tensor:
158161
# --------------------------------------------------------------------------
159162
# Private compile-spec keys and NamedData wire format
160163
# --------------------------------------------------------------------------
161-
# EXPERIMENTAL. All knobs below carry leading underscores so they
164+
# EXPERIMENTAL. All keys below carry leading underscores so they
162165
# read as "internal" at every callsite and stay invisible to anyone who
163-
# only inspects the public surface. The matching public partitioner kwarg
164-
# does NOT exist yet; opting in requires constructing the compile specs
165-
# by hand. This is intentional — the only callers today are this stack's
166-
# own tests; the public partitioner kwarg lands later alongside the
167-
# public ``weight_offload_budget_mb`` runtime spec.
166+
# only inspects the public surface. End users opt in through the public
167+
# ``CudaPartitioner(weight_offload=True, weight_offload_pin_fqns=[...])``
168+
# kwargs, which translate to these internal COMPILE specs. The keys
169+
# themselves stay internal so the stack's own callers can still build
170+
# raw compile specs.
171+
#
172+
# Note the separate axis: load-time budget control. The PUBLIC RUNTIME
173+
# spec is ``weight_offload_budget_mb`` (int megabytes); the INTERNAL
174+
# RUNTIME spec for exact-byte budgets is
175+
# ``_weight_offload_internal_budget_bytes`` (decimal string, used by
176+
# tests that need byte-level precision below 1 MB granularity). Both
177+
# are runtime specs (consumed via ``BackendInitContext::get_runtime_spec``),
178+
# not compile specs.
168179

169180
# Compile-spec key that flips the entire offload pipeline on for a method:
170181
# triggers the pass at preprocess time and tells the runtime to skip
@@ -760,11 +771,11 @@ def _apply_weight_offload(
760771
``_weight_offload_internal_enable`` (see ``COMPILE_SPEC_KEY_ENABLE``);
761772
pin FQNs come in via ``_weight_offload_internal_pin_fqns``
762773
(NUL-separated UTF-8). The enable signal lives in exactly one
763-
place the compile spec rather than being duplicated across
764-
compile spec + payload. The matching public partitioner kwarg
765-
does NOT exist yet; the only opt-in callers today are this
766-
stack's own tests. See the EXPERIMENTAL banner at the top of
767-
this module for the current wiring state.
774+
place - the compile spec - rather than being duplicated across
775+
compile spec + payload. End users opt in through the public
776+
``CudaPartitioner(weight_offload=True,
777+
weight_offload_pin_fqns=[...])`` kwargs, which translate to
778+
these internal specs.
768779
"""
769780
# Canonicalize pin_fqns: dedupe while preserving first-seen
770781
# order so the payload is stable. The commit-9a runtime

0 commit comments

Comments
 (0)