Skip to content

Commit e3b6f16

Browse files
author
zhangyue
committed
feat(ascend): op-cache-attn group — ReshapeAndCache, FlashAttention, PagedAttention, TopkToppSampling
Four KV-cache and attention operators: | op | impl | |---|---| | ReshapeAndCache | 3 impls: aclnnInplaceIndexCopy (kernel.h); custom AscendC (kernel_v2.h); ATB `ReshapeAndCacheParam` (kernel_atb.h, int64 `slot_mapping` handled via cached async `aclnnCast`) | | FlashAttention | `aclnnFusedInferAttentionScoreV4` (prefill + paged decode). Supports both the native `(window_left, window_right)` pair and a new `std::optional<int64_t> sliding_window` entry (additive, vLLM-style) | | PagedAttention | ATB `PagedAttentionParam` with optional CPU-pinned host tensors (`seq_lens_host` / `block_table_host`) that make the op NPUGraph-capturable | | TopkToppSampling | ATB `TopkToppSamplingParam` | Includes vLLM API alignment commits: - `perf(reshape_and_cache)`: int64 slot_mapping routed through cached async `aclnnCast` (no D2H sync, NPUGraph-compatible) - `feat(flash_attention)`: add `sliding_window` entry, additive - `docs(paged_attention)`: base class comment explains the CPU-host tensor contract New `src/base/<op>.h`: paged_attention, topk_topp_sampling. Modified: reshape_and_cache, flash_attention.
1 parent a05713b commit e3b6f16

13 files changed

Lines changed: 2993 additions & 30 deletions

