Skip to content

Commit 518cb2a

Browse files
kshyattJutholkdvos
authored
Allow BraidingTensor to have a custom storage type (#393)
* Allow BraidingTensor to have a custom storage type * Remove extraneous lines Co-authored-by: Jutho <Jutho@users.noreply.github.com> * More extraneous line removal * A little more cleanup * Coverage * Update src/tensors/braidingtensor.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Use braidingtensortype * Apply suggestions from code review Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Update src/tensors/braidingtensor.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Update Project.toml bump minor version --------- Co-authored-by: Jutho <Jutho@users.noreply.github.com> Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 2876a6f commit 518cb2a

10 files changed

Lines changed: 459 additions & 113 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorKit"
22
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
3-
version = "0.16.4"
3+
version = "0.17.0"
44
authors = ["Jutho Haegeman, Lukas Devos"]
55

66
[deps]

ext/TensorKitAdaptExt.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
1515
data′ = adapt(to, x.data)
1616
return DiagonalTensorMap(data′, x.domain)
1717
end
18-
function Adapt.adapt_structure(::Type{TorA}, x::BraidingTensor) where {TorA <: Union{Number, DenseArray{<:Number}}}
19-
return BraidingTensor{scalartype(TorA)}(space(x), x.adjoint)
18+
function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A}
19+
A′ = TensorKit.similarstoragetype(A, T)
20+
return BraidingTensor{T, S, A′}(space(x), x.adjoint)
21+
end
22+
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {T′, TA <: DenseArray{T′}, T, S, A}
23+
return BraidingTensor{T′, S, TA}(space(x), x.adjoint)
2024
end
2125

2226
end

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@ module TensorKitCUDAExt
33
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
44
using CUDA: @allowscalar
55
using cuTENSOR: cuTENSOR
6+
using Strided: StridedViews
67
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
8+
using CUDA.KernelAbstractions: @kernel, @index, get_backend
79

810
using TensorKit
911
using TensorKit.Factorizations
1012
using TensorKit.Strided
1113
using TensorKit.Factorizations: AbstractAlgorithm
1214
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13-
import TensorKit: randisometry, rand, randn
15+
import TensorKit: randisometry, rand, randn, fill_braidingsubblock!
1416

1517
using TensorKit: MatrixAlgebraKit
1618

@@ -19,4 +21,18 @@ using Random
1921
include("cutensormap.jl")
2022
include("truncation.jl")
2123

24+
function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}}
25+
# COV_EXCL_START
26+
# kernels are not reachable by coverage
27+
@kernel function fill_subblock_kernel!(subblock, val)
28+
idx = @index(Global, Cartesian)
29+
idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val)
30+
@inbounds subblock[idx] = idx_val
31+
end
32+
# COV_EXCL_STOP
33+
kernel = fill_subblock_kernel!(get_backend(data))
34+
kernel(data, val; ndrange = size(data))
35+
return data
36+
end
37+
2238
end

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,7 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
168168
return tf
169169
end
170170
end
171+
172+
function TensorKit._add_transform_multi!(tdst::CuTensorMap, tsrc, p, (U, structs_dst, structs_src)::Tuple{<:Array, TD, TS}, buffers, alpha, beta, backend...) where {TD, TS}
173+
return TensorKit._add_transform_multi!(tdst, tsrc, p, (CUDA.Adapt.adapt(CuArray, U), structs_dst, structs_src), buffers, alpha, beta, backend...)
174+
end

