Skip to content

Commit 8a12178

Browse files
committed
More tweaks
1 parent cdf66fb commit 8a12178

File tree

9 files changed

+145
-133
lines changed

9 files changed

+145
-133
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Printf = "1"
5555
Random = "1"
5656
SafeTestsets = "0.1"
5757
ScopedValues = "1.3.0"
58-
Strided = "2"
58+
Strided = "2.3.4"
5959
TensorKitSectors = "0.3.6"
6060
TensorOperations = "5.1"
6161
Test = "1"
@@ -89,3 +89,6 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
8989

9090
[targets]
9191
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "JET", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]
92+
93+
[sources]
94+
Strided = {url = "https://github.com/QuantumKitHub/Strided.jl", rev = "ksh/copyto"}

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using TensorKit.Factorizations
1010
using TensorKit.Strided
1111
using TensorKit.Factorizations: AbstractAlgorithm
1212
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13-
import TensorKit: randisometry, rand, randn
13+
import TensorKit: randisometry, rand, randn, _copyto!, _add_general_kernel_nonthreaded!, blocktype
1414

1515
using TensorKit: MatrixAlgebraKit
1616

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂,
66
function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
77
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
88
end
9+
function TensorMap{T, S, N₁, N₂, DA}(t::TensorMap{T, S, N₁, N₂, HA}) where {T, S, N₁, N₂, DA <: CuArray{T}, HA <: Array{T}}
10+
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
11+
end
912

1013
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
1114
function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
@@ -101,18 +104,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
101104
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
102105
end
103106

104-
function Base.convert(
105-
TT::Type{CuTensorMap{T, S, N₁, N₂}},
106-
t::AbstractTensorMap{<:Any, S, N₁, N₂}
107-
) where {T, S, N₁, N₂}
108-
if typeof(t) === TT
109-
return t
110-
else
111-
tnew = TT(undef, space(t))
112-
return copy!(tnew, t)
113-
end
114-
end
115-
116107
function LinearAlgebra.isposdef(t::CuTensorMap)
117108
domain(t) == codomain(t) ||
118109
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
@@ -138,10 +129,9 @@ function Base.promote_rule(
138129
return CuTensorMap{T, S, N₁, N₂}
139130
end
140131

141-
TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
132+
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
142133
CuArray{T, N, CUDA.default_memory}
143134

144-
145135
# CuTensorMap exponentation:
146136
function TensorKit.exp!(t::CuTensorMap)
147137
domain(t) == codomain(t) ||
@@ -168,3 +158,21 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
168158
return tf
169159
end
170160
end
161+
162+
function TensorKit.add_kernel_nonthreaded!(
163+
tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend...
164+
)
165+
# preallocate buffers
166+
buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer)
167+
168+
for subtransformer in transformer.data
169+
# Special case without intermediate buffers whenever there is only a single block
170+
if length(subtransformer[1]) == 1
171+
TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
172+
else
173+
cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...)
174+
TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...)
175+
end
176+
end
177+
return nothing
178+
end

src/tensors/abstracttensor.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t))
5353
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
5454
if T isa Union
5555
# attempt to be slightly more specific by promoting unions
56-
Ma = storagetype(T.a)
57-
Mb = storagetype(T.b)
58-
return promote_storagetype(Ma, Mb)
56+
return promote_storagetype(T.a, T.b)
57+
elseif eltype(T) isa Union
58+
# attempt to be slightly more specific by promoting unions
59+
TU = eltype(T)
60+
return promote_storagetype(TU.a, TU.b)
5961
else
6062
# fallback definition by using scalartype
6163
return similarstoragetype(scalartype(T))
@@ -103,11 +105,19 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =
103105

104106
# implement on tensors
105107
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
106-
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
107-
similarstoragetype(storagetype(TT), T)
108+
function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number}
109+
return similarstoragetype(storagetype(TT), T)
110+
end
111+
function similarstoragetype(::Type{<:AbstractTensorMap{T, S, N₁, N₂}}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂}
112+
return similarstoragetype(TA, T)
113+
end
114+
function similarstoragetype(t::AbstractTensorMap{T, S, N₁, N₂}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂}
115+
return similarstoragetype(typeof(t), TA)
116+
end
108117

109118
# implement on arrays
110119
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
120+
similarstoragetype(::Type{A}, ::Type{A}) where {A <: DenseVector{<:Number}} = A
111121
Base.@assume_effects :foldable similarstoragetype(::Type{A}) where {A <: AbstractArray{<:Number}} =
112122
Core.Compiler.return_type(similar, Tuple{A, Int})
113123
Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} =

src/tensors/adjoint.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Base.adjoint(t::AbstractTensorMap) = AdjointTensorMap(t)
2222
space(t::AdjointTensorMap) = adjoint(space(parent(t)))
2323
dim(t::AdjointTensorMap) = dim(parent(t))
2424
storagetype(::Type{AdjointTensorMap{T, S, N₁, N₂, TT}}) where {T, S, N₁, N₂, TT} = storagetype(TT)
25+
similarstoragetype(::AdjointTensorMap{T, S, N₁, N₂, TT}, ::Type{T′}) where {T, S, N₁, N₂, TT, T′ <: Number} = similarstoragetype(TT, T′)
26+
similarstoragetype(::AdjointTensorMap{T, S, N₁, N₂, TT}, ::Type{TA}) where {T, S, N₁, N₂, TT, TA <: DenseVector} = similarstoragetype(TT, TA)
2527

