@@ -68,9 +68,9 @@ void EmbeddingKernel(const Context& dev_ctx,
6868 meta_info.strides = DenseTensorMeta::calc_strides (meta_info.dims );
6969 pad_tensor.set_meta (meta_info);
7070 zero_tensor.set_meta (meta_info);
71- custom_kernel::FullKernel<int32_t , phi:: CustomContext>(
71+ custom_kernel::FullKernel<int32_t , CustomContext>(
7272 dev_ctx, shape, phi::Scalar (padding_idx), x.dtype (), &pad_tensor);
73- custom_kernel::FullKernel<T, phi:: CustomContext>(
73+ custom_kernel::FullKernel<T, CustomContext>(
7474 dev_ctx, shape, phi::Scalar (0 ), x.dtype (), &zero_tensor);
7575 meta_info.dtype = DataType::BOOL;
7676 mask_tensor.set_meta (meta_info);
@@ -87,10 +87,10 @@ void EmbeddingKernel(const Context& dev_ctx,
8787 x_brd.set_meta (pad_tensor.meta ());
8888 dev_ctx.Alloc (&x_brd, x_brd.dtype ());
8989 custom_kernel::Broadcast (dev_ctx, x_expand, &x_brd);
90- custom_kernel::EqualKernel<bool , phi:: CustomContext>(
90+ custom_kernel::EqualKernel<bool , CustomContext>(
9191 dev_ctx, pad_tensor, x_brd, &mask_tensor);
9292 pad_tensor.set_meta (out->meta ());
93- custom_kernel::WhereKernel<T, phi:: CustomContext>(
93+ custom_kernel::WhereKernel<T, CustomContext>(
9494 dev_ctx, mask_tensor, zero_tensor, *out, &pad_tensor);
9595 *out = pad_tensor;
9696
0 commit comments