src/planar/preprocessors.jl

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,23 @@ _add_adjoint(ex) = Expr(TO.prime, ex)
8383
# spaces from the rest of the expression. Construct the explicit BraidingTensor objects and
8484
# insert them in the expression.
8585
function _construct_braidingtensors(ex)
86+
function filter_f(expr)
87+
if TO.istensor(expr)
88+
return _remove_adjoint(TO.decomposetensor(expr)[1]) !=
89+
elseif TO.istensorexpr(expr)
90+
return any(filter_f, expr.args)
91+
else
92+
return false
93+
end
94+
end
95+
function extract_tensors(tensor_ex)
96+
if TO.istensor(tensor_ex)
97+
return [TO.decomposetensor(tensor_ex)[1]]
98+
elseif TO.istensorexpr(tensor_ex)
99+
return collect(Iterators.flatmap(extract_tensors, filter(filter_f, tensor_ex.args)))
100+
end
101+
end
102+
# get storagetype
86103
ex isa Expr || return ex
87104
if ex.head == :macrocall && ex.args[1] == Symbol("@notensor")
88105
return ex
@@ -104,7 +121,9 @@ function _construct_braidingtensors(ex)
104121
)
105122
end
106123
end
107-
newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap)
124+
# if this is a definition, the lhs tensor is NOT yet defined
125+
no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, rhs.args)); init = Symbol[])
126+
newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap, no_τ_ex)
108127
success ||
109128
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
110129
pre = Expr(
@@ -115,7 +134,8 @@ function _construct_braidingtensors(ex)
115134
elseif TO.istensorexpr(ex)
116135
preargs = Vector{Any}()
117136
indexmap = Dict{Any, Any}()
118-
newex, success = _construct_braidingtensors!(ex, preargs, indexmap)
137+
no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, ex.args)); init = Symbol[])
138+
newex, success = _construct_braidingtensors!(ex, preargs, indexmap, no_τ_ex)
119139
success ||
120140
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
121141
pre = Expr(
@@ -128,7 +148,7 @@ function _construct_braidingtensors(ex)
128148
end
129149
end
130150

131-
function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression
151+
function _construct_braidingtensors!(ex, preargs, indexmap, non_braiding) # ex is guaranteed to be a single tensor expression
132152
if TO.isscalarexpr(ex)
133153
# ex could be tensorscalar call with more braiding tensors
134154
return _construct_braidingtensors(ex), true
@@ -163,7 +183,9 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
163183
end
164184
if foundV1 && foundV2
165185
s = gensym()
166-
constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), V1, V2)
186+
storageex = Expr(:call, GlobalRef(TensorKit, :promote_storagetype), non_braiding...)
187+
braidingex = Expr(:call, GlobalRef(TensorKit, :braidingtensortype), V1, V2, storageex)
188+
constructex = Expr(:call, braidingex, V1, V2)
167189
push!(preargs, Expr(:(=), s, constructex))
168190
obj = _is_adjoint(obj) ? _add_adjoint(s) : s
169191
success = true
@@ -196,7 +218,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
196218
newargs = Vector{Any}(undef, length(args))
197219
success = true
198220
for i in 1:length(ex.args)
199-
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap)
221+
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap, non_braiding)
200222
success = success && successa
201223
end
202224
newex = Expr(ex.head, newargs...)
@@ -212,7 +234,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
212234
for i in 2:length(ex.args)
213235
successes[i] && continue
214236
newargs[i], successa = _construct_braidingtensors!(
215-
args[i], preargs, indexmap
237+
args[i], preargs, indexmap, non_braiding
216238
)
217239
successes[i] = successa
218240
end
@@ -232,7 +254,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
232254
indices = [TO.getindices(arg) for arg in args]
233255
for i in 2:length(ex.args)
234256
indexmapa = copy(indexmap)
235-
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa)
257+
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa, non_braiding)
236258
for l in indices[i]
237259
if !haskey(indexmap, l) && haskey(indexmapa, l)
238260
indexmap[l] = indexmapa[l]
@@ -243,10 +265,10 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
243265
newex = Expr(ex.head, newargs...)
244266
return newex, success
245267
elseif isexpr(ex, :call) && ex.args[1] == :/ && length(ex.args) == 3
246-
newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap)
268+
newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap, non_braiding)
247269
return Expr(:call, :/, newarg, ex.args[3]), success
248270
elseif isexpr(ex, :call) && ex.args[1] == :\ && length(ex.args) == 3
249-
newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap)
271+
newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap, non_braiding)
250272
return Expr(:call, :\, ex.args[2], newarg), success
251273
else
252274
error("unexpected expression $ex")

src/tensors/braidingtensor.jl

