Skip to content

Commit 8cd5fea

Browse files
devmotionclaude
andcommitted
Replace FLoops with manual Threads.@Spawn in parallel Loader
Drops the FLoops dependency. Parallel iteration in `Loader` is rewritten as a recursive divide-and-conquer over `argiter[lo:hi]` using `Threads.@spawn`, mirroring `Transducers._reduce`: at each level, the right half is spawned and the left half recurses on the current task, with leaves (size <= basesize) processed sequentially. Default `basesize = length ÷ nthreads()` matches the FLoops `ThreadedEx` default, and the base-case check uses `max(basesize, 1)` so small inputs (`numobs < nthreads`) terminate. The outer dispatcher uses `Threads.@spawn` instead of `@async` to avoid making the caller's task non-migratable (per the `@async` docstring's warning about library code). Also moves Transducers from a required dep to a weakdep + extension (`ext/TransducersExt.jl`) and bumps the minimum Julia version to 1.10 (with the matching CI.yml `min-patch` switch). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 89ff3f7 commit 8cd5fea

8 files changed

Lines changed: 98 additions & 109 deletions

File tree

.github/workflows/CI.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,21 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.6'
21+
- 'min-patch'
2222
- '1'
2323
- 'nightly'
2424
os:
2525
- ubuntu-latest
26-
arch:
27-
- x64
2826
include:
2927
- os: windows-latest
3028
version: '1'
31-
arch: x64
3229
- os: macOS-latest
3330
version: '1'
34-
arch: x64
3531
steps:
3632
- uses: actions/checkout@v5
3733
- uses: julia-actions/setup-julia@v2
3834
with:
3935
version: ${{ matrix.version }}
40-
arch: ${{ matrix.arch }}
4136
- uses: julia-actions/cache@v2
4237
- uses: julia-actions/julia-buildpkg@v1
4338
- uses: julia-actions/julia-runtest@v1

Project.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
name = "MLUtils"
22
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
33
authors = ["Carlo Lucibello <carlo.lucibello@gmail.com> and contributors"]
4-
version = "0.4.9"
4+
version = "0.4.10"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
99
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
1010
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
11-
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
1211
MLCore = "c2834f40-e789-41da-a90e-33b280584a8c"
1312
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1413
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -17,14 +16,18 @@ SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
1716
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1817
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1918
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
19+
20+
[weakdeps]
2021
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
2122

23+
[extensions]
24+
TransducersExt = "Transducers"
25+
2226
[compat]
2327
ChainRulesCore = "1.0"
2428
Compat = "4.2"
2529
DataAPI = "1.0"
2630
DelimitedFiles = "1.0"
27-
FLoops = "0.2"
2831
MLCore = "1.0.0"
2932
NNlib = "0.8, 0.9"
3033
ShowCases = "0.1"
@@ -33,4 +36,4 @@ Statistics = "1"
3336
StatsBase = "0.33, 0.34"
3437
Tables = "1.10"
3538
Transducers = "0.4"
36-
julia = "1.6"
39+
julia = "1.10"

ext/TransducersExt.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
module TransducersExt
2+
3+
using MLUtils: DataLoader, numobs, getobs, getobs!, _shuffledata
4+
import Transducers
5+
6+
@inline function _dataloader_foldl1(rf, val, d::DataLoader, data)
7+
if d.shuffle
8+
return _dataloader_foldl2(rf, val, d, _shuffledata(d.rng, data))
9+
else
10+
return _dataloader_foldl2(rf, val, d, data)
11+
end
12+
end
13+
14+
@inline function _dataloader_foldl2(rf, val, d::DataLoader, data)
15+
if d.buffer == false
16+
return _dataloader_foldl3(rf, val, data)
17+
else
18+
return _dataloader_foldl3_buffered(rf, val, data, d.buffer)
19+
end
20+
end
21+
22+
@inline function _dataloader_foldl3(rf, val, data)
23+
for i in 1:numobs(data)
24+
@inbounds x = getobs(data, i)
25+
# TODO: in 1.8 we could @inline this at the callsite,
26+
# optimizer seems to be very sensitive to inlining and
27+
# quite brittle in its capacity to keep this type stable
28+
val = Transducers.@next(rf, val, x)
29+
end
30+
return Transducers.complete(rf, val)
31+
end
32+
33+
@inline function _dataloader_foldl3_buffered(rf, val, data, buf)
34+
for i in 1:numobs(data)
35+
@inbounds x = getobs!(buf, data, i)
36+
val = Transducers.@next(rf, val, x)
37+
end
38+
return Transducers.complete(rf, val)
39+
end
40+
41+
@inline function Transducers.__foldl__(rf, val, d::DataLoader)
42+
d.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
43+
return _dataloader_foldl1(rf, val, d, d._data)
44+
end
45+
46+
end # module

