Skip to content

Commit f757ed6

Browse files
author
zhangyue
committed
perf(ascend/reshape_and_cache): replace int64 slot_mapping D2H with async aclnnCast
The ATB `ReshapeAndCacheParam` (impl=2) int64 path previously did `aclrtMemcpyAsync` D2H + CPU int64→int32 cast + `aclrtMemcpyAsync` H2D with an explicit `aclrtSynchronizeStream` in between. The sync blocks the stream and makes the int64 path NPUGraph-incompatible, which forced callers (vllm-infini) to pre-cast `slot_mapping` to int32 on the Python side (36 redundant Cast launches otherwise per decoding step). Route the int64 branch through a cached `aclnnCast` instead: src/dst tensor descriptors live in `AclTensorCache` slots, the executor is set repeatable, and the cast stays fully async on-stream. The whole op now matches vLLM's native int64 `slot_mapping` convention without the sync penalty.
1 parent c8a3ff2 commit f757ed6

1 file changed

Lines changed: 46 additions & 19 deletions

File tree

src/ascend/reshape_and_cache/kernel_atb.h

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <cstdint>
99

1010
#include "acl/acl.h"
11+
#include "aclnn/aclnn_base.h"
12+
#include "aclnnop/aclnn_cast.h"
1113
#include "ascend/atb_common_.h"
1214
#include "ascend/common.h"
1315
#include "ascend/workspace_pool_.h"
@@ -31,13 +33,15 @@ namespace infini::ops {
3133
// before each Execute to bind the VariantPack.
3234
//
3335
// NOTE: `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the
34-
// caller passes int64 (the default in PyTorch / vLLM), this operator casts
35-
// to int32 via a pre-allocated device buffer — matching the pattern used in
36-
// the ATB rotary_embedding operator.
36+
// caller passes int64 (the PyTorch / vLLM default), this operator issues an
37+
// async `aclnnCast` to a pre-allocated int32 device buffer. The cast
38+
// executor is cached across calls and the whole step stays on the stream
39+
// with no D2H/H2D round-trip, so the int64 path is NPUGraph-capturable and
40+
// roughly on par with the int32 fast path.
3741
//
3842
// Input layout:
3943
// key, value : [num_tokens, num_kv_heads, head_size]
40-
// slot_mapping: [num_tokens] (int32 or int64; int64 is cast internally)
44+
// slot_mapping: [num_tokens] (int32 or int64)
4145
//
4246
// KV cache layout:
4347
// kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size]
@@ -78,6 +82,16 @@ class Operator<ReshapeAndCache, Device::Type::kAscend, 2>
7882

7983
slot_is_int32_ = (slot_mapping.element_size() == sizeof(int32_t));
8084

85+
// Prepare aclnnCast descriptors for the int64 → int32 path. Source
86+
// descriptor's data pointer is refreshed per call; destination is the
87+
// pre-allocated `slot32_buf_`.
88+
if (!slot_is_int32_) {
89+
slot_i64_cache_ = ascend::AclTensorCache(
90+
{T}, ACL_INT64, const_cast<void*>(slot_mapping.data()));
91+
slot_i32_cache_ =
92+
ascend::AclTensorCache({T}, ACL_INT32, slot32_buf_);
93+
}
94+
8195
// Create the ATB operation (reused across calls).
8296
atb::infer::ReshapeAndCacheParam param;
8397
atb::Status s = atb::CreateOperation(param, &op_);
@@ -88,6 +102,8 @@ class Operator<ReshapeAndCache, Device::Type::kAscend, 2>
88102
~Operator() {
89103
if (!ascend::isAclRuntimeAlive()) return;
90104
if (op_) atb::DestroyOperation(op_);
105+
slot_i64_cache_.release();
106+
slot_i32_cache_.release();
91107
if (slot32_buf_) aclrtFree(slot32_buf_);
92108
}
93109

@@ -101,29 +117,31 @@ class Operator<ReshapeAndCache, Device::Type::kAscend, 2>
101117
auto stream = static_cast<aclrtStream>(stream_);
102118

103119
// `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the
104-
// caller provides int64 (the PyTorch/vLLM default), cast to int32 via
105-
// a pre-allocated device buffer.
120+
// caller provides int64 (the PyTorch/vLLM default), issue an async
121+
// `aclnnCast` to the pre-allocated int32 device buffer — keeps the
122+
// whole step on-stream and NPUGraph-capturable.
106123
void* slot32_ptr;
107124

108125
if (slot_is_int32_) {
109126
// Already int32 — pass through directly.
110127
slot32_ptr = const_cast<void*>(slot_mapping.data());
111128
} else {
112-
// int64 → int32: D2H, CPU cast, H2D.
113-
auto T = static_cast<size_t>(num_tokens_);
114-
std::vector<int64_t> i64(T);
115-
aclrtMemcpyAsync(i64.data(), T * sizeof(int64_t), slot_mapping.data(),
116-
T * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream);
117-
aclrtSynchronizeStream(stream);
118-
119-
std::vector<int32_t> i32(T);
120-
121-
for (size_t i = 0; i < T; ++i) {
122-
i32[i] = static_cast<int32_t>(i64[i]);
129+
auto t_src =
130+
slot_i64_cache_.get(const_cast<void*>(slot_mapping.data()));
131+
auto t_dst = slot_i32_cache_.get(slot32_buf_);
132+
133+
if (!cast_exec_) {
134+
aclnnCastGetWorkspaceSize(t_src, ACL_INT32, t_dst, &cast_ws_,
135+
&cast_exec_);
136+
aclSetAclOpExecutorRepeatable(cast_exec_);
137+
} else {
138+
aclSetInputTensorAddr(cast_exec_, 0, t_src,
139+
const_cast<void*>(slot_mapping.data()));
140+
aclSetOutputTensorAddr(cast_exec_, 0, t_dst, slot32_buf_);
123141
}
124142

125-
aclrtMemcpyAsync(slot32_buf_, slot32_bytes_, i32.data(), slot32_bytes_,
126-
ACL_MEMCPY_HOST_TO_DEVICE, stream);
143+
auto& cast_arena = ascend::GetWorkspacePool().Ensure(stream, cast_ws_);
144+
aclnnCast(cast_arena.buf, cast_ws_, cast_exec_, stream);
127145
slot32_ptr = slot32_buf_;
128146
}
129147

@@ -223,6 +241,15 @@ class Operator<ReshapeAndCache, Device::Type::kAscend, 2>
223241

224242
// True if the caller already provides int32 `slot_mapping`.
225243
bool slot_is_int32_ = false;
244+
245+
// Cached aclnnCast descriptors (int64 slot_mapping → int32 buffer).
246+
mutable ascend::AclTensorCache slot_i64_cache_;
247+
248+
mutable ascend::AclTensorCache slot_i32_cache_;
249+
250+
mutable aclOpExecutor* cast_exec_ = nullptr;
251+
252+
mutable uint64_t cast_ws_ = 0;
226253
};
227254

228255
} // namespace infini::ops

0 commit comments

Comments
 (0)