Skip to content

Commit 2db3a8b

Browse files
committed
update kernels funcs
1 parent b334e9c commit 2db3a8b

6 files changed

Lines changed: 38 additions & 36 deletions

File tree

backends/mlu/kernels/funcs/elementwise_utils.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ inline void GetReduceAxesAndDstDims(const int axis,
120120

121121
template <typename T>
122122
void MLUOpTensorKernel(const Context& dev_ctx,
123-
const phi::DenseTensor& x,
124-
const phi::DenseTensor& y,
123+
const DenseTensor& x,
124+
const DenseTensor& y,
125125
int axis,
126126
const cnnlOpTensorDesc_t op_tensor_type,
127-
phi::DenseTensor* out) {
127+
DenseTensor* out) {
128128
PADDLE_ENFORCE_EQ((op_tensor_type == CNNL_OP_TENSOR_ADD) ||
129129
(op_tensor_type == CNNL_OP_TENSOR_SUB) ||
130130
(op_tensor_type == CNNL_OP_TENSOR_MUL),
@@ -241,10 +241,10 @@ inline void MLUBinary<POW>(const Context& dev_ctx,
241241

242242
template <BINARY_FUNCTOR Functor, typename T>
243243
void MLUBinaryOp(const Context& dev_ctx,
244-
const phi::DenseTensor& x,
245-
const phi::DenseTensor& y,
244+
const DenseTensor& x,
245+
const DenseTensor& y,
246246
int axis,
247-
phi::DenseTensor* out) {
247+
DenseTensor* out) {
248248
dev_ctx.template Alloc<T>(out);
249249
Tensor x_t, y_t;
250250
x_t = x;
@@ -319,8 +319,8 @@ inline void MLUUnary<RECIPROCAL>(const Context& dev_ctx,
319319

320320
template <UNARY_FUNCTOR Functor, typename Tin, typename Tout = Tin>
321321
void MLUUnaryOp(const Context& dev_ctx,
322-
const phi::DenseTensor& x,
323-
phi::DenseTensor* out) {
322+
const DenseTensor& x,
323+
DenseTensor* out) {
324324
dev_ctx.template Alloc<Tout>(out);
325325

326326
MLUCnnlTensorDesc x_desc(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType<Tin>());
@@ -342,12 +342,12 @@ enum MINMAX_GRAD_FUNCTOR {
342342
};
343343
template <MINMAX_GRAD_FUNCTOR Functor, typename Tin, typename Tout = Tin>
344344
void MLUMinMaxGradHelper(const Context& dev_ctx,
345-
const phi::DenseTensor& x,
346-
const phi::DenseTensor& y,
347-
const phi::DenseTensor& dout,
345+
const DenseTensor& x,
346+
const DenseTensor& y,
347+
const DenseTensor& dout,
348348
int axis,
349-
phi::DenseTensor* dx,
350-
phi::DenseTensor* dy) {
349+
DenseTensor* dx,
350+
DenseTensor* dy) {
351351
const auto& x_dims = x.dims();
352352
const auto& y_dims = y.dims();
353353
axis =

backends/mlu/kernels/funcs/logic_op.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ namespace custom_kernel {
2121

2222
template <typename Context>
2323
void MLULogicOp(const Context& dev_ctx,
24-
const phi::DenseTensor& x,
25-
const phi::DenseTensor& y,
24+
const DenseTensor& x,
25+
const DenseTensor& y,
2626
const std::string& logic_name,
27-
phi::DenseTensor* out) {
27+
DenseTensor* out) {
2828
dev_ctx.template Alloc<bool>(out);
2929

3030
MLUCnnlTensorDesc input_x(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType(x.dtype()));

backends/mlu/kernels/funcs/mlu_baseop.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
namespace custom_kernel {
2626

2727
using Tensor = phi::DenseTensor;
28+
using DenseTensor = phi::DenseTensor;
29+
using DenseTensorMeta = phi::DenseTensorMeta;
30+
using Scalar = phi::Scalar;
31+
using DDim = phi::DDim;
2832
using Context = phi::CustomContext;
2933
using DataType = phi::DataType;
3034
using DataLayout = phi::DataLayout;

backends/mlu/kernels/funcs/mlu_funcs.h

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ namespace custom_kernel {
2525
*/
2626
template <typename Context>
2727
inline void TensorCopy(const Context& dev_ctx,
28-
const phi::DenseTensor& src,
28+
const DenseTensor& src,
2929
bool blocking,
30-
phi::DenseTensor* dst,
30+
DenseTensor* dst,
3131
const phi::Place& dst_place = phi::CustomPlace()) {
3232
dev_ctx.Wait();
3333
auto* src_ptr = src.data();
@@ -103,7 +103,7 @@ template <typename T>
103103
inline void TensorFromVector(const phi::CustomContext& ctx,
104104
const std::vector<T>& src,
105105
const phi::CustomContext& dev_ctx,
106-
phi::DenseTensor* dst) {
106+
DenseTensor* dst) {
107107
auto dst_place = dev_ctx.GetPlace();
108108
C_Device_st device{dst_place.GetDeviceId()};
109109
auto src_ptr = static_cast<const void*>(src.data());
@@ -128,7 +128,7 @@ template <>
128128
inline void TensorFromVector<bool>(const phi::CustomContext& ctx,
129129
const std::vector<bool>& src,
130130
const phi::CustomContext& dev_ctx,
131-
phi::DenseTensor* dst) {
131+
DenseTensor* dst) {
132132
// vector<bool> has no data() member, use array instead.
133133
// See details:
134134
// https://stackoverflow.com/questions/46115669/why-does-stdvectorbool-have-no-data/46115714
@@ -166,7 +166,7 @@ template <typename T>
166166
inline void TensorFromVector(const phi::CustomContext& ctx,
167167
const std::vector<T>& src,
168168
const phi::CPUContext& dev_ctx,
169-
phi::DenseTensor* dst) {
169+
DenseTensor* dst) {
170170
auto dst_place = dev_ctx.GetPlace();
171171
C_Device_st device{dst_place.GetDeviceId()};
172172
auto src_ptr = static_cast<const void*>(src.data());
@@ -191,7 +191,7 @@ template <>
191191
inline void TensorFromVector<bool>(const phi::CustomContext& ctx,
192192
const std::vector<bool>& src,
193193
const phi::CPUContext& dev_ctx,
194-
phi::DenseTensor* dst) {
194+
DenseTensor* dst) {
195195
auto dst_place = dev_ctx.GetPlace();
196196
PADDLE_THROW(phi::errors::Unimplemented(
197197
"TensorFromVector on %s is not supported.", dst_place));
@@ -202,7 +202,7 @@ void TensorFromArray(const phi::CustomContext& ctx,
202202
const T* src,
203203
const size_t& array_size,
204204
const phi::CustomContext& dev_ctx,
205-
phi::DenseTensor* dst) {
205+
DenseTensor* dst) {
206206
auto dst_place = dev_ctx.GetPlace();
207207
C_Device_st device{dst_place.GetDeviceId()};
208208
auto src_ptr = static_cast<const void*>(src);
@@ -227,7 +227,7 @@ void TensorFromArray(const phi::CustomContext& ctx,
227227
*/
228228
template <typename T>
229229
inline void TensorToVector(const phi::CustomContext& ctx,
230-
const phi::DenseTensor& src,
230+
const DenseTensor& src,
231231
const phi::CustomContext& dev_ctx,
232232
std::vector<T>* dst) {
233233
auto src_ptr = static_cast<const void*>(src.data<T>());
@@ -251,7 +251,7 @@ inline void TensorToVector(const phi::CustomContext& ctx,
251251

252252
template <>
253253
inline void TensorToVector<bool>(const phi::CustomContext& ctx,
254-
const phi::DenseTensor& src,
254+
const DenseTensor& src,
255255
const phi::CustomContext& dev_ctx,
256256
std::vector<bool>* dst) {
257257
auto src_ptr = static_cast<const void*>(src.data<bool>());
@@ -359,11 +359,10 @@ inline void ExtractNCDWH(const phi::DDim& dims,
359359

360360
template <typename T>
361361
inline std::vector<T> get_new_data_from_tensor(
362-
const phi::CustomContext& dev_ctx,
363-
const phi::DenseTensor* new_data_tensor) {
362+
const phi::CustomContext& dev_ctx, const DenseTensor* new_data_tensor) {
364363
std::vector<T> vec_new_data;
365364
auto place = new_data_tensor->place();
366-
phi::DenseTensor cpu_starts_tensor;
365+
DenseTensor cpu_starts_tensor;
367366
if (place.GetType() == phi::AllocationType::CUSTOM) {
368367
// if tensor on CUSTOM place, do memcpy to host
369368
cpu_starts_tensor.Resize(new_data_tensor->dims());
@@ -381,22 +380,21 @@ inline std::vector<T> get_new_data_from_tensor(
381380
}
382381

383382
template <typename T>
384-
inline phi::DenseTensor ReshapeToMatrix(const phi::DenseTensor& src,
385-
T num_col_dims) {
383+
inline DenseTensor ReshapeToMatrix(const DenseTensor& src, T num_col_dims) {
386384
int rank = src.dims().size();
387385
PADDLE_ENFORCE_GE(
388386
rank,
389387
2,
390388
phi::errors::InvalidArgument(
391389
"'ReshapeToMatrix()' is only used for flatten high rank "
392-
"tensors to matrixs. The dimensions of phi::DenseTensor must be "
390+
"tensors to matrixs. The dimensions of DenseTensor must be "
393391
"greater or equal than 2. "
394-
"But received dimensions of phi::DenseTensor is %d",
392+
"But received dimensions of DenseTensor is %d",
395393
rank));
396394
if (rank == 2) {
397395
return src;
398396
}
399-
phi::DenseTensor res;
397+
DenseTensor res;
400398
res = src;
401399
res.Resize(phi::flatten_to_2d(src.dims(), num_col_dims));
402400
return res;

backends/mlu/kernels/funcs/range_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void ArangeRawKernel(const Context& dev_ctx,
2424
const T start_value,
2525
const T end_value,
2626
const T step_value,
27-
phi::DenseTensor* out) {
27+
DenseTensor* out) {
2828
int64_t size = 0;
2929
GetSize(start_value, end_value, step_value, &size);
3030

backends/mlu/kernels/funcs/reduce_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ namespace custom_kernel {
2121

2222
template <typename T, typename Context>
2323
void MLUReduceOp(const Context& dev_ctx,
24-
const phi::DenseTensor& x,
24+
const DenseTensor& x,
2525
const std::vector<int64_t>& axes,
2626
bool keep_dim,
2727
bool reduce_all,
2828
const std::string& reduce_name,
29-
phi::DenseTensor* out) {
29+
DenseTensor* out) {
3030
dev_ctx.template Alloc<T>(out);
3131
if (x.dims().size() == 0) {
3232
TensorCopy(dev_ctx, x, true, out);

0 commit comments

Comments
 (0)