From acbbb8f0c5f414f5faf6595ead283cc24caf3705 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20M=C3=BCller-Widmann?= Date: Wed, 6 May 2026 17:01:20 +0200 Subject: [PATCH] Replace FLoops with manual Threads.@spawn in parallel Loader MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drops the FLoops dependency. Parallel iteration in `Loader` is rewritten as a recursive divide-and-conquer over `argiter[lo:hi]` using `Threads.@spawn`, mirroring `Transducers._reduce`: at each level, the right half is spawned and the left half recurses on the current task, with leaves (size <= basesize) processed sequentially. Default `basesize = length ÷ nthreads()` matches the FLoops `ThreadedEx` default, and the base-case check uses `max(basesize, 1)` so small inputs (`numobs < nthreads`) terminate. The outer dispatcher uses `Threads.@spawn` instead of `@async` to avoid making the caller's task non-migratable (per the `@async` docstring's warning about library code). Also moves Transducers from a required dep to a weakdep + extension (`ext/TransducersExt.jl`) and bumps the minimum Julia version to 1.10 (with the matching CI.yml `min-patch` switch). Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/CI.yml | 7 +--- Project.toml | 11 ++++-- ext/TransducersExt.jl | 46 ++++++++++++++++++++++ src/MLUtils.jl | 3 -- src/dataloader.jl | 48 +---------------------- src/parallel.jl | 82 ++++++++++++++++++++-------------------- test/Project.toml | 2 - test/parallel.jl | 8 ---- 8 files changed, 98 insertions(+), 109 deletions(-) create mode 100644 ext/TransducersExt.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7658a81..3a35807 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,26 +18,21 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - 'min' - '1' - 'nightly' os: - ubuntu-latest - arch: - - x64 include: - os: windows-latest version: '1' - arch: x64 - os: macOS-latest version: '1' - arch: x64 steps: - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 diff --git a/Project.toml b/Project.toml index 991aa92..74931f9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,13 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.4.9" +version = "0.4.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" -FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" MLCore = "c2834f40-e789-41da-a90e-33b280584a8c" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -17,14 +16,18 @@ SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + +[weakdeps] Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" +[extensions] +TransducersExt = "Transducers" + [compat] ChainRulesCore = "1.0" Compat = "4.2" DataAPI = "1.0" DelimitedFiles = "1.0" -FLoops = "0.2" MLCore = "1.0.0" NNlib = "0.8, 0.9" ShowCases = "0.1" @@ -33,4 +36,4 @@ Statistics = "1" StatsBase = "0.33, 0.34" Tables = "1.10" Transducers = "0.4" -julia = "1.6" +julia = "1.10" diff --git a/ext/TransducersExt.jl b/ext/TransducersExt.jl new file mode 100644 index 0000000..e449c5c --- /dev/null +++ b/ext/TransducersExt.jl @@ -0,0 +1,46 @@ +module TransducersExt + +using MLUtils: DataLoader, numobs, getobs, getobs!, _shuffledata +import Transducers + +@inline function _dataloader_foldl1(rf, val, d::DataLoader, data) + if d.shuffle + return _dataloader_foldl2(rf, val, d, _shuffledata(d.rng, data)) + else + return _dataloader_foldl2(rf, val, d, data) + end +end + +@inline function _dataloader_foldl2(rf, val, d::DataLoader, data) + if d.buffer == false + return _dataloader_foldl3(rf, val, data) + else + return _dataloader_foldl3_buffered(rf, val, data, d.buffer) + end +end + +@inline function _dataloader_foldl3(rf, val, data) + for i in 1:numobs(data) + @inbounds x = getobs(data, i) + # TODO: in 1.8 we could @inline this at the callsite, + # optimizer seems to be very sensitive to inlining and + # quite brittle in its capacity to keep this type stable + val = Transducers.@next(rf, val, x) + end + return Transducers.complete(rf, val) +end + +@inline function _dataloader_foldl3_buffered(rf, val, data, buf) + for i in 1:numobs(data) + @inbounds x = getobs!(buf, data, i) + val = Transducers.@next(rf, val, x) + end + return Transducers.complete(rf, val) +end + +@inline function Transducers.__foldl__(rf, val, d::DataLoader) + d.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads")) + return _dataloader_foldl1(rf, val, d, d._data) +end + +end # module diff --git a/src/MLUtils.jl b/src/MLUtils.jl index 9f829a5..877ca72 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -3,10 +3,7 @@ module MLUtils using Random using Statistics using ShowCases: ShowLimit -using FLoops: @floop -using FLoops.Transducers: Executor, ThreadedEx import StatsBase: sample -using Transducers using Tables using DataAPI using Base: @propagate_inbounds diff --git a/src/dataloader.jl b/src/dataloader.jl index ebb67e8..59c954c 100644 --- a/src/dataloader.jl +++ b/src/dataloader.jl @@ -195,7 +195,7 @@ end function Base.iterate(d::DataLoader{T,B,:parallel}) where {T,B} @assert d.buffer != false data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data - iter = _eachobsparallel_buffered(d.buffer, data) + iter = _eachobsparallel_buffered(d.buffer, data; channelsize=Threads.nthreads()) obs, state = iterate(iter) return obs, (iter, state) end @@ -213,7 +213,7 @@ end function Base.iterate(d::DataLoader{T,Bool,:parallel}) where {T} @assert d.buffer == false data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data - iter = _eachobsparallel_unbuffered(data) + iter = _eachobsparallel_unbuffered(data; channelsize=Threads.nthreads()) obs, state = iterate(iter) return obs, (iter, state) end @@ -358,47 +358,3 @@ function _expanded_summary(xs::NamedTuple) parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)] "(; " * join(parts, ", ") * ")" end - - -### TRANSDUCERS IMPLEMENTATION ############################# - - -@inline function _dataloader_foldl1(rf, val, d::DataLoader, data) - if d.shuffle - return _dataloader_foldl2(rf, val, d, _shuffledata(d.rng, data)) - else - return _dataloader_foldl2(rf, val, d, data) - end -end - -@inline function _dataloader_foldl2(rf, val, d::DataLoader, data) - if d.buffer == false - return _dataloader_foldl3(rf, val, data) - else - return _dataloader_foldl3_buffered(rf, val, data, d.buffer) - end -end - -@inline function _dataloader_foldl3(rf, val, data) - for i in 1:numobs(data) - @inbounds x = getobs(data, i) - # TODO: in 1.8 we could @inline this at the callsite, - # optimizer seems to be very sensitive to inlining and - # quite brittle in its capacity to keep this type stable - val = Transducers.@next(rf, val, x) - end - return Transducers.complete(rf, val) -end - -@inline function _dataloader_foldl3_buffered(rf, val, data, buf) - for i in 1:numobs(data) - @inbounds x = getobs!(buf, data, i) - val = Transducers.@next(rf, val, x) - end - return Transducers.complete(rf, val) -end - -@inline function Transducers.__foldl__(rf, val, d::DataLoader) - d.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads")) - return _dataloader_foldl1(rf, val, d, d._data) -end diff --git a/src/parallel.jl b/src/parallel.jl index 90706f5..b5b393b 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -1,5 +1,5 @@ # """ -# eachobsparallel(data; buffer, executor, channelsize) +# eachobsparallel(data; buffer, channelsize, basesize) # Construct a data iterator over observations in container `data`. # It uses available threads as workers to load observations in @@ -19,30 +19,25 @@ # `data`. Setting `buffer = true` means that when using the iterator, an # observation is only valid for the current loop iteration. # You can also pass in a preallocated `buffer = getobs(data, 1)`. -# - `executor = Folds.ThreadedEx()`: task scheduler -# You may specify a different task scheduler which can -# be any `Folds.Executor`. # - `channelsize = Threads.nthreads()`: the number of observations that are prefetched. # Increasing `channelsize` can lead to speedups when per-observation processing # time is irregular but will cause higher memory usage. # """ function eachobsparallel( data; - executor::Executor = _default_executor(), buffer::Bool = false, - channelsize = Threads.nthreads()) - if buffer == false - return _eachobsparallel_unbuffered(data, executor; channelsize) + channelsize::Int = Threads.nthreads()) + if buffer + return _eachobsparallel_buffered(buffer, data; channelsize) else - return _eachobsparallel_buffered(buffer, data, executor; channelsize) + return _eachobsparallel_unbuffered(data; channelsize) end end function _eachobsparallel_buffered( buffer, - data, - executor = _default_executor(); - channelsize=Threads.nthreads()) + data; + channelsize::Int) buffers = [buffer] foreach(_ -> push!(buffers, deepcopy(buffer)), 1:channelsize) @@ -52,7 +47,7 @@ function _eachobsparallel_buffered( # each iteration. setup_channel(sz) = RingBuffer(buffers) - return Loader(1:numobs(data); executor, channelsize, setup_channel) do ringbuffer, i + return Loader(1:numobs(data); channelsize, setup_channel) do ringbuffer, i # Internally, `RingBuffer` will `put!` the result in the results channel put!(ringbuffer) do buf getobs!(buf, data, i) @@ -60,25 +55,16 @@ function _eachobsparallel_buffered( end end -function _eachobsparallel_unbuffered(data, - executor = _default_executor(); - channelsize=Threads.nthreads() +function _eachobsparallel_unbuffered(data; + channelsize::Int ) - - return Loader(1:numobs(data); executor, channelsize) do ch, i + return Loader(1:numobs(data); channelsize) do ch, i obs = getobs(data, i) put!(ch, obs) end end -# Unlike DataLoaders.jl, this currently does not use task pools -# since `ThreadedEx` has shown to be more performant. This may -# change in the future. -# See PR 33 https://github.com/JuliaML/MLUtils.jl/pull/33 -_default_executor() = ThreadedEx() - - # ## Internals # The `Loader` handles the asynchronous iteration and fills @@ -86,18 +72,16 @@ _default_executor() = ThreadedEx() # """ -# Loader(f, args; executor, channelsize, setup_channel) +# Loader(f, args; channelsize, setup_channel) # Create a threaded iterator that iterates over `(f(arg) for arg in args)` # using threads that prefill a channel of length `channelsize`. -# Note: results may not be returned in the correct order, depending on -# `executor`. +# Note: results may not be returned in the correct order. # """ struct Loader f argiter::AbstractVector - executor::Executor channelsize::Int setup_channel end @@ -105,10 +89,9 @@ end function Loader( f, argiter; - executor=_default_executor(), - channelsize=Threads.nthreads(), + channelsize::Int = Threads.nthreads(), setup_channel = sz -> Channel(sz)) - Loader(f, argiter, executor, channelsize, setup_channel) + Loader(f, argiter, channelsize, setup_channel) end Base.length(loader::Loader) = length(loader.argiter) @@ -121,20 +104,39 @@ end function Base.iterate(loader::Loader) ch = loader.setup_channel(loader.channelsize) - task = @async begin - @floop loader.executor for arg in loader.argiter - try - loader.f(ch, arg) - catch e - close(ch, e) - rethrow() - end + basesize = length(loader.argiter) ÷ Threads.nthreads() + task = Threads.@spawn begin + try + _spawn_foreach(loader.f, ch, loader.argiter, + firstindex(loader.argiter), + lastindex(loader.argiter), + basesize) + catch e + close(ch, e) + rethrow() end end return Base.iterate(loader, LoaderState(task, ch, length(loader.argiter))) end +# Recursive divide-and-conquer over `argiter[lo:hi]`: +# At each level we `@spawn` the right half and recurse on the left half on the current task, then `wait` on the right. +# Leaves of size `<= basesize` are processed sequentially. +function _spawn_foreach(f::F, ch, argiter, lo, hi, basesize::Int) where {F} + if hi - lo < max(basesize, 1) + for i in lo:hi + f(ch, argiter[i]) + end + else + mid = (lo + hi) >> 1 + task = Threads.@spawn _spawn_foreach($f, $ch, $argiter, $(mid + 1), $hi, $basesize) + _spawn_foreach(f, ch, argiter, lo, mid, basesize) + wait(task) + end + return nothing +end + function Base.iterate(::Loader, state::LoaderState) if state.remaining == 0 close(state.channel) diff --git a/test/Project.toml b/test/Project.toml index 72b1e61..b83c121 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,8 @@ [deps] -BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/parallel.jl b/test/parallel.jl index 84a808f..ad99b5e 100644 --- a/test/parallel.jl +++ b/test/parallel.jl @@ -7,14 +7,6 @@ @test all(x ∈ 1:10 for x in X_) @test length(unique(X_)) == 10 end - - @testset "With `ThreadedEx`" begin - iter = eachobsparallel(collect(1:10); executor = ThreadedEx()) - @test_nowarn for i in iter end - X_ = collect(iter) - @test all(x ∈ 1:10 for x in X_) - @test length(unique(X_)) == 10 - end end