Skip to content

Commit 4a05736

Browse files
committed
Offload cleanup (round 3): serve/prefetch dedup, ConstantInfo move, docstring trim
Three follow-ups from the post-round-2 review. Extract Session::make_room_and_alloc as the shared eviction + alloc + H2D + ready-event path used by both serve()'s miss case and opportunistic_prefetch(). The two were ~80% copy-paste of the same LRU eviction loop, the same post-batch event-record/wait dance, the same alloc + copy + ready-event sequence, and the same on-error cleanup. They differed only in: * whether to guard against evicting a specific FQN (prefetch: yes for the just-served FQN; serve miss: no); * whether to cudaStreamWaitEvent(compute_, ready) before returning (serve: yes; prefetch: no -- next serve() does the wait); * the log tag ("[ERROR]" vs "[WARN] prefetch"); * which success counter to bump (caller-side). The helper takes a guard_fqn pointer and a log_tag string; callers own the per-caller bits (final cudaStreamWaitEvent, bytes accounting, live_.emplace, success-counter bump). Net -89 lines on session.cpp + session.h and one eviction codepath instead of two. Move ConstantInfo into session.h and delete constant_catalog.h entirely. The file had been reduced to ConstantInfo + a stale banner after build_constant_catalog() came out in round 2. Drop the include from cuda_backend.cpp, drop the header from the runtime TARGETS block. The struct's primary consumer is Session, which is where it now lives. Trim _apply_weight_offload docstring 146 -> 24 lines. The old prose restated what the schedule/probe_id/floor/pin contracts mean -- all of which are already documented at the call sites (the module docstring, _compute_floor_bytes, PAYLOAD_KEY_* constants). Kept the caller contract (internal; method_name required to prevent silent multi-method collision), the mutation summary, the return shape, and a re-entry constraint reminder. 61 of 61 offload tests pass; lint clean. Authored with Claude.
1 parent 31e6036 commit 4a05736

7 files changed

Lines changed: 282 additions & 528 deletions

File tree

backends/cuda/passes/weight_offload_pass.py

