Skip to content

Commit fecb81d

Browse files
committed
More tweaks
1 parent 3dfc99f commit fecb81d

8 files changed

Lines changed: 108 additions & 109 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ OhMyThreads = "0.8.0"
5353
Printf = "1"
5454
Random = "1"
5555
ScopedValues = "1.3.0"
56-
Strided = "2"
56+
Strided = "2.3.4"
5757
TensorKitSectors = "0.3.7"
5858
TensorOperations = "5.1"
5959
TupleTools = "1.5"

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂,
66
function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
77
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
88
end
9+
function TensorMap{T, S, N₁, N₂, DA}(t::TensorMap{T, S, N₁, N₂, HA}) where {T, S, N₁, N₂, DA <: CuArray{T}, HA <: Array{T}}
10+
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
11+
end
912

1013
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
1114
function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
@@ -101,18 +104,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
101104
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
102105
end
103106

104-
function Base.convert(
105-
TT::Type{CuTensorMap{T, S, N₁, N₂}},
106-
t::AbstractTensorMap{<:Any, S, N₁, N₂}
107-
) where {T, S, N₁, N₂}
108-
if typeof(t) === TT
109-
return t
110-
else
111-
tnew = TT(undef, space(t))
112-
return copy!(tnew, t)
113-
end
114-
end
115-
116107
function LinearAlgebra.isposdef(t::CuTensorMap)
117108
domain(t) == codomain(t) ||
118109
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
@@ -138,10 +129,9 @@ function Base.promote_rule(
138129
return CuTensorMap{T, S, N₁, N₂}
139130
end
140131

141-
TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
132+
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
142133
CuArray{T, N, CUDA.default_memory}
143134

144-
145135
# CuTensorMap exponentation:
146136
function TensorKit.exp!(t::CuTensorMap)
147137
domain(t) == codomain(t) ||

src/tensors/abstracttensor.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t))
5353
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
5454
if T isa Union
5555
# attempt to be slightly more specific by promoting unions
56-
Ma = storagetype(T.a)
57-
Mb = storagetype(T.b)
58-
return promote_storagetype(Ma, Mb)
56+
return promote_storagetype(T.a, T.b)
57+
elseif eltype(T) isa Union
58+
# attempt to be slightly more specific by promoting unions
59+
TU = eltype(T)
60+
return promote_storagetype(TU.a, TU.b)
5961
else
6062
# fallback definition by using scalartype
6163
return similarstoragetype(scalartype(T))
@@ -103,11 +105,19 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =
103105

104106
# implement on tensors
105107
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
106-
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
107-
similarstoragetype(storagetype(TT), T)
108+
function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number}
109+
return similarstoragetype(storagetype(TT), T)
110+
end
111+
function similarstoragetype(::Type{<:AbstractTensorMap{T, S, N₁, N₂}}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂}
112+
return similarstoragetype(TA, T)
113+
end
114+
function similarstoragetype(t::AbstractTensorMap{T, S, N₁, N₂}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂}
115+
return similarstoragetype(typeof(t), TA)
116+
end
108117

109118
# implement on arrays
110119
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
120+
similarstoragetype(::Type{A}, ::Type{A}) where {A <: DenseVector{<:Number}} = A
111121
Base.@assume_effects :foldable similarstoragetype(::Type{A}) where {A <: AbstractArray{<:Number}} =
112122
Core.Compiler.return_type(similar, Tuple{A, Int})
113123
Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} =

src/tensors/adjoint.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Base.adjoint(t::AbstractTensorMap) = AdjointTensorMap(t)
2222
space(t::AdjointTensorMap) = adjoint(space(parent(t)))
2323
dim(t::AdjointTensorMap) = dim(parent(t))
2424
storagetype(::Type{AdjointTensorMap{T, S, N₁, N₂, TT}}) where {T, S, N₁, N₂, TT} = storagetype(TT)
25+
similarstoragetype(::AdjointTensorMap{T, S, N₁, N₂, TT}, ::Type{T′}) where {T, S, N₁, N₂, TT, T′ <: Number} = similarstoragetype(TT, T′)
26+
similarstoragetype(::AdjointTensorMap{T, S, N₁, N₂, TT}, ::Type{TA}) where {T, S, N₁, N₂, TT, TA <: DenseVector} = similarstoragetype(TT, TA)
2527

2628
# Blocks and subblocks
2729
#----------------------

