Skip to content

Commit a0eba90

Browse files
Add concurrent scan combinators with static list of scans
1 parent f568c76 commit a0eba90

1 file changed

Lines changed: 37 additions & 9 deletions

File tree

src/Streamly/Internal/Data/Scanl/Concurrent.hs

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
module Streamly.Internal.Data.Scanl.Concurrent
1010
(
1111
parTeeWith
12+
, parDistributeScanM
1213
, parDistributeScan
14+
, parDemuxScanM
1315
, parDemuxScan
1416
)
1517
where
@@ -19,7 +21,7 @@ where
1921
import Control.Concurrent (newEmptyMVar, takeMVar, throwTo)
2022
import Control.Monad.Catch (throwM)
2123
import Control.Monad.IO.Class (MonadIO(liftIO))
22-
import Data.IORef (newIORef, readIORef)
24+
import Data.IORef (newIORef, readIORef, atomicModifyIORef)
2325
import Fusion.Plugin.Types (Fuse(..))
2426
import Streamly.Internal.Control.Concurrent (MonadAsync)
2527
import Streamly.Internal.Data.Atomics (atomicModifyIORefCAS)
@@ -30,6 +32,7 @@ import Streamly.Internal.Data.SVar.Type (adaptState)
3032
import Streamly.Internal.Data.Tuple.Strict (Tuple3'(..))
3133

3234
import qualified Data.Map.Strict as Map
35+
import qualified Streamly.Internal.Data.Stream as Stream
3336

3437
import Streamly.Internal.Data.Fold.Channel.Type
3538
import Streamly.Internal.Data.Channel.Types
@@ -162,13 +165,13 @@ data ScanState s q db f =
162165
-- >>> import Data.IORef
163166
-- >>> ref <- newIORef [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int]
164167
-- >>> gen = atomicModifyIORef ref (\xs -> ([], xs))
165-
-- >>> Stream.toList $ Scanl.parDistributeScan id gen (Stream.enumerateFromTo 1 10)
168+
-- >>> Stream.toList $ Scanl.parDistributeScanM id gen (Stream.enumerateFromTo 1 10)
166169
-- ...
167170
--
168-
{-# INLINE parDistributeScan #-}
169-
parDistributeScan :: MonadAsync m =>
171+
{-# INLINE parDistributeScanM #-}
172+
parDistributeScanM :: MonadAsync m =>
170173
(Config -> Config) -> m [Scanl m a b] -> Stream m a -> Stream m [b]
171-
parDistributeScan cfg getFolds (Stream sstep state) =
174+
parDistributeScanM cfg getFolds (Stream sstep state) =
172175
Stream step ScanInit
173176

174177
where
@@ -243,6 +246,20 @@ parDistributeScan cfg getFolds (Stream sstep state) =
243246
else return $ Yield outputs (ScanDrain q db running)
244247
step _ ScanStop = return Stop
245248

249+
-- | Like 'parDistributeScanM' but takes a list of static scans.
250+
--
251+
-- >>> xs = [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int]
252+
-- >>> Stream.toList $ Scanl.parDistributeScan id xs (Stream.enumerateFromTo 1 10)
253+
-- ...
254+
{-# INLINE parDistributeScan #-}
255+
parDistributeScan :: MonadAsync m =>
256+
(Config -> Config) -> [Scanl m a b] -> Stream m a -> Stream m [b]
257+
parDistributeScan cfg getFolds stream =
258+
Stream.concatEffect $ do
259+
ref <- liftIO $ newIORef getFolds
260+
let action = liftIO $ atomicModifyIORef ref (\xs -> ([], xs))
261+
return $ parDistributeScanM cfg action stream
262+
246263
{-# ANN type DemuxState Fuse #-}
247264
data DemuxState s q db f =
248265
DemuxInit
@@ -273,17 +290,17 @@ data DemuxState s q db f =
273290
-- >>> getScan k = return (fromJust $ Map.lookup k kv)
274291
-- >>> getKey x = if even x then "even" else "odd"
275292
-- >>> input = Stream.enumerateFromTo 1 10
276-
-- >>> Stream.toList $ Scanl.parDemuxScan id getKey getScan input
293+
-- >>> Stream.toList $ Scanl.parDemuxScanM id getKey getScan input
277294
-- ...
278295
--
279-
{-# INLINE parDemuxScan #-}
280-
parDemuxScan :: (MonadAsync m, Ord k) =>
296+
{-# INLINE parDemuxScanM #-}
297+
parDemuxScanM :: (MonadAsync m, Ord k) =>
281298
(Config -> Config)
282299
-> (a -> k)
283300
-> (k -> m (Scanl m a b))
284301
-> Stream m a
285302
-> Stream m [(k, b)]
286-
parDemuxScan cfg getKey getFold (Stream sstep state) =
303+
parDemuxScanM cfg getKey getFold (Stream sstep state) =
287304
Stream step DemuxInit
288305

289306
where
@@ -368,3 +385,14 @@ parDemuxScan cfg getKey getFold (Stream sstep state) =
368385
return $ Skip (DemuxDrain q db keyToChan1)
369386
else return $ Yield outputs (DemuxDrain q db keyToChan1)
370387
step _ DemuxStop = return Stop
388+
389+
-- | Like 'parDemuxScanM' but the key to scan mapping is static/pure instead of
390+
-- monadic.
391+
{-# INLINE parDemuxScan #-}
392+
parDemuxScan :: (MonadAsync m, Ord k) =>
393+
(Config -> Config)
394+
-> (a -> k)
395+
-> (k -> Scanl m a b)
396+
-> Stream m a
397+
-> Stream m [(k, b)]
398+
parDemuxScan cfg getKey getFold = parDemuxScanM cfg getKey (pure . getFold)

0 commit comments

Comments
 (0)