@@ -369,16 +369,18 @@ void TabulateFusionSeTTebdForward(const torch::Tensor& table_tensor,
369369 // compute
370370 if (device == " GPU" ) {
371371#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
372- deepmd::tabulate_fusion_se_t_tebd_gpu (descriptor, table, table_info, em_x, em,
373- nloc, nnei_i, nnei_j, last_layer_size);
372+ deepmd::tabulate_fusion_se_t_tebd_gpu (descriptor, table, table_info, em_x,
373+ em, nloc, nnei_i, nnei_j,
374+ last_layer_size);
374375#else
375376 throw std::runtime_error (
376377 " The input tensor is on the GPU, but the GPU support for the "
377378 " customized OP library is not enabled." );
378379#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
379380 } else if (device == " CPU" ) {
380- deepmd::tabulate_fusion_se_t_tebd_cpu (descriptor, table, table_info, em_x, em,
381- nloc, nnei_i, nnei_j, last_layer_size);
381+ deepmd::tabulate_fusion_se_t_tebd_cpu (descriptor, table, table_info, em_x,
382+ em, nloc, nnei_i, nnei_j,
383+ last_layer_size);
382384 }
383385}
384386
@@ -414,28 +416,29 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
414416 if (device == " GPU" ) {
415417#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
416418 deepmd::tabulate_fusion_se_t_tebd_grad_gpu (dy_dem_x, table, table_info,
417- em_x, em, dy, nloc, nnei_i, nnei_j,
418- last_layer_size);
419+ em_x, em, dy, nloc, nnei_i,
420+ nnei_j, last_layer_size);
419421#else
420422 throw std::runtime_error (
421423 " The input tensor is on the GPU, but the GPU support for the "
422424 " customized OP library is not enabled." );
423425#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
424426 } else if (device == " CPU" ) {
425427 deepmd::tabulate_fusion_se_t_tebd_grad_cpu (dy_dem_x, table, table_info,
426- em_x, em, dy, nloc, nnei_i, nnei_j,
427- last_layer_size);
428+ em_x, em, dy, nloc, nnei_i,
429+ nnei_j, last_layer_size);
428430 }
429431}
430432
431433template <typename FPTYPE >
432- void TabulateFusionSeTTebdGradGradForward (const torch::Tensor& table_tensor,
433- const torch::Tensor& table_info_tensor,
434- const torch::Tensor& em_x_tensor,
435- const torch::Tensor& em_tensor,
436- const torch::Tensor& dz_dy_dem_x_tensor,
437- const torch::Tensor& descriptor_tensor,
438- torch::Tensor& dz_dy_tensor) {
434+ void TabulateFusionSeTTebdGradGradForward (
435+ const torch::Tensor& table_tensor,
436+ const torch::Tensor& table_info_tensor,
437+ const torch::Tensor& em_x_tensor,
438+ const torch::Tensor& em_tensor,
439+ const torch::Tensor& dz_dy_dem_x_tensor,
440+ const torch::Tensor& descriptor_tensor,
441+ torch::Tensor& dz_dy_tensor) {
439442 // Check input shape
440443 if (dz_dy_dem_x_tensor.dim () != 3 ) {
441444 throw std::invalid_argument (" Dim of dz_dy_dem_x should be 3" );
@@ -458,9 +461,9 @@ void TabulateFusionSeTTebdGradGradForward(const torch::Tensor& table_tensor,
458461 // compute
459462 if (device == " GPU" ) {
460463#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
461- deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu (dz_dy, table, table_info, em_x,
462- em, dz_dy_dem_x, nloc,
463- nnei_i, nnei_j, last_layer_size);
464+ deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu (
465+ dz_dy, table, table_info, em_x, em, dz_dy_dem_x, nloc, nnei_i, nnei_j ,
466+ last_layer_size);
464467#else
465468 throw std::runtime_error (
466469 " The input tensor is on the GPU, but the GPU support for the "
@@ -470,9 +473,9 @@ void TabulateFusionSeTTebdGradGradForward(const torch::Tensor& table_tensor,
470473 " In the process of model compression, the size of the "
471474 " last layer of embedding net must be less than 1024!" );
472475 } else if (device == " CPU" ) {
473- deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu (dz_dy, table, table_info, em_x,
474- em, dz_dy_dem_x, nloc,
475- nnei_i, nnei_j, last_layer_size);
476+ deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu (
477+ dz_dy, table, table_info, em_x, em, dz_dy_dem_x, nloc, nnei_i, nnei_j ,
478+ last_layer_size);
476479 }
477480}
478481
@@ -1112,13 +1115,14 @@ class TabulateFusionSeTTebdOp
11121115 auto options = torch::TensorOptions ()
11131116 .dtype (table_tensor.dtype ())
11141117 .device (table_tensor.device ());
1115- torch::Tensor descriptor_tensor = torch::empty (
1116- {em_tensor.size (0 ), em_tensor.size (1 ), em_tensor.size (2 ), last_layer_size},
1117- options);
1118+ torch::Tensor descriptor_tensor =
1119+ torch::empty ({em_tensor.size (0 ), em_tensor.size (1 ), em_tensor.size (2 ),
1120+ last_layer_size},
1121+ options);
11181122 // compute
11191123 TabulateFusionSeTTebdForward<FPTYPE >(table_tensor, table_info_tensor,
1120- em_x_tensor, em_tensor, last_layer_size,
1121- descriptor_tensor);
1124+ em_x_tensor, em_tensor,
1125+ last_layer_size, descriptor_tensor);
11221126 // save data
11231127 ctx->save_for_backward ({table_tensor, table_info_tensor, em_x_tensor,
11241128 em_tensor, descriptor_tensor});
@@ -1202,8 +1206,8 @@ std::vector<torch::Tensor> tabulate_fusion_se_t_tebd(
12021206 const torch::Tensor& em_x_tensor,
12031207 const torch::Tensor& em_tensor,
12041208 int64_t last_layer_size) {
1205- return TabulateFusionSeTTebdOp::apply (table_tensor, table_info_tensor,
1206- em_x_tensor, em_tensor, last_layer_size);
1209+ return TabulateFusionSeTTebdOp::apply (
1210+ table_tensor, table_info_tensor, em_x_tensor, em_tensor, last_layer_size);
12071211}
12081212
12091213std::vector<torch::Tensor> tabulate_fusion_se_r (
0 commit comments