Skip to content

Commit a8ef2a4

Browse files
committed
low mem export controled by compile spec
1 parent 2163d1c commit a8ef2a4

5 files changed

Lines changed: 127 additions & 37 deletions

File tree

backends/aoti/aoti_backend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ def save_data_externally(cls) -> bool:
8888
return False
8989

9090
@classmethod
91-
def get_extra_aoti_compile_context_manager(cls):
92-
"""Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager."""
91+
def get_extra_aoti_compile_context_manager(cls, compile_specs: List[CompileSpec]):
92+
"""Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager.
93+
94+
Subclasses may inspect ``compile_specs`` to opt into behaviors that
95+
only apply to specific methods/models (e.g. low-memory export).
96+
"""
9397
return contextlib.nullcontext()
9498

9599
@classmethod
@@ -195,7 +199,7 @@ def preprocess(
195199
# Compile with fallback kernel collection
196200
with cls.collect_unsupported_fallback_kernels(
197201
missing_fallback_kernels
198-
), torch.no_grad(), cls.get_extra_aoti_compile_context_manager():
202+
), torch.no_grad(), cls.get_extra_aoti_compile_context_manager(compile_specs):
199203
paths = torch._inductor.aot_compile(
200204
edge_program_module, tuple(user_input_placeholders), options=options
201205
)

backends/cuda/benchmarks/benchmark_int4_matvec.py

Whitespace-only changes.

