Skip to content

Commit d2fb67f

Browse files
committed
feat: apply all function settings as transaction-scoped settings
1 parent d7246b4 commit d2fb67f

8 files changed

Lines changed: 113 additions & 62 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
88
### Added
99

1010
- #2887, Add Preference `max-affected` to limit affected resources - @taimoorzaeem
11+
- #3061, Apply all function settings as transaction-scoped settings - @taimoorzaeem
1112

1213
### Fixed
1314

src/PostgREST/App.hs

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,21 @@ import qualified PostgREST.Query as Query
4343
import qualified PostgREST.Response as Response
4444
import qualified PostgREST.Unix as Unix (installSignalHandlers)
4545

46-
import PostgREST.ApiRequest (Action (..), ApiRequest (..),
47-
Mutation (..), Target (..))
48-
import PostgREST.AppState (AppState)
49-
import PostgREST.Auth (AuthResult (..))
50-
import PostgREST.Config (AppConfig (..))
51-
import PostgREST.Config.PgVersion (PgVersion (..))
52-
import PostgREST.Error (Error)
53-
import PostgREST.Query (DbHandler)
54-
import PostgREST.Response.Performance (ServerTiming (..),
55-
serverTimingHeader)
56-
import PostgREST.SchemaCache (SchemaCache (..))
57-
import PostgREST.SchemaCache.Routine (Routine (..))
58-
import PostgREST.Version (docsVersion, prettyVersion)
46+
import PostgREST.ApiRequest (Action (..),
47+
ApiRequest (..),
48+
Mutation (..), Target (..))
49+
import PostgREST.AppState (AppState)
50+
import PostgREST.Auth (AuthResult (..))
51+
import PostgREST.Config (AppConfig (..))
52+
import PostgREST.Config.PgVersion (PgVersion (..))
53+
import PostgREST.Error (Error)
54+
import PostgREST.Query (DbHandler)
55+
import PostgREST.Response.Performance (ServerTiming (..),
56+
serverTimingHeader)
57+
import PostgREST.SchemaCache (SchemaCache (..))
58+
import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..))
59+
import PostgREST.SchemaCache.Routine (Routine (..))
60+
import PostgREST.Version (docsVersion, prettyVersion)
5961

