@@ -9,6 +9,7 @@ using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_
99import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
1111using CUDA
12+ using CUDA: i32
1213using LinearAlgebra
1314using LinearAlgebra: BlasFloat
1415
@@ -58,4 +59,102 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::T
5859 return MatrixAlgebraKit. findtruncated (values, strategy)
5960end
6061
62+ function _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, :: Val{true} )
63+ m, n = size (Au)
64+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
65+ j > n && return
66+ for i in 1 : m
67+ @inbounds begin
68+ val = (Au[i, j] - adjoint (Al[j, i])) / 2
69+ Bu[i, j] = val
70+ Bl[j, i] = - adjoint (val)
71+ end
72+ end
73+ return
74+ end
75+
76+ function _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, :: Val{false} )
77+ m, n = size (Au)
78+ j = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
79+ j > n && return
80+ for i in 1 : m
81+ @inbounds begin
82+ val = (Au[i, j] + adjoint (Al[j, i])) / 2
83+ Bu[i, j] = val
84+ Bl[j, i] = adjoint (val)
85+ end
86+ end
87+ return
88+ end
89+
90+ function _project_hermitian_diag_kernel (A, B, :: Val{true} )
91+ n = size (A, 1 )
92+ j = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
93+ j > n && return
94+ @inbounds begin
95+ for i in 1 i32: (j - 1 i32)
96+ val = (A[i, j] - adjoint (A[j, i])) / 2
97+ B[i, j] = val
98+ B[j, i] = - adjoint (val)
99+ end
100+ B[j, j] = MatrixAlgebraKit. _imimag (A[j, j])
101+ end
102+ return
103+ end
104+
105+ function _project_hermitian_diag_kernel (A, B, :: Val{false} )
106+ n = size (A, 1 )
107+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
108+ j > n && return
109+ @inbounds begin
110+ for i in 1 i32: (j - 1 i32)
111+ val = (A[i, j] + adjoint (A[j, i])) / 2
112+ B[i, j] = val
113+ B[j, i] = adjoint (val)
114+ end
115+ B[j, j] = real (A[j, j])
116+ end
117+ return
118+ end
119+
120+ function MatrixAlgebraKit. _project_hermitian_offdiag! (
121+ Au:: StridedCuMatrix , Al:: StridedCuMatrix , Bu:: StridedCuMatrix , Bl:: StridedCuMatrix , :: Val{anti}
122+ ) where {anti}
123+ thread_dim = 512
124+ block_dim = cld (size (Au, 2 ), thread_dim)
125+ @cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, Val (anti))
126+ return nothing
127+ end
128+ function MatrixAlgebraKit. _project_hermitian_diag! (A:: StridedCuMatrix , B:: StridedCuMatrix , :: Val{anti} ) where {anti}
129+ thread_dim = 512
130+ block_dim = cld (size (A, 1 ), thread_dim)
131+ @cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel (A, B, Val (anti))
132+ return nothing
133+ end
134+
135+ MatrixAlgebraKit. ishermitian_exact (A:: StridedCuMatrix ) = all (A .== adjoint (A))
136+ MatrixAlgebraKit. ishermitian_exact (A:: Diagonal{T, <:StridedCuVector{T}} ) where {T} = all (A. diag .== adjoint (A. diag))
137+
138+ MatrixAlgebraKit. isantihermitian_exact (A:: StridedCuMatrix ) = all (A .== - adjoint (A))
139+ MatrixAlgebraKit. isantihermitian_exact (A:: Diagonal{T, <:StridedCuVector{T}} ) where {T} = all (A. diag .== - adjoint (A. diag))
140+
141+ function MatrixAlgebraKit. _avgdiff! (A:: StridedCuMatrix , B:: StridedCuMatrix )
142+ axes (A) == axes (B) || throw (DimensionMismatch ())
143+ function _avgdiff_kernel (A, B)
144+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
145+ j > length (A) && return
146+ @inbounds begin
147+ a = A[j]
148+ b = B[j]
149+ A[j] = (a + b) / 2
150+ B[j] = b - a
151+ end
152+ return
153+ end
154+ thread_dim = 512
155+ block_dim = cld (length (A), thread_dim)
156+ @cuda threads = thread_dim blocks = block_dim _avgdiff_kernel (A, B)
157+ return A, B
158+ end
159+
61160end
0 commit comments