Lines changed: 78 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,72 +2,78 @@
22
# special (2,2) tensor that implements a standard braiding operation
33
#====================================================================#
44
"""
5-
struct BraidingTensor{T,S<:IndexSpace} <: AbstractTensorMap{T, S, 2, 2}
5+
struct BraidingTensor{T, S <: IndexSpace, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
66
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
7+
BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool=false) where {T, S, A}
78
89
Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
910
braids the first input over the second input; its inverse can be obtained as the adjoint.
1011
1112
It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
12-
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`.
13+
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA`
14+
controls the array type of the braiding tensor used when indexing
15+
and multiplying with other tensors.
1316
"""
14-
struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2}
17+
struct BraidingTensor{T, S, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
1518
V1::S
1619
V2::S
1720
adjoint::Bool
18-
function BraidingTensor{T, S}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
19-
for a in sectors(V1)
20-
for b in sectors(V2)
21-
for c in (a b)
22-
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
23-
throw(ArgumentError("Cannot define a braiding between $a and $b"))
24-
end
25-
end
21+
function BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}}
22+
for a in sectors(V1), b in sectors(V2), c in (a b)
23+
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
24+
throw(ArgumentError("Cannot define a braiding between $a and $b"))
2625
end
27-
return new{T, S}(V1, V2, adjoint)
26+
return new{T, S, A}(V1, V2, adjoint)
2827
# partial construction: only construct rowr and colr when needed
2928
end
3029
end
3130
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
32-
return BraidingTensor{T, S}(V1, V2, adjoint)
31+
return braidingtensortype(S, T)(V1, V2, adjoint)
3332
end
34-
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T}
35-
return BraidingTensor{T}(promote(V1, V2)..., adjoint)
33+
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
34+
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
35+
return BraidingTensor{T}(V1, V2, adjoint)
3636
end
3737
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false)
3838
return BraidingTensor(promote(V1, V2)..., adjoint)
3939
end
40-
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
41-
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
42-
return BraidingTensor{T, S}(V1, V2, adjoint)
43-
end
4440
function BraidingTensor(V::HomSpace, adjoint::Bool = false)
4541
domain(V) == reverse(codomain(V)) ||
4642
throw(SpaceMismatch("Cannot define a braiding on $V"))
4743
return BraidingTensor(V[2], V[1], adjoint)
4844
end
45+
function BraidingTensor{T, S, A}(V::HomSpace, adjoint::Bool = false) where {T, S, A}
46+
domain(V) == reverse(codomain(V)) ||
47+
throw(SpaceMismatch("Cannot define a braiding on $V"))
48+
return BraidingTensor{T, S, A}(V[2], V[1], adjoint)
49+
end
4950
function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T}
5051
domain(V) == reverse(codomain(V)) ||
5152
throw(SpaceMismatch("Cannot define a braiding on $V"))
5253
return BraidingTensor{T}(V[2], V[1], adjoint)
5354
end
54-
function Base.adjoint(b::BraidingTensor{T, S}) where {T, S}
55-
return BraidingTensor{T, S}(b.V1, b.V2, !b.adjoint)
56-
end
5755

58-
space(b::BraidingTensor) = b.adjoint ? b.V1 b.V2 b.V2 b.V1 : b.V2 b.V1 b.V1 b.V2
56+
function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A}
57+
return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint)
58+
end
5959

60-
# specializations to ignore the storagetype of BraidingTensor
61-
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B)
62-
promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A)
63-
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A)
60+
# these are here to make the preprocessing for `@planar` expressions less painful
61+
function braidingtensortype(::Type{S}, ::Type{TorA}) where {S <: IndexSpace, TorA}
62+
A = similarstoragetype(TorA)
63+
return BraidingTensor{scalartype(A), S, A}
64+
end
65+
braidingtensortype(V::S, ::Type{TorA}) where {S <: IndexSpace, TorA} = braidingtensortype(S, TorA)
66+
braidingtensortype(V1::S, V2::S, ::Type{TorA}) where {S <: IndexSpace, TorA} = braidingtensortype(S, TorA)
67+
function braidingtensortype(V1::IndexSpace, V2::IndexSpace, ::Type{TorA}) where {TorA}
68+
S = promote(V1, V2)
69+
return braidingtensortype(S..., TorA)
70+
end
71+
function braidingtensortype(V::HomSpace, ::Type{TorA}) where {TorA}
72+
return braidingtensortype(spacetype(V), TorA)
73+
end
6474

65-
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} =
66-
similarstoragetype(B, T)
67-
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} =
68-
similarstoragetype(A, T)
69-
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} =
70-
similarstoragetype(A, T)
75+
storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A
76+
space(b::BraidingTensor) = b.adjoint ? b.V1 b.V2 b.V2 b.V1 : b.V2 b.V1 b.V1 b.V2
7177

