@@ -58,4 +58,102 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::T
5858 return MatrixAlgebraKit. findtruncated (values, strategy)
5959end
6060
61+ function _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, :: Val{true} )
62+ m, n = size (Au)
63+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
64+ j > n && return
65+ for i in 1 : m
66+ @inbounds begin
67+ val = (Au[i, j] - adjoint (Al[j, i])) / 2
68+ Bu[i, j] = val
69+ Bl[j, i] = - adjoint (val)
70+ end
71+ end
72+ return
73+ end
74+
75+ function _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, :: Val{false} )
76+ m, n = size (Au)
77+ j = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
78+ j > n && return
79+ for i in 1 : m
80+ @inbounds begin
81+ val = (Au[i, j] + adjoint (Al[j, i])) / 2
82+ Bu[i, j] = val
83+ Bl[j, i] = adjoint (val)
84+ end
85+ end
86+ return
87+ end
88+
89+ function _project_hermitian_diag_kernel (A, B, :: Val{true} )
90+ n = size (A, 1 )
91+ j = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
92+ j > n && return
93+ @inbounds begin
94+ for i in 1 i32: (j - 1 i32)
95+ val = (A[i, j] - adjoint (A[j, i])) / 2
96+ B[i, j] = val
97+ B[j, i] = - adjoint (val)
98+ end
99+ B[j, j] = MatrixAlgebraKit. _imimag (A[j, j])
100+ end
101+ return
102+ end
103+
104+ function _project_hermitian_diag_kernel (A, B, :: Val{false} )
105+ n = size (A, 1 )
106+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
107+ j > n && return
108+ @inbounds begin
109+ for i in 1 i32: (j - 1 i32)
110+ val = (A[i, j] + adjoint (A[j, i])) / 2
111+ B[i, j] = val
112+ B[j, i] = adjoint (val)
113+ end
114+ B[j, j] = real (A[j, j])
115+ end
116+ return
117+ end
118+
119+ function MatrixAlgebraKit. _project_hermitian_offdiag! (
120+ Au:: StridedCuMatrix , Al:: StridedCuMatrix , Bu:: StridedCuMatrix , Bl:: StridedCuMatrix , :: Val{anti}
121+ ) where {anti}
122+ thread_dim = 512
123+ block_dim = cld (size (Au, 2 ), thread_dim)
124+ @cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel (Au, Al, Bu, Bl, Val (anti))
125+ return nothing
126+ end
127+ function MatrixAlgebraKit. _project_hermitian_diag! (A:: StridedCuMatrix , B:: StridedCuMatrix , :: Val{anti} ) where {anti}
128+ thread_dim = 512
129+ block_dim = cld (size (A, 1 ), thread_dim)
130+ @cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel (A, B, Val (anti))
131+ return nothing
132+ end
133+
134+ MatrixAlgebraKit. ishermitian_exact (A:: StridedCuMatrix ) = all (A .== adjoint (A))
135+ MatrixAlgebraKit. ishermitian_exact (A:: Diagonal{T, <:StridedCuVector{T}} ) where {T} = all (A. diag .== adjoint (A. diag))
136+
137+ MatrixAlgebraKit. isantihermitian_exact (A:: StridedCuMatrix ) = all (A .== - adjoint (A))
138+ MatrixAlgebraKit. isantihermitian_exact (A:: Diagonal{T, <:StridedCuVector{T}} ) where {T} = all (A. diag .== - adjoint (A. diag))
139+
140+ function MatrixAlgebraKit. _avgdiff! (A:: StridedCuMatrix , B:: StridedCuMatrix )
141+ axes (A) == axes (B) || throw (DimensionMismatch ())
142+ function _avgdiff_kernel (A, B)
143+ j = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
144+ j > length (A) && return
145+ @inbounds begin
146+ a = A[j]
147+ b = B[j]
148+ A[j] = (a + b) / 2
149+ B[j] = b - a
150+ end
151+ return
152+ end
153+ thread_dim = 512
154+ block_dim = cld (length (A), thread_dim)
155+ @cuda threads = thread_dim blocks = block_dim _avgdiff_kernel (A, B)
156+ return A, B
157+ end
158+
61159end
0 commit comments