From c3ba797130fc7ad5e26ddc5fc5008745a704a7d3 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Wed, 27 May 2026 11:27:20 -0700 Subject: [PATCH 1/3] Refactoring to clean up Doltgres integration --- go/cmd/dolt/commands/engine/sqlengine.go | 39 ++++++++++++++++--- go/cmd/dolt/commands/sqlserver/server.go | 4 ++ go/libraries/doltcore/sqle/database.go | 34 +++++++++------- .../doltcore/sqle/database_provider.go | 31 ++++++++++++--- .../doltcore/sqle/statspro/controller.go | 2 +- .../doltcore/sqle/statspro/listener.go | 2 +- 6 files changed, 84 insertions(+), 28 deletions(-) diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index 1c72f6c0183..772bfd89c7b 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -97,6 +97,9 @@ type SqlEngineConfig struct { // Intended for embedded-driver use-cases that need to influence dbfactory / storage open behavior. DBLoadParams map[string]interface{} + // ProviderFactory controls how the DatabaseProvider is created. If nil, sqle.DoltProviderFactory is used. + ProviderFactory sqle.ProviderFactory + FatalBehavior dherrors.FatalBehavior } @@ -181,12 +184,31 @@ func NewSqlEngine( locations = append(locations, nil) } + factory := config.ProviderFactory + if factory == nil { + factory = sqle.DoltProviderFactory{} + } + b := env.GetDefaultInitBranch(mrEnv.Config()) - pro, err := sqle.NewDoltDatabaseProviderWithDatabases(b, mrEnv.FileSystem(), all, locations, config.EngineOverrides) + engineProvider, err := factory.NewProvider(b, mrEnv.FileSystem(), all, locations, config.EngineOverrides) if err != nil { return nil, err } - pro = pro.WithRemoteDialer(mrEnv.RemoteDialProvider()) + + // Extract the underlying *DoltDatabaseProvider for Dolt-specific configuration. For the + // default DoltProviderFactory the result IS a *DoltDatabaseProvider; custom factories + // (e.g. Doltgres) return a wrapper and implement DoltProviderUnwrapper to expose it. + var pro *sqle.DoltDatabaseProvider + switch p := engineProvider.(type) { + case *sqle.DoltDatabaseProvider: + pro = p + case sqle.DoltProviderUnwrapper: + pro = p.UnderlyingDoltProvider() + default: + return nil, fmt.Errorf("provider %T must be or wrap a *sqle.DoltDatabaseProvider", engineProvider) + } + + pro.SetRemoteDialer(mrEnv.RemoteDialProvider()) if config != nil && len(config.DBLoadParams) > 0 { pro.SetDBLoadParams(config.DBLoadParams) } @@ -200,7 +222,7 @@ func NewSqlEngine( sqlEngine := &SqlEngine{} // Create the engine - engine := gms.New(analyzer.NewBuilder(pro).AddOverrides(config.EngineOverrides).Build(), &gms.Config{ + engine := gms.New(analyzer.NewBuilder(engineProvider).AddOverrides(config.EngineOverrides).Build(), &gms.Config{ IsReadOnly: config.IsReadOnly, IsServerLocked: config.IsServerLocked, }).WithBackgroundThreads(bThreads) @@ -372,7 +394,7 @@ func applySystemVariables(vars sql.SystemVariableRegistry, cfg SystemVariables) func (se *SqlEngine) InitStats(ctx context.Context) error { // configuring stats depends on sessionBuilder // sessionBuilder needs ref to statsProv - pro := se.GetUnderlyingEngine().Analyzer.Catalog.DbProvider.(*sqle.DoltDatabaseProvider) + pro := se.GetUnderlyingEngine().Analyzer.Catalog.DbProvider.(dsess.DoltDatabaseProvider) sqlCtx, err := se.NewLocalContext(ctx) if err != nil { return err @@ -386,8 +408,13 @@ func (se *SqlEngine) InitStats(ctx context.Context) error { _, memOnly, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsMemoryOnly) sc.SetMemOnly(memOnly.(int8) == 1) - pro.InitDatabaseHooks = append(pro.InitDatabaseHooks, statspro.NewInitDatabaseHook(sc)) - pro.DropDatabaseHooks = append(pro.DropDatabaseHooks, statspro.NewDropDatabaseHook(sc)) + if adder, ok := se.GetUnderlyingEngine().Analyzer.Catalog.DbProvider.(interface { + AddInitDatabaseHook(sqle.InitDatabaseHook) + AddDropDatabaseHook(sqle.DropDatabaseHook) + }); ok { + adder.AddInitDatabaseHook(statspro.NewInitDatabaseHook(sc)) + adder.AddDropDatabaseHook(statspro.NewDropDatabaseHook(sc)) + } var sqlDbs []sql.Database for _, db := range dbs { diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 9936218887e..368426d4278 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -86,6 +86,9 @@ type Config struct { Controller *svcs.Controller ProtocolListenerFactory server.ProtocolListenerFunc MCP *MCPConfig + + // ProviderFactory controls how the DatabaseProvider is instantiated + ProviderFactory sqle.ProviderFactory } // Serve starts a MySQL-compatible server. Returns any errors that were encountered. @@ -277,6 +280,7 @@ func ConfigureServices( BinlogReplicaController: binlogreplication.DoltBinlogReplicaController, SkipRootUserInitialization: cfg.SkipRootUserInit, EngineOverrides: cfg.ServerConfig.Overrides(), + ProviderFactory: cfg.ProviderFactory, } return nil }, diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index ee53bda4ee5..56ee04b1eb5 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -85,6 +85,15 @@ type Database struct { revName string editOpts editor.Options revType dsess.RevisionType + + // SchemaWrap, when non-nil, is called by GetSchema and AllSchemas to transform the + // returned sql.DatabaseSchema before passing it to the caller. The first argument is + // the original requested schema name (preserving case); the second is the Database + // value with schemaName set to the stored (canonical) name. + // TODO: This is currently used by Doltgres to wrap databases, but we should be able + // to completely remove this in a future refactoring if Doltgres overrides + // GetTableInsensitive(). + SchemaWrap func(requestedName string, db Database) sql.DatabaseSchema } var _ dsess.SqlDatabase = Database{} @@ -2331,11 +2340,7 @@ func (db Database) GetSchema(ctx *sql.Context, schemaName string) (sql.DatabaseS for _, schema := range schemas { if strings.EqualFold(schema.Name, schemaName) { db.schemaName = schema.Name - handledSchema, err := HandleSchema(ctx, schemaName, db) - if err != nil { - return nil, false, err - } - return handledSchema, true, nil + return db.applySchemaWrap(schemaName, db), true, nil } } @@ -2343,16 +2348,19 @@ func (db Database) GetSchema(ctx *sql.Context, schemaName string) (sql.DatabaseS // We create it explicitly for new databases. if strings.EqualFold(schemaName, "public") { db.schemaName = "public" - return db, true, nil + return db.applySchemaWrap(schemaName, db), true, nil } return nil, false, nil } -// HandleSchema is used by Doltgres to intercept a database for the purposes of system tables. In Dolt, this just -// returns the given database. -var HandleSchema = func(ctx *sql.Context, schemaName string, db Database) (sql.DatabaseSchema, error) { - return db, nil +// applySchemaWrap calls SchemaWrap if set, otherwise returns sdb as-is. requestedName is +// the original schema name from the caller (before EqualFold normalization). +func (db Database) applySchemaWrap(requestedName string, sdb Database) sql.DatabaseSchema { + if db.SchemaWrap != nil { + return db.SchemaWrap(requestedName, sdb) + } + return sdb } // AllSchemas implements sql.SchemaDatabase @@ -2375,11 +2383,7 @@ func (db Database) AllSchemas(ctx *sql.Context) ([]sql.DatabaseSchema, error) { for i, schema := range schemas { sdb := db sdb.schemaName = schema.Name - handledDb, err := HandleSchema(ctx, schema.Name, sdb) - if err != nil { - return nil, err - } - dbSchemas[i] = handledDb + dbSchemas[i] = db.applySchemaWrap(schema.Name, sdb) } // For doltgres, the information_schema database should be a schema. diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 1f6525af4cc..b2187be19e1 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -87,6 +87,29 @@ type DoltDatabaseProvider struct { InitDatabaseHooks []InitDatabaseHook } +// ProviderFactory creates a sql.DatabaseProvider for use as the engine's analyzer catalog +// provider. +type ProviderFactory interface { + NewProvider(defaultBranch string, fs filesys.Filesys, databases []dsess.SqlDatabase, locations []filesys.Filesys, overrides sql.EngineOverrides) (sql.DatabaseProvider, error) +} + +// DoltProviderUnwrapper is an optional interface for sql.DatabaseProvider implementations +// that wrap a *DoltDatabaseProvider. NewSqlEngine uses it to access the underlying +// provider for Dolt-specific configuration (hooks, dialer, etc.) that is not part of +// the sql.DatabaseProvider interface. +type DoltProviderUnwrapper interface { + UnderlyingDoltProvider() *DoltDatabaseProvider +} + +// DoltProviderFactory is the default ProviderFactory used by Dolt. +type DoltProviderFactory struct{} + +var _ ProviderFactory = DoltProviderFactory{} + +func (DoltProviderFactory) NewProvider(defaultBranch string, fs filesys.Filesys, databases []dsess.SqlDatabase, locations []filesys.Filesys, overrides sql.EngineOverrides) (sql.DatabaseProvider, error) { + return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, databases, locations, overrides) +} + type remoteDialerWithGitCacheRoot struct { dbfactory.GRPCDialProvider root string @@ -220,11 +243,9 @@ func (p *DoltDatabaseProvider) WithDbFactoryUrl(url string) *DoltDatabaseProvide return &cp } -// WithRemoteDialer returns a copy of this provider with the dialer provided -func (p *DoltDatabaseProvider) WithRemoteDialer(provider dbfactory.GRPCDialProvider) *DoltDatabaseProvider { - cp := *p - cp.remoteDialer = provider - return &cp +// SetRemoteDialer sets the remote dialer on this provider in place and returns it. +func (p *DoltDatabaseProvider) SetRemoteDialer(provider dbfactory.GRPCDialProvider) { + p.remoteDialer = provider } // SetDBLoadParams sets optional DB load params for newly created / registered databases. The provided map is cloned. diff --git a/go/libraries/doltcore/sqle/statspro/controller.go b/go/libraries/doltcore/sqle/statspro/controller.go index 55470e6b61f..f91323652c5 100644 --- a/go/libraries/doltcore/sqle/statspro/controller.go +++ b/go/libraries/doltcore/sqle/statspro/controller.go @@ -67,7 +67,7 @@ func (k tableIndexesKey) String() string { type StatsController struct { logger *logrus.Logger - pro *sqle.DoltDatabaseProvider + pro dsess.DoltDatabaseProvider bgThreads *sql.BackgroundThreads statsBackingDb filesys.Filesys diff --git a/go/libraries/doltcore/sqle/statspro/listener.go b/go/libraries/doltcore/sqle/statspro/listener.go index 3eb29c4f0c3..b2d696e2559 100644 --- a/go/libraries/doltcore/sqle/statspro/listener.go +++ b/go/libraries/doltcore/sqle/statspro/listener.go @@ -149,7 +149,7 @@ func (sc *StatsController) Restart(ctx *sql.Context) error { } // Init should only be called once -func (sc *StatsController) Init(ctx context.Context, pro *sqle.DoltDatabaseProvider, ctxGen ctxFactory, dbs []sql.Database) error { +func (sc *StatsController) Init(ctx context.Context, pro dsess.DoltDatabaseProvider, ctxGen ctxFactory, dbs []sql.Database) error { sc.pro = pro ctxGenWrap := func(ctx context.Context) (*sql.Context, error) { From 1e2e85b4ff5de799075525eae6b47e98aa2b3028 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 1 Jun 2026 11:43:56 -0700 Subject: [PATCH 2/3] Adding a named interface: DatabaseHookRegistrar --- go/cmd/dolt/commands/engine/sqlengine.go | 5 +---- go/libraries/doltcore/sqle/database_provider.go | 10 ++++++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index 772bfd89c7b..5bc26d2e0db 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -408,10 +408,7 @@ func (se *SqlEngine) InitStats(ctx context.Context) error { _, memOnly, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsMemoryOnly) sc.SetMemOnly(memOnly.(int8) == 1) - if adder, ok := se.GetUnderlyingEngine().Analyzer.Catalog.DbProvider.(interface { - AddInitDatabaseHook(sqle.InitDatabaseHook) - AddDropDatabaseHook(sqle.DropDatabaseHook) - }); ok { + if adder, ok := se.GetUnderlyingEngine().Analyzer.Catalog.DbProvider.(sqle.DatabaseHookRegistrar); ok { adder.AddInitDatabaseHook(statspro.NewInitDatabaseHook(sc)) adder.AddDropDatabaseHook(statspro.NewDropDatabaseHook(sc)) } diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index b2187be19e1..ce79860700e 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -129,6 +129,7 @@ var _ sql.CollatedDatabaseProvider = (*DoltDatabaseProvider)(nil) var _ sql.ExternalStoredProcedureProvider = (*DoltDatabaseProvider)(nil) var _ sql.TableFunctionProvider = (*DoltDatabaseProvider)(nil) var _ dsess.DoltDatabaseProvider = (*DoltDatabaseProvider)(nil) +var _ DatabaseHookRegistrar = (*DoltDatabaseProvider)(nil) func (p *DoltDatabaseProvider) DefaultBranch() string { return p.defaultBranch @@ -805,6 +806,15 @@ func validateDBName(dbName string) error { type InitDatabaseHook func(ctx *sql.Context, pro *DoltDatabaseProvider, name string, env *env.DoltEnv, db dsess.SqlDatabase) error type DropDatabaseHook func(ctx *sql.Context, name string) +// DatabaseHookRegistrar is implemented by database providers that support registering +// hooks for database init and drop lifecycle events. +type DatabaseHookRegistrar interface { + // AddInitDatabaseHook adds an InitDatabaseHook that runs whenever a database is created. + AddInitDatabaseHook(InitDatabaseHook) + // AddDropDatabaseHook adds a DropDatabaseHook that runs whenever a database is dropped. + AddDropDatabaseHook(DropDatabaseHook) +} + // NewConfigureReplicationDatabaseHook sets up the hooks to push to a remote to replicate a newly created database. // // For a new database, this hook From 789010e1ff49d9df832089240ac5a096677e78a6 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 1 Jun 2026 12:36:23 -0700 Subject: [PATCH 3/3] Remove SchemaWrap function pointer --- go/libraries/doltcore/sqle/database.go | 59 ++------------------------ 1 file changed, 3 insertions(+), 56 deletions(-) diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 56ee04b1eb5..521ae1afd53 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -85,15 +85,6 @@ type Database struct { revName string editOpts editor.Options revType dsess.RevisionType - - // SchemaWrap, when non-nil, is called by GetSchema and AllSchemas to transform the - // returned sql.DatabaseSchema before passing it to the caller. The first argument is - // the original requested schema name (preserving case); the second is the Database - // value with schemaName set to the stored (canonical) name. - // TODO: This is currently used by Doltgres to wrap databases, but we should be able - // to completely remove this in a future refactoring if Doltgres overrides - // GetTableInsensitive(). - SchemaWrap func(requestedName string, db Database) sql.DatabaseSchema } var _ dsess.SqlDatabase = Database{} @@ -1534,13 +1525,6 @@ func (db Database) getTable(ctx *sql.Context, root doltdb.RootValue, tableName s } } - t, tblExists, err := db.checkForPgCatalogTable(ctx, tableName) - if err != nil { - return nil, false, err - } else if tblExists { - return t, tblExists, nil - } - tblName, tbl, tblExists, err := db.resolveUserTable(ctx, root, doltdb.TableName{Schema: db.schemaName, Name: tableName}) if err != nil { return nil, false, err @@ -1581,27 +1565,6 @@ func (db Database) getTable(ctx *sql.Context, root doltdb.RootValue, tableName s return table, true, nil } -// checkForPgCatalogTable checks if the table is of pg_catalog schema -// when the schema is not defined and the table name start with 'pg_'. -func (db Database) checkForPgCatalogTable(ctx *sql.Context, tableName string) (sql.Table, bool, error) { - if resolve.UseSearchPath && db.schemaName == "" && strings.HasPrefix(strings.ToLower(tableName), "pg_") { - sdb, foundSch, err := db.GetSchema(ctx, "pg_catalog") - if err != nil { - return nil, false, err - } - if foundSch { - tbl, foundTbl, err := sdb.GetTableInsensitive(ctx, tableName) - if err != nil { - return nil, false, err - } - if foundTbl { - return tbl, foundTbl, nil - } - } - } - return nil, false, nil -} - // resolveUserTable returns the table with the given name from the root given. The table name is resolved in a // case-insensitive manner. The table is returned along with its case-sensitive matched name. An error is returned if // no such table exists. @@ -1873,13 +1836,6 @@ func (db Database) DropTable(ctx *sql.Context, tableName string) error { // dropTable drops the table with the baseName given, without any business logic checks func (db Database) dropTable(ctx *sql.Context, tableName string) error { - _, tblExists, err := db.checkForPgCatalogTable(ctx, tableName) - if err != nil { - return err - } else if tblExists { - return sql.ErrDropTableNotSupported.New("pg_catalog") - } - ds := dsess.DSessFromSess(ctx.Session) if _, ok := ds.GetTemporaryTable(ctx, db.Name(), tableName); ok { ds.DropTemporaryTable(ctx, db.Name(), tableName) @@ -2340,7 +2296,7 @@ func (db Database) GetSchema(ctx *sql.Context, schemaName string) (sql.DatabaseS for _, schema := range schemas { if strings.EqualFold(schema.Name, schemaName) { db.schemaName = schema.Name - return db.applySchemaWrap(schemaName, db), true, nil + return db, true, nil } } @@ -2348,21 +2304,12 @@ func (db Database) GetSchema(ctx *sql.Context, schemaName string) (sql.DatabaseS // We create it explicitly for new databases. if strings.EqualFold(schemaName, "public") { db.schemaName = "public" - return db.applySchemaWrap(schemaName, db), true, nil + return db, true, nil } return nil, false, nil } -// applySchemaWrap calls SchemaWrap if set, otherwise returns sdb as-is. requestedName is -// the original schema name from the caller (before EqualFold normalization). -func (db Database) applySchemaWrap(requestedName string, sdb Database) sql.DatabaseSchema { - if db.SchemaWrap != nil { - return db.SchemaWrap(requestedName, sdb) - } - return sdb -} - // AllSchemas implements sql.SchemaDatabase func (db Database) AllSchemas(ctx *sql.Context) ([]sql.DatabaseSchema, error) { if !resolve.UseSearchPath { @@ -2383,7 +2330,7 @@ func (db Database) AllSchemas(ctx *sql.Context) ([]sql.DatabaseSchema, error) { for i, schema := range schemas { sdb := db sdb.schemaName = schema.Name - dbSchemas[i] = db.applySchemaWrap(schema.Name, sdb) + dbSchemas[i] = sdb } // For doltgres, the information_schema database should be a schema.