7278
function Base.getindex(b::BraidingTensor)
7379
sectortype(b) === Trivial || throw(SectorMismatch())
@@ -99,6 +105,13 @@ function _braiding_factor(f₁, f₂, inv::Bool = false)
99105
return r
100106
end
101107

108+
# generates scalar indexing errors on GPU
109+
function fill_braidingsubblock!(data, val)
110+
f(I) = ((I[1] == I[4]) & (I[2] == I[3])) * val
111+
return data .= f.(CartesianIndices(data))
112+
end
113+
114+
102115
@inline function subblock(
103116
b::BraidingTensor, (f₁, f₂)::Tuple{FusionTree{I, 2}, FusionTree{I, 2}}
104117
) where {I <: Sector}
@@ -113,17 +126,10 @@ end
113126
throw(SectorMismatch())
114127
end
115128
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
116-
n1 = d[1] * d[2]
117-
n2 = d[3] * d[4]
118-
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
119-
fill!(data, zero(eltype(b)))
120-
129+
data_parent = storagetype(b)(undef, prod(d))
130+
data = sreshape(StridedView(data_parent), d)
121131
r = _braiding_factor(f₁, f₂, b.adjoint)
122-
if !isnothing(r)
123-
@inbounds for i in axes(data, 1), j in axes(data, 2)
124-
data[i, j, j, i] = r
125-
end
126-
end
132+
isnothing(r) ? zerovector!(data) : fill_braidingsubblock!(data, r)
127133
return data
128134
end
129135

@@ -134,49 +140,50 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b)
134140
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)
135141

136142
Base.complex(b::BraidingTensor{<:Complex}) = b
137-
function Base.complex(b::BraidingTensor)
138-
return BraidingTensor{complex(scalartype(b))}(space(b), b.adjoint)
143+
function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A}
144+
Tc = complex(T)
145+
Ac = similarstoragetype(A, Tc)
146+
return BraidingTensor{Tc, S, Ac}(space(b), b.adjoint)
139147
end
140148

141-
function block(b::BraidingTensor, s::Sector)
142-
I = sectortype(b)
143-
I == typeof(s) || throw(SectorMismatch())
144-
145-
# TODO: probably always square?
146-
m = blockdim(codomain(b), s)
147-
n = blockdim(domain(b), s)
148-
data = Matrix{eltype(b)}(undef, (m, n))
149-
150-
length(data) == 0 && return data # s ∉ blocksectors(b)
151-
152-
data = fill!(data, zero(eltype(b)))
153-
149+
# Trivial
150+
function fill_braidingblock!(data, b::BraidingTensor, s::Trivial)
154151
V1, V2 = codomain(b)
155-
if sectortype(b) === Trivial
156-
d1, d2 = dim(V1), dim(V2)
157-
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
158-
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
159-
subblock[i, j, j, i] = one(eltype(b))
160-
end
161-
return data
162-
end
152+
d1, d2 = dim(V1), dim(V2)
153+
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
154+
fill_braidingsubblock!(subblock, one(eltype(b)))
155+
return data
156+
end
163157

158+
# Nontrivial
159+
function fill_braidingblock!(data, b::BraidingTensor, s::Sector)
164160
base_offset = first(blockstructure(b)[s][2]) - 1
165161

166162
for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b)))
167163
(f₁.coupled == f₂.coupled == s) || continue
168164
r = _braiding_factor(f₁, f₂, b.adjoint)
169-
isnothing(r) && continue
170165
# change offset to account for single block
171166
subblock = StridedView(data, sz, str, off - base_offset)
172-
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
173-
subblock[i, j, j, i] = r
174-
end
167+
isnothing(r) ? zerovector!(subblock) : fill_braidingsubblock!(subblock, r)
175168
end
176-
177169
return data
178170
end
179171

172+
function block(b::BraidingTensor, s::Sector)
173+
I = sectortype(b)
174+
I == typeof(s) || throw(SectorMismatch())
175+
176+
# TODO: probably always square?
177+
m = blockdim(codomain(b), s)
178+
n = blockdim(domain(b), s)
179+
180+
data = reshape(storagetype(b)(undef, m * n), (m, n))
181+
182+
m * n == 0 && return data # s ∉ blocksectors(b)
183+
184+
return fill_braidingblock!(data, b, s)
185+
end
186+
180187
# Index manipulations
181188
# -------------------
182189
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false

0 commit comments

Comments
 (0)