Skip to content

Commit a9523de

Browse files
author
zhangyue
committed
style: apply clang-format to all modified C++ files
1 parent f982122 commit a9523de

File tree

25 files changed

+206
-259
lines changed

25 files changed

+206
-259
lines changed

src/ascend/add_rms_norm/kernel.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#include "aclnn/aclnn_base.h"
88
#include "aclnn_add.h"
99
#include "aclnn_rms_norm.h"
10-
#include "ascend/common.h"
1110
#include "ascend/add_rms_norm/registry.h"
11+
#include "ascend/common.h"
1212
#include "ascend/workspace_pool_.h"
1313
#include "operator.h"
1414

@@ -63,10 +63,8 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
6363
&add_exec_);
6464
aclSetAclOpExecutorRepeatable(add_exec_);
6565
} else {
66-
aclSetInputTensorAddr(add_exec_, 0, t_x1,
67-
const_cast<void*>(x1.data()));
68-
aclSetInputTensorAddr(add_exec_, 1, t_x2,
69-
const_cast<void*>(x2.data()));
66+
aclSetInputTensorAddr(add_exec_, 0, t_x1, const_cast<void*>(x1.data()));
67+
aclSetInputTensorAddr(add_exec_, 1, t_x2, const_cast<void*>(x2.data()));
7068
aclSetOutputTensorAddr(add_exec_, 0, t_x_out, x_out.data());
7169
}
7270
auto& add_arena = ascend::workspacePool().ensure(stream, add_ws_);
@@ -78,18 +76,17 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
7876

7977
// Lazily create rstd tensor descriptor on first call.
8078
if (!rstd_tensor_) {
81-
rstd_tensor_ = aclCreateTensor(
82-
rstd_shape_.data(), 2, ACL_FLOAT,
83-
/*strides=*/nullptr, 0, ACL_FORMAT_ND, rstd_shape_.data(), 2,
84-
rstd_arena.buf);
79+
rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT,
80+
/*strides=*/nullptr, 0, ACL_FORMAT_ND,
81+
rstd_shape_.data(), 2, rstd_arena.buf);
8582
} else {
8683
aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf);
8784
}
8885

8986
// Step 2: y_out = rms_norm(x_out, gamma, eps).
9087
if (!norm_exec_) {
91-
aclnnRmsNormGetWorkspaceSize(t_x_out, t_gamma, eps, t_y_out,
92-
rstd_tensor_, &norm_ws_, &norm_exec_);
88+
aclnnRmsNormGetWorkspaceSize(t_x_out, t_gamma, eps, t_y_out, rstd_tensor_,
89+
&norm_ws_, &norm_exec_);
9390
aclSetAclOpExecutorRepeatable(norm_exec_);
9491
} else {
9592
aclSetInputTensorAddr(norm_exec_, 0, t_x_out, x_out.data());

src/ascend/add_rms_norm/kernel_custom.h

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@
1010
#include "acl/acl.h"
1111
#include "aclnn/aclnn_base.h"
1212
#include "aclnnop/aclnn_cast.h"
13-
#include "ascend/common.h"
1413
#include "ascend/add_rms_norm/registry.h"
14+
#include "ascend/common.h"
1515
#include "ascend/workspace_pool_.h"
1616
#include "base/add_rms_norm.h"
1717
#include "operator.h"
1818

1919
// Forward-declare the generated AscendC kernel launch function.
2020
// This symbol is provided by the `no_workspace_kernel` static library
21-
// built from `ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp`
22-
// via `ascendc_library()`.
21+
// built from
22+
// `ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp` via
23+
// `ascendc_library()`.
2324
extern "C" uint32_t aclrtlaunch_add_rms_norm(
24-
uint32_t blockDim, void* stream,
25-
void* x1, void* x2, void* weight, void* y, void* x_out,
26-
int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign,
27-
int64_t formerNum, int64_t formerLength, int64_t tailLength,
28-
float eps, int64_t dtypeSize);
25+
uint32_t blockDim, void* stream, void* x1, void* x2, void* weight, void* y,
26+
void* x_out, int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign,
27+
int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps,
28+
int64_t dtypeSize);
2929

3030
namespace infini::ops {
3131

@@ -62,8 +62,8 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
6262
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
6363
"Custom AddRmsNorm kernel requires 32-byte aligned last dimension");
6464

