-
Notifications
You must be signed in to change notification settings - Fork 60
Expand file tree
/
Copy pathTensorKitAMDGPUExt.jl
More file actions
108 lines (88 loc) · 4.17 KB
/
TensorKitAMDGPUExt.jl
File metadata and controls
108 lines (88 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
module TensorKitAMDGPUExt
using AMDGPU, AMDGPU.rocBLAS, LinearAlgebra
using AMDGPU: @allowscalar
import AMDGPU: rand as rocrand, rand! as rocrand!, randn as rocrandn, randn! as rocrandn!
using TensorKit
using TensorKit.Factorizations
using TensorKit.Strided
using TensorKit.Factorizations: AbstractAlgorithm
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype
using TensorKit.MatrixAlgebraKit
using Random
include("roctensormap.jl")
const ROCDiagonalTensorMap{T, S} = DiagonalTensorMap{T, S, ROCVector{T, AMDGPU.Mem.HIPBuffer}}
"""
ROCDiagonalTensorMap{T}(undef, domain::S) where {T,S<:IndexSpace}
# expert mode: select storage type `A`
DiagonalTensorMap{T,S,A}(undef, domain::S) where {T,S<:IndexSpace,A<:DenseVector{T}}
Construct a `DiagonalTensorMap` with uninitialized data.
"""
function ROCDiagonalTensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T}
(numin(V) == numout(V) == 1 && domain(V) == codomain(V)) ||
throw(ArgumentError("DiagonalTensorMap requires a space with equal domain and codomain and 2 indices"))
return ROCDiagonalTensorMap{T}(undef, domain(V))
end
function ROCDiagonalTensorMap{T}(::UndefInitializer, V::ProductSpace) where {T}
length(V) == 1 ||
throw(ArgumentError("DiagonalTensorMap requires `numin(d) == numout(d) == 1`"))
return ROCDiagonalTensorMap{T}(undef, only(V))
end
function ROCDiagonalTensorMap{T}(::UndefInitializer, V::S) where {T, S <: IndexSpace}
return ROCDiagonalTensorMap{T, S}(undef, V)
end
ROCDiagonalTensorMap(::UndefInitializer, V::IndexSpace) = ROCDiagonalTensorMap{Float64}(undef, V)
function ROCDiagonalTensorMap(data::ROCVector{T}, V::S) where {T, S}
return ROCDiagonalTensorMap{T, S}(data, V)
end
function ROCDiagonalTensorMap(data::Vector{T}, V::S) where {T, S}
return ROCDiagonalTensorMap{T, S}(ROCVector{T}(data), V)
end
function TensorKit.Factorizations.MAK.initialize_output(::typeof(svd_full!), t::ROCDiagonalTensorMap, alg::DiagonalAlgorithm)
V_cod = fuse(codomain(t))
V_dom = fuse(domain(t))
U = similar(t, codomain(t) ← V_cod)
S = ROCDiagonalTensorMap{real(scalartype(t))}(undef, V_cod ← V_dom)
Vᴴ = similar(t, V_dom ← domain(t))
return U, S, Vᴴ
end
function TensorKit.Factorizations.MAK.initialize_output(::typeof(svd_vals!), t::ROCTensorMap, alg::AbstractAlgorithm)
V_cod = infimum(fuse(codomain(t)), fuse(domain(t)))
return ROCDiagonalTensorMap{real(scalartype(t))}(undef, V_cod)
end
function TensorKit.Factorizations.MAK.initialize_output(::typeof(svd_compact!), t::ROCTensorMap, ::AbstractAlgorithm)
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
U = similar(t, codomain(t) ← V_cod)
S = ROCDiagonalTensorMap{real(scalartype(t))}(undef, V_cod)
Vᴴ = similar(t, V_dom ← domain(t))
return U, S, Vᴴ
end
function TensorKit.Factorizations.MAK.initialize_output(::typeof(eigh_full!), t::ROCTensorMap, ::AbstractAlgorithm)
V_D = fuse(domain(t))
T = real(scalartype(t))
D = ROCDiagonalTensorMap{T}(undef, V_D)
V = similar(t, codomain(t) ← V_D)
return D, V
end
function TensorKit.Factorizations.MAK.initialize_output(::typeof(eig_full!), t::ROCTensorMap, ::AbstractAlgorithm)
V_D = fuse(domain(t))
Tc = complex(scalartype(t))
D = ROCDiagonalTensorMap{Tc}(undef, V_D)
V = similar(t, Tc, codomain(t) ← V_D)
return D, V
end
function TensorKit.Factorizations.MAK.initialize_output(::typeof(eigh_vals!), t::ROCTensorMap, alg::AbstractAlgorithm)
V_D = fuse(domain(t))
T = real(scalartype(t))
return D = ROCDiagonalTensorMap{Tc}(undef, V_D)
end
function TensorKit.Factorizations.MAK.initialize_output(::typeof(eig_vals!), t::ROCTensorMap, alg::AbstractAlgorithm)
V_D = fuse(domain(t))
Tc = complex(scalartype(t))
return D = ROCDiagonalTensorMap{Tc}(undef, V_D)
end
# TODO
# add VectorInterface extensions for proper AMDGPU promotion
function TensorKit.VectorInterface.promote_add(TA::Type{<:AMDGPU.StridedROCMatrix{Tx}}, TB::Type{<:AMDGPU.StridedROCMatrix{Ty}}, α::Tα = TensorKit.VectorInterface.One(), β::Tβ = TensorKit.VectorInterface.One()) where {Tx, Ty, Tα, Tβ}
return Base.promote_op(add, Tx, Ty, Tα, Tβ)
end
end