Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ export scalar, add!, contract!
# truncation schemes
export notrunc, truncerr, truncdim, truncspace, truncbelow

# cache management
export empty_globalcaches!

# Imports
#---------
using TupleTools
Expand Down Expand Up @@ -134,6 +137,7 @@ using PackageExtensionCompat
# Auxiliary files
#-----------------
include("auxiliary/auxiliary.jl")
include("auxiliary/caches.jl")
include("auxiliary/dicts.jl")
include("auxiliary/iterators.jl")
include("auxiliary/linalg.jl")
Expand Down
155 changes: 155 additions & 0 deletions src/auxiliary/caches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
const GLOBAL_CACHES = Pair{Symbol,Any}[]
function empty_globalcaches!()
foreach(empty! ∘ last, GLOBAL_CACHES)
return nothing
end

function global_cache_info(io::IO=stdout)
for (name, cache) in GLOBAL_CACHES
println(io, name, ":\t", LRUCache.cache_info(cache))
end
end

abstract type CacheStyle end
struct NoCache <: CacheStyle end
struct TaskLocalCache{D<:AbstractDict} <: CacheStyle end
struct GlobalLRUCache <: CacheStyle end

const DEFAULT_GLOBALCACHE_SIZE = Ref(10^4)

function CacheStyle(args...)
return GlobalLRUCache()
end

macro cached(ex)
Meta.isexpr(ex, :function) ||
error("cached macro can only be used on function definitions")
fcall = ex.args[1]
if Meta.isexpr(fcall, :where)
hasparams = true
params = fcall.args[2:end]
fcall = fcall.args[1]
else
hasparams = false
end
if Meta.isexpr(fcall, :(::))
typed = true
typeex = fcall.args[2]
fcall = fcall.args[1]
else
typed = false
end
Meta.isexpr(fcall, :call) ||
error("cached macro can only be used on function definitions")
fname = fcall.args[1]
fargs = fcall.args[2:end]
fargnames = map(fargs) do arg
if Meta.isexpr(arg, :(::))
return arg.args[1]
else
return arg
end
end
_fbody = ex.args[2]

# actual implenetation, with underscore name
_fname = Symbol(:_, fname)
_fcall = Expr(:call, _fname, fargs...)
if hasparams
_fcall = Expr(:where, _fcall, params...)
end
_fex = Expr(:function, _fcall, _fbody)

# implementation that chooses the cache style
newfcall = fcall
if hasparams
newfcall = Expr(:where, newfcall, params...)
end
cachestylevar = gensym(:cachestyle)
cachestyleex = Expr(:(=), cachestylevar,
Expr(:call, :CacheStyle, fname, fargnames...))
newfbody = Expr(:block,
cachestyleex,
Expr(:call, fname, fargnames..., cachestylevar))
newfex = Expr(:function, newfcall, newfbody)

# nocache implementation
fnocachecall = Expr(:call, fname, fargs..., :(::NoCache))
if hasparams
fnocachecall = Expr(:where, fnocachecall, params...)
end
fnocachebody = Expr(:call, _fname, fargnames...)
if typed
T = gensym(:T)
fnocachebody = Expr(:block, Expr(:(=), T, typeex), Expr(:(::), fnocachebody, T))
end
fnocacheex = Expr(:function, fnocachecall, fnocachebody)

# tasklocal cache implementation
Dvar = gensym(:D)
flocalcachecall = Expr(:call, fname, fargs..., :(::TaskLocalCache{$Dvar}))
if hasparams
flocalcachecall = Expr(:where, flocalcachecall, params..., Dvar)
else
flocalcachecall = Expr(:where, flocalcachecall, Dvar)
end
localcachename = Symbol(:_tasklocal_, fname, :_cache)
cachevar = gensym(:cache)
getlocalcacheex = :($cachevar::$Dvar = get!(task_local_storage(), $localcachename) do
return $Dvar()
end)
valvar = gensym(:val)
if length(fargnames) == 1
key = fargnames[1]
else
key = Expr(:tuple, fargnames...)
end
getvalex = :(get!($cachevar, $key) do
return $_fname($(fargnames...))
end)
if typed
T = gensym(:T)
flocalcachebody = Expr(:block,
getlocalcacheex,
Expr(:(=), T, typeex),
Expr(:(=), Expr(:(::), valvar, T), getvalex),
Expr(:return, valvar))
else
flocalcachebody = Expr(:block,
getlocalcacheex,
Expr(:(=), valvar, getvalex),
Expr(:return, valvar))
end
flocalcacheex = Expr(:function, flocalcachecall, flocalcachebody)

