99from sqlglot .errors import SqlglotError
1010from sqlglot .optimizer import Scope , build_scope , find_all_in_scope , normalize_identifiers , qualify
1111from sqlglot .optimizer .scope import ScopeType
12+ from sqlglot .schema import ensure_schema
1213
1314if t .TYPE_CHECKING :
1415 from sqlglot .dialects .dialect import DialectType
@@ -133,6 +134,8 @@ def lineage(
133134 copy = copy ,
134135 )
135136
137+ schema = ensure_schema (schema , dialect = dialect )
138+
136139 if not scope :
137140 expression = qualify .qualify (
138141 expression ,
@@ -162,6 +165,7 @@ def lineage(
162165 scope ,
163166 dialect ,
164167 trim_selects = trim_selects ,
168+ schema = schema ,
165169 _cache = cache ,
166170 _scope_meta = scope_meta ,
167171 on_node = on_node ,
@@ -180,6 +184,7 @@ def lineage(
180184 scope ,
181185 dialect ,
182186 trim_selects = trim_selects ,
187+ schema = schema ,
183188 _cache = cache ,
184189 _scope_meta = scope_meta ,
185190 on_node = on_node ,
@@ -197,6 +202,7 @@ def to_node(
197202 source_name : str | None = None ,
198203 reference_node_name : str | None = None ,
199204 trim_selects : bool = True ,
205+ schema : Schema | None = None ,
200206 _cache : dict [tuple , Node ] | None = None ,
201207 _scope_meta : dict [int , tuple [bool , dict [str , exp .Expr ]]] | None = None ,
202208 on_node : t .Callable [[Node ], None ] | None = None ,
@@ -248,6 +254,7 @@ def to_node(
248254 source_name = source_name ,
249255 reference_node_name = reference_node_name ,
250256 trim_selects = trim_selects ,
257+ schema = schema ,
251258 _cache = _cache ,
252259 _scope_meta = _scope_meta ,
253260 on_node = on_node ,
@@ -288,6 +295,7 @@ def to_node(
288295 source_name = source_name ,
289296 reference_node_name = reference_node_name ,
290297 trim_selects = trim_selects ,
298+ schema = schema ,
291299 _cache = _cache ,
292300 _scope_meta = _scope_meta ,
293301 on_node = on_node ,
@@ -337,6 +345,7 @@ def to_node(
337345 dialect = dialect ,
338346 upstream = node ,
339347 trim_selects = trim_selects ,
348+ schema = schema ,
340349 _cache = _cache ,
341350 _scope_meta = _scope_meta ,
342351 on_node = on_node ,
@@ -373,7 +382,18 @@ def to_node(
373382
374383 pivots = scope .pivots
375384 pivot = pivots [0 ] if len (pivots ) == 1 else None
376- pivot_column_mapping = _pivot_column_mapping (pivot ) if pivot else {}
385+ pivot_renames : dict [str , str ] = {}
386+ pivot_column_mapping : dict [str , list [exp .Column ]] = {}
387+
388+ if pivot :
389+ pivot_renames = _pivot_output_renames (pivot , scope , schema )
390+ pivot_column_mapping = _pivot_column_mapping (pivot )
391+ if pivot_renames :
392+ pivot_column_mapping = {
393+ post : pivot_column_mapping [pre ]
394+ for post , pre in pivot_renames .items ()
395+ if pre in pivot_column_mapping
396+ }
377397
378398 for c in source_columns :
379399 table = c .table
@@ -397,6 +417,7 @@ def to_node(
397417 source_name = source_names .get (table ) or source_name ,
398418 reference_node_name = reference_node_name ,
399419 trim_selects = trim_selects ,
420+ schema = schema ,
400421 _cache = _cache ,
401422 _scope_meta = _scope_meta ,
402423 on_node = on_node ,
@@ -412,10 +433,22 @@ def to_node(
412433 # pivoted source -- adapt column to be from the implicit pivoted source.
413434 pivot_parent = pivot .parent
414435 downstream_columns .append (
415- exp .column (c .this , table = pivot_parent .alias_or_name if pivot_parent else "" )
436+ exp .column (
437+ pivot_renames .get (c .name , c .this ),
438+ table = pivot_parent .alias_or_name if pivot_parent else None ,
439+ )
416440 )
417441
418442 for downstream_column in downstream_columns :
443+ if not downstream_column .table :
444+ # Some dialects (e.g. bigquery) don't qualify the IN-list columns,
445+ # but they can only come from the pivoted source
446+ pivot_parent = pivot .parent
447+ downstream_column = exp .column (
448+ downstream_column .this ,
449+ table = pivot_parent .alias_or_name if pivot_parent else None ,
450+ )
451+
419452 table = downstream_column .table
420453 col_source = scope .sources .get (table )
421454 if isinstance (col_source , exp .Table ) and not col_source .db :
@@ -432,6 +465,7 @@ def to_node(
432465 source_name = source_names .get (table ) or source_name ,
433466 reference_node_name = reference_node_name ,
434467 trim_selects = trim_selects ,
468+ schema = schema ,
435469 _cache = _cache ,
436470 _scope_meta = _scope_meta ,
437471 on_node = on_node ,
@@ -466,25 +500,66 @@ def to_node(
466500 return node
467501
468502
503+ def _pivot_output_renames (
504+ pivot : exp .Pivot , scope : Scope , schema : Schema | None = None
505+ ) -> dict [str , str ]:
506+ """
507+ Map each (UN)PIVOT output column name to its pre-rename name, when an alias column
508+ list (`... AS t(c1, c2, ...)`) renames the outputs. The renames are positional over
509+ the operator's full output, so they can only be aligned when the pre-pivot columns
510+ are known: from the projections of a derived table or CTE source, or from the
511+ schema for a physical table.
512+ """
513+ if not pivot .alias_column_names :
514+ return {}
515+
516+ parent = pivot .parent
517+ pre_pivot_columns : list [str ] = []
518+ if isinstance (parent , exp .DerivedTable ) and isinstance (parent .this , exp .Query ):
519+ pre_pivot_columns = parent .this .named_selects
520+ elif isinstance (parent , exp .Table ):
521+ cte_source = scope .cte_sources .get (parent .name ) if not parent .db else None
522+ if isinstance (cte_source , Scope ) and isinstance (cte_source .expression , exp .Query ):
523+ pre_pivot_columns = cte_source .expression .named_selects
524+ elif schema is not None :
525+ pre_pivot_columns = list (schema .column_names (parent , only_visible = True ))
526+
527+ # The alignment is also unknowable when the source's projections aren't fully
528+ # expanded (e.g. an unresolved star), since the renames would silently shift
529+ if not pre_pivot_columns or "*" in pre_pivot_columns :
530+ return {}
531+
532+ return pivot .output_columns (pre_pivot_columns )
533+
534+
469535def _pivot_column_mapping (pivot : exp .Pivot ) -> dict [str , list [exp .Column ]]:
470536 """Map each (UN)PIVOT output column name to the source columns it's derived from."""
471537 mapping : dict [str , list [exp .Column ]] = {}
472538
473539 if pivot .unpivot :
474- # UNPIVOT(val FOR name IN (a, b)): both the value column(s) and the name column
475- # are derived from the IN-list source columns
476- unpivot_columns = [
477- col
478- for field in pivot .fields
479- for e in field .expressions
480- for col in e .find_all (exp .Column )
540+ # UNPIVOT((v1, v2) FOR name IN ((a1, a2), (b1, b2))): each value column is derived
541+ # positionally from the IN-list entries, and the name column from all of them
542+ value_columns = [
543+ identifier for e in pivot .expressions for identifier in e .find_all (exp .Identifier )
481544 ]
482- for value_column in pivot . expressions :
483- for identifier in value_column .find_all ( exp . Identifier ):
484- mapping [ identifier . name ] = unpivot_columns
545+ for value_column in value_columns :
546+ mapping [ value_column .name ] = []
547+
485548 for field in pivot .fields :
486- if isinstance (field , exp .In ):
487- mapping [field .this .name ] = unpivot_columns
549+ if not isinstance (field , exp .In ):
550+ continue
551+
552+ name_columns = mapping .setdefault (field .this .name , [])
553+ for entry in field .expressions :
554+ entry_columns = list (entry .find_all (exp .Column ))
555+ name_columns .extend (entry_columns )
556+
557+ if len (entry_columns ) == len (value_columns ):
558+ for value_column , column in zip (value_columns , entry_columns ):
559+ mapping [value_column .name ].append (column )
560+ else :
561+ for value_column in value_columns :
562+ mapping [value_column .name ].extend (entry_columns )
488563
489564 return mapping
490565
0 commit comments