@@ -77,6 +77,7 @@ module Data.Vector.Generic.Mutable (
7777 PrimMonad , PrimState , RealWorld
7878) where
7979
80+ import Control.Monad.ST
8081import Data.Vector.Generic.Mutable.Base
8182import qualified Data.Vector.Generic.Base as V
8283
@@ -89,7 +90,7 @@ import Data.Vector.Fusion.Bundle.Size
8990import Data.Vector.Fusion.Util ( delay_inline )
9091import Data.Vector.Internal.Check
9192
92- import Control.Monad.Primitive ( PrimMonad (.. ), RealWorld , stToPrim )
93+ import Control.Monad.Primitive ( PrimMonad (.. ), stToPrim )
9394
9495import Prelude
9596 ( Ord , Monad , Bool (.. ), Int , Maybe (.. ), Either (.. ), Ordering (.. )
@@ -104,8 +105,7 @@ import Data.Bits ( Bits(shiftR) )
104105-- Internal functions
105106-- ------------------
106107
107- unsafeAppend1 :: (PrimMonad m , MVector v a )
108- => v (PrimState m ) a -> Int -> a -> m (v (PrimState m ) a )
108+ unsafeAppend1 :: (MVector v a ) => v s a -> Int -> a -> ST s (v s a )
109109{-# INLINE_INNER unsafeAppend1 #-}
110110 -- NOTE: The case distinction has to be on the outside because
111111 -- GHC creates a join point for the unsafeWrite even when everything
@@ -120,8 +120,7 @@ unsafeAppend1 v i x
120120 checkIndex Internal i (length v') $ unsafeWrite v' i x
121121 return v'
122122
123- unsafePrepend1 :: (PrimMonad m , MVector v a )
124- => v (PrimState m ) a -> Int -> a -> m (v (PrimState m ) a , Int )
123+ unsafePrepend1 :: (MVector v a ) => v s a -> Int -> a -> ST s (v s a , Int )
125124{-# INLINE_INNER unsafePrepend1 #-}
126125unsafePrepend1 v i x
127126 | i /= 0 = do
@@ -205,7 +204,7 @@ unstream :: (PrimMonad m, MVector v a)
205204 => Bundle u a -> m (v (PrimState m ) a )
206205-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
207206{-# INLINE_FUSED unstream #-}
208- unstream s = munstream (Bundle. lift s)
207+ unstream s = stToPrim $ munstream (Bundle. lift s)
209208
210209-- | Create a new mutable vector and fill it with elements from the monadic
211210-- stream. The vector will grow exponentially if the maximum size of the stream
@@ -241,9 +240,8 @@ munstreamUnknown s
241240 $ unsafeSlice 0 n v'
242241 where
243242 {-# INLINE_INNER put #-}
244- put (v,i) x = do
245- v' <- unsafeAppend1 v i x
246- return (v',i+ 1 )
243+ put (v,i) x = stToPrim $ do v' <- unsafeAppend1 v i x
244+ return (v',i+ 1 )
247245
248246
249247-- | Create a new mutable vector and fill it with elements from the 'Bundle'.
@@ -253,7 +251,7 @@ vunstream :: (PrimMonad m, V.Vector v a)
253251 => Bundle v a -> m (V. Mutable v (PrimState m ) a )
254252-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
255253{-# INLINE_FUSED vunstream #-}
256- vunstream s = vmunstream (Bundle. lift s)
254+ vunstream s = stToPrim $ vmunstream (Bundle. lift s)
257255
258256-- | Create a new mutable vector and fill it with elements from the monadic
259257-- stream. The vector will grow exponentially if the maximum size of the stream
@@ -309,7 +307,7 @@ unstreamR :: (PrimMonad m, MVector v a)
309307 => Bundle u a -> m (v (PrimState m ) a )
310308-- NOTE: replace INLINE_FUSED by INLINE? (also in unstream)
311309{-# INLINE_FUSED unstreamR #-}
312- unstreamR s = munstreamR (Bundle. lift s)
310+ unstreamR s = stToPrim $ munstreamR (Bundle. lift s)
313311
314312-- | Create a new mutable vector and fill it with elements from the monadic
315313-- stream from right to left. The vector will grow exponentially if the maximum
@@ -348,7 +346,7 @@ munstreamRUnknown s
348346 $ unsafeSlice i (n- i) v'
349347 where
350348 {-# INLINE_INNER put #-}
351- put (v,i) x = unsafePrepend1 v i x
349+ put (v,i) x = stToPrim $ unsafePrepend1 v i x
352350
353351-- Length
354352-- ------
@@ -561,10 +559,9 @@ enlarge_delta :: MVector v a => v s a -> Int
561559enlarge_delta v = max (length v) 1
562560
563561-- | Grow a vector logarithmically.
564- enlarge :: (PrimMonad m , MVector v a )
565- => v (PrimState m ) a -> m (v (PrimState m ) a )
562+ enlarge :: (MVector v a ) => v s a -> ST s (v s a )
566563{-# INLINE enlarge #-}
567- enlarge v = stToPrim $ do
564+ enlarge v = do
568565 vnew <- unsafeGrow v by
569566 basicInitialize $ basicUnsafeSlice (length v) by vnew
570567 return vnew
@@ -994,10 +991,10 @@ unsafeMove dst src = check Unsafe "length mismatch" (length dst == length src)
994991accum :: forall m v a b u . (HasCallStack , PrimMonad m , MVector v a )
995992 => (a -> b -> a ) -> v (PrimState m ) a -> Bundle u (Int , b ) -> m ()
996993{-# INLINE accum #-}
997- accum f ! v s = Bundle. mapM_ upd s
994+ accum f ! v s = stToPrim $ Bundle. mapM_ upd s
998995 where
999996 {-# INLINE_INNER upd #-}
1000- upd :: HasCallStack => (Int , b ) -> m ()
997+ upd :: HasCallStack => (Int , b ) -> ST ( PrimState m ) ()
1001998 upd (i,b) = do
1002999 a <- checkIndex Bounds i n $ unsafeRead v i
10031000 unsafeWrite v i (f a b)
@@ -1006,18 +1003,18 @@ accum f !v s = Bundle.mapM_ upd s
10061003update :: forall m v a u . (HasCallStack , PrimMonad m , MVector v a )
10071004 => v (PrimState m ) a -> Bundle u (Int , a ) -> m ()
10081005{-# INLINE update #-}
1009- update ! v s = Bundle. mapM_ upd s
1006+ update ! v s = stToPrim $ Bundle. mapM_ upd s
10101007 where
10111008 {-# INLINE_INNER upd #-}
1012- upd :: HasCallStack => (Int , a ) -> m ()
1009+ upd :: HasCallStack => (Int , a ) -> ST ( PrimState m ) ()
10131010 upd (i,b) = checkIndex Bounds i n $ unsafeWrite v i b
10141011
10151012 ! n = length v
10161013
10171014unsafeAccum :: (PrimMonad m , MVector v a )
10181015 => (a -> b -> a ) -> v (PrimState m ) a -> Bundle u (Int , b ) -> m ()
10191016{-# INLINE unsafeAccum #-}
1020- unsafeAccum f ! v s = Bundle. mapM_ upd s
1017+ unsafeAccum f ! v s = stToPrim $ Bundle. mapM_ upd s
10211018 where
10221019 {-# INLINE_INNER upd #-}
10231020 upd (i,b) = do
@@ -1026,17 +1023,17 @@ unsafeAccum f !v s = Bundle.mapM_ upd s
10261023 ! n = length v
10271024
10281025unsafeUpdate :: (PrimMonad m , MVector v a )
1029- => v (PrimState m ) a -> Bundle u (Int , a ) -> m ()
1026+ => v (PrimState m ) a -> Bundle u (Int , a ) -> m ()
10301027{-# INLINE unsafeUpdate #-}
1031- unsafeUpdate ! v s = Bundle. mapM_ upd s
1028+ unsafeUpdate ! v s = stToPrim $ Bundle. mapM_ upd s
10321029 where
10331030 {-# INLINE_INNER upd #-}
10341031 upd (i,b) = checkIndex Unsafe i n $ unsafeWrite v i b
10351032 ! n = length v
10361033
10371034reverse :: (PrimMonad m , MVector v a ) => v (PrimState m ) a -> m ()
10381035{-# INLINE reverse #-}
1039- reverse ! v = reverse_loop 0 (length v - 1 )
1036+ reverse ! v = stToPrim $ reverse_loop 0 (length v - 1 )
10401037 where
10411038 reverse_loop i j | i < j = do
10421039 unsafeSwap v i j
@@ -1046,11 +1043,11 @@ reverse !v = reverse_loop 0 (length v - 1)
10461043unstablePartition :: forall m v a . (PrimMonad m , MVector v a )
10471044 => (a -> Bool ) -> v (PrimState m ) a -> m Int
10481045{-# INLINE unstablePartition #-}
1049- unstablePartition f ! v = from_left 0 (length v)
1046+ unstablePartition f ! v = stToPrim $ from_left 0 (length v)
10501047 where
10511048 -- NOTE: GHC 6.10.4 panics without the signatures on from_left and
10521049 -- from_right
1053- from_left :: Int -> Int -> m Int
1050+ from_left :: Int -> Int -> ST ( PrimState m ) Int
10541051 from_left i j
10551052 | i == j = return i
10561053 | otherwise = do
@@ -1059,7 +1056,7 @@ unstablePartition f !v = from_left 0 (length v)
10591056 then from_left (i+ 1 ) j
10601057 else from_right i (j- 1 )
10611058
1062- from_right :: Int -> Int -> m Int
1059+ from_right :: Int -> Int -> ST ( PrimState m ) Int
10631060 from_right i j
10641061 | i == j = return i
10651062 | otherwise = do
@@ -1076,7 +1073,8 @@ unstablePartitionBundle :: (PrimMonad m, MVector v a)
10761073 => (a -> Bool ) -> Bundle u a -> m (v (PrimState m ) a , v (PrimState m ) a )
10771074{-# INLINE unstablePartitionBundle #-}
10781075unstablePartitionBundle f s
1079- = case upperBound (Bundle. size s) of
1076+ = stToPrim
1077+ $ case upperBound (Bundle. size s) of
10801078 Just n -> unstablePartitionMax f s n
10811079 Nothing -> partitionUnknown f s
10821080
@@ -1085,7 +1083,7 @@ unstablePartitionMax :: (PrimMonad m, MVector v a)
10851083 -> m (v (PrimState m ) a , v (PrimState m ) a )
10861084{-# INLINE unstablePartitionMax #-}
10871085unstablePartitionMax f s n
1088- = do
1086+ = stToPrim $ do
10891087 v <- checkLength Internal n $ unsafeNew n
10901088 let {-# INLINE_INNER put #-}
10911089 put (i, j) x
@@ -1103,15 +1101,15 @@ partitionBundle :: (PrimMonad m, MVector v a)
11031101 => (a -> Bool ) -> Bundle u a -> m (v (PrimState m ) a , v (PrimState m ) a )
11041102{-# INLINE partitionBundle #-}
11051103partitionBundle f s
1106- = case upperBound (Bundle. size s) of
1104+ = stToPrim
1105+ $ case upperBound (Bundle. size s) of
11071106 Just n -> partitionMax f s n
11081107 Nothing -> partitionUnknown f s
11091108
11101109partitionMax :: (PrimMonad m , MVector v a )
11111110 => (a -> Bool ) -> Bundle u a -> Int -> m (v (PrimState m ) a , v (PrimState m ) a )
11121111{-# INLINE partitionMax #-}
1113- partitionMax f s n
1114- = do
1112+ partitionMax f s n = stToPrim $ do
11151113 v <- checkLength Internal n $ unsafeNew n
11161114
11171115 let {-# INLINE_INNER put #-}
@@ -1136,8 +1134,7 @@ partitionMax f s n
11361134partitionUnknown :: (PrimMonad m , MVector v a )
11371135 => (a -> Bool ) -> Bundle u a -> m (v (PrimState m ) a , v (PrimState m ) a )
11381136{-# INLINE partitionUnknown #-}
1139- partitionUnknown f s
1140- = do
1137+ partitionUnknown f s = stToPrim $ do
11411138 v1 <- unsafeNew 0
11421139 v2 <- unsafeNew 0
11431140 (v1', n1, v2', n2) <- Bundle. foldM' put (v1, 0 , v2, 0 ) s
@@ -1163,15 +1160,16 @@ partitionWithBundle :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
11631160 => (a -> Either b c ) -> Bundle u a -> m (v (PrimState m ) b , v (PrimState m ) c )
11641161{-# INLINE partitionWithBundle #-}
11651162partitionWithBundle f s
1166- = case upperBound (Bundle. size s) of
1163+ = stToPrim
1164+ $ case upperBound (Bundle. size s) of
11671165 Just n -> partitionWithMax f s n
11681166 Nothing -> partitionWithUnknown f s
11691167
11701168partitionWithMax :: (PrimMonad m , MVector v a , MVector v b , MVector v c )
11711169 => (a -> Either b c ) -> Bundle u a -> Int -> m (v (PrimState m ) b , v (PrimState m ) c )
11721170{-# INLINE partitionWithMax #-}
11731171partitionWithMax f s n
1174- = do
1172+ = stToPrim $ do
11751173 v1 <- unsafeNew n
11761174 v2 <- unsafeNew n
11771175 let {-# INLINE_INNER put #-}
@@ -1192,7 +1190,7 @@ partitionWithUnknown :: forall m v u a b c.
11921190 => (a -> Either b c ) -> Bundle u a -> m (v (PrimState m ) b , v (PrimState m ) c )
11931191{-# INLINE partitionWithUnknown #-}
11941192partitionWithUnknown f s
1195- = do
1193+ = stToPrim $ do
11961194 v1 <- unsafeNew 0
11971195 v2 <- unsafeNew 0
11981196 (v1', n1, v2', n2) <- Bundle. foldM' put (v1, 0 , v2, 0 ) s
@@ -1202,14 +1200,14 @@ partitionWithUnknown f s
12021200 where
12031201 put :: (v (PrimState m ) b , Int , v (PrimState m ) c , Int )
12041202 -> a
1205- -> m (v (PrimState m ) b , Int , v (PrimState m ) c , Int )
1203+ -> ST ( PrimState m ) (v (PrimState m ) b , Int , v (PrimState m ) c , Int )
12061204 {-# INLINE_INNER put #-}
12071205 put (v1, i1, v2, i2) x = case f x of
12081206 Left b -> do
1209- v1' <- unsafeAppend1 v1 i1 b
1207+ v1' <- stToPrim $ unsafeAppend1 v1 i1 b
12101208 return (v1', i1+ 1 , v2, i2)
12111209 Right c -> do
1212- v2' <- unsafeAppend1 v2 i2 c
1210+ v2' <- stToPrim $ unsafeAppend1 v2 i2 c
12131211 return (v1, i1, v2', i2+ 1 )
12141212
12151213-- Modifying vectors
0 commit comments