Skip to content

Commit a7dc3cd

Browse files
Implement maxResults
1 parent 0658e91 commit a7dc3cd

2 files changed

Lines changed: 117 additions & 34 deletions

File tree

app/hfd.hs

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
{-# LANGUAGE CPP #-}
2-
31
module Main (main) where
42

53
import System.IO (stdout)
@@ -17,23 +15,20 @@ import Options.Applicative
1715
, long
1816
, metavar
1917
, optional
18+
, option
2019
, progDesc
2120
, strArgument
2221
, (<**>)
2322
)
2423
import qualified Options.Applicative as OA
2524
import qualified Streamly.Data.Stream.Prelude as Stream
26-
#if defined(mingw32_HOST_OS) || defined(__MINGW32__)
27-
import qualified Streamly.Data.Array as Array
28-
import qualified Streamly.Data.Unfold as Unfold
29-
import qualified Streamly.Internal.Data.Stream as Stream (unfoldEachEndBy)
30-
import qualified Streamly.Unicode.Stream as Unicode
31-
#endif
3225
import qualified Streamly.FileSystem.Handle as Handle
3326
import qualified Streamly.FileSystem.Path as Path
3427

3528
import Streamly.Coreutils.Find
3629
( FindOptions
30+
, findByteChunked
31+
, maxResults
3732
, parallelInterleaved
3833
, parallelOrdered
3934
, parallelUnordered
@@ -43,15 +38,11 @@ import Streamly.Coreutils.Find
4338
, serialDfs
4439
, serialInterleaved
4540
)
46-
#if !defined(mingw32_HOST_OS) && !defined(__MINGW32__)
47-
import Streamly.Coreutils.Find (findByteChunked)
48-
#else
49-
import Streamly.Coreutils.Find (findChunked)
50-
#endif
5141

5242
data Config = Config
5343
{ cfgTraversal :: FindOptions -> FindOptions
5444
, cfgRoot :: FilePath
45+
, cfgMaxResults :: Maybe Int
5546
}
5647

5748
data Traversal
@@ -76,11 +67,12 @@ toTraversalConfig traversal =
7667
TraversalParallelInterleaved -> parallelInterleaved
7768
TraversalParallelOrdered -> parallelOrdered
7869

79-
mkConfig :: Traversal -> Maybe FilePath -> Config
80-
mkConfig traversal mPath =
70+
mkConfig :: Traversal -> Maybe Int -> Maybe FilePath -> Config
71+
mkConfig traversal mMaxResults mPath =
8172
Config
8273
{ cfgTraversal = toTraversalConfig traversal
8374
, cfgRoot = maybe "." id mPath
75+
, cfgMaxResults = mMaxResults
8476
}
8577

8678
traversalParser :: Parser Traversal
@@ -107,6 +99,11 @@ configParser :: Parser Config
10799
configParser =
108100
mkConfig
109101
<$> traversalParser
102+
<*> optional
103+
(option (OA.eitherReader parsePositiveInt)
104+
(long "max-results"
105+
<> metavar "N"
106+
<> help "Stop after emitting N results"))
110107
<*> optional
111108
(strArgument
112109
(metavar "PATH" <> help "Root path to search"))
@@ -120,22 +117,20 @@ parserInfo =
120117
<> progDesc "A basic fd-like driver for Streamly.Coreutils.Find."
121118
<> header "hfd")
122119

123-
#if defined(mingw32_HOST_OS) || defined(__MINGW32__)
124-
#endif
125-
126120
main :: IO ()
127121
main = do
128122
cfg <- execParser parserInfo
129123
path <- Path.fromString (cfgRoot cfg)
130-
#if !defined(mingw32_HOST_OS) && !defined(__MINGW32__)
124+
let applyConfig opts =
125+
maybe id maxResults (cfgMaxResults cfg) $
126+
cfgTraversal cfg opts
131127
Stream.fold (Handle.writeChunks stdout)
132-
$ findByteChunked (cfgTraversal cfg) path
133-
#else
134-
Stream.fold (Handle.writeWith 32000 stdout)
135-
$ Unicode.encodeUtf8
136-
$ Unicode.decodeUtf16le
137-
$ Stream.unfoldEachEndBy 10 Array.reader
138-
$ fmap Path.toArray
139-
$ Stream.unfoldEach Unfold.fromList
140-
$ findChunked (cfgTraversal cfg) path
141-
#endif
128+
$ findByteChunked applyConfig path
129+
130+
parsePositiveInt :: String -> Either String Int
131+
parsePositiveInt str =
132+
case reads str of
133+
[(n, "")]
134+
| n > 0 -> Right n
135+
| otherwise -> Left "N must be positive"
136+
_ -> Left "N must be an integer"

