Skip to content

Commit 715dd78

Browse files
committed
Helper function to get sizes of all input matrices and outline error throwing function
1 parent 1193147 commit 715dd78

1 file changed

Lines changed: 48 additions & 33 deletions

File tree

src/linalg.jl

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,51 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
7474
return C
7575
end
7676

77+
# Slow non-inlined functions for throwing the error without messing up the caller
78+
@noinline function _matmul_size_error(mC, nC, mA, nA, mB, nB, At, Bt)
79+
if At == 'N'
80+
Anames = "first", "second"
81+
else
82+
Anames = "second", "first"
83+
end
84+
if Bt == 'N'
85+
Bnames = "first", "second"
86+
else
87+
Bnames = "second", "first"
88+
end
89+
nA == mB ||
90+
throw(DimensionMismatch("$(Anames[2]) dimension of A, $nA, does not match the $(Bnames[1]) dimension of B, $mB"))
91+
mA == mC ||
92+
throw(DimensionMismatch("$(Anames[1]) dimension of A, $mA, does not match the first dimension of C, $mC"))
93+
nB == nC ||
94+
throw(DimensionMismatch("$(Bnames[2]) dimension of B, $nB, does not match the second dimension of C, $nC"))
95+
# unreachable
96+
throw(DimensionMismatch("Unknown dimension mismatch"))
97+
end
98+
99+
@inline function _matmul_size(C, A, B, ::Val{At}, ::Val{Bt}) where {At,Bt}
100+
mC = size(C, 1)
101+
nC = size(C, 2)
102+
mA = size(A, 1)
103+
nA = size(A, 2)
104+
mB = size(B, 1)
105+
nB = size(B, 2)
106+
107+
_mA, _nA = At == 'N' ? (mA, nA) : (nA, mA)
108+
_mB, _nB = Bt == 'N' ? (mB, nB) : (nB, mB)
109+
110+
if (_nA != _mB) | (_mA != mC) | (_nB != nC)
111+
_matmul_size_error(mC, nC, _mA, _nA, _mB, _nB, At, Bt)
112+
end
113+
return mC, nC, mA, nA, mB, nB
114+
end
115+
116+
@inline _matmul_size_AB(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('N'))
117+
@inline _matmul_size_AtB(C, A, B) = _matmul_size(C, A, B, Val('T'), Val('N'))
118+
@inline _matmul_size_ABt(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('T'))
119+
77120
function _spmatmul!(C, A, B, α, β)
78-
size(A, 2) == size(B, 1) ||
79-
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))"))
80-
size(A, 1) == size(C, 1) ||
81-
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))"))
82-
size(B, 2) == size(C, 2) ||
83-
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
121+
_matmul_size_AB(C, A, B)
84122
nzv = nonzeros(A)
85123
rv = rowvals(A)
86124
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
@@ -95,12 +133,7 @@ function _spmatmul!(C, A, B, α, β)
95133
end
96134

97135
function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
98-
size(A, 2) == size(C, 1) ||
99-
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))"))
100-
size(A, 1) == size(B, 1) ||
101-
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of B, $(size(B,1))"))
102-
size(B, 2) == size(C, 2) ||
103-
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
136+
_matmul_size_AtB(C, A, B)
104137
nzv = nonzeros(A)
105138
rv = rowvals(A)
106139
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
@@ -127,13 +160,7 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB
127160
return C
128161
end
129162
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
130-
mX, nX = size(X)
131-
nX == size(A, 1) ||
132-
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
133-
mX == size(C, 1) ||
134-
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
135-
size(A, 2) == size(C, 2) ||
136-
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
163+
_matmul_size_AB(C, X, A)
137164
rv = rowvals(A)
138165
nzv = nonzeros(A)
139166
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
@@ -146,13 +173,7 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2
146173
end
147174
end
148175
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
149-
mX, nX = size(X)
150-
nX == size(A, 1) ||
151-
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
152-
mX == size(C, 1) ||
153-
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
154-
size(A, 2) == size(C, 2) ||
155-
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
176+
_matmul_size_AB(C, X, A)
156177
rv = rowvals(A)
157178
nzv = nonzeros(A)
158179
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
@@ -164,13 +185,7 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S
164185
end
165186

166187
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
167-
mA, nA = size(A)
168-
nA == size(B, 2) ||
169-
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
170-
mA == size(C, 1) ||
171-
throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of C, $(size(C,1))"))
172-
size(B, 1) == size(C, 2) ||
173-
throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
188+
_matmul_size_ABt(C, A, B)
174189
rv = rowvals(B)
175190
nzv = nonzeros(B)
176191
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)

0 commit comments

Comments
 (0)