Skip to content

Commit e829fae

Browse files
committed
DTensor slice constructor: further testing
1 parent 2258c1d commit e829fae

1 file changed

Lines changed: 26 additions & 4 deletions

File tree

test/testTensor.cu

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -747,16 +747,38 @@ TEST_F(TensorTest, tensorAddAB) {
747747
* --------------------------------------- */
748748

749749
TEMPLATE_WITH_TYPE_T
750-
void tensorSlicePtrMatrices() {
750+
void tensorSliceAxis2PtrMatrices() {
751751
std::vector<T> dataA = TENSOR_DATA_234A;
752752
DTensor<T> d_A(dataA, 2, 3, 4);
753753
DTensor<T> d_ASlice(d_A, 2, 2, 3);
754754
EXPECT_TRUE(d_ASlice.ptrMatrices() == d_A.ptrMatrices() + 2);
755755
}
756756

757-
TEST_F(TensorTest, tensorSlicePtrMatrices) {
758-
tensorSlicePtrMatrices<float>();
759-
tensorSlicePtrMatrices<double>();
757+
TEST_F(TensorTest, tensorSliceAxis2PtrMatrices) {
758+
tensorSliceAxis2PtrMatrices<float>();
759+
tensorSliceAxis2PtrMatrices<double>();
760+
tensorSliceAxis2PtrMatrices<int>();
761+
}
762+
763+
/* ---------------------------------------
764+
* Tensor: slice ptrMatrices
765+
* axis = 0 and 1
766+
* --------------------------------------- */
767+
768+
TEMPLATE_WITH_TYPE_T
769+
void tensorSliceAxis01PtrMatrices() {
770+
std::vector<T> dataA = TENSOR_DATA_234A;
771+
DTensor<T> d_A(dataA, 2, 3, 4);
772+
DTensor<T> d_ASlice0(d_A, 0, 0, 1);
773+
EXPECT_TRUE(!d_ASlice0.ptrMatrices());
774+
DTensor<T> d_ASlice1(d_A, 1, 0, 2);
775+
EXPECT_TRUE(!d_ASlice0.ptrMatrices());
776+
}
777+
778+
TEST_F(TensorTest, tensorSliceAxis01PtrMatrices) {
779+
tensorSliceAxis01PtrMatrices<float>();
780+
tensorSliceAxis01PtrMatrices<double>();
781+
tensorSliceAxis01PtrMatrices<int>();
760782
}
761783

762784
/* ---------------------------------------

0 commit comments

Comments
 (0)