|
1 | 1 | {-# LANGUAGE AllowAmbiguousTypes #-} |
| 2 | +{-# LANGUAGE DeriveAnyClass #-} |
2 | 3 | {-# LANGUAGE ExistentialQuantification #-} |
3 | 4 | {-# LANGUAGE FlexibleContexts #-} |
| 5 | +{-# LANGUAGE FlexibleInstances #-} |
4 | 6 | {-# LANGUAGE RankNTypes #-} |
5 | 7 | {-# LANGUAGE ScopedTypeVariables #-} |
6 | 8 | {-# LANGUAGE TupleSections #-} |
7 | 9 | {-# LANGUAGE TypeApplications #-} |
| 10 | +{-# LANGUAGE TypeOperators #-} |
8 | 11 | module SpecHelper where |
9 | 12 |
|
10 | 13 | import Control.Lens ((^?)) |
@@ -35,17 +38,25 @@ import Test.Hspec |
35 | 38 | import Test.Hspec.Wai |
36 | 39 | import Text.Heredoc |
37 | 40 |
|
38 | | -import Data.String (String) |
39 | | -import PostgREST.Config (AppConfig (..), |
40 | | - JSPathExp (..), |
41 | | - LogLevel (..), |
42 | | - OpenAPIMode (..), |
43 | | - Verbosity (..), parseSecret) |
44 | | -import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..)) |
45 | | -import Prometheus (Counter, getCounter) |
46 | | -import Protolude hiding (get, toS) |
47 | | -import Protolude.Conv (toS) |
48 | | -import Test.Hspec.Expectations.Contrib (annotate) |
| 41 | +import qualified Data.List as DL |
| 42 | +import Data.String (String) |
| 43 | +import qualified Data.Text as T |
| 44 | +import qualified PostgREST.AppState as AppState |
| 45 | +import PostgREST.Config (AppConfig (..), |
| 46 | + JSPathExp (..), |
| 47 | + LogLevel (..), |
| 48 | + OpenAPIMode (..), |
| 49 | + Verbosity (..), |
| 50 | + parseSecret) |
| 51 | +import qualified PostgREST.Metrics as Metrics |
| 52 | +import PostgREST.Observation (Observation (..)) |
| 53 | +import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..)) |
| 54 | +import Prometheus (Counter, |
| 55 | + getCounter) |
| 56 | +import Protolude hiding (get, toS) |
| 57 | +import Protolude.Conv (toS) |
| 58 | +import System.Timeout (timeout) |
| 59 | +import Test.Hspec.Expectations.Contrib (annotate) |
49 | 60 |
|
50 | 61 | filterAndMatchCT :: BS.ByteString -> MatchHeader |
51 | 62 | filterAndMatchCT val = MatchHeader $ \headers _ -> |
@@ -381,3 +392,72 @@ expectCounter :: forall s st m. (KnownSymbol s, HasField s st Counter, MonadIO m |
381 | 392 | expectCounter = expectField @s intCounter |
382 | 393 | where |
383 | 394 | intCounter = ((round @Double @Int) <$>) . getCounter |
| 395 | + |
| 396 | +data TimeoutException = TimeoutException deriving (Show, Exception) |
| 397 | + |
| 398 | +accumulateUntilTimeout :: Int -> (s -> a -> s) -> s -> IO a -> IO s |
| 399 | +accumulateUntilTimeout t f start act = do |
| 400 | + tid <- myThreadId |
| 401 | + -- mask to make sure TimeoutException is not thrown before starting the loop |
| 402 | + mask $ \unmask -> do |
| 403 | + -- start timeout thread unmasking exceptions |
| 404 | + ttid <- forkIOWithUnmask ($ (threadDelay t *> throwTo tid TimeoutException)) |
| 405 | + -- unmask effect |
| 406 | + unmask (fix (\loop accum -> (act >>= loop . f accum) `onTimeout` pure accum) start) |
| 407 | + -- make sure we catch timeout if happens before entering the loop |
| 408 | + `onTimeout` pure start |
| 409 | + -- make sure timer thread is killed on other exceptions |
| 410 | + -- so that it won't throw TimeoutException later |
| 411 | + `onException` killThread ttid |
| 412 | + where |
| 413 | + onTimeout m a = m `catch` \TimeoutException -> a |
| 414 | + |
| 415 | +data ObsChan = ObsChan (Chan Observation) (Chan Observation) |
| 416 | + |
| 417 | +newObsChan :: Chan Observation -> IO ObsChan |
| 418 | +newObsChan = fmap <$> ObsChan <*> dupChan |
| 419 | + |
| 420 | +-- read messages from copy chan and once condition is met drain original to the same point |
| 421 | +-- upon timeout report error and messages remaining in the original chan |
| 422 | +-- that way we report messages since last successful read |
| 423 | +waitForObs :: HasCallStack => ObsChan -> Int -> Text -> (Observation -> Maybe a) -> IO () |
| 424 | +waitForObs (ObsChan orig copy) t msg f = |
| 425 | + timeout t (readUntil copy *> readUntil orig) >>= maybe failTimeout mempty |
| 426 | + where |
| 427 | + failTimeout = takeUntilTimeout decisecond (readChan orig) |
| 428 | + >>= expectationFailure . DL.unlines . fmap show . (failureMessageHeader :) . fmap obsDiagMessage |
| 429 | + failureMessageHeader = "Timeout waiting for " <> msg <> " at " <> loc <> ". Remaining observations:" |
| 430 | + readUntil = void . untilM (pure . not . null . f) . readChan |
| 431 | + loc = fromMaybe "(unknown)" . head $ (T.pack . prettySrcLoc . snd <$> getCallStack callStack) |
| 432 | + -- execute effectful computation until result meets provided condition |
| 433 | + untilM cond m = fix $ \loop -> m >>= \a -> ifM (cond a) (pure a) loop |
| 434 | + -- duplicate the provided channel and construct wairFor function binding both channels |
| 435 | + -- accumulate effecful computation results into a list for specified time |
| 436 | + takeUntilTimeout t' = fmap reverse . accumulateUntilTimeout t' (flip (:)) [] |
| 437 | + obsDiagMessage (HasqlPoolObs o) = show o |
| 438 | + obsDiagMessage o@(DBListenStart host port name channel) = constrName o <> show (host, port, name, channel) |
| 439 | + obsDiagMessage o = constrName o |
| 440 | + decisecond = 100000 |
| 441 | + |
| 442 | +data SpecState = SpecState { |
| 443 | + specAppState :: AppState.AppState, |
| 444 | + specMetrics :: Metrics.MetricsState, |
| 445 | + specObsChan :: ObsChan |
| 446 | +} |
| 447 | + |
| 448 | +-- helpers used to produce observation diagnostics in waitForObs |
| 449 | +constrName :: (HasConstructor (Rep a), Generic a)=> a -> Text |
| 450 | +constrName = genericConstrName . from |
| 451 | + |
| 452 | +class HasConstructor f where |
| 453 | + genericConstrName :: f x -> Text |
| 454 | + |
| 455 | +instance HasConstructor f => HasConstructor (D1 c f) where |
| 456 | + genericConstrName (M1 x) = genericConstrName x |
| 457 | + |
| 458 | +instance (HasConstructor x, HasConstructor y) => HasConstructor (x :+: y) where |
| 459 | + genericConstrName (L1 l) = genericConstrName l |
| 460 | + genericConstrName (R1 r) = genericConstrName r |
| 461 | + |
| 462 | +instance Constructor c => HasConstructor (C1 c f) where |
| 463 | + genericConstrName = T.pack . conName |
0 commit comments