@@ -262,9 +262,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
262262 DenseTensor *out,
263263 DenseTensor *labels,
264264 DenseTensor *mask) {
265- const auto &input_type = TransToProtoVarType ( x.dtype () );
265+ const auto &input_type = x.dtype ();
266266 bool input_type_match =
267- input_type == ProtoDataType ::INT32 || input_type == ProtoDataType ::INT64;
267+ input_type == DataType ::INT32 || input_type == DataType ::INT64;
268268 PADDLE_ENFORCE_EQ (input_type_match,
269269 true ,
270270 common::errors::InvalidArgument (
@@ -274,9 +274,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
274274 DataTypeToString (DataType::INT32),
275275 DataTypeToString (DataType::INT64)));
276276
277- const auto &travel_type = TransToProtoVarType ( travel.dtype () );
278- bool travel_type_match = travel_type == ProtoDataType::INT32 ||
279- travel_type == ProtoDataType ::INT64;
277+ const auto &travel_type = travel.dtype ();
278+ bool travel_type_match =
279+ travel_type == DataType::INT32 || travel_type == DataType ::INT64;
280280 PADDLE_ENFORCE_EQ (travel_type_match,
281281 true ,
282282 common::errors::InvalidArgument (
@@ -286,9 +286,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
286286 DataTypeToString (DataType::INT32),
287287 DataTypeToString (DataType::INT64)));
288288
289- const auto &layer_type = TransToProtoVarType ( layer.dtype () );
289+ const auto &layer_type = layer.dtype ();
290290 bool layer_type_match =
291- layer_type == ProtoDataType ::INT32 || layer_type == ProtoDataType ::INT64;
291+ layer_type == DataType ::INT32 || layer_type == DataType ::INT64;
292292 PADDLE_ENFORCE_EQ (layer_type_match,
293293 true ,
294294 common::errors::InvalidArgument (
@@ -305,10 +305,9 @@ void TDMSamplerKernel(const Context &dev_ctx,
305305 DataTypeToString (travel.dtype ()),
306306 DataTypeToString (layer.dtype ())));
307307
308- auto output_type = static_cast <ProtoDataType> (dtype);
308+ auto output_type = TransToPhiDataType (dtype);
309309
310- if (travel_type == ProtoDataType::INT32 &&
311- output_type == ProtoDataType::INT32) {
310+ if (travel_type == DataType::INT32 && output_type == DataType::INT32) {
312311 TDMSamplerInner<T, Context, int , int >(dev_ctx,
313312 x,
314313 travel,
@@ -320,8 +319,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
320319 out,
321320 labels,
322321 mask);
323- } else if (travel_type == ProtoDataType::INT64 &&
324- output_type == ProtoDataType::INT32) {
322+ } else if (travel_type == DataType::INT64 && output_type == DataType::INT32) {
325323 TDMSamplerInner<T, Context, int64_t , int >(dev_ctx,
326324 x,
327325 travel,
@@ -333,8 +331,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
333331 out,
334332 labels,
335333 mask);
336- } else if (travel_type == ProtoDataType::INT32 &&
337- output_type == ProtoDataType::INT64) {
334+ } else if (travel_type == DataType::INT32 && output_type == DataType::INT64) {
338335 TDMSamplerInner<T, Context, int , int64_t >(dev_ctx,
339336 x,
340337 travel,
@@ -346,8 +343,7 @@ void TDMSamplerKernel(const Context &dev_ctx,
346343 out,
347344 labels,
348345 mask);
349- } else if (travel_type == ProtoDataType::INT64 &&
350- output_type == ProtoDataType::INT64) {
346+ } else if (travel_type == DataType::INT64 && output_type == DataType::INT64) {
351347 TDMSamplerInner<T, Context, int64_t , int64_t >(dev_ctx,
352348 x,
353349 travel,
0 commit comments