@@ -24,8 +24,8 @@ void FlashAttnUnpaddedKernel(const paddle::Tensor& q,
2424 int num_heads,
2525 int head_dim,
2626 int num_kv_heads,
27- int max_seqlens_q ,
28- int max_seqlens_k ,
27+ const paddle::Tensor& max_seqlens_q_ ,
28+ const paddle::Tensor& max_seqlens_k_ ,
2929 bool causal,
3030 float scale,
3131 paddle::Tensor& out) {
@@ -148,10 +148,13 @@ void FlashAttnUnpaddedKernel(const paddle::Tensor& q,
148148 cuinferTensorDescriptor_t lse_desc;
149149 CUINFER_CHECK (cuinferCreateTensorDescriptor (&lse_desc));
150150
151+ const int32_t * max_seqlens_q = max_seqlens_q_.data <int32_t >();
152+ const int32_t * max_seqlens_k = max_seqlens_k_.data <int32_t >();
153+
151154 FmhaFwdFuncArguments args;
152155 args.batch = batch_size;
153- args.max_seqlen_q = max_seqlens_q;
154- args.max_seqlen_k = max_seqlens_k;
156+ args.max_seqlen_q = * max_seqlens_q;
157+ args.max_seqlen_k = * max_seqlens_k;
155158 args.is_causal = causal;
156159 args.scaling = scale;
157160 args.window_size_left = -1 ;
@@ -197,8 +200,8 @@ std::vector<paddle::Tensor> FlashAttnUnpadded(
197200 const paddle::Tensor& v,
198201 const paddle::Tensor& cu_seqlens_q,
199202 const paddle::Tensor& cu_seqlens_k,
200- int max_seqlens_q,
201- int max_seqlens_k,
203+ const paddle::Tensor& max_seqlens_q,
204+ const paddle::Tensor& max_seqlens_k,
202205 bool causal,
203206 float scale,
204207 bool training) {
@@ -248,21 +251,31 @@ std::vector<paddle::Tensor> FlashAttnUnpadded(
248251}
249252
250253std::vector<std::vector<int64_t >> FlashAttnUnpaddedInferShape (
251- const std::vector<int64_t >& q_shape) {
254+ const std::vector<int64_t >& q_shape,
255+ const std::vector<int64_t >& k_shape,
256+ const std::vector<int64_t >& v_shape,
257+ const std::vector<int64_t >& cu_seqlens_q_shape,
258+ const std::vector<int64_t >& cu_seqlens_k_shape,
259+ const std::vector<int64_t >& max_seqlens_q_shape,
260+ const std::vector<int64_t >& max_seqlens_k_shape) {
252261 return {{q_shape[0 ], q_shape[1 ], q_shape[2 ]}};
253262}
254263
255264std::vector<paddle::DataType> FlashAttnUnpaddedInferDtype (
256- const paddle::DataType& q_dtype) {
265+ const paddle::DataType& q_dtype,
266+ const paddle::DataType& k_dtype,
267+ const paddle::DataType& v_dtype,
268+ const paddle::DataType& cu_seqlens_q_dtype,
269+ const paddle::DataType& cu_seqlens_v_dtype,
270+ const paddle::DataType& max_seqlens_q_dtype,
271+ const paddle::DataType& max_seqlens_k_dtype) {
257272 return {q_dtype};
258273}
259274
260275PD_BUILD_STATIC_OP (cuinfer_flash_attn_unpadded)
261- .Inputs({" q" , " k" , " v" , " cu_seqlens_q" , " cu_seqlens_k" })
276+ .Inputs({" q" , " k" , " v" , " cu_seqlens_q" , " cu_seqlens_k" , " max_seqlens_q " , " max_seqlens_k " })
262277 .Outputs({" out" })
263- .Attrs({" max_seqlens_q:int" ,
264- " max_seqlens_k:int" ,
265- " causal:bool" ,
278+ .Attrs({" causal:bool" ,
266279 " scale:float" ,
267280 " training:bool" })
268281 .SetKernelFn(PD_KERNEL (FlashAttnUnpadded))
0 commit comments