Skip to content

Commit a93f9cd

Browse files
committed
Per-weight constant cache for CUDA backend
Replace the old update_constants_from_blob + cross-method sharing with a unified per-weight caching approach. The first method to initialize loads its constants from the blob and caches them by FQN. Subsequent methods with matching FQNs reuse cached GPU tensors via update_user_managed_constant_buffer_pairs, skipping blob loading entirely. This eliminates duplicate GPU weight allocations for multi-method models (e.g., prefill/decode), reducing peak GPU memory from ~35 GB to ~17.6 GB for Qwen 3.5 MoE. Also adds GPU peak memory reporting to the Qwen3.5 MoE runner and a CI check (< 20 GB) in test_model_e2e.sh.
2 parents 865f118 + 87e65ac commit a93f9cd

22 files changed

Lines changed: 2836 additions & 108 deletions

backends/aoti/aoti_delegate_handle.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,6 @@ using AOTInductorModelContainerGetConstantOriginalFQNFunc =
104104
size_t idx,
105105
const char** original_fqn);
106106

107-
// Retrieves a constant's data size in bytes by index.
108-
using AOTInductorModelContainerGetConstantDataSizeFunc = AOTIRuntimeError (*)(
109-
AOTInductorModelContainerHandle container_handle,
110-
size_t idx,
111-
size_t* data_size);
112-
113-
// Retrieves whether a constant was produced by constant folding.
114-
using AOTInductorModelContainerGetConstantFromFoldedFunc = AOTIRuntimeError (*)(
115-
AOTInductorModelContainerHandle container_handle,
116-
size_t idx,
117-
bool* from_folded);
118-
119-
// Retrieves the total size of the constants blob.
120-
using AOTInductorModelContainerGetConstantsBlobSizeFunc = AOTIRuntimeError (*)(
121-
AOTInductorModelContainerHandle container_handle,
122-
uint64_t* ret_size);
123-
124107
// Extracts the constants map from the container (active or inactive buffer).
125108
// constant_map_handle should point to a
126109
// std::unordered_map<std::string, AtenTensorHandle>.
@@ -160,9 +143,6 @@ struct AOTIDelegateHandle {
160143
AOTInductorModelContainerGetNumConstantsFunc get_num_constants;
161144
AOTInductorModelContainerGetConstantNameFunc get_constant_name;
162145
AOTInductorModelContainerGetConstantOriginalFQNFunc get_constant_original_fqn;
163-
AOTInductorModelContainerGetConstantDataSizeFunc get_constant_data_size;
164-
AOTInductorModelContainerGetConstantFromFoldedFunc get_constant_from_folded;
165-
AOTInductorModelContainerGetConstantsBlobSizeFunc get_constants_blob_size;
166146
AOTInductorModelContainerExtractConstantsMapFunc extract_constants_map;
167147
AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc
168148
update_user_managed_constant_buffer_pairs;

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .decompose_glu_pass import DecomposeGluPass # noqa
5454
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
5555
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
56+
from .decompose_gru_pass import DecomposeGruPass # noqa
5657
from .decompose_index_copy_pass import DecomposeIndexCopyPass # noqa
5758
from .decompose_index_select_to_gather_pass import ( # noqa
5859
DecomposeIndexSelectToGatherPass,
@@ -70,13 +71,15 @@
7071
from .decompose_linear_pass import DecomposeLinearPass # noqa
7172
from .decompose_log1p_pass import DecomposeLog1pPass # noqa
7273
from .decompose_logit_pass import DecomposeLogitPass # noqa
74+
from .decompose_lstm_pass import DecomposeLstmPass # noqa
7375
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
7476
from .decompose_matmul import DecomposeMatmulPass # noqa
7577
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
7678
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
7779
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
7880
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
7981
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
82+
from .decompose_rnn_pass import DecomposeRnnPass # noqa
8083
from .decompose_round_pass import DecomposeRoundPass # noqa
8184
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
8285
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
DecomposeGluPass,
6262
DecomposeGroupedConvPass,
6363
DecomposeGroupNormPass,
64+
DecomposeGruPass,
6465
DecomposeIndexCopyPass,
6566
DecomposeIndexSelectToGatherPass,
6667
DecomposeIndexTensorToGatherPass,
@@ -71,13 +72,15 @@
7172
DecomposeLinearPass,
7273
DecomposeLog1pPass,
7374
DecomposeLogitPass,
75+
DecomposeLstmPass,
7476
DecomposeMaskedFillPass,
7577
DecomposeMatmulPass,
7678
DecomposeMaxPool2dPass,
7779
DecomposeMeanDimPass,
7880
DecomposeNotEqualPass,
7981
DecomposeQuantNodesPass,
8082
DecomposeRemainderPass,
83+
DecomposeRnnPass,
8184
DecomposeRoundPass,
8285
DecomposeScaledDotProductAttentionPass,
8386
DecomposeSelectPass,
@@ -360,6 +363,9 @@ def _tosa_pipeline(
360363
ConvertToClampPass(),
361364
DecomposeTOSAUnsupportedClampPass(),
362365
DecomposeGroupNormPass(),
366+
DecomposeGruPass(),
367+
DecomposeLstmPass(),
368+
DecomposeRnnPass(),
363369
DecomposeLayerNormPass(),
364370
DecomposeVarPass(),
365371
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
@@ -578,6 +584,9 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
578584
self.add_passes(
579585
[
580586
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
587+
DecomposeGruPass(tfa_pass=True),
588+
DecomposeLstmPass(tfa_pass=True),
589+
DecomposeRnnPass(tfa_pass=True),
581590
DecomposeNotEqualPass(tfa_pass=True),
582591
DecomposeCosineSimilarityPass(tfa_pass=True),
583592
DecomposeGluPass(tfa_pass=True),

0 commit comments

Comments
 (0)