You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat(ascend-custom): add bf16 support + Google-style identifier renames
bf16 was silently producing garbage / NaN on impl 1 (`rms_norm`) and
impl 2 (`add_rms_norm`): the kernels only instantiated `<half>` and
`<float>`, and the launchers mapped bf16 to the fp32 byte-size path,
so bf16 weight was read as if it were fp32 and the fp16 output cast
used `CAST_ROUND` (fp16-only alias).
Kernel dispatch:
- `op_kernel/rms_norm.cpp` / `op_kernel/add_rms_norm.cpp`: add a
`KernelXxx<bfloat16_t>` instantiation; dispatch in the `extern "C"`
entry is now `switch (static_cast<infini::ops::DataType>(dtypeCode))`
(shared enum forwarded from the launcher via `int64_t`). The
fp16/bf16 branch uses `CAST_RINT` for the fp32 → T writeback —
defined for both `half` and `bfloat16_t` destinations, whereas
`CAST_ROUND` is a `half`-specific alias.
Launchers (`kernel_custom.h`):
- Store `DataType dtype_` (replaces the old `int64_t dtype_size_` which
collapsed fp16 and bf16 onto the same code).
- Use `ascend::ToAclDtype(dtype_)` and `kDataTypeToSize.at(dtype_)`
instead of hand-rolled ternaries (consistent with the rest of the
Ascend backend).
- Forward `static_cast<int64_t>(dtype_)` as the kernel's `dtypeCode`.
- `extern "C" aclrtlaunch_*` forward-decl parameters renamed to
`snake_case`; the function name itself is generated by
`ascendc_add_operator(OP_NAME …)` and carries
`// NOLINTNEXTLINE(readability-identifier-naming)` so `clang-tidy`
accepts it.
Identifier naming (Google C++ Style):
- `op_kernel/*.cpp` members `snake_case_`, params / locals `snake_case`,
constants `kPascalCase` (was `BUFFER_NUM` / `dimLength` / `inQueueX1`
/ `blockRows`, etc. — inherited from the `vllm-ascend` sample style).
Verified: `pytest tests/test_rms_norm.py tests/test_add_rms_norm.py
--devices ascend` → 144 passed / 0 failed (fp32 / fp16 / bf16 × both
ops × full shape + stride matrix).
0 commit comments