Skip to content

Commit c934382

Browse files
mtfishmanclaude
andauthored
Add gram_eigh_full / gram_eigh_full_with_pinv factorizations (#174)
## Summary - Adds `gram_eigh_full(A) -> X` for Hermitian positive-semi-definite inputs (`A ≈ X' * X`, rank-first convention matching `LinearAlgebra.cholesky`), plus `gram_eigh_full_with_pinv(A) -> X, Y` that also returns `Y ≈ pinv(X)` so `X * Y ≈ I` on the rank subspace. - Matrix, tensor, and label-entry layers mirroring existing factorizations. - Exposes a small `MatrixAlgebra` helper family (`pow_diag_safe`, `powh_safe`, and `sqrt`/`invsqrt` variants) with `atol`/`rtol` clamping, used to build the gram factorizations and extensible to graded/block-diagonal types via `MAK.diagview`/`MAK.diagonal`. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5a0dfab commit c934382

7 files changed

Lines changed: 510 additions & 35 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
3-
version = "0.9.2"
3+
version = "0.9.3"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]

src/MatrixAlgebra.jl

Lines changed: 214 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,31 @@ export eigen,
66
eigvals!!,
77
factorize,
88
factorize!!,
9+
gram_eigh_full,
10+
gram_eigh_full!!,
11+
gram_eigh_full_with_pinv,
12+
gram_eigh_full_with_pinv!!,
13+
invsqrt_diag_safe,
14+
invsqrth_safe,
915
lq,
1016
lq!!,
1117
orth,
1218
orth!!,
1319
polar,
1420
polar!!,
21+
pow_diag_safe,
22+
powh_safe,
1523
qr,
1624
qr!!,
25+
sqrt_diag_safe,
26+
sqrth_safe,
1727
svd,
1828
svd!!,
1929
svdvals,
2030
svdvals!!
2131

2232
import MatrixAlgebraKit as MAK
23-
using LinearAlgebra: LinearAlgebra, norm
33+
using LinearAlgebra: LinearAlgebra, Diagonal, isdiag, norm
2434

