99#pragma once
1010
1111// ===========================================================================
12- // EXPERIMENTAL -- PER -FQN AOTI CONSTANT METADATA SOURCE
12+ // EXPERIMENTAL -- per -FQN AOTI constant metadata
1313// ===========================================================================
14- // Builds a per-FQN view (dtype, sizes, nbytes, live data pointer) of
15- // the AOTI container's constants by combining
16- // ``get_num_constants`` / ``get_constant_original_fqn`` /
17- // ``extract_constants_map`` (all existing AOTI APIs — no upstream
18- // PyTorch changes required for offload to know what each constant
19- // looks like) .
14+ // Holds the canonical ConstantInfo struct used by the offload runtime
15+ // (Session, cuda_backend.cpp). The runtime builds its catalog inline
16+ // in CudaBackend::init from the payload's per-FQN metadata block --
17+ // not from AOTI's extract_constants_map, which after dummy installation
18+ // returns placeholder metadata for the installed dummies rather than
19+ // the originals .
2020//
21- // Scope: the catalog is the METADATA source for offload. It is NOT
22- // the host-byte source for the eventual host mirror — relying on
23- // ``data_ptr()`` here would imply "load every weight to GPU first,
24- // then free", which is exactly the path the weight-offload feature
25- // exists to avoid. The host-byte source (likely the
26- // ``_weights_blob`` NamedData entry, indexed by per-constant
27- // offsets sourced separately) is a problem the host-mirror commit
28- // solves.
29- //
30- // Built once per (handle, method) AFTER constants are loaded so the
31- // returned ``data_ptr`` reflects the FINAL active handles —
32- // cross-method weight sharing in ``cuda_backend.cpp`` can swap the
33- // container's constant pointers during load, so reading before load
34- // gives stale data_ptrs.
21+ // Consumers that copy bytes (Session::serve) must use the LOGICAL size
22+ // (product(sizes) * elementSize(dtype)) for the H2D length, not
23+ // storage_nbytes -- view-style constants can have storage_nbytes >
24+ // logical, and Session::create hard-fails any scheduled FQN where
25+ // the two disagree.
3526// ===========================================================================
3627
3728#include < cstdint>
3829#include < string>
39- #include < unordered_map>
4030#include < vector>
4131
42- #include < executorch/backends/aoti/aoti_delegate_handle.h>
43- #include < executorch/backends/aoti/common_shims_slim.h>
44- #include < executorch/runtime/core/error.h>
45- #include < executorch/runtime/core/result.h>
46-
4732namespace executorch ::backends::cuda::weight_offload {
4833
4934struct ConstantInfo {
@@ -52,185 +37,18 @@ struct ConstantInfo {
5237 std::vector<int64_t > sizes;
5338 std::vector<int64_t > strides;
5439 int64_t storage_offset{0 };
55- // ``storage_nbytes`` is what ``aoti_torch_get_storage_size``
56- // reports — the byte size of the underlying storage allocation,
57- // which CAN be larger than the logical tensor for view-style
58- // constants. Consumers that copy bytes (Session::serve) must
59- // use the LOGICAL size (``product(sizes) * elementSize(dtype)``)
60- // for the H2D / D2H length, otherwise an offset-zero contiguous
61- // view backed by larger storage will overrun the destination.
62- // Session::create hard-fails any scheduled FQN where
63- // ``storage_nbytes != logical_nbytes``.
6440 uint64_t storage_nbytes{0 };
65- // Device type from `` aoti_torch_get_device_type`` . Session::create
66- // hard-fails any scheduled FQN whose `` device_type != CUDA`` — the
41+ // Device type from aoti_torch_get_device_type. Session::create
42+ // hard-fails any scheduled FQN whose device_type != CUDA -- the
6743 // sync-H2D path has no model for host-resident or other-device
6844 // constants, and silently treating them as device 0 would corrupt
6945 // data on multi-GPU hosts.
7046 int32_t device_type{0 };
7147 int32_t device_index{0 };
72- // Live device pointer for the constant's bytes, valid until the
73- // AOTI container or its user-managed pair table swaps it out.
74- // METADATA OBSERVABILITY ONLY — see the file banner for why this
75- // is not the host-mirror byte source.
48+ // Live device pointer (the installed dummy in the offload path).
49+ // Used for ProbeRegistry registration; the runtime never reads
50+ // bytes from this pointer.
7651 void * data_ptr{nullptr };
7752};
7853
79- // Build the FQN -> ConstantInfo catalog for the AOTI container
80- // associated with ``handle``. Caller MUST have already loaded
81- // constants (``update_constants_from_blob`` or
82- // ``update_user_managed_constant_buffer_pairs``) — see the file
83- // banner.
84- //
85- // Constants whose ``get_constant_original_fqn`` returns null or an
86- // empty string are SKIPPED — AOTI emits unnamed/internal constants
87- // for some lowerings and they're not addressable through the FQN
88- // the pass uses. Matches the existing ``load_constants_with_cache``
89- // filter so opt-in init doesn't hard-fail on containers the
90- // non-offload path accepts cleanly. The schedule ⊆ catalog
91- // validation in ``CudaBackend::init`` then catches the case where
92- // a probed parameter is missing because AOTI folded it (the FQN
93- // won't be in the catalog at all).
94- //
95- // Returns ``Error::Internal`` if any of the required AOTI symbols
96- // are unresolved on the handle, if any AOTI call fails, or if a
97- // named constant from ``get_constant_original_fqn`` is missing
98- // from ``extract_constants_map``. The offload contract is "loud at
99- // init"; opt-in callers should hard-fail on a non-Ok result.
100- inline ::executorch::runtime::Result<
101- std::unordered_map<std::string, ConstantInfo>>
102- build_constant_catalog (
103- ::executorch::backends::aoti::AOTIDelegateHandle* handle) {
104- using ::executorch::backends::aoti::AOTInductorConstantMapHandle;
105- using ::executorch::backends::aoti::AtenTensorHandle;
106- using ::executorch::runtime::Error;
107- using SlimTensor = ::executorch::backends::aoti::Tensor;
108-
109- if (handle == nullptr || handle->container_handle == nullptr ||
110- handle->get_num_constants == nullptr ||
111- handle->get_constant_original_fqn == nullptr ||
112- handle->extract_constants_map == nullptr ) {
113- return Error::Internal;
114- }
115-
116- size_t num_constants = 0 ;
117- if (handle->get_num_constants (handle->container_handle , &num_constants) !=
118- Error::Ok) {
119- return Error::Internal;
120- }
121-
122- // idx -> FQN, skipping unnamed/internal constants whose original
123- // FQN is null or empty (mirrors ``load_constants_with_cache``).
124- std::vector<std::string> named_fqns;
125- named_fqns.reserve (num_constants);
126- for (size_t i = 0 ; i < num_constants; ++i) {
127- const char * fqn = nullptr ;
128- if (handle->get_constant_original_fqn (handle->container_handle , i, &fqn) !=
129- Error::Ok) {
130- return Error::Internal;
131- }
132- if (fqn == nullptr || fqn[0 ] == ' \0 ' ) {
133- continue ;
134- }
135- named_fqns.emplace_back (fqn);
136- }
137-
138- // The AtenTensorHandle values populated below are BORROWED from the
139- // AOTI container — they remain valid for the container's lifetime
140- // and are not owned by ``extracted``. The catalog stores only the
141- // derived dtype / sizes / strides / device fields (see ConstantInfo
142- // below); it never retains the handles past this function, so no
143- // separate teardown of ``extracted`` is required.
144- std::unordered_map<std::string, AtenTensorHandle> extracted;
145- if (handle->extract_constants_map (
146- handle->container_handle ,
147- reinterpret_cast <AOTInductorConstantMapHandle>(&extracted),
148- /* use_inactive=*/ false ) != Error::Ok) {
149- return Error::Internal;
150- }
151-
152- std::unordered_map<std::string, ConstantInfo> catalog;
153- catalog.reserve (named_fqns.size ());
154- for (const auto & fqn : named_fqns) {
155- auto it = extracted.find (fqn);
156- if (it == extracted.end ()) {
157- // get_constant_original_fqn reported this FQN but
158- // extract_constants_map did not surface it — schema drift in
159- // the AOTI container we're not equipped to handle.
160- return Error::Internal;
161- }
162- SlimTensor* tensor = reinterpret_cast <SlimTensor*>(it->second );
163-
164- ConstantInfo info;
165- info.fqn = fqn;
166-
167- int32_t dtype = 0 ;
168- if (::executorch::backends::aoti::aoti_torch_get_dtype (tensor, &dtype) !=
169- Error::Ok) {
170- return Error::Internal;
171- }
172- info.dtype = dtype;
173-
174- int64_t ndim = 0 ;
175- if (::executorch::backends::aoti::aoti_torch_get_dim (tensor, &ndim) !=
176- Error::Ok) {
177- return Error::Internal;
178- }
179- int64_t * sizes_ptr = nullptr ;
180- if (::executorch::backends::aoti::aoti_torch_get_sizes (
181- tensor, &sizes_ptr) != Error::Ok ||
182- (ndim > 0 && sizes_ptr == nullptr )) {
183- return Error::Internal;
184- }
185- info.sizes .assign (sizes_ptr, sizes_ptr + ndim);
186-
187- int64_t * strides_ptr = nullptr ;
188- if (::executorch::backends::aoti::aoti_torch_get_strides (
189- tensor, &strides_ptr) != Error::Ok ||
190- (ndim > 0 && strides_ptr == nullptr )) {
191- return Error::Internal;
192- }
193- info.strides .assign (strides_ptr, strides_ptr + ndim);
194-
195- int64_t storage_offset = 0 ;
196- if (::executorch::backends::aoti::aoti_torch_get_storage_offset (
197- tensor, &storage_offset) != Error::Ok) {
198- return Error::Internal;
199- }
200- info.storage_offset = storage_offset;
201-
202- int64_t storage_size = 0 ;
203- if (::executorch::backends::aoti::aoti_torch_get_storage_size (
204- tensor, &storage_size) != Error::Ok ||
205- storage_size < 0 ) {
206- return Error::Internal;
207- }
208- info.storage_nbytes = static_cast <uint64_t >(storage_size);
209-
210- int32_t device_type = 0 ;
211- if (::executorch::backends::aoti::aoti_torch_get_device_type (
212- tensor, &device_type) != Error::Ok) {
213- return Error::Internal;
214- }
215- info.device_type = device_type;
216- int32_t device_index = 0 ;
217- if (::executorch::backends::aoti::aoti_torch_get_device_index (
218- tensor, &device_index) != Error::Ok) {
219- return Error::Internal;
220- }
221- info.device_index = device_index;
222-
223- void * data_ptr = nullptr ;
224- if (::executorch::backends::aoti::aoti_torch_get_data_ptr (
225- tensor, &data_ptr) != Error::Ok) {
226- return Error::Internal;
227- }
228- info.data_ptr = data_ptr;
229-
230- catalog.emplace (fqn, std::move (info));
231- }
232-
233- return catalog;
234- }
235-
23654} // namespace executorch::backends::cuda::weight_offload
0 commit comments