Skip to content

Commit 266ff2d

Browse files
Gasoonjiagasoonjia
andauthored
Replace chunked FLA with recurrent gated delta rule for T=1 decode (#18667)
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode while is a good fit for prefilling stage. This PR decomposes the single `forward` method with separate `prefill` and `decode` methods that use different FLA implementation (chunked for `prefilll` while recurrent for `decode`) while share KV cache, conv_state, and recurrent_state to boost the decode performance from 77.7 token/s to 88.3 token/s, while maintain the prefill performance. --------- Co-authored-by: gasoonjia <gasoonjia@fb.com>
1 parent 80198ca commit 266ff2d

File tree

6 files changed

+506
-49
lines changed

6 files changed

+506
-49
lines changed

backends/aoti/aoti_delegate_handle.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
3131
using AOTInductorStreamHandle = void*;
3232
using AOTIProxyExecutorHandle = void*;
3333

34+
// Opaque types for AOTI constant management.
35+
// AtenTensorOpaque wraps at::Tensor* in the AOTI runtime — distinct from
36+
// AOTITensorHandle which wraps executorch::runtime::etensor::Tensor*.
37+
struct AtenTensorOpaque;
38+
using AtenTensorHandle = AtenTensorOpaque*;
39+
40+
struct AOTInductorConstantMap;
41+
using AOTInductorConstantMapHandle = AOTInductorConstantMap*;
42+
43+
struct AOTInductorConstantMapEntry {
44+
const char* name;
45+
AtenTensorHandle handle;
46+
};
47+
3448
// Function pointer types for AOT Inductor model container operations
3549
using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)(
3650
AOTInductorModelContainerHandle* container_handle,
@@ -77,6 +91,37 @@ using AOTInductorModelUpdateConstantsFromBlobFunc = AOTIRuntimeError (*)(
7791
AOTInductorModelContainerHandle container_handle,
7892
const uint8_t* weight_blob_ptr);
7993

94+
// Retrieves a constant's AOTI internal name by index.
95+
using AOTInductorModelContainerGetConstantNameFunc = AOTIRuntimeError (*)(
96+
AOTInductorModelContainerHandle container_handle,
97+
size_t idx,
98+
const char** name);
99+
100+
// Retrieves a constant's original fully-qualified name by index.
101+
using AOTInductorModelContainerGetConstantOriginalFQNFunc =
102+
AOTIRuntimeError (*)(
103+
AOTInductorModelContainerHandle container_handle,
104+
size_t idx,
105+
const char** original_fqn);
106+
107+
// Extracts the constants map from the container (active or inactive buffer).
108+
// constant_map_handle should point to a
109+
// std::unordered_map<std::string, AtenTensorHandle>.
110+
using AOTInductorModelContainerExtractConstantsMapFunc = AOTIRuntimeError (*)(
111+
AOTInductorModelContainerHandle container_handle,
112+
AOTInductorConstantMapHandle constant_map_handle,
113+
bool use_inactive);
114+
115+
// Updates the container's constants with user-managed tensor handles.
116+
// DLL-boundary safe — uses a flat C array instead of std::unordered_map.
117+
using AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc =
118+
AOTIRuntimeError (*)(
119+
AOTInductorModelContainerHandle container_handle,
120+
const AOTInductorConstantMapEntry* pairs,
121+
size_t num_pairs,
122+
bool use_inactive,
123+
bool validate_full_update);
124+
80125
} // extern "C"
81126

82127
// AOTI Delegate Handle structure
@@ -93,6 +138,14 @@ struct AOTIDelegateHandle {
93138
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
94139
AOTInductorModelContainerRunFunc run;
95140
AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob;
141+
142+
// Constant management function pointers (for cross-method buffer sharing)
143+
AOTInductorModelContainerGetNumConstantsFunc get_num_constants;
144+
AOTInductorModelContainerGetConstantNameFunc get_constant_name;
145+
AOTInductorModelContainerGetConstantOriginalFQNFunc get_constant_original_fqn;
146+
AOTInductorModelContainerExtractConstantsMapFunc extract_constants_map;
147+
AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc
148+
update_user_managed_constant_buffer_pairs;
96149
};
97150