backends/cuda/cuda_backend.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,28 @@ def get_aoti_compile_options(
332332
return options
333333

334334
@classmethod
335-
def get_extra_aoti_compile_context_manager(cls):
335+
def get_extra_aoti_compile_context_manager(cls, compile_specs: List[CompileSpec]):
336336
"""
337337
Combine all extra context managers needed during AOTInductor
338338
compilation for the CUDA backend. Each manager is documented at
339339
its own `enter_context` call site below.
340+
341+
The low-memory export monkey-patch (CPU clones for mutated buffers)
342+
is gated on the ``low_memory_mode`` compile spec — only models that
343+
explicitly opt in (currently Qwen3.5 MoE) get it. Other models go
344+
through the unmodified AOTI codepath, which avoids regressions in
345+
their cuda CI exports.
340346
"""
347+
# Parse compile_specs for low_memory_mode (default OFF)
348+
low_memory_mode = "OFF"
349+
for spec in compile_specs:
350+
if spec.key == "low_memory_mode":
351+
mode = spec.value.decode("utf-8").upper()
352+
if mode not in ["ON", "OFF"]:
353+
raise ValueError(
354+
f"Invalid low_memory_mode: {mode}. Expected 'ON' or 'OFF'."
355+
)
356+
low_memory_mode = mode
341357

342358
@contextlib.contextmanager
343359
def _combined():
@@ -348,16 +364,30 @@ def _combined():
348364
# `ReplaceEdgeOpWithTritonOpPass` are unaffected; this is
349365
# only the fallback for the `triton_kernel_mode="OFF"` path.
350366
stack.enter_context(torch.nn.attention.sdpa_kernel([SDPBackend.MATH]))
351-
# Force AOTI's mutated-buffer clones onto CPU during compile
352-
# so we stay under tight GPU memory caps (e.g. 24 GB on a
353-
# consumer 4090). See `_compile_time_cpu_clones` for details.
354-
stack.enter_context(
355-
_compile_time_cpu_clones(torch.device(cls.get_device_name()))
356-
)
367+
if low_memory_mode == "ON":
368+
# Force AOTI's mutated-buffer clones onto CPU during
369+
# compile so we stay under tight GPU memory caps (e.g.
370+
# 24 GB on a consumer 4090). See
371+
# `_compile_time_cpu_clones` for details. Only enabled
372+
# for models that explicitly opt in via the
373+
# `low_memory_mode="ON"` compile spec, since the
374+
# monkey-patch can interact poorly with other models'
375+
# AOTI compile pipelines.
376+
stack.enter_context(
377+
_compile_time_cpu_clones(torch.device(cls.get_device_name()))
378+
)
357379
yield
358380

359381
return _combined()
360382

383+
@staticmethod
384+
def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
385+
"""Return True if any compile spec opts into low-memory export."""
386+
for spec in compile_specs:
387+
if spec.key == "low_memory_mode":
388+
return spec.value.decode("utf-8").upper() == "ON"
389+
return False
390+
361391
@classmethod
362392
def preprocess_multimethod(
363393
cls,
@@ -369,6 +399,11 @@ def preprocess_multimethod(
369399
between methods (e.g. decode then prefill). Inductor caches hold CUDA
370400
tensors from the first compilation, causing the second to OOM under
371401
tight VRAM caps (e.g. 24GB simulating an RTX 4090).
402+
403+
The aggressive cleanup (resizing every CUDA tensor's storage to 0)
404+
is only enabled for methods that opt into ``low_memory_mode="ON"``
405+
— it can otherwise break models that expect their CUDA tensors to
406+
stay live across method preprocessing.
372407
"""
373408
import gc
374409

@@ -384,17 +419,17 @@ def preprocess_multimethod(
384419
preprocess_result = cls.preprocess(program, compile_spec_for_program)
385420
results_for_method.append(preprocess_result)
386421

387-
# Aggressive GPU cleanup between methods
422+
# GPU cleanup between methods. Aggressive storage resize is
423+
# only run for methods that opt into low-memory mode.
388424
if torch.cuda.is_available():
389-
gc.collect()
390-
freed = 0
391-
for obj in gc.get_objects():
392-
if isinstance(obj, torch.Tensor) and obj.is_cuda:
393-
try:
394-
obj.untyped_storage().resize_(0)
395-
freed += 1
396-
except Exception:
397-
pass
425+
if cls._is_low_memory_mode(compile_spec_for_program):
426+
gc.collect()
427+
for obj in gc.get_objects():
428+
if isinstance(obj, torch.Tensor) and obj.is_cuda:
429+
try:
430+
obj.untyped_storage().resize_(0)
431+
except Exception:
432+
pass
398433
gc.collect()
399434
torch.cuda.empty_cache()
400435

backends/cuda/triton/kernels/fused_moe.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -702,30 +702,74 @@ def moe_align_block_size(
702702
# Autotune configs for batched GEMM1 (gate+up projection).
703703
# BLOCK_M is fixed at _BATCHED_BLOCK_M; only N and K are tuned.
704704
_BATCHED_GEMM1_CONFIGS = [
705-
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=3),
706-
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=3),
707-
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16}, num_warps=4, num_stages=3),
708-
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=3),
709-
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16}, num_warps=4, num_stages=3),
710705
triton.Config(
711-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=2
706+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
707+
num_warps=4,
708+
num_stages=3,
712709
),
713710
triton.Config(
714-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16}, num_warps=4, num_stages=2
711+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
712+
num_warps=4,
713+
num_stages=3,
714+
),
715+
triton.Config(
716+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16},
717+
num_warps=4,
718+
num_stages=3,
719+
),
720+
triton.Config(
721+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
722+
num_warps=4,
723+
num_stages=3,
724+
),
725+
triton.Config(
726+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16},
727+
num_warps=4,
728+
num_stages=3,
729+
),
730+
triton.Config(
731+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
732+
num_warps=4,
733+
num_stages=2,
734+
),
735+
triton.Config(
736+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16},
737+
num_warps=4,
738+
num_stages=2,
715739
),
716740
]
717741

718742
# Autotune configs for batched GEMM2 (down projection + SiLU).
719743
_BATCHED_GEMM2_CONFIGS = [
720-
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=3),
721-
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=3),
722-
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16}, num_warps=4, num_stages=3),
723-
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=2),
724744
triton.Config(
725-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_warps=4, num_stages=2
745+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
746+
num_warps=4,
747+
num_stages=3,
748+
),
749+
triton.Config(
750+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
751+
num_warps=4,
752+
num_stages=3,
753+
),
754+
triton.Config(
755+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16},
756+
num_warps=4,
757+
num_stages=3,
758+
),
759+
triton.Config(
760+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
761+
num_warps=4,
762+
num_stages=2,
763+
),
764+
triton.Config(
765+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8},
766+
num_warps=4,
767+
num_stages=2,
726768
),
727769
triton.Config(
728-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16}, num_warps=4, num_stages=2
770+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16},
771+
num_warps=4,
772+
num_stages=2,
729773
),
730774
]
731775

@@ -831,7 +875,8 @@ def _fused_moe_batched_kernel(
831875
B_scale
832876
+ expert_id * stride_bse
833877
+ offs_n[None, :] * stride_bsn
834-
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
878+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size)
879+
* stride_bsk
835880
)
836881
b_scale = tl.load(
837882
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
@@ -967,7 +1012,8 @@ def _fused_moe_batched_int8_kernel(
9671012
B_scale
9681013
+ expert_id * stride_bse
9691014
+ offs_n[None, :] * stride_bsn
970-
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
1015+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size)
1016+
* stride_bsk
9711017
)
9721018
b_scale = tl.load(
9731019
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
@@ -1085,7 +1131,8 @@ def _fused_moe_silu_batched_kernel(
10851131
B_scale
10861132
+ expert_id * stride_bse
10871133
+ offs_n[None, :] * stride_bsn
1088-
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
1134+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size)
1135+
* stride_bsk
10891136
)
10901137
b_scale = tl.load(
10911138
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
@@ -1227,7 +1274,8 @@ def _fused_moe_silu_batched_int8_kernel(
12271274
B_scale
12281275
+ expert_id * stride_bse
12291276
+ offs_n[None, :] * stride_bsn
1230-
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
1277+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size)
1278+
* stride_bsk
12311279
)
12321280
b_scale = tl.load(
12331281
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0

examples/models/qwen3_5_moe/export.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,7 @@ def _export_cuda(model, config, args):
886886
ExecutorchBackendConfig,
887887
to_edge_transform_and_lower,
888888
)
889+
from executorch.exir.backend.compile_spec_schema import CompileSpec
889890
from executorch.exir.passes import MemoryPlanningPass
890891
from torch.export import Dim, export
891892

@@ -959,13 +960,15 @@ def _export_cuda(model, config, args):
959960
CudaPartitioner(
960961
[
961962
CudaBackend.generate_method_name_compile_spec("decode"),
963+
CompileSpec("low_memory_mode", b"ON"),
962964
]
963965
)
964966
],
965967
"prefill": [
966968
CudaPartitioner(
967969
[
968970
CudaBackend.generate_method_name_compile_spec("prefill"),
971+
CompileSpec("low_memory_mode", b"ON"),
969972
]
970973
)
971974
],

0 commit comments

Comments
 (0)