Skip to content

Commit 836b0f5

Browse files
rnjtranjanharendra-kumar
authored andcommitted
Implement intersectBySorted API
1 parent 66fe3d1 commit 836b0f5

5 files changed

Lines changed: 125 additions & 11 deletions

File tree

benchmark/Streamly/Benchmark/Prelude/Serial/NestedStream.hs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,30 @@ joinInnerMap val1 val2 _ =
436436
(fmap toKvMap (mkStreamLen val1))
437437
(fmap toKvMap (mkStreamLen val2))
438438

439+
{-# INLINE intersectBy #-}
440+
intersectBy :: Int -> Int -> Int -> IO ()
441+
intersectBy val1 val2 _ =
442+
S.drain $
443+
Internal.intersectBy (==)
444+
(fmap toKvMap (mkStreamLen val1))
445+
(fmap toKvMap (mkStreamLen val2))
446+
447+
{-# INLINE intersectBySorted #-}
448+
intersectBySorted :: Int -> Int -> Int -> IO ()
449+
intersectBySorted val1 val2 _ =
450+
S.drain $
451+
Internal.intersectBySorted compare
452+
(fmap toKvMap (mkStreamLen val1))
453+
(fmap toKvMap (mkStreamLen val2))
454+
439455
o_n_heap_buffering :: Int -> [Benchmark]
440456
o_n_heap_buffering value =
441457
[ bgroup "buffered"
442458
[
443459
benchIOSrc1 "joinInner" (joinInner sqrtVal sqrtVal)
444460
, benchIOSrc1 "joinInnerMap" (joinInnerMap sqrtVal sqrtVal)
461+
, benchIOSrc1 "intersectBy" (intersectBy sqrtVal sqrtVal)
462+
, benchIOSrc1 "intersectBySorted" (intersectBySorted sqrtVal sqrtVal)
445463
]
446464
]
447465

src/Streamly/Internal/Data/Stream/IsStream/Top.hs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top
2828
-- | These are not exactly set operations because streams are not
2929
-- necessarily sets, they may have duplicated elements.
3030
, intersectBy
31-
, mergeIntersectBy
31+
, intersectBySorted
3232
, differenceBy
3333
, mergeDifferenceBy
3434
, unionBy
@@ -65,6 +65,7 @@ import Streamly.Internal.Data.Stream.IsStream.Common (concatM)
6565
import Streamly.Internal.Data.Stream.IsStream.Type
6666
(IsStream(..), adapt, foldl', fromList)
6767
import Streamly.Internal.Data.Stream.Serial (SerialT)
68+
--import Streamly.Internal.Data.Stream.StreamD (fromStreamD, toStreamD)
6869
import Streamly.Internal.Data.Time.Units (NanoSecond64(..), toRelTime64)
6970

7071
import qualified Data.List as List
@@ -79,6 +80,7 @@ import qualified Streamly.Internal.Data.Stream.IsStream.Expand as Stream
7980
import qualified Streamly.Internal.Data.Stream.IsStream.Reduce as Stream
8081
import qualified Streamly.Internal.Data.Stream.IsStream.Transform as Stream
8182
import qualified Streamly.Internal.Data.Stream.IsStream.Type as IsStream
83+
import qualified Streamly.Internal.Data.Stream.StreamD as StreamD
8284

8385
import Prelude hiding (filter, zipWith, concatMap, concat)
8486

@@ -540,11 +542,12 @@ intersectBy eq s1 s2 =
540542
--
541543
-- Time: O(m+n)
542544
--
543-
-- /Unimplemented/
544-
{-# INLINE mergeIntersectBy #-}
545-
mergeIntersectBy :: -- (IsStream t, Monad m) =>
545+
-- /Pre-release/
546+
{-# INLINE intersectBySorted #-}
547+
intersectBySorted :: (IsStream t, MonadIO m, Eq a) =>
546548
(a -> a -> Ordering) -> t m a -> t m a -> t m a
547-
mergeIntersectBy _eq _s1 _s2 = undefined
549+
intersectBySorted eq s1 =
550+
IsStream.fromStreamD . StreamD.intersectBySorted eq (IsStream.toStreamD s1) . IsStream.toStreamD
548551

549552
-- Roughly leftJoin s1 s2 = s1 `difference` s2 + s1 `intersection` s2
550553

src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ module Streamly.Internal.Data.Stream.StreamD.Nesting
142142
-- | Opposite to compact in ArrayStream
143143
, splitInnerBy
144144
, splitInnerBySuffix
145+
, intersectBySorted
145146
)
146147
where
147148

@@ -482,6 +483,59 @@ mergeBy
482483
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
483484
mergeBy cmp = mergeByM (\a b -> return $ cmp a b)
484485

486+
-------------------------------------------------------------------------------
487+
-- Intersection of sorted streams ---------------------------------------------
488+
-------------------------------------------------------------------------------
489+
{-# INLINE_NORMAL intersectBySorted #-}
490+
intersectBySorted
491+
:: (MonadIO m, Eq a)
492+
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
493+
intersectBySorted cmp (Stream stepa ta) (Stream stepb tb) =
494+
Stream step (Just ta, Just tb, Nothing, Nothing, Nothing)
495+
496+
where
497+
{-# INLINE_LATE step #-}
498+
499+
-- step 1
500+
step gst (Just sa, sb, Nothing, b, Nothing) = do
501+
r <- stepa gst sa
502+
return $ case r of
503+
Yield a sa' -> Skip (Just sa', sb, Just a, b, Nothing)
504+
Skip sa' -> Skip (Just sa', sb, Nothing, b, Nothing)
505+
Stop -> Stop
506+
507+
-- step 2
508+
step gst (sa, Just sb, a, Nothing, Nothing) = do
509+
r <- stepb gst sb
510+
return $ case r of
511+
Yield b sb' -> Skip (sa, Just sb', a, Just b, Nothing)
512+
Skip sb' -> Skip (sa, Just sb', a, Nothing, Nothing)
513+
Stop -> Stop
514+
515+
-- step 3
516+
-- both the values are available compare it
517+
step _ (sa, sb, Just a, Just b, Nothing) = do
518+
let res = cmp a b
519+
return $ case res of
520+
GT -> Skip (sa, sb, Just a, Nothing, Nothing)
521+
LT -> Skip (sa, sb, Nothing, Just b, Nothing)
522+
EQ -> Yield a (sa, sb, Nothing, Just a, Just b) -- step 4
523+
524+
-- step 4
525+
-- Matching element
526+
step gst (Just sa, Just sb, Nothing, Just _, Just b) = do
527+
r1 <- stepa gst sa
528+
return $ case r1 of
529+
Yield a' sa' -> do
530+
if a' == b -- match with prev a
531+
then Yield a' (Just sa', Just sb, Nothing, Just b, Just b) --step 1
532+
else Skip (Just sa', Just sb, Just a', Nothing, Nothing)
533+
534+
Skip sa' -> Skip (Just sa', Just sb, Nothing, Nothing, Nothing)
535+
Stop -> Stop
536+
537+
step _ (_, _, _, _, _) = return Stop
538+
485539
------------------------------------------------------------------------------
486540
-- Combine N Streams - unfoldMany
487541
------------------------------------------------------------------------------

