11# """
2- # eachobsparallel(data; buffer, executor, channelsize )
2+ # eachobsparallel(data; buffer, channelsize, basesize )
33
44# Construct a data iterator over observations in container `data`.
55# It uses available threads as workers to load observations in
1919# `data`. Setting `buffer = true` means that when using the iterator, an
2020# observation is only valid for the current loop iteration.
2121# You can also pass in a preallocated `buffer = getobs(data, 1)`.
22- # - `executor = Folds.ThreadedEx()`: task scheduler
23- # You may specify a different task scheduler which can
24- # be any `Folds.Executor`.
2522# - `channelsize = Threads.nthreads()`: the number of observations that are prefetched.
2623# Increasing `channelsize` can lead to speedups when per-observation processing
2724# time is irregular but will cause higher memory usage.
2825# """
2926function eachobsparallel (
3027 data;
31- executor:: Executor = _default_executor (),
3228 buffer:: Bool = false ,
33- channelsize = Threads. nthreads ())
34- if buffer == false
35- return _eachobsparallel_unbuffered (data, executor ; channelsize)
29+ channelsize:: Int = Threads. nthreads ())
30+ if buffer
31+ return _eachobsparallel_buffered (buffer, data ; channelsize)
3632 else
37- return _eachobsparallel_buffered (buffer, data, executor ; channelsize)
33+ return _eachobsparallel_unbuffered ( data; channelsize)
3834 end
3935end
4036
4137function _eachobsparallel_buffered (
4238 buffer,
43- data,
44- executor = _default_executor ();
45- channelsize= Threads. nthreads ())
39+ data;
40+ channelsize:: Int )
4641 buffers = [buffer]
4742 foreach (_ -> push! (buffers, deepcopy (buffer)), 1 : channelsize)
4843
@@ -52,63 +47,51 @@ function _eachobsparallel_buffered(
5247 # each iteration.
5348 setup_channel (sz) = RingBuffer (buffers)
5449
55- return Loader (1 : numobs (data); executor, channelsize, setup_channel) do ringbuffer, i
50+ return Loader (1 : numobs (data); channelsize, setup_channel) do ringbuffer, i
5651 # Internally, `RingBuffer` will `put!` the result in the results channel
5752 put! (ringbuffer) do buf
5853 getobs! (buf, data, i)
5954 end
6055 end
6156end
6257
63- function _eachobsparallel_unbuffered (data,
64- executor = _default_executor ();
65- channelsize= Threads. nthreads ()
58+ function _eachobsparallel_unbuffered (data;
59+ channelsize:: Int
6660 )
67-
68- return Loader (1 : numobs (data); executor, channelsize) do ch, i
61+ return Loader (1 : numobs (data); channelsize) do ch, i
6962 obs = getobs (data, i)
7063 put! (ch, obs)
7164 end
7265end
7366
7467
75- # Unlike DataLoaders.jl, this currently does not use task pools
76- # since `ThreadedEx` has shown to be more performant. This may
77- # change in the future.
78- # See PR 33 https://github.com/JuliaML/MLUtils.jl/pull/33
79- _default_executor () = ThreadedEx ()
80-
81-
8268# ## Internals
8369
8470# The `Loader` handles the asynchronous iteration and fills
8571# a result channel.
8672
8773
8874# """
89- # Loader(f, args; executor, channelsize, setup_channel)
75+ # Loader(f, args; channelsize, setup_channel)
9076
9177# Create a threaded iterator that iterates over `(f(arg) for arg in args)`
9278# using threads that prefill a channel of length `channelsize`.
9379
94- # Note: results may not be returned in the correct order, depending on
95- # `executor`.
80+ # Note: results may not be returned in the correct order.
9681# """
9782struct Loader
9883 f
9984 argiter:: AbstractVector
100- executor:: Executor
10185 channelsize:: Int
10286 setup_channel
10387end
10488
10589function Loader (
10690 f,
10791 argiter;
108- executor= _default_executor (),
109- channelsize= Threads. nthreads (),
92+ channelsize:: Int = Threads. nthreads (),
11093 setup_channel = sz -> Channel (sz))
111- Loader (f, argiter, executor, channelsize, setup_channel)
94+ Loader (f, argiter, channelsize, setup_channel)
11295end
11396
11497Base. length (loader:: Loader ) = length (loader. argiter)
@@ -121,20 +104,39 @@ end
121104
122105function Base. iterate (loader:: Loader )
123106 ch = loader. setup_channel (loader. channelsize)
124- task = @async begin
125- @floop loader. executor for arg in loader. argiter
126- try
127- loader. f (ch, arg)
128- catch e
129- close (ch, e)
130- rethrow ()
131- end
107+ basesize = length (loader. argiter) ÷ Threads. nthreads ()
108+ task = Threads. @spawn begin
109+ try
110+ _spawn_foreach (loader. f, ch, loader. argiter,
111+ firstindex (loader. argiter),
112+ lastindex (loader. argiter),
113+ basesize)
114+ catch e
115+ close (ch, e)
116+ rethrow ()
132117 end
133118 end
134119
135120 return Base. iterate (loader, LoaderState (task, ch, length (loader. argiter)))
136121end
137122
123+ # Recursive divide-and-conquer over `argiter[lo:hi]`:
124+ # At each level we `@spawn` the right half and recurse on the left half on the current task, then `wait` on the right.
125+ # Leaves of size `<= basesize` are processed sequentially.
126+ function _spawn_foreach (f:: F , ch, argiter, lo, hi, basesize:: Int ) where {F}
127+ if hi - lo < max (basesize, 1 )
128+ for i in lo: hi
129+ f (ch, argiter[i])
130+ end
131+ else
132+ mid = (lo + hi) >> 1
133+ task = Threads. @spawn _spawn_foreach ($ f, $ ch, $ argiter, $ (mid + 1 ), $ hi, $ basesize)
134+ _spawn_foreach (f, ch, argiter, lo, mid, basesize)
135+ wait (task)
136+ end
137+ return nothing
138+ end
139+
138140function Base. iterate (:: Loader , state:: LoaderState )
139141 if state. remaining == 0
140142 close (state. channel)
0 commit comments