@@ -672,144 +672,35 @@ def _apply_weight_offload(
672672) -> dict :
673673 """In-place graph rewrite + offload payload computation.
674674
675- INTERNAL — leading underscore is the Python signal. The only
676- supported caller is
677- ``CudaBackend.pre_aoti_transform_and_collect_named_data`` (see
678- ``backends/cuda/cuda_backend.py``), which gates on the private
679- ``_weight_offload_internal_enable`` compile spec, sources
680- ``method_name`` from the AOTI method-name spec, and routes the
681- returned payload into ``NamedDataStore`` for ``CudaBackend::init``
682- to parse. ``method_name`` is REQUIRED (no default) precisely so a
683- direct caller importing this function for a multi-method model
684- cannot silently collide all methods on ``"forward"``.
685-
686- Inserts ``probe(w, probe_id)`` in front of every consumer of every
687- parameter (or buffer) placeholder, rewriting the consumer's arg to
688- read the probe's output. One probe call per ``(consumer, weight)``
689- pair — not per weight — so the runtime can re-load a weight that
690- was evicted between two uses inside the same forward pass.
691- ``probe_id`` is assigned contiguously (0..N-1) in graph order; the
692- runtime keys the schedule lookup off ``probe_id`` directly, with
693- no cursor.
694-
695- Pinned FQNs are included in the schedule like any other FQN — every
696- consumer of a pinned weight gets a probe with its own ``probe_id``,
697- and ``schedule[probe_id]`` returns the pinned FQN. The pin set
698- ships separately (``pin_fqns``) so the runtime can choose the
699- resident fast-path over the streaming path inside ``serve``. The
700- pin set is, however, EXCLUDED from the floor calculation (pinned
701- weights don't compete for the streaming pool). See
702- ``Session::Config::pin_fqns`` and ``Session::register_schedule`` in
703- ``backends/cuda/runtime/weight_offload/weight_offload.h``.
704-
705- The pass is the single authoritative source for the pin set. The
706- runtime has NO pin-set option; it parses the list out of the
707- partition payload and passes it to the Session unchanged. Pin set
708- affects floor correctness (the floor is computed assuming pinned
709- FQNs do not stream), so the runtime cannot override it.
710-
711- AOTI constant-folding contract:
712- The pass operates on parameter placeholders in the
713- ExportedProgram. AOTI knobs that fold parameters out of the
714- container at compile time (so they no longer appear in
715- ``get_constant_name(idx)``) break offload: the pass inserts
716- probes against the placeholders it sees pre-AOTI, but the
717- folded constants would be loaded eagerly through the normal
718- blob path at runtime — silently defeating offload and
719- reintroducing the OOM this feature exists to prevent. The pass
720- itself cannot observe AOTI folding (it runs before AOTI
721- compile, where all state_dict entries are still placeholders),
722- so the catch lives in the runtime: ``CudaBackend::init`` walks
723- ``get_constant_from_folded(i)`` for every catalog entry and
724- hard-fails on the first folded one (with dummies pre-installed,
725- ``run_const_fold`` would read other constants as garbage). The
726- set-equality coverage check
727- ``non_folded_catalog == unique(schedule)`` is the second half
728- of the defense. Exports that enable weight offload must also
729- configure ``torch._inductor.config.aot_inductor.use_runtime_constant_folding
730- = False``; the partitioner-side opt-in (future work) is the
731- right place to verify that.
732-
733- Metadata transport:
734- The returned ``dict`` is the offload payload that
735- ``CudaBackend.preprocess`` serializes (via
736- ``_serialize_payload``) into the AOTI ``NamedDataStore`` under
737- the per-method ``_weight_offload_payload`` key (see
738- ``named_data_key_for_method``). ``cuda_backend.cpp::init``
739- retrieves and parses it.
740-
741- Args:
742- exported_program: an ``ExportedProgram`` produced by
743- ``torch.export.export``. Mutated in place: probe nodes are
744- inserted, consumer args are rewritten.
745- method_name: the method this pass is being applied to. Returned
746- verbatim in the payload so the runtime can validate which
747- method the bytes belong to.
748- pin_fqns: FQNs to mark as always-resident. Optional. The list
749- is propagated verbatim into the payload AND used by this pass
750- to exclude those FQNs from the floor calculation (pinned
751- weights don't compete for the streaming pool). Pinned FQNs
752- DO appear in the schedule like any other FQN — keeping
753- ``probe_id`` dense is what lets the runtime do a single
754- ``schedule[probe_id]`` lookup per probe. Pinning an FQN that
755- does not appear as a parameter placeholder is a hard error.
756-
757- Returns: a ``dict`` with the keys defined at module scope (all
758- internal payload, not opt-in signals):
759-
760- - ``"version"``: ``int`` schema version (currently ``2``). Runtime
761- hard-fails on any version other than 2 (v1 is rejected with a
762- "rebuild required" message).
763- - ``"method_name"``: ``str``. Echoed for runtime validation.
764- - ``"schedule"``: ``list[str]`` of length N indexed by
765- ``probe_id`` — ``schedule[probe_id]`` is the FQN of the weight
766- that probe site reads. Pinned and non-pinned FQNs BOTH appear
767- here (every probe site contributes one entry regardless of pin
768- status); the runtime checks ``pin_fqns`` inside ``serve`` to
769- choose the resident fast-path over the streaming path. Keeping
770- ``probe_id`` dense and contiguous is what lets the runtime do
771- a single ``schedule[probe_id]`` lookup per probe.
772- - ``"floor_bytes"``: ``int`` — conservative FX fusion upper
773- bound on the streaming pool. NOT a tight kernel-level
774- estimate; that needs post-AOTI kernel grouping that a
775- future commit will land. Computed as ``max over consecutive
776- FX candidate pairs of (sum bytes of the UNION of non-pinned
777- working sets at each side) + max single non-pinned
778- weight``. FX candidates are non-view non-probe
779- ``call_function`` nodes plus the output sink (so Inductor
780- fusing independent final consumers into one multi-output
781- kernel still factors in). Each candidate's working set is
782- built by propagating probe FQNs forward through every
783- fusion-eligible edge (see ``_fusion_dependency_sets``).
784- Defaulting to "fusible" overestimates the floor — safe.
785- Claiming barrier where none exists underestimates it —
786- corruption. The runtime asserts
787- ``(weight_offload_budget_mb << 20) - pinned_bytes`` covers
788- this; below-floor budgets hard-fail at init with the
789- required minimum spelled out.
790- - ``"pin_fqns"``: ``list[str]`` of FQNs the runtime keeps
791- resident. Side set over the FQNs that appear in
792- ``schedule``; an FQN in ``pin_fqns`` must also appear in
793- ``schedule`` (at every site where it is read). Empty if
794- ``pin_fqns`` is unset.
795-
796- Per-FQN AOTI constant metadata (dtype / sizes / strides /
797- storage_offset / nbytes / device_type / device_index) arrives in
798- the v2 payload via the ``constants_metadata`` block (one entry
799- per ``unique(schedule)`` FQN). The runtime cross-checks each
800- payload entry against AOTI's own ``get_constant_data_size`` and
801- drives both the source-blob copy length and the SlimTensor
802- metadata Session uses for borrowed wraps.
803-
804- The opt-in signal is the private compile spec
805- ``_weight_offload_internal_enable`` (see ``COMPILE_SPEC_KEY_ENABLE``);
806- pin FQNs come in via ``_weight_offload_internal_pin_fqns``
807- (NUL-separated UTF-8). The enable signal lives in exactly one
808- place - the compile spec - rather than being duplicated across
809- compile spec + payload. End users opt in through the public
810- ``CudaPartitioner(weight_offload=True,
811- weight_offload_pin_fqns=[...])`` kwargs, which translate to
812- these internal specs.
675+ Internal: the only supported caller is
676+ ``CudaBackend.pre_aoti_transform_and_collect_named_data``. The
677+ ``method_name`` arg is required (no default) so a direct caller
678+ on a multi-method model cannot silently collide all methods on
679+ ``"forward"``.
680+
681+ Mutates ``exported_program`` in place: inserts ``probe(w, probe_id)``
682+ in front of every consumer of every parameter / buffer placeholder
683+ and rewrites the consumer's arg to read the probe's output. One
684+ probe per ``(consumer, weight)`` pair so the runtime can re-load
685+ a weight evicted between two uses in the same forward pass.
686+ ``probe_id`` is dense 0..N-1 in graph order.
687+
688+ Returns the offload payload dict (see ``PAYLOAD_KEY_*`` at module
689+ scope for the schema):
690+ * ``schedule[probe_id]`` is the FQN that probe site reads;
691+ pinned FQNs appear here too (the runtime picks the resident
692+ path inside ``serve``).
693+ * ``floor_bytes`` is a conservative FX-fusion-aware upper bound
694+ on the streaming pool, excluding pinned FQNs. The runtime
695+ hard-fails if ``budget - pinned < floor``.
696+ * ``pin_fqns`` is the resident set, deduped first-seen-order.
697+ * ``constants_metadata`` carries per-FQN dtype / sizes / strides
698+ / storage_offset / nbytes / device for runtime cross-check.
699+
700+ Re-entry: this pass MUST run before AOTI compile (it operates on
701+ placeholders) and MUST NOT be re-run on a graph that already
702+ contains probe nodes -- the second pass would insert probes on
703+ the probes' outputs.
813704 """
814705 # Canonicalize pin_fqns: dedupe while preserving first-seen order
815706 # so the payload is stable. The runtime hard-fails on duplicates
0 commit comments