@@ -346,8 +346,8 @@ void TabulateFusionSeTTebdForward(const torch::Tensor& table_tensor,
346346 if (table_tensor.dim () != 2 ) {
347347 throw std::invalid_argument (" Dim of table should be 2" );
348348 }
349- if (em_x_tensor.dim () != 3 ) {
350- throw std::invalid_argument (" Dim of em_x should be 3 " );
349+ if (em_x_tensor.dim () != 2 ) {
350+ throw std::invalid_argument (" Dim of em_x should be 2 " );
351351 }
352352 if (em_tensor.dim () != 3 ) {
353353 throw std::invalid_argument (" Dim of em should be 3" );
@@ -363,9 +363,9 @@ void TabulateFusionSeTTebdForward(const torch::Tensor& table_tensor,
363363 const FPTYPE* em_x = em_x_tensor.view ({-1 }).data_ptr <FPTYPE>();
364364 const FPTYPE* em = em_tensor.view ({-1 }).data_ptr <FPTYPE>();
365365
366- const int64_t nloc = em_x_tensor .size (0 );
367- const int64_t nnei_i = em_x_tensor .size (1 );
368- const int64_t nnei_j = em_x_tensor .size (2 );
366+ const int64_t nloc = em_tensor .size (0 );
367+ const int64_t nnei_i = em_tensor .size (1 );
368+ const int64_t nnei_j = em_tensor .size (2 );
369369 // compute
370370 if (device == " GPU" ) {
371371#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -405,9 +405,9 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
405405 const FPTYPE* em = em_tensor.view ({-1 }).data_ptr <FPTYPE>();
406406 const FPTYPE* dy = dy_tensor.view ({-1 }).data_ptr <FPTYPE>();
407407
408- const int64_t nloc = em_x_tensor .size (0 );
409- const int64_t nnei_i = em_x_tensor .size (1 );
410- const int64_t nnei_j = em_x_tensor .size (2 );
408+ const int64_t nloc = em_tensor .size (0 );
409+ const int64_t nnei_i = em_tensor .size (1 );
410+ const int64_t nnei_j = em_tensor .size (2 );
411411 const int64_t last_layer_size = descriptor_tensor.size (3 );
412412
413413 // compute
@@ -451,9 +451,9 @@ void TabulateFusionSeTTebdGradGradForward(const torch::Tensor& table_tensor,
451451 const FPTYPE* em_x = em_x_tensor.view ({-1 }).data_ptr <FPTYPE>();
452452 const FPTYPE* em = em_tensor.view ({-1 }).data_ptr <FPTYPE>();
453453 const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.view ({-1 }).data_ptr <FPTYPE>();
454- const int64_t nloc = em_x_tensor .size (0 );
455- const int64_t nnei_i = em_x_tensor .size (1 );
456- const int64_t nnei_j = em_x_tensor .size (2 );
454+ const int64_t nloc = em_tensor .size (0 );
455+ const int64_t nnei_i = em_tensor .size (1 );
456+ const int64_t nnei_j = em_tensor .size (2 );
457457 const int64_t last_layer_size = descriptor_tensor.size (3 );
458458 // compute
459459 if (device == " GPU" ) {
@@ -1113,7 +1113,7 @@ class TabulateFusionSeTTebdOp
11131113 .dtype (table_tensor.dtype ())
11141114 .device (table_tensor.device ());
11151115 torch::Tensor descriptor_tensor = torch::empty (
1116- {em_x_tensor .size (0 ), em_x_tensor .size (1 ), em_x_tensor .size (2 ), last_layer_size},
1116+ {em_tensor .size (0 ), em_tensor .size (1 ), em_tensor .size (2 ), last_layer_size},
11171117 options);
11181118 // compute
11191119 TabulateFusionSeTTebdForward<FPTYPE>(table_tensor, table_info_tensor,
0 commit comments