Skip to content

Commit d698dfd

Browse files
committed
add finitedifferences support
1 parent 085a947 commit d698dfd

2 files changed

Lines changed: 20 additions & 24 deletions

File tree

ext/TensorKitFiniteDifferencesExt.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module TensorKitFiniteDifferencesExt
22

33
using TensorKit
4-
using TensorKit: sqrtdim, invsqrtdim
4+
using TensorKit: sqrtdim, invsqrtdim, SectorVector
55
using VectorInterface: scale!
66
using FiniteDifferences
77

@@ -31,6 +31,25 @@ function FiniteDifferences.to_vec(t::DiagonalTensorMap)
3131
return x_vec, DiagonalTensorMap_from_vec
3232
end
3333

34+
function FiniteDifferences.to_vec(v::SectorVector{T, <:Sector}) where {T}
35+
v_normalized = similar(v)
36+
for (c, b) in pairs(v)
37+
scale!(v_normalized[c], b, sqrtdim(c))
38+
end
39+
vec = parent(v_normalized)
40+
vec_real = T <: Real ? vec : collect(reinterpret(real(T), vec))
41+
42+
function from_vec(x_real)
43+
x = T <: Real ? x_real : reinterpret(T, x_real)
44+
v_result = SectorVector(x, v.structure)
45+
for (c, b) in pairs(v_result)
46+
scale!(b, invsqrtdim(c))
47+
end
48+
return v_result
49+
end
50+
return vec_real, from_vec
51+
end
52+
3453
end
3554

3655
# TODO: Investigate why the approach below doesn't work

test/autodiff/ad.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,6 @@ function ChainRulesTestUtils.test_approx(
3636
return nothing
3737
end
3838

39-
# make sure that norms are computed correctly:
40-
function FiniteDifferences.to_vec(t::SectorDict)
41-
T = scalartype(valtype(t))
42-
vec = mapreduce(vcat, t; init = T[]) do (c, b)
43-
return reshape(b, :) .* sqrt(dim(c))
44-
end
45-
vec_real = T <: Real ? vec : collect(reinterpret(real(T), vec))
46-
47-
function from_vec(x_real)
48-
x = T <: Real ? x_real : reinterpret(T, x_real)
49-
ctr = 0
50-
return SectorDict(
51-
c => (
52-
n = length(b);
53-
b′ = reshape(view(x, ctr .+ (1:n)), size(b)) ./ sqrt(dim(c));
54-
ctr += n;
55-
b′
56-
) for (c, b) in t
57-
)
58-
end
59-
return vec_real, from_vec
60-
end
61-
6239
# Float32 and finite differences don't mix well
6340
precision(::Type{<:Union{Float32, Complex{Float32}}}) = 1.0e-2
6441
precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-5

0 commit comments

Comments
 (0)