diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index 1c72f6c0183..5bc26d2e0db 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -97,6 +97,9 @@ type SqlEngineConfig struct { // Intended for embedded-driver use-cases that need to influence dbfactory / storage open behavior. DBLoadParams map[string]interface{} + // ProviderFactory controls how the DatabaseProvider is created. If nil, sqle.DoltProviderFactory is used. + ProviderFactory sqle.ProviderFactory + FatalBehavior dherrors.FatalBehavior } @@ -181,12 +184,31 @@ func NewSqlEngine( locations = append(locations, nil) } + factory := config.ProviderFactory + if factory == nil { + factory = sqle.DoltProviderFactory{} + } + b := env.GetDefaultInitBranch(mrEnv.Config()) - pro, err := sqle.NewDoltDatabaseProviderWithDatabases(b, mrEnv.FileSystem(), all, locations, config.EngineOverrides) + engineProvider, err := factory.NewProvider(b, mrEnv.FileSystem(), all, locations, config.EngineOverrides) if err != nil { return nil, err } - pro = pro.WithRemoteDialer(mrEnv.RemoteDialProvider()) + + // Extract the underlying *DoltDatabaseProvider for Dolt-specific configuration. For the + // default DoltProviderFactory the result IS a *DoltDatabaseProvider; custom factories + // (e.g. Doltgres) return a wrapper and implement DoltProviderUnwrapper to expose it. + var pro *sqle.DoltDatabaseProvider + switch p := engineProvider.(type) { + case *sqle.DoltDatabaseProvider: + pro = p + case sqle.DoltProviderUnwrapper: + pro = p.UnderlyingDoltProvider() + default: + return nil, fmt.Errorf("provider %T must be or wrap a *sqle.DoltDatabaseProvider", engineProvider) + } + + pro.SetRemoteDialer(mrEnv.RemoteDialProvider()) if config != nil && len(config.DBLoadParams) > 0 { pro.SetDBLoadParams(config.DBLoadParams) } @@ -200,7 +222,7 @@ func NewSqlEngine( sqlEngine := &SqlEngine{} // Create the engine - engine := gms.New(analyzer.NewBuilder(pro).AddOverrides(config.EngineOverrides).Build(), &gms.Config{ + engine := gms.New(analyzer.NewBuilder(engineProvider).AddOverrides(config.EngineOverrides).Build(), &gms.Config{ IsReadOnly: config.IsReadOnly, IsServerLocked: config.IsServerLocked, }).WithBackgroundThreads(bThreads) @@ -372,7 +394,7 @@ func applySystemVariables(vars sql.SystemVariableRegistry, cfg SystemVariables) func (se *SqlEngine) InitStats(ctx context.Context) error { // configuring stats depends on sessionBuilder // sessionBuilder needs ref to statsProv - pro := se.GetUnderlyingEngine().Analyzer.Catalog.DbProvider.(*sqle.DoltDatabaseProvider) + pro := se.GetUnderlyingEngine().Analyzer.Catalog.DbProvider.(dsess.DoltDatabaseProvider) sqlCtx, err := se.NewLocalContext(ctx) if err != nil { return err @@ -386,8 +408,10 @@ func (se *SqlEngine) InitStats(ctx context.Context) error { _, memOnly, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsMemoryOnly) sc.SetMemOnly(memOnly.(int8) == 1) - pro.InitDatabaseHooks = append(pro.InitDatabaseHooks, statspro.NewInitDatabaseHook(sc)) - pro.DropDatabaseHooks = append(pro.DropDatabaseHooks, statspro.NewDropDatabaseHook(sc)) + if adder, ok := se.GetUnderlyingEngine().Analyzer.Catalog.DbProvider.(sqle.DatabaseHookRegistrar); ok { + adder.AddInitDatabaseHook(statspro.NewInitDatabaseHook(sc)) + adder.AddDropDatabaseHook(statspro.NewDropDatabaseHook(sc)) + } var sqlDbs []sql.Database for _, db := range dbs { diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 9936218887e..368426d4278 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -86,6 +86,9 @@ type Config struct { Controller *svcs.Controller ProtocolListenerFactory server.ProtocolListenerFunc MCP *MCPConfig + + // ProviderFactory controls how the DatabaseProvider is instantiated + ProviderFactory sqle.ProviderFactory } // Serve starts a MySQL-compatible server. Returns any errors that were encountered. @@ -277,6 +280,7 @@ func ConfigureServices( BinlogReplicaController: binlogreplication.DoltBinlogReplicaController, SkipRootUserInitialization: cfg.SkipRootUserInit, EngineOverrides: cfg.ServerConfig.Overrides(), + ProviderFactory: cfg.ProviderFactory, } return nil }, diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index ee53bda4ee5..521ae1afd53 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -1525,13 +1525,6 @@ func (db Database) getTable(ctx *sql.Context, root doltdb.RootValue, tableName s } } - t, tblExists, err := db.checkForPgCatalogTable(ctx, tableName) - if err != nil { - return nil, false, err - } else if tblExists { - return t, tblExists, nil - } - tblName, tbl, tblExists, err := db.resolveUserTable(ctx, root, doltdb.TableName{Schema: db.schemaName, Name: tableName}) if err != nil { return nil, false, err @@ -1572,27 +1565,6 @@ func (db Database) getTable(ctx *sql.Context, root doltdb.RootValue, tableName s return table, true, nil } -// checkForPgCatalogTable checks if the table is of pg_catalog schema -// when the schema is not defined and the table name start with 'pg_'. -func (db Database) checkForPgCatalogTable(ctx *sql.Context, tableName string) (sql.Table, bool, error) { - if resolve.UseSearchPath && db.schemaName == "" && strings.HasPrefix(strings.ToLower(tableName), "pg_") { - sdb, foundSch, err := db.GetSchema(ctx, "pg_catalog") - if err != nil { - return nil, false, err - } - if foundSch { - tbl, foundTbl, err := sdb.GetTableInsensitive(ctx, tableName) - if err != nil { - return nil, false, err - } - if foundTbl { - return tbl, foundTbl, nil - } - } - } - return nil, false, nil -} - // resolveUserTable returns the table with the given name from the root given. The table name is resolved in a // case-insensitive manner. The table is returned along with its case-sensitive matched name. An error is returned if // no such table exists. @@ -1864,13 +1836,6 @@ func (db Database) DropTable(ctx *sql.Context, tableName string) error { // dropTable drops the table with the baseName given, without any business logic checks func (db Database) dropTable(ctx *sql.Context, tableName string) error { - _, tblExists, err := db.checkForPgCatalogTable(ctx, tableName) - if err != nil { - return err - } else if tblExists { - return sql.ErrDropTableNotSupported.New("pg_catalog") - } - ds := dsess.DSessFromSess(ctx.Session) if _, ok := ds.GetTemporaryTable(ctx, db.Name(), tableName); ok { ds.DropTemporaryTable(ctx, db.Name(), tableName) @@ -2331,11 +2296,7 @@ func (db Database) GetSchema(ctx *sql.Context, schemaName string) (sql.DatabaseS for _, schema := range schemas { if strings.EqualFold(schema.Name, schemaName) { db.schemaName = schema.Name - handledSchema, err := HandleSchema(ctx, schemaName, db) - if err != nil { - return nil, false, err - } - return handledSchema, true, nil + return db, true, nil } } @@ -2349,12 +2310,6 @@ func (db Database) GetSchema(ctx *sql.Context, schemaName string) (sql.DatabaseS return nil, false, nil } -// HandleSchema is used by Doltgres to intercept a database for the purposes of system tables. In Dolt, this just -// returns the given database. -var HandleSchema = func(ctx *sql.Context, schemaName string, db Database) (sql.DatabaseSchema, error) { - return db, nil -} - // AllSchemas implements sql.SchemaDatabase func (db Database) AllSchemas(ctx *sql.Context) ([]sql.DatabaseSchema, error) { if !resolve.UseSearchPath { @@ -2375,11 +2330,7 @@ func (db Database) AllSchemas(ctx *sql.Context) ([]sql.DatabaseSchema, error) { for i, schema := range schemas { sdb := db sdb.schemaName = schema.Name - handledDb, err := HandleSchema(ctx, schema.Name, sdb) - if err != nil { - return nil, err - } - dbSchemas[i] = handledDb + dbSchemas[i] = sdb } // For doltgres, the information_schema database should be a schema. diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 1f6525af4cc..ce79860700e 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -87,6 +87,29 @@ type DoltDatabaseProvider struct { InitDatabaseHooks []InitDatabaseHook } +// ProviderFactory creates a sql.DatabaseProvider for use as the engine's analyzer catalog +// provider. +type ProviderFactory interface { + NewProvider(defaultBranch string, fs filesys.Filesys, databases []dsess.SqlDatabase, locations []filesys.Filesys, overrides sql.EngineOverrides) (sql.DatabaseProvider, error) +} + +// DoltProviderUnwrapper is an optional interface for sql.DatabaseProvider implementations +// that wrap a *DoltDatabaseProvider. NewSqlEngine uses it to access the underlying +// provider for Dolt-specific configuration (hooks, dialer, etc.) that is not part of +// the sql.DatabaseProvider interface. +type DoltProviderUnwrapper interface { + UnderlyingDoltProvider() *DoltDatabaseProvider +} + +// DoltProviderFactory is the default ProviderFactory used by Dolt. +type DoltProviderFactory struct{} + +var _ ProviderFactory = DoltProviderFactory{} + +func (DoltProviderFactory) NewProvider(defaultBranch string, fs filesys.Filesys, databases []dsess.SqlDatabase, locations []filesys.Filesys, overrides sql.EngineOverrides) (sql.DatabaseProvider, error) { + return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, databases, locations, overrides) +} + type remoteDialerWithGitCacheRoot struct { dbfactory.GRPCDialProvider root string @@ -106,6 +129,7 @@ var _ sql.CollatedDatabaseProvider = (*DoltDatabaseProvider)(nil) var _ sql.ExternalStoredProcedureProvider = (*DoltDatabaseProvider)(nil) var _ sql.TableFunctionProvider = (*DoltDatabaseProvider)(nil) var _ dsess.DoltDatabaseProvider = (*DoltDatabaseProvider)(nil) +var _ DatabaseHookRegistrar = (*DoltDatabaseProvider)(nil) func (p *DoltDatabaseProvider) DefaultBranch() string { return p.defaultBranch @@ -220,11 +244,9 @@ func (p *DoltDatabaseProvider) WithDbFactoryUrl(url string) *DoltDatabaseProvide return &cp } -// WithRemoteDialer returns a copy of this provider with the dialer provided -func (p *DoltDatabaseProvider) WithRemoteDialer(provider dbfactory.GRPCDialProvider) *DoltDatabaseProvider { - cp := *p - cp.remoteDialer = provider - return &cp +// SetRemoteDialer sets the remote dialer on this provider in place and returns it. +func (p *DoltDatabaseProvider) SetRemoteDialer(provider dbfactory.GRPCDialProvider) { + p.remoteDialer = provider } // SetDBLoadParams sets optional DB load params for newly created / registered databases. The provided map is cloned. @@ -784,6 +806,15 @@ func validateDBName(dbName string) error { type InitDatabaseHook func(ctx *sql.Context, pro *DoltDatabaseProvider, name string, env *env.DoltEnv, db dsess.SqlDatabase) error type DropDatabaseHook func(ctx *sql.Context, name string) +// DatabaseHookRegistrar is implemented by database providers that support registering +// hooks for database init and drop lifecycle events. +type DatabaseHookRegistrar interface { + // AddInitDatabaseHook adds an InitDatabaseHook that runs whenever a database is created. + AddInitDatabaseHook(InitDatabaseHook) + // AddDropDatabaseHook adds a DropDatabaseHook that runs whenever a database is dropped. + AddDropDatabaseHook(DropDatabaseHook) +} + // NewConfigureReplicationDatabaseHook sets up the hooks to push to a remote to replicate a newly created database. // // For a new database, this hook diff --git a/go/libraries/doltcore/sqle/statspro/controller.go b/go/libraries/doltcore/sqle/statspro/controller.go index 55470e6b61f..f91323652c5 100644 --- a/go/libraries/doltcore/sqle/statspro/controller.go +++ b/go/libraries/doltcore/sqle/statspro/controller.go @@ -67,7 +67,7 @@ func (k tableIndexesKey) String() string { type StatsController struct { logger *logrus.Logger - pro *sqle.DoltDatabaseProvider + pro dsess.DoltDatabaseProvider bgThreads *sql.BackgroundThreads statsBackingDb filesys.Filesys diff --git a/go/libraries/doltcore/sqle/statspro/listener.go b/go/libraries/doltcore/sqle/statspro/listener.go index 3eb29c4f0c3..b2d696e2559 100644 --- a/go/libraries/doltcore/sqle/statspro/listener.go +++ b/go/libraries/doltcore/sqle/statspro/listener.go @@ -149,7 +149,7 @@ func (sc *StatsController) Restart(ctx *sql.Context) error { } // Init should only be called once -func (sc *StatsController) Init(ctx context.Context, pro *sqle.DoltDatabaseProvider, ctxGen ctxFactory, dbs []sql.Database) error { +func (sc *StatsController) Init(ctx context.Context, pro dsess.DoltDatabaseProvider, ctxGen ctxFactory, dbs []sql.Database) error { sc.pro = pro ctxGenWrap := func(ctx context.Context) (*sql.Context, error) {