src/MLUtils.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@ module MLUtils
33
using Random
44
using Statistics
55
using ShowCases: ShowLimit
6-
using FLoops: @floop
7-
using FLoops.Transducers: Executor, ThreadedEx
86
import StatsBase: sample
9-
using Transducers
107
using Tables
118
using DataAPI
129
using Base: @propagate_inbounds

src/dataloader.jl

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ end
195195
function Base.iterate(d::DataLoader{T,B,:parallel}) where {T,B}
196196
@assert d.buffer != false
197197
data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data
198-
iter = _eachobsparallel_buffered(d.buffer, data)
198+
iter = _eachobsparallel_buffered(d.buffer, data; channelsize=Threads.nthreads())
199199
obs, state = iterate(iter)
200200
return obs, (iter, state)
201201
end
@@ -213,7 +213,7 @@ end
213213
function Base.iterate(d::DataLoader{T,Bool,:parallel}) where {T}
214214
@assert d.buffer == false
215215
data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data
216-
iter = _eachobsparallel_unbuffered(data)
216+
iter = _eachobsparallel_unbuffered(data; channelsize=Threads.nthreads())
217217
obs, state = iterate(iter)
218218
return obs, (iter, state)
219219
end
@@ -358,47 +358,3 @@ function _expanded_summary(xs::NamedTuple)
358358
parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)]
359359
"(; " * join(parts, ", ") * ")"
360360
end
361-
362-
363-
### TRANSDUCERS IMPLEMENTATION #############################
364-
365-
366-
@inline function _dataloader_foldl1(rf, val, d::DataLoader, data)
367-
if d.shuffle
368-
return _dataloader_foldl2(rf, val, d, _shuffledata(d.rng, data))
369-
else
370-
return _dataloader_foldl2(rf, val, d, data)
371-
end
372-
end
373-
374-
@inline function _dataloader_foldl2(rf, val, d::DataLoader, data)
375-
if d.buffer == false
376-
return _dataloader_foldl3(rf, val, data)
377-
else
378-
return _dataloader_foldl3_buffered(rf, val, data, d.buffer)
379-
end
380-
end
381-
382-
@inline function _dataloader_foldl3(rf, val, data)
383-
for i in 1:numobs(data)
384-
@inbounds x = getobs(data, i)
385-
# TODO: in 1.8 we could @inline this at the callsite,
386-
# optimizer seems to be very sensitive to inlining and
387-
# quite brittle in its capacity to keep this type stable
388-
val = Transducers.@next(rf, val, x)
389-
end
390-
return Transducers.complete(rf, val)
391-
end
392-
393-
@inline function _dataloader_foldl3_buffered(rf, val, data, buf)
394-
for i in 1:numobs(data)
395-
@inbounds x = getobs!(buf, data, i)
396-
val = Transducers.@next(rf, val, x)
397-
end
398-
return Transducers.complete(rf, val)
399-
end
400-
401-
@inline function Transducers.__foldl__(rf, val, d::DataLoader)
402-
d.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
403-
return _dataloader_foldl1(rf, val, d, d._data)
404-
end

src/parallel.jl

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# """
2-
# eachobsparallel(data; buffer, executor, channelsize)
2+
# eachobsparallel(data; buffer, channelsize, basesize)
33

44
# Construct a data iterator over observations in container `data`.
55
# It uses available threads as workers to load observations in
@@ -19,30 +19,25 @@
1919
# `data`. Setting `buffer = true` means that when using the iterator, an
2020
# observation is only valid for the current loop iteration.
2121
# You can also pass in a preallocated `buffer = getobs(data, 1)`.
22-
# - `executor = Folds.ThreadedEx()`: task scheduler
23-
# You may specify a different task scheduler which can
24-
# be any `Folds.Executor`.
2522
# - `channelsize = Threads.nthreads()`: the number of observations that are prefetched.
2623
# Increasing `channelsize` can lead to speedups when per-observation processing
2724
# time is irregular but will cause higher memory usage.
2825
# """
2926
function eachobsparallel(
3027
data;
31-
executor::Executor = _default_executor(),
3228
buffer::Bool = false,
33-
channelsize = Threads.nthreads())
34-
if buffer == false
35-
return _eachobsparallel_unbuffered(data, executor; channelsize)
29+
channelsize::Int = Threads.nthreads())
30+
if buffer
31+
return _eachobsparallel_buffered(buffer, data; channelsize)
3632
else
37-
return _eachobsparallel_buffered(buffer, data, executor; channelsize)
33+
return _eachobsparallel_unbuffered(data; channelsize)
3834
end
3935
end
4036

4137
function _eachobsparallel_buffered(
4238
buffer,
43-
data,
44-
executor = _default_executor();
45-
channelsize=Threads.nthreads())
39+
data;
40+
channelsize::Int)
4641
buffers = [buffer]
4742
foreach(_ -> push!(buffers, deepcopy(buffer)), 1:channelsize)
4843

