22LinearAlgebra. rmul! (dst:: StridedView , α:: Number ) = mul! (dst, dst, α)
33LinearAlgebra. lmul! (α:: Number , dst:: StridedView ) = mul! (dst, α, dst)
44
5- function LinearAlgebra. mul! (dst:: StridedView{<:Number,N} , α:: Number ,
6- src:: StridedView{<:Number,N} ) where {N}
5+ function LinearAlgebra. mul! (
6+ dst:: StridedView{<:Number, N} , α:: Number ,
7+ src:: StridedView{<:Number, N}
8+ ) where {N}
79 if α == 1
810 copy! (dst, src)
911 else
1012 dst .= α .* src
1113 end
1214 return dst
1315end
14- function LinearAlgebra. mul! (dst:: StridedView{<:Number,N} , src:: StridedView{<:Number,N} ,
15- α:: Number ) where {N}
16+ function LinearAlgebra. mul! (
17+ dst:: StridedView{<:Number, N} , src:: StridedView{<:Number, N} ,
18+ α:: Number
19+ ) where {N}
1620 if α == 1
1721 copy! (dst, src)
1822 else
1923 dst .= src .* α
2024 end
2125 return dst
2226end
23- function LinearAlgebra. axpy! (a:: Number , X:: StridedView{<:Number,N} ,
24- Y:: StridedView{<:Number,N} ) where {N}
27+ function LinearAlgebra. axpy! (
28+ a:: Number , X:: StridedView{<:Number, N} ,
29+ Y:: StridedView{<:Number, N}
30+ ) where {N}
2531 if a == 1
2632 Y .= X .+ Y
2733 else
2834 Y .= a .* X .+ Y
2935 end
3036 return Y
3137end
32- function LinearAlgebra. axpby! (a:: Number , X:: StridedView{<:Number,N} ,
33- b:: Number , Y:: StridedView{<:Number,N} ) where {N}
38+ function LinearAlgebra. axpby! (
39+ a:: Number , X:: StridedView{<:Number, N} ,
40+ b:: Number , Y:: StridedView{<:Number, N}
41+ ) where {N}
3442 if b == 1
3543 axpy! (a, X, Y)
3644 elseif b == 0
@@ -41,9 +49,11 @@ function LinearAlgebra.axpby!(a::Number, X::StridedView{<:Number,N},
4149 return Y
4250end
4351
44- function LinearAlgebra. mul! (C:: StridedView{T,2} ,
45- A:: StridedView{<:Any,2} , B:: StridedView{<:Any,2} ,
46- α:: Number = true , β:: Number = false ) where {T}
52+ function LinearAlgebra. mul! (
53+ C:: StridedView{T, 2} ,
54+ A:: StridedView{<:Any, 2} , B:: StridedView{<:Any, 2} ,
55+ α:: Number = true , β:: Number = false
56+ ) where {T}
4757 if ! (eltype (C) <: LinearAlgebra.BlasFloat && eltype (A) == eltype (B) == eltype (C))
4858 return __mul! (C, A, B, α, β)
4959 end
@@ -62,7 +72,7 @@ function LinearAlgebra.mul!(C::StridedView{T,2},
6272 return C
6373end
6474
65- function isblasmatrix (A:: StridedView{T,2} ) where {T<: LinearAlgebra.BlasFloat }
75+ function isblasmatrix (A:: StridedView{T, 2} ) where {T <: LinearAlgebra.BlasFloat }
6676 if A. op == identity
6777 return stride (A, 1 ) == 1 || stride (A, 2 ) == 1
6878 elseif A. op == conj
@@ -71,7 +81,7 @@ function isblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
7181 return false
7282 end
7383end
74- function getblasmatrix (A:: StridedView{T,2} ) where {T<: LinearAlgebra.BlasFloat }
84+ function getblasmatrix (A:: StridedView{T, 2} ) where {T <: LinearAlgebra.BlasFloat }
7585 if A. op == identity
7686 if stride (A, 1 ) == 1
7787 return A, ' N'
@@ -84,8 +94,10 @@ function getblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
8494end
8595
8696# here we will have C.op == :identity && stride(C,1) < stride(C,2)
87- function _mul! (C:: StridedView{T,2} , A:: StridedView{T,2} , B:: StridedView{T,2} ,
88- α:: Number , β:: Number ) where {T<: LinearAlgebra.BlasFloat }
97+ function _mul! (
98+ C:: StridedView{T, 2} , A:: StridedView{T, 2} , B:: StridedView{T, 2} ,
99+ α:: Number , β:: Number
100+ ) where {T <: LinearAlgebra.BlasFloat }
89101 if stride (C, 1 ) == 1 && isblasmatrix (A) && isblasmatrix (B)
90102 nthreads = use_threaded_mul () ? get_num_threads () : 1
91103 _threaded_blas_mul! (C, A, B, α, β, nthreads)
@@ -94,41 +106,53 @@ function _mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2},
94106 end
95107end
96108
97- function _threaded_blas_mul! (C:: StridedView{T,2} , A:: StridedView{T,2} , B:: StridedView{T,2} ,
98- α:: Number , β:: Number ,
99- nthreads) where {T<: LinearAlgebra.BlasFloat }
109+ function _threaded_blas_mul! (
110+ C:: StridedView{T, 2} , A:: StridedView{T, 2} , B:: StridedView{T, 2} ,
111+ α:: Number , β:: Number ,
112+ nthreads
113+ ) where {T <: LinearAlgebra.BlasFloat }
100114 m, n = size (C)
101115 m == size (A, 1 ) && n == size (B, 2 ) || throw (DimensionMismatch ())
102- if nthreads == 1 || m * n < 1024
116+ return if nthreads == 1 || m * n < 1024
103117 A2, CA = getblasmatrix (A)
104118 B2, CB = getblasmatrix (B)
105119 LinearAlgebra. BLAS. gemm! (CA, CB, convert (T, α), A2, B2, convert (T, β), C)
106120 else
107121 if m > n
108122 m2 = round (Int, m / 16 ) * 8
109123 nthreads2 = nthreads >> 1
110- t = Threads. @spawn _threaded_blas_mul! (C[1 : ($ m2), :], A[1 : ($ m2), :], B, α, β,
111- $ nthreads2)
112- _threaded_blas_mul! (C[(m2 + 1 ): m, :], A[(m2 + 1 ): m, :], B, α, β,
113- nthreads - nthreads2)
124+ t = Threads. @spawn _threaded_blas_mul! (
125+ C[1 : ($ m2), :], A[1 : ($ m2), :], B, α, β,
126+ $ nthreads2
127+ )
128+ _threaded_blas_mul! (
129+ C[(m2 + 1 ): m, :], A[(m2 + 1 ): m, :], B, α, β,
130+ nthreads - nthreads2
131+ )
114132 wait (t)
115133 return C
116134 else
117135 n2 = round (Int, n / 16 ) * 8
118136 nthreads2 = nthreads >> 1
119- t = Threads. @spawn _threaded_blas_mul! (C[:, 1 : ($ n2)], A, B[:, 1 : ($ n2)], α, β,
120- $ nthreads2)
121- _threaded_blas_mul! (C[:, (n2 + 1 ): n], A, B[:, (n2 + 1 ): n], α, β,
122- nthreads - nthreads2)
137+ t = Threads. @spawn _threaded_blas_mul! (
138+ C[:, 1 : ($ n2)], A, B[:, 1 : ($ n2)], α, β,
139+ $ nthreads2
140+ )
141+ _threaded_blas_mul! (
142+ C[:, (n2 + 1 ): n], A, B[:, (n2 + 1 ): n], α, β,
143+ nthreads - nthreads2
144+ )
123145 wait (t)
124146 return C
125147 end
126148 end
127149end
128150
129151# This implementation is faster than LinearAlgebra.generic_matmatmul
130- function __mul! (C:: StridedView{<:Any,2} , A:: StridedView{<:Any,2} , B:: StridedView{<:Any,2} ,
131- α:: Number , β:: Number )
152+ function __mul! (
153+ C:: StridedView{<:Any, 2} , A:: StridedView{<:Any, 2} , B:: StridedView{<:Any, 2} ,
154+ α:: Number , β:: Number
155+ )
132156 (size (C, 1 ) == size (A, 1 ) && size (C, 2 ) == size (B, 2 ) && size (A, 2 ) == size (B, 1 )) ||
133157 throw (DimensionMismatch (" A has size $(size (A)) , B has size $(size (B)) , C has size $(size (C)) " ))
134158
0 commit comments