diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index 287f0db7..a5ea4429 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -97,10 +97,15 @@ function LinearAlgebra.generic_matvecmul!(Y::oneVector, tA::AbstractChar, A::one if T <: onemklFloat && eltype(A) == eltype(B) == T if tA in ('N', 'T', 'C') return gemv!(tA, alpha, A, B, beta, Y) - elseif tA in ('S', 's') + elseif tA in ('S', 's') && T <: Real + # complex symv! is not wrapped; fall through to generic_matmatmul!, + # which can use symm! instead return symv!(tA == 'S' ? 'U' : 'L', alpha, A, B, beta, Y) elseif tA in ('H', 'h') - return hemv!(tA == 'H' ? 'U' : 'L', alpha, A, B, beta, Y) + # hemv! only supports complex eltypes, but a real Hermitian matrix + # is symmetric + fun = T <: Real ? symv! : hemv! + return fun(tA == 'H' ? 'U' : 'L', alpha, A, B, beta, Y) end end end @@ -162,14 +167,20 @@ function LinearAlgebra.generic_matmatmul!( A isa oneStridedArray{T} && B isa oneStridedArray{T} ) return gemm!(tA, tB, α, A, B, β, C) - elseif (tA == 'S' || tA == 's') && tB == 'N' - return symm!('L', tA == 'S' ? 'U' : 'L', α, A, B, β, C) - elseif (tB == 'S' || tB == 's') && tA == 'N' - return symm!('R', tB == 'S' ? 'U' : 'L', α, B, A, β, C) - elseif (tA == 'H' || tA == 'h') && tB == 'N' - return hemm!('L', tA == 'H' ? 'U' : 'L', α, A, B, β, C) - elseif (tB == 'H' || tB == 'h') && tA == 'N' - return hemm!('R', tB == 'H' ? 'U' : 'L', α, B, A, β, C) + elseif T <: onemklFloat && A isa oneStridedArray{T} && B isa oneStridedArray{T} + # hemm! only supports complex eltypes, but a real Hermitian matrix + # is symmetric + if (tA == 'S' || tA == 's') && tB == 'N' + return symm!('L', tA == 'S' ? 'U' : 'L', α, A, B, β, C) + elseif (tB == 'S' || tB == 's') && tA == 'N' + return symm!('R', tB == 'S' ? 'U' : 'L', α, B, A, β, C) + elseif (tA == 'H' || tA == 'h') && tB == 'N' + fun = T <: Real ? symm! : hemm! + return fun('L', tA == 'H' ? 'U' : 'L', α, A, B, β, C) + elseif (tB == 'H' || tB == 'h') && tA == 'N' + fun = T <: Real ? symm! : hemm! + return fun('R', tB == 'H' ? 'U' : 'L', α, B, A, β, C) + end end end