Skip to content

Commit f4e2f8f

Browse files
committed
More GPU fixes
1 parent 6f99876 commit f4e2f8f

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5757

5858
[targets]
5959
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
60+
61+
[sources]
62+
CUDA = {url="https://github.com/JuliaGPU/CUDA.jl", rev="master"}

test/testsuite/ad_utils.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ function remove_eighgauge_dependence!(
2929
end
3030

3131
function stabilize_eigvals!(D::AbstractVector)
32-
absD = abs.(D)
33-
p = invperm(sortperm(absD)) # rank of abs(D)
32+
absD = collect(abs.(D))
33+
p = invperm(sortperm(collect(absD))) # rank of abs(D)
3434
# account for exact degeneracies in absolute value when having complex conjugate pairs
3535
for i in 1:(length(D) - 1)
3636
if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially
@@ -41,9 +41,7 @@ function stabilize_eigvals!(D::AbstractVector)
4141
# rescale eigenvalues so that they lie on distinct radii in the complex plane
4242
# that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n
4343
radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n
44-
for i in 1:length(D)
45-
D[i] = sign(D[i]) * radii[p[i]]
46-
end
44+
D .= sign.(D) .* radii[p]
4745
return D
4846
end
4947
function make_eig_matrix(T, sz)

0 commit comments

Comments
 (0)