src/Streamly/Coreutils/Find.hs

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ module Streamly.Coreutils.Find
8585

8686
-- * Options
8787
, FindOptions
88+
, maxResults
8889
)
8990
where
9091

92+
import Data.Function ((&))
93+
import Data.Functor.Identity (runIdentity)
9194
import Data.Maybe (fromJust)
9295
import Data.Word (Word8)
9396
import Streamly.Data.Array (Array)
@@ -100,12 +103,17 @@ import System.IO (stdout, hSetBuffering, BufferMode(LineBuffering))
100103
import qualified Streamly.Data.Stream.Prelude as Stream
101104
import qualified Streamly.Data.Array as Array
102105
import qualified Streamly.FileSystem.DirIO as DirIO
103-
import qualified Streamly.Internal.Data.Array as GArray (compactMax')
106+
import qualified Streamly.Internal.Data.Array as GArray
107+
( compactMax'
108+
, read
109+
, unsafeSliceOffLen
110+
)
104111
import qualified Streamly.Internal.Data.Stream as Stream
105112
( unfoldEachEndBy
106113
, concatIterate
107114
, bfsConcatIterate
108115
, altBfsConcatIterate
116+
, postscanlMaybe
109117
)
110118
import qualified Streamly.Data.StreamK as StreamK
111119
import qualified Streamly.Internal.Data.StreamK as StreamK
@@ -125,6 +133,11 @@ import qualified Streamly.Internal.FileSystem.Posix.ReadDir as Dir
125133
import qualified Streamly.Unicode.Stream as Stream
126134
#endif
127135

136+
import Streamly.Internal.Data.Scanl (Step(..), Scanl(..))
137+
import qualified Streamly.Internal.Data.Scanl as Scanl
138+
import qualified Streamly.Internal.Data.Fold as Fold
139+
import qualified Streamly.Internal.Data.Array as Array
140+
128141
--
129142
-- Running on a sample directory tree the concurrent rust "fd" tool took 150 ms
130143
-- (real time). On the same tree the fastest variant using Haskell streamly
@@ -147,10 +160,17 @@ data FindTraversal
147160
| FindParallelInterleaved
148161
| FindParallelOrdered
149162

150-
newtype FindOptions = FindOptions {findTraversal :: FindTraversal}
163+
data FindOptions = FindOptions
164+
{ findTraversal :: FindTraversal
165+
, findMaxResults :: Maybe Int
166+
}
151167

152168
defaultConfig :: FindOptions
153-
defaultConfig = FindOptions FindSerialDfs
169+
defaultConfig =
170+
FindOptions
171+
{ findTraversal = FindSerialDfs
172+
, findMaxResults = Nothing
173+
}
154174

155175
serialDfs :: FindOptions -> FindOptions
156176
serialDfs cfg = cfg {findTraversal = FindSerialDfs}
@@ -176,6 +196,9 @@ parallelInterleaved cfg = cfg {findTraversal = FindParallelInterleaved}
176196
parallelOrdered :: FindOptions -> FindOptions
177197
parallelOrdered cfg = cfg {findTraversal = FindParallelOrdered}
178198

199+
maxResults :: Int -> FindOptions -> FindOptions
200+
maxResults n cfg = cfg {findMaxResults = Just (max 0 n)}
201+
179202
{-# INLINE recReadOpts #-}
180203
recReadOpts :: ReadOptions -> ReadOptions
181204
{-# INLINE reEncode #-}
@@ -196,12 +219,70 @@ recReadOpts =
196219
reEncode = id
197220
#endif
198221

222+
data Counts = Counts !Int !Int deriving Show
223+
224+
{-# INLINE countStep #-}
225+
countStep :: Monad m => Counts -> Word8 -> m (Step Counts (Either Int Int))
226+
countStep (Counts l c) ch =
227+
let l1 = if ch == 10 then l - 1 else l
228+
in if l1 == 0
229+
then return $ Done $ Left (c + 1)
230+
else return $ Partial $ Counts l1 (c + 1)
231+
232+
{-# INLINE countExtract #-}
233+
countExtract :: Monad m => Counts -> m (Either a Int)
234+
countExtract (Counts l _) = return $ Right l
235+
236+
{-# INLINE count #-}
237+
count :: Monad m => Int -> Fold.Fold m Word8 (Either Int Int)
238+
count l = Fold.foldtM' countStep (return $ Partial (Counts l 0 )) countExtract
239+
240+
-- XXX Scanl is an awkward abstraction for the case when we are emitting every
241+
-- element and just need to transform the elements using the state. We need a
242+
-- smapM instead for this case. In the scan we are forced to use a Maybe and
243+
-- then catMaybe unnecessarily to store the elements. Because only in the
244+
-- initial state we do not have an element.
245+
--
246+
{-# INLINE scanStep #-}
247+
scanStep :: Monad m =>
248+
(Int, Maybe (Array Word8))
249+
-> Array Word8
250+
-> m (Step (Int, Maybe (Array Word8)) (Maybe (Array Word8)))
251+
scanStep (n, _) arr = do
252+
r <- Array.read arr & Stream.fold (count n)
253+
case r of
254+
Left len -> return $ Done $ Just (Array.unsafeSliceOffLen 0 len arr)
255+
Right cnt ->
256+
if cnt /= 0
257+
then return $ Partial (cnt, Just arr)
258+
else return $ Done (Just arr)
259+
260+
{-# INLINE scanExtract #-}
261+
scanExtract :: Monad m => (Int, Maybe (Array Word8)) -> m (Maybe (Array Word8))
262+
scanExtract (_, arr) = return arr
263+
264+
{-# INLINE scanFinal #-}
265+
scanFinal :: Monad m => (Int, Maybe (Array Word8)) -> m (Maybe (Array Word8))
266+
scanFinal (_, arr) = return arr
267+
268+
{-# INLINE takeN #-}
269+
takeN :: Int -> Stream IO (Array Word8) -> Stream IO (Array Word8)
270+
takeN n
271+
| n <= 0 = const Stream.nil
272+
| otherwise =
273+
Stream.postscanlMaybe
274+
(Scanl
275+
scanStep
276+
(return (Partial (n, Nothing)))
277+
scanExtract
278+
scanFinal)
279+
199280
#if !defined(mingw32_HOST_OS) && !defined(__MINGW32__)
200281
-- Fastest implementation, only works for posix as of now.
201282
findByteChunked :: (FindOptions -> FindOptions) -> Path -> Stream IO (Array Word8)
202283
findByteChunked f path =
203-
Stream.catRights $
204-
case findTraversal (f defaultConfig) of
284+
transform $ Stream.catRights $
285+
case findTraversal opts of
205286
FindSerialDfs ->
206287
Stream.concatIterate streamDirMaybe -- 154 ms
207288
$ Stream.fromPure (Left [path])
@@ -232,6 +313,11 @@ findByteChunked f path =
232313

233314
where
234315

316+
{-# INLINE transform #-}
317+
transform s = maybe s (\n -> takeN n s) (findMaxResults opts)
318+
319+
opts = f defaultConfig
320+
235321
concatIterateWith combine =
236322
StreamK.toStream
237323
. StreamK.concatIterateWith combine (StreamK.fromStream . streamDir)
@@ -253,6 +339,7 @@ findByteChunked f path =
253339
-- Faster than the find implementation below
254340
findChunked :: (FindOptions -> FindOptions) -> Path -> Stream IO [Path]
255341
findChunked f path =
342+
-- XXX implement maxResults
256343
Stream.catRights $
257344
case findTraversal (f defaultConfig) of
258345
FindSerialDfs ->
@@ -300,6 +387,7 @@ findChunked f path =
300387

301388
find :: (FindOptions -> FindOptions) -> Path -> Stream IO Path
302389
find f path =
390+
-- XXX implement maxResults
303391
Stream.catRights $
304392
case findTraversal (f defaultConfig) of
305393
FindSerialDfs ->

0 commit comments

Comments
 (0)