Skip to content

Commit 0fa5c49

Browse files
authored
Reduce cache footprint by decoupling degeneracy-dependent data (#387)
1 parent fd8dd7f commit 0fa5c49

File tree

18 files changed

+554
-330
lines changed

18 files changed

+554
-330
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "TensorKit"
22
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
3-
authors = ["Jutho Haegeman, Lukas Devos"]
43
version = "0.16.3"
4+
authors = ["Jutho Haegeman, Lukas Devos"]
55

66
[deps]
7+
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
78
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
@@ -41,6 +42,7 @@ CUDA = "5.9"
4142
ChainRulesCore = "1"
4243
ChainRulesTestUtils = "1"
4344
Combinatorics = "1"
45+
Dictionaries = "0.4"
4446
FiniteDifferences = "0.12"
4547
GPUArrays = "11.3.1"
4648
JET = "0.9, 0.10, 0.11"

docs/src/lib/tensors.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ In `TensorMap` instances, all data is gathered in a single `AbstractVector`, whi
9797

9898
To obtain information about the structure of the data, you can use:
9999
```@docs
100-
fusionblockstructure(::AbstractTensorMap)
101100
dim(::AbstractTensorMap)
102101
blocksectors(::AbstractTensorMap)
103102
hasblock(::AbstractTensorMap, ::Sector)

ext/TensorKitMooncakeExt/utility.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ end
6262
Mooncake.tangent_type(::Type{<:VectorSpace}) = Mooncake.NoTangent
6363
Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent
6464

65-
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any}
65+
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.sectorstructure), Any}
66+
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.degeneracystructure), Any}
6667

6768
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple}
6869
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any}

src/TensorKit.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ const TO = TensorOperations
118118

119119
using MatrixAlgebraKit
120120

121+
using Dictionaries: Dictionaries, Dictionary, Indices, gettoken, gettokenvalue
121122
using LRUCache
122123
using OhMyThreads
123124
using ScopedValues
@@ -200,6 +201,23 @@ include("fusiontrees/fusiontrees.jl")
200201
#-------------------------------------------
201202
include("spaces/vectorspaces.jl")
202203

204+
# ElementarySpace types
205+
include("spaces/cartesianspace.jl")
206+
include("spaces/complexspace.jl")
207+
include("spaces/generalspace.jl")
208+
include("spaces/gradedspace.jl")
209+
include("spaces/planarspace.jl")
210+
211+
# CompositeSpace types
212+
include("spaces/productspace.jl")
213+
include("spaces/deligne.jl")
214+
215+
# HomSpace
216+
include("spaces/homspace.jl")
217+
218+
# Derived information
219+
include("spaces/structure.jl")
220+
203221
# Multithreading settings
204222
#-------------------------
205223
const TRANSFORMER_THREADS = Ref(1)

src/auxiliary/dicts.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,25 @@ function Base.:(==)(d1::SortedVectorDict, d2::SortedVectorDict)
263263
end
264264
return true
265265
end
266+
267+
"""
268+
Hashed(value, hashfunction = Base.hash, isequal = Base.isequal)
269+
270+
Wrapper struct to alter the `hash` and `isequal` implementations of a given value.
271+
This is useful in the contexts of dictionaries, where you either want to customize the hashfunction,
272+
or consider various values as equal with a different notion of equality.
273+
"""
274+
struct Hashed{T, H <: Function, E <: Function}
275+
val::T
276+
hashf::H
277+
eqf::E
278+
end
279+
280+
Hashed(val, hashf = Base.hash, eqf = Base.isequal) =
281+
Hashed{typeof(val), typeof(hashf), typeof(eqf)}(val, hashf, eqf)
282+
283+
Base.parent(h::Hashed) = h.val
284+
Base.hash(h::Hashed, seed::UInt) = h.hashf(parent(h), seed)
285+
# Note: requires the equality functions to be equal to avoid asymmetric results
286+
Base.isequal(h1::Hashed{<:Any, <:Any, E}, h2::Hashed{<:Any, <:Any, E}) where {E} =
287+
h1.eqf(parent(h1), parent(h2))

src/fusiontrees/braiding_manipulations.jl

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -299,44 +299,44 @@ Base.@assume_effects :foldable function _fsdicttype(::Type{T}) where {I, N₁, N
299299
return Pair{FusionTreeBlock{I, N₁, N₂, Tuple{F₁, F₂}}, Matrix{E}}
300300
end
301301

302-
@cached function fsbraid(key::K)::_fsdicttype(K) where {I, N₁, N₂, K <: FSPBraidKey{I, N₁, N₂}}
303-
((f₁, f₂), (p1, p2), (l1, l2)) = key
304-
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
305-
levels = (l1..., reverse(l2)...)
306-
(f, f0), coeff1 = repartition((f₁, f₂), N₁ + N₂)
307-
f′, coeff2 = braid(f, p, levels)
308-
(f₁′, f₂′), coeff3 = repartition((f′, f0), N₁)
309-
return (f₁′, f₂′) => coeff1 * coeff2 * coeff3
310-
end
311-
@cached function fsbraid(key::K)::_fsdicttype(K) where {I, N₁, N₂, K <: FSBBraidKey{I, N₁, N₂}}
312-
src, (p1, p2), (l1, l2) = key
302+
@cached function fsbraid(key::K)::_fsdicttype(K) where {I, N₁, N₂, K <: Union{FSPBraidKey{I, N₁, N₂}, FSBBraidKey{I, N₁, N₂}}}
303+
if K <: FSPBraidKey
304+
((f₁, f₂), (p1, p2), (l1, l2)) = key
305+
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
306+
levels = (l1..., reverse(l2)...)
307+
(f, f0), coeff1 = repartition((f₁, f₂), N₁ + N₂)
308+
f′, coeff2 = braid(f, p, levels)
309+
(f₁′, f₂′), coeff3 = repartition((f′, f0), N₁)
310+
return (f₁′, f₂′) => coeff1 * coeff2 * coeff3
313311

314-
p = linearizepermutation(p1, p2, numout(src), numin(src))
315-
levels = (l1..., reverse(l2)...)
312+
else
313+
src, (p1, p2), (l1, l2) = key
316314

317-
dst, U = repartition(src, numind(src))
315+
p = linearizepermutation(p1, p2, numout(src), numin(src))
316+
levels = (l1..., reverse(l2)...)
318317

319-
for s in permutation2swaps(p)
320-
inv = levels[s] > levels[s + 1]
321-
dst, U_tmp = artin_braid(dst, s; inv)
322-
U = U_tmp * U
323-
l = levels[s]
324-
levels = TupleTools.setindex(levels, levels[s + 1], s)
325-
levels = TupleTools.setindex(levels, l, s + 1)
326-
end
318+
dst, U = repartition(src, numind(src))
327319

328-
if N₂ == 0
329-
return dst => U
330-
else
331-
dst, U_tmp = repartition(dst, N₁)
332-
U = U_tmp * U
333-
return dst => U
320+
for s in permutation2swaps(p)
321+
inv = levels[s] > levels[s + 1]
322+
dst, U_tmp = artin_braid(dst, s; inv)
323+
U = U_tmp * U
324+
l = levels[s]
325+
levels = TupleTools.setindex(levels, levels[s + 1], s)
326+
levels = TupleTools.setindex(levels, l, s + 1)
327+
end
328+
329+
if N₂ == 0
330+
return dst => U
331+
else
332+
dst, U_tmp = repartition(dst, N₁)
333+
U = U_tmp * U
334+
return dst => U
335+
end
334336
end
335337
end
336338

337-
CacheStyle(::typeof(fsbraid), k::FSPBraidKey{I}) where {I} =
338-
FusionStyle(I) isa UniqueFusion ? NoCache() : GlobalLRUCache()
339-
CacheStyle(::typeof(fsbraid), k::FSBBraidKey{I}) where {I} =
339+
CacheStyle(::typeof(fsbraid), k::Union{FSPBraidKey{I}, FSBBraidKey{I}}) where {I} =
340340
FusionStyle(I) isa UniqueFusion ? NoCache() : GlobalLRUCache()
341341

342342
"""

src/spaces/gradedspace.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,21 @@ function Base.:(==)(V₁::GradedSpace, V₂::GradedSpace)
191191
return sectortype(V₁) == sectortype(V₂) && (V₁.dims == V₂.dims) && V₁.dual == V₂.dual
192192
end
193193

194+
function sectorhash(V::GradedSpace{I, NTuple{N, Int}}, h::UInt) where {I, N}
195+
return hash(iszero.(V.dims), hash(isdual(V), h))
196+
end
197+
function sectorequal(V₁::GradedSpace{I, D}, V₂::GradedSpace{I, D}) where {I, N, D <: NTuple{N, Int}}
198+
return isdual(V₁) == isdual(V₂) && all(zip(V₁.dims, V₂.dims)) do (d₁, d₂)
199+
return iszero(d₁) == iszero(d₂)
200+
end
201+
end
202+
function sectorhash(V::GradedSpace{I, <:SectorDict}, h::UInt) where {I}
203+
return hash(keys(V.dims), hash(isdual(V), h))
204+
end
205+
function sectorequal(V₁::GradedSpace{I, D}, V₂::GradedSpace{I, D}) where {I, D <: SectorDict}
206+
return isdual(V₁) == isdual(V₂) && keys(V₁.dims) == keys(V₂.dims)
207+
end
208+
194209
Base.summary(io::IO, V::GradedSpace) = print(io, type_repr(typeof(V)))
195210

196211
function Base.show(io::IO, V::GradedSpace)

0 commit comments

Comments
 (0)