|
1 | 1 | # integration with LinearAlgebra stdlib |
2 | 2 |
|
3 | | -using LinearAlgebra: MulAddMul, wrap, diagm |
| 3 | +using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal |
4 | 4 |
|
5 | 5 | ## transpose and adjoint |
6 | 6 |
|
|
493 | 493 | function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T} |
494 | 494 | LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) |
495 | 495 | end |
| 496 | + # need to support mixed complex/real types too |
| 497 | + #function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{Complex{T}}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::V) where {T<:BlasReal, V<:LinearAlgebra.BlasFlag.SyrkHerkGemm} |
| 498 | + # LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) |
| 499 | + #end |
| 500 | + function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{Complex{T}}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::Val{LinearAlgebra.BlasFlag.GEMM}) where T<:Union{Float32, Float64} |
| 501 | + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) |
| 502 | + end |
| 503 | + function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{Complex{T}}, alpha::Number, beta::Number, val::Val{LinearAlgebra.BlasFlag.GEMM}) where T<:Union{Float32, Float64} |
| 504 | + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) |
| 505 | + end |
496 | 506 | # Julia 1.12 introduced generic_mul! for scalar * array operations |
497 | 507 | function LinearAlgebra.generic_mul!(C::AbstractGPUVecOrMat, X::AbstractGPUVecOrMat, s::Number, alpha::Number, beta::Number) |
498 | 508 | if length(C) != length(X) |
|
0 commit comments