@@ -74,13 +74,51 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
7474 return C
7575end
7676
77+ # Slow non-inlined functions for throwing the error without messing up the caller
78+ @noinline function _matmul_size_error (mC, nC, mA, nA, mB, nB, At, Bt)
79+ if At == ' N'
80+ Anames = " first" , " second"
81+ else
82+ Anames = " second" , " first"
83+ end
84+ if Bt == ' N'
85+ Bnames = " first" , " second"
86+ else
87+ Bnames = " second" , " first"
88+ end
89+ nA == mB ||
90+ throw (DimensionMismatch (" $(Anames[2 ]) dimension of A, $nA , does not match the $(Bnames[1 ]) dimension of B, $mB " ))
91+ mA == mC ||
92+ throw (DimensionMismatch (" $(Anames[1 ]) dimension of A, $mA , does not match the first dimension of C, $mC " ))
93+ nB == nC ||
94+ throw (DimensionMismatch (" $(Bnames[2 ]) dimension of B, $nB , does not match the second dimension of C, $nC " ))
95+ # unreachable
96+ throw (DimensionMismatch (" Unknown dimension mismatch" ))
97+ end
98+
99+ @inline function _matmul_size (C, A, B, :: Val{At} , :: Val{Bt} ) where {At,Bt}
100+ mC = size (C, 1 )
101+ nC = size (C, 2 )
102+ mA = size (A, 1 )
103+ nA = size (A, 2 )
104+ mB = size (B, 1 )
105+ nB = size (B, 2 )
106+
107+ _mA, _nA = At == ' N' ? (mA, nA) : (nA, mA)
108+ _mB, _nB = Bt == ' N' ? (mB, nB) : (nB, mB)
109+
110+ if (_nA != _mB) | (_mA != mC) | (_nB != nC)
111+ _matmul_size_error (mC, nC, _mA, _nA, _mB, _nB, At, Bt)
112+ end
113+ return mC, nC, mA, nA, mB, nB
114+ end
115+
116+ @inline _matmul_size_AB (C, A, B) = _matmul_size (C, A, B, Val (' N' ), Val (' N' ))
117+ @inline _matmul_size_AtB (C, A, B) = _matmul_size (C, A, B, Val (' T' ), Val (' N' ))
118+ @inline _matmul_size_ABt (C, A, B) = _matmul_size (C, A, B, Val (' N' ), Val (' T' ))
119+
77120function _spmatmul! (C, A, B, α, β)
78- size (A, 2 ) == size (B, 1 ) ||
79- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the first dimension of B, $(size (B,1 )) " ))
80- size (A, 1 ) == size (C, 1 ) ||
81- throw (DimensionMismatch (" first dimension of A, $(size (A,1 )) , does not match the first dimension of C, $(size (C,1 )) " ))
82- size (B, 2 ) == size (C, 2 ) ||
83- throw (DimensionMismatch (" second dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
121+ _matmul_size_AB (C, A, B)
84122 nzv = nonzeros (A)
85123 rv = rowvals (A)
86124 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
@@ -95,12 +133,7 @@ function _spmatmul!(C, A, B, α, β)
95133end
96134
97135function _At_or_Ac_mul_B! (tfun:: Function , C, A, B, α, β)
98- size (A, 2 ) == size (C, 1 ) ||
99- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the first dimension of C, $(size (C,1 )) " ))
100- size (A, 1 ) == size (B, 1 ) ||
101- throw (DimensionMismatch (" first dimension of A, $(size (A,1 )) , does not match the first dimension of B, $(size (B,1 )) " ))
102- size (B, 2 ) == size (C, 2 ) ||
103- throw (DimensionMismatch (" second dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
136+ _matmul_size_AtB (C, A, B)
104137 nzv = nonzeros (A)
105138 rv = rowvals (A)
106139 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
@@ -127,13 +160,7 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB
127160 return C
128161end
129162function _spmul! (C:: StridedMatrix , X:: DenseMatrixUnion , A:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
130- mX, nX = size (X)
131- nX == size (A, 1 ) ||
132- throw (DimensionMismatch (" second dimension of X, $nX , does not match the first dimension of A, $(size (A,1 )) " ))
133- mX == size (C, 1 ) ||
134- throw (DimensionMismatch (" first dimension of X, $mX , does not match the first dimension of C, $(size (C,1 )) " ))
135- size (A, 2 ) == size (C, 2 ) ||
136- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
163+ _matmul_size_AB (C, X, A)
137164 rv = rowvals (A)
138165 nzv = nonzeros (A)
139166 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
@@ -146,13 +173,7 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2
146173 end
147174end
148175function _spmul! (C:: StridedMatrix , X:: AdjOrTrans{<:Any,<:DenseMatrixUnion} , A:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
149- mX, nX = size (X)
150- nX == size (A, 1 ) ||
151- throw (DimensionMismatch (" second dimension of X, $nX , does not match the first dimension of A, $(size (A,1 )) " ))
152- mX == size (C, 1 ) ||
153- throw (DimensionMismatch (" first dimension of X, $mX , does not match the first dimension of C, $(size (C,1 )) " ))
154- size (A, 2 ) == size (C, 2 ) ||
155- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
176+ _matmul_size_AB (C, X, A)
156177 rv = rowvals (A)
157178 nzv = nonzeros (A)
158179 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
@@ -164,13 +185,7 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S
164185end
165186
166187function _A_mul_Bt_or_Bc! (tfun:: Function , C:: StridedMatrix , A:: AbstractMatrix , B:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
167- mA, nA = size (A)
168- nA == size (B, 2 ) ||
169- throw (DimensionMismatch (" second dimension of A, $nA , does not match the second dimension of B, $(size (B,2 )) " ))
170- mA == size (C, 1 ) ||
171- throw (DimensionMismatch (" first dimension of A, $mA , does not match the first dimension of C, $(size (C,1 )) " ))
172- size (B, 1 ) == size (C, 2 ) ||
173- throw (DimensionMismatch (" first dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
188+ _matmul_size_ABt (C, A, B)
174189 rv = rowvals (B)
175190 nzv = nonzeros (B)
176191 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
0 commit comments