Skip to content

Commit d34542d

Browse files
author
zhangyue
committed
feat(ascend-custom): add bf16 support + Google-style identifier renames
bf16 was silently producing garbage / NaN on impl 1 (`rms_norm`) and impl 2 (`add_rms_norm`): the kernels only instantiated `<half>` and `<float>`, and the launchers mapped bf16 to the fp32 byte-size path, so bf16 weight was read as if it were fp32 and the fp16 output cast used `CAST_ROUND` (fp16-only alias). Kernel dispatch: - `op_kernel/rms_norm.cpp` / `op_kernel/add_rms_norm.cpp`: add a `KernelXxx<bfloat16_t>` instantiation; dispatch in the `extern "C"` entry is now `switch (static_cast<infini::ops::DataType>(dtypeCode))` (shared enum forwarded from the launcher via `int64_t`). The fp16/bf16 branch uses `CAST_RINT` for the fp32 → T writeback — defined for both `half` and `bfloat16_t` destinations, whereas `CAST_ROUND` is a `half`-specific alias. Launchers (`kernel_custom.h`): - Store `DataType dtype_` (replaces the old `int64_t dtype_size_` which collapsed fp16 and bf16 onto the same code). - Use `ascend::ToAclDtype(dtype_)` and `kDataTypeToSize.at(dtype_)` instead of hand-rolled ternaries (consistent with the rest of the Ascend backend). - Forward `static_cast<int64_t>(dtype_)` as the kernel's `dtypeCode`. - `extern "C" aclrtlaunch_*` forward-decl parameters renamed to `snake_case`; the function name itself is generated by `ascendc_add_operator(OP_NAME …)` and carries `// NOLINTNEXTLINE(readability-identifier-naming)` so `clang-tidy` accepts it. Identifier naming (Google C++ Style): - `op_kernel/*.cpp` members `snake_case_`, params / locals `snake_case`, constants `kPascalCase` (was `BUFFER_NUM` / `dimLength` / `inQueueX1` / `blockRows`, etc. — inherited from the `vllm-ascend` sample style). Verified: `pytest tests/test_rms_norm.py tests/test_add_rms_norm.py --devices ascend` → 144 passed / 0 failed (fp32 / fp16 / bf16 × both ops × full shape + stride matrix).
1 parent d9b1e09 commit d34542d

4 files changed

Lines changed: 396 additions & 355 deletions

File tree

src/ascend/add_rms_norm/kernel_custom.h

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919
// This symbol is provided by the `no_workspace_kernel` static library
2020
// built from `ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp`
2121
// via `ascendc_library()`.
22+
// `aclrtlaunch_*` symbol name is generated by `ascendc_library()` /
23+
// `ascendc_add_operator()` and cannot be `PascalCase`d.
24+
// NOLINTNEXTLINE(readability-identifier-naming)
2225
extern "C" uint32_t aclrtlaunch_add_rms_norm(
23-
uint32_t blockDim, void* stream, void* x1, void* x2, void* weight, void* y,
24-
void* x_out, int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign,
25-
int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps,
26-
int64_t dtypeSize);
26+
uint32_t block_dim, void* stream, void* x1, void* x2, void* weight, void* y,
27+
void* x_out, int64_t total_rows, int64_t dim_length,
28+
int64_t dim_length_align, int64_t former_num, int64_t former_length,
29+
int64_t tail_length, float eps, int64_t dtype_code);
2730

2831
namespace infini::ops {
2932

@@ -50,36 +53,36 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
5053
public:
5154
Operator(const Tensor input, const Tensor other, const Tensor weight,
5255
float eps, Tensor out, Tensor residual_out)
53-
: AddRmsNorm(input, other, weight, eps, out, residual_out) {
54-
// Dtype size in bytes.
55-
dtype_size_ = (input.dtype() == DataType::kFloat16) ? 2 : 4;
56-
57-
// Alignment check (32-byte boundary).
58-
int64_t align_elems = 32 / dtype_size_;
56+
: AddRmsNorm(input, other, weight, eps, out, residual_out),
57+
dtype_{input.dtype()} {
58+
assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 ||
59+
dtype_ == DataType::kFloat32) &&
60+
"`AddRmsNorm` custom kernel: `input` must be `fp16`, `bf16`, or "
61+
"`fp32`.");
62+
63+
// 32-byte alignment on the last dimension — kernel relies on aligned
64+
// `DataCopyPad` loads/stores.
65+
int64_t align_elems = 32 / static_cast<int64_t>(kDataTypeToSize.at(dtype_));
5966
dim_length_align_ =
6067
((static_cast<int64_t>(dim_) + align_elems - 1) / align_elems) *
6168
align_elems;
6269
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
63-
"`AddRmsNorm`: custom kernel requires 32-byte aligned last "
64-
"dimension.");
70+
"`AddRmsNorm` custom kernel: last dimension must be 32-byte "
71+
"aligned.");
6572

