Skip to content

Commit cc5f0e0

Browse files
authored
Proof out explicit schema injection without search_path (#798)
Here, follow up #794 to proof out the injection of an explicit schema into SQL queries so that putting River in a custom schema doesn't require the use of `search_path`. We use the newly introduced sqlc template package to prefix table names like so: -- name: JobCountByState :one SELECT count(*) FROM /* TEMPLATE: schema */river_job WHERE state = @State; The main advantage of this approach compared to `search_path` is that connections no longer require that their `search_path` be set with an explicit value to work. `search_path` can be set or misset, but because table names are prefixed with the right schema name already, queries still go through. This is especially useful in the context of PgBouncer, where a valid `search_path` setting can't be guaranteed. Notably, this change doesn't bring in enough new testing to prove that River really works with explicit schema injection, so for the time being this setting remains completely internal. What I'd like to try and do is get some basic infrastructure like this in place first, then prove it out by starting to move the test suite over to schema-specific tests. By virtue of doing that we'd be putting the entire load of the River test suite on the new paths, which should give us high confidence that it's all working as intended.
1 parent e85ae6a commit cc5f0e0

90 files changed

Lines changed: 2114 additions & 1220 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

client.go

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,13 @@ type Config struct {
282282
// Defaults to DefaultRetryPolicy.
283283
RetryPolicy ClientRetryPolicy
284284

285+
// schema is a non-standard schema where River tables are located. All table
286+
// references in database queries will use this value as a prefix.
287+
//
288+
// Defaults to empty, which causes the client to look for tables using the
289+
// setting of Postgres `search_path`.
290+
schema string
291+
285292
// SkipUnknownJobCheck is a flag to control whether the client should skip
286293
// checking to see if a registered worker exists in the client's worker bundle
287294
// for a job arg prior to insertion.
@@ -376,6 +383,7 @@ func (c *Config) WithDefaults() *Config {
376383
ReindexerSchedule: c.ReindexerSchedule,
377384
RescueStuckJobsAfter: valutil.ValOrDefault(c.RescueStuckJobsAfter, rescueAfter),
378385
RetryPolicy: retryPolicy,
386+
schema: c.schema,
379387
SkipUnknownJobCheck: c.SkipUnknownJobCheck,
380388
Test: c.Test,
381389
TestOnly: c.TestOnly,
@@ -681,7 +689,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
681689
// uses listen/notify. Instead, each service polls for changes it's
682690
// interested in. e.g. Elector polls to see if leader has expired.
683691
if !config.PollOnly {
684-
client.notifier = notifier.New(archetype, driver.GetListener())
692+
client.notifier = notifier.New(archetype, driver.GetListener(config.schema))
685693
client.services = append(client.services, client.notifier)
686694
}
687695
} else {
@@ -718,6 +726,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
718726
CancelledJobRetentionPeriod: config.CancelledJobRetentionPeriod,
719727
CompletedJobRetentionPeriod: config.CompletedJobRetentionPeriod,
720728
DiscardedJobRetentionPeriod: config.DiscardedJobRetentionPeriod,
729+
Schema: config.schema,
721730
Timeout: config.JobCleanerTimeout,
722731
}, driver.GetExecutor())
723732
maintenanceServices = append(maintenanceServices, jobCleaner)
@@ -728,6 +737,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
728737
jobRescuer := maintenance.NewRescuer(archetype, &maintenance.JobRescuerConfig{
729738
ClientRetryPolicy: config.RetryPolicy,
730739
RescueAfter: config.RescueStuckJobsAfter,
740+
Schema: config.schema,
731741
WorkUnitFactoryFunc: func(kind string) workunit.WorkUnitFactory {
732742
if workerInfo, ok := config.Workers.workersMap[kind]; ok {
733743
return workerInfo.workUnitFactory
@@ -743,6 +753,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
743753
jobScheduler := maintenance.NewJobScheduler(archetype, &maintenance.JobSchedulerConfig{
744754
Interval: config.schedulerInterval,
745755
NotifyInsert: client.maybeNotifyInsertForQueues,
756+
Schema: config.schema,
746757
}, driver.GetExecutor())
747758
maintenanceServices = append(maintenanceServices, jobScheduler)
748759
client.testSignals.jobScheduler = &jobScheduler.TestSignals
@@ -763,6 +774,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
763774
{
764775
queueCleaner := maintenance.NewQueueCleaner(archetype, &maintenance.QueueCleanerConfig{
765776
RetentionPeriod: maintenance.QueueRetentionPeriodDefault,
777+
Schema: config.schema,
766778
}, driver.GetExecutor())
767779
maintenanceServices = append(maintenanceServices, queueCleaner)
768780
client.testSignals.queueCleaner = &queueCleaner.TestSignals
@@ -774,7 +786,10 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
774786
scheduleFunc = config.ReindexerSchedule.Next
775787
}
776788

777-
reindexer := maintenance.NewReindexer(archetype, &maintenance.ReindexerConfig{ScheduleFunc: scheduleFunc}, driver.GetExecutor())
789+
reindexer := maintenance.NewReindexer(archetype, &maintenance.ReindexerConfig{
790+
ScheduleFunc: scheduleFunc,
791+
Schema: config.schema,
792+
}, driver.GetExecutor())
778793
maintenanceServices = append(maintenanceServices, reindexer)
779794
client.testSignals.reindexer = &reindexer.TestSignals
780795
}
@@ -1249,14 +1264,18 @@ func (c *Client[TTx]) jobCancel(ctx context.Context, exec riverdriver.Executor,
12491264
ID: jobID,
12501265
CancelAttemptedAt: c.baseService.Time.NowUTC(),
12511266
ControlTopic: string(notifier.NotificationTopicControl),
1267+
Schema: c.config.schema,
12521268
})
12531269
}
12541270

12551271
// JobDelete deletes the job with the given ID from the database, returning the
12561272
// deleted row if it was deleted. Jobs in the running state are not deleted,
12571273
// instead returning rivertype.ErrJobRunning.
12581274
func (c *Client[TTx]) JobDelete(ctx context.Context, id int64) (*rivertype.JobRow, error) {
1259-
return c.driver.GetExecutor().JobDelete(ctx, id)
1275+
return c.driver.GetExecutor().JobDelete(ctx, &riverdriver.JobDeleteParams{
1276+
ID: id,
1277+
Schema: c.config.schema,
1278+
})
12601279
}
12611280

12621281
// JobDelete deletes the job with the given ID from the database, returning the
@@ -1266,20 +1285,29 @@ func (c *Client[TTx]) JobDelete(ctx context.Context, id int64) (*rivertype.JobRo
12661285
// until the transaction commits, and if the transaction rolls back, so too is
12671286
// the deleted job.
12681287
func (c *Client[TTx]) JobDeleteTx(ctx context.Context, tx TTx, id int64) (*rivertype.JobRow, error) {
1269-
return c.driver.UnwrapExecutor(tx).JobDelete(ctx, id)
1288+
return c.driver.UnwrapExecutor(tx).JobDelete(ctx, &riverdriver.JobDeleteParams{
1289+
ID: id,
1290+
Schema: c.config.schema,
1291+
})
12701292
}
12711293

12721294
// JobGet fetches a single job by its ID. Returns the up-to-date JobRow for the
12731295
// specified jobID if it exists. Returns ErrNotFound if the job doesn't exist.
12741296
func (c *Client[TTx]) JobGet(ctx context.Context, id int64) (*rivertype.JobRow, error) {
1275-
return c.driver.GetExecutor().JobGetByID(ctx, id)
1297+
return c.driver.GetExecutor().JobGetByID(ctx, &riverdriver.JobGetByIDParams{
1298+
ID: id,
1299+
Schema: c.config.schema,
1300+
})
12761301
}
12771302

12781303
// JobGetTx fetches a single job by its ID, within a transaction. Returns the
12791304
// up-to-date JobRow for the specified jobID if it exists. Returns ErrNotFound
12801305
// if the job doesn't exist.
12811306
func (c *Client[TTx]) JobGetTx(ctx context.Context, tx TTx, id int64) (*rivertype.JobRow, error) {
1282-
return c.driver.UnwrapExecutor(tx).JobGetByID(ctx, id)
1307+
return c.driver.UnwrapExecutor(tx).JobGetByID(ctx, &riverdriver.JobGetByIDParams{
1308+
ID: id,
1309+
Schema: c.config.schema,
1310+
})
12831311
}
12841312

12851313
// JobRetry updates the job with the given ID to make it immediately available
@@ -1291,7 +1319,10 @@ func (c *Client[TTx]) JobGetTx(ctx context.Context, tx TTx, id int64) (*rivertyp
12911319
// MaxAttempts is also incremented by one if the job has already exhausted its
12921320
// max attempts.
12931321
func (c *Client[TTx]) JobRetry(ctx context.Context, id int64) (*rivertype.JobRow, error) {
1294-
return c.driver.GetExecutor().JobRetry(ctx, id)
1322+
return c.driver.GetExecutor().JobRetry(ctx, &riverdriver.JobRetryParams{
1323+
ID: id,
1324+
Schema: c.config.schema,
1325+
})
12951326
}
12961327

12971328
// JobRetryTx updates the job with the given ID to make it immediately available
@@ -1308,7 +1339,10 @@ func (c *Client[TTx]) JobRetry(ctx context.Context, id int64) (*rivertype.JobRow
13081339
// MaxAttempts is also incremented by one if the job has already exhausted its
13091340
// max attempts.
13101341
func (c *Client[TTx]) JobRetryTx(ctx context.Context, tx TTx, id int64) (*rivertype.JobRow, error) {
1311-
return c.driver.UnwrapExecutor(tx).JobRetry(ctx, id)
1342+
return c.driver.UnwrapExecutor(tx).JobRetry(ctx, &riverdriver.JobRetryParams{
1343+
ID: id,
1344+
Schema: c.config.schema,
1345+
})
13121346
}
13131347

13141348
// ID returns the unique ID of this client as set in its config or
@@ -1574,7 +1608,10 @@ func (c *Client[TTx]) validateParamsAndInsertMany(ctx context.Context, tx riverd
15741608
// by the PeriodicJobEnqueuer.
15751609
func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) {
15761610
return c.insertManyShared(ctx, tx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) {
1577-
results, err := c.pilot.JobInsertMany(ctx, tx, insertParams)
1611+
results, err := c.pilot.JobInsertMany(ctx, tx, &riverdriver.JobInsertFastManyParams{
1612+
Jobs: insertParams,
1613+
Schema: c.config.schema,
1614+
})
15781615
if err != nil {
15791616
return nil, err
15801617
}
@@ -1744,7 +1781,10 @@ func (c *Client[TTx]) insertManyFast(ctx context.Context, tx riverdriver.Executo
17441781
}
17451782

17461783
results, err := c.insertManyShared(ctx, tx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) {
1747-
count, err := tx.JobInsertFastManyNoReturning(ctx, insertParams)
1784+
count, err := tx.JobInsertFastManyNoReturning(ctx, &riverdriver.JobInsertFastManyParams{
1785+
Jobs: insertParams,
1786+
Schema: c.config.schema,
1787+
})
17481788
if err != nil {
17491789
return nil, err
17501790
}
@@ -1786,8 +1826,9 @@ func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdr
17861826
}
17871827

17881828
err := tx.NotifyMany(ctx, &riverdriver.NotifyManyParams{
1789-
Topic: string(notifier.NotificationTopicInsert),
17901829
Payload: payloads,
1830+
Schema: c.config.schema,
1831+
Topic: string(notifier.NotificationTopicInsert),
17911832
})
17921833
if err != nil {
17931834
c.baseService.Logger.ErrorContext(
@@ -1817,6 +1858,7 @@ func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriv
18171858

18181859
err = tx.NotifyMany(ctx, &riverdriver.NotifyManyParams{
18191860
Payload: []string{string(payload)},
1861+
Schema: c.config.schema,
18201862
Topic: string(notifier.NotificationTopicControl),
18211863
})
18221864
if err != nil {
@@ -1863,6 +1905,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr
18631905
QueueEventCallback: c.subscriptionManager.distributeQueueEvent,
18641906
RetryPolicy: c.config.RetryPolicy,
18651907
SchedulerInterval: c.config.schedulerInterval,
1908+
Schema: c.config.schema,
18661909
StaleProducerRetentionPeriod: 5 * time.Minute,
18671910
Workers: c.config.Workers,
18681911
})
@@ -1979,7 +2022,10 @@ func (c *Client[TTx]) Queues() *QueueBundle { return c.queues }
19792022
// The provided context is used for the underlying Postgres query and can be
19802023
// used to cancel the operation or apply a timeout.
19812024
func (c *Client[TTx]) QueueGet(ctx context.Context, name string) (*rivertype.Queue, error) {
1982-
return c.driver.GetExecutor().QueueGet(ctx, name)
2025+
return c.driver.GetExecutor().QueueGet(ctx, &riverdriver.QueueGetParams{
2026+
Name: name,
2027+
Schema: c.config.schema,
2028+
})
19832029
}
19842030

19852031
// QueueGetTx returns the queue with the given name. If the queue has not recently
@@ -1988,7 +2034,10 @@ func (c *Client[TTx]) QueueGet(ctx context.Context, name string) (*rivertype.Que
19882034
// The provided context is used for the underlying Postgres query and can be
19892035
// used to cancel the operation or apply a timeout.
19902036
func (c *Client[TTx]) QueueGetTx(ctx context.Context, tx TTx, name string) (*rivertype.Queue, error) {
1991-
return c.driver.UnwrapExecutor(tx).QueueGet(ctx, name)
2037+
return c.driver.UnwrapExecutor(tx).QueueGet(ctx, &riverdriver.QueueGetParams{
2038+
Name: name,
2039+
Schema: c.config.schema,
2040+
})
19922041
}
19932042

19942043
// QueueListResult is the result of a job list operation. It contains a list of
@@ -2014,7 +2063,10 @@ func (c *Client[TTx]) QueueList(ctx context.Context, params *QueueListParams) (*
20142063
params = NewQueueListParams()
20152064
}
20162065

2017-
queues, err := c.driver.GetExecutor().QueueList(ctx, int(params.paginationCount))
2066+
queues, err := c.driver.GetExecutor().QueueList(ctx, &riverdriver.QueueListParams{
2067+
Limit: int(params.paginationCount),
2068+
Schema: c.config.schema,
2069+
})
20182070
if err != nil {
20192071
return nil, err
20202072
}
@@ -2038,7 +2090,10 @@ func (c *Client[TTx]) QueueListTx(ctx context.Context, tx TTx, params *QueueList
20382090
params = NewQueueListParams()
20392091
}
20402092

2041-
queues, err := c.driver.UnwrapExecutor(tx).QueueList(ctx, int(params.paginationCount))
2093+
queues, err := c.driver.UnwrapExecutor(tx).QueueList(ctx, &riverdriver.QueueListParams{
2094+
Limit: int(params.paginationCount),
2095+
Schema: c.config.schema,
2096+
})
20422097
if err != nil {
20432098
return nil, err
20442099
}
@@ -2064,7 +2119,10 @@ func (c *Client[TTx]) QueuePause(ctx context.Context, name string, opts *QueuePa
20642119
}
20652120
defer tx.Rollback(ctx)
20662121

2067-
if err := tx.QueuePause(ctx, name); err != nil {
2122+
if err := tx.QueuePause(ctx, &riverdriver.QueuePauseParams{
2123+
Name: name,
2124+
Schema: c.config.schema,
2125+
}); err != nil {
20682126
return err
20692127
}
20702128

@@ -2089,7 +2147,10 @@ func (c *Client[TTx]) QueuePause(ctx context.Context, name string, opts *QueuePa
20892147
func (c *Client[TTx]) QueuePauseTx(ctx context.Context, tx TTx, name string, opts *QueuePauseOpts) error {
20902148
executorTx := c.driver.UnwrapExecutor(tx)
20912149

2092-
if err := executorTx.QueuePause(ctx, name); err != nil {
2150+
if err := executorTx.QueuePause(ctx, &riverdriver.QueuePauseParams{
2151+
Name: name,
2152+
Schema: c.config.schema,
2153+
}); err != nil {
20932154
return err
20942155
}
20952156

@@ -2119,7 +2180,10 @@ func (c *Client[TTx]) QueueResume(ctx context.Context, name string, opts *QueueP
21192180
}
21202181
defer tx.Rollback(ctx)
21212182

2122-
if err := tx.QueueResume(ctx, name); err != nil {
2183+
if err := tx.QueueResume(ctx, &riverdriver.QueueResumeParams{
2184+
Name: name,
2185+
Schema: c.config.schema,
2186+
}); err != nil {
21232187
return err
21242188
}
21252189

@@ -2145,7 +2209,10 @@ func (c *Client[TTx]) QueueResume(ctx context.Context, name string, opts *QueueP
21452209
func (c *Client[TTx]) QueueResumeTx(ctx context.Context, tx TTx, name string, opts *QueuePauseOpts) error {
21462210
executorTx := c.driver.UnwrapExecutor(tx)
21472211

2148-
if err := executorTx.QueueResume(ctx, name); err != nil {
2212+
if err := executorTx.QueueResume(ctx, &riverdriver.QueueResumeParams{
2213+
Name: name,
2214+
Schema: c.config.schema,
2215+
}); err != nil {
21492216
return err
21502217
}
21512218

@@ -2215,8 +2282,9 @@ func (c *Client[TTx]) queueUpdate(ctx context.Context, executorTx riverdriver.Ex
22152282
}
22162283

22172284
if err := executorTx.NotifyMany(ctx, &riverdriver.NotifyManyParams{
2218-
Topic: string(notifier.NotificationTopicControl),
22192285
Payload: []string{string(payload)},
2286+
Schema: c.config.schema,
2287+
Topic: string(notifier.NotificationTopicControl),
22202288
}); err != nil {
22212289
return nil, err
22222290
}

0 commit comments

Comments
 (0)