Skip to content

Commit 98f8e52

Browse files
taimoorzaeemsteve-chavez
authored andcommitted
refactor: remove auth and logging middleware
This commit removes auth middleware for it hides side effects and obscures logic. The auth operations are now done in its own stage in the request-response cycle. It also removes the logging middleware because now we instead use observation module to log the response.
1 parent 0bba1d2 commit 98f8e52

9 files changed

Lines changed: 164 additions & 146 deletions

File tree

postgrest.cabal

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ library
7070
PostgREST.Listener
7171
PostgREST.Logger
7272
PostgREST.MainTx
73+
PostgREST.Logger.Apache
7374
PostgREST.MediaType
7475
PostgREST.Metrics
7576
PostgREST.Network
@@ -116,6 +117,7 @@ library
116117
, directory >= 1.2.6 && < 1.4
117118
, either >= 4.4.1 && < 5.1
118119
, extra >= 1.7.0 && < 2.0
120+
, fast-logger >= 3.2.0 && < 3.3
119121
, fuzzyset >= 0.2.4 && < 0.3
120122
, hasql >= 1.9 && <= 1.9.3.1
121123
, hasql-dynamic-statements >= 0.3.1 && <= 0.3.1.8
@@ -147,7 +149,6 @@ library
147149
, time >= 1.6 && < 1.15
148150
, unordered-containers >= 0.2.8 && < 0.3
149151
, unix-compat >= 0.5.4 && < 0.8
150-
, vault >= 0.3.1.5 && < 0.4
151152
, vector >= 0.11 && < 0.14
152153
, wai >= 3.2.1 && < 3.3
153154
, wai-cors >= 0.2.5 && < 0.3

src/PostgREST/ApiRequest.hs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module PostgREST.ApiRequest
88
( ApiRequest(..)
99
, userApiRequest
1010
, userPreferences
11+
, userBearerAuth
1112
) where
1213

1314
import qualified Data.CaseInsensitive as CI
@@ -16,13 +17,15 @@ import qualified Data.List.NonEmpty as NonEmptyList
1617
import qualified Data.Set as S
1718
import qualified Data.Text.Encoding as T
1819

19-
import Data.List (lookup)
20-
import Data.Ranged.Ranges (emptyRange, rangeIntersection,
21-
rangeIsEmpty)
22-
import Network.HTTP.Types.Header (RequestHeaders, hCookie)
23-
import Network.Wai (Request (..))
24-
import Network.Wai.Parse (parseHttpAccept)
25-
import Web.Cookie (parseCookies)
20+
import Data.List (lookup)
21+
import Data.Ranged.Ranges (emptyRange, rangeIntersection,
22+
rangeIsEmpty)
23+
import Network.HTTP.Types.Header (RequestHeaders,
24+
hAuthorization, hCookie)
25+
import Network.Wai (Request (..))
26+
import Network.Wai.Middleware.HttpAuth (extractBearerAuth)
27+
import Network.Wai.Parse (parseHttpAccept)
28+
import Web.Cookie (parseCookies)
2629

2730
import PostgREST.ApiRequest.Payload (getPayload)
2831
import PostgREST.ApiRequest.QueryParams (QueryParams (..))
@@ -114,6 +117,10 @@ userApiRequest conf prefs req reqBody = do
114117
userPreferences :: AppConfig -> Request -> TimezoneNames -> Preferences.Preferences
115118
userPreferences conf req timezones = Preferences.fromHeaders (configDbTxAllowOverride conf) timezones $ requestHeaders req
116119

120+
-- | Obtains the Bearer Auth
121+
userBearerAuth :: Request -> Maybe ByteString
122+
userBearerAuth req = extractBearerAuth =<< lookup hAuthorization (requestHeaders req)
123+
117124
getResource :: AppConfig -> [Text] -> Either ApiRequestError Resource
118125
getResource AppConfig{configOpenApiMode, configDbRootSpec} = \case
119126
[] ->

