Skip to content

Commit 3069c79

Browse files
committed
Enable split-K decode SDPA by default with --no-splitk opt-out
Add `use_splitk_decode` config flag to control whether FullAttention uses the split-K (flash-decoding) SDPA kernel or the tiled SDPA for decode (T=1). The split-K kernel partitions the KV sequence across CTAs, yielding ~20% higher decode throughput on H100: Variant Decode tok/s (avg across prompts) Tiled SDPA 88.5 Split-K SDPA 107.5 (+21%) The flag defaults to True (split-K on). Pass `--no-splitk` at export time to disable. Quality is verified identical at temperature=0. This PR was authored with the assistance of Claude
1 parent 7807de4 commit 3069c79

2 files changed

Lines changed: 24 additions & 3 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def load_and_quantize(args):
7777
Returns (model, config) ready for export.
7878
"""
7979
backend = getattr(args, "backend", "cuda")
80+
use_splitk = not getattr(args, "no_splitk", False)
8081

8182
if not args.prequantized:
8283
if getattr(args, "tiny_test", False):
@@ -111,6 +112,7 @@ def load_and_quantize(args):
111112
rms_norm_eps=1e-6,
112113
rope_theta=10_000.0,
113114
max_seq_len=64,
115+
use_splitk_decode=use_splitk,
114116
)
115117
print("Building tiny model with random weights...")
116118
torch.manual_seed(42)
@@ -133,6 +135,10 @@ def load_and_quantize(args):
133135
model, config = Qwen35MoE.from_hf_checkpoint(
134136
args.model_dir, max_seq_len=args.max_seq_len
135137
)
138+
config.use_splitk_decode = use_splitk
139+
for layer in model.layers:
140+
if hasattr(layer.attn, "use_splitk_decode"):
141+
layer.attn.use_splitk_decode = use_splitk
136142
model.eval()
137143
print(
138144
f"Model: {config.num_hidden_layers} layers, {config.hidden_size}d, "
@@ -148,7 +154,11 @@ def load_and_quantize(args):
148154

149155
elif backend == "cuda":
150156
if args.prequantized:
151-
return load_prequantized_model(args.prequantized, args.max_seq_len)
157+
return load_prequantized_model(
158+
args.prequantized,
159+
args.max_seq_len,
160+
use_splitk_decode=use_splitk,
161+
)
152162

153163
# CUDA: quantize experts with packed INT4 for Triton kernel
154164
if args.qlinear or args.qembedding:
@@ -162,12 +172,15 @@ def load_and_quantize(args):
162172
return model, config
163173

164174

165-
def load_prequantized_model(prequantized_dir, max_seq_len=4096):
175+
def load_prequantized_model(
176+
prequantized_dir, max_seq_len=4096, use_splitk_decode=True
177+
):
166178
"""Load a prequantized safetensors bundle into a model.
167179
168180
Args:
169181
prequantized_dir: Directory containing model.safetensors and config.json.
170182
max_seq_len: Maximum sequence length for KV cache.
183+
use_splitk_decode: Use split-K SDPA for decode instead of tiled SDPA.
171184
172185
Returns:
173186
(model, config) ready for export.
@@ -181,6 +194,7 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
181194

182195
config = Qwen35MoEConfig.from_hf_config(config_path)
183196
config.max_seq_len = max_seq_len
197+
config.use_splitk_decode = use_splitk_decode
184198

185199
print(f"Loading prequantized weights from {safetensors_path}...")
186200
state_dict = load_quantized_state_dict(safetensors_path)
@@ -783,6 +797,11 @@ def main():
783797
"No checkpoint download needed. Tests all architectural features "
784798
"(GQA, GDN head ratio, mixed attention, MoE routing) at small scale.",
785799
)
800+
parser.add_argument(
801+
"--no-splitk",
802+
action="store_true",
803+
help="Disable split-K (flash-decoding) SDPA for decode; use tiled SDPA instead.",
804+
)
786805
args = parser.parse_args()
787806

788807
if args.model_id:

examples/models/qwen3_5_moe/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Qwen35MoEConfig:
5151
rms_norm_eps: float = 1e-6
5252
rope_theta: float = 10_000_000.0
5353
max_seq_len: int = 4096
54+
use_splitk_decode: bool = True
5455
layer_types: list = field(default_factory=list)
5556

5657
def __post_init__(self):
@@ -232,6 +233,7 @@ def __init__(self, config):
232233

233234
self.kv_cache = KVCache(self.n_kv_heads, self.head_dim, config.max_seq_len)
234235
self.turboquant = False
236+
self.use_splitk_decode = config.use_splitk_decode
235237

236238
self.register_buffer(
237239
"cache_positions",
@@ -290,7 +292,7 @@ def forward(self, x, input_pos):
290292
# The export produces two methods — decode (T=1, static) and
291293
# prefill (T>=2, dynamic). Each traces only one branch, so no
292294
# torch.cond is needed and we avoid GPU→CPU sync overhead.
293-
if T == 1:
295+
if T == 1 and self.use_splitk_decode:
294296
y = sdpa_decode_splitk(q, k, v, attn_mask=attn_mask)
295297
else:
296298
y = sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=True)

0 commit comments

Comments
 (0)