You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Lets a CUDA-backend ExecuTorch program stream model weights from a
pinned host mirror through a bounded GPU pool on demand, instead of
loading the full constant set into VRAM at program load. This lets a
process run a model whose total weight bytes exceed available GPU
memory, or share GPU across methods whose union exceeds memory.
Public surface (CudaPartitioner kwargs):
weight_offload: bool = False
Opt the method into offload. The runtime skips AOTI's eager
``update_constants_from_blob``, installs 1-byte GPU dummies in
place of constants via
``update_user_managed_constant_buffer_pairs``, and routes every
constant read through a probe op into Session::serve.
weight_offload_pin_fqns: Optional[List[str]] = None
FQNs to keep resident on GPU for the Session lifetime (out-of-
pool ``cudaMalloc`` + one-time H2D + sync at create). Requires
weight_offload=True. Currently device-0-only.
Runtime spec (LoadBackendOptionsMap, or executor_runner's
``--cuda_runtime_spec=weight_offload_budget_mb=N``):
Total GPU bytes the offload pool + pinned constants may use.
Defaults to floor_bytes + pinned_bytes_total when unset.
Architecture:
Export-time pass (backends/cuda/passes/weight_offload_pass.py):
Rewrites every parameter / buffer consumer to read through
``executorch_weight_offload::probe(w, probe_id)`` and serializes
a v2 payload (schedule + floor + pin_fqns + per-FQN dtype /
sizes / strides / storage_offset / nbytes / device) into
NamedDataStore. ``probe_id`` is dense 0..N-1 in graph order so
the runtime can index ``schedule[probe_id]`` directly.
Payload parser (backends/cuda/runtime/weight_offload/payload.h):
Single trust boundary. Validates schema version, framing bounds,
per-FQN invariants (dtype allow-list, contiguous + offset-zero
layout, nbytes == elementSize(dtype) * product(sizes), device
== cuda:0), schedule <-> metadata set equality, and pin_fqns
dedup + subset-of-schedule. Downstream code trusts the parsed
Payload struct.
AOTI c-shim (backends/cuda/runtime/weight_offload/probe_op.cpp):
Looks the dummy data_ptr up in ProbeRegistry and forwards to
Session::serve. Lookup miss with the registry empty falls back
to identity-passthrough (preserves non-offload paths); lookup
miss with the registry non-empty hard-fails (otherwise the
runtime would silently read AOTI's eager constant).
Session (backends/cuda/runtime/weight_offload/session.{h,cpp}):
Owns the offload lifecycle. Builds the pinned host mirror by
indexing into the ``_weights_blob`` NamedData entry, allocates a
bounded ``cudaMemPool``, implements LRU eviction with
event-ordered ``cudaFreeAsync`` on the compute stream, depth-1
opportunistic prefetch on the copy stream, optional pinned-
resident constants. ``serve()``'s miss path and
``opportunistic_prefetch()`` share one ``make_room_and_alloc``
helper.
ProbeRegistry (backends/cuda/runtime/weight_offload/probe_registry.{h,cpp}):
Process-global ``dummy_ptr -> (callback, context)`` table in the
shim library. ``lookup`` + ``has_any_context`` take a shared
lock so concurrent probe dispatch across Sessions does not
serialise.
CudaBackend::init (backends/cuda/runtime/cuda_backend.cpp):
Thin shim for offload methods: parse payload, fail-fast on
cuda_graph / shared_stream / non-device-0, ``cudaSetDevice(0)``,
walk the AOTI catalog, coverage check (catalog <-> schedule),
AOTI <-> payload data_size cross-check (the one genuinely
cross-source check; AOTI's compiled .so vs the .pte payload),
fetch ``_weights_blob``, call Session::create.
Tests:
backends/cuda/tests/test_weight_offload_probe_dispatch.py
backends/cuda/tests/test_weight_offload_pool.py
backends/cuda/tests/test_weight_offload_session.py
backends/cuda/tests/test_weight_offload_transport.py
backends/cuda/tests/test_cuda_partitioner.py
backends/cuda/passes/tests/test_weight_offload_pass.py
61 passed, 1 skipped on the offload surface.
Not yet wired (deferred): multi-device offload (the payload + runtime
hard-code device 0 today; the partitioner rejects target_device !=
cuda:0 when weight_offload=True). Lifting this requires teaching AOTI's
``create_with_device(..., "cuda", nullptr)`` to take a per-method
device index and propagating it through the payload, runtime, and
partitioner kwargs.
0 commit comments