1+ /* Copyright 2026 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include < c10/core/Device.h>
17+ #include < glog/logging.h>
18+ #include < torch/torch.h>
19+ #include < torch_npu/csrc/libs/init_npu.h>
20+ #include < torch_npu/torch_npu.h>
21+
22+ #include < nlohmann/json.hpp>
23+ #ifdef TORCH_HIGHER_THAN_PTA6
24+ #include < torch_npu/csrc/framework/OpCommand.h>
25+ #else
26+ #include < torch_npu/csrc/aten/NPUNativeFunctions.h>
27+ #include < torch_npu/csrc/framework/utils/OpPreparation.h>
28+ #endif
29+
30+ #include " acl/acl.h"
31+ #include " aclnn_recurrent_gated_delta_rule.h"
32+ #include " core/common/macros.h"
33+ #include " core/kernels/npu/utils.h"
34+ #include " npu_ops_api.h"
35+
36+ namespace xllm ::kernel::npu {
37+
38+ torch::Tensor npu_recurrent_gated_delta_rule (
39+ const torch::Tensor& query,
40+ const torch::Tensor& key,
41+ const torch::Tensor& value,
42+ torch::Tensor& state,
43+ const std::optional<torch::Tensor>& beta,
44+ const std::optional<double > scale,
45+ const std::optional<torch::Tensor>& actual_seq_lengths,
46+ const std::optional<torch::Tensor>& ssm_state_indices,
47+ const std::optional<torch::Tensor>& num_accepted_tokens,
48+ const std::optional<torch::Tensor>& g,
49+ const std::optional<torch::Tensor>& gk) {
50+ check_tensor (query, " query" , " recurrent_gated_delta_rule" );
51+ check_tensor (key, " key" , " recurrent_gated_delta_rule" );
52+ check_tensor (value, " value" , " recurrent_gated_delta_rule" );
53+ check_tensor (state, " state" , " recurrent_gated_delta_rule" );
54+
55+ aclTensor* query_ids = nullptr ;
56+ aclTensor* key_ids = nullptr ;
57+ aclTensor* value_ids = nullptr ;
58+ aclTensor* state_ids = nullptr ;
59+ aclTensor* beta_ids = nullptr ;
60+ aclTensor* actual_seq_lengths_ids = nullptr ;
61+ aclTensor* ssm_state_indices_ids = nullptr ;
62+ aclTensor* num_accepted_tokens_ids = nullptr ;
63+ aclTensor* g_ids = nullptr ;
64+ aclTensor* gk_ids = nullptr ;
65+ aclTensor* out_ids = nullptr ;
66+
67+ int32_t device_id = query.device ().index ();
68+ aclrtStream stream = c10_npu::getCurrentNPUStream (device_id).stream ();
69+
70+ create_acltensor (&query_ids, query);
71+ create_acltensor (&key_ids, key);
72+ create_acltensor (&value_ids, value);
73+ create_acltensor (&state_ids, state);
74+
75+ if (beta.has_value () && beta.value ().defined ()) {
76+ create_acltensor (&beta_ids, beta.value ());
77+ }
78+ if (actual_seq_lengths.has_value () && actual_seq_lengths.value ().defined ()) {
79+ create_acltensor (&actual_seq_lengths_ids, actual_seq_lengths.value ());
80+ }
81+ if (ssm_state_indices.has_value () && ssm_state_indices.value ().defined ()) {
82+ create_acltensor (&ssm_state_indices_ids, ssm_state_indices.value ());
83+ }
84+ if (num_accepted_tokens.has_value () &&
85+ num_accepted_tokens.value ().defined ()) {
86+ create_acltensor (&num_accepted_tokens_ids, num_accepted_tokens.value ());
87+ }
88+ if (g.has_value () && g.value ().defined ()) {
89+ create_acltensor (&g_ids, g.value ());
90+ }
91+ if (gk.has_value () && gk.value ().defined ()) {
92+ create_acltensor (&gk_ids, gk.value ());
93+ }
94+
95+ at::Tensor out_result = at::empty_like (value);
96+ create_acltensor (&out_ids, out_result);
97+
98+ float scale_value = static_cast <float >(scale.value ());
99+
100+ uint64_t workspace_size = 0 ;
101+ aclOpExecutor* executor = nullptr ;
102+
103+ CHECK_ACL_SUCCESS (
104+ aclnnRecurrentGatedDeltaRuleGetWorkspaceSize (query_ids,
105+ key_ids,
106+ value_ids,
107+ beta_ids,
108+ state_ids,
109+ actual_seq_lengths_ids,
110+ ssm_state_indices_ids,
111+ g_ids,
112+ gk_ids,
113+ num_accepted_tokens_ids,
114+ scale_value,
115+ out_ids,
116+ &workspace_size,
117+ &executor),
118+ " recurrent_gated_delta_rule: failed to get workspace size" );
119+
120+ void * workspace_addr = nullptr ;
121+ if (workspace_size > 0 ) {
122+ CHECK_ACL_SUCCESS (
123+ aclrtMalloc (&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST),
124+ " recurrent_gated_delta_rule: failed to allocate workspace" );
125+ }
126+
127+ CHECK_ACL_SUCCESS (aclnnRecurrentGatedDeltaRule (
128+ workspace_addr, workspace_size, executor, stream),
129+ " recurrent_gated_delta_rule: failed to perform recurrent "
130+ " gated delta rule" );
131+
132+ aclDestroyTensor (query_ids);
133+ aclDestroyTensor (key_ids);
134+ aclDestroyTensor (value_ids);
135+ aclDestroyTensor (state_ids);
136+ aclDestroyTensor (out_ids);
137+
138+ if (beta_ids != nullptr ) {
139+ aclDestroyTensor (beta_ids);
140+ }
141+ if (actual_seq_lengths_ids != nullptr ) {
142+ aclDestroyTensor (actual_seq_lengths_ids);
143+ }
144+ if (ssm_state_indices_ids != nullptr ) {
145+ aclDestroyTensor (ssm_state_indices_ids);
146+ }
147+ if (num_accepted_tokens_ids != nullptr ) {
148+ aclDestroyTensor (num_accepted_tokens_ids);
149+ }
150+ if (g_ids != nullptr ) {
151+ aclDestroyTensor (g_ids);
152+ }
153+ if (gk_ids != nullptr ) {
154+ aclDestroyTensor (gk_ids);
155+ }
156+
157+ if (workspace_size > 0 ) {
158+ CHECK_ACL_SUCCESS (aclrtFree (workspace_addr),
159+ " recurrent_gated_delta_rule: failed to free workspace" );
160+ }
161+
162+ return out_result;
163+ }
164+
165+ } // namespace xllm::kernel::npu
0 commit comments