Skip to content

Commit e3bdab4

Browse files
kshyattlkdvos
andauthored
fix: improvements to bypass scalar indexing and improve GPU support (QuantumKitHub#375)
* More tweaks for GPU support * fix typo * Fix TC once again * Remove unneeded Adjoint methods * Remove unneeded TensorMapWithStorage? * Death to to_cpu * Remove unneeded similarstoragetype method * Add in TensorMap constructor * Restore former braiding tensor methods * Fix type issue for sortperm * Remove stale type params * Apply suggestions from code review Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Fix bad result of suggestion * Another fix? * Force inds to move back to the CPU * Return to glorious scalartype * Restore DiagonalTensorMap ctor * Resolve trunc ambiguity * Remove extra CUDA ctor * Restore chopped argument * Also remove no longer needed method * Remove forced Int eltype * Get rid of no-op ctor * Try to resolve ambiguity * Cover all truncation strategies * Short-circuit logic in `findtruncated` Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Formatter --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 02424b3 commit e3bdab4

7 files changed

Lines changed: 161 additions & 166 deletions

File tree

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
101101
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
102102
end
103103

104-
function Base.convert(
105-
TT::Type{CuTensorMap{T, S, N₁, N₂}},
106-
t::AbstractTensorMap{<:Any, S, N₁, N₂}
107-
) where {T, S, N₁, N₂}
108-
if typeof(t) === TT
109-
return t
110-
else
111-
tnew = TT(undef, space(t))
112-
return copy!(tnew, t)
113-
end
114-
end
115-
116104
function LinearAlgebra.isposdef(t::CuTensorMap)
117105
domain(t) == codomain(t) ||
118106
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
@@ -138,10 +126,9 @@ function Base.promote_rule(
138126
return CuTensorMap{T, S, N₁, N₂}
139127
end
140128

141-
TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
129+
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
142130
CuArray{T, N, CUDA.default_memory}
143131

144-
145132
# CuTensorMap exponentation:
146133
function TensorKit.exp!(t::CuTensorMap)
147134
domain(t) == codomain(t) ||

ext/TensorKitCUDAExt/truncation.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ function MatrixAlgebraKit.findtruncated(
1010
fill!(v, dim(c))
1111
end
1212

13+
isempty(parent(values)) && return similar(values, Bool)
14+
1315
perm = sortperm(parent(values); strategy.by, strategy.rev)
1416
cumulative_dim = cumsum(Base.permute!(parent(dims), perm))
1517

@@ -36,6 +38,8 @@ function MatrixAlgebraKit.findtruncated(
3638
end
3739
end
3840

41+
isempty(parent(values)) && return similar(values, Bool)
42+
3943
perm = sortperm(parent(values); by = abs, rev = false)
4044
cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm))
4145

@@ -44,6 +48,26 @@ function MatrixAlgebraKit.findtruncated(
4448
return result
4549
end
4650

51+
function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy}
52+
# returning a CuSectorVector wrecks things in truncate_{co}domain
53+
# because of scalar indexing
54+
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
55+
end
56+
57+
for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace))
58+
@eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat)
59+
# returning a CuSectorVector wrecks things in truncate_{co}domain
60+
# because of scalar indexing
61+
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
62+
end
63+
end
64+
65+
function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue)
66+
atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
67+
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
68+
return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values))
69+
end
70+
4771
# Needed until MatrixAlgebraKit patch hits...
4872
function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int})
4973
result = fill!(similar(A), false)

src/tensors/abstracttensor.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ storagetype(t) = storagetype(typeof(t))
5353
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
5454
if T isa Union
5555
# attempt to be slightly more specific by promoting unions
56-
Ma = storagetype(T.a)
57-
Mb = storagetype(T.b)
58-
return promote_storagetype(Ma, Mb)
56+
return promote_storagetype(T.a, T.b)
5957
else
6058
# fallback definition by using scalartype
6159
return similarstoragetype(scalartype(T))
@@ -103,8 +101,9 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =
103101

