diff --git a/internal/db/declarative/declarative.go b/internal/db/declarative/declarative.go index 2a0454d01d..f7c8ca0021 100644 --- a/internal/db/declarative/declarative.go +++ b/internal/db/declarative/declarative.go @@ -235,10 +235,9 @@ func WriteDeclarativeSchemas(output diff.DeclarativeOutput, fsys afero.Fs) error return err } } - // When pg-delta has its own config section, the declarative path is the single - // source of truth there; do not overwrite [db.migrations] schema_paths. - if utils.IsPgDeltaEnabled() && utils.Config.Experimental.PgDelta != nil && - len(utils.Config.Experimental.PgDelta.DeclarativeSchemaPath) > 0 { + // When pg-delta is enabled, the declarative directory (default or configured) + // is the source of truth; do not overwrite [db.migrations] schema_paths. + if utils.IsPgDeltaEnabled() { return nil } utils.Config.Db.Migrations.SchemaPaths = []string{ diff --git a/internal/db/declarative/declarative_test.go b/internal/db/declarative/declarative_test.go index 73b6f473aa..229e6ffe79 100644 --- a/internal/db/declarative/declarative_test.go +++ b/internal/db/declarative/declarative_test.go @@ -48,6 +48,34 @@ func TestWriteDeclarativeSchemas(t *testing.T) { assert.Contains(t, string(cfg), `"database"`) } +func TestWriteDeclarativeSchemasSkipsConfigUpdateWhenPgDeltaEnabled(t *testing.T) { + fsys := afero.NewMemMapFs() + originalConfig := "[db]\n" + require.NoError(t, afero.WriteFile(fsys, utils.ConfigPath, []byte(originalConfig), 0644)) + original := utils.Config.Experimental.PgDelta + utils.Config.Experimental.PgDelta = &config.PgDeltaConfig{Enabled: true} + t.Cleanup(func() { + utils.Config.Experimental.PgDelta = original + }) + + output := diff.DeclarativeOutput{ + Files: []diff.DeclarativeFile{ + {Path: "schemas/public/tables/users.sql", SQL: "create table users(id bigint);"}, + }, + } + + err := WriteDeclarativeSchemas(output, fsys) + require.NoError(t, err) + + users, err := afero.ReadFile(fsys, filepath.Join(utils.DeclarativeDir, "schemas", "public", "tables", "users.sql")) + require.NoError(t, err) + assert.Equal(t, "create table users(id bigint);", string(users)) + + cfg, err := afero.ReadFile(fsys, utils.ConfigPath) + require.NoError(t, err) + assert.Equal(t, originalConfig, string(cfg)) +} + func TestTryCacheMigrationsCatalogWritesPrefixedCache(t *testing.T) { fsys := afero.NewMemMapFs() original := utils.Config.Experimental.PgDelta @@ -146,6 +174,38 @@ func TestWriteDeclarativeSchemasUsesConfiguredDir(t *testing.T) { assert.Contains(t, string(cfg), `db/decl`) } +func TestWriteDeclarativeSchemasSkipsConfigUpdateForPgDeltaCustomDir(t *testing.T) { + fsys := afero.NewMemMapFs() + originalConfig := "[db]\n" + require.NoError(t, afero.WriteFile(fsys, utils.ConfigPath, []byte(originalConfig), 0644)) + original := utils.Config.Experimental.PgDelta + utils.Config.Experimental.PgDelta = &config.PgDeltaConfig{ + Enabled: true, + DeclarativeSchemaPath: filepath.Join(utils.SupabaseDirPath, "db", "decl"), + } + t.Cleanup(func() { + utils.Config.Experimental.PgDelta = original + }) + + output := diff.DeclarativeOutput{ + Files: []diff.DeclarativeFile{ + {Path: "cluster/roles.sql", SQL: "create role app;"}, + }, + } + + err := WriteDeclarativeSchemas(output, fsys) + require.NoError(t, err) + + rolesPath := filepath.Join(utils.SupabaseDirPath, "db", "decl", "cluster", "roles.sql") + roles, err := afero.ReadFile(fsys, rolesPath) + require.NoError(t, err) + assert.Equal(t, "create role app;", string(roles)) + + cfg, err := afero.ReadFile(fsys, utils.ConfigPath) + require.NoError(t, err) + assert.Equal(t, originalConfig, string(cfg)) +} + func TestWriteDeclarativeSchemasRejectsUnsafePath(t *testing.T) { // Export paths must stay within supabase/declarative to prevent traversal. fsys := afero.NewMemMapFs() diff --git a/internal/db/diff/templates/pgdelta.ts b/internal/db/diff/templates/pgdelta.ts index 37995c491c..234c91ab06 100644 --- a/internal/db/diff/templates/pgdelta.ts +++ b/internal/db/diff/templates/pgdelta.ts @@ -21,7 +21,14 @@ const target = Deno.env.get("TARGET"); const includedSchemas = Deno.env.get("INCLUDED_SCHEMAS"); if (includedSchemas) { - supabase.filter = { schema: includedSchemas.split(",") }; + const schemas = includedSchemas.split(","); + const schemaFilter = { + or: [{ "*/schema": schemas }, { "schema/name": schemas }], + }; + // CompositionPattern `and` is valid FilterDSL; Deno's structural typing is strict on `or` branches. + supabase.filter = { + and: [supabase.filter!, schemaFilter], + } as typeof supabase.filter; } const formatOptionsRaw = Deno.env.get("FORMAT_OPTIONS"); diff --git a/internal/db/diff/templates/pgdelta_declarative_export.ts b/internal/db/diff/templates/pgdelta_declarative_export.ts index cdb59924f2..dead372a70 100644 --- a/internal/db/diff/templates/pgdelta_declarative_export.ts +++ b/internal/db/diff/templates/pgdelta_declarative_export.ts @@ -22,20 +22,23 @@ async function resolveInput(ref: string | undefined) { const source = Deno.env.get("SOURCE"); const target = Deno.env.get("TARGET"); supabase.filter = { - // Also allow dropped extensions from migrations to be capted in the declarative schema export + // Also allow dropped extensions from migrations to be captured in the declarative schema export // TODO: fix upstream bug into pgdelta supabase integration or: [ - ...supabase.filter.or, - { type: "extension", operation: "drop", scope: "object" }, + ...supabase.filter!.or!, + { objectType: "extension", operation: "drop", scope: "object" }, ], }; const includedSchemas = Deno.env.get("INCLUDED_SCHEMAS"); if (includedSchemas) { - const schemaFilter = { schema: includedSchemas.split(",") }; - supabase.filter = supabase.filter - ? { and: [supabase.filter, schemaFilter] } - : schemaFilter; + const schemas = includedSchemas.split(","); + const schemaFilter = { + or: [{ "*/schema": schemas }, { "schema/name": schemas }], + }; + supabase.filter = { + and: [supabase.filter!, schemaFilter], + } as unknown as typeof supabase.filter; } const formatOptionsRaw = Deno.env.get("FORMAT_OPTIONS"); diff --git a/pkg/config/auth.go b/pkg/config/auth.go index 82c708e37c..c1795c8971 100644 --- a/pkg/config/auth.go +++ b/pkg/config/auth.go @@ -163,6 +163,7 @@ type ( SigningKeysPath string `toml:"signing_keys_path" json:"signing_keys_path"` SigningKeys []JWK `toml:"-" json:"-"` Passkey *passkey `toml:"passkey" json:"passkey"` + Webauthn *webauthn `toml:"webauthn" json:"webauthn"` RateLimit rateLimit `toml:"rate_limit" json:"rate_limit"` Captcha *captcha `toml:"captcha" json:"captcha"` @@ -380,7 +381,10 @@ type ( } passkey struct { - Enabled bool `toml:"enabled" json:"enabled"` + Enabled bool `toml:"enabled" json:"enabled"` + } + + webauthn struct { RpDisplayName string `toml:"rp_display_name" json:"rp_display_name"` RpId string `toml:"rp_id" json:"rp_id"` RpOrigins []string `toml:"rp_origins" json:"rp_origins"` @@ -418,6 +422,9 @@ func (a *auth) ToUpdateAuthConfigBody() v1API.UpdateAuthConfigBody { if a.Passkey != nil { a.Passkey.toAuthConfigBody(&body) } + if a.Webauthn != nil { + a.Webauthn.toAuthConfigBody(&body) + } a.Hook.toAuthConfigBody(&body) a.MFA.toAuthConfigBody(&body) a.Sessions.toAuthConfigBody(&body) @@ -442,6 +449,7 @@ func (a *auth) FromRemoteAuthConfig(remoteConfig v1API.AuthConfigResponse) { prc := ValOrDefault(remoteConfig.PasswordRequiredCharacters, "") a.PasswordRequirements = NewPasswordRequirement(v1API.UpdateAuthConfigBodyPasswordRequiredCharacters(prc)) a.Passkey.fromAuthConfig(remoteConfig) + a.Webauthn.fromAuthConfig(remoteConfig) a.RateLimit.fromAuthConfig(remoteConfig) if s := a.Email.Smtp; s != nil && s.Enabled { a.RateLimit.EmailSent = cast.IntToUint(ValOrDefault(remoteConfig.RateLimitEmailSent, 0)) @@ -502,11 +510,7 @@ func (c *captcha) fromAuthConfig(remoteConfig v1API.AuthConfigResponse) { } func (p passkey) toAuthConfigBody(body *v1API.UpdateAuthConfigBody) { - if body.PasskeyEnabled = cast.Ptr(p.Enabled); p.Enabled { - body.WebauthnRpDisplayName = nullable.NewNullableWithValue(p.RpDisplayName) - body.WebauthnRpId = nullable.NewNullableWithValue(p.RpId) - body.WebauthnRpOrigins = nullable.NewNullableWithValue(strings.Join(p.RpOrigins, ",")) - } + body.PasskeyEnabled = cast.Ptr(p.Enabled) } func (p *passkey) fromAuthConfig(remoteConfig v1API.AuthConfigResponse) { @@ -514,15 +518,25 @@ func (p *passkey) fromAuthConfig(remoteConfig v1API.AuthConfigResponse) { if p == nil { return } - // Ignore disabled passkey fields to minimise config diff - if p.Enabled { - p.RpDisplayName = ValOrDefault(remoteConfig.WebauthnRpDisplayName, "") - p.RpId = ValOrDefault(remoteConfig.WebauthnRpId, "") - p.RpOrigins = strToArr(ValOrDefault(remoteConfig.WebauthnRpOrigins, "")) - } p.Enabled = remoteConfig.PasskeyEnabled } +func (w webauthn) toAuthConfigBody(body *v1API.UpdateAuthConfigBody) { + body.WebauthnRpDisplayName = nullable.NewNullableWithValue(w.RpDisplayName) + body.WebauthnRpId = nullable.NewNullableWithValue(w.RpId) + body.WebauthnRpOrigins = nullable.NewNullableWithValue(strings.Join(w.RpOrigins, ",")) +} + +func (w *webauthn) fromAuthConfig(remoteConfig v1API.AuthConfigResponse) { + // When local config is not set, we assume platform defaults should not change + if w == nil { + return + } + w.RpDisplayName = ValOrDefault(remoteConfig.WebauthnRpDisplayName, "") + w.RpId = ValOrDefault(remoteConfig.WebauthnRpId, "") + w.RpOrigins = strToArr(ValOrDefault(remoteConfig.WebauthnRpOrigins, "")) +} + func (h hook) toAuthConfigBody(body *v1API.UpdateAuthConfigBody) { // When local config is not set, we assume platform defaults should not change if hook := h.BeforeUserCreated; hook != nil { diff --git a/pkg/config/auth_test.go b/pkg/config/auth_test.go index 65f0066da9..61ba5b429c 100644 --- a/pkg/config/auth_test.go +++ b/pkg/config/auth_test.go @@ -215,8 +215,8 @@ func TestCaptchaDiff(t *testing.T) { func TestPasskeyConfigMapping(t *testing.T) { t.Run("serializes passkey config to update body", func(t *testing.T) { c := newWithDefaults() - c.Passkey = &passkey{ - Enabled: true, + c.Passkey = &passkey{Enabled: true} + c.Webauthn = &webauthn{ RpDisplayName: "Supabase CLI", RpId: "localhost", RpOrigins: []string{ @@ -235,14 +235,9 @@ func TestPasskeyConfigMapping(t *testing.T) { assert.Equal(t, "http://127.0.0.1:3000,https://localhost:3000", ValOrDefault(body.WebauthnRpOrigins, "")) }) - t.Run("does not serialize rp fields when passkey is disabled", func(t *testing.T) { + t.Run("does not serialize rp fields when webauthn is undefined", func(t *testing.T) { c := newWithDefaults() - c.Passkey = &passkey{ - Enabled: false, - RpDisplayName: "Supabase CLI", - RpId: "localhost", - RpOrigins: []string{"http://127.0.0.1:3000"}, - } + c.Passkey = &passkey{Enabled: false} // Run test body := c.ToUpdateAuthConfigBody() // Check result @@ -257,12 +252,27 @@ func TestPasskeyConfigMapping(t *testing.T) { assert.Error(t, err) }) - t.Run("hydrates passkey config from remote", func(t *testing.T) { + t.Run("serializes webauthn fields independently of passkey", func(t *testing.T) { c := newWithDefaults() - c.Passkey = &passkey{ - Enabled: true, + c.Webauthn = &webauthn{ + RpDisplayName: "Supabase CLI", + RpId: "localhost", + RpOrigins: []string{"http://127.0.0.1:3000"}, } // Run test + body := c.ToUpdateAuthConfigBody() + // Check result + assert.Nil(t, body.PasskeyEnabled) + assert.Equal(t, "Supabase CLI", ValOrDefault(body.WebauthnRpDisplayName, "")) + assert.Equal(t, "localhost", ValOrDefault(body.WebauthnRpId, "")) + assert.Equal(t, "http://127.0.0.1:3000", ValOrDefault(body.WebauthnRpOrigins, "")) + }) + + t.Run("hydrates passkey and webauthn config from remote", func(t *testing.T) { + c := newWithDefaults() + c.Passkey = &passkey{Enabled: true} + c.Webauthn = &webauthn{} + // Run test c.FromRemoteAuthConfig(v1API.AuthConfigResponse{ PasskeyEnabled: true, WebauthnRpDisplayName: nullable.NewNullableWithValue("Supabase CLI"), @@ -272,12 +282,14 @@ func TestPasskeyConfigMapping(t *testing.T) { // Check result if assert.NotNil(t, c.Passkey) { assert.True(t, c.Passkey.Enabled) - assert.Equal(t, "Supabase CLI", c.Passkey.RpDisplayName) - assert.Equal(t, "localhost", c.Passkey.RpId) + } + if assert.NotNil(t, c.Webauthn) { + assert.Equal(t, "Supabase CLI", c.Webauthn.RpDisplayName) + assert.Equal(t, "localhost", c.Webauthn.RpId) assert.Equal(t, []string{ "http://127.0.0.1:3000", "https://localhost:3000", - }, c.Passkey.RpOrigins) + }, c.Webauthn.RpOrigins) } }) @@ -292,6 +304,7 @@ func TestPasskeyConfigMapping(t *testing.T) { }) // Check result assert.Nil(t, c.Passkey) + assert.Nil(t, c.Webauthn) }) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 90d81741b1..1a2a780d6d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -262,9 +262,13 @@ func (a *auth) Clone() auth { } if copy.Passkey != nil { passkey := *a.Passkey - passkey.RpOrigins = slices.Clone(a.Passkey.RpOrigins) copy.Passkey = &passkey } + if copy.Webauthn != nil { + webauthn := *a.Webauthn + webauthn.RpOrigins = slices.Clone(a.Webauthn.RpOrigins) + copy.Webauthn = &webauthn + } copy.External = maps.Clone(a.External) if a.Email.Smtp != nil { mailer := *a.Email.Smtp @@ -640,6 +644,8 @@ func (c *config) Load(path string, fsys fs.FS, overrides ...ConfigEditor) error c.Db.Image = pg14 case 15: c.Db.Image = pg15 + case 17: + c.Db.Image = pg17 } if c.Db.MajorVersion > 14 { if version, err := fs.ReadFile(fsys, builder.PostgresVersionPath); err == nil { @@ -921,21 +927,22 @@ func (c *config) Validate(fsys fs.FS) error { return errors.Errorf("failed to decode signing keys: %w", err) } } - if c.Auth.Passkey != nil { - if c.Auth.Passkey.Enabled { - if len(c.Auth.Passkey.RpId) == 0 { - return errors.New("Missing required field in config: auth.passkey.rp_id") - } - if len(c.Auth.Passkey.RpOrigins) == 0 { - return errors.New("Missing required field in config: auth.passkey.rp_origins") - } - if err := assertEnvLoaded(c.Auth.Passkey.RpId); err != nil { - return errors.Errorf("Invalid config for auth.passkey.rp_id: %v", err) - } - for i, origin := range c.Auth.Passkey.RpOrigins { - if err := assertEnvLoaded(origin); err != nil { - return errors.Errorf("Invalid config for auth.passkey.rp_origins[%d]: %v", i, err) - } + if c.Auth.Passkey != nil && c.Auth.Passkey.Enabled { + if c.Auth.Webauthn == nil { + return errors.New("Missing required config section: auth.webauthn (required when auth.passkey.enabled is true)") + } + if len(c.Auth.Webauthn.RpId) == 0 { + return errors.New("Missing required field in config: auth.webauthn.rp_id") + } + if len(c.Auth.Webauthn.RpOrigins) == 0 { + return errors.New("Missing required field in config: auth.webauthn.rp_origins") + } + if err := assertEnvLoaded(c.Auth.Webauthn.RpId); err != nil { + return errors.Errorf("Invalid config for auth.webauthn.rp_id: %v", err) + } + for i, origin := range c.Auth.Webauthn.RpOrigins { + if err := assertEnvLoaded(origin); err != nil { + return errors.Errorf("Invalid config for auth.webauthn.rp_origins[%d]: %v", i, err) } } } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 6957331161..f019b7cbce 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -74,7 +74,7 @@ func TestConfigParsing(t *testing.T) { // Run test assert.Error(t, config.Load("", fsys)) }) - t.Run("config file with passkey settings", func(t *testing.T) { + t.Run("config file with passkey and webauthn settings", func(t *testing.T) { config := NewConfig() fsys := fs.MapFS{ "supabase/config.toml": &fs.MapFile{Data: []byte(` @@ -83,6 +83,7 @@ enabled = true site_url = "http://127.0.0.1:3000" [auth.passkey] enabled = true +[auth.webauthn] rp_display_name = "Supabase CLI" rp_id = "localhost" rp_origins = ["http://127.0.0.1:3000", "https://localhost:3000"] @@ -93,15 +94,56 @@ rp_origins = ["http://127.0.0.1:3000", "https://localhost:3000"] // Check result if assert.NotNil(t, config.Auth.Passkey) { assert.True(t, config.Auth.Passkey.Enabled) - assert.Equal(t, "Supabase CLI", config.Auth.Passkey.RpDisplayName) - assert.Equal(t, "localhost", config.Auth.Passkey.RpId) + } + if assert.NotNil(t, config.Auth.Webauthn) { + assert.Equal(t, "Supabase CLI", config.Auth.Webauthn.RpDisplayName) + assert.Equal(t, "localhost", config.Auth.Webauthn.RpId) assert.Equal(t, []string{ "http://127.0.0.1:3000", "https://localhost:3000", - }, config.Auth.Passkey.RpOrigins) + }, config.Auth.Webauthn.RpOrigins) } }) + t.Run("webauthn section without passkey loads successfully", func(t *testing.T) { + config := NewConfig() + fsys := fs.MapFS{ + "supabase/config.toml": &fs.MapFile{Data: []byte(` +[auth] +enabled = true +site_url = "http://127.0.0.1:3000" +[auth.webauthn] +rp_display_name = "Supabase CLI" +rp_id = "localhost" +rp_origins = ["http://127.0.0.1:3000"] +`)}, + } + // Run test + assert.NoError(t, config.Load("", fsys)) + // Check result + assert.Nil(t, config.Auth.Passkey) + if assert.NotNil(t, config.Auth.Webauthn) { + assert.Equal(t, "localhost", config.Auth.Webauthn.RpId) + } + }) + + t.Run("passkey enabled requires webauthn section", func(t *testing.T) { + config := NewConfig() + fsys := fs.MapFS{ + "supabase/config.toml": &fs.MapFile{Data: []byte(` +[auth] +enabled = true +site_url = "http://127.0.0.1:3000" +[auth.passkey] +enabled = true +`)}, + } + // Run test + err := config.Load("", fsys) + // Check result + assert.ErrorContains(t, err, "Missing required config section: auth.webauthn") + }) + t.Run("passkey enabled requires rp_id", func(t *testing.T) { config := NewConfig() fsys := fs.MapFS{ @@ -111,13 +153,14 @@ enabled = true site_url = "http://127.0.0.1:3000" [auth.passkey] enabled = true +[auth.webauthn] rp_origins = ["http://127.0.0.1:3000"] `)}, } // Run test err := config.Load("", fsys) // Check result - assert.ErrorContains(t, err, "Missing required field in config: auth.passkey.rp_id") + assert.ErrorContains(t, err, "Missing required field in config: auth.webauthn.rp_id") }) t.Run("passkey enabled requires rp_origins", func(t *testing.T) { @@ -129,13 +172,14 @@ enabled = true site_url = "http://127.0.0.1:3000" [auth.passkey] enabled = true +[auth.webauthn] rp_id = "localhost" `)}, } // Run test err := config.Load("", fsys) // Check result - assert.ErrorContains(t, err, "Missing required field in config: auth.passkey.rp_origins") + assert.ErrorContains(t, err, "Missing required field in config: auth.webauthn.rp_origins") }) t.Run("parses experimental pgdelta config", func(t *testing.T) { diff --git a/pkg/config/constants.go b/pkg/config/constants.go index 08d572d2da..ec61d93b6a 100644 --- a/pkg/config/constants.go +++ b/pkg/config/constants.go @@ -12,6 +12,7 @@ const ( pg13 = "supabase/postgres:13.3.0" pg14 = "supabase/postgres:14.1.0.89" pg15 = "supabase/postgres:15.8.1.085" + pg17 = "supabase/postgres:17.6.1.106" deno1 = "supabase/edge-runtime:v1.68.4" ) diff --git a/pkg/config/templates/config.toml b/pkg/config/templates/config.toml index 2909f82230..97ed4e5665 100644 --- a/pkg/config/templates/config.toml +++ b/pkg/config/templates/config.toml @@ -180,6 +180,9 @@ password_requirements = "" # Configure passkey sign-ins. # [auth.passkey] # enabled = false + +# Configure WebAuthn relying party settings (required when passkey is enabled). +# [auth.webauthn] # rp_display_name = "Supabase" # rp_id = "localhost" # rp_origins = ["http://127.0.0.1:3000"] diff --git a/pkg/parser/state.go b/pkg/parser/state.go index f32a671315..47775390d1 100644 --- a/pkg/parser/state.go +++ b/pkg/parser/state.go @@ -46,14 +46,40 @@ func (s *ReadyState) Next(r rune, data []byte) State { case 'c': fallthrough case 'C': - offset := len(data) - len(BEGIN_ATOMIC) - if offset >= 0 && strings.EqualFold(string(data[offset:]), BEGIN_ATOMIC) { + if isBeginAtomic(data) { return &AtomicState{prev: s, delimiter: []byte(END_ATOMIC)} } } return s } +func isBeginAtomic(data []byte) bool { + offset := len(data) - len(BEGIN_ATOMIC) + if offset < 0 || !strings.EqualFold(string(data[offset:]), BEGIN_ATOMIC) { + return false + } + if offset > 0 { + r, _ := utf8.DecodeLastRune(data[:offset]) + if isIdentifierRune(r) { + return false + } + } + prefix := bytes.TrimRightFunc(data[:offset], unicode.IsSpace) + offset = len(prefix) - len("BEGIN") + if offset < 0 || !strings.EqualFold(string(prefix[offset:]), "BEGIN") { + return false + } + if offset == 0 { + return true + } + r, _ := utf8.DecodeLastRune(prefix[:offset]) + return !isIdentifierRune(r) +} + +func isIdentifierRune(r rune) bool { + return unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' +} + // Opened a line comment type CommentState struct{} diff --git a/pkg/parser/state_test.go b/pkg/parser/state_test.go index bae10fe190..ad6db9d26a 100644 --- a/pkg/parser/state_test.go +++ b/pkg/parser/state_test.go @@ -167,4 +167,44 @@ END ;`} checkSplit(t, sql) }) + + t.Run("ignores atomic in identifiers", func(t *testing.T) { + names := []string{ + "fn_atomic", + "atomic_fn", + "my_atomic_thing", + "xatomicx", + "fn_ATomiC", + } + for _, name := range names { + t.Run(name, func(t *testing.T) { + sql := []string{ + `CREATE OR REPLACE FUNCTION ` + name + `() +RETURNS void LANGUAGE plpgsql AS $$ +BEGIN + NULL; +END; +$$;`, + ` +SELECT 1;`, + } + checkSplit(t, sql) + }) + } + }) + + t.Run("does not treat schema-qualified atomic function names as begin atomic", func(t *testing.T) { + sql := []string{`CREATE OR REPLACE FUNCTION public.atomic_example() +RETURNS INTEGER +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN 1; +END; +$$;`, + ` +GRANT EXECUTE ON FUNCTION public.atomic_example() TO authenticated;`, + } + checkSplit(t, sql) + }) }