@@ -6,21 +6,31 @@ export eigen,
66 eigvals!!,
77 factorize,
88 factorize!!,
9+ gram_eigh_full,
10+ gram_eigh_full!!,
11+ gram_eigh_full_with_pinv,
12+ gram_eigh_full_with_pinv!!,
13+ invsqrt_diag_safe,
14+ invsqrth_safe,
915 lq,
1016 lq!!,
1117 orth,
1218 orth!!,
1319 polar,
1420 polar!!,
21+ pow_diag_safe,
22+ powh_safe,
1523 qr,
1624 qr!!,
25+ sqrt_diag_safe,
26+ sqrth_safe,
1727 svd,
1828 svd!!,
1929 svdvals,
2030 svdvals!!
2131
2232import MatrixAlgebraKit as MAK
23- using LinearAlgebra: LinearAlgebra, norm
33+ using LinearAlgebra: LinearAlgebra, Diagonal, isdiag, norm
2434
2535for (f, f_full, f_compact) in (
2636 (:qr , :qr_full , :qr_compact ),
@@ -74,6 +84,209 @@ for (eigvals, eigh_vals, eig_vals) in
7484 end
7585end
7686
87+ function _clamp_kwargs_doc (arg:: AbstractString )
88+ return join (
89+ (
90+ " - `atol::Real`: absolute clamping threshold. Default `0`." ,
91+ " - `rtol::Real`: relative clamping threshold. Default `eps(real(eltype($arg )))^(3//4)` when `atol = 0`, else `0`." ,
92+ ), " \n "
93+ )
94+ end
95+
96+ """
97+ pow_diag_safe(D::AbstractMatrix, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p
98+
99+ Raise a diagonal-structured matrix `D` to the power `p`. Diagonal entries
100+ `d` of `MAK.diagview(D)` with `abs(d) < tol` are clamped to zero before
101+ exponentiation, where `tol = max(atol, rtol * maximum(abs, diagview(D)))`.
102+ Negative `d` above `tol` cause `d^p` to error for fractional `p` (e.g.
103+ `p = 1//2`) and pass through for integer `p`, so the operation itself
104+ enforces the PSD precondition per-power. Errors if `isdiag(D)` is `false`.
105+
106+ The implementation extracts entries via `MAK.diagview` and rebuilds via
107+ `MAK.diagonal`, so types extending those (e.g. graded or block diagonal)
108+ automatically extend [`sqrt_diag_safe`](@ref), [`invsqrt_diag_safe`](@ref),
109+ and the [`powh_safe`](@ref) family.
110+
111+ ## Keyword arguments
112+
113+ $(_clamp_kwargs_doc (" D" ))
114+ """
115+ function pow_diag_safe (
116+ D:: AbstractMatrix , p;
117+ atol = zero (real (eltype (D))),
118+ rtol = iszero (atol) ? eps (real (eltype (D)))^ (3 // 4 ) :
119+ zero (real (eltype (D)))
120+ )
121+ isdiag (D) || throw (
122+ ArgumentError (" pow_diag_safe requires a diagonal-structured matrix" )
123+ )
124+ σ = MAK. diagview (D)
125+ tol = max (atol, rtol * maximum (abs, σ; init = zero (real (eltype (D)))))
126+ return MAK. diagonal (map (d -> abs (d) < tol ? zero (d) : real (d)^ p, σ))
127+ end
128+
129+ """
130+ sqrt_diag_safe(D::AbstractMatrix; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(1//2)
131+
132+ Square root of a diagonal-structured matrix `D`, equivalent to
133+ `pow_diag_safe(D, 1//2; atol, rtol)`.
134+
135+ ## Keyword arguments
136+
137+ $(_clamp_kwargs_doc (" D" ))
138+ """
139+ sqrt_diag_safe (D:: AbstractMatrix ; kwargs... ) = pow_diag_safe (D, 1 // 2 ; kwargs... )
140+
141+ """
142+ invsqrt_diag_safe(D::AbstractMatrix; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(-1//2)
143+
144+ Inverse square root of a diagonal-structured matrix `D`, treating diagonal
145+ entries below tolerance as zero (Moore-Penrose convention). Equivalent to
146+ `pow_diag_safe(D, -1//2; atol, rtol)`.
147+
148+ ## Keyword arguments
149+
150+ $(_clamp_kwargs_doc (" D" ))
151+ """
152+ invsqrt_diag_safe (D:: AbstractMatrix ; kwargs... ) = pow_diag_safe (D, - 1 // 2 ; kwargs... )
153+
154+ """
155+ powh_safe(M::AbstractMatrix, p; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^p
156+
157+ Raise an approximately Hermitian positive semi-definite matrix to the
158+ power `p`. For diagonal-structured `M` (`isdiag(M) == true`), dispatches
159+ to [`pow_diag_safe`](@ref) and skips the eigendecomposition. Otherwise,
160+ computes via `M = V * D * V'` as `V * pow_diag_safe(D, p; atol, rtol) * V'`.
161+
162+ ## Keyword arguments
163+
164+ - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
165+
166+ $(_clamp_kwargs_doc (" M" ))
167+ """
168+ function powh_safe (M:: AbstractMatrix , p; alg = nothing , kwargs... )
169+ isdiag (M) && return pow_diag_safe (M, p; kwargs... )
170+ D, V = MAK. eigh_full (M, MAK. select_algorithm (MAK. eigh_full, M, alg))
171+ return V * pow_diag_safe (D, p; kwargs... ) * V'
172+ end
173+
174+ """
175+ sqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(1//2)
176+
177+ Square root of an approximately Hermitian positive semi-definite matrix.
178+ Equivalent to `powh_safe(M, 1//2; alg, atol, rtol)`.
179+
180+ ## Keyword arguments
181+
182+ - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
183+
184+ $(_clamp_kwargs_doc (" M" ))
185+ """
186+ sqrth_safe (M:: AbstractMatrix ; kwargs... ) = powh_safe (M, 1 // 2 ; kwargs... )
187+
188+ """
189+ invsqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(-1//2)
190+
191+ Inverse square root of an approximately Hermitian positive semi-definite
192+ matrix. Equivalent to `powh_safe(M, -1//2; alg, atol, rtol)`.
193+
194+ ## Keyword arguments
195+
196+ - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
197+
198+ $(_clamp_kwargs_doc (" M" ))
199+ """
200+ invsqrth_safe (M:: AbstractMatrix ; kwargs... ) = powh_safe (M, - 1 // 2 ; kwargs... )
201+
202+ for (gram, gram_with_pinv, eigh_full) in (
203+ (:gram_eigh_full , :gram_eigh_full_with_pinv , :eigh_full ),
204+ (:gram_eigh_full!! , :gram_eigh_full_with_pinv!! , :eigh_full! ),
205+ )
206+ @eval begin
207+ function $gram (A:: AbstractMatrix ; alg = nothing , kwargs... )
208+ D, V = MAK.$ eigh_full (A, MAK. select_algorithm (MAK.$ eigh_full, A, alg))
209+ return sqrth_safe (D; kwargs... ) * V'
210+ end
211+ function $gram_with_pinv (A:: AbstractMatrix ; alg = nothing , kwargs... )
212+ D, V = MAK.$ eigh_full (A, MAK. select_algorithm (MAK.$ eigh_full, A, alg))
213+ return sqrth_safe (D; kwargs... ) * V' , V * invsqrth_safe (D; kwargs... )
214+ end
215+ end
216+ end
217+
218+ """
219+ gram_eigh_full(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X
220+ gram_eigh_full!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X
221+
222+ Gram factorization of a Hermitian positive semi-definite matrix via its
223+ eigendecomposition: returns `X = sqrth_safe(D; atol, rtol) * V'` such
224+ that `A ≈ X' * X`, where `A = V * D * V'`. Eigenvalues below `tol` (see
225+ [`pow_diag_safe`](@ref)) are clamped to zero. The `!!` variant may
226+ destroy `A`.
227+
228+ ## Keyword arguments
229+
230+ - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
231+
232+ $(_clamp_kwargs_doc (" A" ))
233+
234+ # Examples
235+
236+ ```jldoctest
237+ julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full
238+
239+ julia> B = [1.0 0.5; 0.5 2.0];
240+
241+ julia> A = B' * B;
242+
243+ julia> X = gram_eigh_full(A);
244+
245+ julia> X' * X ≈ A
246+ true
247+ ```
248+
249+ See also [`gram_eigh_full_with_pinv`](@ref).
250+ """
251+ gram_eigh_full, gram_eigh_full!!
252+
253+ """
254+ gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y
255+ gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y
256+
257+ Like [`gram_eigh_full`](@ref), but additionally returns
258+ `Y = V * invsqrth_safe(D; atol, rtol) ≈ pinv(X)` so that `X * Y ≈ I` on
259+ the rank subspace. Eigenvalues below `tol` are clamped to zero in both
260+ factors. The `!!` variant may destroy `A`.
261+
262+ ## Keyword arguments
263+
264+ - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
265+
266+ $(_clamp_kwargs_doc (" A" ))
267+
268+ # Examples
269+
270+ ```jldoctest
271+ julia> using LinearAlgebra: I
272+
273+ julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full_with_pinv
274+
275+ julia> B = [1.0 0.5; 0.5 2.0];
276+
277+ julia> A = B' * B;
278+
279+ julia> X, Y = gram_eigh_full_with_pinv(A);
280+
281+ julia> X' * X ≈ A
282+ true
283+
284+ julia> X * Y ≈ I
285+ true
286+ ```
287+ """
288+ gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!!
289+
77290for (svd, svd_trunc, svd_full, svd_compact) in (
78291 (:svd , :svd_trunc , :svd_full , :svd_compact ),
79292 (:svd!! , :svd_trunc! , :svd_full! , :svd_compact! ),
0 commit comments