@@ -52,63 +47,51 @@ function _eachobsparallel_buffered(
5247
# each iteration.
5348
setup_channel(sz) = RingBuffer(buffers)
5449

55-
return Loader(1:numobs(data); executor, channelsize, setup_channel) do ringbuffer, i
50+
return Loader(1:numobs(data); channelsize, setup_channel) do ringbuffer, i
5651
# Internally, `RingBuffer` will `put!` the result in the results channel
5752
put!(ringbuffer) do buf
5853
getobs!(buf, data, i)
5954
end
6055
end
6156
end
6257

63-
function _eachobsparallel_unbuffered(data,
64-
executor = _default_executor();
65-
channelsize=Threads.nthreads()
58+
function _eachobsparallel_unbuffered(data;
59+
channelsize::Int
6660
)
67-
68-
return Loader(1:numobs(data); executor, channelsize) do ch, i
61+
return Loader(1:numobs(data); channelsize) do ch, i
6962
obs = getobs(data, i)
7063
put!(ch, obs)
7164
end
7265
end
7366

7467

75-
# Unlike DataLoaders.jl, this currently does not use task pools
76-
# since `ThreadedEx` has shown to be more performant. This may
77-
# change in the future.
78-
# See PR 33 https://github.com/JuliaML/MLUtils.jl/pull/33
79-
_default_executor() = ThreadedEx()
80-
81-
8268
# ## Internals
8369

8470
# The `Loader` handles the asynchronous iteration and fills
8571
# a result channel.
8672

8773

8874
# """
89-
# Loader(f, args; executor, channelsize, setup_channel)
75+
# Loader(f, args; channelsize, setup_channel)
9076

9177
# Create a threaded iterator that iterates over `(f(arg) for arg in args)`
9278
# using threads that prefill a channel of length `channelsize`.
9379

94-
# Note: results may not be returned in the correct order, depending on
95-
# `executor`.
80+
# Note: results may not be returned in the correct order.
9681
# """
9782
struct Loader
9883
f
9984
argiter::AbstractVector
100-
executor::Executor
10185
channelsize::Int
10286
setup_channel
10387
end
10488

10589
function Loader(
10690
f,
10791
argiter;
108-
executor=_default_executor(),
109-
channelsize=Threads.nthreads(),
92+
channelsize::Int = Threads.nthreads(),
11093
setup_channel = sz -> Channel(sz))
111-
Loader(f, argiter, executor, channelsize, setup_channel)
94+
Loader(f, argiter, channelsize, setup_channel)
11295
end
11396

11497
Base.length(loader::Loader) = length(loader.argiter)
@@ -121,20 +104,39 @@ end
121104

122105
function Base.iterate(loader::Loader)
123106
ch = loader.setup_channel(loader.channelsize)
124-
task = @async begin
125-
@floop loader.executor for arg in loader.argiter
126-
try
127-
loader.f(ch, arg)
128-
catch e
129-
close(ch, e)
130-
rethrow()
131-
end
107+
basesize = length(loader.argiter) ÷ Threads.nthreads()
108+
task = Threads.@spawn begin
109+
try
110+
_spawn_foreach(loader.f, ch, loader.argiter,
111+
firstindex(loader.argiter),
112+
lastindex(loader.argiter),
113+
basesize)
114+
catch e
115+
close(ch, e)
116+
rethrow()
132117
end
133118
end
134119

135120
return Base.iterate(loader, LoaderState(task, ch, length(loader.argiter)))
136121
end
137122

123+
# Recursive divide-and-conquer over `argiter[lo:hi]`:
124+
# At each level we `@spawn` the right half and recurse on the left half on the current task, then `wait` on the right.
125+
# Leaves of size `<= basesize` are processed sequentially.
126+
function _spawn_foreach(f::F, ch, argiter, lo, hi, basesize::Int) where {F}
127+
if hi - lo < max(basesize, 1)
128+
for i in lo:hi
129+
f(ch, argiter[i])
130+
end
131+
else
132+
mid = (lo + hi) >> 1
133+
task = Threads.@spawn _spawn_foreach($f, $ch, $argiter, $(mid + 1), $hi, $basesize)
134+
_spawn_foreach(f, ch, argiter, lo, mid, basesize)
135+
wait(task)
136+
end
137+
return nothing
138+
end
139+
138140
function Base.iterate(::Loader, state::LoaderState)
139141
if state.remaining == 0
140142
close(state.channel)

test/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
[deps]
2-
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
32
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
43
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
54
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
65
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
7-
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
86
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
97
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
108
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

test/parallel.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@
77
@test all(x 1:10 for x in X_)
88
@test length(unique(X_)) == 10
99
end
10-
11-
@testset "With `ThreadedEx`" begin
12-
iter = eachobsparallel(collect(1:10); executor = ThreadedEx())
13-
@test_nowarn for i in iter end
14-
X_ = collect(iter)
15-
@test all(x 1:10 for x in X_)
16-
@test length(unique(X_)) == 10
17-
end
1810
end
1911

2012

0 commit comments

Comments
 (0)