Skip to content

Commit 22dcb2f

Browse files
committed
fix
1 parent f02a082 commit 22dcb2f

7 files changed

Lines changed: 45 additions & 54 deletions

File tree

paddle/phi/core/tensor_utils.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ template phi::dtype::complex<double> GetValue(const DenseTensor* x);
994994
template <typename T>
995995
std::vector<T> GetVectorFromTensor(const DenseTensor* x) {
996996
std::vector<T> vec_new_data;
997-
if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT32) {
997+
if (x->dtype() == DataType::INT32) {
998998
auto* data = x->data<int>();
999999
DenseTensor cpu_attr_tensor;
10001000
if (x->place().GetType() != phi::AllocationType::CPU) {
@@ -1004,7 +1004,7 @@ std::vector<T> GetVectorFromTensor(const DenseTensor* x) {
10041004
data = cpu_attr_tensor.data<int>();
10051005
}
10061006
vec_new_data = std::vector<T>(data, data + x->numel());
1007-
} else if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT64) {
1007+
} else if (x->dtype() == DataType::INT64) {
10081008
auto* data = x->data<int64_t>();
10091009
DenseTensor cpu_attr_tensor;
10101010
if (x->place().GetType() != phi::AllocationType::CPU) {
@@ -1018,7 +1018,7 @@ std::vector<T> GetVectorFromTensor(const DenseTensor* x) {
10181018
} else {
10191019
PADDLE_THROW(common::errors::InvalidArgument(
10201020
"The dtype of Tensor must be int32 or int64, but received: %s",
1021-
phi::TransToProtoVarType(x->dtype())));
1021+
x->dtype()));
10221022
}
10231023
return vec_new_data;
10241024
}
@@ -1046,20 +1046,18 @@ std::vector<T> _GetVectorFromTensor(const DenseTensor* x) {
10461046

10471047
template <>
10481048
std::vector<float> GetVectorFromTensor<float>(const DenseTensor* x) {
1049-
if (phi::TransToProtoVarType(x->dtype()) != ProtoDataType::FP32) {
1049+
if (x->dtype() != DataType::FLOAT32) {
10501050
PADDLE_THROW(common::errors::InvalidArgument(
1051-
"The dtype of Tensor must be float32, but received: %s",
1052-
phi::TransToProtoVarType(x->dtype())));
1051+
"The dtype of Tensor must be float32, but received: %s", x->dtype()));
10531052
}
10541053
return _GetVectorFromTensor<float>(x);
10551054
}
10561055

10571056
template <>
10581057
std::vector<double> GetVectorFromTensor<double>(const DenseTensor* x) {
1059-
if (phi::TransToProtoVarType(x->dtype()) != ProtoDataType::FP64) {
1058+
if (x->dtype() != DataType::FLOAT64) {
10601059
PADDLE_THROW(common::errors::InvalidArgument(
1061-
"The dtype of Tensor must be float64, but received: %s",
1062-
phi::TransToProtoVarType(x->dtype())));
1060+
"The dtype of Tensor must be float64, but received: %s", x->dtype()));
10631061
}
10641062
return _GetVectorFromTensor<double>(x);
10651063
}

paddle/phi/infermeta/multiary.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,11 +2849,12 @@ void FusionGroupInferMeta(const std::vector<const MetaTensor*>& ins,
28492849
}
28502850

28512851
for (size_t j = 0; j < num_outs; ++j) {
2852-
if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT16)) {
2852+
DataType out_dtype = TransToPhiDataType(outs_dtype[j]);
2853+
if (out_dtype == DataType::FLOAT16) {
28532854
outs[j]->set_dtype(DataType::FLOAT16);
2854-
} else if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT32)) {
2855+
} else if (out_dtype == DataType::FLOAT32) {
28552856
outs[j]->set_dtype(DataType::FLOAT32);
2856-
} else if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT64)) {
2857+
} else if (out_dtype == DataType::FLOAT64) {
28572858
outs[j]->set_dtype(DataType::FLOAT64);
28582859
}
28592860
}

