Skip to content

Commit f0c5c50

Browse files
kshyattlkdvos
andauthored
Try to do SVD truncation on GPU with _ind_intersect (#148)
* Try to do SVD truncation on GPU with _ind_intersect * Use fixer GPUArrays branch * ind_intersect via filter * Fix ambiguity * Update src/implementations/truncation.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Revert "Use fixer GPUArrays branch" This reverts commit fa9fa80. --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent e1ea618 commit f0c5c50

3 files changed

Lines changed: 15 additions & 8 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,8 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix
167167
return C
168168
end
169169

170-
# TODO: intersect on GPU arrays is not working
171-
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B)
172-
MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B))
173-
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
170+
# TODO: intersect doesn't work on GPU
171+
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) =
172+
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
174173

175174
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,8 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T
191191
return C
192192
end
193193

194-
# TODO: intersect on GPU arrays is not working
195-
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B)
196-
MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B))
197-
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
194+
# TODO: intersect doesn't work on GPU
195+
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
196+
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
198197

199198
end

src/implementations/truncation.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector)
126126
end
127127
_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A)
128128
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
129+
130+
# when one of the ind selections is a unitrange, filter is more efficient than intersect
131+
# since we know both selections only contain unique entries
132+
# (This is also more GPU-friendly!)
133+
_ind_intersect(A::AbstractUnitRange{Int}, B::AbstractUnitRange{Int}) = intersect(A, B)
134+
_ind_intersect(A::AbstractVector{Int}, B::AbstractUnitRange{Int}) = filter(in(B), A)
135+
_ind_intersect(A::AbstractUnitRange{Int}, B::AbstractVector{Int}) = _ind_intersect(B, A)
136+
137+
# when all else fails, call intersect
129138
_ind_intersect(A, B) = intersect(A, B)
130139

131140
# Truncation error

0 commit comments

Comments
 (0)