Skip to content

Commit 9e4b3d1

Browse files
authored
perf(optimizer): speed up qualify by ~24% and optimize by ~17% (#7724)
- Add Expression.meta_get(key, default), a non-allocating meta read: the meta property allocates {} on first read, so hot per-node read paths (normalize_identifiers' prune, simplify's FINAL checks) were allocating a dict on every AST node and slowing every subsequent copy(). Convert all read-only .meta.get(...) sites; writes still go through meta. - Scope._collect: a single tuple-isinstance gate (COLLECTIBLE_TYPES) lets non-collectible nodes skip the classification chain, paid 3x per query (qualify_tables, qualify_columns, validate_qualify_columns). - walk_in_scope: only CTE/Query nodes can start child scopes, so one isinstance gate replaces up to 4 checks + 2 parent loads per node. - Scope.local_columns: identity-based set instead of structural-equality set of column nodes. - parse_identifier: skip the tokenizer/parser round-trip for names matching SAFE_IDENTIFIER_RE (provably identical output). - _expand_using: early-exit for scopes with no joins / no USING joins, preserving the missing-source validation. - quote_identifiers: direct walk over Identifier nodes instead of transform(), which only does replacement bookkeeping here. Qualified SQL is byte-identical on all TPC-H/TPC-DS fixture queries. Pure Python and mypyc release builds both pass the full suite.
1 parent 6ae4b49 commit 9e4b3d1

12 files changed

Lines changed: 94 additions & 44 deletions

File tree

sqlglot/dialects/bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ def normalize_identifier(self, expression: E) -> E:
147147
or (
148148
isinstance(parent, exp.Table)
149149
and parent.db
150-
and (parent.meta.get("quoted_table") or not parent.meta.get("maybe_column"))
150+
and (parent.meta_get("quoted_table") or not parent.meta_get("maybe_column"))
151151
)
152-
or expression.meta.get("is_table")
152+
or expression.meta_get("is_table")
153153
)
154154
if not case_sensitive:
155155
expression.set("this", expression.this.lower())

