Skip to content

Commit 74bef2f

Browse files
authored
Feat(lineage): add support for UNPIVOT (#7729)
1 parent 32ed149 commit 74bef2f

2 files changed

Lines changed: 84 additions & 20 deletions

File tree

sqlglot/lineage.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -372,25 +372,8 @@ def to_node(
372372
}
373373

374374
pivots = scope.pivots
375-
pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
376-
if pivot:
377-
# For each aggregation function, the pivot creates a new column for each field in category
378-
# combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
379-
# b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
380-
# belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
381-
# to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
382-
# in the lineage, so lookup the pivot column name by index and map that with the columns used
383-
# in the aggregation.
384-
#
385-
# Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
386-
pivot_columns = pivot.args["columns"]
387-
pivot_aggs_count = len(pivot.expressions)
388-
389-
pivot_column_mapping = {}
390-
for i, agg in enumerate(pivot.expressions):
391-
agg_cols = list(agg.find_all(exp.Column))
392-
for col_index in range(i, len(pivot_columns), pivot_aggs_count):
393-
pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
375+
pivot = pivots[0] if len(pivots) == 1 else None
376+
pivot_column_mapping = _pivot_column_mapping(pivot) if pivot else {}
394377

395378
for c in source_columns:
396379
table = c.table
@@ -422,7 +405,7 @@ def to_node(
422405
downstream_columns = []
423406

424407
column_name = c.name
425-
if any(column_name == pivot_column.name for pivot_column in pivot_columns):
408+
if column_name in pivot_column_mapping:
426409
downstream_columns.extend(pivot_column_mapping[column_name])
427410
else:
428411
# The column is not in the pivot, so it must be an implicit column of the
@@ -435,6 +418,10 @@ def to_node(
435418
for downstream_column in downstream_columns:
436419
table = downstream_column.table
437420
col_source = scope.sources.get(table)
421+
if isinstance(col_source, exp.Table) and not col_source.db:
422+
# A pivoted CTE reference maps to the raw table in `scope.sources`,
423+
# so recover the CTE's scope to keep tracing through it
424+
col_source = scope.cte_sources.get(col_source.name, col_source)
438425
if isinstance(col_source, Scope):
439426
to_node(
440427
downstream_column.name,
@@ -479,6 +466,48 @@ def to_node(
479466
return node
480467

481468

469+
def _pivot_column_mapping(pivot: exp.Pivot) -> dict[str, list[exp.Column]]:
470+
"""Map each (UN)PIVOT output column name to the source columns it's derived from."""
471+
mapping: dict[str, list[exp.Column]] = {}
472+
473+
if pivot.unpivot:
474+
# UNPIVOT(val FOR name IN (a, b)): both the value column(s) and the name column
475+
# are derived from the IN-list source columns
476+
unpivot_columns = [
477+
col
478+
for field in pivot.fields
479+
for e in field.expressions
480+
for col in e.find_all(exp.Column)
481+
]
482+
for value_column in pivot.expressions:
483+
for identifier in value_column.find_all(exp.Identifier):
484+
mapping[identifier.name] = unpivot_columns
485+
for field in pivot.fields:
486+
if isinstance(field, exp.In):
487+
mapping[field.this.name] = unpivot_columns
488+
489+
return mapping
490+
491+
# For each aggregation function, the pivot creates a new column for each field in category
492+
# combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
493+
# b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
494+
# belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
495+
# to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
496+
# in the lineage, so lookup the pivot column name by index and map that with the columns used
497+
# in the aggregation.
498+
#
499+
# Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
500+
pivot_columns = pivot.args["columns"]
501+
pivot_aggs_count = len(pivot.expressions)
502+
503+
mapping = {}
504+
for i, agg in enumerate(pivot.expressions):
505+
agg_cols = list(agg.find_all(exp.Column))
506+
for col_index in range(i, len(pivot_columns), pivot_aggs_count):
507+
mapping[pivot_columns[col_index].name] = agg_cols
508+
return mapping
509+
510+
482511
class GraphHTML:
483512
"""Node to HTML generator using vis.js.
484513

tests/test_lineage.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,41 @@ def test_pivot_with_implicit_column_of_pivoted_source_and_cte(self) -> None:
685685
self.assertEqual(node.downstream[0].reference_node_name, "t")
686686
self.assertEqual(node.downstream[0].downstream[0].name, "quarterly_sales.empid")
687687

688+
def test_unpivot(self) -> None:
689+
sql = """
690+
SELECT id, metric_name, score
691+
FROM sales UNPIVOT (score FOR metric_name IN (jan, feb))
692+
"""
693+
schema = {"sales": {"id": "int", "jan": "int", "feb": "int"}}
694+
695+
node = lineage("score", sql, schema=schema, dialect="snowflake")
696+
self.assertEqual([d.name for d in node.downstream], ["SALES.JAN", "SALES.FEB"])
697+
698+
node = lineage("metric_name", sql, schema=schema, dialect="snowflake")
699+
self.assertEqual([d.name for d in node.downstream], ["SALES.JAN", "SALES.FEB"])
700+
701+
node = lineage("id", sql, schema=schema, dialect="snowflake")
702+
self.assertEqual([d.name for d in node.downstream], ["SALES.ID"])
703+
704+
def test_unpivot_with_cte(self) -> None:
705+
sql = """
706+
WITH src AS (
707+
SELECT id, jan, feb FROM sales
708+
)
709+
SELECT id, metric_name, score
710+
FROM src UNPIVOT (score FOR metric_name IN (jan, feb))
711+
"""
712+
schema = {"sales": {"id": "int", "jan": "int", "feb": "int"}}
713+
714+
node = lineage("score", sql, schema=schema, dialect="snowflake")
715+
self.assertEqual([d.name for d in node.downstream], ["SRC.JAN", "SRC.FEB"])
716+
self.assertEqual(node.downstream[0].downstream[0].name, "SALES.JAN")
717+
self.assertEqual(node.downstream[1].downstream[0].name, "SALES.FEB")
718+
719+
node = lineage("id", sql, schema=schema, dialect="snowflake")
720+
self.assertEqual(node.downstream[0].name, "SRC.ID")
721+
self.assertEqual(node.downstream[0].downstream[0].name, "SALES.ID")
722+
688723
def test_table_udtf_snowflake(self) -> None:
689724
lateral_flatten = """
690725
SELECT f.value:external_id::string AS external_id

0 commit comments

Comments
 (0)