65-
total_rows_ = static_cast<int64_t>(batch_size_) *
66-
static_cast<int64_t>(nhead_);
65+
total_rows_ =
66+
static_cast<int64_t>(batch_size_) * static_cast<int64_t>(nhead_);
6767

6868
// For fp16 input, weight needs fp32 conversion because the custom
6969
// kernel always reads weight as fp32.
@@ -72,16 +72,15 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
7272
if (needs_weight_cast_) {
7373
// Allocate persistent fp32 weight buffer on device.
7474
size_t fp32_bytes = static_cast<size_t>(dim_) * sizeof(float);
75-
aclrtMalloc(&weight_fp32_data_, fp32_bytes,
76-
ACL_MEM_MALLOC_NORMAL_ONLY);
75+
aclrtMalloc(&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
7776

7877
// AclTensorCache for the cast source (fp16 weight descriptor).
79-
weight_src_cache_ = ascend::AclTensorCache(
80-
{static_cast<int64_t>(dim_)}, ACL_FLOAT16, nullptr);
78+
weight_src_cache_ = ascend::AclTensorCache({static_cast<int64_t>(dim_)},
79+
ACL_FLOAT16, nullptr);
8180

8281
// AclTensorCache for the cast destination (fp32 weight buffer).
83-
weight_dst_cache_ = ascend::AclTensorCache(
84-
{static_cast<int64_t>(dim_)}, ACL_FLOAT, weight_fp32_data_);
82+
weight_dst_cache_ = ascend::AclTensorCache({static_cast<int64_t>(dim_)},
83+
ACL_FLOAT, weight_fp32_data_);
8584
}
8685
}
8786

@@ -105,8 +104,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
105104
const void* cur_weight = gamma.data();
106105

107106
if (cur_weight != last_weight_ptr_) {
108-
auto t_src =
109-
weight_src_cache_.get(const_cast<void*>(cur_weight));
107+
auto t_src = weight_src_cache_.get(const_cast<void*>(cur_weight));
110108
auto t_dst = weight_dst_cache_.get(weight_fp32_data_);
111109

112110
if (!cast_exec_) {
@@ -133,25 +131,17 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
133131
// Block-level tiling: distribute rows across cores.
134132
static constexpr int64_t kMaxBlockDim = 40;
135133
int64_t used_cores = std::min(total_rows_, kMaxBlockDim);
136-
int64_t former_length =
137-
(total_rows_ + used_cores - 1) / used_cores;
134+
int64_t former_length = (total_rows_ + used_cores - 1) / used_cores;
138135
int64_t tail_length = former_length - 1;
139136
int64_t former_num = total_rows_ - tail_length * used_cores;
140137
uint32_t block_dim = static_cast<uint32_t>(used_cores);
141138

142139
// Launch custom AscendC kernel.
143140
aclrtlaunch_add_rms_norm(
144-
block_dim, stream,
145-
const_cast<void*>(x1.data()),
146-
const_cast<void*>(x2.data()),
147-
weight_fp32,
148-
y_out.data(),
149-
x_out.data(),
150-
total_rows_,
151-
static_cast<int64_t>(dim_),
152-
dim_length_align_,
153-
former_num, former_length, tail_length,
154-
eps, dtype_size_);
141+
block_dim, stream, const_cast<void*>(x1.data()),
142+
const_cast<void*>(x2.data()), weight_fp32, y_out.data(), x_out.data(),
143+
total_rows_, static_cast<int64_t>(dim_), dim_length_align_, former_num,
144+
former_length, tail_length, eps, dtype_size_);
155145
}
156146

157147
private:

src/ascend/apply_rotary_pos_emb/kernel.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ class Operator<ApplyRotaryPosEmb, Device::Type::kAscend>
4949

5050
// V2 expects cos/sin as `[T, 1, D]`. Input is `[T, D]` — same data,
5151
// different descriptor shape (T*1*D == T*D for contiguous tensors).
52-
cos_cache_ = ascend::AclTensorCache(
53-
{T, 1, D}, acl_dt, const_cast<void*>(cos.data()));
54-
sin_cache_ = ascend::AclTensorCache(
55-
{T, 1, D}, acl_dt, const_cast<void*>(sin.data()));
56-
q_cache_ = ascend::AclTensorCache(
57-
{T, Nq, D}, acl_dt, const_cast<void*>(query_out.data()));
58-
k_cache_ = ascend::AclTensorCache(
59-
{T, Nkv, D}, acl_dt, const_cast<void*>(key_out.data()));
52+
cos_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt,
53+
const_cast<void*>(cos.data()));
54+
sin_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt,
55+
const_cast<void*>(sin.data()));
56+
q_cache_ = ascend::AclTensorCache({T, Nq, D}, acl_dt,
57+
const_cast<void*>(query_out.data()));
58+
k_cache_ = ascend::AclTensorCache({T, Nkv, D}, acl_dt,
59+
const_cast<void*>(key_out.data()));
6060
}
6161