6673
total_rows_ =
6774
static_cast<int64_t>(batch_size_) * static_cast<int64_t>(nhead_);
6875

69-
// For `float16` input, `weight` needs fp32 conversion because the custom
70-
// kernel always reads `weight` as `float32`.
71-
needs_weight_cast_ = (dtype_size_ == 2);
72-
73-
if (needs_weight_cast_) {
74-
// Allocate persistent fp32 `weight` buffer on device.
76+
// The custom kernel always reads `weight` as fp32. fp16 / bf16 inputs
77+
// trigger a lazy cast in `operator()` (guarded by `last_weight_ptr_`
78+
// so that the cast runs only when the weight pointer changes — model
79+
// weights are typically fixed after loading).
80+
if (dtype_ != DataType::kFloat32) {
7581
size_t fp32_bytes = static_cast<size_t>(dim_) * sizeof(float);
7682
aclrtMalloc(&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
7783

78-
// `AclTensorCache` for the cast source (`float16` `weight` descriptor).
79-
weight_src_cache_ = ascend::AclTensorCache({static_cast<int64_t>(dim_)},
80-
ACL_FLOAT16, nullptr);
81-
82-
// `AclTensorCache` for the cast destination (`float32` `weight` buffer).
84+
weight_src_cache_ = ascend::AclTensorCache(
85+
{static_cast<int64_t>(dim_)}, ascend::ToAclDtype(dtype_), nullptr);
8386
weight_dst_cache_ = ascend::AclTensorCache({static_cast<int64_t>(dim_)},
8487
ACL_FLOAT, weight_fp32_data_);
8588
}
@@ -99,15 +102,13 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
99102
float eps, Tensor out, Tensor residual_out) const override {
100103
auto stream = static_cast<aclrtStream>(stream_);
101104

102-
// Determine `float32` `weight` pointer.
103105
void* weight_fp32;
104106

105-
if (needs_weight_cast_) {
106-
// Only re-cast when the `weight` data pointer changes. Model weights
107-
// are fixed after loading, so this typically runs once on the first
108-
// call and is skipped on all subsequent calls.
107+
if (dtype_ != DataType::kFloat32) {
109108
const void* cur_weight = weight.data();
110109

110+
// Model weights are fixed after loading, so the cast typically runs
111+
// once on the first call and is skipped on all subsequent calls.
111112
if (cur_weight != last_weight_ptr_) {
112113
auto t_src = weight_src_cache_.get(const_cast<void*>(cur_weight));
113114
auto t_dst = weight_dst_cache_.get(weight_fp32_data_);
@@ -129,36 +130,33 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
129130

130131
weight_fp32 = weight_fp32_data_;
131132
} else {
132-
// `input` is `float32` — `weight` is already `float32`.
133133
weight_fp32 = const_cast<void*>(weight.data());
134134
}
135135

136-
// Block-level tiling: distribute rows across cores.
136+
// Block-level tiling. Ascend 910B has 20–40 AIV cores; over-subscribing
137+
// is safe (runtime multiplexes) but wastes one weight load per block.
137138
static constexpr int64_t kMaxBlockDim = 40;
138139
int64_t used_cores = std::min(total_rows_, kMaxBlockDim);
139140
int64_t former_length = (total_rows_ + used_cores - 1) / used_cores;
140141
int64_t tail_length = former_length - 1;
141142
int64_t former_num = total_rows_ - tail_length * used_cores;
142143
uint32_t block_dim = static_cast<uint32_t>(used_cores);
143144

144-
// Launch custom AscendC kernel.
145145
aclrtlaunch_add_rms_norm(block_dim, stream, const_cast<void*>(input.data()),
146146
const_cast<void*>(other.data()), weight_fp32,
147147
out.data(), residual_out.data(), total_rows_,
148148
static_cast<int64_t>(dim_), dim_length_align_,
149149
former_num, former_length, tail_length, eps,
150-
dtype_size_);
150+
static_cast<int64_t>(dtype_));
151151
}
152152

153153
private:
154-
int64_t dtype_size_;
154+
DataType dtype_;
155155

156156
int64_t dim_length_align_;
157157

158158
int64_t total_rows_;
159159

160-
bool needs_weight_cast_;
161-
162160
void* weight_fp32_data_ = nullptr;
163161

164162
mutable ascend::AclTensorCache weight_src_cache_;

0 commit comments

Comments
 (0)