Skip to content

Commit 6c815cd

Browse files
committed
Offload cleanup (round 4): centralize validation at the payload parser
Net -288 lines. The trust-boundary architecture is now clean: the payload parser in payload.h validates everything a downstream consumer might want to know about a parsed Payload, and downstream code (cuda_backend.cpp init, Session) just trusts the parsed struct. Parser additions (payload.h): * dtype allow-list (the supported set the runtime can elementSize + H2D for). Rejects unknown / corrupt dtype codes before they can truncate through the int8_t-backed ScalarType cast downstream. * nbytes == element_size(dtype) * product(sizes) cross-field check with overflow-aware multiplication. Catches schema drift in the serializer. * C-contiguous + offset-zero check (storage_offset == 0, strides[i] == product(sizes[i+1..])). The host mirror is sized for logical bytes and the H2D copy is dense, so any non- contiguous layout would over- or under-read. * pin_fqns dedup + pin_fqns subset-of schedule. Downstream removals (session.cpp): * logical_nbytes_for(): used info.storage_nbytes directly (the parser already validated nbytes == elementSize * product). * validate_scheduled_layout(): contiguity / storage_offset / storage==logical all validated at parse. * Per-FQN device_type / device_index single-device tracking (parser hard-fails device_index != 0; session unconditionally on device 0). * Per-FQN catalog-lookup defensive error (use catalog.at()). * pin_set_.insert duplicate check (parser dedups). * pin_fqns subset-of validated check (parser enforces). * pinned_bytes_total recompute loop + overflow check (now a parameter into Session::create from init). * required_total = floor + pinned overflow check (init resolves the budget; the < required_total enforcement stays). Downstream removals (cuda_backend.cpp init): * pin_seen dedup loop (parser). * device_index re-check on first metadata entry (parser). * Per-FQN dtype allow-list + sizes-positive + logical_nbytes overflow + payload-nbytes-vs-derived-logical cross-check (parser owns all of these). KEPT the AOTI data_sizes[i] vs payload nbytes check -- AOTI is the one genuinely independent source of truth, and the two could drift if the .pte and .so were built from different versions. Downstream removals (weight_offload_pass.py): * pin_fqns dedup loop in _apply_weight_offload (the partitioner is the user-facing API; raw-spec test paths flow through the parser which now dedups). * Pruned the contiguity-check rationale to one line (Python int can't overflow so no overflow comment). API change: Session::create gains a uint64_t pinned_bytes_total parameter (init computes it once and passes it in, instead of session recomputing). The next round will fold this and several other init-managed inputs into Session itself as the structural refactor lands. 61 of 61 offload tests pass; lint clean. Authored with Claude.
1 parent 4a05736 commit 6c815cd

5 files changed

Lines changed: 128 additions & 416 deletions

File tree

backends/cuda/passes/weight_offload_pass.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -702,18 +702,7 @@ def _apply_weight_offload(
702702
contains probe nodes -- the second pass would insert probes on
703703
the probes' outputs.
704704
"""
705-
# Canonicalize pin_fqns: dedupe while preserving first-seen order
706-
# so the payload is stable. The runtime hard-fails on duplicates
707-
# at parse time; deduping here keeps harmless caller mistakes
708-
# from reaching that hard-fail.
709-
raw_pins = list(pin_fqns or [])
710-
pin_fqns = []
711-
_seen_pin = set()
712-
for fqn in raw_pins:
713-
if fqn in _seen_pin:
714-
continue
715-
_seen_pin.add(fqn)
716-
pin_fqns.append(fqn)
705+
pin_fqns = list(pin_fqns or [])
717706
graph = exported_program.graph_module.graph
718707

719708
# Re-entering the pass would wrap each probe's input (still a
@@ -781,19 +770,15 @@ def _nbytes(fqn: str) -> int:
781770
f"weight offload: FQN {fqn!r} appears as a placeholder but "
782771
f"is missing from state_dict and constants"
783772
)
784-
# The runtime serves weights by raw-byte H2D from a contiguous
785-
# host mirror sized at numel * element_size. A non-contiguous
786-
# parameter / buffer (e.g. a register_buffer holding a strided
787-
# view) would have storage_nbytes > logical_nbytes; the runtime
788-
# would later hard-fail in validate_scheduled_layout, but the
789-
# error is far clearer if we name the offending FQN at export.
773+
# The host mirror is sized at numel * element_size; a non-
774+
# contiguous tensor would over-read its storage. The runtime
775+
# parser also rejects non-contiguous metadata, but flagging
776+
# here names the FQN with a Python stack trace.
790777
if not t.is_contiguous():
791778
raise ValueError(
792-
f"weight offload: FQN {fqn!r} is a non-contiguous tensor "
779+
f"weight offload: FQN {fqn!r} is non-contiguous "
793780
f"(shape={tuple(t.shape)}, strides={tuple(t.stride())}); "
794-
f"the offload runtime only supports contiguous "
795-
f"parameters and buffers. Call .contiguous() on the "
796-
f"source tensor before exporting."
781+
f"call .contiguous() on the source tensor before exporting"
797782
)
798783
return t.numel() * t.element_size()
799784

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 17 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -430,49 +430,12 @@ class ET_EXPERIMENTAL CudaBackend final
430430
offload_payload = std::move(parsed.get());
431431
offload_buf->Free();
432432

433-
// Fail-fast on payload-derived configurations we do NOT
434-
// support yet, BEFORE container creation / .so load / blob
435-
// fetching — no point allocating GPU state for a config we'd
436-
// throw away.
433+
// Fail-fast on incompatible runtime modes BEFORE container
434+
// creation / .so load / blob fetching — no point allocating
435+
// GPU state for a config we'd throw away. pin_fqns dedup,
436+
// device_index==0, and per-FQN metadata invariants are all
437+
// enforced by the payload parser.
437438
//
438-
// Deduplicate pin_fqns. The pass + partitioner emit a canonical
439-
// list, but a corrupted or hand-rolled payload could repeat.
440-
// Hard-fail at parse time so accounting / allocation downstream
441-
// can rely on a 1:1 fqn↔allocation mapping. Session::create's
442-
// pinned_.emplace() is the second-layer guard.
443-
{
444-
std::unordered_set<std::string> pin_seen;
445-
pin_seen.reserve(offload_payload.pin_fqns.size());
446-
for (const auto& fqn : offload_payload.pin_fqns) {
447-
if (!pin_seen.insert(fqn).second) {
448-
std::fprintf(
449-
stderr,
450-
"[ET_WEIGHT_OFFLOAD][ERROR] method '%s': duplicate FQN "
451-
"'%s' in payload.pin_fqns\n",
452-
method_name.c_str(),
453-
fqn.c_str());
454-
return Error::InvalidArgument;
455-
}
456-
}
457-
}
458-
459-
// Single-device constraint. ``create_with_device("cuda", nullptr)``
460-
// doesn't take a per-method device index; dummies + stream + pool
461-
// land on device 0 regardless of payload. The parser already
462-
// validated device_index == 0 per-entry; this is a belt-and-
463-
// braces re-check on the first entry (and a no-op for empty
464-
// metadata).
465-
if (!offload_payload.constants_metadata.empty() &&
466-
offload_payload.constants_metadata[0].device_index != 0) {
467-
std::fprintf(
468-
stderr,
469-
"[ET_WEIGHT_OFFLOAD][ERROR] method '%s' has device_index=%d "
470-
"in payload metadata; only device 0 is supported\n",
471-
method_name.c_str(),
472-
offload_payload.constants_metadata[0].device_index);
473-
return Error::InvalidArgument;
474-
}
475-
476439
// Disallow shared-stream mode with offload. The shared stream
477440
// (see create_shared_cuda_stream) is created on whichever
478441
// device happened to be current at the time of the first
@@ -876,115 +839,24 @@ class ET_EXPERIMENTAL CudaBackend final
876839
for (const auto& m : offload_payload.constants_metadata) {
877840
fqn_to_meta[m.fqn] = &m;
878841
}
842+
// Cross-check AOTI's per-constant data_size against the
843+
// payload's nbytes. This is the one check init still needs to
844+
// do because the two sides are independent sources of truth:
845+
// the parser validated payload internals (dtype + sizes ->
846+
// nbytes consistency, contiguity, etc.), but AOTI's container
847+
// is a separate origin and could disagree with the payload if
848+
// the .pte and the .so were built from drifted sources.
879849
for (size_t i = 0; i < num_constants; ++i) {
880850
const auto& fqn = aoti_fqns[i];
881-
auto m_it = fqn_to_meta.find(fqn);
882-
if (m_it == fqn_to_meta.end()) {
883-
// Parser's cross-field check should have caught this, but
884-
// defensively re-verify.
885-
std::fprintf(
886-
stderr,
887-
"[ET_WEIGHT_OFFLOAD][ERROR] FQN '%s' missing from payload "
888-
"metadata\n",
889-
fqn.c_str());
890-
delete handle;
891-
return Error::Internal;
892-
}
893-
const auto& m = *m_it->second;
894-
// Validate the RAW int32 dtype against a supported code set
895-
// BEFORE casting to slim ScalarType: ScalarType is int8_t-
896-
// backed, so a corrupted dtype like 256 would silently
897-
// truncate to 0 (Byte) on cast. Mirrors the pass-side
898-
// _TORCH_DTYPE_TO_C10 map.
899-
static constexpr int32_t kSupportedDtypeCodes[] = {
900-
0, // Byte / uint8
901-
1, // Char / int8
902-
2, // Short / int16
903-
3, // Int / int32
904-
4, // Long / int64
905-
5, // Half / float16
906-
6, // Float / float32
907-
11, // Bool
908-
15, // BFloat16
909-
};
910-
bool dtype_supported = false;
911-
for (int32_t code : kSupportedDtypeCodes) {
912-
if (m.dtype == code) {
913-
dtype_supported = true;
914-
break;
915-
}
916-
}
917-
if (!dtype_supported) {
918-
std::fprintf(
919-
stderr,
920-
"[ET_WEIGHT_OFFLOAD][ERROR] FQN '%s': dtype=%d not in the "
921-
"offload-supported set\n",
922-
fqn.c_str(),
923-
m.dtype);
924-
delete handle;
925-
return Error::InvalidArgument;
926-
}
927-
auto slim_dtype =
928-
static_cast<::executorch::backends::aoti::slim::c10::ScalarType>(
929-
m.dtype);
930-
uint64_t logical =
931-
::executorch::backends::aoti::slim::c10::elementSize(slim_dtype);
932-
for (int64_t s : m.sizes) {
933-
if (s <= 0) {
934-
std::fprintf(
935-
stderr,
936-
"[ET_WEIGHT_OFFLOAD][ERROR] FQN '%s': non-positive size "
937-
"%lld in payload\n",
938-
fqn.c_str(),
939-
static_cast<long long>(s));
940-
delete handle;
941-
return Error::Internal;
942-
}
943-
// Portable overflow check (MSVC has no __builtin_mul_overflow
944-
// for 64-bit). For unsigned a * b: overflow iff
945-
// b != 0 && a > UINT64_MAX / b. Guard b != 0 first;
946-
// logical starts at elementSize(dtype) which is > 0, and
947-
// s > 0 from the check above, so the second condition is
948-
// the actual safety net.
949-
const uint64_t s_u = static_cast<uint64_t>(s);
950-
if (s_u != 0 &&
951-
logical > std::numeric_limits<uint64_t>::max() / s_u) {
952-
std::fprintf(
953-
stderr,
954-
"[ET_WEIGHT_OFFLOAD][ERROR] FQN '%s': logical nbytes "
955-
"overflow (dtype=%d)\n",
956-
fqn.c_str(),
957-
m.dtype);
958-
delete handle;
959-
return Error::Internal;
960-
}
961-
logical *= s_u;
962-
}
963-
if (logical != static_cast<uint64_t>(data_sizes[i])) {
851+
const auto& m = *fqn_to_meta.at(fqn);
852+
if (m.nbytes != static_cast<uint64_t>(data_sizes[i])) {
964853
std::fprintf(
965854
stderr,
966855
"[ET_WEIGHT_OFFLOAD][ERROR] FQN '%s': AOTI data_size=%zu "
967-
"vs payload logical nbytes=%llu\n",
856+
"vs payload nbytes=%llu (pass <-> AOTI drift)\n",
968857
fqn.c_str(),
969858
data_sizes[i],
970-
static_cast<unsigned long long>(logical));
971-
delete handle;
972-
return Error::InvalidArgument;
973-
}
974-
// Payload also carries an explicit nbytes field for defense
975-
// in depth. The pass computes it as dtype*product(sizes), so
976-
// it must equal `logical` we just computed. Treat any
977-
// mismatch as schema drift / corrupted payload — if v2
978-
// promises this field, a stale value here means we can't
979-
// trust the rest of the metadata either.
980-
if (m.nbytes != logical) {
981-
std::fprintf(
982-
stderr,
983-
"[ET_WEIGHT_OFFLOAD][ERROR] FQN '%s': payload nbytes=%llu "
984-
"vs dtype+sizes logical=%llu\n",
985-
fqn.c_str(),
986-
static_cast<unsigned long long>(m.nbytes),
987-
static_cast<unsigned long long>(logical));
859+
static_cast<unsigned long long>(m.nbytes));
988860
delete handle;
989861
return Error::InvalidArgument;
990862
}
@@ -1274,6 +1146,7 @@ class ET_EXPERIMENTAL CudaBackend final
12741146
session_catalog,
12751147
handle->get_cuda_stream(),
12761148
resolved_budget,
1149+
pinned_bytes_total,
12771150
static_cast<const uint8_t*>(blob_buf->data()),
12781151
fqn_offsets,
12791152
std::move(dummy_guard),

backends/cuda/runtime/weight_offload/payload.h

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
#include <cstddef>
3636
#include <cstdint>
3737
#include <cstring>
38+
#include <limits>
3839
#include <string>
40+
#include <unordered_set>
3941
#include <vector>
4042

4143
#include <executorch/runtime/core/error.h>
@@ -166,23 +168,53 @@ class Cursor {
166168
size_t offset_{0};
167169
};
168170

169-
// Read a single ConstantMetadata entry with per-field bounds. Used in
170-
// the inner loop of parse_payload.
171+
// Element size by dtype code. Returns 0 for unsupported codes, which
172+
// the caller treats as an invalid-payload signal. The supported set
173+
// mirrors the pass-side ``_TORCH_DTYPE_TO_C10`` map; extending one
174+
// without the other will hard-fail at parse, which is the intended
175+
// drift signal.
176+
inline uint64_t element_size(int32_t dtype) {
177+
switch (dtype) {
178+
case 0: // uint8
179+
case 1: // int8
180+
case 11: // bool
181+
return 1;
182+
case 2: // int16
183+
case 5: // float16
184+
case 15: // bfloat16
185+
return 2;
186+
case 3: // int32
187+
case 6: // float32
188+
return 4;
189+
case 4: // int64
190+
return 8;
191+
default:
192+
return 0;
193+
}
194+
}
195+
196+
// Read a single ConstantMetadata entry with per-field bounds + cross-
197+
// field consistency (dtype is supported, sizes positive, strides
198+
// describe a C-contiguous layout, storage_offset == 0, nbytes ==
199+
// elementSize(dtype) * product(sizes)). Catching these at parse means
200+
// downstream code can trust the parsed struct directly.
171201
inline ::executorch::runtime::Error read_constant_metadata(
172202
Cursor& cur,
173203
ConstantMetadata& m) {
174204
using ::executorch::runtime::Error;
175205
if (cur.read_bounded_string(m.fqn, kMaxStrLen) != Error::Ok) {
176206
return Error::InvalidArgument;
177207
}
178-
// Per-entry FQN must be non-empty — empty fqn means no addressable
179-
// probe routing.
180208
if (m.fqn.empty()) {
181209
return Error::InvalidArgument;
182210
}
183211
if (cur.read_i32(m.dtype) != Error::Ok) {
184212
return Error::InvalidArgument;
185213
}
214+
const uint64_t esize = element_size(m.dtype);
215+
if (esize == 0) {
216+
return Error::InvalidArgument;
217+
}
186218
uint32_t ndim = 0;
187219
if (cur.read_u32(ndim) != Error::Ok) {
188220
return Error::InvalidArgument;
@@ -191,30 +223,57 @@ inline ::executorch::runtime::Error read_constant_metadata(
191223
return Error::InvalidArgument;
192224
}
193225
m.sizes.resize(ndim);
226+
uint64_t logical = esize;
194227
for (uint32_t k = 0; k < ndim; ++k) {
195228
if (cur.read_i64(m.sizes[k]) != Error::Ok) {
196229
return Error::InvalidArgument;
197230
}
198-
// Positive sizes only — zero-product constants are hard-failed
199-
// per the v8 zero-byte policy. Scalars (ndim==0) skip this loop
200-
// entirely and are accepted by construction: numel == 1, no
201-
// dimension to validate.
231+
// Positive sizes only. Scalars (ndim==0) skip this loop entirely
232+
// and are accepted: logical stays at element_size, numel == 1.
202233
if (m.sizes[k] <= 0) {
203234
return Error::InvalidArgument;
204235
}
236+
const uint64_t s_u = static_cast<uint64_t>(m.sizes[k]);
237+
if (logical > std::numeric_limits<uint64_t>::max() / s_u) {
238+
return Error::InvalidArgument;
239+
}
240+
logical *= s_u;
205241
}
206242
m.strides.resize(ndim);
207243
for (uint32_t k = 0; k < ndim; ++k) {
208244
if (cur.read_i64(m.strides[k]) != Error::Ok) {
209245
return Error::InvalidArgument;
210246
}
211247
}
248+
// Strides must describe a C-contiguous layout: strides[i] ==
249+
// product(sizes[i+1..]). The offload host mirror is sized for
250+
// logical bytes and the H2D copy is dense, so any non-contiguous
251+
// layout would over- or under-read.
252+
{
253+
int64_t expected = 1;
254+
for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
255+
if (m.strides[i] != expected) {
256+
return Error::InvalidArgument;
257+
}
258+
expected *= m.sizes[i];
259+
}
260+
}
212261
if (cur.read_i64(m.storage_offset) != Error::Ok) {
213262
return Error::InvalidArgument;
214263
}
264+
if (m.storage_offset != 0) {
265+
return Error::InvalidArgument;
266+
}
215267
if (cur.read_u64(m.nbytes) != Error::Ok) {
216268
return Error::InvalidArgument;
217269
}
270+
// nbytes must equal the logical byte count derived from dtype +
271+
// sizes. The pass writes it as `element_size * product(sizes)`;
272+
// catching drift here means downstream consumers can read either
273+
// field interchangeably.
274+
if (m.nbytes != logical) {
275+
return Error::InvalidArgument;
276+
}
218277
if (cur.read_i32(m.device_type) != Error::Ok) {
219278
return Error::InvalidArgument;
220279
}
@@ -331,7 +390,10 @@ inline ::executorch::runtime::Result<Payload> parse_payload(
331390
// Cross-field invariants for v2:
332391
// - constants_metadata FQN set must equal unique(schedule).
333392
// - No duplicate FQNs across metadata entries.
334-
// Catching these at parse means we hard-fail before any GPU work.
393+
// - No duplicate FQNs in pin_fqns.
394+
// - Every pin_fqn must appear in the schedule.
395+
// Catching these at parse means downstream code (init, Session)
396+
// can trust the parsed struct without re-validating.
335397
if (!p.constants_metadata.empty() || !p.schedule.empty()) {
336398
std::vector<std::string> md_fqns;
337399
md_fqns.reserve(p.constants_metadata.size());
@@ -353,6 +415,20 @@ inline ::executorch::runtime::Result<Payload> parse_payload(
353415
return Error::InvalidArgument;
354416
}
355417
}
418+
if (!p.pin_fqns.empty()) {
419+
std::unordered_set<std::string> sched_set(
420+
p.schedule.begin(), p.schedule.end());
421+
std::unordered_set<std::string> pin_set;
422+
pin_set.reserve(p.pin_fqns.size());
423+
for (const auto& f : p.pin_fqns) {
424+
if (!pin_set.insert(f).second) {
425+
return Error::InvalidArgument; // duplicate in pin_fqns
426+
}
427+
if (sched_set.find(f) == sched_set.end()) {
428+
return Error::InvalidArgument; // pin_fqn not in schedule
429+
}
430+
}
431+
}
356432

357433
return p;
358434
}

0 commit comments

Comments
 (0)