98151
} // namespace aoti

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,30 @@ class ET_EXPERIMENTAL CudaBackend final
207207
Info,
208208
"Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)");
209209
}
210+
211+
// Load constant management symbols (optional — needed for cross-method
212+
// buffer sharing). These are available in torch >= 2.6.
213+
#define LOAD_OPTIONAL_SYMBOL(member, name) \
214+
do { \
215+
auto res = get_function(so_handle, #name); \
216+
handle->member = \
217+
res.ok() ? reinterpret_cast<name##Func>(res.get()) : nullptr; \
218+
} while (0)
219+
220+
LOAD_OPTIONAL_SYMBOL(
221+
get_num_constants, AOTInductorModelContainerGetNumConstants);
222+
LOAD_OPTIONAL_SYMBOL(
223+
get_constant_name, AOTInductorModelContainerGetConstantName);
224+
LOAD_OPTIONAL_SYMBOL(
225+
get_constant_original_fqn,
226+
AOTInductorModelContainerGetConstantOriginalFQN);
227+
LOAD_OPTIONAL_SYMBOL(
228+
extract_constants_map, AOTInductorModelContainerExtractConstantsMap);
229+
LOAD_OPTIONAL_SYMBOL(
230+
update_user_managed_constant_buffer_pairs,
231+
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs);
232+
#undef LOAD_OPTIONAL_SYMBOL
233+
210234
return Error::Ok;
211235
}
212236

@@ -348,9 +372,20 @@ class ET_EXPERIMENTAL CudaBackend final
348372
const void* weights_blob = buffer_res->data();
349373
// Feed the weights blob into the container. Under the hood it's copying
350374
// weights, so we should free the buffer immediately.
351-
ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob(
352-
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
375+
auto update_err = handle->update_constants_from_blob(
376+
handle->container_handle, static_cast<const uint8_t*>(weights_blob));
377+
if (update_err != Error::Ok) {
378+
ET_LOG(Error, "update_constants_from_blob failed");
379+
return update_err;
380+
}
381+
// Ensure all weight transfers are complete before execution
382+
cudaDeviceSynchronize();
353383
buffer_res->Free();
384+
} else {
385+
ET_LOG(
386+
Info,
387+
"weights_blob '%s' not found or update fn is null",
388+
weights_blob_key.c_str());
354389
}
355390

356391
// Use shared CUDA stream if enabled via options, otherwise create one.
@@ -378,6 +413,105 @@ class ET_EXPERIMENTAL CudaBackend final
378413
method_name.c_str());
379414
}
380415

