@@ -136,15 +136,17 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false})
136136end
137137# COV_EXCL_STOP
138138
139+ const SupportedCuMatrix{T} = Union{AnyCuMatrix{T}, SubArray{T, 2 , <: AnyCuMatrix{T} }}
140+
139141function MatrixAlgebraKit. _project_hermitian_offdiag! (
140- Au:: StridedCuMatrix , Al:: StridedCuMatrix , Bu:: StridedCuMatrix , Bl:: StridedCuMatrix , :: Val{anti}
142+ Au:: SupportedCuMatrix , Al:: SupportedCuMatrix , Bu:: SupportedCuMatrix , Bl:: SupportedCuMatrix , :: Val{anti}
141143 ) where {anti}
142144 thread_dim = 512
143145 block_dim = cld (size (Au, 2 ), thread_dim)
144146 @cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, Val (anti))
145147 return nothing
146148end
147- function MatrixAlgebraKit. _project_hermitian_diag! (A:: StridedCuMatrix , B:: StridedCuMatrix , :: Val{anti} ) where {anti}
149+ function MatrixAlgebraKit. _project_hermitian_diag! (A:: SupportedCuMatrix , B:: SupportedCuMatrix , :: Val{anti} ) where {anti}
148150 thread_dim = 512
149151 block_dim = cld (size (A, 1 ), thread_dim)
150152 @cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel (A, B, Val (anti))
@@ -153,14 +155,14 @@ end
153155
154156# avoids calling the `StridedMatrix` specialization to avoid scalar indexing,
155157# use (allocating) fallback instead until we write a dedicated kernel
156- MatrixAlgebraKit. ishermitian_exact (A:: StridedCuMatrix ) = A == A'
157- MatrixAlgebraKit. ishermitian_approx (A:: StridedCuMatrix ; atol, rtol, kwargs... ) =
158+ MatrixAlgebraKit. ishermitian_exact (A:: SupportedCuMatrix ) = A == A'
159+ MatrixAlgebraKit. ishermitian_approx (A:: SupportedCuMatrix ; atol, rtol, kwargs... ) =
158160 norm (project_antihermitian (A; kwargs... )) ≤ max (atol, rtol * norm (A))
159- MatrixAlgebraKit. isantihermitian_exact (A:: StridedCuMatrix ) = A == - A'
160- MatrixAlgebraKit. isantihermitian_approx (A:: StridedCuMatrix ; atol, rtol, kwargs... ) =
161+ MatrixAlgebraKit. isantihermitian_exact (A:: SupportedCuMatrix ) = A == - A'
162+ MatrixAlgebraKit. isantihermitian_approx (A:: SupportedCuMatrix ; atol, rtol, kwargs... ) =
161163 norm (project_hermitian (A; kwargs... )) ≤ max (atol, rtol * norm (A))
162164
163- function MatrixAlgebraKit. _avgdiff! (A:: StridedCuMatrix , B:: StridedCuMatrix )
165+ function MatrixAlgebraKit. _avgdiff! (A:: SupportedCuMatrix , B:: SupportedCuMatrix )
164166 axes (A) == axes (B) || throw (DimensionMismatch ())
165167 # COV_EXCL_START
166168 function _avgdiff_kernel (A, B)
0 commit comments