Skip to content

Commit 255ba00

Browse files
author
zhangyue
committed
style: fix code convention violations (round 2)
- C4: lowercase assert message starts (workspace_pool_, rms_norm, rotary_embedding) - C4: remove trailing period from workspace_pool_ assert - C9: add blank line between SlotKey struct members - G4: backtick-fence identifiers in comments across 12 files - G4: backtick-fence identifiers in assert messages (flash_attention, rotary_embedding) - P1: remove duplicate `import re` in generate_wrappers.py - P4: add blank lines around control flow in test_flash_attention.py
1 parent ae2d77e commit 255ba00

20 files changed

Lines changed: 52 additions & 45 deletions

scripts/generate_wrappers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ def _find_vector_tensor_params(op_name):
107107
"""Return a set of parameter names declared as `std::vector<Tensor>` in
108108
the base header.
109109
"""
110-
import re
111-
112110
source = (_BASE_DIR / f"{op_name}.h").read_text()
113111

114112
return set(re.findall(r"std::vector<Tensor>\s+(\w+)", source))

src/ascend/add/kernel.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ class Operator<Add, Device::Type::kAscend> : public Add {
2020
in_cache_(input),
2121
oth_cache_(other),
2222
out_cache_(out) {
23-
// aclCreateScalar stores the pointer rather than copying the value, so
24-
// alpha_storage_* must remain alive for the lifetime of alpha_.
23+
// `aclCreateScalar` stores the pointer rather than copying the value, so
24+
// `alpha_storage_*` must remain alive for the lifetime of `alpha_`.
2525
// The alpha scalar type must match the tensor dtype: use int64 for integer
2626
// dtypes and float for floating-point dtypes.
2727
if (ascend::isIntegerDtype(input.dtype())) {
@@ -71,8 +71,9 @@ class Operator<Add, Device::Type::kAscend> : public Add {
7171
mutable uint64_t ws_size_ = 0;
7272

7373
float alpha_float_storage_ =
74-
1.0f; // stable address for aclCreateScalar (float)
75-
int64_t alpha_int_storage_ = 1; // stable address for aclCreateScalar (int)
74+
1.0f; // Stable address for `aclCreateScalar` (float).
75+
int64_t alpha_int_storage_ =
76+
1; // Stable address for `aclCreateScalar` (int).
7677
aclScalar* alpha_ = nullptr;
7778
};
7879

src/ascend/add_rms_norm/kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
namespace infini::ops {
1616

17-
// Decomposed implementation: aclnnAdd + aclnnRmsNorm.
17+
// Decomposed implementation: `aclnnAdd` + `aclnnRmsNorm`.
1818
//
19-
// The fused aclnnAddRmsNorm API has ~200 us host-side launch overhead that
19+
// The fused `aclnnAddRmsNorm` API has ~200 us host-side launch overhead that
2020
// dominates small-tensor dispatch. Decomposing into two fast ACLNN calls
2121
// reduces host dispatch from ~224 us to ~56 us (4x faster) with negligible
2222
// NPU-side impact for inference tensor sizes.
@@ -31,10 +31,10 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
3131
gamma_cache_(gamma),
3232
y_out_cache_(y_out),
3333
x_out_cache_(x_out) {
34-
// Alpha scalar for aclnnAdd (x_out = x1 + 1.0 * x2).
34+
// Alpha scalar for `aclnnAdd` (x_out = x1 + 1.0 * x2).
3535
alpha_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT);
3636

37-
// aclnnRmsNorm writes rstd as a required side output.
37+
// `aclnnRmsNorm` writes `rstd` as a required side output.
3838
// Size computed here; buffer obtained from pool in `operator()`.
3939
rstd_shape_ = {static_cast<int64_t>(batch_size_),
4040
static_cast<int64_t>(nhead_)};

src/ascend/add_rms_norm/kernel_custom.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ namespace infini::ops {
3333
//
3434
// A single-kernel implementation that computes x_out = x1 + x2 followed by
3535
// y = rms_norm(x_out, gamma, eps) in one launch, avoiding the decomposed
36-
// aclnnAdd + aclnnRmsNorm calls (index 0) or the fused aclnnAddRmsNorm call
37-
// (index 1). Migrated from the custom RmsNorm kernel (index 1 of RmsNorm).
36+
// `aclnnAdd` + `aclnnRmsNorm` calls (index 0) or the fused `aclnnAddRmsNorm`
37+
// call (index 1). Migrated from the custom RmsNorm kernel (index 1 of
38+
// RmsNorm).
3839
//
3940
// Select via `implementation_index=2` in Python:
4041
// infini.ops.add_rms_norm(x1, x2, gamma, eps, y_out, x_out,
@@ -59,8 +60,9 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
5960
dim_length_align_ =
6061
((static_cast<int64_t>(dim_) + align_elems - 1) / align_elems) *
6162
align_elems;
62-
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
63-
"Custom AddRmsNorm kernel requires 32-byte aligned last dimension");
63+
assert(
64+
static_cast<int64_t>(dim_) == dim_length_align_ &&
65+
"custom `AddRmsNorm` kernel requires 32-byte aligned last dimension");
6466

6567
total_rows_ =
6668
static_cast<int64_t>(batch_size_) * static_cast<int64_t>(nhead_);

src/ascend/add_rms_norm/kernel_fused.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
namespace infini::ops {
1515

16-
// Fused implementation via aclnnAddRmsNorm (implementation index 1).
16+
// Fused implementation via `aclnnAddRmsNorm` (implementation index 1).
1717
//
1818
// Computes x_out = x1 + x2 and y_out = rms_norm(x_out, gamma, eps) in a
1919
// single CANN launch. The fused API has higher host-side launch overhead
20-
// (~200 us) compared to the decomposed aclnnAdd + aclnnRmsNorm path (~39 us),
21-
// but may offer better NPU-side efficiency for large tensors where kernel
20+
// (~200 us) compared to the decomposed `aclnnAdd` + `aclnnRmsNorm` path (~39
21+
// us), but may offer better NPU-side efficiency for large tensors where kernel
2222
// fusion reduces memory traffic.
2323
//
2424
// Select via `implementation_index=1` in Python:
@@ -34,7 +34,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 1> : public AddRmsNorm {
3434
gamma_cache_(gamma),
3535
y_out_cache_(y_out),
3636
x_out_cache_(x_out) {
37-
// aclnnAddRmsNorm requires rstdOut to have the same ndim as x1, with
37+
// `aclnnAddRmsNorm` requires `rstdOut` to have the same ndim as x1, with
3838
// the last gamma.ndim() dimensions set to 1. For example:
3939
// x1 shape(2, 32, 128), gamma shape(128) -> rstdOut shape(2, 32, 1)
4040
// x1 shape(64, 128), gamma shape(128) -> rstdOut shape(64, 1)

src/ascend/causal_softmax/kernel.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ class Operator<CausalSoftmax, Device::Type::kAscend> : public CausalSoftmax {
6464
mstrides.data(), 0, ACL_FORMAT_ND,
6565
mshape.data(), mshape.size(), mask_buf_);
6666

67-
// Scalar -inf for the masked-fill step. aclCreateScalar stores the pointer
68-
// rather than copying, so neg_inf_storage_ must stay alive with the object.
67+
// Scalar -inf for the masked-fill step. `aclCreateScalar` stores the
68+
// pointer rather than copying, so `neg_inf_storage_` must stay alive with
69+
// the object.
6970
neg_inf_ = aclCreateScalar(&neg_inf_storage_, ACL_FLOAT);
70-
// Workspaces are allocated lazily on first operator() call.
71+
// Workspaces are allocated lazily on first `operator()` call.
7172
}
7273

7374
~Operator() {

src/ascend/reshape_and_cache/kernel.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
namespace infini::ops {
1717

18-
// Device-side scatter via aclnnInplaceIndexCopy.
18+
// Device-side scatter via `aclnnInplaceIndexCopy`.
1919
//
2020
// The previous implementation copied slot_mapping D2H (aclrtSynchronizeStream),
2121
// then issued per-token D2D memcpy in a host loop. For batch=256, this meant
22-
// ~100 us sync + ~500 us host loop overhead. aclnnInplaceIndexCopy performs
22+
// ~100 us sync + ~500 us host loop overhead. `aclnnInplaceIndexCopy` performs
2323
// the scatter entirely on the NPU with two ACLNN calls (one for K, one for V),
2424
// eliminating all D2H synchronisation and host-side loops.
2525
//
@@ -72,7 +72,7 @@ class Operator<ReshapeAndCache, Device::Type::kAscend>
7272
auto t_slot = slot_cache_.get(const_cast<void*>(slot_mapping.data()));
7373

7474
// K cache scatter: kv_k[slot_mapping[i]] = key[i] along dim 0.
75-
// Executor caching is not used here because aclnnInplaceIndexCopy is an
75+
// Executor caching is not used here because `aclnnInplaceIndexCopy` is an
7676
// inplace operation where self is both input and output; the executor
7777
// reuse via aclSetInputTensorAddr does not update the output reference.
7878
uint64_t k_ws = 0;

src/ascend/rms_norm/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Operator<RmsNorm, Device::Type::kAscend> : public RmsNorm {
2222
in_cache_(input),
2323
weight_cache_(weight),
2424
out_cache_(out) {
25-
// aclnnRmsNorm writes rstd as a required side output.
25+
// `aclnnRmsNorm` writes `rstd` as a required side output.
2626
// Size computed here; buffer obtained from pool in `operator()`.
2727
rstd_shape_ = {static_cast<int64_t>(batch_size_),
2828
static_cast<int64_t>(nhead_)};

src/ascend/rms_norm/kernel_custom.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Operator<RmsNorm, Device::Type::kAscend, 1> : public RmsNorm {
5858
((static_cast<int64_t>(dim_) + align_elems - 1) / align_elems) *
5959
align_elems;
6060
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
61-
"Custom RmsNorm kernel requires 32-byte aligned last dimension");
61+
"custom `RmsNorm` kernel requires 32-byte aligned last dimension");
6262

6363
total_rows_ =
6464
static_cast<int64_t>(batch_size_) * static_cast<int64_t>(nhead_);

src/ascend/rotary_embedding/kernel.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
namespace infini::ops {
1919

20-
// Rotary position embedding via aclnnApplyRotaryPosEmbV2.
20+
// Rotary position embedding via `aclnnApplyRotaryPosEmbV2`.
2121
//
2222
// V2 handles Q and K simultaneously in a single inplace call (layout=4, TND).
2323
// The `rotaryMode` parameter accepts "half", "interleave", or "quarter", but
@@ -42,12 +42,13 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
4242
: RotaryEmbedding(positions, query, key, cos_sin_cache, head_size,
4343
rotary_dim, is_neox_style, query_out, key_out) {
4444
assert(rotary_dim == head_size &&
45-
"Ascend `RotaryEmbedding` requires rotary_dim == head_size "
45+
"ascend `RotaryEmbedding` requires `rotary_dim` == `head_size` "
4646
"(partial rotation not supported)");
4747
assert(is_neox_style &&
48-
"Ascend `RotaryEmbedding` requires neox style — "
49-
"aclnnApplyRotaryPosEmbV2 rotaryMode only supports \"half\"; "
50-
"\"interleave\" and \"quarter\" return ACLNN_ERR_PARAM_INVALID");
48+
"ascend `RotaryEmbedding` requires neox style — "
49+
"`aclnnApplyRotaryPosEmbV2` `rotaryMode` only supports "
50+
"\"half\"; \"interleave\" and \"quarter\" return "
51+
"`ACLNN_ERR_PARAM_INVALID`");
5152

5253
const int64_t max_seq_len = cos_sin_cache.size(0);
5354
const int64_t D = head_size_;
@@ -101,7 +102,7 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
101102
const int64_t Nkv = num_kv_heads_;
102103
aclDataType acl_dt = ascend::toAclDtype(query.dtype());
103104

104-
// Gathered cos/sin buffers [T, D] — filled by aclnnIndexSelect each call.
105+
// Gathered cos/sin buffers [T, D] — filled by `aclnnIndexSelect` each call.
105106
size_t gathered_bytes = static_cast<size_t>(T * D) * elem_sz;
106107
aclrtMalloc(&cos_dev_, gathered_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
107108
aclrtMalloc(&sin_dev_, gathered_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
@@ -147,7 +148,7 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
147148
const int64_t Nkv = key.size(1);
148149
const int64_t D = head_size;
149150

150-
// Step 1: Gather cos/sin by positions via aclnnIndexSelect (async).
151+
// Step 1: Gather cos/sin by positions via `aclnnIndexSelect` (async).
151152
{
152153
auto t_cos_table = cos_table_cache_.get(cos_table_dev_);
153154
auto t_sin_table = sin_table_cache_.get(sin_table_dev_);

0 commit comments

Comments
 (0)