Skip to content

Commit 0a6dd13

Browse files
author
Katharine Hyatt
committed
Ensure const prop for gaugefix
1 parent 7d2befd commit 0a6dd13

2 files changed

Lines changed: 8 additions & 8 deletions

File tree

src/common/gauge.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function gaugefix!(V::AbstractMatrix)
77
return V
88
end
99

10-
function gaugefix!(::Val{:full}, U, S, Vᴴ, m::Int, n::Int)
10+
function gaugefix!(::typeof(svd_full!), U, S, Vᴴ, m::Int, n::Int)
1111
for j in 1:max(m, n)
1212
if j <= min(m, n)
1313
u = view(U, :, j)
@@ -28,7 +28,7 @@ function gaugefix!(::Val{:full}, U, S, Vᴴ, m::Int, n::Int)
2828
return (U, S, Vᴴ)
2929
end
3030

31-
function gaugefix!(::Val{:compact}, U, S, Vᴴ, m::Int, n::Int)
31+
function gaugefix!(::typeof(svd_compact!), U, S, Vᴴ, m::Int, n::Int)
3232
for j in 1:size(U, 2)
3333
u = view(U, :, j)
3434
v = view(Vᴴ, j, :)
@@ -39,7 +39,7 @@ function gaugefix!(::Val{:compact}, U, S, Vᴴ, m::Int, n::Int)
3939
return (U, S, Vᴴ)
4040
end
4141

42-
function gaugefix!(::Val{:trunc}, U, S, Vᴴ, m::Int, n::Int)
42+
function gaugefix!(::typeof(svd_trunc!), U, S, Vᴴ, m::Int, n::Int)
4343
for j in 1:min(m, n)
4444
u = view(U, :, j)
4545
v = view(Vᴴ, j, :)

src/implementations/svd.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
101101
S[i, 1] = zero(eltype(S))
102102
end
103103
# TODO: make this controllable using a `gaugefix` keyword argument
104-
gaugefix!(Val(:full), U, S, Vᴴ, m, n)
104+
gaugefix!(svd_full!, U, S, Vᴴ, m, n)
105105
return USVᴴ
106106
end
107107

@@ -126,7 +126,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
126126
throw(ArgumentError("Unsupported SVD algorithm"))
127127
end
128128
# TODO: make this controllable using a `gaugefix` keyword argument
129-
gaugefix!(Val(:compact), U, S, Vᴴ, size(A)...)
129+
gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...)
130130
return USVᴴ
131131
end
132132

@@ -227,7 +227,7 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgor
227227
diagview(S) .= view(S, 1:minmn, 1)
228228
view(S, 2:minmn, 1) .= zero(eltype(S))
229229
# TODO: make this controllable using a `gaugefix` keyword argument
230-
gaugefix!(Val(:full), U, S, Vᴴ, m, n)
230+
gaugefix!(svd_full!, U, S, Vᴴ, m, n)
231231
return USVᴴ
232232
end
233233

@@ -236,7 +236,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
236236
U, S, Vᴴ = USVᴴ
237237
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
238238
# TODO: make this controllable using a `gaugefix` keyword argument
239-
gaugefix!(Val(:trunc), U, S, Vᴴ, size(A)...)
239+
gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...)
240240
return truncate!(svd_trunc!, USVᴴ, alg.trunc)
241241
end
242242

@@ -255,7 +255,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
255255
throw(ArgumentError("Unsupported SVD algorithm"))
256256
end
257257
# TODO: make this controllable using a `gaugefix` keyword argument
258-
gaugefix!(Val(:compact), U, S, Vᴴ, size(A)...)
258+
gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...)
259259
return USVᴴ
260260
end
261261
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))

0 commit comments

Comments
 (0)