2628
# Blocks and subblocks
2729
#----------------------

src/tensors/braidingtensor.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,10 @@ function block(b::BraidingTensor, s::Sector)
145145
# TODO: probably always square?
146146
m = blockdim(codomain(b), s)
147147
n = blockdim(domain(b), s)
148-
data = Matrix{eltype(b)}(undef, (m, n))
148+
data = zeros(eltype(b), (m, n))
149149

150150
length(data) == 0 && return data # s ∉ blocksectors(b)
151151

152-
data = fill!(data, zero(eltype(b)))
153-
154152
V1, V2 = codomain(b)
155153
if sectortype(b) === Trivial
156154
d1, d2 = dim(V1), dim(V2)
@@ -182,12 +180,15 @@ end
182180
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
183181
function add_transform!(
184182
tdst::AbstractTensorMap,
185-
tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple,
183+
tsrc::BraidingTensor{T, S},
184+
(p₁, p₂)::Index2Tuple,
186185
fusiontreetransform,
187186
α::Number, β::Number, backend::AbstractBackend...
188-
)
187+
) where {T, S}
188+
tsrc_map = similar(tdst, storagetype(tdst), space(tsrc))
189+
copy!(tsrc_map, tsrc)
189190
return add_transform!(
190-
tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
191+
tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β,
191192
backend...
192193
)
193194
end
@@ -287,11 +288,15 @@ function planarcontract!(
287288
backend, allocator
288289
)
289290
# special case only defined for contracting 2 indices
290-
length(oindB) == length(cindB) == 2 ||
291+
if !(length(oindB) == length(cindB) == 2)
292+
# horrible!!!!!
293+
tB′ = TensorMap(B)
294+
tB = TensorMapWithStorage{eltype(B), similarstoragetype(A, eltype(B)), spacetype(tB′), numout(tB′), numin(tB′)}(tB′)
291295
return planarcontract!(
292-
C, A, (oindA, cindA), TensorMap(B), (cindB, oindB), (p1, p2),
293-
α, β, backend, allocator
294-
)
296+
C, A, (oindA, cindA), tB, (cindB, oindB), (p1, p2),
297+
α, β, backend, allocator
298+
)
299+
end
295300

296301
codA, domA = codomainind(A), domainind(A)
297302
codB, domB = codomainind(B), domainind(B)

src/tensors/diagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ end
280280
# ----------------
281281
function TO.tensoradd_type(TC, A::DiagonalTensorMap, ::Index2Tuple{1, 1}, ::Bool)
282282
M = similarstoragetype(A, TC)
283-
return DiagonalTensorMap{TC, spacetype(A), M}
283+
return DiagonalTensorMap{scalartype(M), spacetype(A), M}
284284
end
285285

286286
function TO.tensorcontract_type(

src/tensors/indexmanipulations.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ for (operation, manipulation) in (
1717
$promote_op(::Type{T}, ::Type{I}) where {T <: Number, I <: Sector} =
1818
sectorscalartype(I) <: Integer ? T :
1919
sectorscalartype(I) <: Real ? float(T) : complex(T)
20+
$promote_op(::Type{TA}, ::Type{I}) where {TA <: DenseVector, I <: Sector} =
21+
similarstoragetype(TA, $promote_op(eltype(TA), I))
2022
# TODO: currently the manipulations all use sectorscalartype, change to:
2123
# $manipulation_scalartype(I) <: Integer ? T :
2224
# $manipulation_scalartype(I) <: Real ? float(T) : complex(T)
@@ -342,11 +344,11 @@ See also [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i})
342344
"""
343345
function insertleftunit(
344346
t::AbstractTensorMap, ::Val{i} = Val(numind(t) + 1);
345-
copy::Bool = false, conj::Bool = false, dual::Bool = false
347+
copy::Bool = false, conj::Bool = false, dual::Bool = false,
346348
) where {i}
347349
W = insertleftunit(space(t), Val(i); conj, dual)
348350
if t isa TensorMap
349-
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
351+
return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W)
350352
else
351353
tdst = similar(t, W)
352354
for (c, b) in blocks(t)
@@ -371,11 +373,11 @@ See also [`insertleftunit`](@ref insertleftunit(::AbstractTensorMap, ::Val{i}) w
371373
"""
372374
function insertrightunit(
373375
t::AbstractTensorMap, ::Val{i} = Val(numind(t));
374-
copy::Bool = false, conj::Bool = false, dual::Bool = false
376+
copy::Bool = false, conj::Bool = false, dual::Bool = false,
375377
) where {i}
376378
W = insertrightunit(space(t), Val(i); conj, dual)
377379
if t isa TensorMap
378-
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
380+
return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W)
379381
else
380382
tdst = similar(t, W)
381383
for (c, b) in blocks(t)
@@ -400,7 +402,7 @@ and [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) wher
400402
function removeunit(t::AbstractTensorMap, ::Val{i}; copy::Bool = false) where {i}
401403
W = removeunit(space(t), Val(i))
402404
if t isa TensorMap
403-
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
405+
return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W)
404406
else
405407
tdst = similar(t, W)
406408
for (c, b) in blocks(t)

0 commit comments

Comments
 (0)