Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 91 additions & 66 deletions core/context.go
Comment thread
Hydrocharged marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ import (
// and may be refreshed at any point, including during the middle of a query. Callers should not assume that
// data stored in contextValues is persisted, and other types of data should not be added to contextValues.
type contextValues struct {
seqs map[string]*sequences.Collection
// TODO: all these collection fields need to be mapped by database name as seqs above
types *typecollection.TypeCollection
funcs *functions.Collection
procs *procedures.Collection
trigs *triggers.Collection
exts *extensions.Collection
casts *casts.Collection
seqs map[string]*sequences.Collection
types map[string]*typecollection.TypeCollection
funcs map[string]*functions.Collection
procs map[string]*procedures.Collection
trigs map[string]*triggers.Collection
exts map[string]*extensions.Collection
casts map[string]*casts.Collection

pgCatalogCache any
runner sql.StatementRunner
Expand Down Expand Up @@ -292,72 +291,75 @@ func GetExtensionsCollectionFromContext(ctx *sql.Context, database string) (*ext
if err != nil {
return nil, err
}
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
return nil, err
}
if cv.exts == nil {
cv.exts, err = extensions.LoadExtensions(ctx, root)
cv.exts = make(map[string]*extensions.Collection)
}
if len(database) == 0 {
database = ctx.GetCurrentDatabase()
}
if cv.exts[database] == nil {
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
return nil, err
}
} else if cv.exts.DiffersFrom(ctx, root) {
cv.exts, err = extensions.LoadExtensions(ctx, root)
cv.exts[database], err = extensions.LoadExtensions(ctx, root)
if err != nil {
return nil, err
}
}
return cv.exts, nil
return cv.exts[database], nil
}

// GetFunctionsCollectionFromContext returns the functions collection from the given context. Will always return a
// collection if no error is returned.
func GetFunctionsCollectionFromContext(ctx *sql.Context) (*functions.Collection, error) {
func GetFunctionsCollectionFromContext(ctx *sql.Context, database string) (*functions.Collection, error) {
cv, err := getContextValues(ctx)
if err != nil {
return nil, err
}
_, root, err := GetRootFromContext(ctx)
if err != nil {
return nil, err
}
if cv.funcs == nil {
cv.funcs, err = functions.LoadFunctions(ctx, root)
cv.funcs = make(map[string]*functions.Collection)
}
if len(database) == 0 {
database = ctx.GetCurrentDatabase()
}
if cv.funcs[database] == nil {
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
return nil, err
}
} else if cv.funcs.DiffersFrom(ctx, root) {
cv.funcs, err = functions.LoadFunctions(ctx, root)
cv.funcs[database], err = functions.LoadFunctions(ctx, root)
if err != nil {
return nil, err
}
}
return cv.funcs, nil
return cv.funcs[database], nil
}

// GetProceduresCollectionFromContext returns the procedures collection from the given context. Will always return a
// collection if no error is returned.
func GetProceduresCollectionFromContext(ctx *sql.Context) (*procedures.Collection, error) {
func GetProceduresCollectionFromContext(ctx *sql.Context, database string) (*procedures.Collection, error) {
cv, err := getContextValues(ctx)
if err != nil {
return nil, err
}
_, root, err := GetRootFromContext(ctx)
if err != nil {
return nil, err
}
if cv.procs == nil {
cv.procs, err = procedures.LoadProcedures(ctx, root)
cv.procs = make(map[string]*procedures.Collection)
}
if len(database) == 0 {
database = ctx.GetCurrentDatabase()
}
if cv.procs[database] == nil {
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
return nil, err
}
} else if cv.procs.DiffersFrom(ctx, root) {
cv.procs, err = procedures.LoadProcedures(ctx, root)
cv.procs[database], err = procedures.LoadProcedures(ctx, root)
if err != nil {
return nil, err
}
}
return cv.procs, nil
return cv.procs[database], nil
}

// GetSequencesCollectionFromContext returns the given sequence collection from the context for the database
Expand All @@ -371,6 +373,9 @@ func GetSequencesCollectionFromContext(ctx *sql.Context, database string) (*sequ
if cv.seqs == nil {
cv.seqs = make(map[string]*sequences.Collection)
}
if len(database) == 0 {
database = ctx.GetCurrentDatabase()
}
if cv.seqs[database] == nil {
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
Expand All @@ -391,62 +396,75 @@ func GetTriggersCollectionFromContext(ctx *sql.Context, database string) (*trigg
if err != nil {
return nil, err
}
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
return nil, err
}
if cv.trigs == nil {
cv.trigs, err = triggers.LoadTriggers(ctx, root)
cv.trigs = make(map[string]*triggers.Collection)
}
if len(database) == 0 {
database = ctx.GetCurrentDatabase()
}
if cv.trigs[database] == nil {
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
return nil, err
}
} else if cv.trigs.DiffersFrom(ctx, root) {
cv.trigs, err = triggers.LoadTriggers(ctx, root)
cv.trigs[database], err = triggers.LoadTriggers(ctx, root)
if err != nil {
return nil, err
}
}
return cv.trigs, nil
return cv.trigs[database], nil
}