6062
import qualified Data.ByteString.Char8 as BS
6163
import qualified Data.List as L
@@ -170,43 +172,44 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
170172
case (iAction, iTarget) of
171173
(ActionRead headersOnly, TargetIdent identifier) -> do
172174
(planTime', wrPlan) <- withTiming $ liftEither $ Plan.wrappedReadPlan identifier conf sCache apiReq
173-
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq
175+
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.wrTxMode wrPlan) mempty $ Query.readQuery wrPlan conf apiReq
174176
(respTime', pgrst) <- withTiming $ liftEither $ Response.readResponse wrPlan headersOnly identifier apiReq resultSet
175177
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst
176178

177179
(ActionMutate MutationCreate, TargetIdent identifier) -> do
178180
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationCreate apiReq identifier conf sCache
179-
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf
181+
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.createQuery mrPlan apiReq conf
180182
(respTime', pgrst) <- withTiming $ liftEither $ Response.createResponse identifier mrPlan apiReq resultSet
181183
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst
182184

183185
(ActionMutate MutationUpdate, TargetIdent identifier) -> do
184186
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationUpdate apiReq identifier conf sCache
185-
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf
187+
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.updateQuery mrPlan apiReq conf
186188
(respTime', pgrst) <- withTiming $ liftEither $ Response.updateResponse mrPlan apiReq resultSet
187189
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst
188190

189191
(ActionMutate MutationSingleUpsert, TargetIdent identifier) -> do
190192
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationSingleUpsert apiReq identifier conf sCache
191-
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf
193+
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.singleUpsertQuery mrPlan apiReq conf
192194
(respTime', pgrst) <- withTiming $ liftEither $ Response.singleUpsertResponse mrPlan apiReq resultSet
193195
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst
194196

195197
(ActionMutate MutationDelete, TargetIdent identifier) -> do
196198
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationDelete apiReq identifier conf sCache
197-
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf
199+
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) mempty $ Query.deleteQuery mrPlan apiReq conf
198200
(respTime', pgrst) <- withTiming $ liftEither $ Response.deleteResponse mrPlan apiReq resultSet
199201
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst
200202

201-
(ActionInvoke invMethod, TargetProc identifier _) -> do
203+
(ActionInvoke invMethod, TargetProc identifier@(QualifiedIdentifier _ proname) _) -> do
204+
let setting = [(y,z) | (x,y,z) <- funcSettings, x == encodeUtf8 proname]
202205
(planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod
203-
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
206+
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (Plan.crTxMode cPlan) setting $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
204207
(respTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet
205208
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst
206209

207210
(ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do
208211
(planTime', iPlan) <- withTiming $ liftEither $ Plan.inspectPlan apiReq
209-
(txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
212+
(txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl (Plan.ipTxmode iPlan) mempty $ Query.openApiQuery sCache pgVer conf tSchema
210213
(respTime', pgrst) <- withTiming $ liftEither $ Response.openApiResponse (T.decodeUtf8 prettyVersion, docsVersion) headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile
211214
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst
212215

@@ -230,9 +233,10 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
230233
where
231234
roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf)
232235
roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf
233-
runQuery isoLvl timeout mode query =
236+
funcSettings = dbFuncSettings sCache
237+
runQuery isoLvl mode funcSet query =
234238
runDbHandler appState conf isoLvl mode authenticated prepared $ do
235-
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout
239+
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) funcSet apiReq
236240
Query.runPreReq conf
237241
query
238242

src/PostgREST/Config/Database.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module PostgREST.Config.Database
88
, RoleSettings
99
, RoleIsolationLvl
1010
, TimezoneNames
11+
, FuncSettings
1112
, toIsolationLevel
1213
) where
1314

@@ -31,6 +32,7 @@ import Protolude
3132
type RoleSettings = (HM.HashMap ByteString (HM.HashMap ByteString ByteString))
3233
type RoleIsolationLvl = HM.HashMap ByteString SQL.IsolationLevel
3334
type TimezoneNames = Set ByteString -- cache timezone names for prefer timezone=
35+
type FuncSettings = [(ByteString,ByteString,ByteString)]
3436

3537
toIsolationLevel :: (Eq a, IsString a) => a -> SQL.IsolationLevel
3638
toIsolationLevel a = case a of

src/PostgREST/Query.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,12 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do
247247

248248
-- | Set transaction scoped settings
249249
setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] ->
250-
ApiRequest -> Maybe Text -> DbHandler ()
251-
setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
250+
[(ByteString,ByteString)] -> ApiRequest -> DbHandler ()
251+
setPgLocals AppConfig{..} claims role roleSettings funcSetting ApiRequest{..} = lift $
252252
SQL.statement mempty $ SQL.dynamicallyParameterized
253253
-- To ensure `GRANT SET ON PARAMETER <superuser_setting> TO authenticator` works, the role settings must be set before the impersonated role.
254254
-- Otherwise the GRANT SET would have to be applied to the impersonated role. See https://github.com/PostgREST/postgrest/issues/3045
255-
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql))
255+
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ funcSettingSql ++ appSettingsSql))
256256
HD.noResult configDbPreparedStatements
257257
where
258258
methodSql = setConfigWithConstantName ("request.method", iMethod)
@@ -264,7 +264,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
264264
roleSettingsSql = setConfigWithDynamicName <$> roleSettings
265265
appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings)
266266
timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences
267-
timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout
267+
funcSettingSql = setConfigWithDynamicName <$> funcSetting
268268
searchPathSql =
269269
let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in
270270
setConfigWithConstantName ("search_path", schemas)

src/PostgREST/SchemaCache.hs

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ import Contravariant.Extras (contrazip2)
4343
import Text.InterpolatedString.Perl6 (q)
4444

4545
import PostgREST.Config (AppConfig (..))
46-
import PostgREST.Config.Database (TimezoneNames,
46+
import PostgREST.Config.Database (FuncSettings,
47+
TimezoneNames,
4748
pgVersionStatement,
4849
toIsolationLevel)
4950
import PostgREST.Config.PgVersion (PgVersion, pgVersion100,
5051
pgVersion110,
51-
pgVersion120)
52+
pgVersion120,
53+
pgVersion150)
5254
import PostgREST.SchemaCache.Identifiers (AccessSet, FieldName,
5355
QualifiedIdentifier (..),
5456
RelIdentifier (..),
@@ -74,24 +76,25 @@ import qualified PostgREST.MediaType as MediaType
7476

7577
import Protolude
7678

77-
7879
data SchemaCache = SchemaCache
7980
{ dbTables :: TablesMap
8081
, dbRelationships :: RelationshipsMap
8182
, dbRoutines :: RoutineMap
8283
, dbRepresentations :: RepresentationsMap
8384
, dbMediaHandlers :: MediaHandlerMap
8485
, dbTimezones :: TimezoneNames
86+
, dbFuncSettings :: FuncSettings
8587
}
8688

8789
instance JSON.ToJSON SchemaCache where
88-
toJSON (SchemaCache tabs rels routs reps _ _) = JSON.object [
90+
toJSON (SchemaCache tabs rels routs reps _ _ _) = JSON.object [
8991
"dbTables" .= JSON.toJSON tabs
9092
, "dbRelationships" .= JSON.toJSON rels
9193
, "dbRoutines" .= JSON.toJSON routs
9294
, "dbRepresentations" .= JSON.toJSON reps
9395
, "dbMediaHandlers" .= JSON.emptyArray
9496
, "dbTimezones" .= JSON.emptyArray
97+
, "dbFuncSettings" .= JSON.emptyArray
9598
]
9699

97100
-- | A view foreign key or primary key dependency detected on its source table
@@ -145,6 +148,7 @@ querySchemaCache AppConfig{..} = do
145148
reps <- SQL.statement schemas $ dataRepresentations prepared
146149
mHdlers <- SQL.statement schemas $ mediaHandlers pgVer prepared
147150
tzones <- SQL.statement mempty $ timezones prepared
151+
funSets <- SQL.statement mempty $ funcSettings pgVer prepared
148152
_ <-
149153
let sleepCall = SQL.Statement "select pg_sleep($1)" (param HE.int4) HD.noResult prepared in
150154
whenJust configInternalSCSleep (`SQL.statement` sleepCall) -- only used for testing
@@ -159,6 +163,7 @@ querySchemaCache AppConfig{..} = do
159163
, dbRepresentations = reps
160164
, dbMediaHandlers = HM.union mHdlers initialMediaHandlers -- the custom handlers will override the initial ones
161165
, dbTimezones = tzones
166+
, dbFuncSettings = funSets
162167
}
163168
where
164169
schemas = toList configDbSchemas
@@ -195,6 +200,7 @@ removeInternal schemas dbStruct =
195200
, dbRepresentations = dbRepresentations dbStruct -- no need to filter, not directly exposed through the API
196201
, dbMediaHandlers = dbMediaHandlers dbStruct
197202
, dbTimezones = dbTimezones dbStruct
203+
, dbFuncSettings = dbFuncSettings dbStruct
198204
}
199205
where
200206
hasInternalJunction ComputedRelationship{} = False
@@ -297,7 +303,6 @@ decodeFuncs =
297303
<*> (parseVolatility <$> column HD.char)
298304
<*> column HD.bool
299305
<*> nullableColumn (toIsolationLevel <$> HD.text)
300-
<*> nullableColumn HD.text
301306

302307
addKey :: Routine -> (QualifiedIdentifier, Routine)
303308
addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd)
@@ -431,8 +436,7 @@ funcsSqlQuery pgVer = [q|
431436
bt.oid <> bt.base as rettype_is_composite_alias,
432437
p.provolatile,
433438
p.provariadic > 0 as hasvariadic,
434-
lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level,
435-
lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout
439+
lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level
436440
FROM pg_proc p
437441
LEFT JOIN arguments a ON a.oid = p.oid
438442
JOIN pg_namespace pn ON pn.oid = p.pronamespace
@@ -442,7 +446,6 @@ funcsSqlQuery pgVer = [q|
442446
LEFT JOIN pg_class comp ON comp.oid = t.typrelid
443447
LEFT JOIN pg_description as d ON d.objoid = p.oid
444448
LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%'
445-
LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%'
446449
WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true)
447450
|] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)")
448451

