@@ -194,6 +194,7 @@ void tensorSlicingConstructorAxis2() {
194194 EXPECT_EQ (3 , tensSlice.numCols ());
195195 EXPECT_EQ (2 , tensSlice.numMats ());
196196 EXPECT_EQ (tens.raw (), tensSlice.raw ()); // it is indeed a slice
197+ EXPECT_TRUE (tensSlice.ptrMatrices () != nullptr );
197198}
198199
199200TEST_F (TensorTest, tensorSlicingConstructorAxis2) {
@@ -215,6 +216,7 @@ void tensorSlicingConstructorAxis1() {
215216 EXPECT_EQ (2 , tenzSlice.numRows ());
216217 EXPECT_EQ (2 , tenzSlice.numCols ());
217218 EXPECT_EQ (1 , tenzSlice.numMats ());
219+ EXPECT_TRUE (tenzSlice.ptrMatrices () == nullptr );
218220 std::vector<T> expected = {3 , 4 , 5 , 6 };
219221 std::vector<T> tenzSliceDown (4 );
220222 tenzSlice.download (tenzSliceDown);
@@ -229,7 +231,7 @@ TEST_F(TensorTest, tensorSlicingConstructorAxis1) {
229231
230232/* ---------------------------------------
231233 * Tensor: Slicing constructor
232- * axis = 0 (columns )
234+ * axis = 0 (rows )
233235 * --------------------------------------- */
234236
235237TEMPLATE_WITH_TYPE_T
@@ -240,6 +242,7 @@ void tensorSlicingConstructorAxis0() {
240242 EXPECT_EQ (2 , tenzSlice.numRows ());
241243 EXPECT_EQ (1 , tenzSlice.numCols ());
242244 EXPECT_EQ (1 , tenzSlice.numMats ());
245+ EXPECT_TRUE (tenzSlice.ptrMatrices () == nullptr );
243246 std::vector<T> expected = {3 , 4 };
244247 std::vector<T> tenzSliceDown (2 );
245248 tenzSlice.download (tenzSliceDown);
@@ -738,6 +741,46 @@ TEST_F(TensorTest, tensorAddAB) {
738741 tensorAddAB<float >();
739742}
740743
744+ /* ---------------------------------------
745+ * Tensor: slice ptrMatrices
746+ * axis = 2 (matrices)
747+ * --------------------------------------- */
748+
749+ TEMPLATE_WITH_TYPE_T
750+ void tensorSliceAxis2PtrMatrices () {
751+ std::vector<T> dataA = TENSOR_DATA_234A ;
752+ DTensor<T> d_A (dataA, 2 , 3 , 4 );
753+ DTensor<T> d_ASlice (d_A, 2 , 2 , 3 );
754+ EXPECT_TRUE (d_ASlice.ptrMatrices () == d_A.ptrMatrices () + 2 );
755+ }
756+
757+ TEST_F (TensorTest, tensorSliceAxis2PtrMatrices) {
758+ tensorSliceAxis2PtrMatrices<float >();
759+ tensorSliceAxis2PtrMatrices<double >();
760+ tensorSliceAxis2PtrMatrices<int >();
761+ }
762+
763+ /* ---------------------------------------
764+ * Tensor: slice ptrMatrices
765+ * axis = 0 and 1
766+ * --------------------------------------- */
767+
768+ TEMPLATE_WITH_TYPE_T
769+ void tensorSliceAxis01PtrMatrices () {
770+ std::vector<T> dataA = TENSOR_DATA_234A ;
771+ DTensor<T> d_A (dataA, 2 , 3 , 4 );
772+ DTensor<T> d_ASlice0 (d_A, 0 , 0 , 1 );
773+ EXPECT_TRUE (!d_ASlice0.ptrMatrices ());
774+ DTensor<T> d_ASlice1 (d_A, 1 , 0 , 2 );
775+ EXPECT_TRUE (!d_ASlice0.ptrMatrices ());
776+ }
777+
778+ TEST_F (TensorTest, tensorSliceAxis01PtrMatrices) {
779+ tensorSliceAxis01PtrMatrices<float >();
780+ tensorSliceAxis01PtrMatrices<double >();
781+ tensorSliceAxis01PtrMatrices<int >();
782+ }
783+
741784/* ---------------------------------------
742785 * Tensor: getRows
743786 * --------------------------------------- */
0 commit comments