Skip to content

Commit ffb0119

Browse files
committed
Comments
1 parent 44d3cc8 commit ffb0119

3 files changed

Lines changed: 5 additions & 8 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ using CUDA: i32
1313
using LinearAlgebra
1414
using LinearAlgebra: BlasFloat
1515

16-
using CUDA: i32
17-
1816
include("yacusolver.jl")
1917

2018
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}

src/implementations/eig.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm)
137137
D, V = DV
138138
if alg isa GPU_Simple
139139
isempty(alg.kwargs) ||
140-
throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments"))
140+
@warn "GPU_Simple (geev) does not accept any keyword arguments"
141141
_gpu_geev!(A, D.diag, V)
142142
end
143143
# TODO: make this controllable using a `gaugefix` keyword argument
@@ -149,9 +149,8 @@ function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm)
149149
check_input(eig_vals!, A, D, alg)
150150
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
151151
if alg isa GPU_Simple
152-
# TODO filter out nothing kwargs
153-
#isempty(alg.kwargs) ||
154-
# throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments"))
152+
isempty(alg.kwargs) ||
153+
@warn "GPU_Simple (geev) does not accept any keyword arguments"
155154
_gpu_geev!(A, D, V)
156155
end
157156
return D

test/cuda/projections.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
1212
m = 54
1313
noisefactor = eps(real(T))^(3 / 4)
1414
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15-
A = CuArray(randn(rng, T, m, m))
15+
A = CuArray(randn(rng, T, m, m))
1616
Ah = (A + A') / 2
1717
Aa = (A - A') / 2
1818
Ac = copy(A)
@@ -69,7 +69,7 @@ end
6969
# test that W is closer to A then any other isometry
7070
for k in 1:10
7171
δA = CuArray(randn(rng, T, m, n))
72-
W = project_isometric(A, alg)
72+
W = project_isometric(A, alg)
7373
W2 = project_isometric(A + δA / 100, alg)
7474
@test norm(A - W2) > norm(A - W)
7575
end

0 commit comments

Comments
 (0)