Skip to content

Commit 71b8d7c

Browse files
mtfishmanclaude
andauthored
Balanced gram_eigh_full convention, add one and operator-construction primitives (#177)
## Summary - `gram_eigh_full` and `gram_eigh_full_with_pinv` flipped to the balanced `A ≈ X * X'` convention (was right-Gram `A ≈ X' * X`). Breaking, but lands as a patch bump since `gram_eigh_full` is new in v0.9.4 with no known downstream users. - `TensorAlgebra.one(A, codomain, domain)` for identity operator tensors. Not exported (clashes with `Base.one`). - `similar_map(prototype, [T,] codomain_axes, domain_axes)` for allocating linear-map-shaped arrays. `NamedDimsArrays.similar_operator` routes through this (ITensor/NamedDimsArrays.jl#229). - `projectto!` and `checked_projectto!` for projecting into a restricted subspace. - `project_map` and `checked_project_map` as `similar_map` + `projectto!` allocators. - `trivialrange(::Type{<:AbstractUnitRange})` for the identity range under `tensor_product_axis`. Defaults to `Base.OneTo(1)`, overloaded by downstream packages (for example a charge-0 graded range). - The `!!` (may-destroy-input) factorization variants in `MatrixAlgebra` are no longer exported and no longer appear in the `gram_eigh_full` docstrings, treating them as internal. They remain callable as qualified `MatrixAlgebra.svd!!`-style functions. Nothing outside TensorAlgebra used them. --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 0895aa9 commit 71b8d7c

17 files changed

Lines changed: 338 additions & 75 deletions

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
3-
version = "0.9.4"
3+
version = "0.9.5"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]

src/MatrixAlgebra.jl

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,25 @@
11
module MatrixAlgebra
22

33
export eigen,
4-
eigen!!,
54
eigvals,
6-
eigvals!!,
75
factorize,
8-
factorize!!,
96
gram_eigh_full,
10-
gram_eigh_full!!,
117
gram_eigh_full_with_pinv,
12-
gram_eigh_full_with_pinv!!,
138
invsqrt_diag_safe,
149
invsqrth_safe,
1510
lq,
16-
lq!!,
1711
orth,
18-
orth!!,
1912
polar,
20-
polar!!,
2113
pow_diag_safe,
2214
powh_safe,
2315
qr,
24-
qr!!,
2516
sqrt_diag_safe,
2617
sqrth_safe,
2718
svd,
28-
svd!!,
29-
svdvals,
30-
svdvals!!
19+
svdvals
3120

32-
import MatrixAlgebraKit as MAK
3321
using LinearAlgebra: LinearAlgebra, Diagonal, isdiag, norm
22+
using MatrixAlgebraKit: MatrixAlgebraKit as MAK
3423

3524
for (f, f_full, f_compact) in (
3625
(:qr, :qr_full, :qr_compact),
@@ -206,24 +195,23 @@ for (gram, gram_with_pinv, eigh_full) in (
206195
@eval begin
207196
function $gram(A::AbstractMatrix; alg = nothing, kwargs...)
208197
D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg))
209-
return sqrth_safe(D; kwargs...) * V'
198+
return V * sqrth_safe(D; kwargs...)
210199
end
211200
function $gram_with_pinv(A::AbstractMatrix; alg = nothing, kwargs...)
212201
D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg))
213-
return sqrth_safe(D; kwargs...) * V', V * invsqrth_safe(D; kwargs...)
202+
return V * sqrth_safe(D; kwargs...), invsqrth_safe(D; kwargs...) * V'
214203
end
215204
end
216205
end
217206

218207
"""
219208
gram_eigh_full(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X
220-
gram_eigh_full!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X
221209
222210
Gram factorization of a Hermitian positive semi-definite matrix via its
223-
eigendecomposition: returns `X = sqrth_safe(D; atol, rtol) * V'` such
224-
that `A ≈ X' * X`, where `A = V * D * V'`. Eigenvalues below `tol` (see
225-
[`pow_diag_safe`](@ref)) are clamped to zero. The `!!` variant may
226-
destroy `A`.
211+
eigendecomposition (balanced eigh): returns `X = V * sqrth_safe(D; atol, rtol)`
212+
such that `A ≈ X * X'`, where `A = V * D * V'`. The square-root of `D` is
213+
absorbed symmetrically into the two factors of the eigendecomposition.
214+
Eigenvalues below `tol` (see [`pow_diag_safe`](@ref)) are clamped to zero.
227215
228216
## Keyword arguments
229217
@@ -242,22 +230,21 @@ julia> A = B' * B;
242230
243231
julia> X = gram_eigh_full(A);
244232
245-
julia> X' * X ≈ A
233+
julia> X * X' ≈ A
246234
true
247235
```
248236
249237
See also [`gram_eigh_full_with_pinv`](@ref).
250238
"""
251-
gram_eigh_full, gram_eigh_full!!
239+
gram_eigh_full
252240

