Skip to content

Commit 8dccbdb

Browse files
committed
feat: add --pre-plan-file flag
1 parent 4b2cea3 commit 8dccbdb

5 files changed

Lines changed: 57 additions & 10 deletions

File tree

cmd/pg-schema-diff/apply_cmd.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ func runPlan(ctx context.Context, connConfig *pgx.ConnConfig, plan diff.Plan) er
122122
}
123123
defer conn.Close()
124124

125+
if plan.PrePlanDDL != "" {
126+
fmt.Println(header("Executing pre-plan DDL"))
127+
fmt.Printf("%s\n\n", plan.PrePlanDDL)
128+
if _, err := conn.ExecContext(ctx, plan.PrePlanDDL); err != nil {
129+
return fmt.Errorf("executing pre-plan DDL: %w", err)
130+
}
131+
}
132+
125133
// Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset
126134
// by default when it's returned to the pool.
127135
//

cmd/pg-schema-diff/plan_cmd.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ type (
7575

7676
schemaSourceFlags struct {
7777
schemaDirs []string
78+
prePlanFile string
7879
targetDatabaseDSN string
7980
}
8081

@@ -147,6 +148,11 @@ func schemaSourceFlagsVar(cmd *cobra.Command, p *schemaSourceFlags) {
147148
if err := cmd.MarkFlagDirname("schema-dir"); err != nil {
148149
panic(err)
149150
}
151+
cmd.Flags().StringVar(&p.prePlanFile, "pre-plan-file", "", "File path to a file containing DDL statements to prepend to the generated plan.")
152+
if err := cmd.MarkFlagFilename("pre-plan-file"); err != nil {
153+
panic(err)
154+
}
155+
150156
cmd.Flags().StringVar(&p.targetDatabaseDSN, "schema-source-dsn", "", "DSN for the database to use as the schema source. Use to generate a diff between the target database and the schema in this database.")
151157

152158
cmd.MarkFlagsMutuallyExclusive("schema-dir", "schema-source-dsn")
@@ -222,7 +228,7 @@ func parsePlanConfig(p planFlags) (planConfig, error) {
222228
func parseSchemaSource(p schemaSourceFlags) (schemaSourceFactory, error) {
223229
if len(p.schemaDirs) > 0 {
224230
return func() (diff.SchemaSource, io.Closer, error) {
225-
schemaSource, err := diff.DirSchemaSource(p.schemaDirs)
231+
schemaSource, err := diff.DirSchemaSource(p.schemaDirs, p.prePlanFile)
226232
if err != nil {
227233
return nil, nil, err
228234
}

pkg/diff/plan.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ type Plan struct {
5959
// plan on running them later, you should verify that the current schema hash matches the current schema hash.
6060
// To get the current schema hash, you can use schema.GetPublicSchemaHash(ctx, conn)
6161
CurrentSchemaHash string
62+
// PrePlanDDL is a string containing DDL statements that should be executed before the plan is applied.
63+
// This can be used for setup operations or preliminary changes that need to occur before the main migration.
64+
PrePlanDDL string
6265
}
6366

6467
// ApplyStatementTimeoutModifier applies the given timeout to all statements that match the given regex

pkg/diff/plan_generator.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"fmt"
7+
"os"
78
"strings"
89
"time"
910

@@ -106,7 +107,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
106107
// newDDL: DDL encoding the new schema
107108
// opts: Additional options to configure the plan generation
108109
func GeneratePlan(ctx context.Context, queryable sqldb.Queryable, tempdbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
109-
return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)
110+
return Generate(ctx, queryable, DDLSchemaSource(newDDL, ""), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)
110111
}
111112

112113
// Generate generates a migration plan to migrate the database to the target schema
@@ -155,9 +156,20 @@ func Generate(
155156
return Plan{}, fmt.Errorf("generating current schema hash: %w", err)
156157
}
157158

159+
prePlanDDL := ""
160+
// Prepend pre-plan file statements if available
161+
if ddlSource, ok := targetSchema.(*ddlSchemaSource); ok && ddlSource.prePlanFile != "" {
162+
content, err := os.ReadFile(ddlSource.prePlanFile)
163+
if err != nil {
164+
return Plan{}, fmt.Errorf("reading pre-plan file: %w", err)
165+
}
166+
prePlanDDL = string(content)
167+
}
168+
158169
plan := Plan{
159170
Statements: statements,
160171
CurrentSchemaHash: hash,
172+
PrePlanDDL: prePlanDDL,
161173
}
162174

163175
if planOptions.validatePlan {
@@ -216,11 +228,11 @@ func assertValidPlan(ctx context.Context,
216228
// on the database.
217229
setMaxConnectionsIfNotSet(tempDb.ConnPool, tempDbMaxConnections)
218230

219-
if err := setSchemaForEmptyDatabase(ctx, tempDb, currentSchema, planOptions); err != nil {
231+
if err := setSchemaForEmptyDatabase(ctx, tempDb, currentSchema, planOptions, plan.PrePlanDDL); err != nil {
220232
return fmt.Errorf("inserting schema in temporary database: %w", err)
221233
}
222234

223-
if err := executeStatementsIgnoreTimeouts(ctx, tempDb.ConnPool, plan.Statements); err != nil {
235+
if err := executeStatementsIgnoreTimeouts(ctx, tempDb.ConnPool, plan.Statements, plan.PrePlanDDL); err != nil {
224236
return fmt.Errorf("running migration plan: %w", err)
225237
}
226238

@@ -238,7 +250,7 @@ func setMaxConnectionsIfNotSet(db *sql.DB, defaultMax int) {
238250
}
239251
}
240252

241-
func setSchemaForEmptyDatabase(ctx context.Context, emptyDb *tempdb.Database, targetSchema schema.Schema, options *planOptions) error {
253+
func setSchemaForEmptyDatabase(ctx context.Context, emptyDb *tempdb.Database, targetSchema schema.Schema, options *planOptions, prePlanDDL string) error {
242254
// We can't create invalid indexes. We'll mark them valid in the schema, which should be functionally
243255
// equivalent for the sake of DDL and other statements.
244256
//
@@ -261,7 +273,7 @@ func setSchemaForEmptyDatabase(ctx context.Context, emptyDb *tempdb.Database, ta
261273
if err != nil {
262274
return fmt.Errorf("building schema diff: %w", err)
263275
}
264-
if err := executeStatementsIgnoreTimeouts(ctx, emptyDb.ConnPool, statements); err != nil {
276+
if err := executeStatementsIgnoreTimeouts(ctx, emptyDb.ConnPool, statements, prePlanDDL); err != nil {
265277
return fmt.Errorf("executing statements: %w\n%# v", err, pretty.Formatter(statements))
266278
}
267279
return nil
@@ -290,7 +302,7 @@ func assertMigratedSchemaMatchesTarget(migratedSchema, targetSchema schema.Schem
290302

291303
// executeStatementsIgnoreTimeouts executes the statements using the sql connection but ignores any provided timeouts.
292304
// This function is currently used to validate migration plans.
293-
func executeStatementsIgnoreTimeouts(ctx context.Context, connPool *sql.DB, statements []Statement) error {
305+
func executeStatementsIgnoreTimeouts(ctx context.Context, connPool *sql.DB, statements []Statement, prePlanDDL string) error {
294306
conn, err := connPool.Conn(ctx)
295307
if err != nil {
296308
return fmt.Errorf("getting connection from pool: %w", err)
@@ -301,6 +313,11 @@ func executeStatementsIgnoreTimeouts(ctx context.Context, connPool *sql.DB, stat
301313
if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d", (10*time.Second).Milliseconds())); err != nil {
302314
return fmt.Errorf("setting statement timeout: %w", err)
303315
}
316+
if prePlanDDL != "" {
317+
if _, err := conn.ExecContext(ctx, prePlanDDL); err != nil {
318+
return fmt.Errorf("executing pre-plan DDL: %w", err)
319+
}
320+
}
304321
// Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset
305322
// by default when it's returned to the pool.
306323
//

pkg/diff/schema_source.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ type (
3333

3434
ddlSchemaSource struct {
3535
ddl []ddlStatement
36+
prePlanFile string
3637
}
3738
)
3839

3940
// DirSchemaSource returns a SchemaSource that returns a schema based on the provided directories. You must provide a tempDBFactory
4041
// via the WithTempDbFactory option.
41-
func DirSchemaSource(dirs []string) (SchemaSource, error) {
42+
func DirSchemaSource(dirs []string, prePlanFile string) (SchemaSource, error) {
4243
var ddl []ddlStatement
4344
for _, dir := range dirs {
4445
stmts, err := getDDLFromPath(dir)
@@ -50,6 +51,7 @@ func DirSchemaSource(dirs []string) (SchemaSource, error) {
5051
}
5152
return &ddlSchemaSource{
5253
ddl: ddl,
54+
prePlanFile: prePlanFile,
5355
}, nil
5456
}
5557

@@ -86,7 +88,7 @@ func getDDLFromPath(path string) ([]ddlStatement, error) {
8688

8789
// DDLSchemaSource returns a SchemaSource that returns a schema based on the provided DDL. You must provide a tempDBFactory
8890
// via the WithTempDbFactory option.
89-
func DDLSchemaSource(stmts []string) SchemaSource {
91+
func DDLSchemaSource(stmts []string, prePlanFile string) SchemaSource {
9092
var ddl []ddlStatement
9193
for _, stmt := range stmts {
9294
ddl = append(ddl, ddlStatement{
@@ -96,7 +98,7 @@ func DDLSchemaSource(stmts []string) SchemaSource {
9698
)
9799
}
98100

99-
return &ddlSchemaSource{ddl: ddl}
101+
return &ddlSchemaSource{ddl: ddl, prePlanFile: prePlanFile}
100102
}
101103

102104
func (s *ddlSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDeps) (schema.Schema, error) {
@@ -114,6 +116,17 @@ func (s *ddlSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDe
114116
}
115117
}(tempDb.ContextualCloser)
116118

119+
if s.prePlanFile != "" {
120+
prePlanDDL, err := os.ReadFile(s.prePlanFile)
121+
if err != nil {
122+
return schema.Schema{}, fmt.Errorf("opening pre-plan file: %w", err)
123+
}
124+
125+
if _, err := tempDb.ConnPool.ExecContext(ctx, string(prePlanDDL)); err != nil {
126+
return schema.Schema{}, fmt.Errorf("running pre-plan DDL: %w", err)
127+
}
128+
}
129+
117130
for _, ddlStmt := range s.ddl {
118131
if _, err := tempDb.ConnPool.ExecContext(ctx, ddlStmt.stmt); err != nil {
119132
debugInfo := ""

0 commit comments

Comments
 (0)