Skip to content

Commit 5fd8020

Browse files
[Cherry-Pick][BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn (#7216)
1 parent 9c65655 commit 5fd8020

2 files changed

Lines changed: 1 addition & 2 deletions

File tree

custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
5454
PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape");
5555
PADDLE_ENFORCE(head_dim == 128, "Unmatched shape");
5656
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
57-
PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape");
5857
PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape");
5958

6059
constexpr int kBlockM = 128;

fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def forward_mixed(
307307
q,
308308
k,
309309
v,
310-
forward_meta.cu_seqlens_q,
310+
forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]],
311311
forward_meta.attn_cu_seqlens_k,
312312
forward_meta.seq_lens_encoder,
313313
res_encoder,

0 commit comments

Comments
 (0)