Skip to content

Commit 1a0ca10

Browse files
committed
Weight offloading design surface (CUDA backend)
Design-only PR for CUDA-backend weight offloading: weights live in CPU memory, the runtime streams only the currently-needed ones to GPU through a capped cudaMemPool. Headers and docstrings only -- no implementation bodies, no caller, no wiring on the partitioner or runtime side. All four design files are marked ``EXPERIMENTAL -- NOT YET WIRED``. Public knobs (``CudaPartitioner(weight_offload=True, ...)`` and the ``weight_offload_budget_mb`` runtime spec) are intentionally NOT exposed in this PR; they ship with the implementation. Four open items block wiring and are documented inline: * Probe op preservation -- an identity custom op with ``mutates_args=()`` is a DCE target through inductor; the implementation PR must give probe non-elidable semantics that don't trip torch.export's parameter-output validation, plus a test that asserts the lowered AOTI wrapper actually emits the probe calls. * AOTI blob layout -- ``WeightCatalog::build`` needs per-constant offsets and dtype/shape. AOTI doesn't expose either today; implementation PR must either land upstream shims or serialize the metadata into the offload payload at export time. * Payload transport channel -- the pass has to run from ``AotiBackend.preprocess`` (the partitioner contract forbids mutating the ExportedProgram from ``partition()``); the implementation PR picks between ``processed_bytes`` and a per-method ``NamedDataStore`` entry. * Schedule / cursor order -- the runtime cursor hard-fails on a mismatch against the recorded schedule, but the pass observes parameter order before inductor lowering reorders / duplicates reads. Implementation PR either regenerates the schedule from the post-lowering wrapper or extends probe with a self-identifying ``probe_id`` / FQN arg so no cursor is needed. Read order: backends/cuda/passes/weight_offload_pass.py -- export half backends/cuda/runtime/weight_offload/weight_offload.h -- runtime backends/cuda/runtime/weight_offload/probe_op.h -- c-shim backends/cuda/runtime/weight_offload/prefetcher.h -- copy stream See: #19709
1 parent 54f1f28 commit 1a0ca10

5 files changed

Lines changed: 1317 additions & 0 deletions