paddle/phi/kernels/cpu/pyramid_hash_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ void CPUPyramidHashOPKernel(const Context& dev_ctx,
228228
if (iter != iter_end) {
229229
exit(1);
230230
}
231-
auto weight_type = TransToProtoVarType(_blobs_0->dtype());
232-
if (_is_training == 0 && weight_type != ProtoDataType::INT8) {
231+
auto weight_type = _blobs_0->dtype();
232+
if (_is_training == 0 && weight_type != DataType::INT8) {
233233
funcs::axpy_noadd(
234234
top_data, top_data, top->dims()[0] * top->dims()[1], _drop_out_percent);
235235
}

paddle/phi/kernels/cpu/tdm_sampler_kernel.cc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
262262
DenseTensor *out,
263263
DenseTensor *labels,
264264
DenseTensor *mask) {
265-
const auto &input_type = TransToProtoVarType(x.dtype());
265+
const auto &input_type = x.dtype();
266266
bool input_type_match =
267-
input_type == ProtoDataType::INT32 || input_type == ProtoDataType::INT64;
267+
input_type == DataType::INT32 || input_type == DataType::INT64;
268268
PADDLE_ENFORCE_EQ(input_type_match,
269269
true,
270270
common::errors::InvalidArgument(
@@ -274,9 +274,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
274274
DataTypeToString(DataType::INT32),
275275
DataTypeToString(DataType::INT64)));
276276

277-
const auto &travel_type = TransToProtoVarType(travel.dtype());
278-
bool travel_type_match = travel_type == ProtoDataType::INT32 ||
279-
travel_type == ProtoDataType::INT64;
277+
const auto &travel_type = travel.dtype();
278+
bool travel_type_match =
279+
travel_type == DataType::INT32 || travel_type == DataType::INT64;
280280
PADDLE_ENFORCE_EQ(travel_type_match,
281281
true,
282282
common::errors::InvalidArgument(
@@ -286,9 +286,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
286286
DataTypeToString(DataType::INT32),
287287
DataTypeToString(DataType::INT64)));
288288

289-
const auto &layer_type = TransToProtoVarType(layer.dtype());
289+
const auto &layer_type = layer.dtype();
290290
bool layer_type_match =
291-
layer_type == ProtoDataType::INT32 || layer_type == ProtoDataType::INT64;
291+
layer_type == DataType::INT32 || layer_type == DataType::INT64;
292292
PADDLE_ENFORCE_EQ(layer_type_match,
293293
true,
294294
common::errors::InvalidArgument(
@@ -305,10 +305,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
305305
DataTypeToString(travel.dtype()),
306306
DataTypeToString(layer.dtype())));
307307

308-
auto output_type = static_cast<ProtoDataType>(dtype);
308+
auto output_type = TransToPhiDataType(dtype);
309309

310-
if (travel_type == ProtoDataType::INT32 &&
311-
output_type == ProtoDataType::INT32) {
310+
if (travel_type == DataType::INT32 && output_type == DataType::INT32) {
312311
TDMSamplerInner<T, Context, int, int>(dev_ctx,
313312
x,
314313
travel,
@@ -320,8 +319,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
320319
out,
321320
labels,
322321
mask);
323-
} else if (travel_type == ProtoDataType::INT64 &&
324-
output_type == ProtoDataType::INT32) {
322+
} else if (travel_type == DataType::INT64 && output_type == DataType::INT32) {
325323
TDMSamplerInner<T, Context, int64_t, int>(dev_ctx,
326324
x,
327325
travel,
@@ -333,8 +331,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
333331
out,
334332
labels,
335333
mask);
336-
} else if (travel_type == ProtoDataType::INT32 &&
337-
output_type == ProtoDataType::INT64) {
334+
} else if (travel_type == DataType::INT32 && output_type == DataType::INT64) {
338335
TDMSamplerInner<T, Context, int, int64_t>(dev_ctx,
339336
x,
340337
travel,
@@ -346,8 +343,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
346343
out,
347344
labels,
348345
mask);
349-
} else if (travel_type == ProtoDataType::INT64 &&
350-
output_type == ProtoDataType::INT64) {
346+
} else if (travel_type == DataType::INT64 && output_type == DataType::INT64) {
351347
TDMSamplerInner<T, Context, int64_t, int64_t>(dev_ctx,
352348
x,
353349
travel,

paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@ static void MutableMultiTypeData(std::vector<DenseTensor*>* var,
2828
const std::vector<int>& data_type,
2929
const Context& dev_ctx) {
3030
for (size_t i = 0; i < var->size(); i++) {
31-
if (data_type[i] == phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
31+
DataType dtype = TransToPhiDataType(data_type[i]);
32+
if (dtype == DataType::FLOAT32) {
3233
dev_ctx.template Alloc<float>((*var)[i],
3334
(*var)[i]->numel() * sizeof(float));
34-
} else if (data_type[i] ==
35-
phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
36-
dev_ctx.template Alloc<phi::float16>(
37-
(*var)[i], (*var)[i]->numel() * sizeof(phi::float16));
38-
} else if (data_type[i] ==
39-
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
35+
} else if (dtype == DataType::FLOAT16) {
36+
dev_ctx.template Alloc<float16>((*var)[i],
37+
(*var)[i]->numel() * sizeof(float16));
38+
} else if (dtype == DataType::FLOAT64) {
4039
dev_ctx.template Alloc<double>((*var)[i],
4140
(*var)[i]->numel() * sizeof(double));
4241
}
@@ -66,25 +65,23 @@ void FusionGroupKernel(const Context& dev_ctx,
6665
args.push_back(&n);
6766
std::vector<const void*> ptrs(num_ins + num_outs);
6867
for (size_t i = 0; i < num_ins; ++i) {
69-
if (inputs_dtype[i] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
70-
ptrs[i] = ins[i]->data<phi::float16>();
71-
} else if (inputs_dtype[i] ==
72-
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
68+
DataType input_dtype = TransToPhiDataType(inputs_dtype[i]);
69+
if (input_dtype == DataType::FLOAT16) {
70+
ptrs[i] = ins[i]->data<float16>();
71+
} else if (input_dtype == DataType::FLOAT32) {
7372
ptrs[i] = ins[i]->data<float>();
74-
} else if (inputs_dtype[i] ==
75-
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
73+
} else if (input_dtype == DataType::FLOAT64) {
7674
ptrs[i] = ins[i]->data<double>();
7775
}
7876
args.push_back(&ptrs[i]);
7977
}
8078
for (size_t j = 0; j < num_outs; ++j) {
81-
if (outs_dtype[j] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
82-
ptrs[num_ins + j] = outs[j]->data<phi::float16>();
83-
} else if (outs_dtype[j] ==
84-
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
79+
DataType out_dtype = TransToPhiDataType(outs_dtype[j]);
80+
if (out_dtype == DataType::FLOAT16) {
81+
ptrs[num_ins + j] = outs[j]->data<float16>();
82+
} else if (out_dtype == DataType::FLOAT32) {
8583
ptrs[num_ins + j] = outs[j]->data<float>();
86-
} else if (outs_dtype[j] ==
87-
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
84+
} else if (out_dtype == DataType::FLOAT64) {
8885
ptrs[num_ins + j] = outs[j]->data<double>();
8986
}
9087
args.push_back(&ptrs[num_ins + j]);

paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,13 @@ void RunKernel(const phi::OneDNNContext& dev_ctx,
489489
std::shared_ptr<dnnl::memory> h0_memory_p, weight_h_memory_p,
490490
weight_x_memory_p;
491491

492-
if (phi::TransToProtoVarType(weight_h.dtype()) == phi::ProtoDataType::FP32) {
492+
if (weight_h.dtype() == DataType::FLOAT32) {
493493
h0_memory_p = handler.template AcquireH0Memory<float>(h0.get_ptr());
494494
weight_x_memory_p =
495495
handler.template AcquireWeightXMemory<float>(&weight_x, origin_mode);
496496
weight_h_memory_p =
497497
handler.template AcquireWeightHMemory<float>(&weight_h, origin_mode);
498-
} else if (phi::TransToProtoVarType(weight_h.dtype()) ==
499-
phi::ProtoDataType::BF16) {
498+
} else if (weight_h.dtype() == DataType::BFLOAT16) {
500499
h0_memory_p = handler.template AcquireH0Memory<phi::bfloat16>(h0.get_ptr());
501500
weight_x_memory_p = handler.template AcquireWeightXMemory<phi::bfloat16>(
502501
&weight_x, origin_mode);

paddle/phi/kernels/onednn/matmul_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,9 +449,9 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
449449
const DenseTensor *input_y,
450450
const engine &onednn_engine) {
451451
std::string key = funcs::CreateKey(dev_ctx,
452-
phi::TransToProtoVarType(input_x->dtype()),
452+
TransToProtoVarType(input_x->dtype()),
453453
vectorize(input_x->dims()),
454-
phi::TransToProtoVarType(input_y->dtype()),
454+
TransToProtoVarType(input_y->dtype()),
455455
vectorize(input_y->dims()),
456456
dev_ctx.GetOutputsName("Out")[0]);
457457
key = funcs::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);

0 commit comments

Comments
 (0)