Skip to content

Commit c31e2e3

Browse files
Refactor(Optimizer): improve typing coverage of optimizer modules (#7446)
* refactor (optimizer): add various type annotations to public functions * refactor(optimizer): add various type annotations to simplify module * refactor(optimizer): improve documnetation for `OptimizerFn` Protocol * Chore: ran ruff format * refactor (optimizer): going back to original import pattern for `normalized`. the caller site is currently untypted anyway * refactor (optimizer): make `_simplify_integer_cast` generic, optimize function body to avoid redundant/type unsafe double isinstance check * fix: make `S` TypeVar a runtime concrete value for mypc * fix: for some reason mypc don't work well with generics for `_simplify_integer_cast`. trying with overloads * fix: revert `_simplify_integer_cast` to original typing to avoid mypc issues * Fix: Use a Protocol for `eval_boolean`, since it can take any comparable value * Chore: run formatter * Refactor (optimizer): widen some helper functions input types * refactor (optimizer): use `TypeIs` in `simplify_literals` for type narrowing * chore: ruff format * fix: added various return types to `simplify` functions * chore: ruff format * fix: annotate `_simplify_comparison` return type * refactor: add type hints to `extract_interval` and `date_literal` * fix: widen return type of `date_literal` to satisfy mypc
1 parent 4deba62 commit c31e2e3

9 files changed

Lines changed: 116 additions & 79 deletions

sqlglot/optimizer/annotate_types.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@
2525
BinaryCoercionFunc = t.Callable[
2626
[exp.Expr, exp.Expr], t.Optional[t.Union[exp.DataType, exp.DType]]
2727
]
28-
BinaryCoercions = t.Dict[
29-
t.Tuple[exp.DType, exp.DType],
30-
BinaryCoercionFunc,
31-
]
28+
BinaryCoercions = dict[tuple[exp.DType, exp.DType], BinaryCoercionFunc]
3229

3330
from sqlglot.dialects.dialect import DialectType
3431
from sqlglot.typing import ExprMetadataType
@@ -47,9 +44,9 @@
4744

4845
def annotate_types(
4946
expression: E,
50-
schema: t.Optional[t.Dict | Schema] = None,
47+
schema: dict[str, object] | Schema | None = None,
5148
expression_metadata: t.Optional[ExprMetadataType] = None,
52-
coerces_to: t.Optional[t.Dict[exp.DType, t.Set[exp.DType]]] = None,
49+
coerces_to: dict[exp.DType, set[exp.DType]] | None = None,
5350
dialect: DialectType = None,
5451
overwrite_types: bool = True,
5552
) -> E:

sqlglot/optimizer/isolate_table_selects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
def isolate_table_selects(
1717
expression: E,
18-
schema: t.Optional[t.Dict | Schema] = None,
18+
schema: dict[str, object] | Schema | None = None,
1919
dialect: DialectType = None,
2020
) -> E:
2121
schema = ensure_schema(schema, dialect=dialect)

sqlglot/optimizer/optimize_joins.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from __future__ import annotations
2-
3-
import typing as t
2+
from collections.abc import Iterable
43

54
from sqlglot import exp
65
from sqlglot.helper import tsort
76

87
JOIN_ATTRS = ("on", "side", "kind", "using", "method")
98

109

11-
def optimize_joins(expression):
10+
def optimize_joins(expression: exp.Expr) -> exp.Expr:
1211
"""
1312
Removes cross joins if possible and reorder joins based on predicate dependencies.
1413
@@ -24,8 +23,8 @@ def optimize_joins(expression):
2423
if not _is_reorderable(joins):
2524
continue
2625

27-
references = {}
28-
cross_joins = []
26+
references: dict[str, list[exp.Join]] = {}
27+
cross_joins: list[tuple[str, exp.Join]] = []
2928

3029
for join in joins:
3130
tables = other_table_names(join)
@@ -58,7 +57,7 @@ def optimize_joins(expression):
5857
return expression
5958

6059

61-
def reorder_joins(expression):
60+
def reorder_joins(expression) -> exp.Expr:
6261
"""
6362
Reorder joins by topological sort order based on predicate references.
6463
"""
@@ -82,7 +81,7 @@ def reorder_joins(expression):
8281
return expression
8382

8483

85-
def normalize(expression):
84+
def normalize(expression: exp.Expr) -> exp.Expr:
8685
"""
8786
Remove INNER and OUTER from joins as they are optional.
8887
"""
@@ -101,12 +100,12 @@ def normalize(expression):
101100
return expression
102101

103102

104-
def other_table_names(join: exp.Join) -> t.Set[str]:
103+
def other_table_names(join: exp.Join) -> set[str]:
105104
on = join.args.get("on")
106105
return exp.column_table_names(on, join.alias_or_name) if on else set()
107106

108107

109-
def _is_reorderable(joins: t.List[exp.Join]) -> bool:
108+
def _is_reorderable(joins: Iterable[exp.Join]) -> bool:
110109
"""
111110
Checks if joins can be reordered without changing query semantics.
112111

sqlglot/optimizer/optimizer.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,23 @@
2121
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
2222
from sqlglot.schema import ensure_schema
2323

24-
RULES = (
24+
25+
class OptimizerFn(t.Protocol):
26+
"""Protocol for optimizer rules functions.
27+
28+
An optimizer rule:
29+
30+
- **Must** accept an `Expr` as the first argument
31+
- Can take undefined `*args` and `**kwargs` afterwards
32+
- **Must** return an `Expr`.
33+
Note:
34+
We use `typing.Protocol` here because this is not doable with `collections.abc.Callable`.
35+
"""
36+
37+
def __call__(self, expression: exp.Expr, *args: t.Any, **kwargs: t.Any) -> exp.Expr: ...
38+
39+
40+
RULES: tuple[OptimizerFn, ...] = (
2541
qualify,
2642
pushdown_projections,
2743
normalize,
@@ -41,13 +57,13 @@
4157

4258
def optimize(
4359
expression: str | exp.Expr,
44-
schema: t.Optional[dict | Schema] = None,
45-
db: t.Optional[str | exp.Identifier] = None,
46-
catalog: t.Optional[str | exp.Identifier] = None,
60+
schema: dict[str, object] | Schema | None = None,
61+
db: str | exp.Identifier | None = None,
62+
catalog: str | exp.Identifier | None = None,
4763
dialect: DialectType = None,
48-
rules: Sequence[t.Callable] = RULES,
64+
rules: Sequence[OptimizerFn] = RULES,
4965
sql: t.Optional[str] = None,
50-
**kwargs,
66+
**kwargs: object,
5167
) -> exp.Expr:
5268
"""
5369
Rewrite a sqlglot AST into an optimized form.

sqlglot/optimizer/pushdown_projections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def default_selection(is_agg: bool) -> exp.Alias:
2626

2727
def pushdown_projections(
2828
expression: E,
29-
schema: t.Optional[t.Dict | Schema] = None,
29+
schema: dict[str, object] | Schema | None = None,
3030
remove_unused_selections: bool = True,
3131
dialect: DialectType = None,
3232
) -> E:

sqlglot/optimizer/qualify.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@
1818
def qualify(
1919
expression: exp.Expr,
2020
dialect: DialectType = None,
21-
db: t.Optional[str] = None,
22-
catalog: t.Optional[str] = None,
23-
schema: t.Optional[dict | Schema] = None,
21+
db: str | None = None,
22+
catalog: str | None = None,
23+
schema: dict[str, object] | Schema | None = None,
2424
expand_alias_refs: bool = True,
2525
expand_stars: bool = True,
26-
infer_schema: t.Optional[bool] = None,
26+
infer_schema: bool | None = None,
2727
isolate_tables: bool = False,
2828
qualify_columns: bool = True,
2929
allow_partial_qualification: bool = False,
3030
validate_qualify_columns: bool = True,
3131
quote_identifiers: bool = True,
3232
identify: bool = True,
3333
canonicalize_table_aliases: bool = False,
34-
on_qualify: t.Optional[t.Callable[[exp.Expr], None]] = None,
35-
sql: t.Optional[str] = None,
34+
on_qualify: t.Callable[[exp.Expr], None] | None = None,
35+
sql: str | None = None,
3636
) -> exp.Expr:
3737
"""
3838
Rewrite sqlglot AST to have normalized and qualified tables and columns.

sqlglot/optimizer/qualify_columns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
def qualify_columns(
2222
expression: exp.Expr,
23-
schema: dict | Schema,
23+
schema: dict[str, object] | Schema,
2424
expand_alias_refs: bool = True,
2525
expand_stars: bool = True,
2626
infer_schema: t.Optional[bool] = None,

0 commit comments

Comments
 (0)