Skip to content

Commit b7880e6

Browse files
zhangyuezhangyue
authored andcommitted
style(pr66): sweep assert-message periods + comment backticks
Addresses inline review comments on #66 (reviewer: Ziminli) across all PR-touched files: - C4: strip trailing periods from assert messages; lowercase the sentence-starting word when it is bare English (e.g. "Ascend ..." → "ascend ..."), leave backticked identifiers untouched. - G4: backtick `RmsNorm` in kernel_custom.h header comment; backtick `aclnn` / `cos_sin_cache` / `infini.ops.add_rms_norm(...)` in kernel comments that were still running raw text. - C2: rename `aclrtlaunch_add_rms_norm` / `aclrtlaunch_rms_norm` forward-decl parameter names from AscendC internals (`x1, x2, y, x_out`) to the base-header semantic names (`input, residual, weight, out, residual_out`). The extern "C" symbol is name-blind so the AscendC kernel .cpp can keep its local names — the wrapper .h just presents the public contract. - Pre-gathered rotary test: drop the hardcoded `implementation_index=(0, 1)` parametrize, let conftest auto-inject and skip impl 2 inline (the impl 2 kernel asserts `!pre_gathered_`). Verified locally (`--gpu-id 3/4/5 --local`): test_add_rms_norm.py: 108 passed test_rms_norm.py: 72 passed test_rotary_embedding.py: 88 passed, 16 skipped (expected: impl 2 + pre_gathered, impl 0 + non-neox)
1 parent 053c907 commit b7880e6

11 files changed

Lines changed: 58 additions & 60 deletions

File tree

src/ascend/add_rms_norm/kernel_custom.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
// `ascendc_add_operator()` and cannot be `PascalCase`d.
2323
// NOLINTNEXTLINE(readability-identifier-naming)
2424
extern "C" uint32_t aclrtlaunch_add_rms_norm(
25-
uint32_t block_dim, void* stream, void* x1, void* x2, void* weight, void* y,
26-
void* x_out, int64_t total_rows, int64_t dim_length,
25+
uint32_t block_dim, void* stream, void* input, void* residual, void* weight,
26+
void* out, void* residual_out, int64_t total_rows, int64_t dim_length,
2727
int64_t dim_length_align, int64_t former_num, int64_t former_length,
2828
int64_t tail_length, float eps, int64_t dtype_code);
2929

@@ -57,7 +57,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
5757
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
5858
dtype_ == DataType::kFloat32) &&
5959
"`AddRmsNorm` custom kernel: `input` must be `fp16`, `bf16`, or "
60-
"`fp32`.");
60+
"`fp32`");
6161

6262
// 32-byte alignment on the last dimension — kernel relies on aligned
6363
// `DataCopyPad` loads/stores.
@@ -67,7 +67,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
6767
align_elems;
6868
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
6969
"`AddRmsNorm` custom kernel: last dimension must be 32-byte "
70-
"aligned.");
70+
"aligned");
7171

7272
total_rows_ =
7373
static_cast<int64_t>(batch_size_) * static_cast<int64_t>(nhead_);

