Skip to content

Commit 31e6036

Browse files
committed
Offload cleanup (round 2): aggressive trim of headers, validation, docstrings
Net -829 lines. Implementation unchanged; presentation tightened further per the second review pass. Delete the catalog test file. test_weight_offload_catalog.py had 3 tests: two verified success-log fields, one tested the schedule-not- in-catalog hard-fail. The hard-fail itself stays in cuda_backend.cpp and any offload test that exercises a real model will fail loudly if the check ever regresses. -445 lines + 25 from the TARGETS entry. Remove build_constant_catalog() from constant_catalog.h. The runtime builds its catalog inline in CudaBackend::init from the payload's per-FQN metadata block (extract_constants_map returns dummy metadata after install, so the old AOTI-walk approach can't work post-dummy). The 135-line helper had zero callers. ConstantInfo stays -- it's still the canonical per-FQN struct used by Session + cuda_backend.cpp. Shrink session.h banner. The 52-line file banner restated content already in weight_offload.h's design overview. Down to a 6-line note about ownership + lifecycle. Inlined BudgetSpec field comments. 337 -> 216 lines. Tighten ~25 offload-init error fprintfs in cuda_backend.cpp. Each site was 4-8 lines with prose; collapsed to one-liners that retain method name + FQN + error code. Behavior identical -- only the diagnostic format changes. -78 lines. Trim 5 multi-paragraph module docstrings in test files. Test names + assertions already document intent; the docstrings restated them in prose. -55 lines. Test string update: test_runtime_hard_fails_on_corrupted_payload matched the old "failed to parse" phrasing; updated to match the new "payload parse failed". 64 of 64 offload tests pass (1 skipped); lint clean. Authored with Claude.
1 parent e3db391 commit 31e6036

10 files changed

Lines changed: 139 additions & 968 deletions