@@ -1203,6 +1206,34 @@ timezones = SQL.Statement sql HE.noParams decodeTimezones
12031206
decodeTimezones :: HD.Result TimezoneNames
12041207
decodeTimezones = S.fromList . map encodeUtf8 <$> HD.rowList (column HD.text)
12051208

1209+
funcSettings :: PgVersion -> Bool -> SQL.Statement () FuncSettings
1210+
funcSettings pgVer = SQL.Statement sql HE.noParams rows
1211+
where
1212+
sql = [q|
1213+
WITH
1214+
func_setting AS (
1215+
SELECT p.proname, unnest(p.proconfig) AS setting
1216+
FROM pg_proc p
1217+
),
1218+
kv_settings AS (
1219+
SELECT
1220+
proname,
1221+
substr(setting, 1, strpos(setting, '=') - 1) as key,
1222+
lower(substr(setting, strpos(setting, '=') + 1)) as value
1223+
FROM func_setting
1224+
)
1225+
SELECT
1226+
proname, kv.key AS key, kv.value AS value
1227+
FROM kv_settings kv
1228+
JOIN pg_settings ps ON ps.name = kv.key |] <>
1229+
(if pgVer >= pgVersion150
1230+
then "and (ps.context = 'user' or has_parameter_privilege(current_user::regrole::oid, ps.name, 'set'));"
1231+
else "and ps.context = 'user';")
1232+
1233+
rows :: HD.Result FuncSettings
1234+
rows = HD.rowList $ (,,) <$> (encodeUtf8 <$> column HD.text) <*> (encodeUtf8 <$> column HD.text) <*> (encodeUtf8 <$> column HD.text)
1235+
1236+
12061237
param :: HE.Value a -> HE.Params a
12071238
param = HE.param . HE.nonNullable
12081239