// GetTypesCollectionFromContext returns the given type collection from the context.
// Will always return a collection if no error is returned.
func GetTypesCollectionFromContext(ctx *sql.Context) (*typecollection.TypeCollection, error) {
func GetTypesCollectionFromContext(ctx *sql.Context, database string) (*typecollection.TypeCollection, error) {
cv, err := getContextValues(ctx)
if err != nil {
return nil, err
}
if cv.types == nil {
_, root, err := GetRootFromContext(ctx)
cv.types = make(map[string]*typecollection.TypeCollection)
}
if len(database) == 0 {
database = ctx.GetCurrentDatabase()
}
if cv.types[database] == nil {
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
return nil, err
}
cv.types, err = typecollection.LoadTypes(ctx, root)
cv.types[database], err = typecollection.LoadTypes(ctx, root)
if err != nil {
return nil, err
}
}
return cv.types, nil
return cv.types[database], nil
}

// GetCastsCollectionFromContext returns the given casts collection from the context.
// Will always return a collection if no error is returned.
func GetCastsCollectionFromContext(ctx *sql.Context) (*casts.Collection, error) {
func GetCastsCollectionFromContext(ctx *sql.Context, database string) (*casts.Collection, error) {
cv, err := getContextValues(ctx)
if err != nil {
return nil, err
}
if cv.casts == nil {
_, root, err := GetRootFromContext(ctx)
cv.casts = make(map[string]*casts.Collection)
}
if len(database) == 0 {
database = ctx.GetCurrentDatabase()
}
if cv.casts[database] == nil {
_, root, err := getRootFromContextForDatabase(ctx, database)
if err != nil {
return nil, err
}
cv.casts, err = casts.LoadCasts(ctx, root)
cv.casts[database], err = casts.LoadCasts(ctx, root)
if err != nil {
return nil, err
}
}
return cv.casts, nil
return cv.casts[database], nil
}

// CloseContextRootFinalizer finalizes any changes persisted within the context by writing them to the working root.
Expand Down Expand Up @@ -490,52 +508,59 @@ func updateSessionRootForDatabase(ctx *sql.Context, db string, cv *contextValues
delete(cv.seqs, db)
}

if cv.funcs != nil && cv.funcs.DiffersFrom(ctx, root) {
retRoot, err := cv.funcs.UpdateRoot(ctx, newRoot)
if cv.funcs != nil && cv.funcs[db] != nil && cv.funcs[db].DiffersFrom(ctx, root) {
retRoot, err := cv.funcs[db].UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
cv.funcs = nil
delete(cv.funcs, db)
}

if cv.procs != nil && cv.procs.DiffersFrom(ctx, root) {
retRoot, err := cv.procs.UpdateRoot(ctx, newRoot)
if cv.procs != nil && cv.procs[db] != nil && cv.procs[db].DiffersFrom(ctx, root) {
retRoot, err := cv.procs[db].UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
cv.procs = nil
delete(cv.procs, db)
}

if cv.trigs != nil && cv.trigs.DiffersFrom(ctx, root) {
retRoot, err := cv.trigs.UpdateRoot(ctx, newRoot)
if cv.trigs != nil && cv.trigs[db] != nil && cv.trigs[db].DiffersFrom(ctx, root) {
retRoot, err := cv.trigs[db].UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
cv.trigs = nil
delete(cv.trigs, db)
}

if cv.exts != nil && cv.exts.DiffersFrom(ctx, root) {
retRoot, err := cv.exts.UpdateRoot(ctx, newRoot)
if cv.exts != nil && cv.exts[db] != nil && cv.exts[db].DiffersFrom(ctx, root) {
retRoot, err := cv.exts[db].UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
cv.exts = nil
delete(cv.exts, db)
}

if cv.types != nil {
retRoot, err := cv.types.UpdateRoot(ctx, newRoot)
if cv.types != nil && cv.types[db] != nil {
retRoot, err := cv.types[db].UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
cv.types = nil
delete(cv.types, db)
}

// TODO: need to be able to persist cv.casts without an empty collection updating the root (no value != empty value)
if cv.casts != nil && cv.casts[db] != nil && cv.casts[db].DiffersFrom(ctx, root) {
retRoot, err := cv.casts[db].UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
delete(cv.casts, db)
}

// Setting the session working root doesn't do a check to see if anything actually changed or not before marking that
// branch state dirty, and dolt only allows a single dirty working set per commit. So it's important here to only
Expand Down
6 changes: 3 additions & 3 deletions core/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ func Init() {
id.RegisterListener(sequenceIDListener{}, id.Section_Table)
typecollection.GetSqlTableFromContext = GetSqlTableFromContext
typecollection.GetSchemaName = GetSchemaName
pgtypes.GetTypesCollectionFromContext = func(ctx *sql.Context) (pgtypes.TypeCollection, error) {
return GetTypesCollectionFromContext(ctx)
pgtypes.GetTypesCollectionFromContext = func(ctx *sql.Context, database string) (pgtypes.TypeCollection, error) {
return GetTypesCollectionFromContext(ctx, database)
}
pgtypes.GetAssignmentCast = func(ctx *sql.Context, sourceType *pgtypes.DoltgresType, targetType *pgtypes.DoltgresType) (pgtypes.Cast, error) {
castsColl, err := GetCastsCollectionFromContext(ctx)
castsColl, err := GetCastsCollectionFromContext(ctx, "")
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion server/analyzer/foreign_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func validateForeignKeyDefinition(ctx *sql.Context, fkDef sql.ForeignKeyConstrai
var castsColl *casts.Collection
if len(fkDef.Columns) > 0 {
var err error
castsColl, err = core.GetCastsCollectionFromContext(ctx)
// TODO: which database is this supposed to use?
castsColl, err = core.GetCastsCollectionFromContext(ctx, "")
Comment thread
Hydrocharged marked this conversation as resolved.
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions server/analyzer/resolve_routine_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ import (
func ResolveProcedureDefaults(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
switch n := node.(type) {
case *pgnodes.Call:
procCollection, err := core.GetProceduresCollectionFromContext(ctx)
procCollection, err := core.GetProceduresCollectionFromContext(ctx, "")
if err != nil {
return nil, transform.SameTree, err
}
typesCollection, err := core.GetTypesCollectionFromContext(ctx)
typesCollection, err := core.GetTypesCollectionFromContext(ctx, "")
if err != nil {
return nil, transform.SameTree, err
}
Expand Down
6 changes: 5 additions & 1 deletion server/analyzer/resolve_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,11 @@ func resolveType(ctx *sql.Context, db sql.Database, typ *pgtypes.DoltgresType) (
if typ.IsResolvedType() {
return typ, nil
}
typs, err := core.GetTypesCollectionFromContext(ctx)
var dbname string
if db != nil {
dbname = db.Name()
}
typs, err := core.GetTypesCollectionFromContext(ctx, dbname)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion server/doltgres_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ func (h *DoltgresHandler) convertBindParameters(ctx *sql.Context, types []uint32
if err != nil {
return nil, err
}
typeColl, err := core.GetTypesCollectionFromContext(ctx)
typeColl, err := core.GetTypesCollectionFromContext(ctx, "")
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions server/expression/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (array *Array) Children() []sql.Expression {
func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) {
resultTyp := array.coercedType.ArrayBaseType()
values := make([]any, len(array.children))
castsColl, err := core.GetCastsCollectionFromContext(ctx)
castsColl, err := core.GetCastsCollectionFromContext(ctx, "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -192,7 +192,7 @@ func (array *Array) getTargetType(ctx *sql.Context, children ...sql.Expression)
if schemaName == "" {
schemaName, _ = core.GetCurrentSchema(ctx)
}
if typeColl, tcErr := core.GetTypesCollectionFromContext(ctx); tcErr == nil && typeColl != nil {
if typeColl, tcErr := core.GetTypesCollectionFromContext(ctx, ""); tcErr == nil && typeColl != nil {
if resolved, rErr := typeColl.GetType(ctx, id.NewType(schemaName, targetType.ID.TypeName())); rErr == nil && resolved != nil {
targetType = resolved
}
Expand Down
2 changes: 1 addition & 1 deletion server/expression/assignment_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (ac *AssignmentCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
if err != nil || val == nil {
return val, err
}
castsColl, err := core.GetCastsCollectionFromContext(ctx)
castsColl, err := core.GetCastsCollectionFromContext(ctx, "")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion server/expression/column_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (expr *ColumnAccess) WithChildren(ctx *sql.Context, children ...sql.Express
return nil, errors.New("column access is only valid for Doltgres types")
}
if !doltgresType.IsResolvedType() {
typeColl, err := core.GetTypesCollectionFromContext(ctx)
typeColl, err := core.GetTypesCollectionFromContext(ctx, "")
if err != nil {
return nil, err
}
Expand Down
Loading
Loading