streamly.cabal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ extra-source-files:
101101
test/Streamly/Test/Data/Array/Prim/Pinned.hs
102102
test/Streamly/Test/Data/Array/Foreign.hs
103103
test/Streamly/Test/Data/Array/Stream/Foreign.hs
104-
test/Streamly/Test/Data/Parser/ParserD.hs
104+
test/Streamly/Test/Data/Parser/ParserD.hs
105105
test/Streamly/Test/FileSystem/Event.hs
106106
test/Streamly/Test/FileSystem/Event/Common.hs
107107
test/Streamly/Test/FileSystem/Event/Darwin.hs

test/Streamly/Test/Prelude/Top.hs

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
module Main (main) where
1+
module Main (main)
2+
where
23

3-
import Data.List (sort)
4+
import Data.List (intersect, sort)
45
import Test.QuickCheck
56
( Gen
67
, Property
@@ -65,10 +66,45 @@ joinInnerMap =
6566
]
6667
assert (v1 == v2)
6768

68-
-------------------------------------------------------------------------------
69-
-- Main
70-
-------------------------------------------------------------------------------
69+
intersectBy :: Property
70+
intersectBy =
71+
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
72+
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
73+
monadicIO $ action (sort ls0) (sort ls1)
74+
75+
where
7176

77+
action ls0 ls1 = do
78+
v1 <-
79+
run
80+
$ S.toList
81+
$ Top.intersectBy
82+
(==)
83+
(S.fromList ls0)
84+
(S.fromList ls1)
85+
let v2 = intersect ls0 ls1
86+
assert (v1 == sort v2)
87+
88+
intersectBySorted :: Property
89+
intersectBySorted =
90+
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
91+
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
92+
monadicIO $ action (sort ls0) (sort ls1)
93+
94+
where
95+
96+
action ls0 ls1 = do
97+
v1 <-
98+
run
99+
$ S.toList
100+
$ Top.intersectBySorted
101+
compare
102+
(S.fromList ls0)
103+
(S.fromList ls1)
104+
let v2 = intersect ls0 ls1
105+
assert (v1 == sort v2)
106+
107+
-------------------------------------------------------------------------------
72108
moduleName :: String
73109
moduleName = "Prelude.Top"
74110

@@ -79,3 +115,6 @@ main = hspec $ do
79115

80116
prop "joinInner" Main.joinInner
81117
prop "joinInnerMap" Main.joinInnerMap
118+
-- intersect
119+
prop "intersectBy" Main.intersectBy
120+
prop "intersectBySorted" Main.intersectBySorted

0 commit comments

Comments
 (0)