src/PostgREST/App.hs

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Some of its functionality includes:
99
- Producing HTTP Headers according to RFCs.
1010
- Content Negotiation
1111
-}
12+
{-# LANGUAGE NamedFieldPuns #-}
1213
{-# LANGUAGE RecordWildCards #-}
1314
{-# LANGUAGE ScopedTypeVariables #-}
1415
{-# LANGUAGE ViewPatterns #-}
@@ -24,7 +25,6 @@ import System.IO.Error (ioeGetErrorType)
2425
import Control.Monad.Except (liftEither)
2526
import Control.Monad.Extra (whenJust)
2627
import Data.Either.Combinators (mapLeft, whenLeft)
27-
import Data.Maybe (fromJust)
2828
import Data.String (IsString (..))
2929
import Network.Wai.Handler.Warp (defaultSettings, setHost,
3030
setOnException, setPort,
@@ -33,6 +33,7 @@ import Network.Wai.Handler.Warp (defaultSettings, setHost,
3333
import qualified Data.Text.Encoding as T
3434
import qualified Network.Wai as Wai
3535
import qualified Network.Wai.Handler.Warp as Warp
36+
import qualified Network.Wai.Header as WaiHeader
3637

3738
import qualified PostgREST.Admin as Admin
3839
import qualified PostgREST.ApiRequest as ApiRequest
@@ -41,7 +42,6 @@ import qualified PostgREST.Auth as Auth
4142
import qualified PostgREST.Cors as Cors
4243
import qualified PostgREST.Error as Error
4344
import qualified PostgREST.Listener as Listener
44-
import qualified PostgREST.Logger as Logger
4545
import qualified PostgREST.MainTx as MainTx
4646
import qualified PostgREST.Plan as Plan
4747
import qualified PostgREST.Query as Query
@@ -51,7 +51,7 @@ import qualified PostgREST.Unix as Unix (installSignalHandlers)
5151
import PostgREST.ApiRequest (ApiRequest (..))
5252
import PostgREST.AppState (AppState)
5353
import PostgREST.Auth.Types (AuthResult (..))
54-
import PostgREST.Config (AppConfig (..), LogLevel (..))
54+
import PostgREST.Config (AppConfig (..))
5555
import PostgREST.Error (Error)
5656
import PostgREST.Network (resolveSocketToAddress)
5757
import PostgREST.Observation (Observation (..))
@@ -72,11 +72,9 @@ import qualified Network.Socket as NS
7272
import PostgREST.Unix (createAndBindDomainSocket)
7373
import Protolude hiding (Handler)
7474

75-
type Handler = ExceptT Error
76-
7775
run :: AppState -> IO ()
7876
run appState = do
79-
conf@AppConfig{..} <- AppState.getConfig appState
77+
conf <- AppState.getConfig appState
8078

8179
AppState.schemaCacheLoader appState -- Loads the initial SchemaCache
8280
(mainSocket, adminSocket) <- initSockets conf
@@ -89,7 +87,7 @@ run appState = do
8987

9088
Admin.runAdmin appState adminSocket mainSocket (serverSettings conf)
9189

92-
let app = postgrest configLogLevel appState (AppState.schemaCacheLoader appState)
90+
let app = postgrest appState (AppState.schemaCacheLoader appState)
9391

9492
do
9593
address <- resolveSocketToAddress mainSocket
@@ -122,48 +120,59 @@ serverSettings AppConfig{..} =
122120
& setServerName ("postgrest/" <> prettyVersion)
123121

124122
-- | PostgREST application
125-
postgrest :: LogLevel -> AppState.AppState -> IO () -> Wai.Application
126-
postgrest logLevel appState connWorker =
123+
postgrest :: AppState.AppState -> IO () -> Wai.Application
124+
postgrest appState connWorker =
127125
traceHeaderMiddleware appState .
128-
Cors.middleware appState .
129-
Auth.middleware appState .
130-
Logger.middleware logLevel Auth.getRole $
131-
-- fromJust can be used, because the auth middleware will **always** add
132-
-- some AuthResult to the vault.
126+
Cors.middleware appState $
133127
\req respond -> do
134128
appConf@AppConfig{..} <- AppState.getConfig appState -- the config must be read again because it can reload
135-
case fromJust $ Auth.getResult req of
136-
Left err -> respond $ Error.errorResponseFor configClientErrorVerbosity err
137-
Right authResult -> do
138-
maybeSchemaCache <- AppState.getSchemaCache appState
139-
140-
let
141-
eitherResponse :: IO (Either Error Wai.Response)
142-
eitherResponse =
143-
runExceptT $ postgrestResponse appState appConf maybeSchemaCache authResult req
144-
145-
response <- either (Error.errorResponseFor configClientErrorVerbosity) identity <$> eitherResponse
146-
-- Launch the connWorker when the connection is down. The postgrest
147-
-- function can respond successfully (with a stale schema cache) before
148-
-- the connWorker is done. However, when there's an empty schema cache
149-
-- postgrest responds with the error `PGRST002`; this means that the schema
150-
-- cache is still loading, so we don't launch the connWorker here because
151-
-- it would duplicate the loading process, e.g. https://github.com/PostgREST/postgrest/issues/3704
152-
-- TODO: this process may be unnecessary when the Listener is enabled. Revisit once https://github.com/PostgREST/postgrest/issues/1766 is done
153-
when (isServiceUnavailable response && isJust maybeSchemaCache) connWorker
154-
resp <- do
155-
delay <- AppState.getNextDelay appState
156-
return $ addRetryHint delay response
157-
respond resp
129+
maybeSchemaCache <- AppState.getSchemaCache appState
130+
131+
let observer = AppState.getObserver appState
132+
bearerAuth = ApiRequest.userBearerAuth req
133+
134+
response <- do
135+
authResultE <- runExceptT $ withTiming appConf $
136+
liftIO (Auth.getAuthResult appState bearerAuth) >>= liftEither
137+
138+
case authResultE of
139+
Left err -> do
140+
let resp = Error.errorResponseFor configClientErrorVerbosity err
141+
observer $ genResponseObs Nothing req resp
142+
pure resp
143+
144+
Right (jwtTime, authResult@AuthResult{..}) -> do
145+
resp <- either (Error.errorResponseFor configClientErrorVerbosity) identity <$>
146+
runExceptT (postgrestResponse appState appConf maybeSchemaCache jwtTime authResult req)
147+
observer $ genResponseObs (Just authRole) req resp
148+
pure resp
149+
150+
-- Launch the connWorker when the connection is down. The postgrest
151+
-- function can respond successfully (with a stale schema cache) before
152+
-- the connWorker is done. However, when there's an empty schema cache
153+
-- postgrest responds with the error `PGRST002`; this means that the schema
154+
-- cache is still loading, so we don't launch the connWorker here because
155+
-- it would duplicate the loading process, e.g. https://github.com/PostgREST/postgrest/issues/3704
156+
-- TODO: this process may be unnecessary when the Listener is enabled. Revisit once https://github.com/PostgREST/postgrest/issues/1766 is done
157+
when (isServiceUnavailable response && isJust maybeSchemaCache) connWorker
158+
delay <- AppState.getNextDelay appState
159+
respond $ addRetryHint delay response
160+
where
161+
-- TODO WaiHeader.contentLength does a lookup everytime, see: https://hackage.haskell.org/package/wai-extra-3.1.17/docs/src/Network.Wai.Header.html#contentLength
162+
-- It might be possible to gain some perf by returning the response length from `postgrestResponse`. We calculate the length manually on Response.hs.
163+
genResponseObs :: Maybe ByteString -> Wai.Request -> Wai.Response -> Observation
164+
genResponseObs user req resp =
165+
ResponseObs user req (Wai.responseStatus resp) (WaiHeader.contentLength $ Wai.responseHeaders resp)
158166

159167
postgrestResponse
160168
:: AppState.AppState
161169
-> AppConfig
162170
-> Maybe SchemaCache
171+
-> Maybe Double
163172
-> AuthResult
164173
-> Wai.Request
165-
-> Handler IO Wai.Response
166-
postgrestResponse appState conf@AppConfig{..} maybeSchemaCache authResult@AuthResult{..} req = do
174+
-> ExceptT Error IO Wai.Response
175+
postgrestResponse appState conf@AppConfig{..} maybeSchemaCache jwtTime authResult@AuthResult{..} req = do
167176
let observer = AppState.getObserver appState
168177

169178
sCache <-
@@ -174,20 +183,18 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache authResult@AuthRe
174183
lift $ observer SchemaCacheEmptyObs
175184
throwError Error.NoSchemaCacheError
176185

177-
body <- lift $ Wai.strictRequestBody req
186+
let prefs = ApiRequest.userPreferences conf req (dbTimezones sCache)
178187

179-
let jwtTime = if configServerTimingEnabled then Auth.getJwtDur req else Nothing
180-
timezones = dbTimezones sCache
181-
prefs = ApiRequest.userPreferences conf req timezones
188+
body <- lift $ Wai.strictRequestBody req
182189

183-
(parseTime, apiReq@ApiRequest{..}) <- withTiming $ liftEither . mapLeft Error.ApiRequestErr $ ApiRequest.userApiRequest conf prefs req body
184-
(planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache
190+
(parseTime, apiReq@ApiRequest{..}) <- withTiming conf $ liftEither . mapLeft Error.ApiRequestErr $ ApiRequest.userApiRequest conf prefs req body
191+
(planTime, plan) <- withTiming conf $ liftEither $ Plan.actionPlan iAction conf apiReq sCache
185192

186193
let mainQ = Query.mainQuery plan conf apiReq authResult configDbPreRequest
187194
tx = MainTx.mainTx mainQ conf authResult apiReq plan sCache
188195
obsQuery s = when configLogQuery $ observer $ QueryObs mainQ s
189196

190-
(txTime, txResult) <- withTiming $ do
197+
(txTime, txResult) <- withTiming conf $ do
191198
case tx of
192199
MainTx.NoDbTx r -> pure r
193200
MainTx.DbTx{..} -> do
@@ -200,7 +207,7 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache authResult@AuthRe
200207
lift $ whenLeft eitherResp $ obsQuery . Error.status
201208
liftEither eitherResp
202209

203-
(respTime, resp) <- withTiming $ do
210+
(respTime, resp) <- withTiming conf $ do
204211
let response = Response.actionResponse txResult apiReq (T.decodeUtf8 prettyVersion, docsVersion) conf sCache
205212
status' = either Error.status Response.pgrstStatus response
206213

@@ -224,14 +231,14 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache authResult@AuthRe
224231
varyHeaderPresent :: [HTTP.Header] -> Bool
225232
varyHeaderPresent = any (\(h, _v) -> h == HTTP.hVary)
226233

227-
withTiming :: Handler IO a -> Handler IO (Maybe Double, a)
228-
withTiming f = if configServerTimingEnabled
229-
then do
230-
(t, r) <- timeItT f
231-
pure (Just t, r)
232-
else do
233-
r <- f
234-
pure (Nothing, r)
234+
withTiming :: AppConfig -> ExceptT e IO a -> ExceptT e IO (Maybe Double, a)
235+
withTiming AppConfig{configServerTimingEnabled} f = if configServerTimingEnabled
236+
then do
237+
(t, r) <- timeItT f
238+
pure (Just t, r)
239+
else do
240+
r <- f
241+
pure (Nothing, r)
235242

236243
traceHeaderMiddleware :: AppState -> Wai.Middleware
237244
traceHeaderMiddleware appState app req respond = do

src/PostgREST/Auth.hs

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
{-# LANGUAGE RecordWildCards #-}
21
{-|
32
Module : PostgREST.Auth
43
Description : PostgREST authentication functions.
@@ -12,66 +11,28 @@ In the test suite there is an example of simple login function that can be used
1211
very simple authentication system inside the PostgreSQL database.
1312
-}
1413
module PostgREST.Auth
15-
( getResult
16-
, getJwtDur
17-
, getRole
18-
, middleware
19-
) where
20-
21-
import qualified Data.ByteString as BS
22-
import qualified Data.Vault.Lazy as Vault
23-
import qualified Network.HTTP.Types.Header as HTTP
24-
import qualified Network.Wai as Wai
25-
import qualified Network.Wai.Middleware.HttpAuth as Wai
26-
27-
import Data.List (lookup)
28-
import PostgREST.TimeIt (timeItT)
29-
import System.IO.Unsafe (unsafePerformIO)
14+
( getAuthResult )
15+
where
3016

3117
import PostgREST.AppState (AppState, getConfig, getJwtCacheState,
3218
getTime)
3319
import PostgREST.Auth.Jwt (parseClaims)
3420
import PostgREST.Auth.JwtCache (lookupJwtCache)
35-
import PostgREST.Auth.Types (AuthResult (..))
36-
import PostgREST.Config (AppConfig (..))
37-
import PostgREST.Error (Error (..))
21+
import PostgREST.Auth.Types (AuthResult)
22+
import PostgREST.Error (Error)
3823

3924
import Protolude
4025

41-
-- | Validate authorization header
42-
-- Parse and store JWT claims for future use in the request.
43-
middleware :: AppState -> Wai.Middleware
44-
middleware appState app req respond = do
45-
conf@AppConfig{..} <- getConfig appState
26+
-- | Perform authentication and authorization
27+
-- Parse JWT and return AuthResult
28+
getAuthResult :: AppState -> Maybe ByteString -> IO (Either Error AuthResult)
29+
getAuthResult appState token = do
30+
conf <- getConfig appState
4631
time <- getTime appState
4732

48-
let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
49-
parseJwt = runExceptT $ lookupJwtCache jwtCacheState token >>= parseClaims conf time
50-
jwtCacheState = getJwtCacheState appState
51-
52-
-- If ServerTimingEnabled -> calculate JWT validation time
53-
req' <- if configServerTimingEnabled then do
54-
(dur, authResult) <- timeItT parseJwt
55-
pure $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
56-
else do
57-
authResult <- parseJwt
58-
pure $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
59-
60-
app req' respond
61-
62-
authResultKey :: Vault.Key (Either Error AuthResult)
63-
authResultKey = unsafePerformIO Vault.newKey
64-
{-# NOINLINE authResultKey #-}
65-
66-
getResult :: Wai.Request -> Maybe (Either Error AuthResult)
67-
getResult = Vault.lookup authResultKey . Wai.vault
68-
69-
jwtDurKey :: Vault.Key Double
70-
jwtDurKey = unsafePerformIO Vault.newKey
71-
{-# NOINLINE jwtDurKey #-}
72-
73-
getJwtDur :: Wai.Request -> Maybe Double
74-
getJwtDur = Vault.lookup jwtDurKey . Wai.vault
33+
let jwtCacheState = getJwtCacheState appState
34+
parseJwt = runExceptT $ do
35+
claims <- lookupJwtCache jwtCacheState token
36+
parseClaims conf time claims
7537

76-
getRole :: Wai.Request -> Maybe BS.ByteString
77-
getRole req = authRole <$> (rightToMaybe =<< getResult req)
38+
parseJwt

0 commit comments

Comments
 (0)