1+ from __future__ import annotations
2+
3+ import typing as t
14from collections import defaultdict
25
36from sqlglot import alias , exp
710from sqlglot .errors import OptimizeError
811from sqlglot .helper import seq_get
912
13+ if t .TYPE_CHECKING :
14+ from sqlglot ._typing import E
15+ from sqlglot .schema import Schema
16+ from sqlglot .dialects .dialect import DialectType
17+
1018# Sentinel value that means an outer query selecting ALL columns
1119SELECT_ALL = object ()
1220
@@ -16,7 +24,12 @@ def default_selection(is_agg: bool) -> exp.Alias:
1624 return alias (exp .Max (this = exp .Literal .number (1 )) if is_agg else "1" , "_" )
1725
1826
19- def pushdown_projections (expression , schema = None , remove_unused_selections = True ):
27+ def pushdown_projections (
28+ expression : E ,
29+ schema : t .Optional [t .Dict | Schema ] = None ,
30+ remove_unused_selections : bool = True ,
31+ dialect : DialectType = None ,
32+ ) -> E :
2033 """
2134 Rewrite sqlglot AST to remove unused columns projections.
2235
@@ -34,9 +47,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
3447 sqlglot.Expression: optimized expression
3548 """
3649 # Map of Scope to all columns being selected by outer queries.
37- schema = ensure_schema (schema )
38- source_column_alias_count = {}
39- referenced_columns = defaultdict (set )
50+ schema = ensure_schema (schema , dialect = dialect )
51+ source_column_alias_count : t . Dict [ exp . Expression | Scope , int ] = {}
52+ referenced_columns : t . DefaultDict [ Scope , t . Set [ str | object ]] = defaultdict (set )
4053
4154 # We build the scope tree (which is traversed in DFS postorder), then iterate
4255 # over the result in reverse order. This should ensure that the set of selected
@@ -69,12 +82,12 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
6982 if scope .expression .args .get ("by_name" ):
7083 referenced_columns [right ] = referenced_columns [left ]
7184 else :
72- referenced_columns [right ] = [
85+ referenced_columns [right ] = {
7386 right .expression .selects [i ].alias_or_name
7487 for i , select in enumerate (left .expression .selects )
7588 if SELECT_ALL in parent_selections
7689 or select .alias_or_name in parent_selections
77- ]
90+ }
7891
7992 if isinstance (scope .expression , exp .Select ):
8093 if remove_unused_selections :
0 commit comments