backends/cuda/passes/tests/test_weight_offload_pass.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""Contract tests for ``_apply_weight_offload``.
8-
9-
The pass owns the export-time half of CUDA weight offloading: it
10-
rewrites parameter consumers to read through ``probe(w, probe_id)``
11-
nodes and returns the v1 offload payload (``version``,
12-
``method_name``, ``schedule``, ``floor_bytes``, ``pin_fqns``).
13-
14-
These tests assert the public contract — what the returned payload
15-
contains, where probe nodes appear, how ``probe_id`` lines up with
16-
``schedule``, how view chains on weights are duplicated per
17-
consumer, the set-union semantics of the floor calculation, and
18-
which inputs hard-fail. They do NOT exercise the runtime serve path
19-
(covered by the dispatch test under ``backends/cuda/tests``) or any
20-
partitioner / opt-in plumbing (still unwired in this PR).
21-
"""
7+
"""Contract tests for ``_apply_weight_offload``: payload contents,
8+
probe-node placement, probe_id ↔ schedule alignment, view-chain
9+
duplication, floor set-union semantics, and hard-fail paths. Runtime
10+
serve and partitioner plumbing are covered by other test files."""
2211

2312
import unittest
2413

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 78 additions & 115 deletions
Large diffs are not rendered by default.

backends/cuda/runtime/weight_offload/constant_catalog.h

Lines changed: 17 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,26 @@
99
#pragma once
1010

1111
// ===========================================================================
12-
// EXPERIMENTAL -- PER-FQN AOTI CONSTANT METADATA SOURCE
12+
// EXPERIMENTAL -- per-FQN AOTI constant metadata
1313
// ===========================================================================
14-
// Builds a per-FQN view (dtype, sizes, nbytes, live data pointer) of
15-
// the AOTI container's constants by combining
16-
// ``get_num_constants`` / ``get_constant_original_fqn`` /
17-
// ``extract_constants_map`` (all existing AOTI APIs — no upstream
18-
// PyTorch changes required for offload to know what each constant
19-
// looks like).
14+
// Holds the canonical ConstantInfo struct used by the offload runtime
15+
// (Session, cuda_backend.cpp). The runtime builds its catalog inline
16+
// in CudaBackend::init from the payload's per-FQN metadata block --
17+
// not from AOTI's extract_constants_map, which after dummy installation
18+
// returns placeholder metadata for the installed dummies rather than
19+
// the originals.
2020
//
21-
// Scope: the catalog is the METADATA source for offload. It is NOT
22-
// the host-byte source for the eventual host mirror — relying on
23-
// ``data_ptr()`` here would imply "load every weight to GPU first,
24-
// then free", which is exactly the path the weight-offload feature
25-
// exists to avoid. The host-byte source (likely the
26-
// ``_weights_blob`` NamedData entry, indexed by per-constant
27-
// offsets sourced separately) is a problem the host-mirror commit
28-
// solves.
29-
//
30-
// Built once per (handle, method) AFTER constants are loaded so the
31-
// returned ``data_ptr`` reflects the FINAL active handles —
32-
// cross-method weight sharing in ``cuda_backend.cpp`` can swap the
33-
// container's constant pointers during load, so reading before load
34-
// gives stale data_ptrs.
21+
// Consumers that copy bytes (Session::serve) must use the LOGICAL size
22+
// (product(sizes) * elementSize(dtype)) for the H2D length, not
23+
// storage_nbytes -- view-style constants can have storage_nbytes >
24+
// logical, and Session::create hard-fails any scheduled FQN where
25+
// the two disagree.
3526
// ===========================================================================
3627

3728
#include <cstdint>
3829
#include <string>
39-
#include <unordered_map>
4030
#include <vector>
4131

42-
#include <executorch/backends/aoti/aoti_delegate_handle.h>
43-
#include <executorch/backends/aoti/common_shims_slim.h>
44-
#include <executorch/runtime/core/error.h>
45-
#include <executorch/runtime/core/result.h>
46-
4732
namespace executorch::backends::cuda::weight_offload {
4833

4934
struct ConstantInfo {
@@ -52,185 +37,18 @@ struct ConstantInfo {
5237
std::vector<int64_t> sizes;
5338
std::vector<int64_t> strides;
5439
int64_t storage_offset{0};
55-
// ``storage_nbytes`` is what ``aoti_torch_get_storage_size``
56-
// reports — the byte size of the underlying storage allocation,
57-
// which CAN be larger than the logical tensor for view-style
58-
// constants. Consumers that copy bytes (Session::serve) must
59-
// use the LOGICAL size (``product(sizes) * elementSize(dtype)``)
60-
// for the H2D / D2H length, otherwise an offset-zero contiguous
61-
// view backed by larger storage will overrun the destination.
62-
// Session::create hard-fails any scheduled FQN where
63-
// ``storage_nbytes != logical_nbytes``.
6440
uint64_t storage_nbytes{0};
65-
// Device type from ``aoti_torch_get_device_type``. Session::create
66-
// hard-fails any scheduled FQN whose ``device_type != CUDA`` — the
41+
// Device type from aoti_torch_get_device_type. Session::create
42+
// hard-fails any scheduled FQN whose device_type != CUDA -- the
6743
// sync-H2D path has no model for host-resident or other-device
6844
// constants, and silently treating them as device 0 would corrupt
6945
// data on multi-GPU hosts.
7046
int32_t device_type{0};
7147
int32_t device_index{0};
72-
// Live device pointer for the constant's bytes, valid until the
73-
// AOTI container or its user-managed pair table swaps it out.
74-
// METADATA OBSERVABILITY ONLY — see the file banner for why this
75-
// is not the host-mirror byte source.
48+
// Live device pointer (the installed dummy in the offload path).
49+
// Used for ProbeRegistry registration; the runtime never reads
50+
// bytes from this pointer.
7651
void* data_ptr{nullptr};
7752
};
7853

79-
// Build the FQN -> ConstantInfo catalog for the AOTI container
80-
// associated with ``handle``. Caller MUST have already loaded
81-
// constants (``update_constants_from_blob`` or
82-
// ``update_user_managed_constant_buffer_pairs``) — see the file
83-
// banner.
84-
//
85-
// Constants whose ``get_constant_original_fqn`` returns null or an
86-
// empty string are SKIPPED — AOTI emits unnamed/internal constants
87-
// for some lowerings and they're not addressable through the FQN
88-
// the pass uses. Matches the existing ``load_constants_with_cache``
89-
// filter so opt-in init doesn't hard-fail on containers the
90-
// non-offload path accepts cleanly. The schedule ⊆ catalog
91-
// validation in ``CudaBackend::init`` then catches the case where
92-
// a probed parameter is missing because AOTI folded it (the FQN
93-
// won't be in the catalog at all).
94-
//
95-
// Returns ``Error::Internal`` if any of the required AOTI symbols
96-
// are unresolved on the handle, if any AOTI call fails, or if a
97-
// named constant from ``get_constant_original_fqn`` is missing
98-
// from ``extract_constants_map``. The offload contract is "loud at
99-
// init"; opt-in callers should hard-fail on a non-Ok result.
100-
inline ::executorch::runtime::Result<
101-
std::unordered_map<std::string, ConstantInfo>>
102-
build_constant_catalog(
103-
::executorch::backends::aoti::AOTIDelegateHandle* handle) {
104-
using ::executorch::backends::aoti::AOTInductorConstantMapHandle;
105-
using ::executorch::backends::aoti::AtenTensorHandle;
106-
using ::executorch::runtime::Error;
107-
using SlimTensor = ::executorch::backends::aoti::Tensor;
108-
109-
if (handle == nullptr || handle->container_handle == nullptr ||
110-
handle->get_num_constants == nullptr ||
111-
handle->get_constant_original_fqn == nullptr ||
112-
handle->extract_constants_map == nullptr) {
113-
return Error::Internal;
114-
}
115-
116-
size_t num_constants = 0;
117-
if (handle->get_num_constants(handle->container_handle, &num_constants) !=
118-
Error::Ok) {
119-
return Error::Internal;
120-
}
121-
122-
// idx -> FQN, skipping unnamed/internal constants whose original
123-
// FQN is null or empty (mirrors ``load_constants_with_cache``).
124-
std::vector<std::string> named_fqns;
125-
named_fqns.reserve(num_constants);
126-
for (size_t i = 0; i < num_constants; ++i) {
127-
const char* fqn = nullptr;
128-
if (handle->get_constant_original_fqn(handle->container_handle, i, &fqn) !=
129-
Error::Ok) {
130-
return Error::Internal;
131-
}
132-
if (fqn == nullptr || fqn[0] == '\0') {
133-
continue;
134-
}
135-
named_fqns.emplace_back(fqn);
136-
}
137-
138-
// The AtenTensorHandle values populated below are BORROWED from the
139-
// AOTI container — they remain valid for the container's lifetime
140-
// and are not owned by ``extracted``. The catalog stores only the
141-
// derived dtype / sizes / strides / device fields (see ConstantInfo
142-
// below); it never retains the handles past this function, so no
143-
// separate teardown of ``extracted`` is required.
144-
std::unordered_map<std::string, AtenTensorHandle> extracted;
145-
if (handle->extract_constants_map(
146-
handle->container_handle,
147-
reinterpret_cast<AOTInductorConstantMapHandle>(&extracted),
148-
/*use_inactive=*/false) != Error::Ok) {
149-
return Error::Internal;
150-
}
151-
152-
std::unordered_map<std::string, ConstantInfo> catalog;
153-
catalog.reserve(named_fqns.size());
154-
for (const auto& fqn : named_fqns) {
155-
auto it = extracted.find(fqn);
156-
if (it == extracted.end()) {
157-
// get_constant_original_fqn reported this FQN but
158-
// extract_constants_map did not surface it — schema drift in
159-
// the AOTI container we're not equipped to handle.
160-
return Error::Internal;
161-
}
162-
SlimTensor* tensor = reinterpret_cast<SlimTensor*>(it->second);
163-
164-
ConstantInfo info;
165-
info.fqn = fqn;
166-
167-
int32_t dtype = 0;
168-
if (::executorch::backends::aoti::aoti_torch_get_dtype(tensor, &dtype) !=
169-
Error::Ok) {
170-
return Error::Internal;
171-
}
172-
info.dtype = dtype;
173-
174-
int64_t ndim = 0;
175-
if (::executorch::backends::aoti::aoti_torch_get_dim(tensor, &ndim) !=
176-
Error::Ok) {
177-
return Error::Internal;
178-
}
179-
int64_t* sizes_ptr = nullptr;
180-
if (::executorch::backends::aoti::aoti_torch_get_sizes(
181-
tensor, &sizes_ptr) != Error::Ok ||
182-
(ndim > 0 && sizes_ptr == nullptr)) {
183-
return Error::Internal;
184-
}
185-
info.sizes.assign(sizes_ptr, sizes_ptr + ndim);
186-
187-
int64_t* strides_ptr = nullptr;
188-
if (::executorch::backends::aoti::aoti_torch_get_strides(
189-
tensor, &strides_ptr) != Error::Ok ||
190-
(ndim > 0 && strides_ptr == nullptr)) {
191-
return Error::Internal;
192-
}
193-
info.strides.assign(strides_ptr, strides_ptr + ndim);
194-
195-
int64_t storage_offset = 0;
196-
if (::executorch::backends::aoti::aoti_torch_get_storage_offset(
197-
tensor, &storage_offset) != Error::Ok) {
198-
return Error::Internal;
199-
}
200-
info.storage_offset = storage_offset;
201-
202-
int64_t storage_size = 0;
203-
if (::executorch::backends::aoti::aoti_torch_get_storage_size(
204-
tensor, &storage_size) != Error::Ok ||
205-
storage_size < 0) {
206-
return Error::Internal;
207-
}
208-
info.storage_nbytes = static_cast<uint64_t>(storage_size);
209-
210-
int32_t device_type = 0;
211-
if (::executorch::backends::aoti::aoti_torch_get_device_type(
212-
tensor, &device_type) != Error::Ok) {
213-
return Error::Internal;
214-
}
215-
info.device_type = device_type;
216-
int32_t device_index = 0;
217-
if (::executorch::backends::aoti::aoti_torch_get_device_index(
218-
tensor, &device_index) != Error::Ok) {
219-
return Error::Internal;
220-
}
221-
info.device_index = device_index;
222-
223-
void* data_ptr = nullptr;
224-
if (::executorch::backends::aoti::aoti_torch_get_data_ptr(
225-
tensor, &data_ptr) != Error::Ok) {
226-
return Error::Internal;
227-
}
228-
info.data_ptr = data_ptr;
229-
230-
catalog.emplace(fqn, std::move(info));
231-
}
232-
233-
return catalog;
234-
}
235-
23654
} // namespace executorch::backends::cuda::weight_offload

0 commit comments

Comments
 (0)