@@ -747,16 +747,38 @@ TEST_F(TensorTest, tensorAddAB) {
747747 * --------------------------------------- */
748748
749749TEMPLATE_WITH_TYPE_T
750- void tensorSlicePtrMatrices () {
750+ void tensorSliceAxis2PtrMatrices () {
751751 std::vector<T> dataA = TENSOR_DATA_234A ;
752752 DTensor<T> d_A (dataA, 2 , 3 , 4 );
753753 DTensor<T> d_ASlice (d_A, 2 , 2 , 3 );
754754 EXPECT_TRUE (d_ASlice.ptrMatrices () == d_A.ptrMatrices () + 2 );
755755}
756756
757- TEST_F (TensorTest, tensorSlicePtrMatrices) {
758- tensorSlicePtrMatrices<float >();
759- tensorSlicePtrMatrices<double >();
757+ TEST_F (TensorTest, tensorSliceAxis2PtrMatrices) {
758+ tensorSliceAxis2PtrMatrices<float >();
759+ tensorSliceAxis2PtrMatrices<double >();
760+ tensorSliceAxis2PtrMatrices<int >();
761+ }
762+
763+ /* ---------------------------------------
764+ * Tensor: slice ptrMatrices
765+ * axis = 0 and 1
766+ * --------------------------------------- */
767+
768+ TEMPLATE_WITH_TYPE_T
769+ void tensorSliceAxis01PtrMatrices () {
770+ std::vector<T> dataA = TENSOR_DATA_234A ;
771+ DTensor<T> d_A (dataA, 2 , 3 , 4 );
772+ DTensor<T> d_ASlice0 (d_A, 0 , 0 , 1 );
773+ EXPECT_TRUE (!d_ASlice0.ptrMatrices ());
774+ DTensor<T> d_ASlice1 (d_A, 1 , 0 , 2 );
775+ EXPECT_TRUE (!d_ASlice0.ptrMatrices ());
776+ }
777+
778+ TEST_F (TensorTest, tensorSliceAxis01PtrMatrices) {
779+ tensorSliceAxis01PtrMatrices<float >();
780+ tensorSliceAxis01PtrMatrices<double >();
781+ tensorSliceAxis01PtrMatrices<int >();
760782}
761783
762784/* ---------------------------------------
0 commit comments