Skip to content

Commit 55a5beb

Browse files
author
zhangyue
committed
style: apply ruff format and clang-format to all modified files
1 parent be48553 commit 55a5beb

File tree

19 files changed

+926
-887
lines changed

19 files changed

+926
-887
lines changed

src/ascend/custom_kernel/csrc/ops.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
namespace ascend_kernel {
1515

16-
at::Tensor rms_norm(const at::Tensor &input, const at::Tensor &weight,
16+
at::Tensor rms_norm(const at::Tensor& input, const at::Tensor& weight,
1717
double eps);
1818

19-
} // namespace ascend_kernel
19+
} // namespace ascend_kernel
2020

21-
#endif // OPS_H
21+
#endif // OPS_H

src/ascend/custom_kernel/csrc/ops/add_rms_norm/op_host/add_rms_norm.cpp

Lines changed: 116 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -5,140 +5,126 @@
55
* SPDX-License-Identifier: BSD-3-Clause
66
*/
77

8-
#include "torch_kernel_helper.h"
9-
#include "tiling/platform/platform_ascendc.h"
108
#include "aclrtlaunch_add_rms_norm.h"
9+
#include "tiling/platform/platform_ascendc.h"
10+
#include "torch_kernel_helper.h"
1111

1212
namespace ascend_kernel {
1313

14-
std::vector<at::Tensor> add_rms_norm(const at::Tensor &x1,
15-
const at::Tensor &x2,
16-
const at::Tensor &weight, double eps) {
17-
// Input validation.
18-
TORCH_CHECK(x1.dim() > 0,
19-
"add_rms_norm: x1 must have at least 1 dimension");
20-
TORCH_CHECK(x1.sizes() == x2.sizes(),
21-
"add_rms_norm: x1 and x2 must have the same shape");
22-
TORCH_CHECK(x1.scalar_type() == x2.scalar_type(),
23-
"add_rms_norm: x1 and x2 must have the same dtype");
24-
TORCH_CHECK(x1.scalar_type() == at::kHalf ||
25-
x1.scalar_type() == at::kFloat,
26-
"add_rms_norm: only float16 and float32 are supported, got ",
27-
x1.scalar_type());
28-
TORCH_CHECK(weight.dim() == 1,
29-
"add_rms_norm: weight must be 1-dimensional");
30-
TORCH_CHECK(weight.size(0) == x1.size(-1),
31-
"add_rms_norm: weight size (", weight.size(0),
32-
") must match input last dim (", x1.size(-1), ")");
33-
34-
int64_t dimLength = x1.size(-1);
35-
int64_t totalRows = x1.numel() / dimLength;
36-
37-
if (totalRows == 0 || dimLength == 0) {
38-
return {at::empty_like(x1), at::empty_like(x1)};
39-
}
40-
41-
at::Tensor inp1 = x1.contiguous();
42-
at::Tensor inp2 = x2.contiguous();
43-
int64_t dtypeSize = inp1.element_size();
44-
45-
// Hardware parameters.
46-
auto ascendc_platform =
47-
platform_ascendc::PlatformAscendCManager::GetInstance();
48-
int64_t coreNum =
49-
static_cast<int64_t>(ascendc_platform->GetCoreNumAiv());
50-
uint64_t ubSize;
51-
ascendc_platform->GetCoreMemSize(platform_ascendc::CoreMemType::UB,
52-
ubSize);
53-
int64_t ubSizeLimit = static_cast<int64_t>(ubSize);
54-
55-
// Alignment (32-byte boundary).
56-
int64_t alignElements = 32 / dtypeSize;
57-
int64_t dimLengthAlign =
58-
((dimLength + alignElements - 1) / alignElements) * alignElements;
59-
60-
// UB capacity check.
61-
// fp16: inQ_x1(×2×2) + inQ_x2(×2×2) + outQ_y(×2×2) + outQ_xout(×2×2)
62-
// + fp32Buf1(×4) + fp32Buf2(×4) + weight(×4) = 16 + 12 = 28
63-
// fp32: inQ_x1(×2×4) + inQ_x2(×2×4) + outQ_y(×2×4) + outQ_xout(×2×4)
64-
// + weight(×4) = 32 + 4 = 36
65-
int64_t bufferCoefficient = (dtypeSize == 2) ? 28 : 36;
66-
int64_t maxDimLength =
67-
(ubSizeLimit - 1024) / bufferCoefficient;
68-
int64_t fpAlignElements = 32 / 4;
69-
maxDimLength =
70-
(maxDimLength / fpAlignElements) * fpAlignElements;
71-
TORCH_CHECK(dimLengthAlign <= maxDimLength,
72-
"add_rms_norm: dimLength ", dimLength,
73-
" (aligned ", dimLengthAlign,
74-
") exceeds UB capacity (max ", maxDimLength, ")");
75-
76-
// Padding.
77-
at::Tensor kernelInput1;
78-
at::Tensor kernelInput2;
79-
80-
if (dimLength != dimLengthAlign) {
81-
kernelInput1 = inp1.reshape({totalRows, dimLength});
82-
kernelInput1 = at::constant_pad_nd(
83-
kernelInput1, {0, dimLengthAlign - dimLength}, 0.0);
84-
kernelInput1 = kernelInput1.contiguous();
85-
86-
kernelInput2 = inp2.reshape({totalRows, dimLength});
87-
kernelInput2 = at::constant_pad_nd(
88-
kernelInput2, {0, dimLengthAlign - dimLength}, 0.0);
89-
kernelInput2 = kernelInput2.contiguous();
90-
} else {
91-
kernelInput1 =
92-
inp1.reshape({totalRows, dimLengthAlign}).contiguous();
93-
kernelInput2 =
94-
inp2.reshape({totalRows, dimLengthAlign}).contiguous();
95-
}
96-
97-
at::Tensor kernelOutputY = at::empty_like(kernelInput1);
98-
at::Tensor kernelOutputXOut = at::empty_like(kernelInput1);
99-
100-
// Weight: always pass as fp32, padded to `dimLengthAlign`.
101-
at::Tensor weightFloat = weight.contiguous().to(at::kFloat);
102-
103-
if (dimLength != dimLengthAlign) {
104-
weightFloat = at::constant_pad_nd(
105-
weightFloat, {0, dimLengthAlign - dimLength}, 0.0);
106-
}
107-
108-
weightFloat = weightFloat.contiguous();
109-
110-
// Block-level tiling (distribute rows across cores).
111-
int64_t usedCoreNum = std::min(totalRows, coreNum);
112-
int64_t formerLength =
113-
(totalRows + usedCoreNum - 1) / usedCoreNum;
114-
int64_t tailLength = formerLength - 1;
115-
int64_t formerNum = totalRows - tailLength * usedCoreNum;
116-
uint32_t blockDim = static_cast<uint32_t>(usedCoreNum);
117-
118-
// All EXEC_KERNEL_CMD args must be lvalues.
119-
float epsFloat = static_cast<float>(eps);
120-
int64_t dtypeSizeVal = dtypeSize;
121-
122-
EXEC_KERNEL_CMD(add_rms_norm, blockDim,
123-
kernelInput1, kernelInput2, weightFloat,
124-
kernelOutputY, kernelOutputXOut,
125-
totalRows, dimLength, dimLengthAlign,
126-
formerNum, formerLength, tailLength,
127-
epsFloat, dtypeSizeVal);
128-
129-
// Remove padding and reshape back to original shape.
130-
at::Tensor outputY = kernelOutputY;
131-
at::Tensor outputXOut = kernelOutputXOut;
132-
133-
if (dimLength != dimLengthAlign) {
134-
outputY = outputY.narrow(-1, 0, dimLength).contiguous();
135-
outputXOut = outputXOut.narrow(-1, 0, dimLength).contiguous();
136-
}
137-
138-
outputY = outputY.reshape(x1.sizes());
139-
outputXOut = outputXOut.reshape(x1.sizes());
140-
141-
return {outputY, outputXOut};
14+
std::vector<at::Tensor> add_rms_norm(const at::Tensor& x1, const at::Tensor& x2,
15+
const at::Tensor& weight, double eps) {
16+
// Input validation.
17+
TORCH_CHECK(x1.dim() > 0, "add_rms_norm: x1 must have at least 1 dimension");
18+
TORCH_CHECK(x1.sizes() == x2.sizes(),
19+
"add_rms_norm: x1 and x2 must have the same shape");
20+
TORCH_CHECK(x1.scalar_type() == x2.scalar_type(),
21+
"add_rms_norm: x1 and x2 must have the same dtype");
22+
TORCH_CHECK(x1.scalar_type() == at::kHalf || x1.scalar_type() == at::kFloat,
23+
"add_rms_norm: only float16 and float32 are supported, got ",
24+
x1.scalar_type());
25+
TORCH_CHECK(weight.dim() == 1, "add_rms_norm: weight must be 1-dimensional");
26+
TORCH_CHECK(weight.size(0) == x1.size(-1), "add_rms_norm: weight size (",
27+
weight.size(0), ") must match input last dim (", x1.size(-1),
28+
")");
29+
30+
int64_t dimLength = x1.size(-1);
31+
int64_t totalRows = x1.numel() / dimLength;
32+
33+
if (totalRows == 0 || dimLength == 0) {
34+
return {at::empty_like(x1), at::empty_like(x1)};
35+
}
36+
37+
at::Tensor inp1 = x1.contiguous();
38+
at::Tensor inp2 = x2.contiguous();
39+
int64_t dtypeSize = inp1.element_size();
40+
41+
// Hardware parameters.
42+
auto ascendc_platform =
43+
platform_ascendc::PlatformAscendCManager::GetInstance();
44+
int64_t coreNum = static_cast<int64_t>(ascendc_platform->GetCoreNumAiv());
45+
uint64_t ubSize;
46+
ascendc_platform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
47+
int64_t ubSizeLimit = static_cast<int64_t>(ubSize);
48+
49+
// Alignment (32-byte boundary).
50+
int64_t alignElements = 32 / dtypeSize;
51+
int64_t dimLengthAlign =
52+
((dimLength + alignElements - 1) / alignElements) * alignElements;
53+
54+
// UB capacity check.
55+
// fp16: inQ_x1(×2×2) + inQ_x2(×2×2) + outQ_y(×2×2) + outQ_xout(×2×2)
56+
// + fp32Buf1(×4) + fp32Buf2(×4) + weight(×4) = 16 + 12 = 28
57+
// fp32: inQ_x1(×2×4) + inQ_x2(×2×4) + outQ_y(×2×4) + outQ_xout(×2×4)
58+
// + weight(×4) = 32 + 4 = 36
59+
int64_t bufferCoefficient = (dtypeSize == 2) ? 28 : 36;
60+
int64_t maxDimLength = (ubSizeLimit - 1024) / bufferCoefficient;
61+
int64_t fpAlignElements = 32 / 4;
62+
maxDimLength = (maxDimLength / fpAlignElements) * fpAlignElements;
63+
TORCH_CHECK(dimLengthAlign <= maxDimLength, "add_rms_norm: dimLength ",
64+
dimLength, " (aligned ", dimLengthAlign,
65+
") exceeds UB capacity (max ", maxDimLength, ")");
66+
67+
// Padding.
68+
at::Tensor kernelInput1;
69+
at::Tensor kernelInput2;
70+
71+
if (dimLength != dimLengthAlign) {
72+
kernelInput1 = inp1.reshape({totalRows, dimLength});
73+
kernelInput1 =
74+
at::constant_pad_nd(kernelInput1, {0, dimLengthAlign - dimLength}, 0.0);
75+
kernelInput1 = kernelInput1.contiguous();
76+
77+
kernelInput2 = inp2.reshape({totalRows, dimLength});
78+
kernelInput2 =
79+
at::constant_pad_nd(kernelInput2, {0, dimLengthAlign - dimLength}, 0.0);
80+
kernelInput2 = kernelInput2.contiguous();
81+
} else {
82+
kernelInput1 = inp1.reshape({totalRows, dimLengthAlign}).contiguous();
83+
kernelInput2 = inp2.reshape({totalRows, dimLengthAlign}).contiguous();
84+
}
85+
86+
at::Tensor kernelOutputY = at::empty_like(kernelInput1);
87+
at::Tensor kernelOutputXOut = at::empty_like(kernelInput1);
88+
89+
// Weight: always pass as fp32, padded to `dimLengthAlign`.
90+
at::Tensor weightFloat = weight.contiguous().to(at::kFloat);
91+
92+
if (dimLength != dimLengthAlign) {
93+
weightFloat =
94+
at::constant_pad_nd(weightFloat, {0, dimLengthAlign - dimLength}, 0.0);
95+
}
96+
97+
weightFloat = weightFloat.contiguous();
98+
99+
// Block-level tiling (distribute rows across cores).
100+
int64_t usedCoreNum = std::min(totalRows, coreNum);
101+
int64_t formerLength = (totalRows + usedCoreNum - 1) / usedCoreNum;
102+
int64_t tailLength = formerLength - 1;
103+
int64_t formerNum = totalRows - tailLength * usedCoreNum;
104+
uint32_t blockDim = static_cast<uint32_t>(usedCoreNum);
105+
106+
// All EXEC_KERNEL_CMD args must be lvalues.
107+
float epsFloat = static_cast<float>(eps);
108+
int64_t dtypeSizeVal = dtypeSize;
109+
110+
EXEC_KERNEL_CMD(add_rms_norm, blockDim, kernelInput1, kernelInput2,
111+
weightFloat, kernelOutputY, kernelOutputXOut, totalRows,
112+
dimLength, dimLengthAlign, formerNum, formerLength,
113+
tailLength, epsFloat, dtypeSizeVal);
114+
115+
// Remove padding and reshape back to original shape.
116+
at::Tensor outputY = kernelOutputY;
117+
at::Tensor outputXOut = kernelOutputXOut;
118+
119+
if (dimLength != dimLengthAlign) {
120+
outputY = outputY.narrow(-1, 0, dimLength).contiguous();
121+
outputXOut = outputXOut.narrow(-1, 0, dimLength).contiguous();
122+
}
123+
124+
outputY = outputY.reshape(x1.sizes());
125+
outputXOut = outputXOut.reshape(x1.sizes());
126+
127+
return {outputY, outputXOut};
142128
}
143129

144130
} // namespace ascend_kernel

0 commit comments

Comments
 (0)