Skip to content

Commit 9972a65

Browse files
kshyattKatharine Hyatt
authored andcommitted
Basic tensor support for AMDGPU
1 parent 34ac960 commit 9972a65

7 files changed

Lines changed: 779 additions & 4 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2020

2121
[weakdeps]
2222
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
23+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
2324
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2425
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2526
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
@@ -28,6 +29,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2829

2930
[extensions]
3031
TensorKitAdaptExt = "Adapt"
32+
TensorKitAMDGPUExt = "AMDGPU"
3133
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
3234
TensorKitChainRulesCoreExt = "ChainRulesCore"
3335
TensorKitFiniteDifferencesExt = "FiniteDifferences"
@@ -38,6 +40,7 @@ projects = ["test"]
3840

3941
[compat]
4042
Adapt = "4"
43+
AMDGPU = "2"
4144
CUDA = "5.9"
4245
ChainRulesCore = "1"
4346
Dictionaries = "0.4"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module TensorKitAMDGPUExt
2+
3+
using AMDGPU, AMDGPU.rocBLAS, AMDGPU.rocSOLVER, LinearAlgebra
4+
using AMDGPU: @allowscalar
5+
import AMDGPU: rand as rocrand, rand! as rocrand!, randn as rocrandn, randn! as rocrandn!
6+
7+
using TensorKit
8+
using TensorKit.Factorizations
9+
using Strided
10+
using MatrixAlgebraKit
11+
using MatrixAlgebraKit: AbstractAlgorithm
12+
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13+
import TensorKit: randisometry
14+
using Base: rand, randn
15+
16+
17+
using Random
18+
19+
include("roctensormap.jl")
20+
21+
end
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
const ROCTensorMap{T, S, N₁, N₂} = TensorMap{T, S, N₁, N₂, ROCVector{T, AMDGPU.Mem.HIPBuffer}}
2+
const ROCTensor{T, S, N} = ROCTensorMap{T, S, N, 0}
3+
4+
const AdjointROCTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, ROCTensorMap{T, S, N₁, N₂}}
5+
6+
function ROCTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
7+
return ROCTensorMap{T, S, N₁, N₂}(ROCArray{T}(t.data), space(t))
8+
end
9+
10+
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
11+
function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: ROCVector{T}}
12+
h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V)
13+
h_t = TensorKit.project_symmetric!(h_t, Array(data))
14+
# verify result
15+
isapprox(Array(reshape(data, dims(h_t))), convert(Array, h_t); atol = tol) ||
16+
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
17+
return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V)
18+
end
19+
20+
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
21+
@eval begin
22+
function AMDGPU.$fname(
23+
codomain::TensorSpace{S},
24+
domain::TensorSpace{S} = one(codomain)
25+
) where {S <: IndexSpace}
26+
return AMDGPU.$fname(codomain domain)
27+
end
28+
function AMDGPU.$fname(
29+
::Type{T}, codomain::TensorSpace{S},
30+
domain::TensorSpace{S} = one(codomain)
31+
) where {T, S <: IndexSpace}
32+
return AMDGPU.$fname(T, codomain domain)
33+
end
34+
AMDGPU.$fname(V::TensorMapSpace) = AMDGPU.$fname(Float64, V)
35+
function AMDGPU.$fname(::Type{T}, V::TensorMapSpace) where {T}
36+
t = ROCTensorMap{T}(undef, V)
37+
fill!(t, $felt(T))
38+
return t
39+
end
40+
end
41+
end
42+
43+
for randfun in (:rocrand, :rocrandn)
44+
randfun! = Symbol(randfun, :!)
45+
@eval begin
46+
# converting `codomain` and `domain` into `HomSpace`
47+
function $randfun(
48+
codomain::TensorSpace{S},
49+
domain::TensorSpace{S} = one(codomain),
50+
) where {S <: IndexSpace}
51+
return $randfun(codomain domain)
52+
end
53+
function $randfun(
54+
::Type{T}, codomain::TensorSpace{S},
55+
domain::TensorSpace{S} = one(codomain),
56+
) where {T, S <: IndexSpace}
57+
return $randfun(T, codomain domain)
58+
end
59+
function $randfun(
60+
rng::Random.AbstractRNG, ::Type{T},
61+
codomain::TensorSpace{S},
62+
domain::TensorSpace{S} = one(codomain),
63+
) where {T, S <: IndexSpace}
64+
return $randfun(rng, T, codomain domain)
65+
end
66+
67+
# filling in default eltype
68+
$randfun(V::TensorMapSpace) = $randfun(Float64, V)
69+
function $randfun(rng::Random.AbstractRNG, V::TensorMapSpace)
70+
return $randfun(rng, Float64, V)
71+
end
72+
73+
# filling in default rng
74+
function $randfun(::Type{T}, V::TensorMapSpace) where {T}
75+
return $randfun(Random.default_rng(), T, V)
76+
end
77+
78+
# implementation
79+
function $randfun(
80+
rng::Random.AbstractRNG, ::Type{T},
81+
V::TensorMapSpace
82+
) where {T}
83+
t = ROCTensorMap{T}(undef, V)
84+
$randfun!(rng, t)
85+
return t
86+
end
87+
88+
function $randfun!(rng::Random.AbstractRNG, t::ROCTensorMap)
89+
$randfun!(rng, t.data)
90+
return t
91+
end
92+
end
93+
end
94+
95+
# Scalar implementation
96+
#-----------------------
97+
function TensorKit.scalar(t::ROCTensorMap{T, S, 0, 0}) where {T, S}
98+
inds = findall(!iszero, t.data)
99+
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
100+
end
101+
102+
function Base.convert(
103+
TT::Type{ROCTensorMap{T, S, N₁, N₂}},
104+
t::AbstractTensorMap{<:Any, S, N₁, N₂}
105+
) where {T, S, N₁, N₂}
106+
if typeof(t) === TT
107+
return t
108+
else
109+
tnew = TT(undef, space(t))
110+
return copy!(tnew, t)
111+
end
112+
end
113+
114+
function LinearAlgebra.isposdef(t::ROCTensorMap)
115+
domain(t) == codomain(t) ||
116+
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
117+
InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false
118+
for (c, b) in blocks(t)
119+
# do our own hermitian check
120+
isherm = MatrixAlgebraKit.ishermitian(b)
121+
isherm || return false
122+
isposdef(Hermitian(b)) || return false
123+
end
124+
return true
125+
end
126+
127+
function Base.promote_rule(
128+
::Type{<:TT₁},
129+
::Type{<:TT₂}
130+
) where {
131+
S, N₁, N₂, TTT₁, TTT₂,
132+
TT₁ <: ROCTensorMap{TTT₁, S, N₁, N₂},
133+
TT₂ <: ROCTensorMap{TTT₂, S, N₁, N₂},
134+
}
135+
T = TensorKit.VectorInterface.promote_add(TTT₁, TTT₂)
136+
return ROCTensorMap{T, S, N₁, N₂}
137+
end
138+
139+
# ROCTensorMap exponentation:
140+
function TensorKit.exp!(t::ROCTensorMap)
141+
domain(t) == codomain(t) ||
142+
error("Exponential of a tensor only exist when domain == codomain.")
143+
!MatrixAlgebraKit.ishermitian(t) && throw(ArgumentError("`exp!` is currently only supported on hermitian AMDGPU tensors"))
144+
for (c, b) in blocks(t)
145+
copy!(b, parent(Base.exp(Hermitian(b))))
146+
end
147+
return t
148+
end
149+
150+
# functions that don't map ℝ to (a subset of) ℝ
151+
for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
152+
sf = string(f)
153+
@eval function Base.$f(t::ROCTensorMap)
154+
domain(t) == codomain(t) ||
155+
throw(SpaceMismatch("`$($sf)` of a tensor only exists when domain == codomain"))
156+
!MatrixAlgebraKit.ishermitian(t) && throw(ArgumentError("`$($sf)` is currently only supported on hermitian AMDGPU tensors"))
157+
T = complex(float(scalartype(t)))
158+
tf = similar(t, T)
159+
for (c, b) in blocks(t)
160+
copy!(block(tf, c), parent($f(Hermitian(b))))
161+
end
162+
return tf
163+
end
164+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ name = "TensorKitTests"
33
[deps]
44
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
55
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
6+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
67
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

0 commit comments

Comments
 (0)