@@ -9,6 +9,7 @@ using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_
99import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
1111using CUDA
12+ using CUDA: i32
1213using LinearAlgebra
1314using LinearAlgebra: BlasFloat
1415
@@ -58,4 +59,106 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::T
5859 return MatrixAlgebraKit. findtruncated (values, strategy)
5960end
6061
62+ # COV_EXCL_START
63+ function _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, :: Val{true} )
64+ m, n = size (Au)
65+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
66+ j > n && return
67+ for i in 1 : m
68+ @inbounds begin
69+ val = (Au[i, j] - adjoint (Al[j, i])) / 2
70+ Bu[i, j] = val
71+ Bl[j, i] = - adjoint (val)
72+ end
73+ end
74+ return
75+ end
76+
77+ function _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, :: Val{false} )
78+ m, n = size (Au)
79+ j = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
80+ j > n && return
81+ for i in 1 : m
82+ @inbounds begin
83+ val = (Au[i, j] + adjoint (Al[j, i])) / 2
84+ Bu[i, j] = val
85+ Bl[j, i] = adjoint (val)
86+ end
87+ end
88+ return
89+ end
90+
91+ function _project_hermitian_diag_kernel (A, B, :: Val{true} )
92+ n = size (A, 1 )
93+ j = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
94+ j > n && return
95+ @inbounds begin
96+ for i in 1 i32: (j - 1 i32)
97+ val = (A[i, j] - adjoint (A[j, i])) / 2
98+ B[i, j] = val
99+ B[j, i] = - adjoint (val)
100+ end
101+ B[j, j] = MatrixAlgebraKit. _imimag (A[j, j])
102+ end
103+ return
104+ end
105+
106+ function _project_hermitian_diag_kernel (A, B, :: Val{false} )
107+ n = size (A, 1 )
108+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
109+ j > n && return
110+ @inbounds begin
111+ for i in 1 i32: (j - 1 i32)
112+ val = (A[i, j] + adjoint (A[j, i])) / 2
113+ B[i, j] = val
114+ B[j, i] = adjoint (val)
115+ end
116+ B[j, j] = real (A[j, j])
117+ end
118+ return
119+ end
120+ # COV_EXCL_STOP
121+
122+ function MatrixAlgebraKit. _project_hermitian_offdiag! (
123+ Au:: StridedCuMatrix , Al:: StridedCuMatrix , Bu:: StridedCuMatrix , Bl:: StridedCuMatrix , :: Val{anti}
124+ ) where {anti}
125+ thread_dim = 512
126+ block_dim = cld (size (Au, 2 ), thread_dim)
127+ @cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, Val (anti))
128+ return nothing
129+ end
130+ function MatrixAlgebraKit. _project_hermitian_diag! (A:: StridedCuMatrix , B:: StridedCuMatrix , :: Val{anti} ) where {anti}
131+ thread_dim = 512
132+ block_dim = cld (size (A, 1 ), thread_dim)
133+ @cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel (A, B, Val (anti))
134+ return nothing
135+ end
136+
137+ MatrixAlgebraKit. ishermitian_exact (A:: StridedCuMatrix ) = all (A .== adjoint (A))
138+ MatrixAlgebraKit. ishermitian_exact (A:: Diagonal{T, <:StridedCuVector{T}} ) where {T} = all (A. diag .== adjoint (A. diag))
139+
140+ MatrixAlgebraKit. isantihermitian_exact (A:: StridedCuMatrix ) = all (A .== - adjoint (A))
141+ MatrixAlgebraKit. isantihermitian_exact (A:: Diagonal{T, <:StridedCuVector{T}} ) where {T} = all (A. diag .== - adjoint (A. diag))
142+
143+ function MatrixAlgebraKit. _avgdiff! (A:: StridedCuMatrix , B:: StridedCuMatrix )
144+ axes (A) == axes (B) || throw (DimensionMismatch ())
145+ # COV_EXCL_START
146+ function _avgdiff_kernel (A, B)
147+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
148+ j > length (A) && return
149+ @inbounds begin
150+ a = A[j]
151+ b = B[j]
152+ A[j] = (a + b) / 2
153+ B[j] = b - a
154+ end
155+ return
156+ end
157+ # COV_EXCL_STOP
158+ thread_dim = 512
159+ block_dim = cld (length (A), thread_dim)
160+ @cuda threads = thread_dim blocks = block_dim _avgdiff_kernel (A, B)
161+ return A, B
162+ end
163+
61164end
0 commit comments