2535
for (f, f_full, f_compact) in (
2636
(:qr, :qr_full, :qr_compact),
@@ -74,6 +84,209 @@ for (eigvals, eigh_vals, eig_vals) in
7484
end
7585
end
7686

87+
function _clamp_kwargs_doc(arg::AbstractString)
88+
return join(
89+
(
90+
" - `atol::Real`: absolute clamping threshold. Default `0`.",
91+
" - `rtol::Real`: relative clamping threshold. Default `eps(real(eltype($arg)))^(3//4)` when `atol = 0`, else `0`.",
92+
), "\n"
93+
)
94+
end
95+
96+
"""
97+
pow_diag_safe(D::AbstractMatrix, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p
98+
99+
Raise a diagonal-structured matrix `D` to the power `p`. Diagonal entries
100+
`d` of `MAK.diagview(D)` with `abs(d) < tol` are clamped to zero before
101+
exponentiation, where `tol = max(atol, rtol * maximum(abs, diagview(D)))`.
102+
Negative `d` above `tol` cause `d^p` to error for fractional `p` (e.g.
103+
`p = 1//2`) and pass through for integer `p`, so the operation itself
104+
enforces the PSD precondition per-power. Errors if `isdiag(D)` is `false`.
105+
106+
The implementation extracts entries via `MAK.diagview` and rebuilds via
107+
`MAK.diagonal`, so types extending those (e.g. graded or block diagonal)
108+
automatically extend [`sqrt_diag_safe`](@ref), [`invsqrt_diag_safe`](@ref),
109+
and the [`powh_safe`](@ref) family.
110+
111+
## Keyword arguments
112+
113+
$(_clamp_kwargs_doc("D"))
114+
"""
115+
function pow_diag_safe(
116+
D::AbstractMatrix, p;
117+
atol = zero(real(eltype(D))),
118+
rtol = iszero(atol) ? eps(real(eltype(D)))^(3 // 4) :
119+
zero(real(eltype(D)))
120+
)
121+
isdiag(D) || throw(
122+
ArgumentError("pow_diag_safe requires a diagonal-structured matrix")
123+
)
124+
σ = MAK.diagview(D)
125+
tol = max(atol, rtol * maximum(abs, σ; init = zero(real(eltype(D)))))
126+
return MAK.diagonal(map(d -> abs(d) < tol ? zero(d) : real(d)^p, σ))
127+
end
128+
129+
"""
130+
sqrt_diag_safe(D::AbstractMatrix; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(1//2)
131+
132+
Square root of a diagonal-structured matrix `D`, equivalent to
133+
`pow_diag_safe(D, 1//2; atol, rtol)`.
134+
135+
## Keyword arguments
136+
137+
$(_clamp_kwargs_doc("D"))
138+
"""
139+
sqrt_diag_safe(D::AbstractMatrix; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...)
140+
141+
"""
142+
invsqrt_diag_safe(D::AbstractMatrix; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(-1//2)
143+
144+
Inverse square root of a diagonal-structured matrix `D`, treating diagonal
145+
entries below tolerance as zero (Moore-Penrose convention). Equivalent to
146+
`pow_diag_safe(D, -1//2; atol, rtol)`.
147+
148+
## Keyword arguments
149+
150+
$(_clamp_kwargs_doc("D"))
151+
"""
152+
invsqrt_diag_safe(D::AbstractMatrix; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...)
153+
154+
"""
155+
powh_safe(M::AbstractMatrix, p; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^p
156+
157+
Raise an approximately Hermitian positive semi-definite matrix to the
158+
power `p`. For diagonal-structured `M` (`isdiag(M) == true`), dispatches
159+
to [`pow_diag_safe`](@ref) and skips the eigendecomposition. Otherwise,
160+
computes via `M = V * D * V'` as `V * pow_diag_safe(D, p; atol, rtol) * V'`.
161+
162+
## Keyword arguments
163+
164+
- `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
165+
166+
$(_clamp_kwargs_doc("M"))
167+
"""
168+
function powh_safe(M::AbstractMatrix, p; alg = nothing, kwargs...)
169+
isdiag(M) && return pow_diag_safe(M, p; kwargs...)
170+
D, V = MAK.eigh_full(M, MAK.select_algorithm(MAK.eigh_full, M, alg))
171+
return V * pow_diag_safe(D, p; kwargs...) * V'
172+
end
173+
174+
"""
175+
sqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(1//2)
176+
177+
Square root of an approximately Hermitian positive semi-definite matrix.
178+
Equivalent to `powh_safe(M, 1//2; alg, atol, rtol)`.
179+
180+
## Keyword arguments
181+
182+
- `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
183+
184+
$(_clamp_kwargs_doc("M"))
185+
"""
186+
sqrth_safe(M::AbstractMatrix; kwargs...) = powh_safe(M, 1 // 2; kwargs...)
187+
188+
"""
189+
invsqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(-1//2)
190+
191+
Inverse square root of an approximately Hermitian positive semi-definite
192+
matrix. Equivalent to `powh_safe(M, -1//2; alg, atol, rtol)`.
193+
194+
## Keyword arguments
195+
196+
- `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
197+
198+
$(_clamp_kwargs_doc("M"))
199+
"""
200+
invsqrth_safe(M::AbstractMatrix; kwargs...) = powh_safe(M, -1 // 2; kwargs...)
201+
202+
for (gram, gram_with_pinv, eigh_full) in (
203+
(:gram_eigh_full, :gram_eigh_full_with_pinv, :eigh_full),
204+
(:gram_eigh_full!!, :gram_eigh_full_with_pinv!!, :eigh_full!),
205+
)
206+
@eval begin
207+
function $gram(A::AbstractMatrix; alg = nothing, kwargs...)
208+
D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg))
209+
return sqrth_safe(D; kwargs...) * V'
210+
end
211+
function $gram_with_pinv(A::AbstractMatrix; alg = nothing, kwargs...)
212+
D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg))
213+
return sqrth_safe(D; kwargs...) * V', V * invsqrth_safe(D; kwargs...)
214+
end
215+
end
216+
end
217+
218+
"""
219+
gram_eigh_full(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X
220+
gram_eigh_full!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X
221+
222+
Gram factorization of a Hermitian positive semi-definite matrix via its
223+
eigendecomposition: returns `X = sqrth_safe(D; atol, rtol) * V'` such
224+
that `A ≈ X' * X`, where `A = V * D * V'`. Eigenvalues below `tol` (see
225+
[`pow_diag_safe`](@ref)) are clamped to zero. The `!!` variant may
226+
destroy `A`.
227+
228+
## Keyword arguments
229+
230+
- `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
231+
232+
$(_clamp_kwargs_doc("A"))
233+
234+
# Examples
235+
236+
```jldoctest
237+
julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full
238+
239+
julia> B = [1.0 0.5; 0.5 2.0];
240+
241+
julia> A = B' * B;
242+
243+
julia> X = gram_eigh_full(A);
244+
245+
julia> X' * X ≈ A
246+
true
247+
```
248+
249+
See also [`gram_eigh_full_with_pinv`](@ref).
250+
"""
251+
gram_eigh_full, gram_eigh_full!!
252+
253+
"""
254+
gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y
255+
gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y
256+
257+
Like [`gram_eigh_full`](@ref), but additionally returns
258+
`Y = V * invsqrth_safe(D; atol, rtol) ≈ pinv(X)` so that `X * Y ≈ I` on
259+
the rank subspace. Eigenvalues below `tol` are clamped to zero in both
260+
factors. The `!!` variant may destroy `A`.
261+
262+
## Keyword arguments
263+
264+
- `alg`: forwarded to `MatrixAlgebraKit.eigh_full`.
265+
266+
$(_clamp_kwargs_doc("A"))
267+
268+
# Examples
269+
270+
```jldoctest
271+
julia> using LinearAlgebra: I
272+
273+
julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full_with_pinv
274+
275+
julia> B = [1.0 0.5; 0.5 2.0];
276+
277+
julia> A = B' * B;
278+
279+
julia> X, Y = gram_eigh_full_with_pinv(A);
280+
281+
julia> X' * X ≈ A
282+
true
283+
284+
julia> X * Y ≈ I
285+
true
286+
```
287+
"""
288+
gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!!
289+
77290
for (svd, svd_trunc, svd_full, svd_compact) in (
78291
(:svd, :svd_trunc, :svd_full, :svd_compact),
79292
(:svd!!, :svd_trunc!, :svd_full!, :svd_compact!),

src/TensorAlgebra.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module TensorAlgebra
22

3-
export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar,
4-
lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals
3+
export contract, contract!, eigen, eigvals, factorize, gram_eigh_full,
4+
gram_eigh_full_with_pinv, left_null, left_orth, left_polar, lq, qr,
5+
right_null, right_orth, right_polar, orth, polar, svd, svdvals
56

67
if VERSION >= v"1.11.0-DEV.469"
78
eval(Meta.parse("public contractopadd!, matricizeop"))

0 commit comments

Comments
 (0)