Skip to content

Commit 7bee958

Browse files
committed
fix
1 parent b0a412a commit 7bee958

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

backends/gcu/common/gcu_op_runner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using DataType = phi::DataType;
3333
using Place = phi::Place;
3434
using CPUPlace = phi::CPUPlace;
3535
using CPUContext = phi::CPUContext;
36+
using CustomContext = phi::CustomContext;
3637
using phi::DataTypeToString;
3738
using TensorNameMap = std::map<std::string, std::vector<std::string>>;
3839
using TensorValueMap = std::map<std::string, std::vector<DenseTensor*>>;

backends/gcu/kernels/embedding_kernel.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ void EmbeddingKernel(const Context& dev_ctx,
6868
meta_info.strides = DenseTensorMeta::calc_strides(meta_info.dims);
6969
pad_tensor.set_meta(meta_info);
7070
zero_tensor.set_meta(meta_info);
71-
custom_kernel::FullKernel<int32_t, phi::CustomContext>(
71+
custom_kernel::FullKernel<int32_t, CustomContext>(
7272
dev_ctx, shape, phi::Scalar(padding_idx), x.dtype(), &pad_tensor);
73-
custom_kernel::FullKernel<T, phi::CustomContext>(
73+
custom_kernel::FullKernel<T, CustomContext>(
7474
dev_ctx, shape, phi::Scalar(0), x.dtype(), &zero_tensor);
7575
meta_info.dtype = DataType::BOOL;
7676
mask_tensor.set_meta(meta_info);
@@ -87,10 +87,10 @@ void EmbeddingKernel(const Context& dev_ctx,
8787
x_brd.set_meta(pad_tensor.meta());
8888
dev_ctx.Alloc(&x_brd, x_brd.dtype());
8989
custom_kernel::Broadcast(dev_ctx, x_expand, &x_brd);
90-
custom_kernel::EqualKernel<bool, phi::CustomContext>(
90+
custom_kernel::EqualKernel<bool, CustomContext>(
9191
dev_ctx, pad_tensor, x_brd, &mask_tensor);
9292
pad_tensor.set_meta(out->meta());
93-
custom_kernel::WhereKernel<T, phi::CustomContext>(
93+
custom_kernel::WhereKernel<T, CustomContext>(
9494
dev_ctx, mask_tensor, zero_tensor, *out, &pad_tensor);
9595
*out = pad_tensor;
9696

0 commit comments

Comments
 (0)