Skip to content

Commit e9a788a

Browse files
committed
Split up tests and address other comments
1 parent 8ffea38 commit e9a788a

4 files changed

Lines changed: 188 additions & 131 deletions

File tree

src/interface/orthnull.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ left_orth_alg(alg::LeftOrthAlgorithm) = alg
443443
left_orth_alg(alg::QRAlgorithms) = LeftOrthViaQR(alg)
444444
left_orth_alg(alg::PolarAlgorithms) = LeftOrthViaPolar(alg)
445445
left_orth_alg(alg::SVDAlgorithms) = LeftOrthViaSVD(alg)
446+
left_orth_alg(alg::DiagonalAlgorithm) = LeftOrthViaSVD(alg)
446447
left_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftOrthViaSVD(alg)
447448

448449
"""
@@ -478,6 +479,7 @@ right_orth_alg(alg::RightOrthAlgorithm) = alg
478479
right_orth_alg(alg::LQAlgorithms) = RightOrthViaLQ(alg)
479480
right_orth_alg(alg::PolarAlgorithms) = RightOrthViaPolar(alg)
480481
right_orth_alg(alg::SVDAlgorithms) = RightOrthViaSVD(alg)
482+
right_orth_alg(alg::DiagonalAlgorithm) = RightOrthViaSVD(alg)
481483
right_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightOrthViaSVD(alg)
482484

483485
"""

test/orthnull.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
1818
TestSuite.seed_rng!(123)
1919
if T BLASFloats
2020
if CUDA.functional()
21-
TestSuite.test_orthnull(CuMatrix{T}, (m, n))
22-
n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m)
21+
TestSuite.test_orthnull(CuMatrix{T}, (m, n); test_nullity = false)
22+
n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m; test_orthnull = false)
2323
end
2424
if AMDGPU.functional()
25-
TestSuite.test_orthnull(ROCMatrix{T}, (m, n))
26-
n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m)
25+
TestSuite.test_orthnull(ROCMatrix{T}, (m, n); test_nullity = false)
26+
n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m; test_orthnull = false)
2727
end
2828
end
2929
if !is_buildkite
3030
TestSuite.test_orthnull(T, (m, n))
31-
#AT = Diagonal{T, Vector{T}}
32-
#TestSuite.test_orthnull(AT, m) # not supported
31+
AT = Diagonal{T, Vector{T}}
32+
TestSuite.test_orthnull(AT, m; test_orthnull = false)
3333
end
3434
end

test/testsuite/TestSuite.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module TestSuite
1111
using Test
1212
using MatrixAlgebraKit
1313
using MatrixAlgebraKit: diagview
14-
using LinearAlgebra: Diagonal, norm, istriu, istril
14+
using LinearAlgebra: Diagonal, norm, istriu, istril, I
1515
using Random, StableRNGs
1616
using AMDGPU, CUDA
1717

@@ -69,6 +69,13 @@ is_positive(alg::MatrixAlgebraKit.ROCSOLVER_HouseholderQR) = alg.positive
6969
is_positive(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_positive(alg.qr_alg)
7070
is_pivoted(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_pivoted(alg.qr_alg)
7171

72+
isleftcomplete(V, N) = V * V' + N * N' I
73+
isleftcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isleftcomplete(collect(V), collect(N))
74+
isleftcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isleftcomplete(collect(V), collect(N))
75+
isrightcomplete(Vᴴ, Nᴴ) = Vᴴ' * Vᴴ + Nᴴ' * Nᴴ I
76+
isrightcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isrightcomplete(collect(V), collect(N))
77+
isrightcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isrightcomplete(collect(V), collect(N))
78+
7279
include("qr.jl")
7380
include("lq.jl")
7481
include("polar.jl")

0 commit comments

Comments
 (0)