diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 8ea0f08d2..efd546a59 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -999,7 +999,10 @@ RaggedShape Transpose(RaggedShape &src, Array1 *value_indexes) { K2_CHECK_GT(src.NumAxes(), 2); ContextPtr c = src.Context(); int32_t src_dim0 = src.Dim0(), src_tot_size1 = src.TotSize(1); - if (src_dim0 <= 0) return src; + if (src_dim0 <= 0) { + if (value_indexes) *value_indexes = Array1(c, 0); + return src; + } int32_t src_dim1 = src_tot_size1 / src_dim0; K2_CHECK_EQ(src_tot_size1 % src_dim0, 0) << "Transpose(): all dims on axis 0 must be the same.\n" diff --git a/k2/csrc/ragged_test.cu b/k2/csrc/ragged_test.cu index 9efda15dc..0bc74543d 100644 --- a/k2/csrc/ragged_test.cu +++ b/k2/csrc/ragged_test.cu @@ -763,6 +763,62 @@ template void TestTransposeRagged() { ContextPtr cpu = GetCpuContext(); // will be used to copy data for (auto &context : {GetCpuContext(), GetCudaContext()}) { + // empty case, fsavec with a empty fsa + { + const std::vector row_splits1_vec = {0, 0}; + const std::vector row_splits2_vec = {0}; + Array1 row_splits1(context, row_splits1_vec); + Array1 row_splits2(context, row_splits2_vec); + RaggedShape src_shape = + RaggedShape3(&row_splits1, nullptr, -1, &row_splits2, nullptr, -1); + ASSERT_EQ(src_shape.Dim0(), 1); + ASSERT_EQ(src_shape.TotSize(1), 0); + + Array1 values_array(context, 0); + ASSERT_EQ(values_array.Dim(), src_shape.NumElements()); + + Ragged ragged(src_shape, values_array); + Ragged ans = Transpose(ragged); + RaggedShape shape = ans.shape; + // Check shape + ASSERT_EQ(shape.Dim0(), 0); + ASSERT_EQ(shape.TotSize(1), 0); + CheckArrayData(shape.RowSplits(1), std::vector({0})); + CheckArrayData(shape.RowSplits(2), std::vector({0})); + K2_CHECK_EQ(shape.RowIds(1).Dim(), 0); + K2_CHECK_EQ(shape.RowIds(2).Dim(), 0); + // Check values + K2_CHECK_EQ(ans.values.Dim(), 0); + } + + // empty case, fsavec without any fsa + { + const std::vector row_splits1_vec = {0}; + const std::vector row_splits2_vec = {0}; + Array1 row_splits1(context, row_splits1_vec); + Array1 row_splits2(context, row_splits2_vec); + RaggedShape src_shape = + RaggedShape3(&row_splits1, nullptr, -1, &row_splits2, nullptr, -1); + ASSERT_EQ(src_shape.Dim0(), 0); + ASSERT_EQ(src_shape.TotSize(1), 0); + + Array1 values_array(context, 0); + ASSERT_EQ(values_array.Dim(), src_shape.NumElements()); + + Ragged ragged(src_shape, values_array); + Ragged ans = Transpose(ragged); + RaggedShape shape = ans.shape; + // Check shape + ASSERT_EQ(shape.Dim0(), 0); + ASSERT_EQ(shape.TotSize(1), 0); + CheckArrayData(shape.RowSplits(1), std::vector({0})); + CheckArrayData(shape.RowSplits(2), std::vector({0})); + K2_CHECK_EQ(shape.RowIds(1).Dim(), 0); + K2_CHECK_EQ(shape.RowIds(2).Dim(), 0); + // Check values + K2_CHECK_EQ(ans.values.Dim(), 0); + } + { const std::vector row_splits1_vec = {0, 2, 4, 6}; const std::vector row_splits2_vec = {0, 3, 4, 7, 8, 10, 12}; diff --git a/k2/python/csrc/torch/torch_util.h b/k2/python/csrc/torch/torch_util.h index da1f3e98e..1611ab93d 100644 --- a/k2/python/csrc/torch/torch_util.h +++ b/k2/python/csrc/torch/torch_util.h @@ -119,8 +119,11 @@ Array1 FromTorch(torch::Tensor &tensor) { K2_CHECK_EQ(tensor.scalar_type(), ToScalarType::value) << "Expected scalar type: " << ToScalarType::value << ". Given: " << tensor.scalar_type(); - K2_CHECK_EQ(tensor.strides()[0], 1) - << "Expected stride: 1. Given: " << tensor.strides()[0]; + // Some empty tensor may have stride not equal to 1, e.g., tensor returned by + // clone() method, it is valid here, so we won't check its strieds. + if (tensor.numel()) + K2_CHECK_EQ(tensor.strides()[0], 1) + << "Expected stride: 1. Given: " << tensor.strides()[0]; auto region = NewRegion(tensor); Array1 ans(tensor.numel(), region, 0);