6262
~Operator() {
@@ -105,10 +105,8 @@ class Operator<ApplyRotaryPosEmb, Device::Type::kAscend>
105105
} else {
106106
aclSetInputTensorAddr(v2_exec_, 0, t_q, query_out.data());
107107
aclSetInputTensorAddr(v2_exec_, 1, t_k, key_out.data());
108-
aclSetInputTensorAddr(v2_exec_, 2, t_cos,
109-
const_cast<void*>(cos.data()));
110-
aclSetInputTensorAddr(v2_exec_, 3, t_sin,
111-
const_cast<void*>(sin.data()));
108+
aclSetInputTensorAddr(v2_exec_, 2, t_cos, const_cast<void*>(cos.data()));
109+
aclSetInputTensorAddr(v2_exec_, 3, t_sin, const_cast<void*>(sin.data()));
112110
}
113111

114112
auto& arena = ascend::workspacePool().ensure(stream, v2_ws_);

src/ascend/apply_rotary_pos_emb/kernel_atb.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ class Operator<ApplyRotaryPosEmb, Device::Type::kAscend, 1>
9898

9999
if (query.data() != query_out.data()) {
100100
aclrtMemcpyAsync(query_out.data(),
101-
static_cast<size_t>(T * hiddenQ) * elem_sz,
102-
query.data(),
101+
static_cast<size_t>(T * hiddenQ) * elem_sz, query.data(),
103102
static_cast<size_t>(T * hiddenQ) * elem_sz,
104103
ACL_MEMCPY_DEVICE_TO_DEVICE, stream);
105104
}
@@ -126,9 +125,9 @@ class Operator<ApplyRotaryPosEmb, Device::Type::kAscend, 1>
126125
cos_sin_shape_, acl_dt_, const_cast<void*>(cos.data()), cs_bytes);
127126
atb::Tensor t_sin = ascend::toAtbTensor(
128127
cos_sin_shape_, acl_dt_, const_cast<void*>(sin.data()), cs_bytes);
129-
atb::Tensor t_seqlen = ascend::toAtbTensor(
130-
seqlen_shape_, ACL_INT32, seqlen_dev_,
131-
static_cast<uint64_t>(sizeof(int32_t)));
128+
atb::Tensor t_seqlen =
129+
ascend::toAtbTensor(seqlen_shape_, ACL_INT32, seqlen_dev_,
130+
static_cast<uint64_t>(sizeof(int32_t)));
132131

133132
atb::VariantPack vp;
134133
vp.inTensors = {t_q, t_k, t_cos, t_sin, t_seqlen};

src/ascend/atb_common_.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
#include <vector>
1010

1111
#include "acl/acl.h"
12+
#include "ascend/data_type_.h"
1213
#include "atb/context.h"
1314
#include "atb/operation.h"
1415
#include "atb/types.h"
15-
#include "ascend/data_type_.h"
1616
#include "tensor.h"
1717

