@@ -6,6 +6,9 @@ const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂,
66function CuTensorMap (t:: TensorMap{T, S, N₁, N₂, A} ) where {T, S, N₁, N₂, A}
77 return CuTensorMap {T, S, N₁, N₂} (CuArray {T} (t. data), space (t))
88end
9+ function TensorMap {T, S, N₁, N₂, DA} (t:: TensorMap{T, S, N₁, N₂, HA} ) where {T, S, N₁, N₂, DA <: CuArray{T} , HA <: Array{T} }
10+ return CuTensorMap {T, S, N₁, N₂} (CuArray {T} (t. data), space (t))
11+ end
912
1013# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
1114function TensorKit. project_symmetric_and_check (:: Type{T} , :: Type{A} , data:: AbstractArray , V:: TensorMapSpace ; tol = sqrt (eps (real (float (eltype (data)))))) where {T, A <: CuVector{T} }
@@ -101,18 +104,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
101104 return isempty (inds) ? zero (scalartype (t)) : @allowscalar @inbounds t. data[only (inds)]
102105end
103106
104- function Base. convert (
105- TT:: Type{CuTensorMap{T, S, N₁, N₂}} ,
106- t:: AbstractTensorMap{<:Any, S, N₁, N₂}
107- ) where {T, S, N₁, N₂}
108- if typeof (t) === TT
109- return t
110- else
111- tnew = TT (undef, space (t))
112- return copy! (tnew, t)
113- end
114- end
115-
116107function LinearAlgebra. isposdef (t:: CuTensorMap )
117108 domain (t) == codomain (t) ||
118109 throw (SpaceMismatch (" `isposdef` requires domain and codomain to be the same" ))
@@ -138,10 +129,9 @@ function Base.promote_rule(
138129 return CuTensorMap{T, S, N₁, N₂}
139130end
140131
141- TensorKit. promote_storage_rule (:: Type{CuArray{T, N}} , :: Type{<:CuArray{T, N}} ) where {T, N} =
132+ TensorKit. promote_storage_rule (:: Type{<: CuArray{T, N}} , :: Type{<:CuArray{T, N}} ) where {T, N} =
142133 CuArray{T, N, CUDA. default_memory}
143134
144-
145135# CuTensorMap exponentation:
146136function TensorKit. exp! (t:: CuTensorMap )
147137 domain (t) == codomain (t) ||
@@ -168,3 +158,21 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
168158 return tf
169159 end
170160end
161+
162+ function TensorKit. add_kernel_nonthreaded! (
163+ tdst:: CuTensorMap , tsrc:: CuTensorMap , p, transformer:: TensorKit.GenericTreeTransformer , α, β, backend...
164+ )
165+ # preallocate buffers
166+ buffers = TensorKit. allocate_buffers (tdst, tsrc, transformer)
167+
168+ for subtransformer in transformer. data
169+ # Special case without intermediate buffers whenever there is only a single block
170+ if length (subtransformer[1 ]) == 1
171+ TensorKit. _add_transform_single! (tdst, tsrc, p, subtransformer, α, β, backend... )
172+ else
173+ cu_subtransformer = tuple (CUDA. adapt (CuArray, subtransformer[1 ]), subtransformer[2 : end ]. .. )
174+ TensorKit. _add_transform_multi! (tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend... )
175+ end
176+ end
177+ return nothing
178+ end
0 commit comments