Skip to content

Commit b04c799

Browse files
Add fairCross to Stream module
1 parent fe2778b commit b04c799

2 files changed

Lines changed: 71 additions & 5 deletions

File tree

core/src/Streamly/Internal/Data/Stream/Nesting.hs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,11 +1022,6 @@ schedForM = flip schedMapM
10221022
schedFor :: Monad m => Stream m a -> (a -> Stream m b) -> Stream m b
10231023
schedFor = flip schedMap
10241024

1025-
data FairUnfoldState o i =
1026-
FairUnfoldInit o ([i] -> [i])
1027-
| FairUnfoldNext o ([i] -> [i]) [i]
1028-
| FairUnfoldDrain ([i] -> [i]) [i]
1029-
10301025
-- | Similar to 'fairUnfoldEach' but scheduling is independent of the output.
10311026
--
10321027
-- >>> :{

core/src/Streamly/Internal/Data/Stream/Type.hs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ module Streamly.Internal.Data.Stream.Type
110110
, crossApplySnd
111111
, crossWith
112112
, cross
113+
, FairUnfoldState (..)
114+
, fairCrossWithM
115+
, fairCrossWith
116+
, fairCross
113117
, loop -- forEach
114118
, loopBy
115119

@@ -1298,6 +1302,59 @@ crossApply (Stream stepa statea) (Stream stepb stateb) =
12981302
Stop -> Skip (Left os))
12991303
(stepb (adaptState gst) st)
13001304

1305+
-- This is shared by all fairUnfold, fairConcat combinators.
1306+
data FairUnfoldState o i =
1307+
FairUnfoldInit o ([i] -> [i])
1308+
| FairUnfoldNext o ([i] -> [i]) [i]
1309+
| FairUnfoldDrain ([i] -> [i]) [i]
1310+
1311+
-- XXX will it perform better if we write it in the same way as crossApply?
1312+
-- crossApply is faster than unfoldCross in equation solving benchmarks.
1313+
1314+
-- | Like 'fairCrossWith' but with monadic function argument.
1315+
--
1316+
{-# INLINE_NORMAL fairCrossWithM #-}
1317+
fairCrossWithM :: Monad m =>
1318+
(a -> b -> m c) -> Stream m a -> Stream m b -> Stream m c
1319+
fairCrossWithM f (Stream step1 state1) (Stream step2 state2) =
1320+
Stream step (FairUnfoldInit state1 id)
1321+
1322+
where
1323+
1324+
{-# INLINE_LATE step #-}
1325+
step gst (FairUnfoldInit o ls) = do
1326+
r <- step1 (adaptState gst) o
1327+
return $ case r of
1328+
Yield b o' -> Skip (FairUnfoldNext o' id (ls [(b,state2)]))
1329+
Skip o' -> Skip (FairUnfoldInit o' ls)
1330+
Stop -> Skip (FairUnfoldDrain id (ls []))
1331+
1332+
step _ (FairUnfoldNext o ys []) =
1333+
return $ Skip (FairUnfoldInit o ys)
1334+
1335+
step gst (FairUnfoldNext o ys ((b,st):ls)) = do
1336+
r <- step2 (adaptState gst) st
1337+
case r of
1338+
Yield c s ->
1339+
f b c >>= \x ->
1340+
return $ Yield x (FairUnfoldNext o (ys . ((b, s) :)) ls)
1341+
Skip s -> return $ Skip (FairUnfoldNext o ys ((b,s) : ls))
1342+
Stop -> return $ Skip (FairUnfoldNext o ys ls)
1343+
1344+
step _ (FairUnfoldDrain ys []) =
1345+
case ys [] of
1346+
[] -> return Stop
1347+
xs -> return $ Skip (FairUnfoldDrain id xs)
1348+
1349+
step gst (FairUnfoldDrain ys ((b,st):ls)) = do
1350+
r <- step2 (adaptState gst) st
1351+
case r of
1352+
Yield c s ->
1353+
f b c >>= \x ->
1354+
return $ Yield x (FairUnfoldDrain (ys . ((b,s) :)) ls)
1355+
Skip s -> return $ Skip (FairUnfoldDrain ys ((b,s) : ls))
1356+
Stop -> return $ Skip (FairUnfoldDrain ys ls)
1357+
13011358
{-# INLINE_NORMAL crossApplySnd #-}
13021359
crossApplySnd :: Functor f => Stream f a -> Stream f b -> Stream f b
13031360
crossApplySnd (Stream stepa statea) (Stream stepb stateb) =
@@ -1380,6 +1437,14 @@ instance Applicative f => Applicative (Stream f) where
13801437
crossWith :: Monad m => (a -> b -> c) -> Stream m a -> Stream m b -> Stream m c
13811438
crossWith f m1 m2 = fmap f m1 `crossApply` m2
13821439

1440+
-- | Like 'crossWith' but interleaves the outer and inner loops fairly. See
1441+
-- 'fairConcatFor' for more details.
1442+
--
1443+
{-# INLINE fairCrossWith #-}
1444+
fairCrossWith :: Monad m =>
1445+
(a -> b -> c) -> Stream m a -> Stream m b -> Stream m c
1446+
fairCrossWith f = fairCrossWithM (\a b -> return $ f a b)
1447+
13831448
-- | Given a @Stream m a@ and @Stream m b@ generate a stream with all possible
13841449
-- combinations of the tuple @(a, b)@.
13851450
--
@@ -1399,6 +1464,12 @@ crossWith f m1 m2 = fmap f m1 `crossApply` m2
13991464
cross :: Monad m => Stream m a -> Stream m b -> Stream m (a, b)
14001465
cross = crossWith (,)
14011466

1467+
-- | Like 'cross' but interleaves the outer and inner loops fairly. See
1468+
-- 'fairConcatFor' for more details.
1469+
{-# INLINE fairCross #-}
1470+
fairCross :: Monad m => Stream m a -> Stream m b -> Stream m (a, b)
1471+
fairCross = fairCrossWith (,)
1472+
14021473
-- crossWith/cross should ideally use Stream m b as the first stream, because
14031474
-- we are transforming Stream m a using that. We provide loop with arguments
14041475
-- flipped.

0 commit comments

Comments
 (0)