@@ -22,9 +22,9 @@ namespace custom_kernel {
2222
2323template <typename Context>
2424void AbsGrad (const Context& dev_ctx,
25- const phi:: DenseTensor& x,
26- const phi:: DenseTensor& dout,
27- phi:: DenseTensor* dx) {
25+ const DenseTensor& x,
26+ const DenseTensor& dout,
27+ DenseTensor* dx) {
2828 tecodnnHandle_t tecodnnHandle = GetHandleFromCTX (dev_ctx);
2929 int num = static_cast <int >(x.numel ());
3030 std::vector<int > dims = {1 , 1 , 1 , num};
@@ -46,9 +46,7 @@ void AbsGrad(const Context& dev_ctx,
4646}
4747
4848template <typename T, typename Context>
49- void AbsKernel (const Context& dev_ctx,
50- const phi::DenseTensor& x,
51- phi::DenseTensor* out) {
49+ void AbsKernel (const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
5250 VLOG (4 ) << " Call SDAA AbsKernel" ;
5351 dev_ctx.template Alloc <T>(out);
5452
@@ -57,9 +55,9 @@ void AbsKernel(const Context& dev_ctx,
5755
5856template <typename T, typename Context>
5957void AbsGradKernel (const Context& dev_ctx,
60- const phi:: DenseTensor& x,
61- const phi:: DenseTensor& dout,
62- phi:: DenseTensor* dx) {
58+ const DenseTensor& x,
59+ const DenseTensor& dout,
60+ DenseTensor* dx) {
6361 VLOG (4 ) << " Call SDAA AbsGradKernel" ;
6462 dev_ctx.template Alloc <T>(dx);
6563
0 commit comments