Skip to content

Commit fa065bd

Browse files
committed
Allow BraidingTensor to have a custom storage type
1 parent 34ac960 commit fa065bd

7 files changed

Lines changed: 419 additions & 85 deletions

File tree

ext/TensorKitAdaptExt.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
1515
data′ = adapt(to, x.data)
1616
return DiagonalTensorMap(data′, x.domain)
1717
end
18-
function Adapt.adapt_structure(::Type{TorA}, x::BraidingTensor) where {TorA <: Union{Number, DenseArray{<:Number}}}
19-
return BraidingTensor{scalartype(TorA)}(space(x), x.adjoint)
18+
function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A}
19+
return BraidingTensor(space(x), TensorKit.similarstoragetype(A, T), x.adjoint)
20+
end
21+
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {TA <: DenseArray{<:Number}, T, S, A}
22+
return BraidingTensor(space(x), TA, x.adjoint)
2023
end
2124

2225
end

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@ module TensorKitCUDAExt
33
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
44
using CUDA: @allowscalar
55
using cuTENSOR: cuTENSOR
6+
using Strided: StridedViews
67
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
78

9+
using CUDA: KernelAbstractions
10+
using CUDA.KernelAbstractions: @kernel, @index
11+
812
using TensorKit
913
using TensorKit.Factorizations
1014
using TensorKit.Strided
1115
using TensorKit.Factorizations: AbstractAlgorithm
1216
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13-
import TensorKit: randisometry, rand, randn
17+
import TensorKit: randisometry, rand, randn, _set_subblock!
1418

1519
using TensorKit: MatrixAlgebraKit
1620

@@ -19,4 +23,16 @@ using Random
1923
include("cutensormap.jl")
2024
include("truncation.jl")
2125

26+
function TensorKit._set_subblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}}
27+
@kernel function fill_subblock_kernel!(subblock, val)
28+
idx = @index(Global, Cartesian)
29+
@inbounds subblock[idx[1], idx[2], idx[2], idx[1]] = val
30+
end
31+
kernel = fill_subblock_kernel!(KernelAbstractions.get_backend(data))
32+
d1 = size(data, 1)
33+
d2 = size(data, 2)
34+
kernel(data, val; ndrange = (d1, d2))
35+
return data
36+
end
37+
2238
end

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,23 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
168168
return tf
169169
end
170170
end
171+
172+
173+
function TensorKit.add_kernel_nonthreaded!(
174+
::TensorKit.FusionStyle,
175+
tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend...
176+
)
177+
# preallocate buffers
178+
buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer)
179+
180+
for subtransformer in transformer.data
181+
# Special case without intermediate buffers whenever there is only a single block
182+
if length(subtransformer[1]) == 1
183+
TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
184+
else
185+
cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...)
186+
TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...)
187+
end
188+
end
189+
return nothing
190+
end

src/tensors/braidingtensor.jl

Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,73 +2,67 @@
22
# special (2,2) tensor that implements a standard braiding operation
33
#====================================================================#
44
"""
5-
struct BraidingTensor{T,S<:IndexSpace} <: AbstractTensorMap{T, S, 2, 2}
6-
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
5+
struct BraidingTensor{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
6+
BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool=false) where {S<:IndexSpace, A <: DenseVector{<:Number}}
77
88
Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
99
braids the first input over the second input; its inverse can be obtained as the adjoint.
1010
1111
It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
12-
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`.
12+
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA`
13+
controls the array type of the braiding tensor used when indexing
14+
and multiplying with other tensors.
1315
"""
14-
struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2}
16+
struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
1517
V1::S
1618
V2::S
1719
adjoint::Bool
18-
function BraidingTensor{T, S}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
19-
for a in sectors(V1)
20-
for b in sectors(V2)
21-
for c in (a b)
22-
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
23-
throw(ArgumentError("Cannot define a braiding between $a and $b"))
24-
end
25-
end
20+
function BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}}
21+
for a in sectors(V1), b in sectors(V2), c in (a b)
22+
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
23+
throw(ArgumentError("Cannot define a braiding between $a and $b"))
2624
end
27-
return new{T, S}(V1, V2, adjoint)
25+
return new{T, S, A}(V1, V2, adjoint)
2826
# partial construction: only construct rowr and colr when needed
2927
end
3028
end
3129
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
32-
return BraidingTensor{T, S}(V1, V2, adjoint)
30+
return BraidingTensor{T, S, Vector{T}}(V1, V2, adjoint)
3331
end
34-
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T}
35-
return BraidingTensor{T}(promote(V1, V2)..., adjoint)
32+
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
33+
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
34+
return BraidingTensor{T, S, Vector{T}}(V1, V2, adjoint)
35+
end
36+
# necessary due to HomSpace ctor below
37+
function BraidingTensor(V1::S, V2::S, A, adjoint::Bool = false) where {S <: IndexSpace}
38+
T = eltype(A)
39+
return BraidingTensor{T, S, A}(V1, V2, adjoint)
3640
end
3741
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false)
3842
return BraidingTensor(promote(V1, V2)..., adjoint)
3943
end
40-
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
41-
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
42-
return BraidingTensor{T, S}(V1, V2, adjoint)
43-
end
4444
function BraidingTensor(V::HomSpace, adjoint::Bool = false)
4545
domain(V) == reverse(codomain(V)) ||
4646
throw(SpaceMismatch("Cannot define a braiding on $V"))
4747
return BraidingTensor(V[2], V[1], adjoint)
4848
end
49+
function BraidingTensor(V::HomSpace, A, adjoint::Bool = false)
50+
domain(V) == reverse(codomain(V)) ||
51+
throw(SpaceMismatch("Cannot define a braiding on $V"))
52+
return BraidingTensor(V[2], V[1], A, adjoint)
53+
end
4954
function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T}
5055
domain(V) == reverse(codomain(V)) ||
5156
throw(SpaceMismatch("Cannot define a braiding on $V"))
5257
return BraidingTensor{T}(V[2], V[1], adjoint)
5358
end
54-
function Base.adjoint(b::BraidingTensor{T, S}) where {T, S}
55-
return BraidingTensor{T, S}(b.V1, b.V2, !b.adjoint)
59+
function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A}
60+
return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint)
5661
end
5762

