Skip to content

Commit 0b283c8

Browse files
kshyattlkdvos
andauthored
Apply suggestions from code review
Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 3ac60b5 commit 0b283c8

3 files changed

Lines changed: 8 additions & 9 deletions

File tree

ext/TensorKitAdaptExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
1616
return DiagonalTensorMap(data′, x.domain)
1717
end
1818
function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A}
19-
return BraidingTensor{T}(space(x), x.adjoint)
19+
A = TensorKit.similarstoragetype(A, T)
20+
return BraidingTensor{T, S, A}(space(x), x.adjoint)
2021
end
2122
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {T′, TA <: DenseArray{T′}, T, S, A}
2223
return BraidingTensor{T′, S, TA}(space(x), x.adjoint)

src/tensors/braidingtensor.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# special (2,2) tensor that implements a standard braiding operation
33
#====================================================================#
44
"""
5-
struct BraidingTensor{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
5+
struct BraidingTensor{T, S <: IndexSpace, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
66
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
77
BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool=false) where {T, S, A}
88
@@ -14,7 +14,7 @@ It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
1414
controls the array type of the braiding tensor used when indexing
1515
and multiplying with other tensors.
1616
"""
17-
struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
17+
struct BraidingTensor{T, S, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
1818
V1::S
1919
V2::S
2020
adjoint::Bool
@@ -133,8 +133,7 @@ end
133133
data_parent = storagetype(b)(undef, prod(d))
134134
data = sreshape(StridedView(data_parent), d)
135135
r = _braiding_factor(f₁, f₂, b.adjoint)
136-
val = isnothing(r) ? zero(eltype(b)) : r
137-
fill_braidingsubblock!(data, val)
136+
isnothing(r) ? zerovector!(data) : fill_braidingsubblock!(data, r)
138137
return data
139138
end
140139

@@ -171,8 +170,7 @@ function fill_braidingblock!(data, b::BraidingTensor, s::Sector)
171170
subblock = StridedView(data, sz, str, off - base_offset)
172171
# without the zero-value, the non-trivial block is not set
173172
# correctly in the GPU case
174-
val = isnothing(r) ? zero(eltype(data)) : r
175-
fill_braidingsubblock!(subblock, val)
173+
isnothing(r) ? zerovector!(subblock) : fill_braidingsubblock!(subblock, r)
176174
end
177175
return data
178176
end

test/setup.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ function force_planar(tsrc::TensorMap{<:Any, ComplexSpace})
115115
undef,
116116
force_planar(codomain(tsrc))
117117
force_planar(domain(tsrc))
118-
)
118+
tdst = similar(tsrc, force_planar(codomain(tsrc)) force_planar(domain(tsrc)))
119119
copyto!(block(tdst, PlanarTrivial()), block(tsrc, Trivial()))
120120
return tdst
121121
end
@@ -124,7 +124,7 @@ function force_planar(tsrc::TensorMap{<:Any, <:GradedSpace})
124124
undef,
125125
force_planar(codomain(tsrc))
126126
force_planar(domain(tsrc))
127-
)
127+
tdst = similar(tsrc, force_planar(codomain(tsrc)) force_planar(domain(tsrc)))
128128
for (c, b) in blocks(tsrc)
129129
copyto!(block(tdst, c PlanarTrivial()), b)
130130
end

0 commit comments

Comments
 (0)