Skip to content

Commit 234cf55

Browse files
authored
Buffered matrix multiplication (#83)
* Buffered matrix multiplication * Add fallbacks for buffered operations
1 parent cee7868 commit 234cf55

4 files changed

Lines changed: 125 additions & 12 deletions

File tree

src/interface.jl

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,46 @@ end
262262

263263
buffer_for(::Function, args::Vararg{Type,N}) where {N} = nothing
264264

265+
function mutable_buffered_operate_to_fallback(::NotMutable, buffer, output, op::Function, args...)
266+
throw(
267+
ArgumentError(
268+
"Cannot call `mutable_buffered_operate_to!(::$(typeof(buffer)), ::$(typeof(output)), $op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(output))` cannot be modifed to equal the result of the operation. Use `buffered_operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
269+
),
270+
)
271+
end
272+
273+
function mutable_buffered_operate_to_fallback(::IsMutable, buffer, output, op::Function, args...)
274+
error(
275+
"`mutable_buffered_operate_to!(::$(typeof(buffer)), ::$(typeof(output)), $op, ::",
276+
join(typeof.(args), ", ::"),
277+
")` is not implemented.",
278+
)
279+
end
280+
281+
function mutable_buffered_operate_to_fallback(
282+
buffer,
283+
output,
284+
op::Function,
285+
args::Vararg{Any,N},
286+
) where {N}
287+
return mutable_buffered_operate_to_fallback(
288+
mutability(output, op, args...),
289+
buffer,
290+
output,
291+
op,
292+
args...
293+
)
294+
end
295+
296+
function mutable_buffered_operate_to_fallback(
297+
::Nothing,
298+
output,
299+
op::Function,
300+
args::Vararg{Any,N},
301+
) where {N}
302+
return mutable_operate_to!(output, op, args...)
303+
end
304+
265305
"""
266306
mutable_buffered_operate_to!(buffer, output, op::Function, args...)
267307
@@ -270,12 +310,49 @@ possibly modifying `buffer`. Can only be called if
270310
`mutability(output, op, args...)` returns `true`.
271311
"""
272312
function mutable_buffered_operate_to!(
273-
::Nothing,
313+
buffer,
274314
output,
275315
op::Function,
276316
args::Vararg{Any,N},
277317
) where {N}
278-
return mutable_operate_to!(output, op, args...)
318+
return mutable_buffered_operate_to_fallback(buffer, output, op, args...)
319+
end
320+
321+
function mutable_buffered_operate_fallback(::NotMutable, buffer, op::Function, args...)
322+
throw(
323+
ArgumentError(
324+
"Cannot call `mutable_buffered_operate!(::$(typeof(buffer)), $op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(args[1]))` cannot be modifed to equal the result of the operation. Use `buffered_operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
325+
),
326+
)
327+
end
328+
329+
function mutable_buffered_operate_fallback(::IsMutable, buffer, op::Function, args...)
330+
error(
331+
"`mutable_buffered_operate!(::$(typeof(buffer)), $op, ::",
332+
join(typeof.(args), ", ::"),
333+
")` is not implemented.",
334+
)
335+
end
336+
337+
function mutable_buffered_operate_fallback(
338+
buffer,
339+
op::Function,
340+
args::Vararg{Any,N},
341+
) where {N}
342+
return mutable_buffered_operate_fallback(
343+
mutability(args[1], op, args...),
344+
buffer,
345+
op,
346+
args...
347+
)
348+
end
349+
350+
function mutable_buffered_operate_fallback(
351+
::Nothing,
352+
op::Function,
353+
args::Vararg{Any,N},
354+
) where {N}
355+
return mutable_operate!(op, args...)
279356
end
280357

281358
"""
@@ -286,8 +363,8 @@ possibly modifying `buffer`. Can only be called if
286363
`mutability(args[1], op, args...)` returns `true`.
287364
"""
288365
function mutable_buffered_operate! end
289-
function mutable_buffered_operate!(::Nothing, op::Function, args::Vararg{Any,N}) where {N}
290-
return mutable_operate!(op, args...)
366+
function mutable_buffered_operate!(buffer, op::Function, args::Vararg{Any,N}) where {N}
367+
return mutable_buffered_operate_fallback(buffer, op, args...)
291368
end
292369

293370
"""

src/linear_algebra.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,15 @@ function _dim_check(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
257257
end
258258
end
259259

260-
function _add_mul_array(C::Vector, A::AbstractMatrix, B::AbstractVector)
260+
function _add_mul_array(buffer, C::Vector, A::AbstractMatrix, B::AbstractVector)
261261
Astride = size(A, 1)
262262
# We need a buffer to hold the intermediate multiplication.
263-
mul_buffer = buffer_for(add_mul, eltype(C), eltype(A), eltype(B))
264263
@inbounds begin
265264
for k in eachindex(B)
266265
aoffs = (k - 1) * Astride
267266
b = B[k]
268267
for i in Base.OneTo(size(A, 1))
269-
C[i] = buffered_operate!(mul_buffer, add_mul, C[i], A[aoffs+i], b)
268+
C[i] = buffered_operate!(buffer, add_mul, C[i], A[aoffs+i], b)
270269
end
271270
end
272271
end # @inbounds
@@ -275,28 +274,46 @@ end
275274

276275
# This is incorrect if `C` is `LinearAlgebra.Symmetric` as we modify twice the
277276
# same diagonal element.
278-
function _add_mul_array(C::Matrix, A::AbstractMatrix, B::AbstractMatrix)
279-
mul_buffer = buffer_for(add_mul, eltype(C), eltype(A), eltype(B))
277+
function _add_mul_array(buffer, C::Matrix, A::AbstractMatrix, B::AbstractMatrix)
280278
@inbounds begin
281279
for i = 1:size(A, 1), j = 1:size(B, 2)
282280
Ctmp = C[i, j]
283281
for k = 1:size(A, 2)
284-
Ctmp = buffered_operate!(mul_buffer, add_mul, Ctmp, A[i, k], B[k, j])
282+
Ctmp = buffered_operate!(buffer, add_mul, Ctmp, A[i, k], B[k, j])
285283
end
286284
C[i, j] = Ctmp
287285
end
288286
end # @inbounds
289287
return C
290288
end
291289

292-
function mutable_operate!(
290+
function mutable_buffered_operate!(
291+
buffer,
293292
::typeof(add_mul),
294293
C::VecOrMat,
295294
A::AbstractMatrix,
296295
B::AbstractVecOrMat,
297296
)
298297
_dim_check(C, A, B)
299-
_add_mul_array(C, A, B)
298+
_add_mul_array(buffer, C, A, B)
299+
end
300+
301+
function buffer_for(
302+
::typeof(add_mul),
303+
::Type{<:VecOrMat{S}},
304+
::Type{<:AbstractMatrix{T}},
305+
::Type{<:AbstractVecOrMat{U}},
306+
) where {S,T,U}
307+
return buffer_for(add_mul, S, T, U)
308+
end
309+
function mutable_operate!(
310+
::typeof(add_mul),
311+
C::VecOrMat,
312+
A::AbstractMatrix,
313+
B::AbstractVecOrMat,
314+
)
315+
buffer = buffer_for(add_mul, typeof(C), typeof(A), typeof(B))
316+
return mutable_buffered_operate!(buffer, add_mul, C, A, B)
300317
end
301318

302319
function mutable_operate!(::typeof(zero), C::Union{Vector,Matrix})

test/interface.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ end
2020
"Cannot call `mutable_operate!(+, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
2121
)
2222
@test_throws err MA.mutable_operate!(+, 0, 0)
23+
err = ArgumentError(
24+
"Cannot call `mutable_buffered_operate_to!(::$Int, ::$Int, +, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `buffered_operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
25+
)
26+
@test_throws err MA.mutable_buffered_operate_to!(0, 0, +, 0, 0)
27+
err = ArgumentError(
28+
"Cannot call `mutable_buffered_operate!(::$Int, +, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `buffered_operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
29+
)
30+
@test_throws err MA.mutable_buffered_operate!(0, +, 0, 0)
2331
x = DummyMutable()
2432
err = ErrorException(
2533
"`mutable_operate_to!(::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented yet.",
@@ -29,4 +37,12 @@ end
2937
"`mutable_operate!(+, ::DummyMutable, ::DummyMutable)` is not implemented yet.",
3038
)
3139
@test_throws err MA.mutable_operate!(+, x, x)
40+
err = ErrorException(
41+
"`mutable_buffered_operate_to!(::DummyMutable, ::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented.",
42+
)
43+
@test_throws err MA.mutable_buffered_operate_to!(x, x, +, x, x)
44+
err = ErrorException(
45+
"`mutable_buffered_operate!(::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented.",
46+
)
47+
@test_throws err MA.mutable_buffered_operate!(x, +, x, x)
3248
end

test/matmul.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ end
131131
alloc_test(() -> MA.operate_fallback!(MA.IsMutable(), MA.add_mul, y, A, x), n)
132132
alloc_test(() -> MA.operate!(MA.add_mul, y, A, x), n)
133133
alloc_test(() -> MA.mutable_operate!(MA.add_mul, y, A, x), n)
134+
# Apparently, all allocations were on creating the buffer since this is allocation free:
135+
buffer = MA.buffer_for(MA.add_mul, typeof(y), typeof(A), typeof(x))
136+
alloc_test(() -> MA.mutable_buffered_operate!(buffer, MA.add_mul, y, A, x), 0)
134137
end
135138
@testset "matrix-matrix product" begin
136139
A = [1 2 3; 4 5 6; 6 8 9]

0 commit comments

Comments
 (0)