# # global cache implementation
fglobalcachecall = Expr(:call, fname, fargs..., :(::GlobalLRUCache))
if hasparams
fglobalcachecall = Expr(:where, fglobalcachecall, params...)
end
globalcachename = Symbol(:GLOBAL_, uppercase(string(fname)), :_CACHE)
getglobalcachex = Expr(:(=), cachevar, globalcachename)
if typed
T = gensym(:T)
fglobalcachebody = Expr(:block,
getglobalcachex,
Expr(:(=), T, typeex),
Expr(:(=), Expr(:(::), valvar, T), getvalex),
Expr(:return, valvar))
else
fglobalcachebody = Expr(:block,
getglobalcachex,
Expr(:(=), valvar, getvalex),
Expr(:return, valvar))
end
fglobalcacheex = Expr(:function, fglobalcachecall, fglobalcachebody)
fglobalcachedef = Expr(:const,
Expr(:(=), globalcachename,
:(LRU{Any,Any}(; maxsize=DEFAULT_GLOBALCACHE_SIZE[]))))
fglobalcacheregister = Expr(:call, :push!, :GLOBAL_CACHES,
:($(QuoteNode(globalcachename)) => $globalcachename))

# # total expression
return esc(Expr(:block, _fex, newfex, fnocacheex, flocalcacheex,
fglobalcachedef, fglobalcacheregister, fglobalcacheex))
end
97 changes: 37 additions & 60 deletions src/fusiontrees/manipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,6 @@ function _recursive_repartition(f₁::FusionTree{I,N₁},
end
end

# transpose double fusion tree
const transposecache = LRU{Any,Any}(; maxsize=10^5)
const usetransposecache = Ref{Bool}(true)

"""
transpose(f₁::FusionTree{I}, f₂::FusionTree{I},
p1::NTuple{N₁, Int}, p2::NTuple{N₂, Int}) where {I, N₁, N₂}
Expand All @@ -548,28 +544,24 @@ function Base.transpose(f₁::FusionTree{I}, f₂::FusionTree{I},
@assert length(f₁) + length(f₂) == N
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
@assert iscyclicpermutation(p)
if usetransposecache[]
T = sectorscalartype(I)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = fusiontreedict(I){Tuple{F₁,F₂},T}
return _get_transpose(D, (f₁, f₂, p1, p2))
else
return _transpose((f₁, f₂, p1, p2))
end
return fstranspose((f₁, f₂, p1, p2))
end

@noinline function _get_transpose(::Type{D}, @nospecialize(key)) where {D}
d::D = get!(transposecache, key) do
return _transpose(key)
end
return d
end
const FSTransposeKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
IndexTuple{N₁},IndexTuple{N₂}}

const TransposeKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
IndexTuple{N₁},IndexTuple{N₂}}
function _fsdicttype(I, N₁, N₂)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
T = sectorscalartype(I)
return fusiontreedict(I){Tuple{F₁,F₂},T}
end

function _transpose((f₁, f₂, p1, p2)::TransposeKey{I,N₁,N₂}) where {I<:Sector,N₁,N₂}
@cached function fstranspose(key::FSTransposeKey{I,N₁,N₂})::_fsdicttype(I, N₁,
N₂) where {I<:Sector,
N₁,
N₂}
f₁, f₂, p1, p2 = key
N = N₁ + N₂
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
newtrees = repartition(f₁, f₂, N₁)
Expand Down Expand Up @@ -611,6 +603,14 @@ function _transpose((f₁, f₂, p1, p2)::TransposeKey{I,N₁,N₂}) where {I<:S
return newtrees
end

function CacheStyle(::typeof(fstranspose), k::FSTransposeKey{I}) where {I<:Sector}
if FusionStyle(I) isa UniqueFusion
return NoCache()
else
return GlobalLRUCache()
end
end

# COMPOSITE DUALITY MANIPULATIONS PART 2: Planar traces
#-------------------------------------------------------------------
# -> composite manipulations that depend on the duality (rigidity) and pivotal structure
Expand Down Expand Up @@ -1015,10 +1015,6 @@ function permute(f::FusionTree{I,N}, p::NTuple{N,Int}) where {I<:Sector,N}
end

# braid double fusion tree
const braidcache = LRU{Any,Any}(; maxsize=10^5)
const usebraidcache_abelian = Ref{Bool}(false)
const usebraidcache_nonabelian = Ref{Bool}(true)

"""
braid(f₁::FusionTree{I}, f₂::FusionTree{I},
levels1::IndexTuple, levels2::IndexTuple,
Expand All @@ -1043,42 +1039,15 @@ function braid(f₁::FusionTree{I}, f₂::FusionTree{I},
@assert length(f₁) + length(f₂) == N₁ + N₂
@assert length(f₁) == length(levels1) && length(f₂) == length(levels2)
@assert TupleTools.isperm((p1..., p2...))
if FusionStyle(f₁) isa UniqueFusion &&
BraidingStyle(f₁) isa SymmetricBraiding
if usebraidcache_abelian[]
T = Int # do we hardcode this ?
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = SingletonDict{Tuple{F₁,F₂},T}
return _get_braid(D, (f₁, f₂, levels1, levels2, p1, p2))
else
return _braid((f₁, f₂, levels1, levels2, p1, p2))
end
else
if usebraidcache_nonabelian[]
T = sectorscalartype(I)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = FusionTreeDict{Tuple{F₁,F₂},T}
return _get_braid(D, (f₁, f₂, levels1, levels2, p1, p2))
else
return _braid((f₁, f₂, levels1, levels2, p1, p2))
end
end
return fsbraid((f₁, f₂, levels1, levels2, p1, p2))
end
const FSBraidKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
IndexTuple,IndexTuple,
IndexTuple{N₁},IndexTuple{N₂}}

@noinline function _get_braid(::Type{D}, @nospecialize(key)) where {D}
d::D = get!(braidcache, key) do
return _braid(key)
end
return d
end

const BraidKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
IndexTuple,IndexTuple,
IndexTuple{N₁},IndexTuple{N₂}}

function _braid((f₁, f₂, l1, l2, p1, p2)::BraidKey{I,N₁,N₂}) where {I<:Sector,N₁,N₂}
@cached function fsbraid(key::FSBraidKey{I,N₁,N₂})::_fsdicttype(I, N₁,
N₂) where {I<:Sector,N₁,N₂}
(f₁, f₂, l1, l2, p1, p2) = key
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
levels = (l1..., reverse(l2)...)
local newtrees
Expand All @@ -1097,6 +1066,14 @@ function _braid((f₁, f₂, l1, l2, p1, p2)::BraidKey{I,N₁,N₂}) where {I<:S
return newtrees
end

function CacheStyle(::typeof(fsbraid), k::FSBraidKey{I}) where {I<:Sector}
if FusionStyle(I) isa UniqueFusion
return NoCache()
else
return GlobalLRUCache()
end
end

"""
permute(f₁::FusionTree{I}, f₂::FusionTree{I},
p1::NTuple{N₁, Int}, p2::NTuple{N₂, Int}) where {I, N₁, N₂}
Expand Down
51 changes: 11 additions & 40 deletions src/spaces/homspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,18 +233,17 @@ struct FusionBlockStructure{I,N,F₁,F₂}
fusiontreeindices::FusionTreeDict{Tuple{F₁,F₂},Int}
end

abstract type CacheStyle end
struct NoCache <: CacheStyle end
struct TaskLocalCache{D<:AbstractDict} <: CacheStyle end
struct GlobalLRUCache <: CacheStyle end

function CacheStyle(I::Type{<:Sector})
return GlobalLRUCache()
function fusionblockstructuretype(W::HomSpace)
Comment thread
lkdvos marked this conversation as resolved.
N₁ = length(codomain(W))
N₂ = length(domain(W))
N = N₁ + N₂
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
return FusionBlockStructure{I,N,F₁,F₂}
end

fusionblockstructure(W::HomSpace) = fusionblockstructure(W, CacheStyle(sectortype(W)))

function fusionblockstructure(W::HomSpace, ::NoCache)
@cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W)
codom = codomain(W)
dom = domain(W)
N₁ = length(codom)
Expand Down Expand Up @@ -317,36 +316,8 @@ function _subblock_strides(subsz, sz, str)
return Strided.StridedViews._computereshapestrides(subsz, sz_simplify...)
end

function fusionblockstructure(W::HomSpace, ::TaskLocalCache{D}) where {D}
cache::D = get!(task_local_storage(), :_local_tensorstructure_cache) do
return D()
end
N₁ = length(codomain(W))
N₂ = length(domain(W))
N = N₁ + N₂
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
structure::FusionBlockStructure{I,N,F₁,F₂} = get!(cache, W) do
return fusionblockstructure(W, NoCache())
end
return structure
end

const GLOBAL_FUSIONBLOCKSTRUCTURE_CACHE = LRU{Any,Any}(; maxsize=10^4)
# 10^4 different tensor spaces should be enough for most purposes
function fusionblockstructure(W::HomSpace, ::GlobalLRUCache)
cache = GLOBAL_FUSIONBLOCKSTRUCTURE_CACHE
N₁ = length(codomain(W))
N₂ = length(domain(W))
N = N₁ + N₂
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
structure::FusionBlockStructure{I,N,F₁,F₂} = get!(cache, W) do
return fusionblockstructure(W, NoCache())
end
return structure
function CacheStyle(::typeof(fusionblockstructure), W::HomSpace)
return GlobalLRUCache()
end

# Diagonal ranges
Expand Down
Loading
Loading