Skip to content

Commit 2d26cdd

Browse files
committed
Execute functions in ST where possible
Idea is to save work for GHC. Instead of having >>= as type class method for some polymorphic `m` we work in ST which GHC knows how to compile and optimize very well. No changes to exposed API Reduces allocation during compilation of Data.Vector.Generic.Mutable by ~20%. No measurable changes in other modules
1 parent 7ad0e10 commit 2d26cdd

1 file changed

Lines changed: 37 additions & 39 deletions

File tree

vector/src/Data/Vector/Generic/Mutable.hs

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ module Data.Vector.Generic.Mutable (
7777
PrimMonad, PrimState, RealWorld
7878
) where
7979

80+
import Control.Monad.ST
8081
import Data.Vector.Generic.Mutable.Base
8182
import qualified Data.Vector.Generic.Base as V
8283

@@ -89,7 +90,7 @@ import Data.Vector.Fusion.Bundle.Size
8990
import Data.Vector.Fusion.Util ( delay_inline )
9091
import Data.Vector.Internal.Check
9192

92-
import Control.Monad.Primitive ( PrimMonad(..), RealWorld, stToPrim )
93+
import Control.Monad.Primitive ( PrimMonad(..), stToPrim )
9394

9495
import Prelude
9596
( Ord, Monad, Bool(..), Int, Maybe(..), Either(..), Ordering(..)
@@ -104,8 +105,7 @@ import Data.Bits ( Bits(shiftR) )
104105
-- Internal functions
105106
-- ------------------
106107

107-
unsafeAppend1 :: (PrimMonad m, MVector v a)
108-
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
108+
unsafeAppend1 :: (MVector v a) => v s a -> Int -> a -> ST s (v s a)
109109
{-# INLINE_INNER unsafeAppend1 #-}
110110
-- NOTE: The case distinction has to be on the outside because
111111
-- GHC creates a join point for the unsafeWrite even when everything
@@ -120,8 +120,7 @@ unsafeAppend1 v i x
120120
checkIndex Internal i (length v') $ unsafeWrite v' i x
121121
return v'
122122

123-
unsafePrepend1 :: (PrimMonad m, MVector v a)
124-
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
123+
unsafePrepend1 :: (MVector v a) => v s a -> Int -> a -> ST s (v s a, Int)
125124
{-# INLINE_INNER unsafePrepend1 #-}
126125
unsafePrepend1 v i x
127126
| i /= 0 = do
@@ -205,7 +204,7 @@ unstream :: (PrimMonad m, MVector v a)
205204
=> Bundle u a -> m (v (PrimState m) a)
206205
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
207206
{-# INLINE_FUSED unstream #-}
208-
unstream s = munstream (Bundle.lift s)
207+
unstream s = stToPrim $ munstream (Bundle.lift s)
209208

210209
-- | Create a new mutable vector and fill it with elements from the monadic
211210
-- stream. The vector will grow exponentially if the maximum size of the stream
@@ -241,9 +240,8 @@ munstreamUnknown s
241240
$ unsafeSlice 0 n v'
242241
where
243242
{-# INLINE_INNER put #-}
244-
put (v,i) x = do
245-
v' <- unsafeAppend1 v i x
246-
return (v',i+1)
243+
put (v,i) x = stToPrim $ do v' <- unsafeAppend1 v i x
244+
return (v',i+1)
247245

248246

249247
-- | Create a new mutable vector and fill it with elements from the 'Bundle'.
@@ -253,7 +251,7 @@ vunstream :: (PrimMonad m, V.Vector v a)
253251
=> Bundle v a -> m (V.Mutable v (PrimState m) a)
254252
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
255253
{-# INLINE_FUSED vunstream #-}
256-
vunstream s = vmunstream (Bundle.lift s)
254+
vunstream s = stToPrim $ vmunstream (Bundle.lift s)
257255

258256
-- | Create a new mutable vector and fill it with elements from the monadic
259257
-- stream. The vector will grow exponentially if the maximum size of the stream
@@ -309,7 +307,7 @@ unstreamR :: (PrimMonad m, MVector v a)
309307
=> Bundle u a -> m (v (PrimState m) a)
310308
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstream)
311309
{-# INLINE_FUSED unstreamR #-}
312-
unstreamR s = munstreamR (Bundle.lift s)
310+
unstreamR s = stToPrim $ munstreamR (Bundle.lift s)
313311

314312
-- | Create a new mutable vector and fill it with elements from the monadic
315313
-- stream from right to left. The vector will grow exponentially if the maximum
@@ -348,7 +346,7 @@ munstreamRUnknown s
348346
$ unsafeSlice i (n-i) v'
349347
where
350348
{-# INLINE_INNER put #-}
351-
put (v,i) x = unsafePrepend1 v i x
349+
put (v,i) x = stToPrim $ unsafePrepend1 v i x
352350

353351
-- Length
354352
-- ------
@@ -561,10 +559,9 @@ enlarge_delta :: MVector v a => v s a -> Int
561559
enlarge_delta v = max (length v) 1
562560

563561
-- | Grow a vector logarithmically.
564-
enlarge :: (PrimMonad m, MVector v a)
565-
=> v (PrimState m) a -> m (v (PrimState m) a)
562+
enlarge :: (MVector v a) => v s a -> ST s (v s a)
566563
{-# INLINE enlarge #-}
567-
enlarge v = stToPrim $ do
564+
enlarge v = do
568565
vnew <- unsafeGrow v by
569566
basicInitialize $ basicUnsafeSlice (length v) by vnew
570567
return vnew
@@ -994,10 +991,10 @@ unsafeMove dst src = check Unsafe "length mismatch" (length dst == length src)
994991
accum :: forall m v a b u. (HasCallStack, PrimMonad m, MVector v a)
995992
=> (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
996993
{-# INLINE accum #-}
997-
accum f !v s = Bundle.mapM_ upd s
994+
accum f !v s = stToPrim $ Bundle.mapM_ upd s
998995
where
999996
{-# INLINE_INNER upd #-}
1000-
upd :: HasCallStack => (Int, b) -> m ()
997+
upd :: HasCallStack => (Int, b) -> ST (PrimState m) ()
1001998
upd (i,b) = do
1002999
a <- checkIndex Bounds i n $ unsafeRead v i
10031000
unsafeWrite v i (f a b)
@@ -1006,18 +1003,18 @@ accum f !v s = Bundle.mapM_ upd s
10061003
update :: forall m v a u. (HasCallStack, PrimMonad m, MVector v a)
10071004
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
10081005
{-# INLINE update #-}
1009-
update !v s = Bundle.mapM_ upd s
1006+
update !v s = stToPrim $ Bundle.mapM_ upd s
10101007
where
10111008
{-# INLINE_INNER upd #-}
1012-
upd :: HasCallStack => (Int, a) -> m ()
1009+
upd :: HasCallStack => (Int, a) -> ST (PrimState m) ()
10131010
upd (i,b) = checkIndex Bounds i n $ unsafeWrite v i b
10141011

10151012
!n = length v
10161013

10171014
unsafeAccum :: (PrimMonad m, MVector v a)
10181015
=> (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
10191016
{-# INLINE unsafeAccum #-}
1020-
unsafeAccum f !v s = Bundle.mapM_ upd s
1017+
unsafeAccum f !v s = stToPrim $ Bundle.mapM_ upd s
10211018
where
10221019
{-# INLINE_INNER upd #-}
10231020
upd (i,b) = do
@@ -1026,17 +1023,17 @@ unsafeAccum f !v s = Bundle.mapM_ upd s
10261023
!n = length v
10271024

10281025
unsafeUpdate :: (PrimMonad m, MVector v a)
1029-
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
1026+
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
10301027
{-# INLINE unsafeUpdate #-}
1031-
unsafeUpdate !v s = Bundle.mapM_ upd s
1028+
unsafeUpdate !v s = stToPrim $ Bundle.mapM_ upd s
10321029
where
10331030
{-# INLINE_INNER upd #-}
10341031
upd (i,b) = checkIndex Unsafe i n $ unsafeWrite v i b
10351032
!n = length v
10361033

10371034
reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
10381035
{-# INLINE reverse #-}
1039-
reverse !v = reverse_loop 0 (length v - 1)
1036+
reverse !v = stToPrim $ reverse_loop 0 (length v - 1)
10401037
where
10411038
reverse_loop i j | i < j = do
10421039
unsafeSwap v i j
@@ -1046,11 +1043,11 @@ reverse !v = reverse_loop 0 (length v - 1)
10461043
unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
10471044
=> (a -> Bool) -> v (PrimState m) a -> m Int
10481045
{-# INLINE unstablePartition #-}
1049-
unstablePartition f !v = from_left 0 (length v)
1046+
unstablePartition f !v = stToPrim $ from_left 0 (length v)
10501047
where
10511048
-- NOTE: GHC 6.10.4 panics without the signatures on from_left and
10521049
-- from_right
1053-
from_left :: Int -> Int -> m Int
1050+
from_left :: Int -> Int -> ST (PrimState m) Int
10541051
from_left i j
10551052
| i == j = return i
10561053
| otherwise = do
@@ -1059,7 +1056,7 @@ unstablePartition f !v = from_left 0 (length v)
10591056
then from_left (i+1) j
10601057
else from_right i (j-1)
10611058

1062-
from_right :: Int -> Int -> m Int
1059+
from_right :: Int -> Int -> ST (PrimState m) Int
10631060
from_right i j
10641061
| i == j = return i
10651062
| otherwise = do
@@ -1076,7 +1073,8 @@ unstablePartitionBundle :: (PrimMonad m, MVector v a)
10761073
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
10771074
{-# INLINE unstablePartitionBundle #-}
10781075
unstablePartitionBundle f s
1079-
= case upperBound (Bundle.size s) of
1076+
= stToPrim
1077+
$ case upperBound (Bundle.size s) of
10801078
Just n -> unstablePartitionMax f s n
10811079
Nothing -> partitionUnknown f s
10821080

@@ -1085,7 +1083,7 @@ unstablePartitionMax :: (PrimMonad m, MVector v a)
10851083
-> m (v (PrimState m) a, v (PrimState m) a)
10861084
{-# INLINE unstablePartitionMax #-}
10871085
unstablePartitionMax f s n
1088-
= do
1086+
= stToPrim $ do
10891087
v <- checkLength Internal n $ unsafeNew n
10901088
let {-# INLINE_INNER put #-}
10911089
put (i, j) x
@@ -1103,15 +1101,15 @@ partitionBundle :: (PrimMonad m, MVector v a)
11031101
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
11041102
{-# INLINE partitionBundle #-}
11051103
partitionBundle f s
1106-
= case upperBound (Bundle.size s) of
1104+
= stToPrim
1105+
$ case upperBound (Bundle.size s) of
11071106
Just n -> partitionMax f s n
11081107
Nothing -> partitionUnknown f s
11091108

11101109
partitionMax :: (PrimMonad m, MVector v a)
11111110
=> (a -> Bool) -> Bundle u a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
11121111
{-# INLINE partitionMax #-}
1113-
partitionMax f s n
1114-
= do
1112+
partitionMax f s n = stToPrim $ do
11151113
v <- checkLength Internal n $ unsafeNew n
11161114

11171115
let {-# INLINE_INNER put #-}
@@ -1136,8 +1134,7 @@ partitionMax f s n
11361134
partitionUnknown :: (PrimMonad m, MVector v a)
11371135
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
11381136
{-# INLINE partitionUnknown #-}
1139-
partitionUnknown f s
1140-
= do
1137+
partitionUnknown f s = stToPrim $ do
11411138
v1 <- unsafeNew 0
11421139
v2 <- unsafeNew 0
11431140
(v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
@@ -1163,15 +1160,16 @@ partitionWithBundle :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
11631160
=> (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c)
11641161
{-# INLINE partitionWithBundle #-}
11651162
partitionWithBundle f s
1166-
= case upperBound (Bundle.size s) of
1163+
= stToPrim
1164+
$ case upperBound (Bundle.size s) of
11671165
Just n -> partitionWithMax f s n
11681166
Nothing -> partitionWithUnknown f s
11691167

11701168
partitionWithMax :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
11711169
=> (a -> Either b c) -> Bundle u a -> Int -> m (v (PrimState m) b, v (PrimState m) c)
11721170
{-# INLINE partitionWithMax #-}
11731171
partitionWithMax f s n
1174-
= do
1172+
= stToPrim $ do
11751173
v1 <- unsafeNew n
11761174
v2 <- unsafeNew n
11771175
let {-# INLINE_INNER put #-}
@@ -1192,7 +1190,7 @@ partitionWithUnknown :: forall m v u a b c.
11921190
=> (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c)
11931191
{-# INLINE partitionWithUnknown #-}
11941192
partitionWithUnknown f s
1195-
= do
1193+
= stToPrim $ do
11961194
v1 <- unsafeNew 0
11971195
v2 <- unsafeNew 0
11981196
(v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
@@ -1202,14 +1200,14 @@ partitionWithUnknown f s
12021200
where
12031201
put :: (v (PrimState m) b, Int, v (PrimState m) c, Int)
12041202
-> a
1205-
-> m (v (PrimState m) b, Int, v (PrimState m) c, Int)
1203+
-> ST (PrimState m) (v (PrimState m) b, Int, v (PrimState m) c, Int)
12061204
{-# INLINE_INNER put #-}
12071205
put (v1, i1, v2, i2) x = case f x of
12081206
Left b -> do
1209-
v1' <- unsafeAppend1 v1 i1 b
1207+
v1' <- stToPrim $ unsafeAppend1 v1 i1 b
12101208
return (v1', i1+1, v2, i2)
12111209
Right c -> do
1212-
v2' <- unsafeAppend1 v2 i2 c
1210+
v2' <- stToPrim $ unsafeAppend1 v2 i2 c
12131211
return (v1, i1, v2', i2+1)
12141212

12151213
-- Modifying vectors

0 commit comments

Comments
 (0)