Skip to content

Commit c548f6f

Browse files
committed
Address remaining PR review comments in triton_fa.py
- quantize_p: raise NotImplementedError when any of q/k/v requires_grad since backward does not model the quantized P path (inference-only) - b_start_loc_k: only synthesize dummy tensor in paged mode; raise ValueError in contiguous separate-KV path when b_start_loc_k is None; also validate that v_cache and block_table are provided alongside k_cache Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 4248327 commit c548f6f

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

modelopt/torch/kernels/triton_fa.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -985,10 +985,16 @@ def forward(
985985
b_start_loc_k = b_start_loc
986986
max_input_len_k = max_input_len
987987

988-
# Paged mode: b_start_loc_k may be None (KV is in paged cache, not contiguous).
989-
# Provide a dummy tensor so Triton can compile the tl.load (it won't be used).
990-
if b_start_loc_k is None:
991-
b_start_loc_k = torch.zeros_like(b_start_loc)
988+
if is_paged:
989+
if v_cache is None or block_table is None:
990+
raise ValueError("k_cache, v_cache, and block_table must all be provided together")
991+
# Paged mode: b_start_loc_k is never dereferenced, but Triton still needs a tensor.
992+
if b_start_loc_k is None:
993+
b_start_loc_k = torch.zeros_like(b_start_loc)
994+
elif b_start_loc_k is None and b_seq_len_k is not None:
995+
raise ValueError(
996+
"b_start_loc_k is required when K/V are passed as a separate packed tensor"
997+
)
992998

993999
# Pre-multiply scale by log2(e) so the kernel can use exp2()
9941000
# exp(score * sm_scale) = exp2(score * sm_scale * log2(e))
@@ -1012,6 +1018,11 @@ def forward(
10121018
# Therefore the threshold in kernel (log2) space is simply log2(lambda).
10131019
# Do NOT multiply by sm_scale — that factor is already absorbed into the
10141020
# log2(e) conversion above.
1021+
if quantize_p and (q.requires_grad or k.requires_grad or v.requires_grad):
1022+
raise NotImplementedError(
1023+
"quantize_p supports inference only; backward does not model the quantized P path"
1024+
)
1025+
10151026
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
10161027
if apply_skip:
10171028
skip_threshold_log2 = math.log2(skip_softmax_threshold)

0 commit comments

Comments
 (0)