Skip to content

Commit acb439c

Browse files
committed
Cleanup and comments
1 parent fa065bd commit acb439c

4 files changed

Lines changed: 32 additions & 38 deletions

File tree

ext/TensorKitAdaptExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
1616
return DiagonalTensorMap(data′, x.domain)
1717
end
1818
function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A}
19-
return BraidingTensor(space(x), TensorKit.similarstoragetype(A, T), x.adjoint)
19+
return BraidingTensor{T}(space(x), x.adjoint)
2020
end
21-
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {TA <: DenseArray{<:Number}, T, S, A}
22-
return BraidingTensor(space(x), TA, x.adjoint)
21+
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {T′, TA <: DenseArray{T′}, T, S, A}
22+
return BraidingTensor{T′, S, TA}(space(x), x.adjoint)
2323
end
2424

2525
end

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using TensorKit.Factorizations
1414
using TensorKit.Strided
1515
using TensorKit.Factorizations: AbstractAlgorithm
1616
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
17-
import TensorKit: randisometry, rand, randn, _set_subblock!
17+
import TensorKit: randisometry, rand, randn, fill_braidingsubblock!
1818

1919
using TensorKit: MatrixAlgebraKit
2020

@@ -23,15 +23,14 @@ using Random
2323
include("cutensormap.jl")
2424
include("truncation.jl")
2525

26-
function TensorKit._set_subblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}}
26+
function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}}
2727
@kernel function fill_subblock_kernel!(subblock, val)
2828
idx = @index(Global, Cartesian)
29-
@inbounds subblock[idx[1], idx[2], idx[2], idx[1]] = val
29+
idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val)
30+
@inbounds subblock[idx] = idx_val
3031
end
3132
kernel = fill_subblock_kernel!(KernelAbstractions.get_backend(data))
32-
d1 = size(data, 1)
33-
d2 = size(data, 2)
34-
kernel(data, val; ndrange = (d1, d2))
33+
kernel(data, val; ndrange = size(data))
3534
return data
3635
end
3736