src/tensors/braidingtensor.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,15 @@ end
182182
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
183183
function add_transform!(
184184
tdst::AbstractTensorMap,
185-
tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple,
185+
tsrc::BraidingTensor{T, S},
186+
(p₁, p₂)::Index2Tuple,
186187
fusiontreetransform,
187188
α::Number, β::Number, backend::AbstractBackend...
188-
)
189+
) where {T, S}
190+
tsrc_map = similar(tdst, storagetype(tdst), space(tsrc))
191+
copy!(tsrc_map, tsrc)
189192
return add_transform!(
190-
tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
193+
tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β,
191194
backend...
192195
)
193196
end
@@ -287,11 +290,15 @@ function planarcontract!(
287290
backend, allocator
288291
)
289292
# special case only defined for contracting 2 indices
290-
length(oindB) == length(cindB) == 2 ||
293+
if !(length(oindB) == length(cindB) == 2)
294+
# horrible!!!!!
295+
tB′ = TensorMap(B)
296+
tB = TensorMapWithStorage{eltype(B), similarstoragetype(A, eltype(B)), spacetype(tB′), numout(tB′), numin(tB′)}(tB′)
291297
return planarcontract!(
292-
C, A, (oindA, cindA), TensorMap(B), (cindB, oindB), (p1, p2),
293-
α, β, backend, allocator
294-
)
298+
C, A, (oindA, cindA), tB, (cindB, oindB), (p1, p2),
299+
α, β, backend, allocator
300+
)
301+
end
295302

296303
codA, domA = codomainind(A), domainind(A)
297304
codB, domB = codomainind(B), domainind(B)

src/tensors/diagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ end
280280
# ----------------
281281
function TO.tensoradd_type(TC, A::DiagonalTensorMap, ::Index2Tuple{1, 1}, ::Bool)
282282
M = similarstoragetype(A, TC)
283-
return DiagonalTensorMap{TC, spacetype(A), M}
283+
return DiagonalTensorMap{scalartype(M), spacetype(A), M}
284284
end
285285

286286
function TO.tensorcontract_type(

src/tensors/indexmanipulations.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ for (operation, manipulation) in (
1717
$promote_op(::Type{T}, ::Type{I}) where {T <: Number, I <: Sector} =
1818
sectorscalartype(I) <: Integer ? T :
1919
sectorscalartype(I) <: Real ? float(T) : complex(T)
20+
$promote_op(::Type{TA}, ::Type{I}) where {TA <: DenseVector, I <: Sector} =
21+
similarstoragetype(TA, $promote_op(eltype(TA), I))
2022
# TODO: currently the manipulations all use sectorscalartype, change to:
2123
# $manipulation_scalartype(I) <: Integer ? T :
2224
# $manipulation_scalartype(I) <: Real ? float(T) : complex(T)
@@ -342,11 +344,11 @@ See also [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i})
342344
"""
343345
function insertleftunit(
344346
t::AbstractTensorMap, ::Val{i} = Val(numind(t) + 1);
345-
copy::Bool = false, conj::Bool = false, dual::Bool = false
347+
copy::Bool = false, conj::Bool = false, dual::Bool = false,
346348
) where {i}
347349
W = insertleftunit(space(t), Val(i); conj, dual)
348350
if t isa TensorMap
349-
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
351+
return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W)
350352
else
351353
tdst = similar(t, W)
352354
for (c, b) in blocks(t)
@@ -371,11 +373,11 @@ See also [`insertleftunit`](@ref insertleftunit(::AbstractTensorMap, ::Val{i}) w
371373
"""
372374
function insertrightunit(
373375
t::AbstractTensorMap, ::Val{i} = Val(numind(t));
374-
copy::Bool = false, conj::Bool = false, dual::Bool = false
376+
copy::Bool = false, conj::Bool = false, dual::Bool = false,
375377
) where {i}
376378
W = insertrightunit(space(t), Val(i); conj, dual)
377379
if t isa TensorMap
378-
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
380+
return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W)
379381
else
380382
tdst = similar(t, W)
381383
for (c, b) in blocks(t)
@@ -400,7 +402,7 @@ and [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) wher
400402
function removeunit(t::AbstractTensorMap, ::Val{i}; copy::Bool = false) where {i}
401403
W = removeunit(space(t), Val(i))
402404
if t isa TensorMap
403-
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
405+
return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W)
404406
else
405407
tdst = similar(t, W)
406408
for (c, b) in blocks(t)

0 commit comments

Comments
 (0)