|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""Weight offloading pass for the CUDA backend. |
| 8 | +
|
| 9 | +EXPERIMENTAL -- NOT YET WIRED. |
| 10 | +
|
| 11 | +This module is a DESIGN DOCUMENT. ``_apply_weight_offload`` has no |
| 12 | +implementation body (its body is ``...``); the ``executorch_weight_offload::probe`` |
| 13 | +custom op is registered with torch but is not preserved through inductor |
| 14 | +lowering (see the "Probe op preservation" open item below); and |
| 15 | +``CudaPartitioner`` does NOT yet expose a ``weight_offload`` kwarg that |
| 16 | +would invoke this pass. Nothing in this PR calls the function or the |
| 17 | +custom op. |
| 18 | +
|
| 19 | +Open item -- probe op preservation: |
| 20 | + The pass's design intent is to insert one |
| 21 | + ``executorch_weight_offload::probe(w)`` call before every consumer of |
| 22 | + every parameter placeholder, so the runtime can intercept each |
| 23 | + parameter read via an AOTI c-shim. The current registration |
| 24 | + (``custom_op(..., mutates_args=())`` plus an identity fake) is NOT |
| 25 | + sufficient to keep the op alive through inductor's CSE + fusion |
| 26 | + passes -- a side-effect-free identity op is a textbook DCE target, |
| 27 | + and there is no test in this PR that asserts the lowered AOTI |
| 28 | + wrapper actually emits an ``aoti_torch_cuda_probe`` call per |
| 29 | + ``(consumer, weight)`` pair. The implementation PR must: |
| 30 | + 1. Give probe explicit non-elidable semantics that survive |
| 31 | + inductor lowering AND don't clash with torch.export's |
| 32 | + parameter-output validation (parameters are read-only by |
| 33 | + convention, so ``mutates_args={"w"}`` won't work directly). |
| 34 | + 2. Land a regression test that lowers a tiny two-linear model |
| 35 | + and asserts the wrapper.cpp contains the expected |
| 36 | + ``aoti_torch_cuda_probe`` calls in the expected places. |
| 37 | + Without both, weight offload silently degrades to eager loading at |
| 38 | + runtime even when "enabled" -- which is exactly the surprise this |
| 39 | + feature exists to prevent. |
| 40 | +
|
| 41 | +Open item -- payload transport: |
| 42 | + The pass needs to run from a custom pass inside |
| 43 | + ``AotiBackend.preprocess`` -- the ``ExportedProgram`` partitioner |
| 44 | + contract forbids mutating the program from ``partition()`` (see |
| 45 | + ``exir/backend/partitioner.py:83``), so neither ``CudaPartitioner`` |
| 46 | + nor any other partitioner can run the rewrite directly. ``preprocess`` |
| 47 | + has two candidate channels for the resulting payload (schedule, |
| 48 | + floor, pin_fqns, version), both accessible from there and neither |
| 49 | + yet wired: |
| 50 | + (a) Serialize into the partition's ``processed_bytes`` (currently |
| 51 | + ``b""`` for AOTI; see ``backends/aoti/aoti_backend.py``). |
| 52 | + (b) Attach a per-method ``NamedDataStore`` entry (where |
| 53 | + ``AotiBackend`` already writes ``_so_blob`` and |
| 54 | + ``_weights_blob``). |
| 55 | + Pick one when wiring. The payload-key constants below are channel- |
| 56 | + agnostic. |
| 57 | +
|
| 58 | +Open item -- schedule / cursor order: |
| 59 | + The runtime cursor in ``Session::serve`` hard-fails on a mismatch |
| 60 | + against the recorded schedule (see ``register_schedule`` in |
| 61 | + ``weight_offload.h``). That contract requires the execution-order |
| 62 | + FQN list this pass records to match the order the lowered AOTI |
| 63 | + wrapper actually invokes ``aoti_torch_cuda_probe`` in. The pass |
| 64 | + observes the graph BEFORE inductor's custom passes / decompositions |
| 65 | + / lowering run (``backends/aoti/aoti_backend.py:206``), all of which |
| 66 | + can reorder or duplicate parameter reads. Two options for the |
| 67 | + implementation PR: (a) regenerate the schedule from the post-lowering |
| 68 | + wrapper order; or (b) extend probe with an explicit |
| 69 | + ``probe_id`` / FQN argument so each call self-identifies and the |
| 70 | + runtime needs no cursor at all. (b) is the more robust choice -- |
| 71 | + it removes the entire class of "graph order drifted from wrapper |
| 72 | + order" silent failures, at the cost of a wider op signature. |
| 73 | +
|
| 74 | +The runtime half lives in ``backends/cuda/runtime/weight_offload/``, |
| 75 | +which is also marked EXPERIMENTAL -- NOT YET WIRED. |
| 76 | +""" |
| 77 | + |
| 78 | +import torch |
| 79 | +from torch.library import custom_op, register_fake |
| 80 | + |
| 81 | + |
| 82 | +_OP_NAMESPACE = "executorch_weight_offload" |
| 83 | +_OP_QUALNAME = f"{_OP_NAMESPACE}::probe" |
| 84 | + |
| 85 | + |
| 86 | +@custom_op(_OP_QUALNAME, mutates_args=()) |
| 87 | +def probe(w: torch.Tensor) -> torch.Tensor: |
| 88 | + """Identity passthrough in eager. CUDA backend replaces via c-shim at AOTI compile time. |
| 89 | +
|
| 90 | + Inserted by ``apply_weight_offload`` before every consumer of every |
| 91 | + parameter (or buffer) placeholder. The CUDA runtime's c-shim |
| 92 | + (``aoti_torch_cuda_probe``) intercepts each call at runtime and serves |
| 93 | + bytes through the bounded GPU pool. |
| 94 | +
|
| 95 | + Signature is deliberately minimal — no FQN or schedule-index argument. |
| 96 | + The runtime resolves which weight is being probed by looking up the |
| 97 | + input tensor's ``data_ptr()`` in the ``ProbeRegistry`` populated by |
| 98 | + ``Session::bind_placeholder_constants`` at backend init. |
| 99 | +
|
| 100 | + Notes: |
| 101 | + - The current ``mutates_args=()`` is insufficient: an identity op |
| 102 | + with no side effect is a textbook DCE target for inductor. |
| 103 | + ``mutates_args={"w"}`` clashes with torch.export's |
| 104 | + parameter-output validation (parameters are read-only by |
| 105 | + convention). The implementation PR must find a third option; |
| 106 | + see the "Probe op preservation" open item in the module |
| 107 | + docstring above. |
| 108 | + - Weight offloading is mutually exclusive with the CUDA backend's |
| 109 | + ``enable_cuda_graph_for_method`` option: CUDA-graph Replay bypasses |
| 110 | + AOTI's ``run()``, so probe ops never fire. The runtime hard-fails |
| 111 | + at ``init`` if both are set for the same method. |
| 112 | + """ |
| 113 | + return w |
| 114 | + |
| 115 | + |
| 116 | +@register_fake(_OP_QUALNAME) |
| 117 | +def _probe_fake(w: torch.Tensor) -> torch.Tensor: |
| 118 | + # Fresh fake tensor so inductor doesn't decide to inline the op away. |
| 119 | + return torch.empty_like(w) |
| 120 | + |
| 121 | + |
| 122 | +PROBE_OP_TARGET = torch.ops.executorch_weight_offload.probe.default |
| 123 | + |
| 124 | + |
| 125 | +# Payload field names. INTERNAL design intent for the partition-payload |
| 126 | +# (or NamedDataStore -- see the "payload transport" open item in the |
| 127 | +# module docstring) that ``CudaBackend.preprocess`` would write and |
| 128 | +# ``cuda_backend.cpp::init`` would parse once wired. Names are |
| 129 | +# namespaced by method so prefill and decode each get their own payload |
| 130 | +# in the same .pte. |
| 131 | +PAYLOAD_KEY_VERSION = "version" |
| 132 | +PAYLOAD_KEY_METHOD_NAME = "method_name" |
| 133 | +PAYLOAD_KEY_SCHEDULE = "schedule" |
| 134 | +PAYLOAD_KEY_FLOOR = "floor_bytes" |
| 135 | +PAYLOAD_KEY_PIN_FQNS = "pin_fqns" |
| 136 | + |
| 137 | +# Schema version for the emitted offload payload. Bumped whenever the |
| 138 | +# shape of any field above changes (e.g. switching the floor from |
| 139 | +# uint64 bytes to a struct with prefetch headroom, or switching the |
| 140 | +# budget wire option from ``weight_offload_budget_mb`` to a bytes-typed |
| 141 | +# field once ``BackendOptions`` grows int64 support). The runtime |
| 142 | +# hard-fails at ``CudaBackend::init`` if the version is missing or |
| 143 | +# unknown, naming the expected range — version drift surfaces loudly at |
| 144 | +# load instead of silently mis-parsing a payload. |
| 145 | +SCHEMA_VERSION = 1 |
| 146 | + |
| 147 | + |
| 148 | +def _apply_weight_offload( |
| 149 | + exported_program, |
| 150 | + *, |
| 151 | + method_name: str, |
| 152 | + pin_fqns: list[str] | None = None, |
| 153 | +) -> dict: |
| 154 | + """In-place graph rewrite + offload payload computation. |
| 155 | +
|
| 156 | + INTERNAL — leading underscore is the Python signal. The only supported |
| 157 | + caller is ``CudaPartitioner`` (see ``backends/cuda/cuda_partitioner.py``), |
| 158 | + which sources ``method_name`` from its compile specs so prefill |
| 159 | + and decode get distinct payloads instead of colliding. |
| 160 | + ``method_name`` is REQUIRED (no default) precisely so a direct |
| 161 | + caller importing this function for a multi-method model cannot |
| 162 | + silently collide all methods on ``"forward"``. |
| 163 | +
|
| 164 | + Inserts ``probe(w)`` in front of every consumer of every parameter (or |
| 165 | + buffer) placeholder, rewriting the consumer's arg to read the probe's |
| 166 | + output. One probe call per ``(consumer, weight)`` pair — not per |
| 167 | + weight — so the runtime can re-load a weight that was evicted between |
| 168 | + two uses inside the same forward pass. |
| 169 | +
|
| 170 | + Pinned FQNs still get probes inserted (so the runtime serves them |
| 171 | + through the same path), but they are EXCLUDED from the schedule and |
| 172 | + EXCLUDED from the floor calculation. See ``Session::Config::pin_fqns`` |
| 173 | + and ``Session::register_schedule`` in |
| 174 | + ``backends/cuda/runtime/weight_offload/weight_offload.h``. |
| 175 | +
|
| 176 | + The pass is the single authoritative source for the pin set. The |
| 177 | + runtime has NO pin-set option; it parses the list out of the |
| 178 | + partition payload and passes it to the Session unchanged. Pin set |
| 179 | + affects floor correctness (the floor is computed assuming pinned |
| 180 | + FQNs do not stream), so the runtime cannot override it. |
| 181 | +
|
| 182 | + AOTI constant-folding contract: |
| 183 | + The pass operates on parameter placeholders in the ExportedProgram. |
| 184 | + AOTI knobs that fold parameters out of the container at compile |
| 185 | + time (so they no longer appear in ``get_constant_name(idx)``) |
| 186 | + break offload: the pass cannot insert probes for parameters it |
| 187 | + cannot see, and the folded constants would be loaded eagerly |
| 188 | + through the normal blob path at runtime — silently defeating |
| 189 | + offload and reintroducing the OOM this feature exists to prevent. |
| 190 | +
|
| 191 | + Exports that enable weight offload must NOT have AOTI fold |
| 192 | + parameter constants. The exact ``torch._inductor.config`` knob |
| 193 | + and its required value is verified in the implementation PR. |
| 194 | + The pass hard-fails at export if it detects folded parameters |
| 195 | + (placeholder count below the expected catalog count derived |
| 196 | + from ``exported_program.state_dict``); the runtime hard-fails |
| 197 | + again at ``Session::bind_placeholder_constants`` if any catalog |
| 198 | + FQN is missing a probe binding — defense in depth against the |
| 199 | + two halves drifting. |
| 200 | +
|
| 201 | + Metadata transport: |
| 202 | + The returned ``dict`` is the offload payload that the |
| 203 | + implementation PR will route to ``cuda_backend.cpp::init`` via the |
| 204 | + AOTI ``preprocess`` path (see the "payload transport" open item |
| 205 | + in the module docstring for the design constraint and the two |
| 206 | + candidate channels: ``processed_bytes`` vs. ``NamedDataStore``). |
| 207 | +
|
| 208 | + Args: |
| 209 | + exported_program: an ``ExportedProgram`` produced by |
| 210 | + ``torch.export.export``. Mutated in place: probe nodes are |
| 211 | + inserted, consumer args are rewritten. |
| 212 | + method_name: the method this pass is being applied to. Returned |
| 213 | + verbatim in the payload so the runtime can validate which |
| 214 | + method the bytes belong to. |
| 215 | + pin_fqns: FQNs to mark as always-resident. Optional. The list |
| 216 | + is propagated verbatim into the payload AND used by this pass |
| 217 | + to exclude those FQNs from the schedule and the floor |
| 218 | + calculation. Pinning an FQN that does not appear as a parameter |
| 219 | + placeholder is a hard error. |
| 220 | +
|
| 221 | + Returns: a ``dict`` with the keys defined at module scope (all |
| 222 | + internal payload, not opt-in signals): |
| 223 | +
|
| 224 | + - ``"version"``: ``int`` schema version (currently ``1``). Runtime |
| 225 | + hard-fails on unknown version. |
| 226 | + - ``"method_name"``: ``str``. Echoed for runtime validation. |
| 227 | + - ``"schedule"``: ``list[str]`` of NON-PINNED parameter FQNs in |
| 228 | + execution order. Drives the runtime cursor + prefetch (or is |
| 229 | + obviated entirely if the implementation PR picks option (b) of |
| 230 | + the "Schedule / cursor order" open item -- a probe-id arg). |
| 231 | + Pinned FQNs do not appear here; their probes take a separate |
| 232 | + fast-path in ``Session::serve`` that does not touch the cursor. |
| 233 | + - ``"floor_bytes"``: ``int`` — minimum GPU byte budget for the |
| 234 | + streaming portion of the working set (``max-over-consecutive- |
| 235 | + kernel-pairs of (sum bytes K_i + K_{i+1}) + max single weight |
| 236 | + size``), computed over the schedule above (i.e. excluding |
| 237 | + pinned weights). The runtime asserts |
| 238 | + ``(weight_offload_budget_mb << 20) - pinned_bytes`` covers |
| 239 | + this; below-floor budgets hard-fail at init with the required |
| 240 | + minimum spelled out. |
| 241 | + - ``"pin_fqns"``: ``list[str]`` of FQNs the runtime keeps |
| 242 | + resident. Empty if ``pin_fqns`` is unset. |
| 243 | +
|
| 244 | + The opt-in signal is intended to be a separate |
| 245 | + ``CompileSpec("weight_offload", b"1")`` emitted by ``CudaPartitioner`` |
| 246 | + when wired -- the enable signal lives in exactly one place, the |
| 247 | + compile spec, rather than being duplicated across compile spec + |
| 248 | + payload. Neither the compile spec nor the partitioner kwarg exists |
| 249 | + in this PR; see the EXPERIMENTAL banner at the top of this module. |
| 250 | +
|
| 251 | + Not called by users. Not called by anything in this PR. |
| 252 | + """ |
| 253 | + ... |
0 commit comments