11{-# LANGUAGE AllowAmbiguousTypes #-}
2+ {-# LANGUAGE DeriveAnyClass #-}
23{-# LANGUAGE ExistentialQuantification #-}
34{-# LANGUAGE FlexibleContexts #-}
5+ {-# LANGUAGE FlexibleInstances #-}
6+ {-# LANGUAGE LambdaCase #-}
7+ {-# LANGUAGE RankNTypes #-}
48{-# LANGUAGE ScopedTypeVariables #-}
59{-# LANGUAGE TupleSections #-}
610{-# LANGUAGE TypeApplications #-}
11+ {-# LANGUAGE TypeOperators #-}
712module ObsHelper where
813
9- import qualified Data.ByteString.Base64 as B64 (decodeLenient )
10- import qualified Data.ByteString.Char8 as BS
11- import qualified Data.ByteString.Lazy as BL
12- import qualified Jose.Jwa as JWT
13- import qualified Jose.Jws as JWT
14- import qualified Jose.Jwt as JWT
14+ import qualified Data.ByteString as BS
15+ import qualified Data.ByteString.Base64 as B64
16+ import qualified Data.ByteString.Lazy as BL
17+ import qualified Data.List as DL
18+ import Data.List.NonEmpty (fromList )
19+ import Data.String (String )
20+ import qualified Data.Text as T
21+ import qualified Jose.Jwa as JWT
22+ import qualified Jose.Jws as JWT
23+ import qualified Jose.Jwt as JWT
24+ import Network.HTTP.Types
25+ import qualified PostgREST.AppState as AppState
26+ import PostgREST.Config (AppConfig (.. ),
27+ JSPathExp (.. ),
28+ LogLevel (.. ),
29+ OpenAPIMode (.. ),
30+ Verbosity (.. ),
31+ parseSecret )
32+ import qualified PostgREST.Metrics as Metrics
33+ import PostgREST.Observation (Observation (.. ))
34+ import Prometheus (Counter , getCounter )
35+ import Protolude hiding (get , toS )
36+ import System.Timeout (timeout )
37+ import Test.Hspec
38+ import Test.Hspec.Expectations.Contrib (annotate )
39+
40+ -- helpers used to produce observation diagnostics in waitForObs
41+ class HasConstructor f where
42+ genericConstrName :: f x -> Text
43+
44+ instance HasConstructor f => HasConstructor (D1 c f ) where
45+ genericConstrName (M1 x) = genericConstrName x
46+
47+ instance (HasConstructor x , HasConstructor y ) => HasConstructor (x :+: y ) where
48+ genericConstrName (L1 l) = genericConstrName l
49+ genericConstrName (R1 r) = genericConstrName r
50+
51+ instance Constructor c => HasConstructor (C1 c f ) where
52+ genericConstrName = T. pack . conName
53+
54+ data SpecState = SpecState {
55+ specAppState :: AppState. AppState ,
56+ specMetrics :: Metrics. MetricsState ,
57+ specObsChan :: ObsChan
58+ }
1559
16- import PostgREST.Config (AppConfig (.. ), JSPathExp (.. ),
17- LogLevel (.. ), OpenAPIMode (.. ),
18- Verbosity (.. ), parseSecret )
60+ data StateCheck st m = forall a . StateCheck (st -> (String , m a )) (a -> a -> Expectation )
1961
20- import Data.List.NonEmpty (fromList )
21- import Data.String (String )
22- import Prometheus (Counter , getCounter )
23- import Test.Hspec.Expectations.Contrib (annotate )
62+ data TimeoutException = TimeoutException deriving (Show , Exception )
2463
25- import Network.HTTP.Types
26- import Protolude
27- import Test.Hspec
28- import Test.Hspec.Wai
64+ data ObsChan = ObsChan (Chan Observation ) (Chan Observation )
2965
66+ constrName :: (HasConstructor (Rep a ), Generic a )=> a -> Text
67+ constrName = genericConstrName . from
3068
3169baseCfg :: AppConfig
3270baseCfg = let secret = encodeUtf8 " reallyreallyreallyreallyverysafe" in
@@ -109,18 +147,12 @@ generateJWT claims =
109147 either mempty JWT. unJwt $ JWT. hmacEncode JWT. HS256 generateSecret (BL. toStrict claims)
110148
111149-- state check helpers
112-
113- data StateCheck st m = forall a . StateCheck (st -> (String , m a )) (a -> a -> Expectation )
114-
115150stateCheck :: (Show a , Eq a ) => (c -> m a ) -> (st -> (String , c )) -> (a -> a ) -> StateCheck st m
116151stateCheck extractValue extractComponent expect = StateCheck (second extractValue . extractComponent) (flip shouldBe . expect)
117152
118153expectField :: forall s st a c m . (KnownSymbol s , Show a , Eq a , HasField s st c ) => (c -> m a ) -> (a -> a ) -> StateCheck st m
119154expectField extractValue = stateCheck extractValue ((symbolVal (Proxy @ s ),) . getField @ s )
120155
121- checkState :: (Traversable t ) => t (StateCheck st (WaiSession st )) -> WaiSession st b -> WaiSession st ()
122- checkState checks act = getState >>= flip (`checkState'` checks) act
123-
124156checkState' :: (Traversable t , MonadIO m ) => st -> t (StateCheck st m ) -> m b -> m ()
125157checkState' initialState checks act = do
126158 expectations <- traverse (\ (StateCheck g expect) -> let (msg, m) = g initialState in m >>= createExpectation msg m . expect) checks
@@ -133,3 +165,48 @@ expectCounter :: forall s st m. (KnownSymbol s, HasField s st Counter, MonadIO m
133165expectCounter = expectField @ s intCounter
134166 where
135167 intCounter = ((round @ Double @ Int ) <$> ) . getCounter
168+
169+ accumulateUntilTimeout :: Int -> (s -> a -> s ) -> s -> IO a -> IO s
170+ accumulateUntilTimeout t f start act = do
171+ tid <- myThreadId
172+ -- mask to make sure TimeoutException is not thrown before starting the loop
173+ mask $ \ unmask -> do
174+ -- start timeout thread unmasking exceptions
175+ ttid <- forkIOWithUnmask ($ (threadDelay t *> throwTo tid TimeoutException ))
176+ -- unmask effect
177+ unmask (fix (\ loop accum -> (act >>= loop . f accum) `onTimeout` pure accum) start)
178+ -- make sure we catch timeout if happens before entering the loop
179+ `onTimeout` pure start
180+ -- make sure timer thread is killed on other exceptions
181+ -- so that it won't throw TimeoutException later
182+ `onException` killThread ttid
183+ where
184+ onTimeout m a = m `catch` \ TimeoutException -> a
185+
186+ newObsChan :: Chan Observation -> IO ObsChan
187+ newObsChan = fmap <$> ObsChan <*> dupChan
188+
189+ -- read messages from copy chan and once condition is met drain original to the same point
190+ -- upon timeout report error and messages remaining in the original chan
191+ -- that way we report messages since last successful read
192+ waitForObs :: HasCallStack => ObsChan -> Int -> Text -> (Observation -> Maybe a ) -> IO ()
193+ waitForObs (ObsChan orig copy) t msg f =
194+ timeout t (readUntil copy *> readUntil orig) >>= maybe failTimeout mempty
195+ where
196+ failTimeout = takeUntilTimeout decisecond (readChan orig)
197+ >>= expectationFailure . DL. unlines . fmap show . (failureMessageHeader : ) . fmap obsDiagMessage
198+ failureMessageHeader = " Timeout waiting for " <> msg <> " at " <> loc <> " . Remaining observations:"
199+ readUntil = void . untilM (pure . not . null . f) . readChan
200+ loc = fromMaybe " (unknown)" . head $ (T. pack . prettySrcLoc . snd <$> getCallStack callStack)
201+ -- execute effectful computation until result meets provided condition
202+ untilM cond m = fix $ \ loop -> m >>= \ a -> ifM (cond a) (pure a) loop
203+ -- duplicate the provided channel and construct wairFor function binding both channels
204+ -- accumulate effecful computation results into a list for specified time
205+ takeUntilTimeout t' = fmap reverse . accumulateUntilTimeout t' (flip (:) ) []
206+ decisecond = 100000
207+
208+ obsDiagMessage :: Observation -> Text
209+ obsDiagMessage = \ case
210+ (HasqlPoolObs o) -> show o
211+ o@ (DBListenStart host port name channel) -> constrName o <> show (host, port, name, channel)
212+ o -> constrName o
0 commit comments