1010#include < vector>
1111
1212#include " acl/acl.h"
13+ #include " ascend/atb_common_.h"
14+ #include " ascend/paged_attention/registry.h"
15+ #include " ascend/workspace_pool_.h"
1316#include " atb/context.h"
1417#include " atb/infer_op_params.h"
1518#include " atb/operation.h"
1619#include " atb/types.h"
17- #include " ascend/atb_common_.h"
18- #include " ascend/paged_attention/registry.h"
19- #include " ascend/workspace_pool_.h"
2020#include " base/paged_attention.h"
2121#include " operator.h"
2222
@@ -47,10 +47,10 @@ template <>
4747class Operator <PagedAttention, Device::Type::kAscend , 0 >
4848 : public PagedAttention {
4949 public:
50- Operator (const Tensor query, const Tensor key_cache,
51- const Tensor value_cache , const Tensor seq_lens ,
52- const Tensor block_table , int64_t num_heads, int64_t num_kv_heads ,
53- int64_t head_size, double scale, int64_t block_size, Tensor output,
50+ Operator (const Tensor query, const Tensor key_cache, const Tensor value_cache,
51+ const Tensor seq_lens , const Tensor block_table, int64_t num_heads ,
52+ int64_t num_kv_heads , int64_t head_size, double scale ,
53+ int64_t block_size, Tensor output,
5454 std::optional<Tensor> seq_lens_host = std::nullopt ,
5555 std::optional<Tensor> block_table_host = std::nullopt )
5656 : PagedAttention(query, key_cache, value_cache, seq_lens, block_table,
@@ -110,8 +110,7 @@ class Operator<PagedAttention, Device::Type::kAscend, 0>
110110 param.qkScale = static_cast <float >(scale_);
111111
112112 atb::Status s = atb::CreateOperation (param, &op_);
113- assert (s == atb::NO_ERROR &&
114- " atb::CreateOperation(PagedAttention) failed" );
113+ assert (s == atb::NO_ERROR && " atb::CreateOperation(PagedAttention) failed" );
115114 }
116115
117116 ~Operator () {
@@ -150,24 +149,23 @@ class Operator<PagedAttention, Device::Type::kAscend, 0>
150149 if (block_table_host.has_value ()) {
151150 bt_host_ptr = const_cast <void *>(block_table_host.value ().data ());
152151 } else {
153- aclrtMemcpy (bt_host_, bt_host_bytes_, block_table.data (),
154- bt_host_bytes_, ACL_MEMCPY_DEVICE_TO_HOST);
152+ aclrtMemcpy (bt_host_, bt_host_bytes_, block_table.data (), bt_host_bytes_,
153+ ACL_MEMCPY_DEVICE_TO_HOST);
155154 }
156155
157156 if (seq_lens_host.has_value ()) {
158157 sl_host_ptr = const_cast <void *>(seq_lens_host.value ().data ());
159158 } else {
160- aclrtMemcpy (sl_host_, sl_host_bytes_, seq_lens.data (),
161- sl_host_bytes_, ACL_MEMCPY_DEVICE_TO_HOST);
159+ aclrtMemcpy (sl_host_, sl_host_bytes_, seq_lens.data (), sl_host_bytes_,
160+ ACL_MEMCPY_DEVICE_TO_HOST);
162161 }
163162
164163 atb::VariantPack vp = buildVariantPack (
165- const_cast <void *>(query.data ()),
166- const_cast <void *>(key_cache.data ()),
164+ const_cast <void *>(query.data ()), const_cast <void *>(key_cache.data ()),
167165 const_cast <void *>(value_cache.data ()),
168166 const_cast <void *>(block_table.data ()),
169- const_cast <void *>(seq_lens.data ()), output.data (),
170- bt_host_ptr, sl_host_ptr);
167+ const_cast <void *>(seq_lens.data ()), output.data (), bt_host_ptr,
168+ sl_host_ptr);
171169
172170 // Setup computes workspace requirements and binds tensor descriptors.
173171 uint64_t ws_size = 0 ;
@@ -197,9 +195,8 @@ class Operator<PagedAttention, Device::Type::kAscend, 0>
197195 // `aclIntArray*` parameters.
198196 atb::VariantPack buildVariantPack (void * query_data, void * key_cache_data,
199197 void * value_cache_data,
200- void * block_table_data,
201- void * seq_lens_data, void * output_data,
202- void * bt_host_ptr,
198+ void * block_table_data, void * seq_lens_data,
199+ void * output_data, void * bt_host_ptr,
203200 void * sl_host_ptr) const {
204201 int64_t B = query_tnd_shape_[0 ];
205202 int64_t N = query_tnd_shape_[1 ];
@@ -214,12 +211,11 @@ class Operator<PagedAttention, Device::Type::kAscend, 0>
214211 int64_t nb = kv_cache_shape_[0 ];
215212 int64_t bs = kv_cache_shape_[1 ];
216213 int64_t Nkv = kv_cache_shape_[2 ];
217- uint64_t kv_bytes =
218- static_cast <uint64_t >(nb * bs * Nkv * D) * elem_size_;
219- atb::Tensor t_key_cache = ascend::toAtbTensor (kv_cache_shape_, acl_dt_,
220- key_cache_data, kv_bytes);
221- atb::Tensor t_value_cache = ascend::toAtbTensor (
222- kv_cache_shape_, acl_dt_, value_cache_data, kv_bytes);
214+ uint64_t kv_bytes = static_cast <uint64_t >(nb * bs * Nkv * D) * elem_size_;
215+ atb::Tensor t_key_cache =
216+ ascend::toAtbTensor (kv_cache_shape_, acl_dt_, key_cache_data, kv_bytes);
217+ atb::Tensor t_value_cache = ascend::toAtbTensor (kv_cache_shape_, acl_dt_,
218+ value_cache_data, kv_bytes);
223219
224220 // Block table [B, max_blocks] — with hostData for `aclIntArray*`.
225221 atb::Tensor t_block_table = ascend::toAtbTensor (
0 commit comments