1818
namespace infini::ops::ascend {

src/ascend/causal_softmax/kernel.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ template <>
2929
class Operator<CausalSoftmax, Device::Type::kAscend> : public CausalSoftmax {
3030
public:
3131
Operator(const Tensor input, Tensor out)
32-
: CausalSoftmax(input, out),
33-
in_cache_(input),
34-
out_cache_(out) {
32+
: CausalSoftmax(input, out), in_cache_(input), out_cache_(out) {
3533
// Compute temp buffer size — allocated lazily from pool in `operator()`.
3634
size_t n_elems = input.numel();
3735
size_t elem_bytes = kDataTypeToSize.at(dtype_);

src/ascend/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ class AclTensorCache {
7373
public:
7474
AclTensorCache() = default;
7575

76-
// Construct from explicit metadata (for device buffers not wrapped in Tensor).
77-
// Computes contiguous strides from shape.
76+
// Construct from explicit metadata (for device buffers not wrapped in
77+
// Tensor). Computes contiguous strides from shape.
7878
AclTensorCache(std::vector<int64_t> shape, aclDataType dtype, void* data)
7979
: shape_(std::move(shape)), dtype_(dtype) {
8080
strides_.resize(shape_.size());

src/ascend/paged_attention/kernel_atb.h

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
#include <vector>
1111

1212
#include "acl/acl.h"
13+
#include "ascend/atb_common_.h"
14+
#include "ascend/paged_attention/registry.h"
15+
#include "ascend/workspace_pool_.h"
1316
#include "atb/context.h"
1417
#include "atb/infer_op_params.h"
1518
#include "atb/operation.h"
1619
#include "atb/types.h"
17-
#include "ascend/atb_common_.h"
18-
#include "ascend/paged_attention/registry.h"
19-
#include "ascend/workspace_pool_.h"
2020
#include "base/paged_attention.h"
2121
#include "operator.h"
2222

@@ -47,10 +47,10 @@ template <>
4747
class Operator<PagedAttention, Device::Type::kAscend, 0>
4848
: public PagedAttention {
4949
public:
50-
Operator(const Tensor query, const Tensor key_cache,
51-
const Tensor value_cache, const Tensor seq_lens,
52-
const Tensor block_table, int64_t num_heads, int64_t num_kv_heads,
53-
int64_t head_size, double scale, int64_t block_size, Tensor output,
50+
Operator(const Tensor query, const Tensor key_cache, const Tensor value_cache,
51+
const Tensor seq_lens, const Tensor block_table, int64_t num_heads,
52+
int64_t num_kv_heads, int64_t head_size, double scale,
53+
int64_t block_size, Tensor output,
5454
std::optional<Tensor> seq_lens_host = std::nullopt,
5555
std::optional<Tensor> block_table_host = std::nullopt)
5656
: PagedAttention(query, key_cache, value_cache, seq_lens, block_table,
@@ -110,8 +110,7 @@ class Operator<PagedAttention, Device::Type::kAscend, 0>
110110
param.qkScale = static_cast<float>(scale_);
111111

112112
atb::Status s = atb::CreateOperation(param, &op_);
113-
assert(s == atb::NO_ERROR &&
114-
"atb::CreateOperation(PagedAttention) failed");
113+
assert(s == atb::NO_ERROR && "atb::CreateOperation(PagedAttention) failed");
115114
}
116115

117116
~Operator() {
@@ -150,24 +149,23 @@ class Operator<PagedAttention, Device::Type::kAscend, 0>
150149
if (block_table_host.has_value()) {
151150
bt_host_ptr = const_cast<void*>(block_table_host.value().data());
152151
} else {
153-
aclrtMemcpy(bt_host_, bt_host_bytes_, block_table.data(),
154-
bt_host_bytes_, ACL_MEMCPY_DEVICE_TO_HOST);
152+
aclrtMemcpy(bt_host_, bt_host_bytes_, block_table.data(), bt_host_bytes_,
153+
ACL_MEMCPY_DEVICE_TO_HOST);
155154
}
156155

157156
if (seq_lens_host.has_value()) {
158157
sl_host_ptr = const_cast<void*>(seq_lens_host.value().data());
159158
} else {
160-
aclrtMemcpy(sl_host_, sl_host_bytes_, seq_lens.data(),
161-
sl_host_bytes_, ACL_MEMCPY_DEVICE_TO_HOST);
159+
aclrtMemcpy(sl_host_, sl_host_bytes_, seq_lens.data(), sl_host_bytes_,
160+
ACL_MEMCPY_DEVICE_TO_HOST);
162161
}
163162

164163
atb::VariantPack vp = buildVariantPack(
165-
const_cast<void*>(query.data()),
166-
const_cast<void*>(key_cache.data()),
164+
const_cast<void*>(query.data()), const_cast<void*>(key_cache.data()),
167165
const_cast<void*>(value_cache.data()),
168166
const_cast<void*>(block_table.data()),
169-
const_cast<void*>(seq_lens.data()), output.data(),
170-
bt_host_ptr, sl_host_ptr);
167+
const_cast<void*>(seq_lens.data()), output.data(), bt_host_ptr,
168+
sl_host_ptr);
171169

172170
// Setup computes workspace requirements and binds tensor descriptors.
173171
uint64_t ws_size = 0;
@@ -197,9 +195,8 @@ class Operator<PagedAttention, Device::Type::kAscend, 0>
197195
// `aclIntArray*` parameters.
198196
atb::VariantPack buildVariantPack(void* query_data, void* key_cache_data,
199197
void* value_cache_data,
200-
void* block_table_data,
201-
void* seq_lens_data, void* output_data,
202-
void* bt_host_ptr,
198+
void* block_table_data, void* seq_lens_data,
199+
void* output_data, void* bt_host_ptr,
203200
void* sl_host_ptr) const {
204201
int64_t B = query_tnd_shape_[0];
205202
int64_t N = query_tnd_shape_[1];
@@ -214,12 +211,11 @@ class Operator<PagedAttention, Device::Type::kAscend, 0>
214211
int64_t nb = kv_cache_shape_[0];
215212
int64_t bs = kv_cache_shape_[1];
216213
int64_t Nkv = kv_cache_shape_[2];
217-
uint64_t kv_bytes =
218-
static_cast<uint64_t>(nb * bs * Nkv * D) * elem_size_;
219-
atb::Tensor t_key_cache = ascend::toAtbTensor(kv_cache_shape_, acl_dt_,
220-
key_cache_data, kv_bytes);
221-
atb::Tensor t_value_cache = ascend::toAtbTensor(
222-
kv_cache_shape_, acl_dt_, value_cache_data, kv_bytes);
214+
uint64_t kv_bytes = static_cast<uint64_t>(nb * bs * Nkv * D) * elem_size_;
215+
atb::Tensor t_key_cache =
216+
ascend::toAtbTensor(kv_cache_shape_, acl_dt_, key_cache_data, kv_bytes);
217+
atb::Tensor t_value_cache = ascend::toAtbTensor(kv_cache_shape_, acl_dt_,
218+
value_cache_data, kv_bytes);
223219

224220
// Block table [B, max_blocks] — with hostData for `aclIntArray*`.
225221
atb::Tensor t_block_table = ascend::toAtbTensor(

src/ascend/reshape_and_cache/kernel.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class Operator<ReshapeAndCache, Device::Type::kAscend>
4646

4747
// Flattened K cache view: [total_slots, num_kv_heads, head_size].
4848
// K cache is kv_cache_out[0], starting at offset 0.
49-
kv_k_cache_ = ascend::AclTensorCache(
50-
{total_slots, nkv, hs}, acl_dt, kv_cache_out.data());
49+
kv_k_cache_ = ascend::AclTensorCache({total_slots, nkv, hs}, acl_dt,
50+
kv_cache_out.data());
5151

5252
// V cache is kv_cache_out[1], offset by stride(0) elements.
5353
v_offset_bytes_ = static_cast<size_t>(kv_cache_out.stride(0)) *
@@ -63,8 +63,7 @@ class Operator<ReshapeAndCache, Device::Type::kAscend>
6363
auto stream = static_cast<aclrtStream>(stream_);
6464

6565
void* kv_k_data = kv_cache_out.data();
66-
void* kv_v_data =
67-
static_cast<char*>(kv_cache_out.data()) + v_offset_bytes_;
66+
void* kv_v_data = static_cast<char*>(kv_cache_out.data()) + v_offset_bytes_;
6867

6968
auto t_kv_k = kv_k_cache_.get(kv_k_data);
7069
auto t_kv_v = kv_v_cache_.get(kv_v_data);
@@ -78,16 +77,16 @@ class Operator<ReshapeAndCache, Device::Type::kAscend>
7877
// reuse via aclSetInputTensorAddr does not update the output reference.
7978
uint64_t k_ws = 0;
8079
aclOpExecutor* k_exec = nullptr;
81-
aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_k, 0, t_slot, t_key,
82-
&k_ws, &k_exec);
80+
aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_k, 0, t_slot, t_key, &k_ws,
81+
&k_exec);
8382
auto& k_arena = ascend::workspacePool().ensure(stream, k_ws);
8483
aclnnInplaceIndexCopy(k_arena.buf, k_ws, k_exec, stream);
8584

8685
// V cache scatter: kv_v[slot_mapping[i]] = value[i] along dim 0.
8786
uint64_t v_ws = 0;
8887
aclOpExecutor* v_exec = nullptr;
89-
aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_v, 0, t_slot, t_value,
90-
&v_ws, &v_exec);
88+
aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_v, 0, t_slot, t_value, &v_ws,
89+
&v_exec);
9190
auto& v_arena = ascend::workspacePool().ensure(stream, v_ws);
9291
aclnnInplaceIndexCopy(v_arena.buf, v_ws, v_exec, stream);
9392
}

0 commit comments

Comments
 (0)