63+
storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A
5864
space(b::BraidingTensor) = b.adjoint ? b.V1 b.V2 b.V2 b.V1 : b.V2 b.V1 b.V1 b.V2
5965

60-
# specializations to ignore the storagetype of BraidingTensor
61-
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B)
62-
promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A)
63-
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A)
64-
65-
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} =
66-
similarstoragetype(B, T)
67-
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} =
68-
similarstoragetype(A, T)
69-
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} =
70-
similarstoragetype(A, T)
71-
7266
function Base.getindex(b::BraidingTensor)
7367
sectortype(b) === Trivial || throw(SectorMismatch())
7468
(V1, V2) = domain(b)
@@ -99,6 +93,12 @@ function _braiding_factor(f₁, f₂, inv::Bool = false)
9993
return r
10094
end
10195

96+
function _set_subblock!(data, val)
97+
f(I) = ((I[1] == I[4]) & (I[2] == I[3])) * val
98+
return data .= f.(CartesianIndices(data))
99+
end
100+
101+
102102
@inline function subblock(
103103
b::BraidingTensor, (f₁, f₂)::Tuple{FusionTree{I, 2}, FusionTree{I, 2}}
104104
) where {I <: Sector}
@@ -115,15 +115,12 @@ end
115115
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
116116
n1 = d[1] * d[2]
117117
n2 = d[3] * d[4]
118-
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
118+
data_parent = storagetype(b)(undef, prod(d))
119+
data = sreshape(StridedView(data_parent), d)
119120
fill!(data, zero(eltype(b)))
120121

121122
r = _braiding_factor(f₁, f₂, b.adjoint)
122-
if !isnothing(r)
123-
@inbounds for i in axes(data, 1), j in axes(data, 2)
124-
data[i, j, j, i] = r
125-
end
126-
end
123+
!isnothing(r) && _set_subblock!(data, r)
127124
return data
128125
end
129126

@@ -134,33 +131,20 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b)
134131
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)
135132

136133
Base.complex(b::BraidingTensor{<:Complex}) = b
137-
function Base.complex(b::BraidingTensor)
138-
return BraidingTensor{complex(scalartype(b))}(space(b), b.adjoint)
134+
function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A}
135+
Ac = similarstoragetype(A, complex(T))
136+
return BraidingTensor(space(b), Ac, b.adjoint)
139137
end
140138

141-
function block(b::BraidingTensor, s::Sector)
142-
I = sectortype(b)
143-
I == typeof(s) || throw(SectorMismatch())
144-
145-
# TODO: probably always square?
146-
m = blockdim(codomain(b), s)
147-
n = blockdim(domain(b), s)
148-
data = Matrix{eltype(b)}(undef, (m, n))
149-
150-
length(data) == 0 && return data # s ∉ blocksectors(b)
151-
152-
data = fill!(data, zero(eltype(b)))
153-
139+
function _trivial_subblock!(data, b::BraidingTensor)
154140
V1, V2 = codomain(b)
155-
if sectortype(b) === Trivial
156-
d1, d2 = dim(V1), dim(V2)
157-
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
158-
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
159-
subblock[i, j, j, i] = one(eltype(b))
160-
end
161-
return data
162-
end
141+
d1, d2 = dim(V1), dim(V2)
142+
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
143+
_set_subblock!(subblock, one(eltype(b)))
144+
return data
145+
end
163146

147+
function _nontrivial_subblock!(data, b::BraidingTensor, s::Sector)
164148
base_offset = first(blockstructure(b)[s][2]) - 1
165149

166150
for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b)))
@@ -169,14 +153,31 @@ function block(b::BraidingTensor, s::Sector)
169153
isnothing(r) && continue
170154
# change offset to account for single block
171155
subblock = StridedView(data, sz, str, off - base_offset)
172-
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
173-
subblock[i, j, j, i] = r
174-
end
156+
_set_subblock!(subblock, r)
175157
end
176-
177158
return data
178159
end
179160

161+
function block(b::BraidingTensor, s::Sector)
162+
I = sectortype(b)
163+
I == typeof(s) || throw(SectorMismatch())
164+
165+
# TODO: probably always square?
166+
m = blockdim(codomain(b), s)
167+
n = blockdim(domain(b), s)
168+
169+
data = reshape(storagetype(b)(undef, m * n), (m, n))
170+
171+
m * n == 0 && return data # s ∉ blocksectors(b)
172+
fill!(data, zero(eltype(b)))
173+
174+
if sectortype(b) === Trivial
175+
return _trivial_subblock!(data, b)
176+
else
177+
return _nontrivial_subblock!(data, b, s)
178+
end
179+
end
180+
180181
# Index manipulations
181182
# -------------------
182183
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false

0 commit comments

Comments
 (0)