Skip to content

Commit eee6926

Browse files
feich-msclaude
andcommitted
Add ORT_ENFORCE for past KV non-null in kv_empty path
KV-shared layers always reuse another layer's cache, so past_key and past_value must be present. Make this invariant explicit rather than silently falling through to empty internal tensors. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
1 parent 068b686 commit eee6926

1 file changed

Lines changed: 6 additions & 7 deletions

File tree

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -471,13 +471,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
471471
// Skip CopyKVCache and fused split+rotary+copyKV.
472472
// Use past_key/past_value directly as the present buffers for attention.
473473
ORT_ENFORCE(!do_rotary, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV should not be used with kv_sequence_length==0.");
474-
if (past_key != nullptr && past_value != nullptr) {
475-
// Alias past as present — flash attention only reads present_key/present_value,
476-
// and CopyKVCache is skipped when kv_empty, so no writes occur through these pointers.
477-
present_key = const_cast<Tensor*>(past_key);
478-
present_value = const_cast<Tensor*>(past_value);
479-
}
480-
// If past is also null, present_key/present_value were already set to internal empty tensors above.
474+
ORT_ENFORCE(past_key != nullptr && past_value != nullptr,
475+
"kv_empty path requires past KV context (KV-shared layers reuse another layer's cache).");
476+
// Alias past as present — flash attention only reads present_key/present_value,
477+
// and CopyKVCache is skipped when kv_empty, so no writes occur through these pointers.
478+
present_key = const_cast<Tensor*>(past_key);
479+
present_value = const_cast<Tensor*>(past_value);
481480
ORT_ENFORCE(!parameters.past_present_share_buffer_,
482481
"kv_empty path must not use past_present_share_buffer (CopyKVCache is skipped).");
483482
} else if (do_rotary) {

0 commit comments

Comments
 (0)