src/ascend/add_rms_norm/kernel_fused.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace infini::ops {
2222
// large tensors where kernel fusion reduces memory traffic.
2323
//
2424
// Select via `implementation_index=1` in Python:
25-
// infini.ops.add_rms_norm(..., implementation_index=1, stream=s)
25+
// `infini.ops.add_rms_norm(..., implementation_index=1, stream=s)`.
2626
template <>
2727
class Operator<AddRmsNorm, Device::Type::kAscend, 1> : public AddRmsNorm {
2828
public:

src/ascend/rms_norm/kernel_custom.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,29 @@
2323
// `ascendc_add_operator()` and cannot be `PascalCase`d.
2424
// NOLINTNEXTLINE(readability-identifier-naming)
2525
extern "C" uint32_t aclrtlaunch_rms_norm(
26-
uint32_t block_dim, void* stream, void* x, void* weight, void* y,
26+
uint32_t block_dim, void* stream, void* input, void* weight, void* out,
2727
int64_t total_rows, int64_t dim_length, int64_t dim_length_align,
2828
int64_t former_num, int64_t former_length, int64_t tail_length, float eps,
2929
int64_t dtype_code);
3030

3131
namespace infini::ops {
3232

33-
// Custom AscendC fused RmsNorm kernel (implementation index 1).
33+
// Custom AscendC fused `RmsNorm` kernel (implementation index 1).
3434
//
35-
// A single-kernel implementation that computes RMSNorm in one launch, avoiding
36-
// the 5-sub-op decomposition of `aclnnRmsNorm` (index 0). Uses `Sqrt` +
37-
// scalar division instead of `Rsqrt` for higher precision (~1e-7 fp32 error
38-
// vs ~0.2% with `Rsqrt`).
35+
// A single-kernel implementation that computes `RMSNorm` in one launch,
36+
// avoiding the 5-sub-op decomposition of `aclnnRmsNorm` (index 0). Uses
37+
// `Sqrt` + scalar division instead of `Rsqrt` for higher precision (~1e-7
38+
// `fp32` error vs ~0.2% with `Rsqrt`).
3939
//
4040
// Select via `implementation_index=1` in Python:
41-
// infini.ops.rms_norm(input, weight, eps, out, implementation_index=1,
42-
// stream=s)
41+
// `infini.ops.rms_norm(input, weight, eps, out, implementation_index=1,
42+
// stream=s)`.
4343
//
4444
// Requirements:
45-
// - Input last dimension must be 32-byte aligned (divisible by 16 for fp16
46-
// or 8 for fp32). All standard LLM hidden dimensions satisfy this.
47-
// - Weight must have the same dtype as input.
45+
// - Input last dimension must be 32-byte aligned (divisible by 16 for
46+
// `fp16` or 8 for `fp32`). All standard LLM hidden dimensions satisfy
47+
// this.
48+
// - `weight` must have the same dtype as `input`.
4849
// - The custom kernel binary must be linked (`BUILD_ASCEND_CUSTOM=ON`).
4950
template <>
5051
class Operator<RmsNorm, Device::Type::kAscend, 1> : public RmsNorm {
@@ -54,7 +55,7 @@ class Operator<RmsNorm, Device::Type::kAscend, 1> : public RmsNorm {
5455
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
5556
dtype_ == DataType::kFloat32) &&
5657
"`RmsNorm` custom kernel: `input` must be `fp16`, `bf16`, or "
57-
"`fp32`.");
58+
"`fp32`");
5859

5960
// 32-byte alignment on the last dimension — kernel relies on aligned
6061
// `DataCopyPad` loads/stores.
@@ -63,7 +64,7 @@ class Operator<RmsNorm, Device::Type::kAscend, 1> : public RmsNorm {
6364
((static_cast<int64_t>(dim_) + align_elems - 1) / align_elems) *
6465
align_elems;
6566
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
66-
"`RmsNorm` custom kernel: last dimension must be 32-byte aligned.");
67+
"`RmsNorm` custom kernel: last dimension must be 32-byte aligned");
6768

6869
total_rows_ =
6970
static_cast<int64_t>(batch_size_) * static_cast<int64_t>(nhead_);

src/ascend/rotary_embedding/kernel.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ class Operator<RotaryEmbedding, Device::Type::kAscend>
5252
max_seq_len_{cos_sin_cache.size(0)},
5353
elem_sz_{cos_sin_cache.element_size()} {
5454
assert(rotary_dim == head_size &&
55-
"Ascend `RotaryEmbedding`: `rotary_dim` must equal `head_size` "
56-
"(partial rotation is not implemented in this wrapper).");
55+
"ascend `RotaryEmbedding`: `rotary_dim` must equal `head_size` "
56+
"(partial rotation is not implemented in this wrapper)");
5757
assert(is_neox_style &&
58-
"Ascend `RotaryEmbedding`: `is_neox_style` must be `true` — "
58+
"ascend `RotaryEmbedding`: `is_neox_style` must be `true` — "
5959
"this wrapper only plumbs `rotaryMode=\"half\"` through "
60-
"`aclnnApplyRotaryPosEmbV2`.");
60+
"`aclnnApplyRotaryPosEmbV2`");
6161
assert(has_key_ &&
62-
"Ascend `RotaryEmbedding` (impl 0): `key` is required — "
63-
"`aclnnApplyRotaryPosEmbV2` always rotates Q and K together.");
62+
"ascend `RotaryEmbedding` (impl 0): `key` is required — "
63+
"`aclnnApplyRotaryPosEmbV2` always rotates Q and K together");
6464

6565
// Resolve optional out buffers; when omitted, RoPE writes back in place
6666
// on `query` / `key` — vLLM-style inplace semantics.

