diff --git a/src/Streamly/Internal/Data/Fold.hs b/src/Streamly/Internal/Data/Fold.hs index bee0dd6eb2..8799615532 100644 --- a/src/Streamly/Internal/Data/Fold.hs +++ b/src/Streamly/Internal/Data/Fold.hs @@ -199,6 +199,14 @@ module Streamly.Internal.Data.Fold , foldChunks , duplicate + -- * Exceptions + , handle + , onException + , bracket + , before + , after + , finally + -- * Running Folds , initialize , runStep @@ -209,8 +217,12 @@ module Streamly.Internal.Data.Fold ) where +import Control.Exception (mask_) import Control.Monad (void) +import Control.Monad.Catch (Exception, MonadCatch) import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Trans.Control (liftBaseOp_) +import Data.Functor (($>)) import Data.Functor.Identity (Identity(..)) import Data.Int (Int64) import Data.Map.Strict (Map) @@ -223,9 +235,11 @@ import Prelude scanl, scanl1, replicate, concatMap, mconcat, foldMap, unzip, span, splitAt, break, mapM) +import qualified Control.Monad.Catch as MC import qualified Data.Map.Strict as Map import qualified Prelude +import Streamly.Internal.Data.IORef (newFinalizedIORef, runIORefFinalizer) import Streamly.Internal.Data.Pipe.Types (Pipe (..), PipeState(..)) import Streamly.Internal.Data.Fold.Types import Streamly.Internal.Data.Strict @@ -1605,6 +1619,156 @@ lchunksInRange low high (Fold step1 initial1 extract1) (Fold step2 initial2 extract2) = undefined -} +------------------------------------------------------------------------------ +-- Exceptions +------------------------------------------------------------------------------ + +-- | Exception handling states of a fold +data HandleExc s f1 f2 = InitDone !s | InitFailed !f1 | StepFailed !f2 + +-- | @handle initHandler stepHandler fold@ produces a new fold from a given +-- fold. The new fold executes the original @fold@, if an exception occurs +-- when initializing the fold then @initHandler@ is executed and fold resulting +-- from that starts execution. If an exception occurs while executing the +-- @step@ function of a fold then the @stephandler@ is executed and we start +-- executing the fold resulting from that. +-- +-- The exception is caught and handled, not rethrown. If the exception handler +-- itself throws an exception that exception is thrown. +-- +-- /Internal/ +-- +{-# INLINE handle #-} +handle :: (MonadCatch m, Exception e) + => (e -> m (Fold m a b)) + -> (e -> Fold m a b -> m (Fold m a b)) + -> Fold m a b + -> Fold m a b +handle initH stepH (Fold step1 initial1 extract1) = Fold step initial extract + + where + + initial = fmap InitDone initial1 `MC.catch` (fmap InitFailed . initH) + + step (InitDone s) a = + let f = Fold step1 (return s) extract1 + in fmap InitDone (step1 s a) + `MC.catch` (\e -> fmap StepFailed (stepH e f)) + step (InitFailed (Fold step2 initial2 extract2)) a = do + s <- initial2 + s1 <- step2 s a + return $ InitFailed $ Fold step2 (return s1) extract2 + step (StepFailed (Fold step2 initial2 extract2)) a = do + s <- initial2 + s1 <- step2 s a + return $ StepFailed $ Fold step2 (return s1) extract2 + + extract (InitDone s) = extract1 s + extract (InitFailed (Fold _ initial2 extract2)) = initial2 >>= extract2 + extract (StepFailed (Fold _ initial2 extract2)) = initial2 >>= extract2 + +-- | @onException action fold@ runs @action@ whenever the fold throws an +-- exception. The action is executed on any exception whether it is in +-- initial, step or extract action of the fold. +-- +-- The exception is not caught, simply rethrown. If the @action@ itself +-- throws an exception that exception is thrown instead of the original +-- exception. +-- +-- /Internal/ +-- +{-# INLINE onException #-} +onException :: MonadCatch m => m x -> Fold m a b -> Fold m a b +onException action (Fold step1 initial1 extract1) = Fold step initial extract + + where + + initial = initial1 `MC.onException` action + step s a = step1 s a `MC.onException` action + extract s = extract1 s `MC.onException` action + +-- XXX we cannot use a bracketed fold for scan, because extract would release +-- the resource. This can be fixed when we have terminating folds, we can +-- release the resource on Stop instead of on extract. +-- +-- | @bracket before after between@ runs @before@ and invokes @between@ using +-- its output, then runs the fold generated by @between@. If the fold ends +-- normally, due to an exception or if it is garbage collected prematurely then +-- @after@ is run with the output of @before@ as argument. +-- +-- If @before@ or @after@ throw an exception that exception is thrown. +-- +-- /Internal/ +-- +{-# INLINE bracket #-} +bracket :: (MonadAsync m, MonadCatch m) + => m x -> (x -> m c) -> (x -> m (Fold m a b)) -> Fold m a b +bracket bef aft bet = Fold step initial extract + + where + + initial = do + (r, ref) <- liftBaseOp_ mask_ $ do + r <- bef + ref <- newFinalizedIORef (aft r) + return (r, ref) + fld <- bet r + return $ Tuple' ref fld + + step (Tuple' ref (Fold step1 initial1 extract1)) a = do + s <- initial1 + s1 <- step1 s a `MC.onException` runIORefFinalizer ref + return $ Tuple' ref $ Fold step1 (return s1) extract1 + + extract (Tuple' ref (Fold _ initial1 extract1)) = do + runIORefFinalizer ref + initial1 >>= extract1 + +-- | Run a side effect whenever the fold stops normally, aborts due to an +-- exception or is garbage collected. +-- +-- /Internal/ +-- +{-# INLINE finally #-} +finally :: (MonadAsync m, MonadCatch m) => m b -> Fold m a b -> Fold m a b +finally aft (Fold step1 initial1 extract1) = Fold step initial extract + + where + + initial = do + ref <- newFinalizedIORef aft + Tuple' ref <$> initial1 + + step (Tuple' ref s) a = do + s1 <- step1 s a `MC.onException` runIORefFinalizer ref + return $ Tuple' ref s1 + + extract (Tuple' ref s) = do + runIORefFinalizer ref + extract1 s + +-- | Run a side effect before the fold consumes its first element. +-- +-- /Internal/ +-- +{-# INLINE before #-} +before :: Monad m => m x -> Fold m a b -> Fold m a b +before effect (Fold step1 initial1 extract1) = Fold step1 initial extract1 + + where + + initial = effect *> initial1 + +-- | Run a side effect after the fold stops normally. Please use 'finally' if +-- you need a guarantee that the action runs even if the fold is garbage +-- collected. +-- +-- /Internal/ +-- +{-# INLINE after #-} +after :: Monad m => m x -> Fold m a b -> Fold m a b +after effect = mapM (effect $>) + ------------------------------------------------------------------------------ -- Fold to a Parallel SVar ------------------------------------------------------------------------------ diff --git a/src/Streamly/Internal/Data/IORef.hs b/src/Streamly/Internal/Data/IORef.hs new file mode 100644 index 0000000000..7b6aa887f7 --- /dev/null +++ b/src/Streamly/Internal/Data/IORef.hs @@ -0,0 +1,58 @@ +{-# LANGUAGE FlexibleContexts #-} + +-- | +-- Module : Streamly.Internal.Data.IORef +-- Copyright : (c) 2019 Composewell Technologies +-- License : BSD3 +-- Maintainer : streamly@composewell.com +-- Stability : experimental +-- Portability : GHC +-- +-- +module Streamly.Internal.Data.IORef + ( + newFinalizedIORef + , runIORefFinalizer + , clearIORefFinalizer + ) +where + +import Control.Monad (void) +import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Trans.Control (MonadBaseControl) +import Data.IORef (newIORef, readIORef, mkWeakIORef, writeIORef, IORef) + +import Streamly.Internal.Data.SVar + +-- | Create an IORef holding a finalizer that is called automatically when the +-- IORef is garbage collected. The IORef can be written to with a 'Nothing' +-- value to deactivate the finalizer. +newFinalizedIORef :: (MonadIO m, MonadBaseControl IO m) + => m a -> m (IORef (Maybe (IO ()))) +newFinalizedIORef finalizer = do + mrun <- captureMonadState + ref <- liftIO $ newIORef $ Just $ liftIO $ void $ do + _ <- runInIO mrun finalizer + return () + let finalizer1 = do + res <- readIORef ref + case res of + Nothing -> return () + Just f -> f + _ <- liftIO $ mkWeakIORef ref finalizer1 + return ref + +-- | Run the finalizer stored in an IORef and deactivate it so that it is run +-- only once. +-- +runIORefFinalizer :: MonadIO m => IORef (Maybe (IO ())) -> m () +runIORefFinalizer ref = liftIO $ do + res <- readIORef ref + case res of + Nothing -> return () + Just f -> writeIORef ref Nothing >> f + +-- | Deactivate the finalizer stored in an IORef without running it. +-- +clearIORefFinalizer :: MonadIO m => IORef (Maybe (IO ())) -> m () +clearIORefFinalizer ref = liftIO $ writeIORef ref Nothing diff --git a/src/Streamly/Internal/Data/Stream/StreamD.hs b/src/Streamly/Internal/Data/Stream/StreamD.hs index 3778b5d64e..9d78f02ea7 100644 --- a/src/Streamly/Internal/Data/Stream/StreamD.hs +++ b/src/Streamly/Internal/Data/Stream/StreamD.hs @@ -299,9 +299,6 @@ module Streamly.Internal.Data.Stream.StreamD , the -- * Exceptions - , newFinalizedIORef - , runIORefFinalizer - , clearIORefFinalizer , gbracket , before , after @@ -335,7 +332,7 @@ import Control.Monad.Trans.Control (MonadBaseControl, liftBaseOp_) import Data.Bits (shiftR, shiftL, (.|.), (.&.)) import Data.Functor.Identity (Identity(..)) import Data.Int (Int64) -import Data.IORef (newIORef, readIORef, mkWeakIORef, writeIORef, IORef) +import Data.IORef (newIORef, readIORef, mkWeakIORef, writeIORef) import Data.Maybe (fromJust, isJust, isNothing) import Data.Word (Word32) import Foreign.Ptr (Ptr) @@ -371,6 +368,7 @@ import Streamly.Internal.Data.Time.Units import Streamly.Internal.Data.Unfold.Types (Unfold(..)) import Streamly.Internal.Data.Strict (Tuple3'(..)) +import Streamly.Internal.Data.IORef import Streamly.Internal.Data.Stream.StreamD.Type import Streamly.Internal.Data.SVar import Streamly.Internal.Data.Stream.SVar (fromConsumer, pushToFold) @@ -3202,39 +3200,6 @@ gbracket bef exc aft fexc fnormal = Skip s -> return $ Skip (GBracketException (Stream step1 s)) Stop -> return Stop --- | Create an IORef holding a finalizer that is called automatically when the --- IORef is garbage collected. The IORef can be written to with a 'Nothing' --- value to deactivate the finalizer. -newFinalizedIORef :: (MonadIO m, MonadBaseControl IO m) - => m a -> m (IORef (Maybe (IO ()))) -newFinalizedIORef finalizer = do - mrun <- captureMonadState - ref <- liftIO $ newIORef $ Just $ liftIO $ void $ do - _ <- runInIO mrun finalizer - return () - let finalizer1 = do - res <- readIORef ref - case res of - Nothing -> return () - Just f -> f - _ <- liftIO $ mkWeakIORef ref finalizer1 - return ref - --- | Run the finalizer stored in an IORef and deactivate it so that it is run --- only once. --- -runIORefFinalizer :: MonadIO m => IORef (Maybe (IO ())) -> m () -runIORefFinalizer ref = liftIO $ do - res <- readIORef ref - case res of - Nothing -> return () - Just f -> writeIORef ref Nothing >> f - --- | Deactivate the finalizer stored in an IORef without running it. --- -clearIORefFinalizer :: MonadIO m => IORef (Maybe (IO ())) -> m () -clearIORefFinalizer ref = liftIO $ writeIORef ref Nothing - data GbracketIOState s1 s2 v wref = GBracketIOInit | GBracketIONormal s1 v wref @@ -3301,6 +3266,8 @@ gbracketIO bef exc aft fexc fnormal = Skip s -> return $ Skip (GBracketIOException (Stream step1 s)) Stop -> return Stop +-- Same as nilM action <> stream +-- -- | Run a side effect before the stream yields its first element. {-# INLINE_NORMAL before #-} before :: Monad m => m b -> Stream m a -> Stream m a @@ -3318,6 +3285,8 @@ before action (Stream step state) = Stream step' Nothing Skip s -> return $ Skip (Just s) Stop -> return Stop +-- Same as stream <> nilM action +-- -- | Run a side effect whenever the stream stops normally. {-# INLINE_NORMAL after #-} after :: Monad m => m b -> Stream m a -> Stream m a diff --git a/src/Streamly/Internal/Data/Unfold.hs b/src/Streamly/Internal/Data/Unfold.hs index 24c1b8d0d1..11cc1d626c 100644 --- a/src/Streamly/Internal/Data/Unfold.hs +++ b/src/Streamly/Internal/Data/Unfold.hs @@ -132,6 +132,7 @@ module Streamly.Internal.Data.Unfold where import Control.Exception (Exception, mask_) +import Control.Monad.Catch (MonadCatch) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Control (MonadBaseControl, liftBaseOp_) import Data.Void (Void) @@ -146,8 +147,9 @@ import Streamly.Internal.Data.Stream.StreamD.Type (pattern Stream) #endif import Streamly.Internal.Data.Unfold.Types (Unfold(..)) import Streamly.Internal.Data.Fold.Types (Fold(..)) +import Streamly.Internal.Data.IORef + (newFinalizedIORef, runIORefFinalizer, clearIORefFinalizer) import Streamly.Internal.Data.SVar (defState, MonadAsync) -import Control.Monad.Catch (MonadCatch) import qualified Prelude import qualified Control.Monad.Catch as MC @@ -794,7 +796,7 @@ gbracketIO bef exc aft (Unfold estep einject) (Unfold step1 inject1) = -- the registration of 'aft' atomic. See comment in 'D.gbracketIO'. (r, ref) <- liftBaseOp_ mask_ $ do r <- bef x - ref <- D.newFinalizedIORef (aft r) + ref <- newFinalizedIORef (aft r) return (r, ref) s <- inject1 r return $ Right (s, r, ref) @@ -807,10 +809,10 @@ gbracketIO bef exc aft (Unfold estep einject) (Unfold step1 inject1) = Yield x s -> return $ Yield x (Right (s, v, ref)) Skip s -> return $ Skip (Right (s, v, ref)) Stop -> do - D.runIORefFinalizer ref + runIORefFinalizer ref return Stop Left e -> do - D.clearIORefFinalizer ref + clearIORefFinalizer ref r <- einject (v, e) return $ Skip (Left r) step (Left st) = do @@ -893,7 +895,7 @@ afterIO action (Unfold step1 inject1) = Unfold step inject inject x = do s <- inject1 x - ref <- D.newFinalizedIORef (action x) + ref <- newFinalizedIORef (action x) return (s, ref) {-# INLINE_LATE step #-} @@ -903,7 +905,7 @@ afterIO action (Unfold step1 inject1) = Unfold step inject Yield x s -> return $ Yield x (s, ref) Skip s -> return $ Skip (s, ref) Stop -> do - D.runIORefFinalizer ref + runIORefFinalizer ref return Stop {-# INLINE_NORMAL _onException #-} @@ -979,17 +981,17 @@ finallyIO action (Unfold step1 inject1) = Unfold step inject inject x = do s <- inject1 x - ref <- D.newFinalizedIORef (action x) + ref <- newFinalizedIORef (action x) return (s, ref) {-# INLINE_LATE step #-} step (st, ref) = do - res <- step1 st `MC.onException` D.runIORefFinalizer ref + res <- step1 st `MC.onException` runIORefFinalizer ref case res of Yield x s -> return $ Yield x (s, ref) Skip s -> return $ Skip (s, ref) Stop -> do - D.runIORefFinalizer ref + runIORefFinalizer ref return Stop {-# INLINE_NORMAL _bracket #-} @@ -1048,19 +1050,19 @@ bracketIO bef aft (Unfold step1 inject1) = Unfold step inject -- the registration of 'aft' atomic. See comment in 'D.gbracketIO'. (r, ref) <- liftBaseOp_ mask_ $ do r <- bef x - ref <- D.newFinalizedIORef (aft r) + ref <- newFinalizedIORef (aft r) return (r, ref) s <- inject1 r return (s, ref) {-# INLINE_LATE step #-} step (st, ref) = do - res <- step1 st `MC.onException` D.runIORefFinalizer ref + res <- step1 st `MC.onException` runIORefFinalizer ref case res of Yield x s -> return $ Yield x (s, ref) Skip s -> return $ Skip (s, ref) Stop -> do - D.runIORefFinalizer ref + runIORefFinalizer ref return Stop -- | When unfolding if an exception occurs, unfold the exception using the diff --git a/streamly.cabal b/streamly.cabal index 845db4c334..36cbc2b740 100644 --- a/streamly.cabal +++ b/streamly.cabal @@ -310,6 +310,7 @@ library -- Internal modules , Streamly.Internal.BaseCompat , Streamly.Internal.Control.Monad + , Streamly.Internal.Data.IORef , Streamly.Internal.Data.Strict , Streamly.Internal.Data.Atomics , Streamly.Internal.Data.Time