Skip to content

Commit 6d87fa2

Browse files
committed
bug fix: use em_tensor.size
1 parent d243804 commit 6d87fa2

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

source/op/pt/tabulate_multi_device.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ void TabulateFusionSeTTebdForward(const torch::Tensor& table_tensor,
346346
if (table_tensor.dim() != 2) {
347347
throw std::invalid_argument("Dim of table should be 2");
348348
}
349-
if (em_x_tensor.dim() != 3) {
350-
throw std::invalid_argument("Dim of em_x should be 3");
349+
if (em_x_tensor.dim() != 2) {
350+
throw std::invalid_argument("Dim of em_x should be 2");
351351
}
352352
if (em_tensor.dim() != 3) {
353353
throw std::invalid_argument("Dim of em should be 3");
@@ -363,9 +363,9 @@ void TabulateFusionSeTTebdForward(const torch::Tensor& table_tensor,
363363
const FPTYPE* em_x = em_x_tensor.view({-1}).data_ptr<FPTYPE>();
364364
const FPTYPE* em = em_tensor.view({-1}).data_ptr<FPTYPE>();
365365

366-
const int64_t nloc = em_x_tensor.size(0);
367-
const int64_t nnei_i = em_x_tensor.size(1);
368-
const int64_t nnei_j = em_x_tensor.size(2);
366+
const int64_t nloc = em_tensor.size(0);
367+
const int64_t nnei_i = em_tensor.size(1);
368+
const int64_t nnei_j = em_tensor.size(2);
369369
// compute
370370
if (device == "GPU") {
371371
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -405,9 +405,9 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
405405
const FPTYPE* em = em_tensor.view({-1}).data_ptr<FPTYPE>();
406406
const FPTYPE* dy = dy_tensor.view({-1}).data_ptr<FPTYPE>();
407407

408-
const int64_t nloc = em_x_tensor.size(0);
409-
const int64_t nnei_i = em_x_tensor.size(1);
410-
const int64_t nnei_j = em_x_tensor.size(2);
408+
const int64_t nloc = em_tensor.size(0);
409+
const int64_t nnei_i = em_tensor.size(1);
410+
const int64_t nnei_j = em_tensor.size(2);
411411
const int64_t last_layer_size = descriptor_tensor.size(3);
412412

413413
// compute
@@ -451,9 +451,9 @@ void TabulateFusionSeTTebdGradGradForward(const torch::Tensor& table_tensor,
451451
const FPTYPE* em_x = em_x_tensor.view({-1}).data_ptr<FPTYPE>();
452452
const FPTYPE* em = em_tensor.view({-1}).data_ptr<FPTYPE>();
453453
const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.view({-1}).data_ptr<FPTYPE>();
454-
const int64_t nloc = em_x_tensor.size(0);
455-
const int64_t nnei_i = em_x_tensor.size(1);
456-
const int64_t nnei_j = em_x_tensor.size(2);
454+
const int64_t nloc = em_tensor.size(0);
455+
const int64_t nnei_i = em_tensor.size(1);
456+
const int64_t nnei_j = em_tensor.size(2);
457457
const int64_t last_layer_size = descriptor_tensor.size(3);
458458
// compute
459459
if (device == "GPU") {
@@ -1113,7 +1113,7 @@ class TabulateFusionSeTTebdOp
11131113
.dtype(table_tensor.dtype())
11141114
.device(table_tensor.device());
11151115
torch::Tensor descriptor_tensor = torch::empty(
1116-
{em_x_tensor.size(0), em_x_tensor.size(1), em_x_tensor.size(2), last_layer_size},
1116+
{em_tensor.size(0), em_tensor.size(1), em_tensor.size(2), last_layer_size},
11171117
options);
11181118
// compute
11191119
TabulateFusionSeTTebdForward<FPTYPE>(table_tensor, table_info_tensor,

0 commit comments

Comments
 (0)