Skip to content

Commit 33308a4

Browse files
author
zhangyue
committed
feat(ascend): add multi-implementation variants and ATB operators
Add alternative implementations with registries: - AddRmsNorm: decomposed (0), fused aclnnAddRmsNorm (1), custom AscendC (2) - RmsNorm: ACLNN (0), custom AscendC (1) - RotaryEmbedding: ACLNN (0), ATB Rope (1) - ReshapeAndCache: ACLNN (0), ScatterPaKvCache (1), ATB (2) - Swiglu: decomposed (0), fused aclnnSwiGlu (1) - SiluAndMul: fused aclnnSwiGlu (0), registry (1) - PagedAttention: ATB (0)
1 parent 0c7a31a commit 33308a4

15 files changed

Lines changed: 1714 additions & 0 deletions

File tree

src/ascend/add_rms_norm/kernel.h

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_
3+
4+
#include <vector>
5+
6+
#include "acl/acl.h"
7+
#include "aclnn/aclnn_base.h"
8+
#include "aclnn_add.h"
9+
#include "aclnn_rms_norm.h"
10+
#include "ascend/common.h"
11+
#include "ascend/add_rms_norm/registry.h"
12+
#include "ascend/workspace_pool_.h"
13+
#include "operator.h"
14+
15+
namespace infini::ops {
16+
17+
// Decomposed implementation: aclnnAdd + aclnnRmsNorm.
18+
//
19+
// The fused aclnnAddRmsNorm API has ~200 us host-side launch overhead that
20+
// dominates small-tensor dispatch. Decomposing into two fast ACLNN calls
21+
// reduces host dispatch from ~224 us to ~56 us (4x faster) with negligible
22+
// NPU-side impact for inference tensor sizes.
23+
template <>
24+
class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
25+
public:
26+
Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps,
27+
Tensor y_out, Tensor x_out)
28+
: AddRmsNorm(x1, x2, gamma, eps, y_out, x_out),
29+
x1_cache_(x1),
30+
x2_cache_(x2),
31+
gamma_cache_(gamma),
32+
y_out_cache_(y_out),
33+
x_out_cache_(x_out) {
34+
// Alpha scalar for aclnnAdd (x_out = x1 + 1.0 * x2).
35+
alpha_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT);
36+
37+
// aclnnRmsNorm writes rstd as a required side output.
38+
// Size computed here; buffer obtained from pool in `operator()`.
39+
rstd_shape_ = {static_cast<int64_t>(batch_size_),
40+
static_cast<int64_t>(nhead_)};
41+
rstd_size_ = batch_size_ * nhead_ * sizeof(float);
42+
}
43+
44+
~Operator() {
45+
if (add_exec_) aclDestroyAclOpExecutor(add_exec_);
46+
if (norm_exec_) aclDestroyAclOpExecutor(norm_exec_);
47+
aclDestroyScalar(alpha_);
48+
if (rstd_tensor_) aclDestroyTensor(rstd_tensor_);
49+
}
50+
51+
void operator()(const Tensor x1, const Tensor x2, const Tensor gamma,
52+
float eps, Tensor y_out, Tensor x_out) const override {
53+
auto t_x1 = x1_cache_.get(const_cast<void*>(x1.data()));
54+
auto t_x2 = x2_cache_.get(const_cast<void*>(x2.data()));
55+
auto t_gamma = gamma_cache_.get(const_cast<void*>(gamma.data()));
56+
auto t_y_out = y_out_cache_.get(y_out.data());
57+
auto t_x_out = x_out_cache_.get(x_out.data());
58+
auto stream = static_cast<aclrtStream>(stream_);
59+
60+
// Step 1: x_out = x1 + x2.
61+
if (!add_exec_) {
62+
aclnnAddGetWorkspaceSize(t_x1, t_x2, alpha_, t_x_out, &add_ws_,
63+
&add_exec_);
64+
aclSetAclOpExecutorRepeatable(add_exec_);
65+
} else {
66+
aclSetInputTensorAddr(add_exec_, 0, t_x1,
67+
const_cast<void*>(x1.data()));
68+
aclSetInputTensorAddr(add_exec_, 1, t_x2,
69+
const_cast<void*>(x2.data()));
70+
aclSetOutputTensorAddr(add_exec_, 0, t_x_out, x_out.data());
71+
}
72+
auto& add_arena = ascend::workspacePool().ensure(stream, add_ws_);
73+
aclnnAdd(add_arena.buf, add_ws_, add_exec_, stream);
74+
75+
// Obtain shared rstd buffer from pool.
76+
auto& rstd_arena =
77+
ascend::workspacePool().ensure(stream, rstd_size_, "temp");
78+
79+
// Lazily create rstd tensor descriptor on first call.
80+
if (!rstd_tensor_) {
81+
rstd_tensor_ = aclCreateTensor(
82+
rstd_shape_.data(), 2, ACL_FLOAT,
83+
/*strides=*/nullptr, 0, ACL_FORMAT_ND, rstd_shape_.data(), 2,
84+
rstd_arena.buf);
85+
} else {
86+
aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf);
87+
}
88+
89+
// Step 2: y_out = rms_norm(x_out, gamma, eps).
90+
if (!norm_exec_) {
91+
aclnnRmsNormGetWorkspaceSize(t_x_out, t_gamma, eps, t_y_out,
92+
rstd_tensor_, &norm_ws_, &norm_exec_);
93+
aclSetAclOpExecutorRepeatable(norm_exec_);
94+
} else {
95+
aclSetInputTensorAddr(norm_exec_, 0, t_x_out, x_out.data());
96+
aclSetInputTensorAddr(norm_exec_, 1, t_gamma,
97+
const_cast<void*>(gamma.data()));
98+
aclSetOutputTensorAddr(norm_exec_, 0, t_y_out, y_out.data());
99+
aclSetOutputTensorAddr(norm_exec_, 1, rstd_tensor_, rstd_arena.buf);
100+
}
101+
auto& norm_arena = ascend::workspacePool().ensure(stream, norm_ws_);
102+
aclnnRmsNorm(norm_arena.buf, norm_ws_, norm_exec_, stream);
103+
}
104+
105+
private:
106+
mutable ascend::AclTensorCache x1_cache_;
107+
108+
mutable ascend::AclTensorCache x2_cache_;
109+
110+
mutable ascend::AclTensorCache gamma_cache_;
111+
112+
mutable ascend::AclTensorCache y_out_cache_;
113+
114+
mutable ascend::AclTensorCache x_out_cache_;
115+
116+
float alpha_storage_ = 1.0f;
117+
118+
aclScalar* alpha_ = nullptr;
119+
120+
std::vector<int64_t> rstd_shape_;
121+
122+
uint64_t rstd_size_ = 0;
123+
124+
mutable aclTensor* rstd_tensor_ = nullptr;
125+
126+
mutable aclOpExecutor* add_exec_ = nullptr;
127+
128+
mutable uint64_t add_ws_ = 0;
129+
130+
mutable aclOpExecutor* norm_exec_ = nullptr;
131+
132+
mutable uint64_t norm_ws_ = 0;
133+
};
134+
135+
} // namespace infini::ops
136+
137+
#endif
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_
2+
#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_
3+
4+
#ifdef INFINI_HAS_CUSTOM_ADD_RMS_NORM
5+
6+
#include <algorithm>
7+
#include <cstdint>
8+
#include <vector>
9+
10+
#include "acl/acl.h"
11+
#include "aclnn/aclnn_base.h"
12+
#include "aclnnop/aclnn_cast.h"
13+
#include "ascend/common.h"
14+
#include "ascend/add_rms_norm/registry.h"
15+
#include "ascend/workspace_pool_.h"
16+
#include "base/add_rms_norm.h"
17+
#include "operator.h"
18+
19+
// Forward-declare the generated AscendC kernel launch function.
20+
// This symbol is provided by the `no_workspace_kernel` static library
21+
// built from `ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp`
22+
// via `ascendc_library()`.
23+
extern "C" uint32_t aclrtlaunch_add_rms_norm(
24+
uint32_t blockDim, void* stream,
25+
void* x1, void* x2, void* weight, void* y, void* x_out,
26+
int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign,
27+
int64_t formerNum, int64_t formerLength, int64_t tailLength,
28+
float eps, int64_t dtypeSize);
29+
30+
namespace infini::ops {
31+
32+
// Custom AscendC fused AddRmsNorm kernel (implementation index 2).
33+
//
34+
// A single-kernel implementation that computes x_out = x1 + x2 followed by
35+
// y = rms_norm(x_out, gamma, eps) in one launch, avoiding the decomposed
36+
// aclnnAdd + aclnnRmsNorm calls (index 0) or the fused aclnnAddRmsNorm call
37+
// (index 1). Migrated from the custom RmsNorm kernel (index 1 of RmsNorm).
38+
//
39+
// Select via `implementation_index=2` in Python:
40+
// infini.ops.add_rms_norm(x1, x2, gamma, eps, y_out, x_out,
41+
// implementation_index=2, stream=s)
42+
//
43+
// Requirements:
44+
// - Input last dimension must be 32-byte aligned (divisible by 16 for fp16
45+
// or 8 for fp32). All standard LLM hidden dimensions satisfy this.
46+
// - Weight must have the same dtype as input.
47+
// - The custom kernel binary must be linked (`BUILD_CUSTOM_KERNEL=ON`).
48+
template <>
49+
class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
50+
public:
51+
Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps,
52+
Tensor y_out, Tensor x_out)
53+
: AddRmsNorm(x1, x2, gamma, eps, y_out, x_out) {
54+
// Dtype size in bytes.
55+
dtype_size_ = (x1.dtype() == DataType::kFloat16) ? 2 : 4;
56+
57+
// Alignment check (32-byte boundary).
58+
int64_t align_elems = 32 / dtype_size_;
59+
dim_length_align_ =
60+
((static_cast<int64_t>(dim_) + align_elems - 1) / align_elems) *
61+
align_elems;
62+
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
63+
"Custom AddRmsNorm kernel requires 32-byte aligned last dimension");
64+
65+
total_rows_ = static_cast<int64_t>(batch_size_) *
66+
static_cast<int64_t>(nhead_);
67+
68+
// For fp16 input, weight needs fp32 conversion because the custom
69+
// kernel always reads weight as fp32.
70+
needs_weight_cast_ = (dtype_size_ == 2);
71+
72+
if (needs_weight_cast_) {
73+
// Allocate persistent fp32 weight buffer on device.
74+
size_t fp32_bytes = static_cast<size_t>(dim_) * sizeof(float);
75+
aclrtMalloc(&weight_fp32_data_, fp32_bytes,
76+
ACL_MEM_MALLOC_NORMAL_ONLY);
77+
78+
// AclTensorCache for the cast source (fp16 weight descriptor).
79+
weight_src_cache_ = ascend::AclTensorCache(
80+
{static_cast<int64_t>(dim_)}, ACL_FLOAT16, nullptr);
81+
82+
// AclTensorCache for the cast destination (fp32 weight buffer).
83+
weight_dst_cache_ = ascend::AclTensorCache(
84+
{static_cast<int64_t>(dim_)}, ACL_FLOAT, weight_fp32_data_);
85+
}
86+
}
87+
88+
~Operator() {
89+
if (!ascend::isAclRuntimeAlive()) return;
90+
if (cast_exec_) aclDestroyAclOpExecutor(cast_exec_);
91+
if (weight_fp32_data_) aclrtFree(weight_fp32_data_);
92+
}
93+
94+
void operator()(const Tensor x1, const Tensor x2, const Tensor gamma,
95+
float eps, Tensor y_out, Tensor x_out) const override {
96+
auto stream = static_cast<aclrtStream>(stream_);
97+
98+
// Determine fp32 weight pointer.
99+
void* weight_fp32;
100+
101+
if (needs_weight_cast_) {
102+
// Only re-cast when the weight data pointer changes. Model weights
103+
// are fixed after loading, so this typically runs once on the first
104+
// call and is skipped on all subsequent calls.
105+
const void* cur_weight = gamma.data();
106+
107+
if (cur_weight != last_weight_ptr_) {
108+
auto t_src =
109+
weight_src_cache_.get(const_cast<void*>(cur_weight));
110+
auto t_dst = weight_dst_cache_.get(weight_fp32_data_);
111+
112+
if (!cast_exec_) {
113+
aclnnCastGetWorkspaceSize(t_src, ACL_FLOAT, t_dst, &cast_ws_,
114+
&cast_exec_);
115+
aclSetAclOpExecutorRepeatable(cast_exec_);
116+
} else {
117+
aclSetInputTensorAddr(cast_exec_, 0, t_src,
118+
const_cast<void*>(cur_weight));
119+
aclSetOutputTensorAddr(cast_exec_, 0, t_dst, weight_fp32_data_);
120+
}
121+
122+
auto& arena = ascend::workspacePool().ensure(stream, cast_ws_);
123+
aclnnCast(arena.buf, cast_ws_, cast_exec_, stream);
124+
last_weight_ptr_ = cur_weight;
125+
}
126+
127+
weight_fp32 = weight_fp32_data_;
128+
} else {
129+
// Input is fp32 — weight is already fp32.
130+
weight_fp32 = const_cast<void*>(gamma.data());
131+
}
132+
133+
// Block-level tiling: distribute rows across cores.
134+
static constexpr int64_t kMaxBlockDim = 40;
135+
int64_t used_cores = std::min(total_rows_, kMaxBlockDim);
136+
int64_t former_length =
137+
(total_rows_ + used_cores - 1) / used_cores;
138+
int64_t tail_length = former_length - 1;
139+
int64_t former_num = total_rows_ - tail_length * used_cores;
140+
uint32_t block_dim = static_cast<uint32_t>(used_cores);
141+
142+
// Launch custom AscendC kernel.
143+
aclrtlaunch_add_rms_norm(
144+
block_dim, stream,
145+
const_cast<void*>(x1.data()),
146+
const_cast<void*>(x2.data()),
147+
weight_fp32,
148+
y_out.data(),
149+
x_out.data(),
150+
total_rows_,
151+
static_cast<int64_t>(dim_),
152+
dim_length_align_,
153+
former_num, former_length, tail_length,
154+
eps, dtype_size_);
155+
}
156+
157+
private:
158+
int64_t dtype_size_;
159+
160+
int64_t dim_length_align_;
161+
162+
int64_t total_rows_;
163+
164+
bool needs_weight_cast_;
165+
166+
void* weight_fp32_data_ = nullptr;
167+
168+
mutable ascend::AclTensorCache weight_src_cache_;
169+
170+
mutable ascend::AclTensorCache weight_dst_cache_;
171+
172+
mutable const void* last_weight_ptr_ = nullptr;
173+
174+
mutable aclOpExecutor* cast_exec_ = nullptr;
175+
176+
mutable uint64_t cast_ws_ = 0;
177+
};
178+
179+
} // namespace infini::ops
180+
181+
#endif // INFINI_HAS_CUSTOM_ADD_RMS_NORM
182+
#endif // INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_

0 commit comments

Comments
 (0)