src/tensors/braidingtensor.jl

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
#====================================================================#
44
"""
55
struct BraidingTensor{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
6-
BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool=false) where {S<:IndexSpace, A <: DenseVector{<:Number}}
6+
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
7+
BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool=false) where {T, S, A}
78
89
Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
910
braids the first input over the second input; its inverse can be obtained as the adjoint.
@@ -27,16 +28,11 @@ struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
2728
end
2829
end
2930
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
30-
return BraidingTensor{T, S, Vector{T}}(V1, V2, adjoint)
31+
return BraidingTensor{T, S, similarstoragetype{T}}(V1, V2, adjoint)
3132
end
3233
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
3334
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
34-
return BraidingTensor{T, S, Vector{T}}(V1, V2, adjoint)
35-
end
36-
# necessary due to HomSpace ctor below
37-
function BraidingTensor(V1::S, V2::S, A, adjoint::Bool = false) where {S <: IndexSpace}
38-
T = eltype(A)
39-
return BraidingTensor{T, S, A}(V1, V2, adjoint)
35+
return BraidingTensor{T}(V1, V2, adjoint)
4036
end
4137
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false)
4238
return BraidingTensor(promote(V1, V2)..., adjoint)
@@ -46,10 +42,10 @@ function BraidingTensor(V::HomSpace, adjoint::Bool = false)
4642
throw(SpaceMismatch("Cannot define a braiding on $V"))
4743
return BraidingTensor(V[2], V[1], adjoint)
4844
end
49-
function BraidingTensor(V::HomSpace, A, adjoint::Bool = false)
45+
function BraidingTensor{T, S, A}(V::HomSpace, adjoint::Bool = false) where {T, S, A}
5046
domain(V) == reverse(codomain(V)) ||
5147
throw(SpaceMismatch("Cannot define a braiding on $V"))
52-
return BraidingTensor(V[2], V[1], A, adjoint)
48+
return BraidingTensor{T, S, A}(V[2], V[1], adjoint)
5349
end
5450
function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T}
5551
domain(V) == reverse(codomain(V)) ||
@@ -93,7 +89,8 @@ function _braiding_factor(f₁, f₂, inv::Bool = false)
9389
return r
9490
end
9591

96-
function _set_subblock!(data, val)
92+
# generates scalar indexing errors on GPU
93+
function fill_braidingsubblock!(data, val)
9794
f(I) = ((I[1] == I[4]) & (I[2] == I[3])) * val
9895
return data .= f.(CartesianIndices(data))
9996
end
@@ -117,10 +114,8 @@ end
117114
n2 = d[3] * d[4]
118115
data_parent = storagetype(b)(undef, prod(d))
119116
data = sreshape(StridedView(data_parent), d)
120-
fill!(data, zero(eltype(b)))
121-
122117
r = _braiding_factor(f₁, f₂, b.adjoint)
123-
!isnothing(r) && _set_subblock!(data, r)
118+
!isnothing(r) && fill_braidingsubblock!(data, r)
124119
return data
125120
end
126121

@@ -132,28 +127,33 @@ Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)
132127

133128
Base.complex(b::BraidingTensor{<:Complex}) = b
134129
function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A}
135-
Ac = similarstoragetype(A, complex(T))
136-
return BraidingTensor(space(b), Ac, b.adjoint)
130+
Tc = complex(T)
131+
Ac = similarstoragetype(A, Tc)
132+
return BraidingTensor{Tc, S, Ac}(space(b), b.adjoint)
137133
end
138134

139-
function _trivial_subblock!(data, b::BraidingTensor)
135+
# Trivial
136+
function fill_braidingblock!(data, b::BraidingTensor, s::Trivial)
140137
V1, V2 = codomain(b)
141138
d1, d2 = dim(V1), dim(V2)
142139
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
143-
_set_subblock!(subblock, one(eltype(b)))
140+
fill_braidingsubblock!(subblock, one(eltype(b)))
144141
return data
145142
end
146143

147-
function _nontrivial_subblock!(data, b::BraidingTensor, s::Sector)
144+
# Nontrivial
145+
function fill_braidingblock!(data, b::BraidingTensor, s::Sector)
148146
base_offset = first(blockstructure(b)[s][2]) - 1
149147

150148
for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b)))
151149
(f₁.coupled == f₂.coupled == s) || continue
152150
r = _braiding_factor(f₁, f₂, b.adjoint)
153-
isnothing(r) && continue
154151
# change offset to account for single block
155152
subblock = StridedView(data, sz, str, off - base_offset)
156-
_set_subblock!(subblock, r)
153+
# without the zero-value, the non-trivial block is not set
154+
# correctly in the GPU case
155+
val = isnothing(r) ? zero(eltype(data)) : r
156+
fill_braidingsubblock!(subblock, val)
157157
end
158158
return data
159159
end
@@ -169,13 +169,8 @@ function block(b::BraidingTensor, s::Sector)
169169
data = reshape(storagetype(b)(undef, m * n), (m, n))
170170

171171
m * n == 0 && return data # s ∉ blocksectors(b)
172-
fill!(data, zero(eltype(b)))
173172

174-
if sectortype(b) === Trivial
175-
return _trivial_subblock!(data, b)
176-
else
177-
return _nontrivial_subblock!(data, b, s)
178-
end
173+
return fill_braidingblock!(data, b, s)
179174
end
180175

181176
# Index manipulations

test/cuda/planar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ for V in spacelist
1515
@timedtestset "Braiding tensor + CUDA with symmetry: $Istr" verbose = true begin
1616
W = V[1] V[2] V[2] V[1]
1717
T = isreal(sectortype(W)) ? Float64 : ComplexF64
18-
t1 = @constinferred BraidingTensor(W, CuVector{T, CUDA.DeviceMemory})
18+
t1 = @constinferred BraidingTensor{T, spacetype(V[2]), CuVector{T, CUDA.DeviceMemory}}(W)
1919
@test space(t1) == W
2020
@test codomain(t1) == codomain(W)
2121
@test domain(t1) == domain(W)
2222
@test scalartype(t1) == (isreal(sectortype(W)) ? Float64 : ComplexF64)
2323
@test storagetype(t1) == CuVector{scalartype(t1), CUDA.DeviceMemory}
24-
t2 = @constinferred BraidingTensor(W, CuVector{ComplexF64, CUDA.DeviceMemory})
24+
t2 = @constinferred BraidingTensor{ComplexF64, spacetype(V[2]), CuVector{ComplexF64, CUDA.DeviceMemory}}(W)
2525
@test scalartype(t2) == ComplexF64
2626
@test storagetype(t2) == CuVector{ComplexF64, CUDA.DeviceMemory}
2727
t3 = @testinferred adapt(storagetype(t2), t1)

0 commit comments

Comments
 (0)