1+ using MatrixAlgebraKit: svd_compact_pullback!
2+
13# Factorizations rules
24# --------------------
35function ChainRulesCore. rrule (:: typeof (TensorKit. tsvd!), t:: AbstractTensorMap ;
4- trunc:: TensorKit.TruncationScheme = TensorKit. NoTruncation (),
5- p:: Real = 2 ,
6+ trunc:: TensorKit.TruncationScheme = TensorKit. notrunc (),
67 alg:: Union{TensorKit.SVD,TensorKit.SDD} = TensorKit. SDD ())
7- U, Σ, V⁺, truncerr = tsvd (t; trunc= TensorKit. NoTruncation (), p = p, alg = alg)
8+ U, Σ, V⁺, truncerr = tsvd (t; trunc= TensorKit. notrunc (), alg)
89
9- if ! (trunc isa TensorKit. NoTruncation ) && ! isempty (blocksectors (t))
10+ if ! (trunc == TensorKit. notrunc () ) && ! isempty (blocksectors (t))
1011 Σdata = TensorKit. SectorDict (c => diag (b) for (c, b) in blocks (Σ))
1112
12- truncdim = TensorKit. _compute_truncdim (Σdata, trunc, p )
13- truncerr = TensorKit. _compute_truncerr (Σdata, truncdim, p )
13+ truncdim = TensorKit. _compute_truncdim (Σdata, trunc; p = 2 )
14+ truncerr = TensorKit. _compute_truncerr (Σdata, truncdim; p = 2 )
1415
1516 SVDdata = TensorKit. SectorDict (c => (block (U, c), Σc, block (V⁺, c))
1617 for (c, Σc) in Σdata)
@@ -23,12 +24,11 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
2324 function tsvd!_pullback (ΔUSVϵ)
2425 ΔU, ΔΣ, ΔV⁺, = unthunk .(ΔUSVϵ)
2526 Δt = similar (t)
26- for (c, b) in blocks (Δt)
27- Uc, Σc, V⁺c = block (U, c), block (Σ, c), block (V⁺, c)
28- ΔUc, ΔΣc, ΔV⁺c = block (ΔU, c), block (ΔΣ, c), block (ΔV⁺, c)
29- Σdc = view (Σc, diagind (Σc))
30- ΔΣdc = (ΔΣc isa AbstractZero) ? ΔΣc : view (ΔΣc, diagind (ΔΣc))
31- svd_pullback! (b, Uc, Σdc, V⁺c, ΔUc, ΔΣdc, ΔV⁺c)
27+ foreachblock (Δt) do (c, b)
28+ USVᴴc = block (U, c), block (Σ, c), block (V⁺, c)
29+ ΔUSVᴴc = block (ΔU, c), block (ΔΣ, c), block (ΔV⁺, c)
30+ svd_compact_pullback! (b, USVᴴc, ΔUSVᴴc)
31+ return nothing
3232 end
3333 return NoTangent (), Δt
3434 end
@@ -187,122 +187,122 @@ end
187187# Other implementation considerations for GPU compatibility:
188188# no scalar indexing, lots of broadcasting and views
189189#
190- function svd_pullback! (ΔA:: AbstractMatrix , U:: AbstractMatrix , S:: AbstractVector ,
191- Vd:: AbstractMatrix , ΔU, ΔS, ΔVd;
192- tol:: Real = default_pullback_gaugetol (S))
193-
194- # Basic size checks and determination
195- m, n = size (U, 1 ), size (Vd, 2 )
196- size (U, 2 ) == size (Vd, 1 ) == length (S) == min (m, n) || throw (DimensionMismatch ())
197- p = - 1
198- if ! (ΔU isa AbstractZero)
199- m == size (ΔU, 1 ) || throw (DimensionMismatch ())
200- p = size (ΔU, 2 )
201- end
202- if ! (ΔVd isa AbstractZero)
203- n == size (ΔVd, 2 ) || throw (DimensionMismatch ())
204- if p == - 1
205- p = size (ΔVd, 1 )
206- else
207- p == size (ΔVd, 1 ) || throw (DimensionMismatch ())
208- end
209- end
210- if ! (ΔS isa AbstractZero)
211- if p == - 1
212- p = length (ΔS)
213- else
214- p == length (ΔS) || throw (DimensionMismatch ())
215- end
216- end
217- Up = view (U, :, 1 : p)
218- Vp = view (Vd, 1 : p, :)'
219- Sp = view (S, 1 : p)
220-
221- # rank
222- r = searchsortedlast (S, tol; rev= true )
223-
224- # compute antihermitian part of projection of ΔU and ΔV onto U and V
225- # also already subtract this projection from ΔU and ΔV
226- if ! (ΔU isa AbstractZero)
227- UΔU = Up' * ΔU
228- aUΔU = rmul! (UΔU - UΔU' , 1 / 2 )
229- if m > p
230- ΔU -= Up * UΔU
231- end
232- else
233- aUΔU = fill! (similar (U, (p, p)), 0 )
234- end
235- if ! (ΔVd isa AbstractZero)
236- VΔV = Vp' * ΔVd'
237- aVΔV = rmul! (VΔV - VΔV' , 1 / 2 )
238- if n > p
239- ΔVd -= VΔV' * Vp'
240- end
241- else
242- aVΔV = fill! (similar (Vd, (p, p)), 0 )
243- end
244-
245- # check whether cotangents arise from gauge-invariance objective function
246- mask = abs .(Sp' .- Sp) .< tol
247- Δgauge = norm (view (aUΔU, mask) + view (aVΔV, mask), Inf )
248- if p > r
249- rprange = (r + 1 ): p
250- Δgauge = max (Δgauge, norm (view (aUΔU, rprange, rprange), Inf ))
251- Δgauge = max (Δgauge, norm (view (aVΔV, rprange, rprange), Inf ))
252- end
253- Δgauge < tol ||
254- @warn " `svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
255-
256- UdΔAV = (aUΔU .+ aVΔV) .* safe_inv .(Sp' .- Sp, tol) .+
257- (aUΔU .- aVΔV) .* safe_inv .(Sp' .+ Sp, tol)
258- if ! (ΔS isa ZeroTangent)
259- UdΔAV[diagind (UdΔAV)] .+ = real .(ΔS)
260- # in principle, ΔS is real, but maybe not if coming from an anyonic tensor
261- end
262- mul! (ΔA, Up, UdΔAV * Vp' )
263-
264- if r > p # contribution from truncation
265- Ur = view (U, :, (p + 1 ): r)
266- Vr = view (Vd, (p + 1 ): r, :)'
267- Sr = view (S, (p + 1 ): r)
268-
269- if ! (ΔU isa AbstractZero)
270- UrΔU = Ur' * ΔU
271- if m > r
272- ΔU -= Ur * UrΔU # subtract this part from ΔU
273- end
274- else
275- UrΔU = fill! (similar (U, (r - p, p)), 0 )
276- end
277- if ! (ΔVd isa AbstractZero)
278- VrΔV = Vr' * ΔVd'
279- if n > r
280- ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
281- end
282- else
283- VrΔV = fill! (similar (Vd, (r - p, p)), 0 )
284- end
285-
286- X = (1 // 2 ) .* ((UrΔU .+ VrΔV) .* safe_inv .(Sp' .- Sr, tol) .+
287- (UrΔU .- VrΔV) .* safe_inv .(Sp' .+ Sr, tol))
288- Y = (1 // 2 ) .* ((UrΔU .+ VrΔV) .* safe_inv .(Sp' .- Sr, tol) .-
289- (UrΔU .- VrΔV) .* safe_inv .(Sp' .+ Sr, tol))
290-
291- # ΔA += Ur * X * Vp' + Up * Y' * Vr'
292- mul! (ΔA, Ur, X * Vp' , 1 , 1 )
293- mul! (ΔA, Up * Y' , Vr' , 1 , 1 )
294- end
295-
296- if m > max (r, p) && ! (ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
297- # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp'
298- mul! (ΔA, ΔU .* safe_inv .(Sp' , tol), Vp' , 1 , 1 )
299- end
300- if n > max (r, p) && ! (ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
301- # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd)
302- mul! (ΔA, Up, safe_inv .(Sp, tol) .* ΔVd, 1 , 1 )
303- end
304- return ΔA
305- end
190+ # function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
191+ # Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
192+ # tol::Real=default_pullback_gaugetol(S))
193+
194+ # # Basic size checks and determination
195+ # m, n = size(U, 1), size(Vd, 2)
196+ # size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch())
197+ # p = -1
198+ # if !(ΔU isa AbstractZero)
199+ # m == size(ΔU, 1) || throw(DimensionMismatch())
200+ # p = size(ΔU, 2)
201+ # end
202+ # if !(ΔVd isa AbstractZero)
203+ # n == size(ΔVd, 2) || throw(DimensionMismatch())
204+ # if p == -1
205+ # p = size(ΔVd, 1)
206+ # else
207+ # p == size(ΔVd, 1) || throw(DimensionMismatch())
208+ # end
209+ # end
210+ # if !(ΔS isa AbstractZero)
211+ # if p == -1
212+ # p = length(ΔS)
213+ # else
214+ # p == length(ΔS) || throw(DimensionMismatch())
215+ # end
216+ # end
217+ # Up = view(U, :, 1:p)
218+ # Vp = view(Vd, 1:p, :)'
219+ # Sp = view(S, 1:p)
220+
221+ # # rank
222+ # r = searchsortedlast(S, tol; rev=true)
223+
224+ # # compute antihermitian part of projection of ΔU and ΔV onto U and V
225+ # # also already subtract this projection from ΔU and ΔV
226+ # if !(ΔU isa AbstractZero)
227+ # UΔU = Up' * ΔU
228+ # aUΔU = rmul!(UΔU - UΔU', 1 / 2)
229+ # if m > p
230+ # ΔU -= Up * UΔU
231+ # end
232+ # else
233+ # aUΔU = fill!(similar(U, (p, p)), 0)
234+ # end
235+ # if !(ΔVd isa AbstractZero)
236+ # VΔV = Vp' * ΔVd'
237+ # aVΔV = rmul!(VΔV - VΔV', 1 / 2)
238+ # if n > p
239+ # ΔVd -= VΔV' * Vp'
240+ # end
241+ # else
242+ # aVΔV = fill!(similar(Vd, (p, p)), 0)
243+ # end
244+
245+ # # check whether cotangents arise from gauge-invariance objective function
246+ # mask = abs.(Sp' .- Sp) .< tol
247+ # Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
248+ # if p > r
249+ # rprange = (r + 1):p
250+ # Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
251+ # Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
252+ # end
253+ # Δgauge < tol ||
254+ # @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
255+
256+ # UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
257+ # (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol)
258+ # if !(ΔS isa ZeroTangent)
259+ # UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
260+ # # in principle, ΔS is real, but maybe not if coming from an anyonic tensor
261+ # end
262+ # mul!(ΔA, Up, UdΔAV * Vp')
263+
264+ # if r > p # contribution from truncation
265+ # Ur = view(U, :, (p + 1):r)
266+ # Vr = view(Vd, (p + 1):r, :)'
267+ # Sr = view(S, (p + 1):r)
268+
269+ # if !(ΔU isa AbstractZero)
270+ # UrΔU = Ur' * ΔU
271+ # if m > r
272+ # ΔU -= Ur * UrΔU # subtract this part from ΔU
273+ # end
274+ # else
275+ # UrΔU = fill!(similar(U, (r - p, p)), 0)
276+ # end
277+ # if !(ΔVd isa AbstractZero)
278+ # VrΔV = Vr' * ΔVd'
279+ # if n > r
280+ # ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
281+ # end
282+ # else
283+ # VrΔV = fill!(similar(Vd, (r - p, p)), 0)
284+ # end
285+
286+ # X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+
287+ # (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
288+ # Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .-
289+ # (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
290+
291+ # # ΔA += Ur * X * Vp' + Up * Y' * Vr'
292+ # mul!(ΔA, Ur, X * Vp', 1, 1)
293+ # mul!(ΔA, Up * Y', Vr', 1, 1)
294+ # end
295+
296+ # if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
297+ # # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp'
298+ # mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1)
299+ # end
300+ # if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
301+ # # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd)
302+ # mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1)
303+ # end
304+ # return ΔA
305+ # end
306306
307307function eig_pullback! (ΔA:: AbstractMatrix , D:: AbstractVector , V:: AbstractMatrix , ΔD, ΔV;
308308 tol:: Real = default_pullback_gaugetol (D))
0 commit comments