Skip to content

Commit 7f020e6

Browse files
committed
final touch: info in SVD
1 parent 55f06a4 commit 7f020e6

2 files changed

Lines changed: 6 additions & 16 deletions

File tree

include/tensor.cuh

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,6 @@ private:
12341234
std::shared_ptr<DTensor<unsigned int>> m_rank; ///< Rank of original matrix
12351235
bool m_computeU = false; ///< Whether to compute U
12361236
bool m_destroyMatrix = true; ///< Whether to sacrifice original matrix
1237-
int m_svdStatus = 0;
12381237

12391238
/**
12401239
* Ensures tensor to factorise contains exactly one matrix, and that matrix is tall.
@@ -1256,10 +1255,6 @@ private:
12561255

12571256
public:
12581257

1259-
int statusCode() {
1260-
// redefinition of getStatus
1261-
return m_svdStatus;
1262-
}
12631258

12641259
/**
12651260
* Constructor.
@@ -1269,7 +1264,7 @@ public:
12691264
*/
12701265
Svd(DTensor<T> &mat,
12711266
bool computeU = false,
1272-
bool destroyMatrix = true) : IStatus() {
1267+
bool destroyMatrix = true) : IStatus(mat.numMats()) {
12731268
checkMatrix(mat);
12741269
m_destroyMatrix = destroyMatrix;
12751270
m_tensor = (destroyMatrix) ? &mat : new DTensor<T>(mat);
@@ -1379,10 +1374,7 @@ inline void Svd<double>::factorise() {
13791374
m_workspace->raw(),
13801375
m_lwork,
13811376
nullptr, // rwork (used only if SVD fails)
1382-
m_info->raw()));
1383-
#ifdef GPUTILS_DEBUG_MODE
1384-
m_svdStatus = std::max(m_svdStatus, (*m_info)(0, 0, 0));
1385-
#endif
1377+
m_info->raw() + i));
13861378
}
13871379
}
13881380

@@ -1409,10 +1401,7 @@ inline void Svd<float>::factorise() {
14091401
m_workspace->raw(),
14101402
m_lwork,
14111403
nullptr, // rwork (used only if SVD fails)
1412-
m_info->raw()));
1413-
#ifdef GPUTILS_DEBUG_MODE
1414-
m_svdStatus = std::max(m_svdStatus, (*m_info)(0, 0, 0));
1415-
#endif
1404+
m_info->raw() + i));
14161405
}
14171406
}
14181407

test/testTensor.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -902,14 +902,15 @@ void singularValuesComputation(float epsilon) {
902902
DTensor<T> B(bData, 8, 3);
903903
Svd<T> svd(B, true, false);
904904
svd.factorise();
905-
EXPECT_EQ(0, svd.info()(0));
905+
for (size_t i = 0; i < B.numMats(); i++) {
906+
EXPECT_EQ(0, svd.info()(i));
907+
}
906908
auto S = svd.singularValues();
907909
EXPECT_NEAR(32.496241123753592, S(0), epsilon); // value from MATLAB
908910
EXPECT_NEAR(0.997152358903242, S(1), epsilon); // value from MATLAB
909911

910912
auto U = svd.leftSingularVectors();
911913
EXPECT_TRUE(U.has_value());
912-
EXPECT_EQ(0, svd.info()(0));
913914
}
914915

915916
TEST_F(SvdTest, singularValuesComputation) {

0 commit comments

Comments
 (0)