sqlglot/expressions/builders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Ide
319319
Returns:
320320
The identifier ast node.
321321
"""
322+
if isinstance(name, str) and SAFE_IDENTIFIER_RE.match(name):
323+
# Simple names parse to a single unquoted identifier in all dialects, so we can
324+
# avoid the tokenizer/parser round-trip for them.
325+
return Identifier(this=name, quoted=False)
326+
322327
try:
323328
expression = maybe_parse(name, dialect=dialect, into=Identifier)
324329
except (ParseError, TokenError):
@@ -822,7 +827,7 @@ def replace_tables(
822827
mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()}
823828

824829
def _replace_tables(node: Expr) -> Expr:
825-
if isinstance(node, Table) and node.meta.get("replace") is not False:
830+
if isinstance(node, Table) and node.meta_get("replace") is not False:
826831
original = normalize_table_name(node, dialect=dialect)
827832
new_name = mapping.get(original)
828833

sqlglot/expressions/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def is_leaf(self) -> bool:
240240
def meta(self) -> dict[str, t.Any]:
241241
raise NotImplementedError
242242

243+
def meta_get(self, key: str, default: t.Any = None) -> t.Any:
244+
raise NotImplementedError
245+
243246
def __deepcopy__(self, memo: t.Any) -> Expr:
244247
raise NotImplementedError
245248

@@ -990,6 +993,11 @@ def meta(self) -> dict[str, t.Any]:
990993
self._meta = {}
991994
return self._meta
992995

996+
def meta_get(self, key: str, default: t.Any = None) -> t.Any:
997+
"""Reads a meta value without allocating the meta dict (unlike the `meta` property)."""
998+
meta = self._meta
999+
return meta.get(key, default) if meta is not None else default
1000+
9931001
def __deepcopy__(self, memo: t.Any) -> Expr:
9941002
root = self.__class__()
9951003
stack: list[tuple[Expr, Expr]] = [(self, root)]

sqlglot/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4554,7 +4554,7 @@ def function_fallback_sql(self, expression: exp.Func) -> str:
45544554
args.append(arg_value)
45554555

45564556
if self.dialect.PRESERVE_ORIGINAL_NAMES:
4557-
name = (expression._meta and expression.meta.get("name")) or expression.sql_name()
4557+
name = expression.meta_get("name") or expression.sql_name()
45584558
else:
45594559
name = expression.sql_name()
45604560

@@ -5258,7 +5258,7 @@ def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, te
52585258
)
52595259
return self.sql(this)
52605260

5261-
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
5261+
if self.IGNORE_NULLS_IN_FUNC and not expression.meta_get("inline"):
52625262
if self.IGNORE_NULLS_BEFORE_ORDER:
52635263
# The first modifier here will be the one closest to the AggFunc's arg
52645264
mods = sorted(

sqlglot/generators/bigquery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _levenshtein_sql(self: BigQueryGenerator, expression: exp.Levenshtein) -> st
201201

202202

203203
def _json_extract_sql(self: BigQueryGenerator, expression: JSON_EXTRACT_TYPE) -> str:
204-
name = (expression._meta and expression.meta.get("name")) or expression.sql_name()
204+
name = expression.meta_get("name") or expression.sql_name()
205205
upper = name.upper()
206206

207207
dquote_escaping = upper in DQUOTES_ESCAPING_JSON_FUNCTIONS
@@ -565,7 +565,7 @@ def mod_sql(self, expression: exp.Mod) -> str:
565565
)
566566

567567
def column_parts(self, expression: exp.Column) -> str:
568-
if expression.meta.get("quoted_column"):
568+
if expression.meta_get("quoted_column"):
569569
# If a column reference is of the form `dataset.table`.name, we need
570570
# to preserve the quoted table path, otherwise the reference breaks
571571
table_parts = ".".join(p.name for p in expression.parts[:-1])
@@ -583,7 +583,7 @@ def table_parts(self, expression: exp.Table) -> str:
583583
#
584584
# - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x.y` -> cross join
585585
# - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x`.`y` -> implicit unnest
586-
if expression.meta.get("quoted_table"):
586+
if expression.meta_get("quoted_table"):
587587
table_parts = ".".join(p.name for p in expression.parts)
588588
return self.sql(exp.Identifier(this=table_parts, quoted=True))
589589

sqlglot/optimizer/annotate_types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def annotate_scope(self, scope: Scope) -> None:
400400
isinstance(source, Scope)
401401
and isinstance(source.expression, exp.Query)
402402
and (
403-
source.expression.meta.get("query_type") or exp.DType.UNKNOWN.into_expr()
403+
source.expression.meta_get("query_type") or exp.DType.UNKNOWN.into_expr()
404404
).is_type(exp.DType.STRUCT)
405405
):
406406
self._set_type(table_column, source.expression.meta["query_type"])
@@ -497,7 +497,7 @@ def _annotate_expression(
497497
else:
498498
self._set_type(expr, exp.DType.UNKNOWN)
499499

500-
if expr.is_type(exp.DType.JSON) and (dot_parts := expr.meta.get("dot_parts")):
500+
if expr.is_type(exp.DType.JSON) and (dot_parts := expr.meta_get("dot_parts")):
501501
# JSON dot access is case sensitive across all dialects, so we need to undo the normalization.
502502
i = iter(dot_parts)
503503
parent = expr.parent
@@ -740,7 +740,7 @@ def _annotate_binary(self, expression: B) -> B:
740740
self._annotate_by_args(expression, left, right)
741741

742742
if isinstance(expression, exp.Is) or (
743-
left.meta.get("nonnull") is True and right.meta.get("nonnull") is True
743+
left.meta_get("nonnull") is True and right.meta_get("nonnull") is True
744744
):
745745
expression.meta["nonnull"] = True
746746

@@ -752,7 +752,7 @@ def _annotate_unary(self, expression: E) -> E:
752752
else:
753753
self._set_type(expression, expression.this.type)
754754

755-
if expression.this.meta.get("nonnull") is True:
755+
if expression.this.meta_get("nonnull") is True:
756756
expression.meta["nonnull"] = True
757757

758758
return expression

sqlglot/optimizer/normalize_identifiers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def normalize_identifiers(expression, dialect=None, store_original_column_identi
6363
if isinstance(expression, str):
6464
expression = exp.parse_identifier(expression, dialect=dialect)
6565

66-
for node in expression.walk(prune=lambda n: bool(n.meta.get("case_sensitive"))):
67-
if not node.meta.get("case_sensitive"):
66+
for node in expression.walk(prune=lambda n: bool(n.meta_get("case_sensitive"))):
67+
if not node.meta_get("case_sensitive"):
6868
if store_original_column_identifiers and isinstance(node, exp.Column):
6969
# TODO: This does not handle non-column cases, e.g PARSE_JSON(...).key
7070
parent = node
@@ -73,6 +73,7 @@ def normalize_identifiers(expression, dialect=None, store_original_column_identi
7373

7474
node.meta["dot_parts"] = [p.name for p in parent.parts]
7575

76-
dialect.normalize_identifier(node)
76+
if isinstance(node, exp.Identifier):
77+
dialect.normalize_identifier(node)
7778

7879
return expression

sqlglot/optimizer/qualify_columns.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ def validate_qualify_columns(expression: E, sql: str | None = None) -> E:
124124
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
125125
column = scope.external_columns[0]
126126
for_table = f" for table: '{column.table}'" if column.table else ""
127-
line = column.this.meta.get("line")
128-
col = column.this.meta.get("col")
129-
start = column.this.meta.get("start")
130-
end = column.this.meta.get("end")
127+
line = column.this.meta_get("line")
128+
col = column.this.meta_get("col")
129+
start = column.this.meta_get("start")
130+
end = column.this.meta_get("end")
131131

132132
error_msg = f"Column '{column.name}' could not be resolved{for_table}."
133133
if line and col:
@@ -142,10 +142,10 @@ def validate_qualify_columns(expression: E, sql: str | None = None) -> E:
142142

143143
if all_unqualified_columns:
144144
first_column = all_unqualified_columns[0]
145-
line = first_column.this.meta.get("line")
146-
col = first_column.this.meta.get("col")
147-
start = first_column.this.meta.get("start")
148-
end = first_column.this.meta.get("end")
145+
line = first_column.this.meta_get("line")
146+
col = first_column.this.meta_get("col")
147+
start = first_column.this.meta_get("start")
148+
end = first_column.this.meta_get("end")
149149

150150
error_msg = f"Ambiguous column '{first_column.name}'"
151151
if line and col:
@@ -204,6 +204,9 @@ def _update_source_columns(source_name: str) -> None:
204204
columns[column_name] = source_name
205205

206206
joins = list(scope.find_all(exp.Join))
207+
if not joins:
208+
return {}
209+
207210
names = {join.alias_or_name for join in joins}
208211
ordered = [key for key in scope.selected_sources if key not in names]
209212

@@ -213,6 +216,9 @@ def _update_source_columns(source_name: str) -> None:
213216
# Mapping of automatically joined column names to an ordered set of source names (dict).
214217
column_tables: dict[str, dict[str, t.Any]] = {}
215218

219+
if not any(join.args.get("using") for join in joins):
220+
return column_tables
221+
216222
for source_name in ordered:
217223
_update_source_columns(source_name)
218224

@@ -954,9 +960,15 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expr) -> None:
954960

955961
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
956962
"""Makes sure all identifiers that need to be quoted are quoted."""
957-
return expression.transform(
958-
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
959-
)
963+
dialect = Dialect.get_or_raise(dialect)
964+
965+
# `quote_identifier` only mutates identifiers in place, so we avoid `transform` here
966+
# because its node replacement machinery is wasteful for this case.
967+
for node in expression.walk():
968+
if isinstance(node, exp.Identifier):
969+
dialect.quote_identifier(node, identify=identify)
970+
971+
return expression
960972

961973

962974
def pushdown_cte_alias_columns(scope: Scope) -> None:

sqlglot/optimizer/scope.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@
2121

2222
ROW_LEVEL_AGG_FUNCS = (exp.Count,)
2323

24+
# The node types `Scope._collect` classifies, roughly ordered by frequency. `exp.Query`
25+
# covers subqueries, and `exp.UDTF` covers laterals.
26+
COLLECTIBLE_TYPES = (
27+
exp.Column,
28+
exp.Dot,
29+
exp.Table,
30+
exp.Query,
31+
exp.UDTF,
32+
exp.CTE,
33+
exp.Star,
34+
exp.TableColumn,
35+
exp.JoinHint,
36+
)
37+
2438

2539
class ScopeType(Enum):
2640
ROOT = auto()
@@ -170,7 +184,9 @@ def _collect(self) -> None:
170184
self._column_index = set()
171185

172186
for node in self.walk():
173-
if node is self.expression:
187+
# Most nodes (identifiers, literals, operators etc.) aren't collectible, so a
188+
# single isinstance gate lets them skip the classification chain below.
189+
if node is self.expression or not isinstance(node, COLLECTIBLE_TYPES):
174190
continue
175191

176192
if isinstance(node, exp.Dot) and node.is_star:
@@ -460,8 +476,10 @@ def local_columns(self) -> list[exp.Column]:
460476
list[exp.Column]: Column instances that reference sources in the current scope.
461477
"""
462478
if self._local_columns is None:
463-
external_columns = set(self.external_columns)
464-
self._local_columns = [c for c in self.columns if c not in external_columns]
479+
# Compare nodes by identity: structural equality would conflate distinct
480+
# column nodes that happen to look the same, and is much more expensive.
481+
external_column_ids = {id(c) for c in self.external_columns}
482+
self._local_columns = [c for c in self.columns if id(c) not in external_column_ids]
465483