File tree

Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
#ifndef INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_
3+
4+
#include <cassert>
5+
#include <cstddef>
6+
#include <vector>
7+
8+
#include "acl/acl.h"
9+
#include "aclnn/aclnn_base.h"
10+
#include "aclnnop/aclnn_fused_infer_attention_score_v4.h"
11+
#include "ascend/common.h"
12+
#include "ascend/workspace_pool_.h"
13+
#include "base/flash_attention.h"
14+
#include "operator.h"
15+
16+
namespace infini::ops {
17+
18+
namespace detail {
19+
20+
// Extract cu_seqlens differences to a host aclIntArray.
21+
// cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...].
22+
// Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths).
23+
//
24+
// When cu_seqlens is a CPU tensor (device type kCpu), the data pointer is
25+
// already on the host and can be read directly — no D2H sync needed.
26+
inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens,
27+
aclrtStream stream) {
28+
auto n = cu_seqlens.numel();
29+
30+
const int64_t* cu_host_ptr = nullptr;
31+
std::vector<int64_t> cu_host_buf;
32+
33+
if (cu_seqlens.device().type() == Device::Type::kCpu) {
34+
cu_host_ptr = static_cast<const int64_t*>(cu_seqlens.data());
35+
} else {
36+
cu_host_buf.resize(n);
37+
aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), cu_seqlens.data(),
38+
n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream);
39+
aclrtSynchronizeStream(stream);
40+
cu_host_ptr = cu_host_buf.data();
41+
}
42+
43+
std::vector<int64_t> lengths(n - 1);
44+
for (size_t i = 0; i < lengths.size(); ++i) {
45+
lengths[i] = cu_host_ptr[i + 1] - cu_host_ptr[i];
46+
}
47+
48+
return aclCreateIntArray(lengths.data(),
49+
static_cast<int64_t>(lengths.size()));
50+
}
51+
52+
// Extract cumulative end positions from cu_seqlens to a host aclIntArray.
53+
// cu_seqlens = [0, s1, s1+s2, ...] -> cum_lens = [s1, s1+s2, ...].
54+
// FIA V4 TND varlen uses cumulative end positions, matching the vllm-ascend
55+
// convention for npu_fused_infer_attention_score actual_seq_lengths.
56+
//
57+
// When cu_seqlens is a CPU tensor, reads directly from host memory.
58+
inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens,
59+
aclrtStream stream) {
60+
auto n = cu_seqlens.numel();
61+
62+
const int64_t* cu_host_ptr = nullptr;
63+
std::vector<int64_t> cu_host_buf;
64+
65+
if (cu_seqlens.device().type() == Device::Type::kCpu) {
66+
cu_host_ptr = static_cast<const int64_t*>(cu_seqlens.data());
67+
} else {
68+
cu_host_buf.resize(n);
69+
aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), cu_seqlens.data(),
70+
n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream);
71+
aclrtSynchronizeStream(stream);
72+
cu_host_ptr = cu_host_buf.data();
73+
}
74+
75+
// Skip the leading 0; return [s1, s1+s2, ...].
76+
return aclCreateIntArray(cu_host_ptr + 1, static_cast<int64_t>(n - 1));
77+
}
78+
79+
// Allocate a 2048x2048 lower-triangular UINT8 causal mask on device.
80+
// Required for `sparseMode` >= 2.
81+
inline aclTensor* makeCausalMask(void** mask_buf, aclrtStream stream) {
82+
constexpr int64_t kMaskDim = 2048;
83+
const int64_t mask_elems = kMaskDim * kMaskDim;
84+
const size_t mask_bytes = static_cast<size_t>(mask_elems); // uint8_t
85+
86+
aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
87+
88+
std::vector<uint8_t> host_mask(mask_elems);
89+
for (int64_t r = 0; r < kMaskDim; ++r) {
90+
for (int64_t c = 0; c < kMaskDim; ++c) {
91+
// 1 = masked out (upper triangle); 0 = attend (lower triangle).
92+
host_mask[r * kMaskDim + c] = (c > r) ? 1 : 0;
93+
}
94+
}
95+
aclrtMemcpyAsync(*mask_buf, mask_bytes, host_mask.data(), mask_bytes,
96+
ACL_MEMCPY_HOST_TO_DEVICE, stream);
97+
aclrtSynchronizeStream(stream);
98+
99+
std::vector<int64_t> mask_shape = {kMaskDim, kMaskDim};
100+
std::vector<int64_t> mask_strides = {kMaskDim, 1};
101+
std::vector<int64_t> mask_storage = {mask_elems};
102+
return aclCreateTensor(mask_shape.data(), 2, ACL_UINT8, mask_strides.data(),
103+
0, ACL_FORMAT_ND, mask_storage.data(), 1, *mask_buf);
104+
}
105+
106+
} // namespace detail
107+
108+
template <>
109+
class Operator<FlashAttention, Device::Type::kAscend> : public FlashAttention {
110+
public:
111+
Operator(const Tensor query, const Tensor key, const Tensor value,
112+
std::optional<Tensor> cu_seqlens_q,
113+
std::optional<Tensor> cu_seqlens_kv,
114+
std::optional<Tensor> block_table, int64_t num_heads,
115+
int64_t num_kv_heads, int64_t head_size, double scale, bool causal,
116+
int64_t window_left, int64_t window_right, int64_t block_size,
117+
Tensor output, std::optional<int64_t> sliding_window = std::nullopt)
118+
: FlashAttention(query, key, value, cu_seqlens_q, cu_seqlens_kv,
119+
block_table, num_heads, num_kv_heads, head_size, scale,
120+
causal, window_left, window_right, block_size, output,
121+
sliding_window) {
122+
paged_ = block_table.has_value() && block_size > 0;
123+
aclDataType acl_dt = ascend::ToAclDtype(query.dtype());
124+
125+
if (!paged_) {
126+
// Prefill: cache Q and output (TND layout).
127+
prefill_q_cache_ = ascend::AclTensorCache(query);
128+
prefill_out_cache_ = ascend::AclTensorCache(output);
129+
130+
// Pre-compute causal mask once (sparse_mode >= 2). Read the
131+
// resolved pair from base-class members so `sliding_window`
132+
// normalization is honored at cache-key construction.
133+
if (causal) {
134+
int64_t sm = (window_left_ >= 0) ? 4 : 3;
135+
if (sm >= 2) {
136+
causal_mask_ = detail::makeCausalMask(&causal_mask_buf_, nullptr);
137+
}
138+
}
139+
} else {
140+
// Decode: cache Q/output (BNSD), block_table.
141+
const int64_t N = query.size(1);
142+
const int64_t D = query.size(2);
143+
const int64_t B = query.size(0);
144+
145+
decode_q_cache_ = ascend::AclTensorCache({B, N, 1, D}, acl_dt,
146+
const_cast<void*>(query.data()));
147+
decode_out_cache_ =
148+
ascend::AclTensorCache({B, N, 1, D}, acl_dt, output.data());
149+
block_table_cache_ = ascend::AclTensorCache(block_table.value());
150+
151+
// Pre-compute KV reshape metadata.
152+
const int64_t nb = key.size(0);
153+
const int64_t bsz = key.size(1);
154+
const int64_t NkvD = key.size(2) * key.size(3);
155+
kv_shape_ = {nb, bsz, NkvD};
156+
kv_strides_ = {bsz * NkvD, NkvD, 1};
157+
kv_storage_shape_ = {nb * bsz * NkvD};
158+
kv_acl_dt_ = acl_dt;
159+
}
160+
}
161+
162+
~Operator() {
163+
if (!ascend::IsAclRuntimeAlive()) return;
164+
165+
if (causal_mask_) aclDestroyTensor(causal_mask_);
166+
if (causal_mask_buf_) aclrtFree(causal_mask_buf_);
167+
}
168+
169+
void operator()(const Tensor query, const Tensor key, const Tensor value,
170+
std::optional<Tensor> cu_seqlens_q,
171+
std::optional<Tensor> cu_seqlens_kv,
172+
std::optional<Tensor> block_table, int64_t num_heads,
173+
int64_t num_kv_heads, int64_t head_size, double scale,
174+
bool causal, int64_t window_left, int64_t window_right,
175+
int64_t block_size, Tensor output,
176+
std::optional<int64_t> sliding_window) const override {
177+
auto stream = static_cast<aclrtStream>(stream_);
178+
const bool paged = paged_;
179+
180+
// The base class stored the resolved window pair in `window_left_` /
181+
// `window_right_` at construction; prefer those over the call-site
182+
// args so that `sliding_window` is honored here as well.
183+
int64_t wl = window_left_;
184+
int64_t wr = window_right_;
185+
(void)window_left;
186+
(void)window_right;
187+
(void)sliding_window;
188+
189+
int64_t sparse_mode;
190+
int64_t pre_tokens = 2147483647;
191+
int64_t next_tokens = 2147483647;
192+
if (causal) {
193+
if (wl >= 0) {
194+
sparse_mode = 4;
195+
pre_tokens = wl;
196+
next_tokens = 0;
197+
} else {
198+
sparse_mode = 3;
199+
next_tokens = 0;
200+
}
201+
} else {
202+
sparse_mode = 0;
203+
if (wl >= 0) pre_tokens = wl;
204+
if (wr >= 0) next_tokens = wr;
205+
}
206+
207+
if (!paged) {
208+
// --- Prefill ---
209+
int64_t T = query.size(0);
210+
211+
// cumSeqLengths / extractSeqLengths automatically skip D2H when
212+
// cu_seqlens is a CPU tensor (see detail:: helpers above).
213+
aclIntArray* seq_q =
214+
cu_seqlens_q.has_value()
215+
? detail::cumSeqLengths(cu_seqlens_q.value(), stream)
216+
: aclCreateIntArray(&T, 1);
217+
aclIntArray* seq_kv =
218+
cu_seqlens_kv.has_value()
219+
? detail::cumSeqLengths(cu_seqlens_kv.value(), stream)
220+
: aclCreateIntArray(&T, 1);
221+
222+
aclTensor* t_q = prefill_q_cache_.get(const_cast<void*>(query.data()));
223+
// K/V descriptors go into TensorList which takes ownership — must be
224+
// per-call (cannot cache).
225+
aclTensor* t_k = ascend::BuildAclTensor(key);
226+
aclTensor* t_v = ascend::BuildAclTensor(value);
227+
aclTensor* t_out = prefill_out_cache_.get(output.data());
228+
229+
const aclTensor* k_arr[] = {t_k};
230+
const aclTensor* v_arr[] = {t_v};
231+
aclTensorList* key_list = aclCreateTensorList(k_arr, 1);
232+
aclTensorList* val_list = aclCreateTensorList(v_arr, 1);
233+
234+
uint64_t ws_needed = 0;
235+
aclOpExecutor* executor = nullptr;
236+
aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize(
237+
t_q, key_list, val_list,
238+
nullptr, // pseShift
239+
causal_mask_, // attenMask (pre-computed, or nullptr)
240+
seq_q, // actualSeqLengths
241+
seq_kv, // actualSeqLengthsKv
242+
nullptr, nullptr, nullptr, nullptr,
243+
nullptr, // deqScale1..quantOffset2
244+
nullptr, nullptr, // antiquantScale, antiquantOffset
245+
nullptr, // blockTable
246+
nullptr, nullptr, // queryPaddingSize, kvPaddingSize
247+
nullptr, nullptr, nullptr,
248+
nullptr, // key/value antiquant scale/offset
249+
nullptr, nullptr,
250+
nullptr, // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen
251+
nullptr, nullptr,
252+
nullptr, // queryRope, keyRope, keyRopeAntiquantScale
253+
nullptr, nullptr, // dequantScaleQuery, learnableSink
254+
num_heads, scale, pre_tokens, next_tokens, const_cast<char*>("TND"),
255+
num_kv_heads, sparse_mode,
256+
0, // innerPrecise
257+
0, // blockSize (unused for prefill)
258+
0, false, // antiquantMode, softmaxLseFlag
259+
0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode
260+
t_out, nullptr, &ws_needed, &executor);
261+
assert(
262+
gws == ACL_SUCCESS &&
263+
"aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (prefill)");
264+
265+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed);
266+
aclError ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed,
267+
executor, stream);
268+
assert(ret == ACL_SUCCESS &&
269+
"aclnnFusedInferAttentionScoreV4 failed (prefill)");
270+
271+
// t_q and t_out are owned by caches — do NOT destroy.
272+
// t_k and t_v are owned by TensorLists.
273+
aclDestroyTensorList(key_list);
274+
aclDestroyTensorList(val_list);
275+
aclDestroyIntArray(seq_q);
276+
aclDestroyIntArray(seq_kv);
277+
return;
278+
}
279+
280+
// --- Paged decode ---
281+
assert(cu_seqlens_kv.has_value() &&
282+
"`FlashAttention` paged decode requires `cu_seqlens_kv`");
283+
284+
aclTensor* t_query = decode_q_cache_.get(const_cast<void*>(query.data()));
285+
aclTensor* t_output = decode_out_cache_.get(output.data());
286+
287+
// K/V descriptors go into TensorList which takes ownership — must be
288+
// per-call. Use pre-computed metadata to avoid heap allocs.
289+
aclTensor* t_key = aclCreateTensor(
290+
kv_shape_.data(), static_cast<int64_t>(kv_shape_.size()), kv_acl_dt_,
291+
kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(),
292+
static_cast<int64_t>(kv_storage_shape_.size()),
293+
const_cast<void*>(key.data()));
294+
aclTensor* t_value = aclCreateTensor(
295+
kv_shape_.data(), static_cast<int64_t>(kv_shape_.size()), kv_acl_dt_,
296+
kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(),
297+
static_cast<int64_t>(kv_storage_shape_.size()),
298+
const_cast<void*>(value.data()));
299+
300+
// extractSeqLengths skips D2H when cu_seqlens_kv is a CPU tensor.
301+
aclIntArray* seq_kv =
302+
detail::extractSeqLengths(cu_seqlens_kv.value(), stream);
303+
aclTensor* t_block_table =
304+
block_table_cache_.get(const_cast<void*>(block_table.value().data()));
305+
306+
const aclTensor* k_arr[] = {t_key};
307+
const aclTensor* v_arr[] = {t_value};
308+
aclTensorList* key_list = aclCreateTensorList(k_arr, 1);
309+
aclTensorList* val_list = aclCreateTensorList(v_arr, 1);
310+
311+
uint64_t ws_needed = 0;
312+
aclOpExecutor* executor = nullptr;
313+
aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize(
314+
t_query, key_list, val_list,
315+
nullptr, // pseShift
316+
nullptr, // attenMask (sparseMode ignored for Q_S=1)
317+
nullptr, // actualSeqLengths (ignored for Q_S=1)
318+
seq_kv, // actualSeqLengthsKv (mandatory for paged)
319+
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
320+
t_block_table, // blockTable
321+
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
322+
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, num_heads, scale,
323+
static_cast<int64_t>(2147483647), static_cast<int64_t>(2147483647),
324+
const_cast<char*>("BNSD"), num_kv_heads,
325+
0, // sparseMode=0 (ignored for Q_S=1)
326+
0, // innerPrecise
327+
block_size, // blockSize
328+
0, false, // antiquantMode, softmaxLseFlag
329+
0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode
330+
t_output, nullptr, &ws_needed, &executor);
331+
assert(gws == ACL_SUCCESS &&
332+
"aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (decode)");
333+
334+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed);
335+
aclError ret =
336+
aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, executor, stream);
337+
assert(ret == ACL_SUCCESS &&
338+
"aclnnFusedInferAttentionScoreV4 failed (decode)");
339+
340+
// t_query, t_output, t_block_table owned by caches — do NOT destroy.
341+
// t_key, t_value owned by TensorLists.
342+
aclDestroyTensorList(key_list);
343+
aclDestroyTensorList(val_list);
344+
aclDestroyIntArray(seq_kv);
345+
}
346+
347+
private:
348+
bool paged_ = false;
349+
350+
mutable ascend::AclTensorCache prefill_q_cache_;
351+
352+
mutable ascend::AclTensorCache prefill_out_cache_;
353+
354+
mutable ascend::AclTensorCache decode_q_cache_;
355+
356+
mutable ascend::AclTensorCache decode_out_cache_;
357+
358+
mutable ascend::AclTensorCache block_table_cache_;
359+
360+
aclTensor* causal_mask_ = nullptr;
361+
362+
void* causal_mask_buf_ = nullptr;
363+
364+
std::vector<int64_t> kv_shape_;
365+
366+
std::vector<int64_t> kv_strides_;
367+
368+
std::vector<int64_t> kv_storage_shape_;
369+
370+
aclDataType kv_acl_dt_ = ACL_DT_UNDEFINED;
371+
};
372+
373+
} // namespace infini::ops
374+
375+
#endif

0 commit comments

Comments
 (0)