File tree

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Weight offloading pass for the CUDA backend.
8+
9+
EXPERIMENTAL -- NOT YET WIRED.
10+
11+
This module is a DESIGN DOCUMENT. ``_apply_weight_offload`` has no
12+
implementation body (its body is ``...``); the ``executorch_weight_offload::probe``
13+
custom op is registered with torch but is not preserved through inductor
14+
lowering (see the "Probe op preservation" open item below); and
15+
``CudaPartitioner`` does NOT yet expose a ``weight_offload`` kwarg that
16+
would invoke this pass. Nothing in this PR calls the function or the
17+
custom op.
18+
19+
Open item -- probe op preservation:
20+
The pass's design intent is to insert one
21+
``executorch_weight_offload::probe(w)`` call before every consumer of
22+
every parameter placeholder, so the runtime can intercept each
23+
parameter read via an AOTI c-shim. The current registration
24+
(``custom_op(..., mutates_args=())`` plus an identity fake) is NOT
25+
sufficient to keep the op alive through inductor's CSE + fusion
26+
passes -- a side-effect-free identity op is a textbook DCE target,
27+
and there is no test in this PR that asserts the lowered AOTI
28+
wrapper actually emits an ``aoti_torch_cuda_probe`` call per
29+
``(consumer, weight)`` pair. The implementation PR must:
30+
1. Give probe explicit non-elidable semantics that survive
31+
inductor lowering AND don't clash with torch.export's
32+
parameter-output validation (parameters are read-only by
33+
convention, so ``mutates_args={"w"}`` won't work directly).
34+
2. Land a regression test that lowers a tiny two-linear model
35+
and asserts the wrapper.cpp contains the expected
36+
``aoti_torch_cuda_probe`` calls in the expected places.
37+
Without both, weight offload silently degrades to eager loading at
38+
runtime even when "enabled" -- which is exactly the surprise this
39+
feature exists to prevent.
40+
41+
Open item -- payload transport:
42+
The pass needs to run from a custom pass inside
43+
``AotiBackend.preprocess`` -- the ``ExportedProgram`` partitioner
44+
contract forbids mutating the program from ``partition()`` (see
45+
``exir/backend/partitioner.py:83``), so neither ``CudaPartitioner``
46+
nor any other partitioner can run the rewrite directly. ``preprocess``
47+
has two candidate channels for the resulting payload (schedule,
48+
floor, pin_fqns, version), both accessible from there and neither
49+
yet wired:
50+
(a) Serialize into the partition's ``processed_bytes`` (currently
51+
``b""`` for AOTI; see ``backends/aoti/aoti_backend.py``).
52+
(b) Attach a per-method ``NamedDataStore`` entry (where
53+
``AotiBackend`` already writes ``_so_blob`` and
54+
``_weights_blob``).
55+
Pick one when wiring. The payload-key constants below are channel-
56+
agnostic.
57+
58+
Open item -- schedule / cursor order:
59+
The runtime cursor in ``Session::serve`` hard-fails on a mismatch
60+
against the recorded schedule (see ``register_schedule`` in
61+
``weight_offload.h``). That contract requires the execution-order
62+
FQN list this pass records to match the order the lowered AOTI
63+
wrapper actually invokes ``aoti_torch_cuda_probe`` in. The pass
64+
observes the graph BEFORE inductor's custom passes / decompositions
65+
/ lowering run (``backends/aoti/aoti_backend.py:206``), all of which
66+
can reorder or duplicate parameter reads. Two options for the
67+
implementation PR: (a) regenerate the schedule from the post-lowering
68+
wrapper order; or (b) extend probe with an explicit
69+
``probe_id`` / FQN argument so each call self-identifies and the
70+
runtime needs no cursor at all. (b) is the more robust choice --
71+
it removes the entire class of "graph order drifted from wrapper
72+
order" silent failures, at the cost of a wider op signature.
73+
74+
The runtime half lives in ``backends/cuda/runtime/weight_offload/``,
75+
which is also marked EXPERIMENTAL -- NOT YET WIRED.
76+
"""
77+
78+
import torch
79+
from torch.library import custom_op, register_fake
80+
81+
82+
_OP_NAMESPACE = "executorch_weight_offload"
83+
_OP_QUALNAME = f"{_OP_NAMESPACE}::probe"
84+
85+
86+
@custom_op(_OP_QUALNAME, mutates_args=())
87+
def probe(w: torch.Tensor) -> torch.Tensor:
88+
"""Identity passthrough in eager. CUDA backend replaces via c-shim at AOTI compile time.
89+
90+
Inserted by ``apply_weight_offload`` before every consumer of every
91+
parameter (or buffer) placeholder. The CUDA runtime's c-shim
92+
(``aoti_torch_cuda_probe``) intercepts each call at runtime and serves
93+
bytes through the bounded GPU pool.
94+
95+
Signature is deliberately minimal — no FQN or schedule-index argument.
96+
The runtime resolves which weight is being probed by looking up the
97+
input tensor's ``data_ptr()`` in the ``ProbeRegistry`` populated by
98+
``Session::bind_placeholder_constants`` at backend init.
99+
100+
Notes:
101+
- The current ``mutates_args=()`` is insufficient: an identity op
102+
with no side effect is a textbook DCE target for inductor.
103+
``mutates_args={"w"}`` clashes with torch.export's
104+
parameter-output validation (parameters are read-only by
105+
convention). The implementation PR must find a third option;
106+
see the "Probe op preservation" open item in the module
107+
docstring above.
108+
- Weight offloading is mutually exclusive with the CUDA backend's
109+
``enable_cuda_graph_for_method`` option: CUDA-graph Replay bypasses
110+
AOTI's ``run()``, so probe ops never fire. The runtime hard-fails
111+
at ``init`` if both are set for the same method.
112+
"""
113+
return w
114+
115+
116+
@register_fake(_OP_QUALNAME)
117+
def _probe_fake(w: torch.Tensor) -> torch.Tensor:
118+
# Fresh fake tensor so inductor doesn't decide to inline the op away.
119+
return torch.empty_like(w)
120+
121+
122+
PROBE_OP_TARGET = torch.ops.executorch_weight_offload.probe.default
123+
124+
125+
# Payload field names. INTERNAL design intent for the partition-payload
126+
# (or NamedDataStore -- see the "payload transport" open item in the
127+
# module docstring) that ``CudaBackend.preprocess`` would write and
128+
# ``cuda_backend.cpp::init`` would parse once wired. Names are
129+
# namespaced by method so prefill and decode each get their own payload
130+
# in the same .pte.
131+
PAYLOAD_KEY_VERSION = "version"
132+
PAYLOAD_KEY_METHOD_NAME = "method_name"
133+
PAYLOAD_KEY_SCHEDULE = "schedule"
134+
PAYLOAD_KEY_FLOOR = "floor_bytes"
135+
PAYLOAD_KEY_PIN_FQNS = "pin_fqns"
136+
137+
# Schema version for the emitted offload payload. Bumped whenever the
138+
# shape of any field above changes (e.g. switching the floor from
139+
# uint64 bytes to a struct with prefetch headroom, or switching the
140+
# budget wire option from ``weight_offload_budget_mb`` to a bytes-typed
141+
# field once ``BackendOptions`` grows int64 support). The runtime
142+
# hard-fails at ``CudaBackend::init`` if the version is missing or
143+
# unknown, naming the expected range — version drift surfaces loudly at
144+
# load instead of silently mis-parsing a payload.
145+
SCHEMA_VERSION = 1
146+
147+
148+
def _apply_weight_offload(
149+
exported_program,
150+
*,
151+
method_name: str,
152+
pin_fqns: list[str] | None = None,
153+
) -> dict:
154+
"""In-place graph rewrite + offload payload computation.
155+
156+
INTERNAL — leading underscore is the Python signal. The only supported
157+
caller is ``CudaPartitioner`` (see ``backends/cuda/cuda_partitioner.py``),
158+
which sources ``method_name`` from its compile specs so prefill
159+
and decode get distinct payloads instead of colliding.
160+
``method_name`` is REQUIRED (no default) precisely so a direct
161+
caller importing this function for a multi-method model cannot
162+
silently collide all methods on ``"forward"``.
163+
164+
Inserts ``probe(w)`` in front of every consumer of every parameter (or
165+
buffer) placeholder, rewriting the consumer's arg to read the probe's
166+
output. One probe call per ``(consumer, weight)`` pair — not per
167+
weight — so the runtime can re-load a weight that was evicted between
168+
two uses inside the same forward pass.
169+
170+
Pinned FQNs still get probes inserted (so the runtime serves them
171+
through the same path), but they are EXCLUDED from the schedule and
172+
EXCLUDED from the floor calculation. See ``Session::Config::pin_fqns``
173+
and ``Session::register_schedule`` in
174+
``backends/cuda/runtime/weight_offload/weight_offload.h``.
175+
176+
The pass is the single authoritative source for the pin set. The
177+
runtime has NO pin-set option; it parses the list out of the
178+
partition payload and passes it to the Session unchanged. Pin set
179+
affects floor correctness (the floor is computed assuming pinned
180+
FQNs do not stream), so the runtime cannot override it.
181+
182+
AOTI constant-folding contract:
183+
The pass operates on parameter placeholders in the ExportedProgram.
184+
AOTI knobs that fold parameters out of the container at compile
185+
time (so they no longer appear in ``get_constant_name(idx)``)
186+
break offload: the pass cannot insert probes for parameters it
187+
cannot see, and the folded constants would be loaded eagerly
188+
through the normal blob path at runtime — silently defeating
189+
offload and reintroducing the OOM this feature exists to prevent.
190+
191+
Exports that enable weight offload must NOT have AOTI fold
192+
parameter constants. The exact ``torch._inductor.config`` knob
193+
and its required value is verified in the implementation PR.
194+
The pass hard-fails at export if it detects folded parameters
195+
(placeholder count below the expected catalog count derived
196+
from ``exported_program.state_dict``); the runtime hard-fails
197+
again at ``Session::bind_placeholder_constants`` if any catalog
198+
FQN is missing a probe binding — defense in depth against the
199+
two halves drifting.
200+
201+
Metadata transport:
202+
The returned ``dict`` is the offload payload that the
203+
implementation PR will route to ``cuda_backend.cpp::init`` via the
204+
AOTI ``preprocess`` path (see the "payload transport" open item
205+
in the module docstring for the design constraint and the two
206+
candidate channels: ``processed_bytes`` vs. ``NamedDataStore``).
207+
208+
Args:
209+
exported_program: an ``ExportedProgram`` produced by
210+
``torch.export.export``. Mutated in place: probe nodes are
211+
inserted, consumer args are rewritten.
212+
method_name: the method this pass is being applied to. Returned
213+
verbatim in the payload so the runtime can validate which
214+
method the bytes belong to.
215+
pin_fqns: FQNs to mark as always-resident. Optional. The list
216+
is propagated verbatim into the payload AND used by this pass
217+
to exclude those FQNs from the schedule and the floor
218+
calculation. Pinning an FQN that does not appear as a parameter
219+
placeholder is a hard error.
220+
221+
Returns: a ``dict`` with the keys defined at module scope (all
222+
internal payload, not opt-in signals):
223+
224+
- ``"version"``: ``int`` schema version (currently ``1``). Runtime
225+
hard-fails on unknown version.
226+
- ``"method_name"``: ``str``. Echoed for runtime validation.
227+
- ``"schedule"``: ``list[str]`` of NON-PINNED parameter FQNs in
228+
execution order. Drives the runtime cursor + prefetch (or is
229+
obviated entirely if the implementation PR picks option (b) of
230+
the "Schedule / cursor order" open item -- a probe-id arg).
231+
Pinned FQNs do not appear here; their probes take a separate
232+
fast-path in ``Session::serve`` that does not touch the cursor.
233+
- ``"floor_bytes"``: ``int`` — minimum GPU byte budget for the
234+
streaming portion of the working set (``max-over-consecutive-
235+
kernel-pairs of (sum bytes K_i + K_{i+1}) + max single weight
236+
size``), computed over the schedule above (i.e. excluding
237+
pinned weights). The runtime asserts
238+
``(weight_offload_budget_mb << 20) - pinned_bytes`` covers
239+
this; below-floor budgets hard-fail at init with the required
240+
minimum spelled out.
241+
- ``"pin_fqns"``: ``list[str]`` of FQNs the runtime keeps
242+
resident. Empty if ``pin_fqns`` is unset.
243+
244+
The opt-in signal is intended to be a separate
245+
``CompileSpec("weight_offload", b"1")`` emitted by ``CudaPartitioner``
246+
when wired -- the enable signal lives in exactly one place, the
247+
compile spec, rather than being duplicated across compile spec +
248+
payload. Neither the compile spec nor the partitioner kwarg exists
249+
in this PR; see the EXPERIMENTAL banner at the top of this module.
250+
251+
Not called by users. Not called by anything in this PR.
252+
"""
253+
...

backends/cuda/runtime/cuda_delegate_handle.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {
168168

169169
// CUDA graph state (warmup, capture, replay, static buffers)
170170
CudaGraphState cuda_graph_state;
171+
172+
// Weight offloading: the per-handle ``unique_ptr<weight_offload::Session>``
173+
// field lands with the implementation PR alongside ``weight_offload.cpp``,
174+
// which provides the out-of-line ``Session::~Session()`` definition that
175+
// ``unique_ptr<Session>``'s implicit destructor needs. Adding the field
176+
// here in an API-surface-only PR would force every TU that includes this
177+
// header into an unresolved-symbol link error against
178+
// ``Session::~Session()``. See
179+
// ``backends/cuda/runtime/weight_offload/weight_offload.h`` for the
180+
// ownership model that explains why the Session lives per-handle.
171181
};
172182

173183
} // namespace cuda

0 commit comments

Comments
 (0)