Skip to content

Commit d5209fc

Browse files
committed
Enable split-K decode SDPA by default with --no-splitk opt-out
Add `use_splitk_decode` config flag to control whether FullAttention uses the split-K (flash-decoding) SDPA kernel or the tiled SDPA for decode (T=1). The split-K kernel partitions the KV sequence across CTAs, yielding ~20% higher decode throughput on H100: Variant Decode tok/s (avg across prompts) Tiled SDPA 88.5 Split-K SDPA 107.5 (+21%) The flag defaults to True (split-K on). Pass `--no-splitk` at export time to disable. Quality is verified identical at temperature=0. This PR was authored with the assistance of Claude
1 parent ff207ea commit d5209fc

2 files changed

Lines changed: 23 additions & 4 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _prepare_and_quantize_mlx(model, config, args):
6868
pack_all_switch_linears(model)
6969

7070

71-
def load_and_quantize(args):
71+
def load_and_quantize(args): # noqa: C901
7272
"""Load model from checkpoint, optionally quantize.
7373
7474
For CUDA: quantizes experts with packed INT4, then transformer layers on CUDA.
@@ -77,6 +77,7 @@ def load_and_quantize(args):
7777
Returns (model, config) ready for export.
7878
"""
7979
backend = getattr(args, "backend", "cuda")
80+
use_splitk = not getattr(args, "no_splitk", False)
8081

8182
if not args.prequantized:
8283
if getattr(args, "tiny_test", False):
@@ -111,6 +112,7 @@ def load_and_quantize(args):
111112
rms_norm_eps=1e-6,
112113
rope_theta=10_000.0,
113114
max_seq_len=64,
115+
use_splitk_decode=use_splitk,
114116
)
115117
print("Building tiny model with random weights...")
116118
torch.manual_seed(42)
@@ -133,6 +135,10 @@ def load_and_quantize(args):
133135
model, config = Qwen35MoE.from_hf_checkpoint(
134136
args.model_dir, max_seq_len=args.max_seq_len
135137
)
138+
config.use_splitk_decode = use_splitk
139+
for layer in model.layers:
140+
if hasattr(layer.attn, "use_splitk_decode"):
141+
layer.attn.use_splitk_decode = use_splitk
136142
model.eval()
137143
print(
138144
f"Model: {config.num_hidden_layers} layers, {config.hidden_size}d, "
@@ -148,7 +154,11 @@ def load_and_quantize(args):
148154

149155
elif backend == "cuda":
150156
if args.prequantized:
151-
return load_prequantized_model(args.prequantized, args.max_seq_len)
157+
return load_prequantized_model(
158+
args.prequantized,
159+
args.max_seq_len,
160+
use_splitk_decode=use_splitk,
161+
)
152162

153163
# CUDA: quantize experts with packed INT4 for Triton kernel
154164
if args.qlinear or args.qembedding:
@@ -162,12 +172,13 @@ def load_and_quantize(args):
162172
return model, config
163173

164174

165-
def load_prequantized_model(prequantized_dir, max_seq_len=4096):
175+
def load_prequantized_model(prequantized_dir, max_seq_len=4096, use_splitk_decode=True):
166176
"""Load a prequantized safetensors bundle into a model.
167177
168178
Args:
169179
prequantized_dir: Directory containing model.safetensors and config.json.
170180
max_seq_len: Maximum sequence length for KV cache.
181+
use_splitk_decode: Use split-K SDPA for decode instead of tiled SDPA.
171182
172183
Returns:
173184
(model, config) ready for export.
@@ -181,6 +192,7 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
181192

182193
config = Qwen35MoEConfig.from_hf_config(config_path)
183194
config.max_seq_len = max_seq_len
195+
config.use_splitk_decode = use_splitk_decode
184196

185197
print(f"Loading prequantized weights from {safetensors_path}...")
186198
state_dict = load_quantized_state_dict(safetensors_path)
@@ -789,6 +801,11 @@ def main():
789801
"No checkpoint download needed. Tests all architectural features "
790802
"(GQA, GDN head ratio, mixed attention, MoE routing) at small scale.",
791803
)
804+
parser.add_argument(
805+
"--no-splitk",
806+
action="store_true",
807+
help="Disable split-K (flash-decoding) SDPA for decode; use tiled SDPA instead.",
808+
)
792809
args = parser.parse_args()
793810

794811
if args.model_id:

examples/models/qwen3_5_moe/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Qwen35MoEConfig:
5050
rms_norm_eps: float = 1e-6
5151
rope_theta: float = 10_000_000.0
5252
max_seq_len: int = 4096
53+
use_splitk_decode: bool = True
5354
layer_types: list = field(default_factory=list)
5455

5556
def __post_init__(self):
@@ -231,6 +232,7 @@ def __init__(self, config):
231232

232233
self.kv_cache = KVCache(self.n_kv_heads, self.head_dim, config.max_seq_len)
233234
self.turboquant = False
235+
self.use_splitk_decode = config.use_splitk_decode
234236

235237
self.register_buffer(
236238
"cache_positions",
@@ -289,7 +291,7 @@ def forward(self, x, input_pos):
289291
# The export produces two methods — decode (T=1, static) and
290292
# prefill (T>=2, dynamic). Each traces only one branch, so no
291293
# torch.cond is needed and we avoid GPU→CPU sync overhead.
292-
if T == 1:
294+
if T == 1 and self.use_splitk_decode:
293295
from executorch.backends.cuda.triton.kernels.sdpa import (
294296
sdpa_decode_splitk,
295297
)

0 commit comments

Comments
 (0)