Skip to content

Commit 5378c2e

Browse files
kshyattlkdvos
authored andcommitted
Support Subarray{<:Adjoint{<:GPUMatrix}}
1 parent cea4178 commit 5378c2e

2 files changed

Lines changed: 19 additions & 15 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,31 +112,33 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false})
112112
end
113113
# COV_EXCL_STOP
114114

115+
const SupportedROCMatrix{T} = Union{AnyROCMatrix{T}, SubArray{T, 2, <:AnyROCMatrix{T}}}
116+
115117
function MatrixAlgebraKit._project_hermitian_offdiag!(
116-
Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti}
118+
Au::SupportedROCMatrix, Al::SupportedROCMatrix, Bu::SupportedROCMatrix, Bl::SupportedROCMatrix, ::Val{anti}
117119
) where {anti}
118120
thread_dim = 512
119121
block_dim = cld(size(Au, 2), thread_dim)
120122
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
121123
return nothing
122124
end
123-
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti}
125+
function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedROCMatrix, B::SupportedROCMatrix, ::Val{anti}) where {anti}
124126
thread_dim = 512
125127
block_dim = cld(size(A, 1), thread_dim)
126128
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
127129
return nothing
128130
end
129131

130-
# avoids calling the `StridedMatrix` specialization to avoid scalar indexing,
132+
# avoids calling the `SupportedMatrix` specialization to avoid scalar indexing,
131133
# use (allocating) fallback instead until we write a dedicated kernel
132-
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = A == A'
133-
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) =
134+
MatrixAlgebraKit.ishermitian_exact(A::SupportedROCMatrix) = A == A'
135+
MatrixAlgebraKit.ishermitian_approx(A::SupportedROCMatrix; atol, rtol, kwargs...) =
134136
norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
135-
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = A == -A'
136-
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) =
137+
MatrixAlgebraKit.isantihermitian_exact(A::SupportedROCMatrix) = A == -A'
138+
MatrixAlgebraKit.isantihermitian_approx(A::SupportedROCMatrix; atol, rtol, kwargs...) =
137139
norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
138140

139-
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
141+
function MatrixAlgebraKit._avgdiff!(A::SupportedROCMatrix, B::SupportedROCMatrix)
140142
axes(A) == axes(B) || throw(DimensionMismatch())
141143
# COV_EXCL_START
142144
function _avgdiff_kernel(A, B)

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,17 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false})
136136
end
137137
# COV_EXCL_STOP
138138

139+
const SupportedCuMatrix{T} = Union{AnyCuMatrix{T}, SubArray{T, 2, <:AnyCuMatrix{T}}}
140+
139141
function MatrixAlgebraKit._project_hermitian_offdiag!(
140-
Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti}
142+
Au::SupportedCuMatrix, Al::SupportedCuMatrix, Bu::SupportedCuMatrix, Bl::SupportedCuMatrix, ::Val{anti}
141143
) where {anti}
142144
thread_dim = 512
143145
block_dim = cld(size(Au, 2), thread_dim)
144146
@cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
145147
return nothing
146148
end
147-
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti}
149+
function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedCuMatrix, B::SupportedCuMatrix, ::Val{anti}) where {anti}
148150
thread_dim = 512
149151
block_dim = cld(size(A, 1), thread_dim)
150152
@cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
@@ -153,14 +155,14 @@ end
153155

154156
# avoids calling the `StridedMatrix` specialization to avoid scalar indexing,
155157
# use (allocating) fallback instead until we write a dedicated kernel
156-
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = A == A'
157-
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) =
158+
MatrixAlgebraKit.ishermitian_exact(A::SupportedCuMatrix) = A == A'
159+
MatrixAlgebraKit.ishermitian_approx(A::SupportedCuMatrix; atol, rtol, kwargs...) =
158160
norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
159-
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = A == -A'
160-
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) =
161+
MatrixAlgebraKit.isantihermitian_exact(A::SupportedCuMatrix) = A == -A'
162+
MatrixAlgebraKit.isantihermitian_approx(A::SupportedCuMatrix; atol, rtol, kwargs...) =
161163
norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
162164

163-
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
165+
function MatrixAlgebraKit._avgdiff!(A::SupportedCuMatrix, B::SupportedCuMatrix)
164166
axes(A) == axes(B) || throw(DimensionMismatch())
165167
# COV_EXCL_START
166168
function _avgdiff_kernel(A, B)

0 commit comments

Comments
 (0)