Skip to content

Commit 2011261

Browse files
committed
allocating path for GPU
1 parent 1e0610d commit 2011261

2 files changed

Lines changed: 23 additions & 8 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,17 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid
128128
end
129129

130130
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
131-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
131+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
132+
all(A.diag .== adjoint(A.diag))
133+
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; kwargs...) =
134+
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)
132135

133-
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all(A .== -adjoint(A))
134-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
136+
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) =
137+
all(A .== -adjoint(A))
138+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
139+
all(A.diag .== -adjoint(A.diag))
140+
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; kwargs...) =
141+
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
135142

136143
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
137144
axes(A) == axes(B) || throw(DimensionMismatch())

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,19 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::Stride
134134
return nothing
135135
end
136136

137-
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all(A .== adjoint(A))
138-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
139-
140-
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all(A .== -adjoint(A))
141-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
137+
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) =
138+
all(A .== adjoint(A))
139+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
140+
all(A.diag .== adjoint(A.diag))
141+
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; kwargs...) =
142+
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)
143+
144+
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) =
145+
all(A .== -adjoint(A))
146+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
147+
all(A.diag .== -adjoint(A.diag))
148+
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; kwargs...) =
149+
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
142150

143151
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
144152
axes(A) == axes(B) || throw(DimensionMismatch())

0 commit comments

Comments
 (0)