416+
// ---------------------------------------------------------------
417+
// Cross-method constant sharing (e.g., KV cache between prefill/decode).
418+
//
419+
// The first container to initialize extracts its constants (keyed by
420+
// original FQN) and stores the AtenTensorHandle's. Subsequent containers
421+
// with matching FQNs are updated to point to the same GPU tensors via
422+
// UpdateUserManagedConstantBufferPairs (user_managed = true → no copy,
423+
// the source container retains ownership).
424+
// ---------------------------------------------------------------
425+
if (handle->get_num_constants && handle->get_constant_name &&
426+
handle->get_constant_original_fqn && handle->extract_constants_map &&
427+
handle->update_user_managed_constant_buffer_pairs) {
428+
size_t num_constants = 0;
429+
handle->get_num_constants(handle->container_handle, &num_constants);
430+
431+
if (num_constants > 0) {
432+
// Build FQN → internal_name mapping for this container.
433+
std::unordered_map<std::string, std::string> fqn_to_name;
434+
for (size_t i = 0; i < num_constants; i++) {
435+
const char* name = nullptr;
436+
const char* fqn = nullptr;
437+
handle->get_constant_name(handle->container_handle, i, &name);
438+
handle->get_constant_original_fqn(handle->container_handle, i, &fqn);
439+
if (name && fqn && fqn[0] != '\0') {
440+
fqn_to_name[fqn] = name;
441+
}
442+
}
443+
444+
std::lock_guard<std::mutex> guard(shared_constants_mutex_);
445+
446+
if (!constants_extracted_) {
447+
// First container: extract its constants and store by FQN.
448+
std::unordered_map<std::string, AtenTensorHandle> extracted_map;
449+
auto extract_err = handle->extract_constants_map(
450+
handle->container_handle,
451+
reinterpret_cast<AOTInductorConstantMapHandle>(&extracted_map),
452+
/*use_inactive=*/false);
453+
454+
if (extract_err == Error::Ok) {
455+
for (const auto& [fqn, internal_name] : fqn_to_name) {
456+
auto it = extracted_map.find(fqn);
457+
if (it != extracted_map.end()) {
458+
shared_constant_tensors_[fqn] = it->second;
459+
}
460+
}
461+
constants_extracted_ = true;
462+
ET_LOG(
463+
Info,
464+
"Extracted %zu shared constants from method '%s'",
465+
shared_constant_tensors_.size(),
466+
method_name.c_str());
467+
} else {
468+
ET_LOG(
469+
Error,
470+
"Failed to extract constants from '%s'",
471+
method_name.c_str());
472+
}
473+
} else {
474+
// Subsequent container: share matching constants from the first.
475+
std::vector<AOTInductorConstantMapEntry> pairs;
476+
for (const auto& [fqn, internal_name] : fqn_to_name) {
477+
auto it = shared_constant_tensors_.find(fqn);
478+
if (it != shared_constant_tensors_.end()) {
479+
// UpdateUserManagedConstantBufferPairs matches against the
480+
// codegen constant name (underscored), not the original FQN.
481+
pairs.push_back({internal_name.c_str(), it->second});
482+
}
483+
}
484+
485+
if (!pairs.empty()) {
486+
auto update_err = handle->update_user_managed_constant_buffer_pairs(
487+
handle->container_handle,
488+
pairs.data(),
489+
pairs.size(),
490+
/*use_inactive=*/false,
491+
/*validate_full_update=*/false);
492+
493+
if (update_err == Error::Ok) {
494+
ET_LOG(
495+
Info,
496+
"Shared %zu constants into method '%s'",
497+
pairs.size(),
498+
method_name.c_str());
499+
} else {
500+
ET_LOG(
501+
Error,
502+
"Failed to share constants into '%s'",
503+
method_name.c_str());
504+
}
505+
}
506+
}
507+
}
508+
} else {
509+
ET_LOG(
510+
Info,
511+
"Constant sharing APIs not available for method '%s'",
512+
method_name.c_str());
513+
}
514+
381515
return (DelegateHandle*)handle; // Return the handle post-processing
382516
}
383517

@@ -623,6 +757,22 @@ class ET_EXPERIMENTAL CudaBackend final
623757
mutable std::
624758
unordered_map<cuda::CudaDelegateHandle*, std::vector<SlimTensor*>>
625759
cached_outputs_;
760+
761+
// Cross-method constant sharing state.
762+
// When multiple AOTI containers share mutable buffers (e.g., KV cache),
763+
// the first container's constants are extracted and stored here. Subsequent
764+
// containers with matching FQNs share the same GPU tensors via
765+
// UpdateUserManagedConstantBufferPairs.
766+
mutable std::mutex shared_constants_mutex_;
767+
768+
// FQN → AtenTensorHandle from the source (first) container.
769+
// The tensor handles are owned by the source container (which is never
770+
// explicitly deleted — see destroy() comment).
771+
mutable std::unordered_map<std::string, AtenTensorHandle>
772+
shared_constant_tensors_;
773+
774+
// Whether we've already extracted constants from a source container.
775+
mutable bool constants_extracted_ = false;
626776
};
627777

628778
} // namespace executorch::backends::cuda

