Skip to content

Commit f923471

Browse files
committed
Major overhaul
1 parent 4b0b658 commit f923471

19 files changed

Lines changed: 1906 additions & 2091 deletions

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 128 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
using MatrixAlgebraKit: svd_compact_pullback!
2+
13
# Factorizations rules
24
# --------------------
35
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
4-
trunc::TensorKit.TruncationScheme=TensorKit.NoTruncation(),
5-
p::Real=2,
6+
trunc::TensorKit.TruncationScheme=TensorKit.notrunc(),
67
alg::Union{TensorKit.SVD,TensorKit.SDD}=TensorKit.SDD())
7-
U, Σ, V⁺, truncerr = tsvd(t; trunc=TensorKit.NoTruncation(), p=p, alg=alg)
8+
U, Σ, V⁺, truncerr = tsvd(t; trunc=TensorKit.notrunc(), alg)
89

9-
if !(trunc isa TensorKit.NoTruncation) && !isempty(blocksectors(t))
10+
if !(trunc == TensorKit.notrunc()) && !isempty(blocksectors(t))
1011
Σdata = TensorKit.SectorDict(c => diag(b) for (c, b) in blocks(Σ))
1112

12-
truncdim = TensorKit._compute_truncdim(Σdata, trunc, p)
13-
truncerr = TensorKit._compute_truncerr(Σdata, truncdim, p)
13+
truncdim = TensorKit._compute_truncdim(Σdata, trunc; p=2)
14+
truncerr = TensorKit._compute_truncerr(Σdata, truncdim; p=2)
1415

