From 87e9c89e1b3b3cea685222185c4b343143f4274f Mon Sep 17 00:00:00 2001 From: Harendra Kumar Date: Sat, 17 May 2025 19:27:12 +0530 Subject: [PATCH] Add concurrent scan combinators with static list of scans --- .../Benchmark/Data/Scanl/Concurrent.hs | 8 ++-- .../Internal/Data/Scanl/Concurrent.hs | 46 +++++++++++++++---- test/Streamly/Test/Data/Scanl/Concurrent.hs | 16 +++---- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/benchmark/Streamly/Benchmark/Data/Scanl/Concurrent.hs b/benchmark/Streamly/Benchmark/Data/Scanl/Concurrent.hs index e5f8a2c02b..4645af4a52 100644 --- a/benchmark/Streamly/Benchmark/Data/Scanl/Concurrent.hs +++ b/benchmark/Streamly/Benchmark/Data/Scanl/Concurrent.hs @@ -39,11 +39,11 @@ mkBench name f = -- Benchmarks -------------------------------------------------------------------------------- -parDistributeScan :: Int -> Seed -> IO () -parDistributeScan len seed = do +parDistributeScanM :: Int -> Seed -> IO () +parDistributeScanM len seed = do ref <- newIORef [Scanl.latest] let gen = atomicModifyIORef ref (\xs -> ([], xs)) - Scanl.parDistributeScan id gen (source len seed) + Scanl.parDistributeScanM id gen (source len seed) & Stream.fold Fold.drain -------------------------------------------------------------------------------- @@ -53,7 +53,7 @@ parDistributeScan len seed = do o_1_space_scans :: Int -> [Benchmark] o_1_space_scans numElements = [ bgroup "scan" - [ mkBench "parDistributeScan" (parDistributeScan numElements) + [ mkBench "parDistributeScanM" (parDistributeScanM numElements) ] ] diff --git a/src/Streamly/Internal/Data/Scanl/Concurrent.hs b/src/Streamly/Internal/Data/Scanl/Concurrent.hs index bc02456ba5..4675630fdb 100644 --- a/src/Streamly/Internal/Data/Scanl/Concurrent.hs +++ b/src/Streamly/Internal/Data/Scanl/Concurrent.hs @@ -9,7 +9,9 @@ module Streamly.Internal.Data.Scanl.Concurrent ( parTeeWith + , parDistributeScanM , parDistributeScan + , parDemuxScanM , parDemuxScan ) where @@ -19,7 +21,7 @@ where import Control.Concurrent (newEmptyMVar, takeMVar, throwTo) import Control.Monad.Catch (throwM) import Control.Monad.IO.Class (MonadIO(liftIO)) -import Data.IORef (newIORef, readIORef) +import Data.IORef (newIORef, readIORef, atomicModifyIORef) import Fusion.Plugin.Types (Fuse(..)) import Streamly.Internal.Control.Concurrent (MonadAsync) import Streamly.Internal.Data.Atomics (atomicModifyIORefCAS) @@ -30,6 +32,7 @@ import Streamly.Internal.Data.SVar.Type (adaptState) import Streamly.Internal.Data.Tuple.Strict (Tuple3'(..)) import qualified Data.Map.Strict as Map +import qualified Streamly.Internal.Data.Stream as Stream import Streamly.Internal.Data.Fold.Channel.Type import Streamly.Internal.Data.Channel.Types @@ -162,13 +165,13 @@ data ScanState s q db f = -- >>> import Data.IORef -- >>> ref <- newIORef [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int] -- >>> gen = atomicModifyIORef ref (\xs -> ([], xs)) --- >>> Stream.toList $ Scanl.parDistributeScan id gen (Stream.enumerateFromTo 1 10) +-- >>> Stream.toList $ Scanl.parDistributeScanM id gen (Stream.enumerateFromTo 1 10) -- ... -- -{-# INLINE parDistributeScan #-} -parDistributeScan :: MonadAsync m => +{-# INLINE parDistributeScanM #-} +parDistributeScanM :: MonadAsync m => (Config -> Config) -> m [Scanl m a b] -> Stream m a -> Stream m [b] -parDistributeScan cfg getFolds (Stream sstep state) = +parDistributeScanM cfg getFolds (Stream sstep state) = Stream step ScanInit where @@ -243,6 +246,20 @@ parDistributeScan cfg getFolds (Stream sstep state) = else return $ Yield outputs (ScanDrain q db running) step _ ScanStop = return Stop +-- | Like 'parDistributeScanM' but takes a list of static scans. +-- +-- >>> xs = [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int] +-- >>> Stream.toList $ Scanl.parDistributeScan id xs (Stream.enumerateFromTo 1 10) +-- ... +{-# INLINE parDistributeScan #-} +parDistributeScan :: MonadAsync m => + (Config -> Config) -> [Scanl m a b] -> Stream m a -> Stream m [b] +parDistributeScan cfg getFolds stream = + Stream.concatEffect $ do + ref <- liftIO $ newIORef getFolds + let action = liftIO $ atomicModifyIORef ref (\xs -> ([], xs)) + return $ parDistributeScanM cfg action stream + {-# ANN type DemuxState Fuse #-} data DemuxState s q db f = DemuxInit @@ -273,17 +290,17 @@ data DemuxState s q db f = -- >>> getScan k = return (fromJust $ Map.lookup k kv) -- >>> getKey x = if even x then "even" else "odd" -- >>> input = Stream.enumerateFromTo 1 10 --- >>> Stream.toList $ Scanl.parDemuxScan id getKey getScan input +-- >>> Stream.toList $ Scanl.parDemuxScanM id getKey getScan input -- ... -- -{-# INLINE parDemuxScan #-} -parDemuxScan :: (MonadAsync m, Ord k) => +{-# INLINE parDemuxScanM #-} +parDemuxScanM :: (MonadAsync m, Ord k) => (Config -> Config) -> (a -> k) -> (k -> m (Scanl m a b)) -> Stream m a -> Stream m [(k, b)] -parDemuxScan cfg getKey getFold (Stream sstep state) = +parDemuxScanM cfg getKey getFold (Stream sstep state) = Stream step DemuxInit where @@ -368,3 +385,14 @@ parDemuxScan cfg getKey getFold (Stream sstep state) = return $ Skip (DemuxDrain q db keyToChan1) else return $ Yield outputs (DemuxDrain q db keyToChan1) step _ DemuxStop = return Stop + +-- | Like 'parDemuxScanM' but the key to scan mapping is static/pure instead of +-- monadic. +{-# INLINE parDemuxScan #-} +parDemuxScan :: (MonadAsync m, Ord k) => + (Config -> Config) + -> (a -> k) + -> (k -> Scanl m a b) + -> Stream m a + -> Stream m [(k, b)] +parDemuxScan cfg getKey getFold = parDemuxScanM cfg getKey (pure . getFold) diff --git a/test/Streamly/Test/Data/Scanl/Concurrent.hs b/test/Streamly/Test/Data/Scanl/Concurrent.hs index 3cecbe4e6f..8b3507231b 100644 --- a/test/Streamly/Test/Data/Scanl/Concurrent.hs +++ b/test/Streamly/Test/Data/Scanl/Concurrent.hs @@ -47,7 +47,7 @@ parDistributeScan_ScanEnd concOpts = do inpList = [1..streamLen] inpStream = Stream.fromList inpList res1 <- - Scanl.parDistributeScan concOpts gen inpStream + Scanl.parDistributeScanM concOpts gen inpStream & Stream.concatMap Stream.fromList & Stream.catMaybes & Stream.fold Fold.toList @@ -65,7 +65,7 @@ parDemuxScan_ScanEnd concOpts = do inpList = [1..streamLen] inpStream = Stream.fromList inpList res <- - Scanl.parDemuxScan concOpts demuxer gen inpStream + Scanl.parDemuxScanM concOpts demuxer gen inpStream & Stream.concatMap Stream.fromList & fmap (\x -> (fst x,) <$> snd x) & Stream.catMaybes @@ -81,7 +81,7 @@ parDistributeScan_StreamEnd concOpts = do inpList = [1..streamLen] inpStream = Stream.fromList inpList res1 <- - Scanl.parDistributeScan concOpts gen inpStream + Scanl.parDistributeScanM concOpts gen inpStream & Stream.concatMap Stream.fromList & Stream.catMaybes & Stream.fold Fold.toList @@ -95,7 +95,7 @@ parDemuxScan_StreamEnd concOpts = do inpList = [1..streamLen] inpStream = Stream.fromList inpList res <- - Scanl.parDemuxScan concOpts demuxer gen inpStream + Scanl.parDemuxScanM concOpts demuxer gen inpStream & Stream.concatMap Stream.fromList & fmap (\x -> (fst x,) <$> snd x) & Stream.catMaybes @@ -110,11 +110,11 @@ main = hspec $ modifyMaxSuccess (const 10) #endif $ describe moduleName $ do - it "parDistributeScan (stream end) (maxBuffer 1)" + it "parDistributeScanM (stream end) (maxBuffer 1)" $ parDistributeScan_StreamEnd (Scanl.maxBuffer 1) - it "parDistributeScan (scan end) (maxBuffer 1)" + it "parDistributeScanM (scan end) (maxBuffer 1)" $ parDistributeScan_ScanEnd (Scanl.maxBuffer 1) - it "parDemuxScan (stream end) (maxBuffer 1)" + it "parDemuxScanM (stream end) (maxBuffer 1)" $ parDemuxScan_StreamEnd (Scanl.maxBuffer 1) - it "parDemuxScan (scan end) (maxBuffer 1)" + it "parDemuxScanM (scan end) (maxBuffer 1)" $ parDemuxScan_ScanEnd (Scanl.maxBuffer 1)