examples/models/qwen3_5_moe/export.py

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
107107
# Any missing weight key indicates a version mismatch between the
108108
# checkpoint and the model (e.g., unfused vs fused projections).
109109
runtime_prefixes = (
110+
".mask",
111+
".inv_freq",
110112
".kv_cache.",
111113
".conv_state",
112114
".recurrent_state",
@@ -312,10 +314,11 @@ def _materialize_buffers(model, config):
312314
"""Materialize meta-device buffers before torch.export.
313315
314316
Replaces meta buffers with real tensors on CPU, recomputes RoPE
315-
inv_freq and causal masks.
317+
inv_freq and causal masks. State buffers (KV cache, conv/recurrent
318+
state) are zero-initialized registered buffers that will be shared
319+
across methods via share_mutable_buffers.
316320
"""
317-
# State buffers (KV cache, conv/recurrent state) are bf16 to match
318-
# compute dtype. Masks stay bool, inv_freq stays float32.
321+
# Masks stay bool, inv_freq stays float32.
319322
for fqn, buf in list(model.named_buffers()):
320323
if buf.device.type == "meta":
321324
dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool
@@ -378,7 +381,18 @@ def _apply_turboquant(model, config):
378381

379382

380383
def export_and_lower(model, config, args):
381-
"""Export model to .pte via torch.export + CUDA backend."""
384+
"""Export model to .pte via torch.export + CUDA backend.
385+
386+
Exports two methods:
387+
- "decode": decode path (T=1), uses native PyTorch recurrent FLA
388+
so AOTI can fuse with surrounding ops for maximum decode throughput.
389+
- "prefill": prefill path (T>=2), uses chunked FLA triton_op with
390+
dynamic sequence length.
391+
392+
Both methods share mutable state buffers (KV cache, conv_state,
393+
recurrent_state) via share_mutable_buffers=True. The model uses
394+
registered buffers with in-place updates — no state in/out args.
395+
"""
382396
import torch._inductor.config as inductor_config
383397

384398
from executorch.backends.cuda.cuda_backend import CudaBackend
@@ -398,25 +412,39 @@ def export_and_lower(model, config, args):
398412
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
399413
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
400414

401-
# Dynamic shapes
402-
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
403-
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
404-
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
405-
dynamic_shapes = ({1: seq_dim}, {0: seq_dim})
406-
407-
print("Exporting with torch.export...")
415+
# --- Decode method (T=1, static shape) ---
416+
print("Exporting decode method...")
417+
decode_tokens = torch.tensor([[0]], dtype=torch.long)
418+
decode_pos = torch.tensor([0], dtype=torch.long)
419+
with torch.no_grad():
420+
decode_ep = export(
421+
model,
422+
(decode_tokens, decode_pos),
423+
strict=True,
424+
)
425+
print("Decode export successful!")
426+
427+
# --- Prefill method (T>=2, dynamic shape) ---
428+
print("Exporting prefill method...")
429+
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
430+
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
431+
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
432+
prefill_dynamic_shapes = (
433+
{1: seq_dim}, # tokens
434+
{0: seq_dim}, # input_pos
435+
)
408436
with torch.no_grad():
409-
exported = export(
437+
prefill_ep = export(
410438
model,
411-
(example_tokens, example_input_pos),
412-
dynamic_shapes=dynamic_shapes,
439+
(prefill_tokens, prefill_pos),
440+
dynamic_shapes=prefill_dynamic_shapes,
413441
strict=True,
414442
)
415-
print("Export successful!")
443+
print("Prefill export successful!")
416444

417-
# Lower with CUDA backend
445+
# Lower with CUDA backend (per-method partitioners to avoid so_blob collision)
418446
print("Lowering to ExecuTorch with CUDA...")
419-
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]
447+
420448
metadata = {
421449
"get_max_seq_len": config.max_seq_len,
422450
"get_vocab_size": config.vocab_size,
@@ -426,8 +454,19 @@ def export_and_lower(model, config, args):
426454
"enable_dynamic_shape": True,
427455
}
428456
et_prog = to_edge_transform_and_lower(
429-
exported,
430-
partitioner=[CudaPartitioner(compile_specs)],
457+
{"decode": decode_ep, "prefill": prefill_ep},
458+
partitioner={
459+
"decode": [
460+
CudaPartitioner(
461+
[CudaBackend.generate_method_name_compile_spec("decode")]
462+
)
463+
],
464+
"prefill": [
465+
CudaPartitioner(
466+
[CudaBackend.generate_method_name_compile_spec("prefill")]
467+
)
468+
],
469+
},
431470
compile_config=EdgeCompileConfig(
432471
_check_ir_validity=False,
433472
_skip_dim_order=True,
@@ -438,7 +477,11 @@ def export_and_lower(model, config, args):
438477
config=ExecutorchBackendConfig(
439478
extract_delegate_segments=True,
440479
do_quant_fusion_and_const_prop=True,
441-
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
480+
memory_planning_pass=MemoryPlanningPass(
481+
alloc_graph_input=False,
482+
share_mutable_buffers=True,
483+
),
484+
emit_mutable_buffer_names=True,
442485
),
443486
)
444487

0 commit comments

Comments
 (0)