Skip to content

Commit f516fc2

Browse files
authored
Unbreak mixed type mul on 1.12.4 (#676)
1 parent 0c08596 commit f516fc2

3 files changed

Lines changed: 21 additions & 2 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GPUArrays"
22
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
3-
version = "11.3.3"
3+
version = "11.3.4"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/host/linalg.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# integration with LinearAlgebra stdlib
22

3-
using LinearAlgebra: MulAddMul, wrap, diagm
3+
using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal
44

55
## transpose and adjoint
66

@@ -493,6 +493,16 @@ end
493493
function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T}
494494
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
495495
end
496+
# need to support mixed complex/real types too
497+
#function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{Complex{T}}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::V) where {T<:BlasReal, V<:LinearAlgebra.BlasFlag.SyrkHerkGemm}
498+
# LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
499+
#end
500+
function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{Complex{T}}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::Val{LinearAlgebra.BlasFlag.GEMM}) where T<:Union{Float32, Float64}
501+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
502+
end
503+
function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{Complex{T}}, alpha::Number, beta::Number, val::Val{LinearAlgebra.BlasFlag.GEMM}) where T<:Union{Float32, Float64}
504+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
505+
end
496506
# Julia 1.12 introduced generic_mul! for scalar * array operations
497507
function LinearAlgebra.generic_mul!(C::AbstractGPUVecOrMat, X::AbstractGPUVecOrMat, s::Number, alpha::Number, beta::Number)
498508
if length(C) != length(X)

test/testsuite/linalg.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,15 @@ end
484484
@test compare(mul!, AT, C, f(A), g(B), Ref(T(4)), Ref(T(5)))
485485
@test typeof(AT(rand(T, 3, 3)) * AT(rand(T, 3, 3))) <: AbstractMatrix
486486
end
487+
@testset "$(complex(T)), $(complex(T)), $T gemm C := A * B * a + C * b" for T in filter(T-><:(T, Real) && <:(T, AbstractFloat), eltypes)
488+
Tc = complex(T)
489+
A, B, C = rand(Tc, 4, 4), rand(T, 4, 4), rand(Tc, 4, 4)
490+
491+
@test compare(*, AT, A, B)
492+
@test compare(mul!, AT, C, A, B)
493+
@test compare(mul!, AT, C, A, B, Ref(T(4)), Ref(T(5)))
494+
@test typeof(AT(rand(Tc, 3, 3)) * AT(rand(T, 3, 3))) <: AbstractMatrix
495+
end
487496
end
488497

489498
@testsuite "linalg/norm" (AT, eltypes)->begin

0 commit comments

Comments
 (0)