Skip to content

Commit c861c40

Browse files
committed
Working GPU col permute
1 parent 8c1c00f commit c861c40

3 files changed

Lines changed: 17 additions & 2 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,4 +206,11 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix
206206
return C
207207
end
208208

209+
function MatrixAlgebraKit.permute_V_cols!(V, I::ROCVector{Int})
210+
I_ixs = ROCArray(collect(1:size(V, 1)))
211+
c_ixs = map(CartesianIndex, I, I_ixs)
212+
V[c_ixs] .= one(eltype(V))
213+
return V
214+
end
215+
209216
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,11 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T
191191
return C
192192
end
193193

194+
function MatrixAlgebraKit.permute_V_cols!(V, I::CuVector{Int})
195+
I_ixs = CuArray(collect(1:size(V, 1)))
196+
c_ixs = map(CartesianIndex, I, I_ixs)
197+
V[c_ixs] .= one(eltype(V))
198+
return V
199+
end
200+
194201
end

src/implementations/eigh.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ function eigh_trunc_no_error!(A, DV, alg::TruncatedAlgorithm)
141141
return DVtrunc
142142
end
143143

144+
permute_V_cols!(V, I::Vector{Int}) = Base.permutecols!!(V, I)
145+
144146
# Diagonal logic
145147
# --------------
146148
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
@@ -153,8 +155,7 @@ function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
153155
diagview(D) .= real.(diagview(A))[I]
154156
end
155157
zero!(V)
156-
Is = [CartesianIndex(ix, I[ix]) for ix in 1:size(A, 1)]
157-
V[Is] .= one(eltype(A))
158+
V = permute_V_cols!(V, I)
158159
return D, V
159160
end
160161

0 commit comments

Comments
 (0)