Skip to content

Commit cede8d4

Browse files
committed
update sdaa kernels
1 parent b334e9c commit cede8d4

127 files changed

Lines changed: 1942 additions & 1977 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/sdaa/kernels/abs_kernel.cc

100755100644
Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ namespace custom_kernel {
2222

2323
template <typename Context>
2424
void AbsGrad(const Context& dev_ctx,
25-
const phi::DenseTensor& x,
26-
const phi::DenseTensor& dout,
27-
phi::DenseTensor* dx) {
25+
const DenseTensor& x,
26+
const DenseTensor& dout,
27+
DenseTensor* dx) {
2828
tecodnnHandle_t tecodnnHandle = GetHandleFromCTX(dev_ctx);
2929
int num = static_cast<int>(x.numel());
3030
std::vector<int> dims = {1, 1, 1, num};
@@ -46,9 +46,7 @@ void AbsGrad(const Context& dev_ctx,
4646
}
4747

4848
template <typename T, typename Context>
49-
void AbsKernel(const Context& dev_ctx,
50-
const phi::DenseTensor& x,
51-
phi::DenseTensor* out) {
49+
void AbsKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
5250
VLOG(4) << "Call SDAA AbsKernel";
5351
dev_ctx.template Alloc<T>(out);
5452

@@ -57,9 +55,9 @@ void AbsKernel(const Context& dev_ctx,
5755

5856
template <typename T, typename Context>
5957
void AbsGradKernel(const Context& dev_ctx,
60-
const phi::DenseTensor& x,
61-
const phi::DenseTensor& dout,
62-
phi::DenseTensor* dx) {
58+
const DenseTensor& x,
59+
const DenseTensor& dout,
60+
DenseTensor* dx) {
6361
VLOG(4) << "Call SDAA AbsGradKernel";
6462
dev_ctx.template Alloc<T>(dx);
6563

backends/sdaa/kernels/accuracy_kernel.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ namespace custom_kernel {
2121

2222
template <typename T, typename Context>
2323
void AccuracyRawKernel(const Context& dev_ctx,
24-
const phi::DenseTensor& inference,
25-
const phi::DenseTensor& indices,
26-
const phi::DenseTensor& label,
27-
phi::DenseTensor* accuracy,
28-
phi::DenseTensor* correct,
29-
phi::DenseTensor* total) {
24+
const DenseTensor& inference,
25+
const DenseTensor& indices,
26+
const DenseTensor& label,
27+
DenseTensor* accuracy,
28+
DenseTensor* correct,
29+
DenseTensor* total) {
3030
VLOG(4) << "Call sdaa Accuracy kernel";
3131
dev_ctx.template Alloc<T>(accuracy);
3232
dev_ctx.template Alloc<T>(correct);

0 commit comments

Comments
 (0)