Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmark/Streamly/Benchmark/Data/Scanl/Concurrent.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ mkBench name f =
-- Benchmarks
--------------------------------------------------------------------------------

parDistributeScan :: Int -> Seed -> IO ()
parDistributeScan len seed = do
parDistributeScanM :: Int -> Seed -> IO ()
parDistributeScanM len seed = do
ref <- newIORef [Scanl.latest]
let gen = atomicModifyIORef ref (\xs -> ([], xs))
Scanl.parDistributeScan id gen (source len seed)
Scanl.parDistributeScanM id gen (source len seed)
& Stream.fold Fold.drain

--------------------------------------------------------------------------------
Expand All @@ -53,7 +53,7 @@ parDistributeScan len seed = do
o_1_space_scans :: Int -> [Benchmark]
o_1_space_scans numElements =
[ bgroup "scan"
[ mkBench "parDistributeScan" (parDistributeScan numElements)
[ mkBench "parDistributeScanM" (parDistributeScanM numElements)
]
]

Expand Down
46 changes: 37 additions & 9 deletions src/Streamly/Internal/Data/Scanl/Concurrent.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
module Streamly.Internal.Data.Scanl.Concurrent
(
parTeeWith
, parDistributeScanM
, parDistributeScan
, parDemuxScanM
, parDemuxScan
)
where
Expand All @@ -19,7 +21,7 @@ where
import Control.Concurrent (newEmptyMVar, takeMVar, throwTo)
import Control.Monad.Catch (throwM)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Data.IORef (newIORef, readIORef)
import Data.IORef (newIORef, readIORef, atomicModifyIORef)
import Fusion.Plugin.Types (Fuse(..))
import Streamly.Internal.Control.Concurrent (MonadAsync)
import Streamly.Internal.Data.Atomics (atomicModifyIORefCAS)
Expand All @@ -30,6 +32,7 @@ import Streamly.Internal.Data.SVar.Type (adaptState)
import Streamly.Internal.Data.Tuple.Strict (Tuple3'(..))

import qualified Data.Map.Strict as Map
import qualified Streamly.Internal.Data.Stream as Stream

import Streamly.Internal.Data.Fold.Channel.Type
import Streamly.Internal.Data.Channel.Types
Expand Down Expand Up @@ -162,13 +165,13 @@ data ScanState s q db f =
-- >>> import Data.IORef
-- >>> ref <- newIORef [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int]
-- >>> gen = atomicModifyIORef ref (\xs -> ([], xs))
-- >>> Stream.toList $ Scanl.parDistributeScan id gen (Stream.enumerateFromTo 1 10)
-- >>> Stream.toList $ Scanl.parDistributeScanM id gen (Stream.enumerateFromTo 1 10)
-- ...
--
{-# INLINE parDistributeScan #-}
parDistributeScan :: MonadAsync m =>
{-# INLINE parDistributeScanM #-}
parDistributeScanM :: MonadAsync m =>
(Config -> Config) -> m [Scanl m a b] -> Stream m a -> Stream m [b]
parDistributeScan cfg getFolds (Stream sstep state) =
parDistributeScanM cfg getFolds (Stream sstep state) =
Stream step ScanInit

where
Expand Down Expand Up @@ -243,6 +246,20 @@ parDistributeScan cfg getFolds (Stream sstep state) =
else return $ Yield outputs (ScanDrain q db running)
step _ ScanStop = return Stop

-- | Like 'parDistributeScanM' but takes a list of static scans.
--
-- >>> xs = [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int]
-- >>> Stream.toList $ Scanl.parDistributeScan id xs (Stream.enumerateFromTo 1 10)
-- ...
{-# INLINE parDistributeScan #-}
parDistributeScan :: MonadAsync m =>
(Config -> Config) -> [Scanl m a b] -> Stream m a -> Stream m [b]
parDistributeScan cfg getFolds stream =
Stream.concatEffect $ do
ref <- liftIO $ newIORef getFolds
let action = liftIO $ atomicModifyIORef ref (\xs -> ([], xs))
return $ parDistributeScanM cfg action stream

{-# ANN type DemuxState Fuse #-}
data DemuxState s q db f =
DemuxInit
Expand Down Expand Up @@ -273,17 +290,17 @@ data DemuxState s q db f =
-- >>> getScan k = return (fromJust $ Map.lookup k kv)
-- >>> getKey x = if even x then "even" else "odd"
-- >>> input = Stream.enumerateFromTo 1 10
-- >>> Stream.toList $ Scanl.parDemuxScan id getKey getScan input
-- >>> Stream.toList $ Scanl.parDemuxScanM id getKey getScan input
-- ...
--
{-# INLINE parDemuxScan #-}
parDemuxScan :: (MonadAsync m, Ord k) =>
{-# INLINE parDemuxScanM #-}
parDemuxScanM :: (MonadAsync m, Ord k) =>
(Config -> Config)
-> (a -> k)
-> (k -> m (Scanl m a b))
-> Stream m a
-> Stream m [(k, b)]
parDemuxScan cfg getKey getFold (Stream sstep state) =
parDemuxScanM cfg getKey getFold (Stream sstep state) =
Stream step DemuxInit

where
Expand Down Expand Up @@ -368,3 +385,14 @@ parDemuxScan cfg getKey getFold (Stream sstep state) =
return $ Skip (DemuxDrain q db keyToChan1)
else return $ Yield outputs (DemuxDrain q db keyToChan1)
step _ DemuxStop = return Stop

-- | Like 'parDemuxScanM' but the key to scan mapping is static/pure instead of
-- monadic.
{-# INLINE parDemuxScan #-}
parDemuxScan :: (MonadAsync m, Ord k) =>
(Config -> Config)
-> (a -> k)
-> (k -> Scanl m a b)
-> Stream m a
-> Stream m [(k, b)]
parDemuxScan cfg getKey getFold = parDemuxScanM cfg getKey (pure . getFold)
16 changes: 8 additions & 8 deletions test/Streamly/Test/Data/Scanl/Concurrent.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ parDistributeScan_ScanEnd concOpts = do
inpList = [1..streamLen]
inpStream = Stream.fromList inpList
res1 <-
Scanl.parDistributeScan concOpts gen inpStream
Scanl.parDistributeScanM concOpts gen inpStream
& Stream.concatMap Stream.fromList
& Stream.catMaybes
& Stream.fold Fold.toList
Expand All @@ -65,7 +65,7 @@ parDemuxScan_ScanEnd concOpts = do
inpList = [1..streamLen]
inpStream = Stream.fromList inpList
res <-
Scanl.parDemuxScan concOpts demuxer gen inpStream
Scanl.parDemuxScanM concOpts demuxer gen inpStream
& Stream.concatMap Stream.fromList
& fmap (\x -> (fst x,) <$> snd x)
& Stream.catMaybes
Expand All @@ -81,7 +81,7 @@ parDistributeScan_StreamEnd concOpts = do
inpList = [1..streamLen]
inpStream = Stream.fromList inpList
res1 <-
Scanl.parDistributeScan concOpts gen inpStream
Scanl.parDistributeScanM concOpts gen inpStream
& Stream.concatMap Stream.fromList
& Stream.catMaybes
& Stream.fold Fold.toList
Expand All @@ -95,7 +95,7 @@ parDemuxScan_StreamEnd concOpts = do
inpList = [1..streamLen]
inpStream = Stream.fromList inpList
res <-
Scanl.parDemuxScan concOpts demuxer gen inpStream
Scanl.parDemuxScanM concOpts demuxer gen inpStream
& Stream.concatMap Stream.fromList
& fmap (\x -> (fst x,) <$> snd x)
& Stream.catMaybes
Expand All @@ -110,11 +110,11 @@ main = hspec
$ modifyMaxSuccess (const 10)
#endif
$ describe moduleName $ do
it "parDistributeScan (stream end) (maxBuffer 1)"
it "parDistributeScanM (stream end) (maxBuffer 1)"
$ parDistributeScan_StreamEnd (Scanl.maxBuffer 1)
it "parDistributeScan (scan end) (maxBuffer 1)"
it "parDistributeScanM (scan end) (maxBuffer 1)"
$ parDistributeScan_ScanEnd (Scanl.maxBuffer 1)
it "parDemuxScan (stream end) (maxBuffer 1)"
it "parDemuxScanM (stream end) (maxBuffer 1)"
$ parDemuxScan_StreamEnd (Scanl.maxBuffer 1)
it "parDemuxScan (scan end) (maxBuffer 1)"
it "parDemuxScanM (scan end) (maxBuffer 1)"
$ parDemuxScan_ScanEnd (Scanl.maxBuffer 1)
Loading