@@ -389,8 +389,7 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
389389 const torch::Tensor& em_tensor,
390390 const torch::Tensor& dy_tensor,
391391 const torch::Tensor& descriptor_tensor,
392- torch::Tensor& dy_dem_x_tensor,
393- torch::Tensor& dy_dem_tensor) {
392+ torch::Tensor& dy_dem_x_tensor) {
394393 // check input shape
395394 if (dy_tensor.dim () != 4 ) {
396395 throw std::invalid_argument (" Dim of dy_tensor should be 4" );
@@ -399,7 +398,6 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
399398 GetTensorDevice (table_tensor, device);
400399 // flat the tensors
401400 FPTYPE * dy_dem_x = dy_dem_x_tensor.view ({-1 }).data_ptr <FPTYPE >();
402- FPTYPE * dy_dem = dy_dem_tensor.view ({-1 }).data_ptr <FPTYPE >();
403401
404402 const FPTYPE * table = table_tensor.view ({-1 }).data_ptr <FPTYPE >();
405403 const FPTYPE * table_info = table_info_tensor.view ({-1 }).data_ptr <FPTYPE >();
@@ -430,6 +428,54 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
430428 }
431429}
432430
431+ template <typename FPTYPE >
432+ void TabulateFusionSeTTebdGradGradForward (const torch::Tensor& table_tensor,
433+ const torch::Tensor& table_info_tensor,
434+ const torch::Tensor& em_x_tensor,
435+ const torch::Tensor& em_tensor,
436+ const torch::Tensor& dz_dy_dem_x_tensor,
437+ const torch::Tensor& descriptor_tensor,
438+ torch::Tensor& dz_dy_tensor) {
439+ // Check input shape
440+ if (dz_dy_dem_x_tensor.dim () != 3 ) {
441+ throw std::invalid_argument (" Dim of dz_dy_dem_x should be 3" );
442+ }
443+ // get the device
444+ std::string device;
445+ GetTensorDevice (table_tensor, device);
446+ // flat the tensors
447+ FPTYPE * dz_dy = dz_dy_tensor.view ({-1 }).data_ptr <FPTYPE >();
448+
449+ const FPTYPE * table = table_tensor.view ({-1 }).data_ptr <FPTYPE >();
450+ const FPTYPE * table_info = table_info_tensor.view ({-1 }).data_ptr <FPTYPE >();
451+ const FPTYPE * em_x = em_x_tensor.view ({-1 }).data_ptr <FPTYPE >();
452+ const FPTYPE * em = em_tensor.view ({-1 }).data_ptr <FPTYPE >();
453+ const FPTYPE * dz_dy_dem_x = dz_dy_dem_x_tensor.view ({-1 }).data_ptr <FPTYPE >();
454+ const int64_t nloc = em_x_tensor.size (0 );
455+ const int64_t nnei_i = em_x_tensor.size (1 );
456+ const int64_t nnei_j = em_x_tensor.size (2 );
457+ const int64_t last_layer_size = descriptor_tensor.size (3 );
458+ // compute
459+ if (device == " GPU" ) {
460+ #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
461+ deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu (dz_dy, table, table_info, em_x,
462+ em, dz_dy_dem_x, nloc,
463+ nnei_i, nnei_j, last_layer_size);
464+ #else
465+ throw std::runtime_error (
466+ " The input tensor is on the GPU, but the GPU support for the "
467+ " customized OP library is not enabled." );
468+ #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
469+ TORCH_CHECK (last_layer_size <= 1024 ,
470+ " In the process of model compression, the size of the "
471+ " last layer of embedding net must be less than 1024!" );
472+ } else if (device == " CPU" ) {
473+ deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu (dz_dy, table, table_info, em_x,
474+ em, dz_dy_dem_x, nloc,
475+ nnei_i, nnei_j, last_layer_size);
476+ }
477+ }
478+
433479template <typename FPTYPE >
434480void TabulateFusionSeRForward (const torch::Tensor& table_tensor,
435481 const torch::Tensor& table_info_tensor,
@@ -1107,13 +1153,12 @@ class TabulateFusionSeTTebdOp
11071153 torch::Tensor dy_tensor = grad_output[0 ].contiguous ();
11081154 // allocate output tensors
11091155 torch::Tensor dy_dem_x_tensor = torch::zeros_like (em_x_tensor);
1110- torch::Tensor dy_dem_tensor = torch::zeros_like (em_tensor);
11111156 // compute
11121157 TabulateFusionSeTTebdGradForward<FPTYPE >(
11131158 table_tensor, table_info_tensor, em_x_tensor, em_tensor, dy_tensor,
1114- descriptor_tensor, dy_dem_x_tensor, dy_dem_tensor );
1159+ descriptor_tensor, dy_dem_x_tensor);
11151160
1116- return {at::Tensor (), at::Tensor (), dy_dem_x_tensor, dy_dem_tensor ,
1161+ return {at::Tensor (), at::Tensor (), dy_dem_x_tensor, at::Tensor () ,
11171162 at::Tensor ()};
11181163 }
11191164};
0 commit comments