466484
return self._local_columns
467485

@@ -950,11 +968,17 @@ def walk_in_scope(
950968

951969
yield node
952970

953-
if node is not expression and (
954-
isinstance(node, exp.CTE)
955-
or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node))
956-
or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
957-
or isinstance(node, exp.UNWRAPPED_QUERIES)
971+
# Only CTEs and Queries can start child scopes; checking that first lets all
972+
# other nodes (the vast majority) skip the rest of the boundary checks.
973+
if (
974+
node is not expression
975+
and isinstance(node, (exp.CTE, exp.Query))
976+
and (
977+
isinstance(node, exp.CTE)
978+
or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node))
979+
or isinstance(node.parent, exp.UDTF)
980+
or isinstance(node, exp.UNWRAPPED_QUERIES)
981+
)
958982
):
959983
if isinstance(node, (exp.Subquery, exp.UDTF)):
960984
for key in ("joins", "laterals", "pivots"):

sqlglot/optimizer/simplify.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,9 @@ def simplify(
596596
joins: list[exp.Join] = []
597597

598598
for node in expression.walk(
599-
prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
599+
prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta_get(FINAL))
600600
):
601-
if node.meta.get(FINAL):
601+
if node.meta_get(FINAL):
602602
continue
603603

604604
# group by expressions cannot be simplified, for example
@@ -687,10 +687,10 @@ def _simplify(
687687
original.replace(node)
688688

689689
for n in node.iter_expressions(reverse=True):
690-
if n.meta.get(FINAL):
690+
if n.meta_get(FINAL):
691691
raise
692692
pre_transformation_stack.extend(
693-
n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
693+
n for n in node.iter_expressions(reverse=True) if not n.meta_get(FINAL)
694694
)
695695
post_transformation_stack.append((node, parent))
696696

@@ -937,7 +937,7 @@ def remove_complements(self, expression: object, root: bool = True) -> object:
937937
ops = set(expression.flatten())
938938
for op in ops:
939939
if isinstance(op, exp.Not) and op.this in ops:
940-
if expression.meta.get("nonnull") is True:
940+
if expression.meta_get("nonnull") is True:
941941
return exp.false() if isinstance(expression, exp.And) else exp.true()
942942

943943
return expression

0 commit comments

Comments
 (0)