@@ -372,25 +372,8 @@ def to_node(
372372 }
373373
374374 pivots = scope .pivots
375- pivot = pivots [0 ] if len (pivots ) == 1 and not pivots [0 ].unpivot else None
376- if pivot :
377- # For each aggregation function, the pivot creates a new column for each field in category
378- # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
379- # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
380- # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
381- # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
382- # in the lineage, so lookup the pivot column name by index and map that with the columns used
383- # in the aggregation.
384- #
385- # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
386- pivot_columns = pivot .args ["columns" ]
387- pivot_aggs_count = len (pivot .expressions )
388-
389- pivot_column_mapping = {}
390- for i , agg in enumerate (pivot .expressions ):
391- agg_cols = list (agg .find_all (exp .Column ))
392- for col_index in range (i , len (pivot_columns ), pivot_aggs_count ):
393- pivot_column_mapping [pivot_columns [col_index ].name ] = agg_cols
375+ pivot = pivots [0 ] if len (pivots ) == 1 else None
376+ pivot_column_mapping = _pivot_column_mapping (pivot ) if pivot else {}
394377
395378 for c in source_columns :
396379 table = c .table
@@ -422,7 +405,7 @@ def to_node(
422405 downstream_columns = []
423406
424407 column_name = c .name
425- if any ( column_name == pivot_column . name for pivot_column in pivot_columns ) :
408+ if column_name in pivot_column_mapping :
426409 downstream_columns .extend (pivot_column_mapping [column_name ])
427410 else :
428411 # The column is not in the pivot, so it must be an implicit column of the
@@ -435,6 +418,10 @@ def to_node(
435418 for downstream_column in downstream_columns :
436419 table = downstream_column .table
437420 col_source = scope .sources .get (table )
421+ if isinstance (col_source , exp .Table ) and not col_source .db :
422+ # A pivoted CTE reference maps to the raw table in `scope.sources`,
423+ # so recover the CTE's scope to keep tracing through it
424+ col_source = scope .cte_sources .get (col_source .name , col_source )
438425 if isinstance (col_source , Scope ):
439426 to_node (
440427 downstream_column .name ,
@@ -479,6 +466,48 @@ def to_node(
479466 return node
480467
481468
469+ def _pivot_column_mapping (pivot : exp .Pivot ) -> dict [str , list [exp .Column ]]:
470+ """Map each (UN)PIVOT output column name to the source columns it's derived from."""
471+ mapping : dict [str , list [exp .Column ]] = {}
472+
473+ if pivot .unpivot :
474+ # UNPIVOT(val FOR name IN (a, b)): both the value column(s) and the name column
475+ # are derived from the IN-list source columns
476+ unpivot_columns = [
477+ col
478+ for field in pivot .fields
479+ for e in field .expressions
480+ for col in e .find_all (exp .Column )
481+ ]
482+ for value_column in pivot .expressions :
483+ for identifier in value_column .find_all (exp .Identifier ):
484+ mapping [identifier .name ] = unpivot_columns
485+ for field in pivot .fields :
486+ if isinstance (field , exp .In ):
487+ mapping [field .this .name ] = unpivot_columns
488+
489+ return mapping
490+
491+ # For each aggregation function, the pivot creates a new column for each field in category
492+ # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
493+ # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
494+ # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
495+ # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
496+ # in the lineage, so lookup the pivot column name by index and map that with the columns used
497+ # in the aggregation.
498+ #
499+ # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
500+ pivot_columns = pivot .args ["columns" ]
501+ pivot_aggs_count = len (pivot .expressions )
502+
503+ mapping = {}
504+ for i , agg in enumerate (pivot .expressions ):
505+ agg_cols = list (agg .find_all (exp .Column ))
506+ for col_index in range (i , len (pivot_columns ), pivot_aggs_count ):
507+ mapping [pivot_columns [col_index ].name ] = agg_cols
508+ return mapping
509+
510+
482511class GraphHTML :
483512 """Node to HTML generator using vis.js.
484513
0 commit comments