253241
"""
254242
gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y
255-
gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y
256243
257244
Like [`gram_eigh_full`](@ref), but additionally returns
258-
`Y = V * invsqrth_safe(D; atol, rtol) ≈ pinv(X)` so that `X * Y ≈ I` on
259-
the rank subspace. Eigenvalues below `tol` are clamped to zero in both
260-
factors. The `!!` variant may destroy `A`.
245+
`Y = invsqrth_safe(D; atol, rtol) * V' ≈ pinv(X)`, a left inverse of `X`
246+
on the rank subspace: `Y * X ≈ I`. Eigenvalues below `tol` are clamped to
247+
zero in both factors.
261248
262249
## Keyword arguments
263250
@@ -278,14 +265,14 @@ julia> A = B' * B;
278265
279266
julia> X, Y = gram_eigh_full_with_pinv(A);
280267
281-
julia> X' * X ≈ A
268+
julia> X * X' ≈ A
282269
true
283270
284-
julia> X * Y ≈ I
271+
julia> Y * X ≈ I
285272
true
286273
```
287274
"""
288-
gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!!
275+
gram_eigh_full_with_pinv
289276

290277
for (svd, svd_trunc, svd_full, svd_compact) in (
291278
(:svd, :svd_trunc, :svd_full, :svd_compact),

src/TensorAlgebra.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ include("contract/allocate_output.jl")
2222
include("contract/contract_matricize.jl")
2323
include("factorizations.jl")
2424
include("matrixfunctions.jl")
25+
include("similar_map.jl")
26+
include("projectto.jl")
2527
include("linearbroadcasted.jl")
2628

2729
end

src/blockedtuple.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ function Base.map(f, bt::AbstractBlockTuple)
110110
return widened_constructorof(typeof(bt))(t, Val(blocklengths(bt)))
111111
end
112112

113+
function Base.invperm(bt::AbstractBlockTuple)
114+
return widened_constructorof(typeof(bt))(invperm(Tuple(bt)), Val(blocklengths(bt)))
115+
end
116+
113117
function Base.show(io::IO, bt::AbstractBlockTuple)
114118
return print(io, nameof(typeof(bt)), blocks(bt))
115119
end

src/factorizations.jl

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
for f in (
3232
:qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth,
3333
:factorize, :eigen, :eigvals, :svd, :svdvals, :left_null, :right_null,
34-
:gram_eigh_full, :gram_eigh_full_with_pinv,
34+
:gram_eigh_full, :gram_eigh_full_with_pinv, :one,
3535
)
3636
@eval begin
3737
function $f(
@@ -312,8 +312,11 @@ function svd!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs
312312
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
313313
axes_codomain, axes_domain = blocks(axes(A)[biperm])
314314
axes_U = tuplemortar((axes_codomain, (axes(U, 2),)))
315+
axes_S = tuplemortar(((axes(S, 1),), (axes(S, 2),)))
315316
axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain))
316-
return unmatricize(style, U, axes_U), S, unmatricize(style, Vᴴ, axes_Vᴴ)
317+
return unmatricize(style, U, axes_U),
318+
unmatricize(style, S, axes_S),
319+
unmatricize(style, Vᴴ, axes_Vᴴ)
317320
end
318321
function svd!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
319322
return svd!!(FusionStyle(A), A, ndims_codomain; kwargs...)
@@ -443,7 +446,9 @@ end
443446
444447
Gram factorization of a generic N-dimensional array, interpreting it as a
445448
Hermitian positive semi-definite linear map from the domain to the codomain
446-
dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg).
449+
dimensions. Returns `X` such that `A ≈ X * X'` (contracted on the rank leg),
450+
i.e. the codomain axes of `X` match the codomain axes of `A` and `X` has a
451+
single trailing rank axis.
447452
448453
## Keyword arguments
449454
@@ -462,7 +467,7 @@ julia> A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d));
462467
463468
julia> X = gram_eigh_full(A, (:a, :b, :c, :d), (:a, :b), (:c, :d));
464469
465-
julia> A ≈ contract((:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d))
470+
julia> A ≈ contract((:a, :b, :c, :d), X, (:a, :b, :r), conj(X), (:c, :d, :r))
466471
true
467472
```
468473
@@ -478,7 +483,7 @@ function gram_eigh_full!!(
478483
X = MatrixAlgebra.gram_eigh_full!!(A_mat; kwargs...)
479484
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
480485
axes_codomain = first(blocks(axes(A)[biperm]))
481-
axes_X = tuplemortar(((axes(X, 1),), axes_codomain))
486+
axes_X = tuplemortar((axes_codomain, (axes(X, 2),)))
482487
return unmatricize(style, X, axes_X)
483488
end
484489
function gram_eigh_full!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
@@ -501,7 +506,9 @@ end
501506
gram_eigh_full_with_pinv(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y
502507
503508
Like [`gram_eigh_full`](@ref), but additionally returns `Y ≈ pinv(X)` such
504-
that `X * Y ≈ I` on the rank subspace.
509+
that `Y * X ≈ I` on the rank subspace (a left inverse). The codomain axes
510+
of `X` match the codomain axes of `A`; `Y` has a leading rank axis followed
511+
by the codomain axes.
505512
506513
## Keyword arguments
507514
@@ -522,10 +529,10 @@ julia> A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d));
522529
523530
julia> X, Y = gram_eigh_full_with_pinv(A, (:a, :b, :c, :d), (:a, :b), (:c, :d));
524531
525-
julia> A ≈ contract((:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d))
532+
julia> A ≈ contract((:a, :b, :c, :d), X, (:a, :b, :r), conj(X), (:c, :d, :r))
526533
true
527534
528-
julia> contract((:r, :s), X, (:r, :a, :b), Y, (:a, :b, :s)) ≈ I
535+
julia> contract((:r, :s), Y, (:r, :a, :b), X, (:a, :b, :s)) ≈ I
529536
true
530537
```
531538
@@ -540,8 +547,8 @@ function gram_eigh_full_with_pinv!!(
540547
X, Y = MatrixAlgebra.gram_eigh_full_with_pinv!!(A_mat; kwargs...)
541548
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
542549
axes_codomain = first(blocks(axes(A)[biperm]))
543-
axes_X = tuplemortar(((axes(X, 1),), axes_codomain))
544-
axes_Y = tuplemortar((axes_codomain, (axes(Y, 2),)))
550+
axes_X = tuplemortar((axes_codomain, (axes(X, 2),)))
551+
axes_Y = tuplemortar(((axes(Y, 1),), conj.(axes_codomain)))
545552
return unmatricize(style, X, axes_X), unmatricize(style, Y, axes_Y)
546553
end
547554
function gram_eigh_full_with_pinv!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
@@ -556,3 +563,55 @@ end
556563
function gram_eigh_full_with_pinv(A::AbstractArray, ndims_codomain::Val; kwargs...)
557564
return gram_eigh_full_with_pinv!!(copy(A), ndims_codomain; kwargs...)
558565
end
566+
567+
"""
568+
TensorAlgebra.one(A::AbstractArray, labels_A, labels_codomain, labels_domain) -> Id
569+
TensorAlgebra.one(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}) -> Id
570+
TensorAlgebra.one(A::AbstractArray, ndims_codomain::Val) -> Id
571+
TensorAlgebra.one(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> Id
572+
573+
Construct the identity operator tensor whose shape mirrors `A`, interpreted as a
574+
linear map from the domain to the codomain dimensions. The codomain and domain
575+
partition is specified either via labels or directly through a bi-permutation;
576+
fused codomain and domain sizes must match. `A` is treated as a shape prototype
577+
and is not mutated.
578+
579+
Not exported, since exporting would clash with the implicit `Base.one`. Qualify
580+
as `TensorAlgebra.one(A, ...)`.
581+
582+
See also `MatrixAlgebraKit.one!`.
583+
584+
# Examples
585+
586+
```jldoctest
587+
julia> using LinearAlgebra: I
588+
589+
julia> using TensorAlgebra: TensorAlgebra, matricize
590+
591+
julia> A = randn(2, 3, 2, 3);
592+
593+
julia> Id = TensorAlgebra.one(A, (:a, :b, :c, :d), (:a, :b), (:c, :d));
594+
595+
julia> matricize(Id, Val(2)) ≈ I
596+
true
597+
```
598+
"""
599+
one
600+
601+
function one!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
602+
A_mat = matricize(style, A, ndims_codomain)
603+
MatrixAlgebraKit.one!(A_mat)
604+
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
605+
axes_codomain, axes_domain = blocks(axes(A)[biperm])
606+
return unmatricize(style, A_mat, axes_codomain, axes_domain)
607+
end
608+
function one!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
609+
return one!!(FusionStyle(A), A, ndims_codomain; kwargs...)
610+
end
611+
612+
function one(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
613+
return one!!(style, copy(A), ndims_codomain; kwargs...)
614+
end
615+
function one(A::AbstractArray, ndims_codomain::Val; kwargs...)
616+
return one!!(copy(A), ndims_codomain; kwargs...)
617+
end

src/linearbroadcasted.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import Base.Broadcast as BC
2-
import LinearAlgebra as LA
1+
using Base.Broadcast: Broadcast as BC
2+
using LinearAlgebra: LinearAlgebra as LA
33

44
# TermInterface-like interface.
55
iscall(x) = false

src/matricize.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange
7777
return FusionStyle(FusionStyle(r1), FusionStyle(r2))
7878
end
7979

80+
"""
81+
TensorAlgebra.trivialrange(R::Type{<:AbstractUnitRange})
82+
TensorAlgebra.trivialrange(r::AbstractUnitRange)
83+
84+
Return the identity range for `tensor_product_axis` on ranges of type `R`,
85+
i.e. a one-dimensional range `t` for which fusing `t` with any other range
86+
of the same family leaves that range unchanged. Defaults to `Base.OneTo(1)`;
87+
downstream packages overload the type-level method to return their own
88+
identity (for example, a charge-0 one-dimensional sector for a graded range).
89+
"""
90+
trivialrange(r::AbstractUnitRange) = trivialrange(typeof(r))
91+
trivialrange(::Type{<:AbstractUnitRange}) = Base.OneTo(1)
92+
8093
function fused_axis(
8194
style::FusionStyle, side::Val{:codomain}, a::AbstractArray,
8295
axes_codomain::Tuple{Vararg{AbstractUnitRange}},

src/permutedimsadd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import StridedViews as SV
21
using FunctionImplementations: permuteddims
32
using Strided: Strided
3+
using StridedViews: StridedViews as SV
44

55
# Specify if an array is on CPU. This is helpful for backends that don't support
66
# operations on GPU, such as Strided.jl.

src/projectto.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
projectto!(dest, src) -> dest
3+
4+
Project `src` into the restricted space of `dest` without checking which
5+
components may have been projected out. Defaults to `copyto!`. See
6+
[`checked_projectto!`](@ref) for a checked version.
7+
"""
8+
projectto!(dest, src) = copyto!(dest, src)
9+
10+
"""
11+
checked_projectto!(dest, src; kwargs...) -> dest
12+
13+
Project `src` into the restricted space of `dest` via [`projectto!`](@ref)
14+
and verify via `isapprox(src, dest; kwargs...)` that the discarded
15+
component is within tolerance. Keyword arguments are forwarded to
16+
`isapprox`. The default tolerances are subject to change in future
17+
versions.
18+
"""
19+
function checked_projectto!(dest, src; kwargs...)
20+
projectto!(dest, src)
21+
isapprox(src, dest; kwargs...) ||
22+
throw(InexactError(:checked_projectto!, typeof(dest), src))
23+
return dest
24+
end
25+
26+
"""
27+
project_map(raw, codomain_axes, domain_axes) -> dest
28+
29+
Allocate a map-shaped array via [`similar_map`](@ref) and project `raw`
30+
into it with [`projectto!`](@ref). See [`checked_project_map`](@ref) for
31+
a checked version.
32+
"""
33+
function project_map(raw, codomain_axes, domain_axes)
34+
return projectto!(similar_map(raw, codomain_axes, domain_axes), raw)
35+
end
36+
37+
"""
38+
checked_project_map(raw, codomain_axes, domain_axes; kwargs...) -> dest
39+
40+
Allocate a map-shaped array via [`similar_map`](@ref) and project `raw`
41+
into it with [`checked_projectto!`](@ref). Keyword arguments are forwarded
42+
to [`checked_projectto!`](@ref).
43+
"""
44+
function checked_project_map(raw, codomain_axes, domain_axes; kwargs...)
45+
return checked_projectto!(
46+
similar_map(raw, codomain_axes, domain_axes), raw; kwargs...
47+
)
48+
end

src/similar_map.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
similar_map(prototype, [T,] codomain_axes, domain_axes) -> M
3+
4+
Allocate an array shaped as a linear map from `domain_axes` to
5+
`codomain_axes` with element type `T` (defaulting to `eltype(prototype)`),
6+
using `prototype` to determine the array backend. Defaults to
7+
`similar(prototype, T, (codomain_axes..., conj.(domain_axes)...))`.
8+
9+
# Examples
10+
11+
```jldoctest
12+
julia> using TensorAlgebra: similar_map
13+
14+
julia> cod, dom = (Base.OneTo(2), Base.OneTo(3)), (Base.OneTo(4), Base.OneTo(5));
15+
16+
julia> M = similar_map(randn(3), Float32, cod, dom);
17+
18+
julia> eltype(M), size(M)
19+
(Float32, (2, 3, 4, 5))
20+
```
21+
"""
22+
function similar_map(prototype, ::Type{T}, codomain_axes, domain_axes) where {T}
23+
return similar(prototype, T, (codomain_axes..., conj.(domain_axes)...))
24+
end
25+
function similar_map(prototype, codomain_axes, domain_axes)
26+
return similar_map(prototype, eltype(prototype), codomain_axes, domain_axes)
27+
end

0 commit comments

Comments
 (0)