|
| 1 | +#ifndef INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ |
| 2 | +#define INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ |
| 3 | + |
| 4 | +#include <cassert> |
| 5 | +#include <cstddef> |
| 6 | +#include <vector> |
| 7 | + |
| 8 | +#include "acl/acl.h" |
| 9 | +#include "aclnn/aclnn_base.h" |
| 10 | +#include "aclnnop/aclnn_fused_infer_attention_score_v4.h" |
| 11 | +#include "ascend/common.h" |
| 12 | +#include "ascend/workspace_pool_.h" |
| 13 | +#include "base/flash_attention.h" |
| 14 | +#include "operator.h" |
| 15 | + |
| 16 | +namespace infini::ops { |
| 17 | + |
| 18 | +namespace detail { |
| 19 | + |
| 20 | +// Extract cu_seqlens differences to a host aclIntArray. |
| 21 | +// cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...]. |
| 22 | +// Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths). |
| 23 | +// |
| 24 | +// When cu_seqlens is a CPU tensor (device type kCpu), the data pointer is |
| 25 | +// already on the host and can be read directly — no D2H sync needed. |
| 26 | +inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens, |
| 27 | + aclrtStream stream) { |
| 28 | + auto n = cu_seqlens.numel(); |
| 29 | + |
| 30 | + const int64_t* cu_host_ptr = nullptr; |
| 31 | + std::vector<int64_t> cu_host_buf; |
| 32 | + |
| 33 | + if (cu_seqlens.device().type() == Device::Type::kCpu) { |
| 34 | + cu_host_ptr = static_cast<const int64_t*>(cu_seqlens.data()); |
| 35 | + } else { |
| 36 | + cu_host_buf.resize(n); |
| 37 | + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), cu_seqlens.data(), |
| 38 | + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); |
| 39 | + aclrtSynchronizeStream(stream); |
| 40 | + cu_host_ptr = cu_host_buf.data(); |
| 41 | + } |
| 42 | + |
| 43 | + std::vector<int64_t> lengths(n - 1); |
| 44 | + for (size_t i = 0; i < lengths.size(); ++i) { |
| 45 | + lengths[i] = cu_host_ptr[i + 1] - cu_host_ptr[i]; |
| 46 | + } |
| 47 | + |
| 48 | + return aclCreateIntArray(lengths.data(), |
| 49 | + static_cast<int64_t>(lengths.size())); |
| 50 | +} |
| 51 | + |
| 52 | +// Extract cumulative end positions from cu_seqlens to a host aclIntArray. |
| 53 | +// cu_seqlens = [0, s1, s1+s2, ...] -> cum_lens = [s1, s1+s2, ...]. |
| 54 | +// FIA V4 TND varlen uses cumulative end positions, matching the vllm-ascend |
| 55 | +// convention for npu_fused_infer_attention_score actual_seq_lengths. |
| 56 | +// |
| 57 | +// When cu_seqlens is a CPU tensor, reads directly from host memory. |
| 58 | +inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens, |
| 59 | + aclrtStream stream) { |
| 60 | + auto n = cu_seqlens.numel(); |
| 61 | + |
| 62 | + const int64_t* cu_host_ptr = nullptr; |
| 63 | + std::vector<int64_t> cu_host_buf; |
| 64 | + |
| 65 | + if (cu_seqlens.device().type() == Device::Type::kCpu) { |
| 66 | + cu_host_ptr = static_cast<const int64_t*>(cu_seqlens.data()); |
| 67 | + } else { |
| 68 | + cu_host_buf.resize(n); |
| 69 | + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), cu_seqlens.data(), |
| 70 | + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); |
| 71 | + aclrtSynchronizeStream(stream); |
| 72 | + cu_host_ptr = cu_host_buf.data(); |
| 73 | + } |
| 74 | + |
| 75 | + // Skip the leading 0; return [s1, s1+s2, ...]. |
| 76 | + return aclCreateIntArray(cu_host_ptr + 1, static_cast<int64_t>(n - 1)); |
| 77 | +} |
| 78 | + |
| 79 | +// Allocate a 2048x2048 lower-triangular UINT8 causal mask on device. |
| 80 | +// Required for `sparseMode` >= 2. |
| 81 | +inline aclTensor* makeCausalMask(void** mask_buf, aclrtStream stream) { |
| 82 | + constexpr int64_t kMaskDim = 2048; |
| 83 | + const int64_t mask_elems = kMaskDim * kMaskDim; |
| 84 | + const size_t mask_bytes = static_cast<size_t>(mask_elems); // uint8_t |
| 85 | + |
| 86 | + aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); |
| 87 | + |
| 88 | + std::vector<uint8_t> host_mask(mask_elems); |
| 89 | + for (int64_t r = 0; r < kMaskDim; ++r) { |
| 90 | + for (int64_t c = 0; c < kMaskDim; ++c) { |
| 91 | + // 1 = masked out (upper triangle); 0 = attend (lower triangle). |
| 92 | + host_mask[r * kMaskDim + c] = (c > r) ? 1 : 0; |
| 93 | + } |
| 94 | + } |
| 95 | + aclrtMemcpyAsync(*mask_buf, mask_bytes, host_mask.data(), mask_bytes, |
| 96 | + ACL_MEMCPY_HOST_TO_DEVICE, stream); |
| 97 | + aclrtSynchronizeStream(stream); |
| 98 | + |
| 99 | + std::vector<int64_t> mask_shape = {kMaskDim, kMaskDim}; |
| 100 | + std::vector<int64_t> mask_strides = {kMaskDim, 1}; |
| 101 | + std::vector<int64_t> mask_storage = {mask_elems}; |
| 102 | + return aclCreateTensor(mask_shape.data(), 2, ACL_UINT8, mask_strides.data(), |
| 103 | + 0, ACL_FORMAT_ND, mask_storage.data(), 1, *mask_buf); |
| 104 | +} |
| 105 | + |
| 106 | +} // namespace detail |
| 107 | + |
| 108 | +template <> |
| 109 | +class Operator<FlashAttention, Device::Type::kAscend> : public FlashAttention { |
| 110 | + public: |
| 111 | + Operator(const Tensor query, const Tensor key, const Tensor value, |
| 112 | + std::optional<Tensor> cu_seqlens_q, |
| 113 | + std::optional<Tensor> cu_seqlens_kv, |
| 114 | + std::optional<Tensor> block_table, int64_t num_heads, |
| 115 | + int64_t num_kv_heads, int64_t head_size, double scale, bool causal, |
| 116 | + int64_t window_left, int64_t window_right, int64_t block_size, |
| 117 | + Tensor output, std::optional<int64_t> sliding_window = std::nullopt) |
| 118 | + : FlashAttention(query, key, value, cu_seqlens_q, cu_seqlens_kv, |
| 119 | + block_table, num_heads, num_kv_heads, head_size, scale, |
| 120 | + causal, window_left, window_right, block_size, output, |
| 121 | + sliding_window) { |
| 122 | + paged_ = block_table.has_value() && block_size > 0; |
| 123 | + aclDataType acl_dt = ascend::ToAclDtype(query.dtype()); |
| 124 | + |
| 125 | + if (!paged_) { |
| 126 | + // Prefill: cache Q and output (TND layout). |
| 127 | + prefill_q_cache_ = ascend::AclTensorCache(query); |
| 128 | + prefill_out_cache_ = ascend::AclTensorCache(output); |
| 129 | + |
| 130 | + // Pre-compute causal mask once (sparse_mode >= 2). Read the |
| 131 | + // resolved pair from base-class members so `sliding_window` |
| 132 | + // normalization is honored at cache-key construction. |
| 133 | + if (causal) { |
| 134 | + int64_t sm = (window_left_ >= 0) ? 4 : 3; |
| 135 | + if (sm >= 2) { |
| 136 | + causal_mask_ = detail::makeCausalMask(&causal_mask_buf_, nullptr); |
| 137 | + } |
| 138 | + } |
| 139 | + } else { |
| 140 | + // Decode: cache Q/output (BNSD), block_table. |
| 141 | + const int64_t N = query.size(1); |
| 142 | + const int64_t D = query.size(2); |
| 143 | + const int64_t B = query.size(0); |
| 144 | + |
| 145 | + decode_q_cache_ = ascend::AclTensorCache({B, N, 1, D}, acl_dt, |
| 146 | + const_cast<void*>(query.data())); |
| 147 | + decode_out_cache_ = |
| 148 | + ascend::AclTensorCache({B, N, 1, D}, acl_dt, output.data()); |
| 149 | + block_table_cache_ = ascend::AclTensorCache(block_table.value()); |
| 150 | + |
| 151 | + // Pre-compute KV reshape metadata. |
| 152 | + const int64_t nb = key.size(0); |
| 153 | + const int64_t bsz = key.size(1); |
| 154 | + const int64_t NkvD = key.size(2) * key.size(3); |
| 155 | + kv_shape_ = {nb, bsz, NkvD}; |
| 156 | + kv_strides_ = {bsz * NkvD, NkvD, 1}; |
| 157 | + kv_storage_shape_ = {nb * bsz * NkvD}; |
| 158 | + kv_acl_dt_ = acl_dt; |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + ~Operator() { |
| 163 | + if (!ascend::IsAclRuntimeAlive()) return; |
| 164 | + |
| 165 | + if (causal_mask_) aclDestroyTensor(causal_mask_); |
| 166 | + if (causal_mask_buf_) aclrtFree(causal_mask_buf_); |
| 167 | + } |
| 168 | + |
| 169 | + void operator()(const Tensor query, const Tensor key, const Tensor value, |
| 170 | + std::optional<Tensor> cu_seqlens_q, |
| 171 | + std::optional<Tensor> cu_seqlens_kv, |
| 172 | + std::optional<Tensor> block_table, int64_t num_heads, |
| 173 | + int64_t num_kv_heads, int64_t head_size, double scale, |
| 174 | + bool causal, int64_t window_left, int64_t window_right, |
| 175 | + int64_t block_size, Tensor output, |
| 176 | + std::optional<int64_t> sliding_window) const override { |
| 177 | + auto stream = static_cast<aclrtStream>(stream_); |
| 178 | + const bool paged = paged_; |
| 179 | + |
| 180 | + // The base class stored the resolved window pair in `window_left_` / |
| 181 | + // `window_right_` at construction; prefer those over the call-site |
| 182 | + // args so that `sliding_window` is honored here as well. |
| 183 | + int64_t wl = window_left_; |
| 184 | + int64_t wr = window_right_; |
| 185 | + (void)window_left; |
| 186 | + (void)window_right; |
| 187 | + (void)sliding_window; |
| 188 | + |
| 189 | + int64_t sparse_mode; |
| 190 | + int64_t pre_tokens = 2147483647; |
| 191 | + int64_t next_tokens = 2147483647; |
| 192 | + if (causal) { |
| 193 | + if (wl >= 0) { |
| 194 | + sparse_mode = 4; |
| 195 | + pre_tokens = wl; |
| 196 | + next_tokens = 0; |
| 197 | + } else { |
| 198 | + sparse_mode = 3; |
| 199 | + next_tokens = 0; |
| 200 | + } |
| 201 | + } else { |
| 202 | + sparse_mode = 0; |
| 203 | + if (wl >= 0) pre_tokens = wl; |
| 204 | + if (wr >= 0) next_tokens = wr; |
| 205 | + } |
| 206 | + |
| 207 | + if (!paged) { |
| 208 | + // --- Prefill --- |
| 209 | + int64_t T = query.size(0); |
| 210 | + |
| 211 | + // cumSeqLengths / extractSeqLengths automatically skip D2H when |
| 212 | + // cu_seqlens is a CPU tensor (see detail:: helpers above). |
| 213 | + aclIntArray* seq_q = |
| 214 | + cu_seqlens_q.has_value() |
| 215 | + ? detail::cumSeqLengths(cu_seqlens_q.value(), stream) |
| 216 | + : aclCreateIntArray(&T, 1); |
| 217 | + aclIntArray* seq_kv = |
| 218 | + cu_seqlens_kv.has_value() |
| 219 | + ? detail::cumSeqLengths(cu_seqlens_kv.value(), stream) |
| 220 | + : aclCreateIntArray(&T, 1); |
| 221 | + |
| 222 | + aclTensor* t_q = prefill_q_cache_.get(const_cast<void*>(query.data())); |
| 223 | + // K/V descriptors go into TensorList which takes ownership — must be |
| 224 | + // per-call (cannot cache). |
| 225 | + aclTensor* t_k = ascend::BuildAclTensor(key); |
| 226 | + aclTensor* t_v = ascend::BuildAclTensor(value); |
| 227 | + aclTensor* t_out = prefill_out_cache_.get(output.data()); |
| 228 | + |
| 229 | + const aclTensor* k_arr[] = {t_k}; |
| 230 | + const aclTensor* v_arr[] = {t_v}; |
| 231 | + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); |
| 232 | + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); |
| 233 | + |
| 234 | + uint64_t ws_needed = 0; |
| 235 | + aclOpExecutor* executor = nullptr; |
| 236 | + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( |
| 237 | + t_q, key_list, val_list, |
| 238 | + nullptr, // pseShift |
| 239 | + causal_mask_, // attenMask (pre-computed, or nullptr) |
| 240 | + seq_q, // actualSeqLengths |
| 241 | + seq_kv, // actualSeqLengthsKv |
| 242 | + nullptr, nullptr, nullptr, nullptr, |
| 243 | + nullptr, // deqScale1..quantOffset2 |
| 244 | + nullptr, nullptr, // antiquantScale, antiquantOffset |
| 245 | + nullptr, // blockTable |
| 246 | + nullptr, nullptr, // queryPaddingSize, kvPaddingSize |
| 247 | + nullptr, nullptr, nullptr, |
| 248 | + nullptr, // key/value antiquant scale/offset |
| 249 | + nullptr, nullptr, |
| 250 | + nullptr, // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen |
| 251 | + nullptr, nullptr, |
| 252 | + nullptr, // queryRope, keyRope, keyRopeAntiquantScale |
| 253 | + nullptr, nullptr, // dequantScaleQuery, learnableSink |
| 254 | + num_heads, scale, pre_tokens, next_tokens, const_cast<char*>("TND"), |
| 255 | + num_kv_heads, sparse_mode, |
| 256 | + 0, // innerPrecise |
| 257 | + 0, // blockSize (unused for prefill) |
| 258 | + 0, false, // antiquantMode, softmaxLseFlag |
| 259 | + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode |
| 260 | + t_out, nullptr, &ws_needed, &executor); |
| 261 | + assert( |
| 262 | + gws == ACL_SUCCESS && |
| 263 | + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (prefill)"); |
| 264 | + |
| 265 | + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed); |
| 266 | + aclError ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, |
| 267 | + executor, stream); |
| 268 | + assert(ret == ACL_SUCCESS && |
| 269 | + "aclnnFusedInferAttentionScoreV4 failed (prefill)"); |
| 270 | + |
| 271 | + // t_q and t_out are owned by caches — do NOT destroy. |
| 272 | + // t_k and t_v are owned by TensorLists. |
| 273 | + aclDestroyTensorList(key_list); |
| 274 | + aclDestroyTensorList(val_list); |
| 275 | + aclDestroyIntArray(seq_q); |
| 276 | + aclDestroyIntArray(seq_kv); |
| 277 | + return; |
| 278 | + } |
| 279 | + |
| 280 | + // --- Paged decode --- |
| 281 | + assert(cu_seqlens_kv.has_value() && |
| 282 | + "`FlashAttention` paged decode requires `cu_seqlens_kv`"); |
| 283 | + |
| 284 | + aclTensor* t_query = decode_q_cache_.get(const_cast<void*>(query.data())); |
| 285 | + aclTensor* t_output = decode_out_cache_.get(output.data()); |
| 286 | + |
| 287 | + // K/V descriptors go into TensorList which takes ownership — must be |
| 288 | + // per-call. Use pre-computed metadata to avoid heap allocs. |
| 289 | + aclTensor* t_key = aclCreateTensor( |
| 290 | + kv_shape_.data(), static_cast<int64_t>(kv_shape_.size()), kv_acl_dt_, |
| 291 | + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), |
| 292 | + static_cast<int64_t>(kv_storage_shape_.size()), |
| 293 | + const_cast<void*>(key.data())); |
| 294 | + aclTensor* t_value = aclCreateTensor( |
| 295 | + kv_shape_.data(), static_cast<int64_t>(kv_shape_.size()), kv_acl_dt_, |
| 296 | + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), |
| 297 | + static_cast<int64_t>(kv_storage_shape_.size()), |
| 298 | + const_cast<void*>(value.data())); |
| 299 | + |
| 300 | + // extractSeqLengths skips D2H when cu_seqlens_kv is a CPU tensor. |
| 301 | + aclIntArray* seq_kv = |
| 302 | + detail::extractSeqLengths(cu_seqlens_kv.value(), stream); |
| 303 | + aclTensor* t_block_table = |
| 304 | + block_table_cache_.get(const_cast<void*>(block_table.value().data())); |
| 305 | + |
| 306 | + const aclTensor* k_arr[] = {t_key}; |
| 307 | + const aclTensor* v_arr[] = {t_value}; |
| 308 | + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); |
| 309 | + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); |
| 310 | + |
| 311 | + uint64_t ws_needed = 0; |
| 312 | + aclOpExecutor* executor = nullptr; |
| 313 | + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( |
| 314 | + t_query, key_list, val_list, |
| 315 | + nullptr, // pseShift |
| 316 | + nullptr, // attenMask (sparseMode ignored for Q_S=1) |
| 317 | + nullptr, // actualSeqLengths (ignored for Q_S=1) |
| 318 | + seq_kv, // actualSeqLengthsKv (mandatory for paged) |
| 319 | + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, |
| 320 | + t_block_table, // blockTable |
| 321 | + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, |
| 322 | + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, num_heads, scale, |
| 323 | + static_cast<int64_t>(2147483647), static_cast<int64_t>(2147483647), |
| 324 | + const_cast<char*>("BNSD"), num_kv_heads, |
| 325 | + 0, // sparseMode=0 (ignored for Q_S=1) |
| 326 | + 0, // innerPrecise |
| 327 | + block_size, // blockSize |
| 328 | + 0, false, // antiquantMode, softmaxLseFlag |
| 329 | + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode |
| 330 | + t_output, nullptr, &ws_needed, &executor); |
| 331 | + assert(gws == ACL_SUCCESS && |
| 332 | + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (decode)"); |
| 333 | + |
| 334 | + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed); |
| 335 | + aclError ret = |
| 336 | + aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, executor, stream); |
| 337 | + assert(ret == ACL_SUCCESS && |
| 338 | + "aclnnFusedInferAttentionScoreV4 failed (decode)"); |
| 339 | + |
| 340 | + // t_query, t_output, t_block_table owned by caches — do NOT destroy. |
| 341 | + // t_key, t_value owned by TensorLists. |
| 342 | + aclDestroyTensorList(key_list); |
| 343 | + aclDestroyTensorList(val_list); |
| 344 | + aclDestroyIntArray(seq_kv); |
| 345 | + } |
| 346 | + |
| 347 | + private: |
| 348 | + bool paged_ = false; |
| 349 | + |
| 350 | + mutable ascend::AclTensorCache prefill_q_cache_; |
| 351 | + |
| 352 | + mutable ascend::AclTensorCache prefill_out_cache_; |
| 353 | + |
| 354 | + mutable ascend::AclTensorCache decode_q_cache_; |
| 355 | + |
| 356 | + mutable ascend::AclTensorCache decode_out_cache_; |
| 357 | + |
| 358 | + mutable ascend::AclTensorCache block_table_cache_; |
| 359 | + |
| 360 | + aclTensor* causal_mask_ = nullptr; |
| 361 | + |
| 362 | + void* causal_mask_buf_ = nullptr; |
| 363 | + |
| 364 | + std::vector<int64_t> kv_shape_; |
| 365 | + |
| 366 | + std::vector<int64_t> kv_strides_; |
| 367 | + |
| 368 | + std::vector<int64_t> kv_storage_shape_; |
| 369 | + |
| 370 | + aclDataType kv_acl_dt_ = ACL_DT_UNDEFINED; |
| 371 | +}; |
| 372 | + |
| 373 | +} // namespace infini::ops |
| 374 | + |
| 375 | +#endif |
0 commit comments