Skip to content

Commit 1ea293d

Browse files
committed
Add CUDA weight offloading
Lets a CUDA-backend ExecuTorch program stream model weights from a pinned host mirror through a bounded GPU pool on demand, instead of loading the full constant set into VRAM at program load. This lets a process run a model whose total weight bytes exceed available GPU memory, or share GPU across methods whose union exceeds memory. Public surface (CudaPartitioner kwargs): weight_offload: bool = False Opt the method into offload. The runtime skips AOTI's eager ``update_constants_from_blob``, installs 1-byte GPU dummies in place of constants via ``update_user_managed_constant_buffer_pairs``, and routes every constant read through a probe op into Session::serve. weight_offload_pin_fqns: Optional[List[str]] = None FQNs to keep resident on GPU for the Session lifetime (out-of- pool ``cudaMalloc`` + one-time H2D + sync at create). Requires weight_offload=True. Currently device-0-only. Runtime spec (LoadBackendOptionsMap, or executor_runner's ``--cuda_runtime_spec=weight_offload_budget_mb=N``): Total GPU bytes the offload pool + pinned constants may use. Defaults to floor_bytes + pinned_bytes_total when unset. Architecture: Export-time pass (backends/cuda/passes/weight_offload_pass.py): Rewrites every parameter / buffer consumer to read through ``executorch_weight_offload::probe(w, probe_id)`` and serializes a v2 payload (schedule + floor + pin_fqns + per-FQN dtype / sizes / strides / storage_offset / nbytes / device) into NamedDataStore. ``probe_id`` is dense 0..N-1 in graph order so the runtime can index ``schedule[probe_id]`` directly. Payload parser (backends/cuda/runtime/weight_offload/payload.h): Single trust boundary. Validates schema version, framing bounds, per-FQN invariants (dtype allow-list, contiguous + offset-zero layout, nbytes == elementSize(dtype) * product(sizes), device == cuda:0), schedule <-> metadata set equality, and pin_fqns dedup + subset-of-schedule. Downstream code trusts the parsed Payload struct. AOTI c-shim (backends/cuda/runtime/weight_offload/probe_op.cpp): Looks the dummy data_ptr up in ProbeRegistry and forwards to Session::serve. Lookup miss with the registry empty falls back to identity-passthrough (preserves non-offload paths); lookup miss with the registry non-empty hard-fails (otherwise the runtime would silently read AOTI's eager constant). Session (backends/cuda/runtime/weight_offload/session.{h,cpp}): Owns the offload lifecycle. Builds the pinned host mirror by indexing into the ``_weights_blob`` NamedData entry, allocates a bounded ``cudaMemPool``, implements LRU eviction with event-ordered ``cudaFreeAsync`` on the compute stream, depth-1 opportunistic prefetch on the copy stream, optional pinned- resident constants. ``serve()``'s miss path and ``opportunistic_prefetch()`` share one ``make_room_and_alloc`` helper. ProbeRegistry (backends/cuda/runtime/weight_offload/probe_registry.{h,cpp}): Process-global ``dummy_ptr -> (callback, context)`` table in the shim library. ``lookup`` + ``has_any_context`` take a shared lock so concurrent probe dispatch across Sessions does not serialise. CudaBackend::init (backends/cuda/runtime/cuda_backend.cpp): Thin shim for offload methods: parse payload, fail-fast on cuda_graph / shared_stream / non-device-0, ``cudaSetDevice(0)``, walk the AOTI catalog, coverage check (catalog <-> schedule), AOTI <-> payload data_size cross-check (the one genuinely cross-source check; AOTI's compiled .so vs the .pte payload), fetch ``_weights_blob``, call Session::create. Tests: backends/cuda/tests/test_weight_offload_probe_dispatch.py backends/cuda/tests/test_weight_offload_pool.py backends/cuda/tests/test_weight_offload_session.py backends/cuda/tests/test_weight_offload_transport.py backends/cuda/tests/test_cuda_partitioner.py backends/cuda/passes/tests/test_weight_offload_pass.py 61 passed, 1 skipped on the offload surface. Not yet wired (deferred): multi-device offload (the payload + runtime hard-code device 0 today; the partitioner rejects target_device != cuda:0 when weight_offload=True). Lifting this requires teaching AOTI's ``create_with_device(..., "cuda", nullptr)`` to take a per-method device index and propagating it through the payload, runtime, and partitioner kwargs.
1 parent 54f1f28 commit 1ea293d

25 files changed

Lines changed: 6405 additions & 23 deletions

backends/aoti/aoti_backend.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import typing
1010
from abc import ABC, abstractmethod
1111
from enum import Enum
12-
from typing import Any, Dict, List, Optional, Set
12+
from typing import Any, Dict, List, Optional, Set, Tuple
1313

1414
import torch
1515
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
@@ -129,6 +129,29 @@ def release_moved_tensors(
129129
"""
130130
return
131131

132+
@classmethod
133+
def pre_aoti_transform_and_collect_named_data(
134+
cls,
135+
device_edge_program: ExportedProgram,
136+
compile_specs: List[CompileSpec],
137+
) -> List[Tuple[str, bytes, int, Optional[str]]]:
138+
"""Backend hook for graph mutation + extra NamedDataStore entries.
139+
140+
Called between ``run_decompositions`` and ``aot_compile``, so
141+
overrides can:
142+
1. Mutate ``device_edge_program`` in place (e.g. insert
143+
custom ops that the backend's AOTI c-shim registry
144+
handles), and/or
145+
2. Return ``[(key, blob, alignment, external_tag), ...]``
146+
entries to be added to the same ``NamedDataStore`` that
147+
carries ``_so_blob`` / ``_weights_blob``.
148+
149+
Default: no-op, returns ``[]``. CudaBackend overrides this to
150+
wire weight offloading when its private compile-spec opt-in
151+
is set.
152+
"""
153+
return []
154+
132155
@classmethod
133156
@contextlib.contextmanager
134157
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
@@ -182,7 +205,7 @@ def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_ker
182205
)
183206

184207
@classmethod
185-
def preprocess(
208+
def preprocess( # noqa: C901
186209
cls,
187210
edge_program: ExportedProgram,
188211
compile_specs: List[CompileSpec],
@@ -217,6 +240,15 @@ def preprocess(
217240
decomposition_table
218241
)
219242

243+
# Backend extension point: mutate ``device_edge_program`` in
244+
# place (e.g. insert custom ops the AOTI shim path will pick
245+
# up) and return extra NamedDataStore entries to attach
246+
# alongside ``_so_blob`` / ``_weights_blob`` after compile.
247+
# Default: no-op.
248+
extra_named_data = cls.pre_aoti_transform_and_collect_named_data(
249+
device_edge_program, compile_specs
250+
)
251+
220252
edge_program_module = device_edge_program.module()
221253

222254
# Grab all input placeholders from the graph
@@ -289,6 +321,11 @@ def preprocess(
289321
method_name + "_weights_blob", blob_data, 1, external_tag
290322
)
291323

324+
for extra_key, extra_blob, extra_align, extra_tag in extra_named_data:
325+
named_data_store.add_named_data(
326+
extra_key, extra_blob, extra_align, extra_tag
327+
)
328+
292329
# Clean up the generated files
293330
os.remove(so_path)
294331
os.remove(blob_path)

backends/aoti/aoti_delegate_handle.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,23 @@ using AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc =
122122
bool use_inactive,
123123
bool validate_full_update);
124124

125+
// Retrieves a constant's data size (storage bytes) by index. Required by
126+
// the weight-offload runtime to replicate AOTI's source-blob layout
127+
// without calling update_constants_from_blob.
128+
using AOTInductorModelContainerGetConstantDataSizeFunc = AOTIRuntimeError (*)(
129+
AOTInductorModelContainerHandle container_handle,
130+
size_t idx,
131+
size_t* data_size);
132+
133+
// Reports whether a constant is computed by AOTI's runtime const-folding.
134+
// Required by the weight-offload runtime to hard-fail at init if any
135+
// folded constant exists (run_const_fold reads other constants and would
136+
// see dummies in our pre-install model).
137+
using AOTInductorModelContainerGetConstantFromFoldedFunc = AOTIRuntimeError (*)(
138+
AOTInductorModelContainerHandle container_handle,
139+
size_t idx,
140+
bool* from_folded);
141+
125142
} // extern "C"
126143

127144
// AOTI Delegate Handle structure
@@ -146,6 +163,12 @@ struct AOTIDelegateHandle {
146163
AOTInductorModelContainerExtractConstantsMapFunc extract_constants_map;
147164
AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc
148165
update_user_managed_constant_buffer_pairs;
166+
167+
// Weight-offload-only function pointers. Optional for handles where
168+
// offload is not enabled; cuda_backend.cpp::init dlsym's them only when
169+
// the weight_offload compile spec is present, and hard-fails if missing.
170+
AOTInductorModelContainerGetConstantDataSizeFunc get_constant_data_size;
171+
AOTInductorModelContainerGetConstantFromFoldedFunc get_constant_from_folded;
149172
};
150173

151174
} // namespace aoti

backends/cuda/CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ install(
103103
)
104104

105105
# CUDA-specific AOTI shim symbols (dynamically linked)
106-
set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
107-
runtime/shims/cuda_guard.cpp
106+
set(_aoti_cuda_shim_sources
107+
runtime/shims/memory.cpp runtime/shims/cuda_guard.cpp
108+
runtime/weight_offload/probe_op.cpp
109+
runtime/weight_offload/probe_registry.cpp
108110
)
109111

110112
# Only build CUDA shims when CUDA language/toolchain is available.
@@ -179,7 +181,9 @@ install(
179181
)
180182

181183
# CUDA backend implementation
182-
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp)
184+
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp
185+
runtime/weight_offload/session.cpp
186+
)
183187

184188
# CUDA backend implementation
185189
add_library(aoti_cuda_backend STATIC ${_aoti_cuda_backend_sources})

backends/cuda/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ runtime.python_library(
77
srcs = [
88
"passes/__init__.py",
99
"passes/move_cond_predicate_to_cpu.py",
10+
"passes/weight_offload_pass.py",
1011
],
1112
visibility = [
1213
"//executorch/backends/cuda/...",

backends/cuda/cuda_backend.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
228228
"aoti_torch_cuda_randint_low_out": None,
229229
"executorch_cuda::int4_plain_mm": None,
230230
"aoti_torch_cuda_int4_plain_mm": None,
231+
# Weight-offload probe op (EXPERIMENTAL). Whitelisted so
232+
# AOTI's missing-fallback-kernel check doesn't fail when a
233+
# graph contains probes; the actual symbol resolution
234+
# happens at .so load time against ``libaoti_cuda_shims``.
235+
"executorch_weight_offload::probe": None,
236+
"aoti_torch_cuda_probe": None,
231237
}
232238

233239
@classmethod
@@ -262,6 +268,60 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
262268
passes.append(ReplaceEdgeOpWithTritonOpPass())
263269
return passes
264270

271+
@classmethod
272+
def pre_aoti_transform_and_collect_named_data(
273+
cls,
274+
device_edge_program,
275+
compile_specs: List[CompileSpec],
276+
):
277+
"""Wire weight offloading when the private opt-in compile spec
278+
is set. Otherwise no-op.
279+
280+
When opted in, the pass writes a NamedData payload describing
281+
the per-method probe schedule, floor budget, and constants
282+
catalog. ``CudaBackend::init`` reads that payload, pre-installs
283+
device dummies, skips AOTI's eager constant load, and builds a
284+
``Session`` that serves weights from a bounded GPU pool backed
285+
by a host mirror (with optional pinned-resident weights).
286+
"""
287+
# Cheap opt-in probe first — avoids importing the offload
288+
# pass machinery for the (overwhelmingly common) non-offload
289+
# export. The key string is duplicated against
290+
# ``COMPILE_SPEC_KEY_ENABLE`` on purpose so this branch has
291+
# no import dependency at all.
292+
opt_in = any(
293+
spec.key == "_weight_offload_internal_enable" and len(spec.value) > 0
294+
for spec in compile_specs
295+
)
296+
if not opt_in:
297+
return []
298+
299+
# Opt-in is set: the offload pass MUST be importable. A failing
300+
# import here means a broken build, not a "skip and hope for
301+
# the best" — the runtime would otherwise hard-fail at init
302+
# with a "payload missing" error miles away from the actual
303+
# cause.
304+
from executorch.backends.cuda.passes.weight_offload_pass import (
305+
_apply_weight_offload,
306+
_serialize_payload,
307+
named_data_key_for_method,
308+
pin_fqns_from_specs,
309+
)
310+
311+
method_name = cls.method_name_from_compile_specs(compile_specs)
312+
pin_fqns = pin_fqns_from_specs(compile_specs)
313+
payload = _apply_weight_offload(
314+
device_edge_program,
315+
method_name=method_name,
316+
pin_fqns=pin_fqns,
317+
)
318+
blob = _serialize_payload(payload)
319+
# Merge into the .pte rather than the external .ptd: the
320+
# payload is metadata (a few KB at most) and the runtime
321+
# parses it in ``init`` before any heavy data loading, so
322+
# keeping it inline avoids dragging the .ptd open early.
323+
return [(named_data_key_for_method(method_name), blob, 1, None)]
324+
265325
@classmethod
266326
def get_aoti_compile_options(
267327
cls, compile_specs: List[CompileSpec]
@@ -314,6 +374,31 @@ def get_aoti_compile_options(
314374
# int4_dispatch.py not imported — op not registered, skip C shim mapping
315375
pass
316376

377+
# Weight-offload probe op (EXPERIMENTAL). The op is defined in
378+
# ``backends/cuda/passes/weight_offload_pass.py``; import it
379+
# lazily so a build that excludes the offload pass still
380+
# compiles. ``probe`` is registered unconditionally because the
381+
# AOTI wrapper only emits the call when an FX-graph node
382+
# references it, and the runtime symbol always exists in
383+
# ``aoti_cuda_shims``.
384+
try:
385+
from executorch.backends.cuda.passes import ( # noqa: F401
386+
weight_offload_pass,
387+
)
388+
389+
shims = options.setdefault("aot_inductor.custom_ops_to_c_shims", {})
390+
shims.setdefault(
391+
torch.ops.executorch_weight_offload.probe.default,
392+
[
393+
"AOTITorchError aoti_torch_cuda_probe("
394+
"AtenTensorHandle, int64_t, AtenTensorHandle*)"
395+
],
396+
)
397+
except AttributeError:
398+
# weight_offload_pass importable but op not registered (e.g.
399+
# torch.library shim missing in the running interpreter); skip.
400+
pass
401+
317402
# Parse compile_specs to check for platform
318403

319404
platform = "linux"

0 commit comments

Comments
 (0)