88#include < cstdint>
99
1010#include " acl/acl.h"
11+ #include " aclnn/aclnn_base.h"
12+ #include " aclnnop/aclnn_cast.h"
1113#include " ascend/atb_common_.h"
1214#include " ascend/common.h"
1315#include " ascend/workspace_pool_.h"
@@ -31,13 +33,15 @@ namespace infini::ops {
3133// before each Execute to bind the VariantPack.
3234//
3335// NOTE: `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the
34- // caller passes int64 (the default in PyTorch / vLLM), this operator casts
35- // to int32 via a pre-allocated device buffer — matching the pattern used in
36- // the ATB rotary_embedding operator.
36+ // caller passes int64 (the PyTorch / vLLM default), this operator issues an
37+ // async `aclnnCast` to a pre-allocated int32 device buffer. The cast
38+ // executor is cached across calls and the whole step stays on the stream
39+ // with no D2H/H2D round-trip, so the int64 path is NPUGraph-capturable and
40+ // roughly on par with the int32 fast path.
3741//
3842// Input layout:
3943// key, value : [num_tokens, num_kv_heads, head_size]
40- // slot_mapping: [num_tokens] (int32 or int64; int64 is cast internally )
44+ // slot_mapping: [num_tokens] (int32 or int64)
4145//
4246// KV cache layout:
4347// kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size]
@@ -78,6 +82,16 @@ class Operator<ReshapeAndCache, Device::Type::kAscend, 2>
7882
7983 slot_is_int32_ = (slot_mapping.element_size () == sizeof (int32_t ));
8084
85+ // Prepare aclnnCast descriptors for the int64 → int32 path. Source
86+ // descriptor's data pointer is refreshed per call; destination is the
87+ // pre-allocated `slot32_buf_`.
88+ if (!slot_is_int32_) {
89+ slot_i64_cache_ = ascend::AclTensorCache (
90+ {T}, ACL_INT64 , const_cast <void *>(slot_mapping.data ()));
91+ slot_i32_cache_ =
92+ ascend::AclTensorCache ({T}, ACL_INT32 , slot32_buf_);
93+ }
94+
8195 // Create the ATB operation (reused across calls).
8296 atb::infer::ReshapeAndCacheParam param;
8397 atb::Status s = atb::CreateOperation (param, &op_);
@@ -88,6 +102,8 @@ class Operator<ReshapeAndCache, Device::Type::kAscend, 2>
88102 ~Operator () {
89103 if (!ascend::isAclRuntimeAlive ()) return ;
90104 if (op_) atb::DestroyOperation (op_);
105+ slot_i64_cache_.release ();
106+ slot_i32_cache_.release ();
91107 if (slot32_buf_) aclrtFree (slot32_buf_);
92108 }
93109
@@ -101,29 +117,31 @@ class Operator<ReshapeAndCache, Device::Type::kAscend, 2>
101117 auto stream = static_cast <aclrtStream>(stream_);
102118
103119 // `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the
104- // caller provides int64 (the PyTorch/vLLM default), cast to int32 via
105- // a pre-allocated device buffer.
120+ // caller provides int64 (the PyTorch/vLLM default), issue an async
121+ // `aclnnCast` to the pre-allocated int32 device buffer — keeps the
122+ // whole step on-stream and NPUGraph-capturable.
106123 void * slot32_ptr;
107124
108125 if (slot_is_int32_) {
109126 // Already int32 — pass through directly.
110127 slot32_ptr = const_cast <void *>(slot_mapping.data ());
111128 } else {
112- // int64 → int32: D2H, CPU cast, H2D.
113- auto T = static_cast <size_t >(num_tokens_);
114- std::vector<int64_t > i64 (T);
115- aclrtMemcpyAsync (i64 .data (), T * sizeof (int64_t ), slot_mapping.data (),
116- T * sizeof (int64_t ), ACL_MEMCPY_DEVICE_TO_HOST , stream);
117- aclrtSynchronizeStream (stream);
118-
119- std::vector<int32_t > i32 (T);
120-
121- for (size_t i = 0 ; i < T; ++i) {
122- i32 [i] = static_cast <int32_t >(i64 [i]);
129+ auto t_src =
130+ slot_i64_cache_.get (const_cast <void *>(slot_mapping.data ()));
131+ auto t_dst = slot_i32_cache_.get (slot32_buf_);
132+
133+ if (!cast_exec_) {
134+ aclnnCastGetWorkspaceSize (t_src, ACL_INT32 , t_dst, &cast_ws_,
135+ &cast_exec_);
136+ aclSetAclOpExecutorRepeatable (cast_exec_);
137+ } else {
138+ aclSetInputTensorAddr (cast_exec_, 0 , t_src,
139+ const_cast <void *>(slot_mapping.data ()));
140+ aclSetOutputTensorAddr (cast_exec_, 0 , t_dst, slot32_buf_);
123141 }
124142
125- aclrtMemcpyAsync (slot32_buf_, slot32_bytes_, i32 . data (), slot32_bytes_,
126- ACL_MEMCPY_HOST_TO_DEVICE , stream);
143+ auto & cast_arena = ascend::GetWorkspacePool (). Ensure (stream, cast_ws_);
144+ aclnnCast (cast_arena. buf , cast_ws_, cast_exec_ , stream);
127145 slot32_ptr = slot32_buf_;
128146 }
129147
@@ -223,6 +241,15 @@ class Operator<ReshapeAndCache, Device::Type::kAscend, 2>
223241
224242 // True if the caller already provides int32 `slot_mapping`.
225243 bool slot_is_int32_ = false ;
244+
245+ // Cached aclnnCast descriptors (int64 slot_mapping → int32 buffer).
246+ mutable ascend::AclTensorCache slot_i64_cache_;
247+
248+ mutable ascend::AclTensorCache slot_i32_cache_;
249+
250+ mutable aclOpExecutor* cast_exec_ = nullptr ;
251+
252+ mutable uint64_t cast_ws_ = 0 ;
226253};
227254
228255} // namespace infini::ops
0 commit comments