Skip to content

Commit 547cdd3

Browse files
authored
Adapt extension (#45)
* add Adapt extension * add adapt tests * add missing constructor
1 parent dc6696a commit 547cdd3

5 files changed

Lines changed: 60 additions & 1 deletion

File tree

Project.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
1515
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1616
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1717

18+
[weakdeps]
19+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
20+
21+
[extensions]
22+
BlockTensorKitAdaptExt = "Adapt"
23+
1824
[compat]
25+
Adapt = "4"
1926
Aqua = "0.8"
2027
BlockArrays = "1"
2128
Combinatorics = "1"
@@ -34,6 +41,7 @@ VectorInterface = "0.4.8, 0.5"
3441
julia = "1.10"
3542

3643
[extras]
44+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3745
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3846
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
3947
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -42,4 +50,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4250
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
4351

4452
[targets]
45-
test = ["Test", "TestExtras", "Random", "Combinatorics", "SafeTestsets", "Aqua"]
53+
test = ["Test", "TestExtras", "Random", "Combinatorics", "SafeTestsets", "Aqua", "Adapt"]

ext/BlockTensorKitAdaptExt.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module BlockTensorKitAdaptExt
2+
3+
using TensorKit
4+
using BlockTensorKit
5+
using Adapt
6+
7+
function Adapt.adapt_structure(to, x::BlockTensorMap)
8+
data′ = map(adapt(to), x.data)
9+
return BlockTensorMap(data′, space(x))
10+
end
11+
12+
function Adapt.adapt_structure(to, x::SparseBlockTensorMap)
13+
data′ = Dict(I => adapt(to, v) for (I, v) in x.data)
14+
return SparseBlockTensorMap(data′, space(x))
15+
end
16+
17+
end

src/tensors/sparseblocktensor.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ function SparseBlockTensorMap{TT}(
4646
) where {E, S, N₁, N₂, N, TT <: AbstractTensorMap{E, S, N₁, N₂}}
4747
return SparseBlockTensorMap{TT, E, S, N₁, N₂, N}(data, space)
4848
end
49+
function SparseBlockTensorMap(
50+
data::Dict{CartesianIndex{N}, TT}, space::TensorMapSumSpace{S, N₁, N₂}
51+
) where {E, S, N₁, N₂, N, TT <: AbstractTensorMap{E, S, N₁, N₂}}
52+
return SparseBlockTensorMap{TT}(data, space)
53+
end
4954

5055
function sparseblocktensormaptype(::Type{S}, N₁::Int, N₂::Int, ::Type{T}) where {S, T}
5156
TT = tensormaptype(S, N₁, N₂, T)

test/abstracttensor/blocktensor.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using TensorKit
44
using BlockTensorKit
55
using Random
66
using Combinatorics
7+
using Adapt
78

89
Vtr = (
910
SumSpace(ℂ^3),
@@ -84,6 +85,19 @@ end
8485
end
8586
end
8687

88+
@testset "Adapt" begin
89+
W = V1 V2 V3 V4 V5
90+
t1 = rand(Float32, W)
91+
for T in (Float64, ComplexF64)
92+
t2 = @testinferred adapt(Vector{T}, t1)
93+
@test t2 isa BlockTensorMap
94+
@test scalartype(t2) == T
95+
@test storagetype(t2) == Vector{T}
96+
@test space(t1) == space(t2)
97+
@test norm(t1) norm(t2)
98+
end
99+
end
100+
87101
@testset "Basic linear algebra" begin
88102
W = V1 V2 V3 V4 V5
89103
for T in (Float32, ComplexF64)

test/abstracttensor/sparseblocktensor.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using TensorKit
33
using BlockTensorKit
44
using Random
55
using Combinatorics
6+
using Adapt
67

78
Vtr = (
89
SumSpace(ℂ^3),
@@ -90,6 +91,20 @@ end
9091
end
9192
end
9293

94+
@testset "Adapt" begin
95+
W = V1 V2 V3 V4 V5
96+
t1 = sprand(Float32, W, 0.5)
97+
for T in (Float64, ComplexF64)
98+
t2 = @testinferred adapt(Vector{T}, t1)
99+
@test t2 isa SparseBlockTensorMap
100+
@test scalartype(t2) == T
101+
@test storagetype(t2) == Vector{T}
102+
@test space(t1) == space(t2)
103+
@test norm(t1) norm(t2)
104+
end
105+
end
106+
107+
93108
@testset "Basic linear algebra" begin
94109
W = V1 V2 V3 V4 V5
95110
for T in (Float32, ComplexF64)

0 commit comments

Comments
 (0)