104102
# implement on tensors
105103
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
106-
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
107-
similarstoragetype(storagetype(TT), T)
104+
function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number}
105+
return similarstoragetype(storagetype(TT), T)
106+
end
108107

109108
# implement on arrays
110109
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A

src/tensors/adjoint.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ Base.@propagate_inbounds function subblock(t::AdjointTensorMap, (f₁, f₂)::Tu
5050
return permutedims(conj(data), (domainind(tp)..., codomainind(tp)...))
5151
end
5252

53-
to_cpu(t::AdjointTensorMap) = adjoint(to_cpu(adjoint(t)))
54-
5553
# Show
5654
#------
5755
function Base.showarg(io::IO, t::AdjointTensorMap, toplevel::Bool)

src/tensors/tensor.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac
2121
end
2222
return TensorMap{T, S, N₁, N₂, A}(data, space)
2323
end
24-
2524
# constructors from data
2625
function TensorMap{T, S, N₁, N₂, A}(
2726
data::A, space::TensorMapSpace{S, N₁, N₂}
@@ -34,6 +33,7 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac
3433
return new{T, S, N₁, N₂, A}(data, space)
3534
end
3635
end
36+
TensorMap{T, S, N₁, N₂, A}(t::TensorMap{T, S, N₁, N₂}) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} = TensorMap(A(t.data), space(t))
3737

3838
"""
3939
Tensor{T, S, N, A<:DenseVector{T}} = TensorMap{T, S, N, 0, A}
@@ -407,11 +407,6 @@ for randf in (:rand, :randn, :randexp, :randisometry)
407407
end
408408
end
409409

410-
# Moving arbitrary TensorMaps to CPU
411-
#-----------------------------
412-
to_cpu(t::TensorMapWithStorage{T, Vector{T}}) where {T} = t # no op
413-
to_cpu(t::TensorMap) = convert(TensorMapWithStorage{scalartype(t), similarstoragetype(scalartype(t))}, t)
414-
415410
# Efficient copy constructors
416411
#-----------------------------
417412
Base.copy(t::TensorMap) = typeof(t)(copy(t.data), t.space)

test/amd/tensors.jl

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ for V in spacelist
9797
for T in (Int, Float32, ComplexF64)
9898
t = @constinferred AMDGPU.rand(T, W)
9999
d = convert(Dict, t)
100-
@test TensorKit.to_cpu(t) == convert(TensorMap, d)
100+
@test adapt(Array, t) == convert(TensorMap, d)
101101
end
102102
end
103103
symmetricbraiding && @timedtestset "Basic linear algebra" begin
@@ -189,10 +189,10 @@ for V in spacelist
189189
t = AMDGPU.rand(T, W)
190190
t2 = @constinferred AMDGPU.rand!(similar(t))
191191
α = rand(T)
192-
@test norm(t, 2) norm(TensorKit.to_cpu(t), 2)
193-
@test dot(t2, t) dot(TensorKit.to_cpu(t2), TensorKit.to_cpu(t))
194-
@test TensorKit.to_cpu(α * t) α * TensorKit.to_cpu(t)
195-
@test TensorKit.to_cpu(t + t) 2 * TensorKit.to_cpu(t)
192+
@test norm(t, 2) norm(adapt(Array, t), 2)
193+
@test dot(t2, t) dot(adapt(Array, t2), adapt(Array, t))
194+
@test adapt(Array, α * t) α * adapt(Array, t)
195+
@test adapt(Array, t + t) 2 * adapt(Array, t)
196196
end
197197
end
198198
@timedtestset "Real and imaginary parts" begin
@@ -202,17 +202,17 @@ for V in spacelist
202202

203203
tr = @constinferred real(t)
204204
@test scalartype(tr) <: Real
205-
@test real(TensorKit.to_cpu(t)) == TensorKit.to_cpu(tr)
205+
@test real(adapt(Array, t)) == adapt(Array, tr)
206206
@test storagetype(tr) == ROCVector{real(T), AMDGPU.Mem.HIPBuffer}
207207

208208
ti = @constinferred imag(t)
209209
@test scalartype(ti) <: Real
210-
@test imag(TensorKit.to_cpu(t)) == TensorKit.to_cpu(ti)
210+
@test imag(adapt(Array, t)) == adapt(Array, ti)
211211
@test storagetype(ti) == ROCVector{real(T), AMDGPU.Mem.HIPBuffer}
212212

213213
tc = @inferred complex(t)
214214
@test scalartype(tc) <: Complex
215-
@test complex(TensorKit.to_cpu(t)) == TensorKit.to_cpu(tc)
215+
@test complex(adapt(Array, t)) == adapt(Array, tc)
216216
@test storagetype(tc) == ROCVector{complex(T), AMDGPU.Mem.HIPBuffer}
217217

218218
tc2 = @inferred complex(tr, ti)
@@ -275,13 +275,13 @@ for V in spacelist
275275
p1 = ntuple(n -> p[n], k)
276276
p2 = ntuple(n -> p[k + n], 5 - k)
277277
dt2 = AMDGPU.@allowscalar permute(t, (p1, p2))
278-
ht2 = permute(TensorKit.to_cpu(t), (p1, p2))
279-
@test ht2 == TensorKit.to_cpu(dt2)
278+
ht2 = permute(adapt(Array, t), (p1, p2))
279+
@test ht2 == adapt(Array, dt2)
280280
end
281281

282282
dt3 = AMDGPU.@allowscalar repartition(t, k)
283-
ht3 = repartition(TensorKit.to_cpu(t), k)
284-
@test ht3 == TensorKit.to_cpu(dt3)
283+
ht3 = repartition(adapt(Array, t), k)
284+
@test ht3 == adapt(Array, dt3)
285285
end
286286
end
287287
symmetricbraiding && @timedtestset "Full trace: test self-consistency" begin
@@ -339,10 +339,10 @@ for V in spacelist
339339
@tensor dHrA12[a, s1, s2, c] := drhoL[a, a'] * conj(dA1[a', t1, b]) *
340340
dA2[b, t2, c'] * drhoR[c', c] *
341341
dH[s1, s2, t1, t2]
342-
@tensor hHrA12[a, s1, s2, c] := TensorKit.to_cpu(drhoL)[a, a'] * conj(TensorKit.to_cpu(dA1)[a', t1, b]) *
343-
TensorKit.to_cpu(dA2)[b, t2, c'] * TensorKit.to_cpu(drhoR)[c', c] *
344-
TensorKit.to_cpu(dH)[s1, s2, t1, t2]
345-
@test TensorKit.to_cpu(dHrA12) ≈ hHrA12
342+
@tensor hHrA12[a, s1, s2, c] := adapt(Array, drhoL)[a, a'] * conj(adapt(Array, dA1)[a', t1, b]) *
343+
adapt(Array, dA2)[b, t2, c'] * adapt(Array, drhoR)[c', c] *
344+
adapt(Array, dH)[s1, s2, t1, t2]
345+
@test adapt(Array, dHrA12) ≈ hHrA12
346346
end
347347
end=# # doesn't yet work because of AdjointTensor
348348
BraidingStyle(I) isa HasBraiding && @timedtestset "Index flipping: test flipping inverse" begin
@@ -422,31 +422,31 @@ for V in spacelist
422422
t1 = AMDGPU.rand(T, W1, W1)
423423
t2 = AMDGPU.rand(T, W2, W2)
424424
t = AMDGPU.rand(T, W1, W2)
425-
ht1 = TensorKit.to_cpu(t1)
426-
ht2 = TensorKit.to_cpu(t2)
427-
ht = TensorKit.to_cpu(t)
428-
@test TensorKit.to_cpu(t1 * t) ht1 * ht
429-
@test TensorKit.to_cpu(t1' * t) ht1' * ht
430-
@test TensorKit.to_cpu(t2 * t') ht2 * ht'
431-
@test TensorKit.to_cpu(t2' * t') ht2' * ht'
425+
ht1 = adapt(Array, t1)
426+
ht2 = adapt(Array, t2)
427+
ht = adapt(Array, t)
428+
@test adapt(Array, t1 * t) ht1 * ht
429+
@test adapt(Array, t1' * t) ht1' * ht
430+
@test adapt(Array, t2 * t') ht2 * ht'
431+
@test adapt(Array, t2' * t') ht2' * ht'
432432

433433
#=AMDGPU.@allowscalar begin
434-
@test TensorKit.to_cpu(inv(t1)) ≈ inv(ht1)
435-
@test TensorKit.to_cpu(pinv(t)) ≈ pinv(ht)
434+
@test adapt(Array, inv(t1)) ≈ inv(ht1)
435+
@test adapt(Array, pinv(t)) ≈ pinv(ht)
436436
437437
if T == Float32 || T == ComplexF32
438438
continue
439439
end
440440
441-
@test TensorKit.to_cpu(t1 \ t) ≈ ht1 \ ht
442-
@test TensorKit.to_cpu(t1' \ t) ≈ ht1' \ ht
443-
@test TensorKit.to_cpu(t2 \ t') ≈ ht2 \ ht'
444-
@test TensorKit.to_cpu(t2' \ t') ≈ ht2' \ ht'
441+
@test adapt(Array, t1 \ t) ≈ ht1 \ ht
442+
@test adapt(Array, t1' \ t) ≈ ht1' \ ht
443+
@test adapt(Array, t2 \ t') ≈ ht2 \ ht'
444+
@test adapt(Array, t2' \ t') ≈ ht2' \ ht'
445445
446-
@test TensorKit.to_cpu(t2 / t) ≈ ht2 / ht
447-
@test TensorKit.to_cpu(t2' / t) ≈ ht2' / ht
448-
@test TensorKit.to_cpu(t1 / t') ≈ ht1 / ht'
449-
@test TensorKit.to_cpu(t1' / t') ≈ ht1' / ht'
446+
@test adapt(Array, t2 / t) ≈ ht2 / ht
447+
@test adapt(Array, t2' / t) ≈ ht2' / ht
448+
@test adapt(Array, t1 / t') ≈ ht1 / ht'
449+
@test adapt(Array, t1' / t') ≈ ht1' / ht'
450450
end=#
451451
end
452452
end
@@ -456,11 +456,11 @@ for V in spacelist
456456
#=t = project_hermitian!(AMDGPU.randn(T, W, W))
457457
s = dim(W)
458458
@test (@constinferred sqrt(t))^2 ≈ t
459-
@test TensorKit.to_cpu(sqrt(t)) ≈ sqrt(TensorKit.to_cpu(t))
459+
@test adapt(Array, sqrt(t)) ≈ sqrt(adapt(Array, t))
460460
expt = @constinferred exp(t)
461-
@test TensorKit.to_cpu(expt) ≈ exp(TensorKit.to_cpu(t))
461+
@test adapt(Array, expt) ≈ exp(adapt(Array, t))
462462
@test exp(@constinferred log(project_hermitian!(expt))) ≈ expt
463-
@test TensorKit.to_cpu(log(project_hermitian!(expt))) ≈ log(TensorKit.to_cpu(expt))
463+
@test adapt(Array, log(project_hermitian!(expt))) ≈ log(adapt(Array, expt))
464464
465465
@test (@constinferred cos(t))^2 + (@constinferred sin(t))^2 ≈
466466
id(storagetype(t), W)

0 commit comments

Comments
 (0)