Lines changed: 29 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -672,144 +672,35 @@ def _apply_weight_offload(
672672
) -> dict:
673673
"""In-place graph rewrite + offload payload computation.
674674
675-
INTERNAL — leading underscore is the Python signal. The only
676-
supported caller is
677-
``CudaBackend.pre_aoti_transform_and_collect_named_data`` (see
678-
``backends/cuda/cuda_backend.py``), which gates on the private
679-
``_weight_offload_internal_enable`` compile spec, sources
680-
``method_name`` from the AOTI method-name spec, and routes the
681-
returned payload into ``NamedDataStore`` for ``CudaBackend::init``
682-
to parse. ``method_name`` is REQUIRED (no default) precisely so a
683-
direct caller importing this function for a multi-method model
684-
cannot silently collide all methods on ``"forward"``.
685-
686-
Inserts ``probe(w, probe_id)`` in front of every consumer of every
687-
parameter (or buffer) placeholder, rewriting the consumer's arg to
688-
read the probe's output. One probe call per ``(consumer, weight)``
689-
pair — not per weight — so the runtime can re-load a weight that
690-
was evicted between two uses inside the same forward pass.
691-
``probe_id`` is assigned contiguously (0..N-1) in graph order; the
692-
runtime keys the schedule lookup off ``probe_id`` directly, with
693-
no cursor.
694-
695-
Pinned FQNs are included in the schedule like any other FQN — every
696-
consumer of a pinned weight gets a probe with its own ``probe_id``,
697-
and ``schedule[probe_id]`` returns the pinned FQN. The pin set
698-
ships separately (``pin_fqns``) so the runtime can choose the
699-
resident fast-path over the streaming path inside ``serve``. The
700-
pin set is, however, EXCLUDED from the floor calculation (pinned
701-
weights don't compete for the streaming pool). See
702-
``Session::Config::pin_fqns`` and ``Session::register_schedule`` in
703-
``backends/cuda/runtime/weight_offload/weight_offload.h``.
704-
705-
The pass is the single authoritative source for the pin set. The
706-
runtime has NO pin-set option; it parses the list out of the
707-
partition payload and passes it to the Session unchanged. Pin set
708-
affects floor correctness (the floor is computed assuming pinned
709-
FQNs do not stream), so the runtime cannot override it.
710-
711-
AOTI constant-folding contract:
712-
The pass operates on parameter placeholders in the
713-
ExportedProgram. AOTI knobs that fold parameters out of the
714-
container at compile time (so they no longer appear in
715-
``get_constant_name(idx)``) break offload: the pass inserts
716-
probes against the placeholders it sees pre-AOTI, but the
717-
folded constants would be loaded eagerly through the normal
718-
blob path at runtime — silently defeating offload and
719-
reintroducing the OOM this feature exists to prevent. The pass
720-
itself cannot observe AOTI folding (it runs before AOTI
721-
compile, where all state_dict entries are still placeholders),
722-
so the catch lives in the runtime: ``CudaBackend::init`` walks
723-
``get_constant_from_folded(i)`` for every catalog entry and
724-
hard-fails on the first folded one (with dummies pre-installed,
725-
``run_const_fold`` would read other constants as garbage). The
726-
set-equality coverage check
727-
``non_folded_catalog == unique(schedule)`` is the second half
728-
of the defense. Exports that enable weight offload must also
729-
configure ``torch._inductor.config.aot_inductor.use_runtime_constant_folding
730-
= False``; the partitioner-side opt-in (future work) is the
731-
right place to verify that.
732-
733-
Metadata transport:
734-
The returned ``dict`` is the offload payload that
735-
``CudaBackend.preprocess`` serializes (via
736-
``_serialize_payload``) into the AOTI ``NamedDataStore`` under
737-
the per-method ``_weight_offload_payload`` key (see
738-
``named_data_key_for_method``). ``cuda_backend.cpp::init``
739-
retrieves and parses it.
740-
741-
Args:
742-
exported_program: an ``ExportedProgram`` produced by
743-
``torch.export.export``. Mutated in place: probe nodes are
744-
inserted, consumer args are rewritten.
745-
method_name: the method this pass is being applied to. Returned
746-
verbatim in the payload so the runtime can validate which
747-
method the bytes belong to.
748-
pin_fqns: FQNs to mark as always-resident. Optional. The list
749-
is propagated verbatim into the payload AND used by this pass
750-
to exclude those FQNs from the floor calculation (pinned
751-
weights don't compete for the streaming pool). Pinned FQNs
752-
DO appear in the schedule like any other FQN — keeping
753-
``probe_id`` dense is what lets the runtime do a single
754-
``schedule[probe_id]`` lookup per probe. Pinning an FQN that
755-
does not appear as a parameter placeholder is a hard error.
756-
757-
Returns: a ``dict`` with the keys defined at module scope (all
758-
internal payload, not opt-in signals):
759-
760-
- ``"version"``: ``int`` schema version (currently ``2``). Runtime
761-
hard-fails on any version other than 2 (v1 is rejected with a
762-
"rebuild required" message).
763-
- ``"method_name"``: ``str``. Echoed for runtime validation.
764-
- ``"schedule"``: ``list[str]`` of length N indexed by
765-
``probe_id`` — ``schedule[probe_id]`` is the FQN of the weight
766-
that probe site reads. Pinned and non-pinned FQNs BOTH appear
767-
here (every probe site contributes one entry regardless of pin
768-
status); the runtime checks ``pin_fqns`` inside ``serve`` to
769-
choose the resident fast-path over the streaming path. Keeping
770-
``probe_id`` dense and contiguous is what lets the runtime do
771-
a single ``schedule[probe_id]`` lookup per probe.
772-
- ``"floor_bytes"``: ``int`` — conservative FX fusion upper
773-
bound on the streaming pool. NOT a tight kernel-level
774-
estimate; that needs post-AOTI kernel grouping that a
775-
future commit will land. Computed as ``max over consecutive
776-
FX candidate pairs of (sum bytes of the UNION of non-pinned
777-
working sets at each side) + max single non-pinned
778-
weight``. FX candidates are non-view non-probe
779-
``call_function`` nodes plus the output sink (so Inductor
780-
fusing independent final consumers into one multi-output
781-
kernel still factors in). Each candidate's working set is
782-
built by propagating probe FQNs forward through every
783-
fusion-eligible edge (see ``_fusion_dependency_sets``).
784-
Defaulting to "fusible" overestimates the floor — safe.
785-
Claiming barrier where none exists underestimates it —
786-
corruption. The runtime asserts
787-
``(weight_offload_budget_mb << 20) - pinned_bytes`` covers
788-
this; below-floor budgets hard-fail at init with the
789-
required minimum spelled out.
790-
- ``"pin_fqns"``: ``list[str]`` of FQNs the runtime keeps
791-
resident. Side set over the FQNs that appear in
792-
``schedule``; an FQN in ``pin_fqns`` must also appear in
793-
``schedule`` (at every site where it is read). Empty if
794-
``pin_fqns`` is unset.
795-
796-
Per-FQN AOTI constant metadata (dtype / sizes / strides /
797-
storage_offset / nbytes / device_type / device_index) arrives in
798-
the v2 payload via the ``constants_metadata`` block (one entry
799-
per ``unique(schedule)`` FQN). The runtime cross-checks each
800-
payload entry against AOTI's own ``get_constant_data_size`` and
801-
drives both the source-blob copy length and the SlimTensor
802-
metadata Session uses for borrowed wraps.
803-
804-
The opt-in signal is the private compile spec
805-
``_weight_offload_internal_enable`` (see ``COMPILE_SPEC_KEY_ENABLE``);
806-
pin FQNs come in via ``_weight_offload_internal_pin_fqns``
807-
(NUL-separated UTF-8). The enable signal lives in exactly one
808-
place - the compile spec - rather than being duplicated across
809-
compile spec + payload. End users opt in through the public
810-
``CudaPartitioner(weight_offload=True,
811-
weight_offload_pin_fqns=[...])`` kwargs, which translate to
812-
these internal specs.
675+
Internal: the only supported caller is
676+
``CudaBackend.pre_aoti_transform_and_collect_named_data``. The
677+
``method_name`` arg is required (no default) so a direct caller
678+
on a multi-method model cannot silently collide all methods on
679+
``"forward"``.
680+
681+
Mutates ``exported_program`` in place: inserts ``probe(w, probe_id)``
682+
in front of every consumer of every parameter / buffer placeholder
683+
and rewrites the consumer's arg to read the probe's output. One
684+
probe per ``(consumer, weight)`` pair so the runtime can re-load
685+
a weight evicted between two uses in the same forward pass.
686+
``probe_id`` is dense 0..N-1 in graph order.
687+
688+
Returns the offload payload dict (see ``PAYLOAD_KEY_*`` at module
689+
scope for the schema):
690+
* ``schedule[probe_id]`` is the FQN that probe site reads;
691+
pinned FQNs appear here too (the runtime picks the resident
692+
path inside ``serve``).
693+
* ``floor_bytes`` is a conservative FX-fusion-aware upper bound
694+
on the streaming pool, excluding pinned FQNs. The runtime
695+
hard-fails if ``budget - pinned < floor``.
696+
* ``pin_fqns`` is the resident set, deduped first-seen-order.
697+
* ``constants_metadata`` carries per-FQN dtype / sizes / strides
698+
/ storage_offset / nbytes / device for runtime cross-check.
699+
700+
Re-entry: this pass MUST run before AOTI compile (it operates on
701+
placeholders) and MUST NOT be re-run on a graph that already
702+
contains probe nodes -- the second pass would insert probes on
703+
the probes' outputs.
813704
"""
814705
# Canonicalize pin_fqns: dedupe while preserving first-seen order
815706
# so the payload is stable. The runtime hard-fails on duplicates

backends/cuda/runtime/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ runtime.cxx_library(
8686
],
8787
headers = [
8888
"cuda_delegate_handle.h",
89-
"weight_offload/constant_catalog.h",
9089
"weight_offload/payload.h",
9190
"weight_offload/session.h",
9291
],

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
#include <executorch/backends/cuda/runtime/platform/platform.h>
5050
#include <executorch/backends/cuda/runtime/shims/memory.h>
5151
#include <executorch/backends/cuda/runtime/utils.h>
52-
#include <executorch/backends/cuda/runtime/weight_offload/constant_catalog.h>
5352
#include <executorch/backends/cuda/runtime/weight_offload/payload.h>
5453
#include <executorch/backends/cuda/runtime/weight_offload/session.h>
5554

backends/cuda/runtime/weight_offload/constant_catalog.h

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)