src/ascend/rotary_embedding/kernel_atb.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ namespace infini::ops {
5656
// * `rotaryCoeff=2` when `is_neox_style=true` (half split + cat)
5757
// * `rotaryCoeff=head_size` when `is_neox_style=false` (interleave)
5858
// Partial rotary (`rotary_dim < head_size`) is not supported by either
59-
// the aclnn or ATB fused APIs; callers must pad to `head_size` upstream.
59+
// the `aclnn` or ATB fused APIs; callers must pad to `head_size`
60+
// upstream.
6061
template <>
6162
class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
6263
: public RotaryEmbedding {
@@ -71,11 +72,11 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
7172
is_neox_style, rotary_dim, query_out, key_out,
7273
pre_gathered) {
7374
assert(rotary_dim == head_size &&
74-
"Ascend `RotaryEmbedding` (ATB): `rotary_dim` must equal "
75-
"`head_size` — ATB `RopeParam` does not support partial rotary.");
75+
"ascend `RotaryEmbedding` (ATB): `rotary_dim` must equal "
76+
"`head_size` — ATB `RopeParam` does not support partial rotary");
7677
assert(has_key_ &&
77-
"Ascend `RotaryEmbedding` (ATB): `key` is required — ATB "
78-
"`RopeParam` always rotates Q and K together.");
78+
"ascend `RotaryEmbedding` (ATB): `key` is required — ATB "
79+
"`RopeParam` always rotates Q and K together");
7980

8081
const int64_t head_dim = head_size_;
8182
const size_t elem_sz = cos_sin_cache.element_size();
@@ -101,7 +102,7 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
101102
aclrtMalloc(&cos_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
102103
aclrtMalloc(&sin_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
103104

104-
// Upload the initial cos_sin_cache. `cos_sin_cache_data_` memorizes
105+
// Upload the initial `cos_sin_cache`. `cos_sin_cache_data_` memorizes
105106
// the source pointer; if the caller later hands in a different buffer,
106107
// `operator()` re-runs the upload.
107108
UploadCosSinCache(cos_sin_cache);
@@ -142,7 +143,7 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
142143
param.cosFormat = 0; // Inference mode.
143144
atb::Status s = atb::CreateOperation(param, &op_);
144145

145-
assert(s == atb::NO_ERROR && "`atb::CreateOperation(Rope)` failed.");
146+
assert(s == atb::NO_ERROR && "`atb::CreateOperation(Rope)` failed");
146147
}
147148

148149
~Operator() {
@@ -295,7 +296,7 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
295296
uint64_t ws_size = 0;
296297
atb::Status s = op_->Setup(vp, ws_size, ctx);
297298

298-
assert(s == atb::NO_ERROR && "ATB Rope `Setup` failed.");
299+
assert(s == atb::NO_ERROR && "ATB Rope `Setup` failed");
299300

300301
uint8_t* ws_ptr = nullptr;
301302

@@ -306,7 +307,7 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 1>
306307

307308
s = op_->Execute(vp, ws_ptr, ws_size, ctx);
308309

309-
assert(s == atb::NO_ERROR && "ATB Rope `Execute` failed.");
310+
assert(s == atb::NO_ERROR && "ATB Rope `Execute` failed");
310311
}
311312

312313
private:

src/ascend/rotary_embedding/kernel_sincos_cache.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
5151
pre_gathered),
5252
max_seq_len_{cos_sin_cache.size(0)} {
5353
assert(has_key_ &&
54-
"Ascend `RotaryEmbedding` (`aclnnRopeWithSinCosCache`): `key` is "
55-
"required — this fused API always rotates Q and K together.");
54+
"ascend `RotaryEmbedding` (`aclnnRopeWithSinCosCache`): `key` is "
55+
"required — this fused API always rotates Q and K together");
5656
assert(!pre_gathered_ &&
57-
"Ascend `RotaryEmbedding` (`aclnnRopeWithSinCosCache`): "
57+
"ascend `RotaryEmbedding` (`aclnnRopeWithSinCosCache`): "
5858
"`pre_gathered` is not supported — use implementation index 0 or "
59-
"1 for the pre-gathered fast path.");
59+
"1 for the pre-gathered fast path");
6060

6161
// Resolve optional out buffers (inplace on `query` / `key` when omitted).
6262
// Non-const so `.data()` returns a writable `void*`.
@@ -143,7 +143,7 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
143143
auto ret = aclnnRopeWithSinCosCacheGetWorkspaceSize(
144144
t_pos, t_q, t_k, t_cache, /*mropeSection=*/nullptr, head_size,
145145
is_neox_style, t_q_out, t_k_out, &ws_size, &executor);
146-
assert(ret == 0 && "`aclnnRopeWithSinCosCacheGetWorkspaceSize` failed.");
146+
assert(ret == 0 && "`aclnnRopeWithSinCosCacheGetWorkspaceSize` failed");
147147

148148
void* ws_buf = nullptr;
149149

@@ -153,7 +153,7 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
153153
}
154154

