@@ -85,9 +85,12 @@ module Streamly.Coreutils.Find
8585
8686 -- * Options
8787 , FindOptions
88+ , maxResults
8889 )
8990where
9091
92+ import Data.Function ((&) )
93+ import Data.Functor.Identity (runIdentity )
9194import Data.Maybe (fromJust )
9295import Data.Word (Word8 )
9396import Streamly.Data.Array (Array )
@@ -100,12 +103,17 @@ import System.IO (stdout, hSetBuffering, BufferMode(LineBuffering))
100103import qualified Streamly.Data.Stream.Prelude as Stream
101104import qualified Streamly.Data.Array as Array
102105import qualified Streamly.FileSystem.DirIO as DirIO
103- import qualified Streamly.Internal.Data.Array as GArray (compactMax' )
106+ import qualified Streamly.Internal.Data.Array as GArray
107+ ( compactMax'
108+ , read
109+ , unsafeSliceOffLen
110+ )
104111import qualified Streamly.Internal.Data.Stream as Stream
105112 ( unfoldEachEndBy
106113 , concatIterate
107114 , bfsConcatIterate
108115 , altBfsConcatIterate
116+ , postscanlMaybe
109117 )
110118import qualified Streamly.Data.StreamK as StreamK
111119import qualified Streamly.Internal.Data.StreamK as StreamK
@@ -125,6 +133,11 @@ import qualified Streamly.Internal.FileSystem.Posix.ReadDir as Dir
125133import qualified Streamly.Unicode.Stream as Stream
126134#endif
127135
136+ import Streamly.Internal.Data.Scanl (Step (.. ), Scanl (.. ))
137+ import qualified Streamly.Internal.Data.Scanl as Scanl
138+ import qualified Streamly.Internal.Data.Fold as Fold
139+ import qualified Streamly.Internal.Data.Array as Array
140+
128141--
129142-- Running on a sample directory tree the concurrent rust "fd" tool took 150 ms
130143-- (real time). On the same tree the fastest variant using Haskell streamly
@@ -147,10 +160,17 @@ data FindTraversal
147160 | FindParallelInterleaved
148161 | FindParallelOrdered
149162
150- newtype FindOptions = FindOptions { findTraversal :: FindTraversal }
163+ data FindOptions = FindOptions
164+ { findTraversal :: FindTraversal
165+ , findMaxResults :: Maybe Int
166+ }
151167
152168defaultConfig :: FindOptions
153- defaultConfig = FindOptions FindSerialDfs
169+ defaultConfig =
170+ FindOptions
171+ { findTraversal = FindSerialDfs
172+ , findMaxResults = Nothing
173+ }
154174
155175serialDfs :: FindOptions -> FindOptions
156176serialDfs cfg = cfg {findTraversal = FindSerialDfs }
@@ -176,6 +196,9 @@ parallelInterleaved cfg = cfg {findTraversal = FindParallelInterleaved}
176196parallelOrdered :: FindOptions -> FindOptions
177197parallelOrdered cfg = cfg {findTraversal = FindParallelOrdered }
178198
199+ maxResults :: Int -> FindOptions -> FindOptions
200+ maxResults n cfg = cfg {findMaxResults = Just (max 0 n)}
201+
179202{-# INLINE recReadOpts #-}
180203recReadOpts :: ReadOptions -> ReadOptions
181204{-# INLINE reEncode #-}
@@ -196,12 +219,70 @@ recReadOpts =
196219reEncode = id
197220#endif
198221
222+ data Counts = Counts ! Int ! Int deriving Show
223+
224+ {-# INLINE countStep #-}
225+ countStep :: Monad m => Counts -> Word8 -> m (Step Counts (Either Int Int ))
226+ countStep (Counts l c) ch =
227+ let l1 = if ch == 10 then l - 1 else l
228+ in if l1 == 0
229+ then return $ Done $ Left (c + 1 )
230+ else return $ Partial $ Counts l1 (c + 1 )
231+
232+ {-# INLINE countExtract #-}
233+ countExtract :: Monad m => Counts -> m (Either a Int )
234+ countExtract (Counts l _) = return $ Right l
235+
236+ {-# INLINE count #-}
237+ count :: Monad m => Int -> Fold. Fold m Word8 (Either Int Int )
238+ count l = Fold. foldtM' countStep (return $ Partial (Counts l 0 )) countExtract
239+
240+ -- XXX Scanl is an awkward abstraction for the case when we are emitting every
241+ -- element and just need to transform the elements using the state. We need a
242+ -- smapM instead for this case. In the scan we are forced to use a Maybe and
243+ -- then catMaybe unnecessarily to store the elements. Because only in the
244+ -- initial state we do not have an element.
245+ --
246+ {-# INLINE scanStep #-}
247+ scanStep :: Monad m =>
248+ (Int , Maybe (Array Word8 ))
249+ -> Array Word8
250+ -> m (Step (Int , Maybe (Array Word8 )) (Maybe (Array Word8 )))
251+ scanStep (n, _) arr = do
252+ r <- Array. read arr & Stream. fold (count n)
253+ case r of
254+ Left len -> return $ Done $ Just (Array. unsafeSliceOffLen 0 len arr)
255+ Right cnt ->
256+ if cnt /= 0
257+ then return $ Partial (cnt, Just arr)
258+ else return $ Done (Just arr)
259+
260+ {-# INLINE scanExtract #-}
261+ scanExtract :: Monad m => (Int , Maybe (Array Word8 )) -> m (Maybe (Array Word8 ))
262+ scanExtract (_, arr) = return arr
263+
264+ {-# INLINE scanFinal #-}
265+ scanFinal :: Monad m => (Int , Maybe (Array Word8 )) -> m (Maybe (Array Word8 ))
266+ scanFinal (_, arr) = return arr
267+
268+ {-# INLINE takeN #-}
269+ takeN :: Int -> Stream IO (Array Word8 ) -> Stream IO (Array Word8 )
270+ takeN n
271+ | n <= 0 = const Stream. nil
272+ | otherwise =
273+ Stream. postscanlMaybe
274+ (Scanl
275+ scanStep
276+ (return (Partial (n, Nothing )))
277+ scanExtract
278+ scanFinal)
279+
199280#if !defined(mingw32_HOST_OS) && !defined(__MINGW32__)
200281-- Fastest implementation, only works for posix as of now.
201282findByteChunked :: (FindOptions -> FindOptions ) -> Path -> Stream IO (Array Word8 )
202283findByteChunked f path =
203- Stream. catRights $
204- case findTraversal (f defaultConfig) of
284+ transform $ Stream. catRights $
285+ case findTraversal opts of
205286 FindSerialDfs ->
206287 Stream. concatIterate streamDirMaybe -- 154 ms
207288 $ Stream. fromPure (Left [path])
@@ -232,6 +313,11 @@ findByteChunked f path =
232313
233314 where
234315
316+ {-# INLINE transform #-}
317+ transform s = maybe s (\ n -> takeN n s) (findMaxResults opts)
318+
319+ opts = f defaultConfig
320+
235321 concatIterateWith combine =
236322 StreamK. toStream
237323 . StreamK. concatIterateWith combine (StreamK. fromStream . streamDir)
@@ -253,6 +339,7 @@ findByteChunked f path =
253339-- Faster than the find implementation below
254340findChunked :: (FindOptions -> FindOptions ) -> Path -> Stream IO [Path ]
255341findChunked f path =
342+ -- XXX implement maxResults
256343 Stream. catRights $
257344 case findTraversal (f defaultConfig) of
258345 FindSerialDfs ->
@@ -300,6 +387,7 @@ findChunked f path =
300387
301388find :: (FindOptions -> FindOptions ) -> Path -> Stream IO Path
302389find f path =
390+ -- XXX implement maxResults
303391 Stream. catRights $
304392 case findTraversal (f defaultConfig) of
305393 FindSerialDfs ->
0 commit comments