44 "context"
55 "database/sql"
66 "fmt"
7+ "os"
78 "strings"
89 "time"
910
@@ -106,7 +107,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
106107// newDDL: DDL encoding the new schema
107108// opts: Additional options to configure the plan generation
108109func GeneratePlan (ctx context.Context , queryable sqldb.Queryable , tempdbFactory tempdb.Factory , newDDL []string , opts ... PlanOpt ) (Plan , error ) {
109- return Generate (ctx , queryable , DDLSchemaSource (newDDL ), append (opts , WithTempDbFactory (tempdbFactory ), WithIncludeSchemas ("public" ))... )
110+ return Generate (ctx , queryable , DDLSchemaSource (newDDL , "" ), append (opts , WithTempDbFactory (tempdbFactory ), WithIncludeSchemas ("public" ))... )
110111}
111112
112113// Generate generates a migration plan to migrate the database to the target schema
@@ -155,9 +156,20 @@ func Generate(
155156 return Plan {}, fmt .Errorf ("generating current schema hash: %w" , err )
156157 }
157158
159+ prePlanDDL := ""
160+ // Prepend pre-plan file statements if available
161+ if ddlSource , ok := targetSchema .(* ddlSchemaSource ); ok && ddlSource .prePlanFile != "" {
162+ content , err := os .ReadFile (ddlSource .prePlanFile )
163+ if err != nil {
164+ return Plan {}, fmt .Errorf ("reading pre-plan file: %w" , err )
165+ }
166+ prePlanDDL = string (content )
167+ }
168+
158169 plan := Plan {
159170 Statements : statements ,
160171 CurrentSchemaHash : hash ,
172+ PrePlanDDL : prePlanDDL ,
161173 }
162174
163175 if planOptions .validatePlan {
@@ -216,11 +228,11 @@ func assertValidPlan(ctx context.Context,
216228 // on the database.
217229 setMaxConnectionsIfNotSet (tempDb .ConnPool , tempDbMaxConnections )
218230
219- if err := setSchemaForEmptyDatabase (ctx , tempDb , currentSchema , planOptions ); err != nil {
231+ if err := setSchemaForEmptyDatabase (ctx , tempDb , currentSchema , planOptions , plan . PrePlanDDL ); err != nil {
220232 return fmt .Errorf ("inserting schema in temporary database: %w" , err )
221233 }
222234
223- if err := executeStatementsIgnoreTimeouts (ctx , tempDb .ConnPool , plan .Statements ); err != nil {
235+ if err := executeStatementsIgnoreTimeouts (ctx , tempDb .ConnPool , plan .Statements , plan . PrePlanDDL ); err != nil {
224236 return fmt .Errorf ("running migration plan: %w" , err )
225237 }
226238
@@ -238,7 +250,7 @@ func setMaxConnectionsIfNotSet(db *sql.DB, defaultMax int) {
238250 }
239251}
240252
241- func setSchemaForEmptyDatabase (ctx context.Context , emptyDb * tempdb.Database , targetSchema schema.Schema , options * planOptions ) error {
253+ func setSchemaForEmptyDatabase (ctx context.Context , emptyDb * tempdb.Database , targetSchema schema.Schema , options * planOptions , prePlanDDL string ) error {
242254 // We can't create invalid indexes. We'll mark them valid in the schema, which should be functionally
243255 // equivalent for the sake of DDL and other statements.
244256 //
@@ -261,7 +273,7 @@ func setSchemaForEmptyDatabase(ctx context.Context, emptyDb *tempdb.Database, ta
261273 if err != nil {
262274 return fmt .Errorf ("building schema diff: %w" , err )
263275 }
264- if err := executeStatementsIgnoreTimeouts (ctx , emptyDb .ConnPool , statements ); err != nil {
276+ if err := executeStatementsIgnoreTimeouts (ctx , emptyDb .ConnPool , statements , prePlanDDL ); err != nil {
265277 return fmt .Errorf ("executing statements: %w\n %# v" , err , pretty .Formatter (statements ))
266278 }
267279 return nil
@@ -290,7 +302,7 @@ func assertMigratedSchemaMatchesTarget(migratedSchema, targetSchema schema.Schem
290302
291303// executeStatementsIgnoreTimeouts executes the statements using the sql connection but ignores any provided timeouts.
292304// This function is currently used to validate migration plans.
293- func executeStatementsIgnoreTimeouts (ctx context.Context , connPool * sql.DB , statements []Statement ) error {
305+ func executeStatementsIgnoreTimeouts (ctx context.Context , connPool * sql.DB , statements []Statement , prePlanDDL string ) error {
294306 conn , err := connPool .Conn (ctx )
295307 if err != nil {
296308 return fmt .Errorf ("getting connection from pool: %w" , err )
@@ -301,6 +313,11 @@ func executeStatementsIgnoreTimeouts(ctx context.Context, connPool *sql.DB, stat
301313 if _ , err := conn .ExecContext (ctx , fmt .Sprintf ("SET SESSION statement_timeout = %d" , (10 * time .Second ).Milliseconds ())); err != nil {
302314 return fmt .Errorf ("setting statement timeout: %w" , err )
303315 }
316+ if prePlanDDL != "" {
317+ if _ , err := conn .ExecContext (ctx , prePlanDDL ); err != nil {
318+ return fmt .Errorf ("executing pre-plan DDL: %w" , err )
319+ }
320+ }
304321 // Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset
305322 // by default when it's returned to the pool.
306323 //
0 commit comments