File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11module TensorKitFiniteDifferencesExt
22
33using TensorKit
4- using TensorKit: sqrtdim, invsqrtdim
4+ using TensorKit: sqrtdim, invsqrtdim, SectorVector
55using VectorInterface: scale!
66using FiniteDifferences
77
@@ -31,6 +31,25 @@ function FiniteDifferences.to_vec(t::DiagonalTensorMap)
3131 return x_vec, DiagonalTensorMap_from_vec
3232end
3333
34+ function FiniteDifferences. to_vec (v:: SectorVector{T, <:Sector} ) where {T}
35+ v_normalized = similar (v)
36+ for (c, b) in pairs (v)
37+ scale! (v_normalized[c], b, sqrtdim (c))
38+ end
39+ vec = parent (v_normalized)
40+ vec_real = T <: Real ? vec : collect (reinterpret (real (T), vec))
41+
42+ function from_vec (x_real)
43+ x = T <: Real ? x_real : reinterpret (T, x_real)
44+ v_result = SectorVector (x, v. structure)
45+ for (c, b) in pairs (v_result)
46+ scale! (b, invsqrtdim (c))
47+ end
48+ return v_result
49+ end
50+ return vec_real, from_vec
51+ end
52+
3453end
3554
3655# TODO : Investigate why the approach below doesn't work
Original file line number Diff line number Diff line change @@ -36,29 +36,6 @@ function ChainRulesTestUtils.test_approx(
3636 return nothing
3737end
3838
39- # make sure that norms are computed correctly:
40- function FiniteDifferences. to_vec (t:: SectorDict )
41- T = scalartype (valtype (t))
42- vec = mapreduce (vcat, t; init = T[]) do (c, b)
43- return reshape (b, :) .* sqrt (dim (c))
44- end
45- vec_real = T <: Real ? vec : collect (reinterpret (real (T), vec))
46-
47- function from_vec (x_real)
48- x = T <: Real ? x_real : reinterpret (T, x_real)
49- ctr = 0
50- return SectorDict (
51- c => (
52- n = length (b);
53- b′ = reshape (view (x, ctr .+ (1 : n)), size (b)) ./ sqrt (dim (c));
54- ctr += n;
55- b′
56- ) for (c, b) in t
57- )
58- end
59- return vec_real, from_vec
60- end
61-
6239# Float32 and finite differences don't mix well
6340precision (:: Type{<:Union{Float32, Complex{Float32}}} ) = 1.0e-2
6441precision (:: Type{<:Union{Float64, Complex{Float64}}} ) = 1.0e-5
You can’t perform that action at this time.
0 commit comments