Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 49 additions & 20 deletions sqlglot/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,25 +372,8 @@ def to_node(
}

pivots = scope.pivots
pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
if pivot:
# For each aggregation function, the pivot creates a new column for each field in category
# combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
# b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
# belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
# to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
# in the lineage, so lookup the pivot column name by index and map that with the columns used
# in the aggregation.
#
# Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
pivot_columns = pivot.args["columns"]
pivot_aggs_count = len(pivot.expressions)

pivot_column_mapping = {}
for i, agg in enumerate(pivot.expressions):
agg_cols = list(agg.find_all(exp.Column))
for col_index in range(i, len(pivot_columns), pivot_aggs_count):
pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
pivot = pivots[0] if len(pivots) == 1 else None
pivot_column_mapping = _pivot_column_mapping(pivot) if pivot else {}

for c in source_columns:
table = c.table
Expand Down Expand Up @@ -422,7 +405,7 @@ def to_node(
downstream_columns = []

column_name = c.name
if any(column_name == pivot_column.name for pivot_column in pivot_columns):
if column_name in pivot_column_mapping:
downstream_columns.extend(pivot_column_mapping[column_name])
else:
# The column is not in the pivot, so it must be an implicit column of the
Expand All @@ -435,6 +418,10 @@ def to_node(
for downstream_column in downstream_columns:
table = downstream_column.table
col_source = scope.sources.get(table)
if isinstance(col_source, exp.Table) and not col_source.db:
# A pivoted CTE reference maps to the raw table in `scope.sources`,
# so recover the CTE's scope to keep tracing through it
col_source = scope.cte_sources.get(col_source.name, col_source)
if isinstance(col_source, Scope):
to_node(
downstream_column.name,
Expand Down Expand Up @@ -479,6 +466,48 @@ def to_node(
return node


def _pivot_column_mapping(pivot: exp.Pivot) -> dict[str, list[exp.Column]]:
"""Map each (UN)PIVOT output column name to the source columns it's derived from."""
mapping: dict[str, list[exp.Column]] = {}

if pivot.unpivot:
# UNPIVOT(val FOR name IN (a, b)): both the value column(s) and the name column
# are derived from the IN-list source columns
unpivot_columns = [
col
for field in pivot.fields
for e in field.expressions
for col in e.find_all(exp.Column)
]
for value_column in pivot.expressions:
for identifier in value_column.find_all(exp.Identifier):
mapping[identifier.name] = unpivot_columns
for field in pivot.fields:
if isinstance(field, exp.In):
mapping[field.this.name] = unpivot_columns

Comment on lines +469 to +488
return mapping

# For each aggregation function, the pivot creates a new column for each field in category
# combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
# b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
# belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
# to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
# in the lineage, so lookup the pivot column name by index and map that with the columns used
# in the aggregation.
#
# Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
pivot_columns = pivot.args["columns"]
pivot_aggs_count = len(pivot.expressions)

mapping = {}
for i, agg in enumerate(pivot.expressions):
agg_cols = list(agg.find_all(exp.Column))
for col_index in range(i, len(pivot_columns), pivot_aggs_count):
mapping[pivot_columns[col_index].name] = agg_cols
return mapping


class GraphHTML:
"""Node to HTML generator using vis.js.

Expand Down
35 changes: 35 additions & 0 deletions tests/test_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,41 @@ def test_pivot_with_implicit_column_of_pivoted_source_and_cte(self) -> None:
self.assertEqual(node.downstream[0].reference_node_name, "t")
self.assertEqual(node.downstream[0].downstream[0].name, "quarterly_sales.empid")

def test_unpivot(self) -> None:
sql = """
SELECT id, metric_name, score
FROM sales UNPIVOT (score FOR metric_name IN (jan, feb))
"""
schema = {"sales": {"id": "int", "jan": "int", "feb": "int"}}

node = lineage("score", sql, schema=schema, dialect="snowflake")
self.assertEqual([d.name for d in node.downstream], ["SALES.JAN", "SALES.FEB"])

node = lineage("metric_name", sql, schema=schema, dialect="snowflake")
self.assertEqual([d.name for d in node.downstream], ["SALES.JAN", "SALES.FEB"])

node = lineage("id", sql, schema=schema, dialect="snowflake")
self.assertEqual([d.name for d in node.downstream], ["SALES.ID"])

def test_unpivot_with_cte(self) -> None:
sql = """
WITH src AS (
SELECT id, jan, feb FROM sales
)
SELECT id, metric_name, score
FROM src UNPIVOT (score FOR metric_name IN (jan, feb))
"""
schema = {"sales": {"id": "int", "jan": "int", "feb": "int"}}

node = lineage("score", sql, schema=schema, dialect="snowflake")
self.assertEqual([d.name for d in node.downstream], ["SRC.JAN", "SRC.FEB"])
self.assertEqual(node.downstream[0].downstream[0].name, "SALES.JAN")
self.assertEqual(node.downstream[1].downstream[0].name, "SALES.FEB")

node = lineage("id", sql, schema=schema, dialect="snowflake")
self.assertEqual(node.downstream[0].name, "SRC.ID")
self.assertEqual(node.downstream[0].downstream[0].name, "SALES.ID")

def test_table_udtf_snowflake(self) -> None:
lateral_flatten = """
SELECT f.value:external_id::string AS external_id
Expand Down
Loading