1516
SVDdata = TensorKit.SectorDict(c => (block(U, c), Σc, block(V⁺, c))
1617
for (c, Σc) in Σdata)
@@ -23,12 +24,11 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
2324
function tsvd!_pullback(ΔUSVϵ)
2425
ΔU, ΔΣ, ΔV⁺, = unthunk.(ΔUSVϵ)
2526
Δt = similar(t)
26-
for (c, b) in blocks(Δt)
27-
Uc, Σc, V⁺c = block(U, c), block(Σ, c), block(V⁺, c)
28-
ΔUc, ΔΣc, ΔV⁺c = block(ΔU, c), block(ΔΣ, c), block(ΔV⁺, c)
29-
Σdc = view(Σc, diagind(Σc))
30-
ΔΣdc = (ΔΣc isa AbstractZero) ? ΔΣc : view(ΔΣc, diagind(ΔΣc))
31-
svd_pullback!(b, Uc, Σdc, V⁺c, ΔUc, ΔΣdc, ΔV⁺c)
27+
foreachblock(Δt) do (c, b)
28+
USVᴴc = block(U, c), block(Σ, c), block(V⁺, c)
29+
ΔUSVᴴc = block(ΔU, c), block(ΔΣ, c), block(ΔV⁺, c)
30+
svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc)
31+
return nothing
3232
end
3333
return NoTangent(), Δt
3434
end
@@ -187,122 +187,122 @@ end
187187
# Other implementation considerations for GPU compatibility:
188188
# no scalar indexing, lots of broadcasting and views
189189
#
190-
function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
191-
Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
192-
tol::Real=default_pullback_gaugetol(S))
193-
194-
# Basic size checks and determination
195-
m, n = size(U, 1), size(Vd, 2)
196-
size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch())
197-
p = -1
198-
if !(ΔU isa AbstractZero)
199-
m == size(ΔU, 1) || throw(DimensionMismatch())
200-
p = size(ΔU, 2)
201-
end
202-
if !(ΔVd isa AbstractZero)
203-
n == size(ΔVd, 2) || throw(DimensionMismatch())
204-
if p == -1
205-
p = size(ΔVd, 1)
206-
else
207-
p == size(ΔVd, 1) || throw(DimensionMismatch())
208-
end
209-
end
210-
if !(ΔS isa AbstractZero)
211-
if p == -1
212-
p = length(ΔS)
213-
else
214-
p == length(ΔS) || throw(DimensionMismatch())
215-
end
216-
end
217-
Up = view(U, :, 1:p)
218-
Vp = view(Vd, 1:p, :)'
219-
Sp = view(S, 1:p)
220-
221-
# rank
222-
r = searchsortedlast(S, tol; rev=true)
223-
224-
# compute antihermitian part of projection of ΔU and ΔV onto U and V
225-
# also already subtract this projection from ΔU and ΔV
226-
if !(ΔU isa AbstractZero)
227-
UΔU = Up' * ΔU
228-
aUΔU = rmul!(UΔU - UΔU', 1 / 2)
229-
if m > p
230-
ΔU -= Up * UΔU
231-
end
232-
else
233-
aUΔU = fill!(similar(U, (p, p)), 0)
234-
end
235-
if !(ΔVd isa AbstractZero)
236-
VΔV = Vp' * ΔVd'
237-
aVΔV = rmul!(VΔV - VΔV', 1 / 2)
238-
if n > p
239-
ΔVd -= VΔV' * Vp'
240-
end
241-
else
242-
aVΔV = fill!(similar(Vd, (p, p)), 0)
243-
end
244-
245-
# check whether cotangents arise from gauge-invariance objective function
246-
mask = abs.(Sp' .- Sp) .< tol
247-
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
248-
if p > r
249-
rprange = (r + 1):p
250-
Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
251-
Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
252-
end
253-
Δgauge < tol ||
254-
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
255-
256-
UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
257-
(aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol)
258-
if !(ΔS isa ZeroTangent)
259-
UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
260-
# in principle, ΔS is real, but maybe not if coming from an anyonic tensor
261-
end
262-
mul!(ΔA, Up, UdΔAV * Vp')
263-
264-
if r > p # contribution from truncation
265-
Ur = view(U, :, (p + 1):r)
266-
Vr = view(Vd, (p + 1):r, :)'
267-
Sr = view(S, (p + 1):r)
268-
269-
if !(ΔU isa AbstractZero)
270-
UrΔU = Ur' * ΔU
271-
if m > r
272-
ΔU -= Ur * UrΔU # subtract this part from ΔU
273-
end
274-
else
275-
UrΔU = fill!(similar(U, (r - p, p)), 0)
276-
end
277-
if !(ΔVd isa AbstractZero)
278-
VrΔV = Vr' * ΔVd'
279-
if n > r
280-
ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
281-
end
282-
else
283-
VrΔV = fill!(similar(Vd, (r - p, p)), 0)
284-
end
285-
286-
X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+
287-
(UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
288-
Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .-
289-
(UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
290-
291-
# ΔA += Ur * X * Vp' + Up * Y' * Vr'
292-
mul!(ΔA, Ur, X * Vp', 1, 1)
293-
mul!(ΔA, Up * Y', Vr', 1, 1)
294-
end
295-
296-
if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
297-
# ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp'
298-
mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1)
299-
end
300-
if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
301-
# ΔA += U * (safe_inv.(Sp, tol) .* ΔVd)
302-
mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1)
303-
end
304-
return ΔA
305-
end
190+
# function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
191+
# Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
192+
# tol::Real=default_pullback_gaugetol(S))
193+
194+
# # Basic size checks and determination
195+
# m, n = size(U, 1), size(Vd, 2)
196+
# size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch())
197+
# p = -1
198+
# if !(ΔU isa AbstractZero)
199+
# m == size(ΔU, 1) || throw(DimensionMismatch())
200+
# p = size(ΔU, 2)
201+
# end
202+
# if !(ΔVd isa AbstractZero)
203+
# n == size(ΔVd, 2) || throw(DimensionMismatch())
204+
# if p == -1
205+
# p = size(ΔVd, 1)
206+
# else
207+
# p == size(ΔVd, 1) || throw(DimensionMismatch())
208+
# end
209+
# end
210+
# if !(ΔS isa AbstractZero)
211+
# if p == -1
212+
# p = length(ΔS)
213+
# else
214+
# p == length(ΔS) || throw(DimensionMismatch())
215+
# end
216+
# end
217+
# Up = view(U, :, 1:p)
218+
# Vp = view(Vd, 1:p, :)'
219+
# Sp = view(S, 1:p)
220+
221+
# # rank
222+
# r = searchsortedlast(S, tol; rev=true)
223+
224+
# # compute antihermitian part of projection of ΔU and ΔV onto U and V
225+
# # also already subtract this projection from ΔU and ΔV
226+
# if !(ΔU isa AbstractZero)
227+
# UΔU = Up' * ΔU
228+
# aUΔU = rmul!(UΔU - UΔU', 1 / 2)
229+
# if m > p
230+
# ΔU -= Up * UΔU
231+
# end
232+
# else
233+
# aUΔU = fill!(similar(U, (p, p)), 0)
234+
# end
235+
# if !(ΔVd isa AbstractZero)
236+
# VΔV = Vp' * ΔVd'
237+
# aVΔV = rmul!(VΔV - VΔV', 1 / 2)
238+
# if n > p
239+
# ΔVd -= VΔV' * Vp'
240+
# end
241+
# else
242+
# aVΔV = fill!(similar(Vd, (p, p)), 0)
243+
# end
244+
245+
# # check whether cotangents arise from gauge-invariance objective function
246+
# mask = abs.(Sp' .- Sp) .< tol
247+
# Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
248+
# if p > r
249+
# rprange = (r + 1):p
250+
# Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
251+
# Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
252+
# end
253+
# Δgauge < tol ||
254+
# @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
255+
256+
# UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
257+
# (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol)
258+
# if !(ΔS isa ZeroTangent)
259+
# UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
260+
# # in principle, ΔS is real, but maybe not if coming from an anyonic tensor
261+
# end
262+
# mul!(ΔA, Up, UdΔAV * Vp')
263+
264+
# if r > p # contribution from truncation
265+
# Ur = view(U, :, (p + 1):r)
266+
# Vr = view(Vd, (p + 1):r, :)'
267+
# Sr = view(S, (p + 1):r)
268+
269+
# if !(ΔU isa AbstractZero)
270+
# UrΔU = Ur' * ΔU
271+
# if m > r
272+
# ΔU -= Ur * UrΔU # subtract this part from ΔU
273+
# end
274+
# else
275+
# UrΔU = fill!(similar(U, (r - p, p)), 0)
276+
# end
277+
# if !(ΔVd isa AbstractZero)
278+
# VrΔV = Vr' * ΔVd'
279+
# if n > r
280+
# ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
281+
# end
282+
# else
283+
# VrΔV = fill!(similar(Vd, (r - p, p)), 0)
284+
# end
285+
286+
# X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+
287+
# (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
288+
# Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .-
289+
# (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
290+
291+
# # ΔA += Ur * X * Vp' + Up * Y' * Vr'
292+
# mul!(ΔA, Ur, X * Vp', 1, 1)
293+
# mul!(ΔA, Up * Y', Vr', 1, 1)
294+
# end
295+
296+
# if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
297+
# # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp'
298+
# mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1)
299+
# end
300+
# if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
301+
# # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd)
302+
# mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1)
303+
# end
304+
# return ΔA
305+
# end
306306

307307
function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
308308
tol::Real=default_pullback_gaugetol(D))

src/TensorKit.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export TruncationScheme
3131
export SpaceMismatch, SectorMismatch, IndexError # error types
3232

3333
# general vector space methods
34-
export space, field, dual, dim, reduceddim, dims, fuse, flip, isdual, oplus,
34+
export space, field, dual, dim, reduceddim, dims, fuse, flip, isdual, oplus, ominus,
3535
insertleftunit, insertrightunit, removeunit
3636

3737
# partial order for vector spaces
@@ -47,7 +47,7 @@ export ZNSpace, SU2Irrep, U1Irrep, CU1Irrep
4747
# bendleft, bendright, foldleft, foldright, cycleclockwise, cycleanticlockwise
4848

4949
# some unicode
50-
export , , ×, , ℂ, ℝ, ℤ, , , , , , ,
50+
export , , , ×, , ℂ, ℝ, ℤ, , , , , , ,
5151
export ℤ₂, ℤ₃, ℤ₄, U₁, SU, SU₂, CU₁
5252
export fℤ₂, fU₁, fSU₂
5353
export ℤ₂Space, ℤ₃Space, ℤ₄Space, U₁Space, CU₁Space, SU₂Space
@@ -70,8 +70,8 @@ export inner, dot, norm, normalize, normalize!, tr
7070

7171
# factorizations
7272
export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby!
73-
export leftorth, rightorth, leftnull, rightnull,
74-
leftorth!, rightorth!, leftnull!, rightnull!,
73+
export leftorth, rightorth, leftnull, rightnull, leftpolar, rightpolar,
74+
leftorth!, rightorth!, leftnull!, rightnull!, leftpolar!, rightpolar!,
7575
tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!,
7676
isposdef, isposdef!, ishermitian, isisometry, sylvester, rank, cond
7777
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
@@ -200,11 +200,13 @@ include("tensors/tensoroperations.jl")
200200
include("tensors/treetransformers.jl")
201201
include("tensors/indexmanipulations.jl")
202202
include("tensors/diagonal.jl")
203-
include("tensors/truncation.jl")
204-
include("tensors/matrixalgebrakit.jl")
205-
include("tensors/factorizations.jl")
206203
include("tensors/braidingtensor.jl")
207204

205+
include("tensors/factorizations/factorizations.jl")
206+
using .Factorizations
207+
# include("tensors/factorizations/matrixalgebrakit.jl")
208+
# include("tensors/truncation.jl")
209+
208210
# # Planar macros and related functionality
209211
# #-----------------------------------------
210212
@nospecialize

src/auxiliary/deprecate.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
import Base: transpose
22

33
#! format: off
4-
Base.@deprecate(permute(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false),
5-
permute(t, (p1, p2); copy=copy))
6-
Base.@deprecate(transpose(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false),
7-
transpose(t, (p1, p2); copy=copy))
8-
Base.@deprecate(braid(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple, levels; copy::Bool=false),
9-
braid(t, (p1, p2), levels; copy=copy))
10-
11-
Base.@deprecate(tsvd(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
12-
tsvd(t, (p₁, p₂); kwargs...))
13-
Base.@deprecate(leftorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
14-
leftorth(t, (p₁, p₂); kwargs...))
15-
Base.@deprecate(rightorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
16-
rightorth(t, (p₁, p₂); kwargs...))
17-
Base.@deprecate(leftnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
18-
leftnull(t, (p₁, p₂); kwargs...))
19-
Base.@deprecate(rightnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
20-
rightnull(t, (p₁, p₂); kwargs...))
21-
Base.@deprecate(LinearAlgebra.eigen(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
22-
LinearAlgebra.eigen(t, (p₁, p₂); kwargs...), false)
23-
Base.@deprecate(eig(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
24-
eig(t, (p₁, p₂); kwargs...))
25-
Base.@deprecate(eigh(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
26-
eigh(t, (p₁, p₂); kwargs...))
4+
# Base.@deprecate(permute(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false),
5+
# permute(t, (p1, p2); copy=copy))
6+
# Base.@deprecate(transpose(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false),
7+
# transpose(t, (p1, p2); copy=copy))
8+
# Base.@deprecate(braid(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple, levels; copy::Bool=false),
9+
# braid(t, (p1, p2), levels; copy=copy))
10+
11+
# Base.@deprecate(tsvd(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
12+
# tsvd(t, (p₁, p₂); kwargs...))
13+
# Base.@deprecate(leftorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
14+
# leftorth(t, (p₁, p₂); kwargs...))
15+
# Base.@deprecate(rightorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
16+
# rightorth(t, (p₁, p₂); kwargs...))
17+
# Base.@deprecate(leftnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
18+
# leftnull(t, (p₁, p₂); kwargs...))
19+
# Base.@deprecate(rightnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
20+
# rightnull(t, (p₁, p₂); kwargs...))
21+
# Base.@deprecate(LinearAlgebra.eigen(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
22+
# LinearAlgebra.eigen(t, (p₁, p₂); kwargs...), false)
23+
# Base.@deprecate(eig(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
24+
# eig(t, (p₁, p₂); kwargs...))
25+
# Base.@deprecate(eigh(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...),
26+
# eigh(t, (p₁, p₂); kwargs...))
2727

2828
for f in (:rand, :randn, :zeros, :ones)
2929
@eval begin

src/spaces/vectorspaces.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,4 +406,3 @@ have the same value.
406406
function supremum(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace}
407407
return supremum(supremum(V₁, V₂), V₃...)
408408
end
409-

0 commit comments

Comments
 (0)