Skip to content

Commit 5f75ba3

Browse files
committed
Add scheduler support
1 parent 68437b9 commit 5f75ba3

3 files changed

Lines changed: 35 additions & 0 deletions

File tree

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ version = "0.14.5"
77
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
10+
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
1011
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1214
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1315
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1416
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
@@ -33,8 +35,10 @@ FiniteDifferences = "0.12"
3335
LRUCache = "1.0.2"
3436
LinearAlgebra = "1"
3537
MatrixAlgebraKit = "0.1.1"
38+
OhMyThreads = "0.8.0"
3639
PackageExtensionCompat = "1"
3740
Random = "1"
41+
ScopedValues = "1.3.0"
3842
SparseArrays = "1"
3943
Strided = "2"
4044
TensorKitSectors = "0.1"

src/TensorKit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ const TO = TensorOperations
103103
using MatrixAlgebraKit: MatrixAlgebraKit as MAK
104104

105105
using LRUCache
106+
using OhMyThreads
107+
using ScopedValues
106108

107109
using TensorKitSectors
108110
import TensorKitSectors: dim, BraidingStyle, FusionStyle, ,
@@ -186,6 +188,7 @@ include("spaces/vectorspaces.jl")
186188
#-------------------------------------
187189
# general definitions
188190
include("tensors/abstracttensor.jl")
191+
include("tensors/backends.jl")
189192
include("tensors/blockiterator.jl")
190193
include("tensors/tensor.jl")
191194
include("tensors/adjoint.jl")

src/tensors/backends.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Scheduler implementation
2+
# ------------------------
3+
function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
4+
return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs)
5+
Threads.nthreads() > 1 ? SerialScheduler() : DynamicScheduler()
6+
else
7+
OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...)
8+
end
9+
end
10+
11+
"""
12+
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())
13+
14+
The default scheduler used when looping over different blocks in the matrix representation of a
15+
tensor.
16+
For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref).
17+
"""
18+
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())
19+
20+
"""
21+
with_blockscheduler(f, [scheduler]; kwargs...)
22+
23+
Run `f` in a scope where the `blockscheduler` is determined by `scheduler' and `kwargs...`.
24+
"""
25+
@inline function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();
26+
kwargs...)
27+
@with blockscheduler => select_scheduler(scheduler; kwargs...) f()
28+
end

0 commit comments

Comments
 (0)