src/PostgREST/SchemaCache/Routine.hs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,11 @@ data Routine = Function
5858
, pdVolatility :: FuncVolatility
5959
, pdHasVariadic :: Bool
6060
, pdIsoLvl :: Maybe SQL.IsolationLevel
61-
, pdTimeout :: Maybe Text
6261
}
6362
deriving (Eq, Show, Generic)
6463
-- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error)
6564
instance JSON.ToJSON Routine where
66-
toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object
65+
toJSON (Function sch nam desc params ret vol hasVar _) = JSON.object
6766
[
6867
"pdSchema" .= sch
6968
, "pdName" .= nam
@@ -72,7 +71,6 @@ instance JSON.ToJSON Routine where
7271
, "pdReturnType" .= JSON.toJSON ret
7372
, "pdVolatility" .= JSON.toJSON vol
7473
, "pdHasVariadic" .= JSON.toJSON hasVar
75-
, "pdTimeout" .= tout
7674
]
7775

7876
data RoutineParam = RoutineParam
@@ -86,10 +84,10 @@ data RoutineParam = RoutineParam
8684

8785
-- Order by least number of params in the case of overloaded functions
8886
instance Ord Routine where
89-
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2
87+
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2
9088
| schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT
9189
| schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT
92-
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2)
90+
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2)
9391

9492
-- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine).
9593
-- | It uses a HashMap for a faster lookup.

test/io/fixtures.sql

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,9 @@ $$ language sql set statement_timeout = '4s';
198198
create function get_postgres_version() returns int as $$
199199
select current_setting('server_version_num')::int;
200200
$$ language sql;
201+
202+
GRANT SET ON PARAMETER log_min_duration_sample TO postgrest_test_anonymous;
203+
204+
create or replace function log_min_duration_test() returns text as $$
205+
select current_setting('log_min_duration_sample',false);
206+
$$ language sql set log_min_duration_sample = '5s';

0 commit comments

Comments
 (0)