diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 4b84f7ba79..6db18a09f0 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -372,25 +372,8 @@ def to_node( } pivots = scope.pivots - pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None - if pivot: - # For each aggregation function, the pivot creates a new column for each field in category - # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, - # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' - # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs - # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest - # in the lineage, so lookup the pivot column name by index and map that with the columns used - # in the aggregation. - # - # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') - pivot_columns = pivot.args["columns"] - pivot_aggs_count = len(pivot.expressions) - - pivot_column_mapping = {} - for i, agg in enumerate(pivot.expressions): - agg_cols = list(agg.find_all(exp.Column)) - for col_index in range(i, len(pivot_columns), pivot_aggs_count): - pivot_column_mapping[pivot_columns[col_index].name] = agg_cols + pivot = pivots[0] if len(pivots) == 1 else None + pivot_column_mapping = _pivot_column_mapping(pivot) if pivot else {} for c in source_columns: table = c.table @@ -422,7 +405,7 @@ def to_node( downstream_columns = [] column_name = c.name - if any(column_name == pivot_column.name for pivot_column in pivot_columns): + if column_name in pivot_column_mapping: downstream_columns.extend(pivot_column_mapping[column_name]) else: # The column is not in the pivot, so it must be an implicit column of the @@ -435,6 +418,10 @@ def to_node( for downstream_column in downstream_columns: table = downstream_column.table col_source = scope.sources.get(table) + if isinstance(col_source, exp.Table) and not col_source.db: + # A pivoted CTE reference maps to the raw table in `scope.sources`, + # so recover the CTE's scope to keep tracing through it + col_source = scope.cte_sources.get(col_source.name, col_source) if isinstance(col_source, Scope): to_node( downstream_column.name, @@ -479,6 +466,48 @@ def to_node( return node +def _pivot_column_mapping(pivot: exp.Pivot) -> dict[str, list[exp.Column]]: + """Map each (UN)PIVOT output column name to the source columns it's derived from.""" + mapping: dict[str, list[exp.Column]] = {} + + if pivot.unpivot: + # UNPIVOT(val FOR name IN (a, b)): both the value column(s) and the name column + # are derived from the IN-list source columns + unpivot_columns = [ + col + for field in pivot.fields + for e in field.expressions + for col in e.find_all(exp.Column) + ] + for value_column in pivot.expressions: + for identifier in value_column.find_all(exp.Identifier): + mapping[identifier.name] = unpivot_columns + for field in pivot.fields: + if isinstance(field, exp.In): + mapping[field.this.name] = unpivot_columns + + return mapping + + # For each aggregation function, the pivot creates a new column for each field in category + # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, + # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' + # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs + # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest + # in the lineage, so lookup the pivot column name by index and map that with the columns used + # in the aggregation. + # + # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') + pivot_columns = pivot.args["columns"] + pivot_aggs_count = len(pivot.expressions) + + mapping = {} + for i, agg in enumerate(pivot.expressions): + agg_cols = list(agg.find_all(exp.Column)) + for col_index in range(i, len(pivot_columns), pivot_aggs_count): + mapping[pivot_columns[col_index].name] = agg_cols + return mapping + + class GraphHTML: """Node to HTML generator using vis.js. diff --git a/tests/test_lineage.py b/tests/test_lineage.py index a8ad75c27f..5ddb8fccd5 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -685,6 +685,41 @@ def test_pivot_with_implicit_column_of_pivoted_source_and_cte(self) -> None: self.assertEqual(node.downstream[0].reference_node_name, "t") self.assertEqual(node.downstream[0].downstream[0].name, "quarterly_sales.empid") + def test_unpivot(self) -> None: + sql = """ + SELECT id, metric_name, score + FROM sales UNPIVOT (score FOR metric_name IN (jan, feb)) + """ + schema = {"sales": {"id": "int", "jan": "int", "feb": "int"}} + + node = lineage("score", sql, schema=schema, dialect="snowflake") + self.assertEqual([d.name for d in node.downstream], ["SALES.JAN", "SALES.FEB"]) + + node = lineage("metric_name", sql, schema=schema, dialect="snowflake") + self.assertEqual([d.name for d in node.downstream], ["SALES.JAN", "SALES.FEB"]) + + node = lineage("id", sql, schema=schema, dialect="snowflake") + self.assertEqual([d.name for d in node.downstream], ["SALES.ID"]) + + def test_unpivot_with_cte(self) -> None: + sql = """ + WITH src AS ( + SELECT id, jan, feb FROM sales + ) + SELECT id, metric_name, score + FROM src UNPIVOT (score FOR metric_name IN (jan, feb)) + """ + schema = {"sales": {"id": "int", "jan": "int", "feb": "int"}} + + node = lineage("score", sql, schema=schema, dialect="snowflake") + self.assertEqual([d.name for d in node.downstream], ["SRC.JAN", "SRC.FEB"]) + self.assertEqual(node.downstream[0].downstream[0].name, "SALES.JAN") + self.assertEqual(node.downstream[1].downstream[0].name, "SALES.FEB") + + node = lineage("id", sql, schema=schema, dialect="snowflake") + self.assertEqual(node.downstream[0].name, "SRC.ID") + self.assertEqual(node.downstream[0].downstream[0].name, "SALES.ID") + def test_table_udtf_snowflake(self) -> None: lateral_flatten = """ SELECT f.value:external_id::string AS external_id