155155
ret = aclnnRopeWithSinCosCache(ws_buf, ws_size, executor, stream);
156-
assert(ret == 0 && "`aclnnRopeWithSinCosCache` failed.");
156+
assert(ret == 0 && "`aclnnRopeWithSinCosCache` failed");
157157
}
158158

159159
private:

src/base/add_rms_norm.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@ class AddRmsNorm : public Operator<AddRmsNorm> {
2222
batch_size_{ndim_ == 2 ? input.size(-2) : input.size(-3)},
2323
nhead_{ndim_ == 2 ? 1 : input.size(-2)} {
2424
assert(input.dtype() == residual.dtype() &&
25-
"`AddRmsNorm`: `input` and `residual` must have the same dtype.");
25+
"`AddRmsNorm`: `input` and `residual` must have the same dtype");
2626
assert(input.dtype() == out.dtype() &&
27-
"`AddRmsNorm`: `input` and `out` must have the same dtype.");
28-
assert(
29-
input.dtype() == residual_out.dtype() &&
30-
"`AddRmsNorm`: `input` and `residual_out` must have the same dtype.");
27+
"`AddRmsNorm`: `input` and `out` must have the same dtype");
28+
assert(input.dtype() == residual_out.dtype() &&
29+
"`AddRmsNorm`: `input` and `residual_out` must have the same dtype");
3130
}
3231

3332
AddRmsNorm(Tensor input, Tensor residual, const Tensor weight, float eps)

src/base/linear.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ class Linear : public Operator<Linear> {
4141
: Linear{input, weight, bias, /*trans_a=*/false, /*trans_b=*/true, out} {
4242
assert(weight.ndim() >= 2 &&
4343
"`Linear`: `weight` must have at least 2 dims "
44-
"`[..., out_features, in_features]`.");
44+
"`[..., out_features, in_features]`");
4545
assert(weight.size(-1) == input.size(-1) &&
4646
"`Linear`: `weight.shape[-1]` must equal `input.shape[-1]` "
47-
"(`in_features`).");
47+
"(`in_features`)");
4848
assert(weight.size(-2) == out.size(-1) &&
4949
"`Linear`: `weight.shape[-2]` must equal `out.shape[-1]` "
50-
"(`out_features`).");
50+
"(`out_features`)");
5151
}
5252

5353
// Deprecated — use `(input, weight, bias, out)` overload.

src/base/rotary_embedding.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,25 @@ class RotaryEmbedding : public Operator<RotaryEmbedding> {
4040
has_key_{key.has_value()},
4141
pre_gathered_{pre_gathered} {
4242
assert(positions.dtype() == DataType::kInt64 &&
43-
"`RotaryEmbedding`: `positions` must be `int64` (vLLM convention).");
43+
"`RotaryEmbedding`: `positions` must be `int64` (vLLM convention)");
4444

4545
assert((query.ndim() == 2 || query.ndim() == 3) &&
4646
"`RotaryEmbedding`: `query` must be 2D `[T, Nq * head_size]` or 3D "
47-
"`[T, Nq, head_size]`.");
47+
"`[T, Nq, head_size]`");
4848

4949
// TODO: relax once an MLA-capable Ascend impl lands. The signature keeps
5050
// `std::optional<Tensor> key` for vLLM-API compatibility, but all current
5151
// Ascend impls assume `key` is present and rotate Q and K together.
5252
assert(key.has_value() &&
5353
"`RotaryEmbedding`: `key` is required; the `key = None` (MLA) path "
54-
"is not yet implemented on any backend.");
54+
"is not yet implemented on any backend");
5555

5656
assert((key->ndim() == 2 || key->ndim() == 3) &&
5757
"`RotaryEmbedding`: `key` must be 2D `[T, Nkv * head_size]` or 3D "
58-
"`[T, Nkv, head_size]`.");
58+
"`[T, Nkv, head_size]`");
5959

6060
assert(rotary_dim <= head_size &&
61-
"`RotaryEmbedding`: `rotary_dim` must be `<= head_size`.");
61+
"`RotaryEmbedding`: `rotary_dim` must be `<= head_size`");
6262
}
6363

6464
virtual void operator()(const Tensor positions, const Tensor query,

src/base/silu_and_mul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class SiluAndMul : public Operator<SiluAndMul> {
2323
is_input_contiguous_{input.IsContiguous()},
2424
is_out_contiguous_{out.IsContiguous()} {
2525
assert(input_dtype_ == out_dtype_ &&
26-
"`SiluAndMul`: `input` and `out` must have the same dtype.");
26+
"`SiluAndMul`: `input` and `out` must have the same dtype");
2727
}
2828

2929
SiluAndMul(const Tensor input, Tensor out) : SiluAndMul{input, -1, out} {}

0 commit comments

Comments
 (0)