diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 6db18a09f0..c649c80a87 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -9,6 +9,7 @@ from sqlglot.errors import SqlglotError from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify from sqlglot.optimizer.scope import ScopeType +from sqlglot.schema import ensure_schema if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -133,6 +134,8 @@ def lineage( copy=copy, ) + schema = ensure_schema(schema, dialect=dialect) + if not scope: expression = qualify.qualify( expression, @@ -162,6 +165,7 @@ def lineage( scope, dialect, trim_selects=trim_selects, + schema=schema, _cache=cache, _scope_meta=scope_meta, on_node=on_node, @@ -180,6 +184,7 @@ def lineage( scope, dialect, trim_selects=trim_selects, + schema=schema, _cache=cache, _scope_meta=scope_meta, on_node=on_node, @@ -197,6 +202,7 @@ def to_node( source_name: str | None = None, reference_node_name: str | None = None, trim_selects: bool = True, + schema: Schema | None = None, _cache: dict[tuple, Node] | None = None, _scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] | None = None, on_node: t.Callable[[Node], None] | None = None, @@ -248,6 +254,7 @@ def to_node( source_name=source_name, reference_node_name=reference_node_name, trim_selects=trim_selects, + schema=schema, _cache=_cache, _scope_meta=_scope_meta, on_node=on_node, @@ -288,6 +295,7 @@ def to_node( source_name=source_name, reference_node_name=reference_node_name, trim_selects=trim_selects, + schema=schema, _cache=_cache, _scope_meta=_scope_meta, on_node=on_node, @@ -337,6 +345,7 @@ def to_node( dialect=dialect, upstream=node, trim_selects=trim_selects, + schema=schema, _cache=_cache, _scope_meta=_scope_meta, on_node=on_node, @@ -373,7 +382,18 @@ def to_node( pivots = scope.pivots pivot = pivots[0] if len(pivots) == 1 else None - pivot_column_mapping = _pivot_column_mapping(pivot) if pivot else {} + pivot_renames: dict[str, str] = {} + pivot_column_mapping: dict[str, list[exp.Column]] = {} + + if pivot: + pivot_renames = _pivot_output_renames(pivot, scope, schema) + pivot_column_mapping = _pivot_column_mapping(pivot) + if pivot_renames: + pivot_column_mapping = { + post: pivot_column_mapping[pre] + for post, pre in pivot_renames.items() + if pre in pivot_column_mapping + } for c in source_columns: table = c.table @@ -397,6 +417,7 @@ def to_node( source_name=source_names.get(table) or source_name, reference_node_name=reference_node_name, trim_selects=trim_selects, + schema=schema, _cache=_cache, _scope_meta=_scope_meta, on_node=on_node, @@ -412,10 +433,22 @@ def to_node( # pivoted source -- adapt column to be from the implicit pivoted source. pivot_parent = pivot.parent downstream_columns.append( - exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "") + exp.column( + pivot_renames.get(c.name, c.this), + table=pivot_parent.alias_or_name if pivot_parent else None, + ) ) for downstream_column in downstream_columns: + if not downstream_column.table: + # Some dialects (e.g. bigquery) don't qualify the IN-list columns, + # but they can only come from the pivoted source + pivot_parent = pivot.parent + downstream_column = exp.column( + downstream_column.this, + table=pivot_parent.alias_or_name if pivot_parent else None, + ) + table = downstream_column.table col_source = scope.sources.get(table) if isinstance(col_source, exp.Table) and not col_source.db: @@ -432,6 +465,7 @@ def to_node( source_name=source_names.get(table) or source_name, reference_node_name=reference_node_name, trim_selects=trim_selects, + schema=schema, _cache=_cache, _scope_meta=_scope_meta, on_node=on_node, @@ -466,25 +500,66 @@ def to_node( return node +def _pivot_output_renames( + pivot: exp.Pivot, scope: Scope, schema: Schema | None = None +) -> dict[str, str]: + """ + Map each (UN)PIVOT output column name to its pre-rename name, when an alias column + list (`... AS t(c1, c2, ...)`) renames the outputs. The renames are positional over + the operator's full output, so they can only be aligned when the pre-pivot columns + are known: from the projections of a derived table or CTE source, or from the + schema for a physical table. + """ + if not pivot.alias_column_names: + return {} + + parent = pivot.parent + pre_pivot_columns: list[str] = [] + if isinstance(parent, exp.DerivedTable) and isinstance(parent.this, exp.Query): + pre_pivot_columns = parent.this.named_selects + elif isinstance(parent, exp.Table): + cte_source = scope.cte_sources.get(parent.name) if not parent.db else None + if isinstance(cte_source, Scope) and isinstance(cte_source.expression, exp.Query): + pre_pivot_columns = cte_source.expression.named_selects + elif schema is not None: + pre_pivot_columns = list(schema.column_names(parent, only_visible=True)) + + # The alignment is also unknowable when the source's projections aren't fully + # expanded (e.g. an unresolved star), since the renames would silently shift + if not pre_pivot_columns or "*" in pre_pivot_columns: + return {} + + return pivot.output_columns(pre_pivot_columns) + + def _pivot_column_mapping(pivot: exp.Pivot) -> dict[str, list[exp.Column]]: """Map each (UN)PIVOT output column name to the source columns it's derived from.""" mapping: dict[str, list[exp.Column]] = {} if pivot.unpivot: - # UNPIVOT(val FOR name IN (a, b)): both the value column(s) and the name column - # are derived from the IN-list source columns - unpivot_columns = [ - col - for field in pivot.fields - for e in field.expressions - for col in e.find_all(exp.Column) + # UNPIVOT((v1, v2) FOR name IN ((a1, a2), (b1, b2))): each value column is derived + # positionally from the IN-list entries, and the name column from all of them + value_columns = [ + identifier for e in pivot.expressions for identifier in e.find_all(exp.Identifier) ] - for value_column in pivot.expressions: - for identifier in value_column.find_all(exp.Identifier): - mapping[identifier.name] = unpivot_columns + for value_column in value_columns: + mapping[value_column.name] = [] + for field in pivot.fields: - if isinstance(field, exp.In): - mapping[field.this.name] = unpivot_columns + if not isinstance(field, exp.In): + continue + + name_columns = mapping.setdefault(field.this.name, []) + for entry in field.expressions: + entry_columns = list(entry.find_all(exp.Column)) + name_columns.extend(entry_columns) + + if len(entry_columns) == len(value_columns): + for value_column, column in zip(value_columns, entry_columns): + mapping[value_column.name].append(column) + else: + for value_column in value_columns: + mapping[value_column.name].extend(entry_columns) return mapping diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 5ddb8fccd5..9d3272cc99 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -720,6 +720,87 @@ def test_unpivot_with_cte(self) -> None: self.assertEqual(node.downstream[0].name, "SRC.ID") self.assertEqual(node.downstream[0].downstream[0].name, "SALES.ID") + def test_unpivot_multi_column(self) -> None: + sql = """ + SELECT product, semesters, first_half_sales, second_half_sales + FROM produce + UNPIVOT((first_half_sales, second_half_sales) FOR semesters IN ((q1, q2) AS 'semester_1', (q3, q4) AS 'semester_2')) + """ + schema = { + "produce": { + "product": "string", + "q1": "int64", + "q2": "int64", + "q3": "int64", + "q4": "int64", + } + } + + node = lineage("first_half_sales", sql, schema=schema, dialect="bigquery") + self.assertEqual([d.name for d in node.downstream], ["produce.q1", "produce.q3"]) + + node = lineage("second_half_sales", sql, schema=schema, dialect="bigquery") + self.assertEqual([d.name for d in node.downstream], ["produce.q2", "produce.q4"]) + + node = lineage("semesters", sql, schema=schema, dialect="bigquery") + self.assertEqual( + [d.name for d in node.downstream], + ["produce.q1", "produce.q2", "produce.q3", "produce.q4"], + ) + + node = lineage("product", sql, schema=schema, dialect="bigquery") + self.assertEqual([d.name for d in node.downstream], ["produce.product"]) + + def test_unpivot_with_alias_columns(self) -> None: + sql = """ + WITH src AS ( + SELECT empid, dept, jan, feb FROM monthly_sales + ) + SELECT m, s, e FROM src UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s) + """ + schema = {"monthly_sales": {"empid": "int", "dept": "text", "jan": "int", "feb": "int"}} + + for renamed in ("s", "m"): + node = lineage(renamed, sql, schema=schema, dialect="snowflake") + self.assertEqual([d.name for d in node.downstream], ["SRC.JAN", "SRC.FEB"]) + self.assertEqual(node.downstream[0].downstream[0].name, "MONTHLY_SALES.JAN") + self.assertEqual(node.downstream[1].downstream[0].name, "MONTHLY_SALES.FEB") + + node = lineage("e", sql, schema=schema, dialect="snowflake") + self.assertEqual(node.downstream[0].name, "SRC.EMPID") + self.assertEqual(node.downstream[0].downstream[0].name, "MONTHLY_SALES.EMPID") + + # Physical table source: the pre-pivot columns come from the schema + sql = "SELECT m, s, e FROM monthly_sales UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s)" + + for renamed in ("s", "m"): + node = lineage(renamed, sql, schema=schema, dialect="snowflake") + self.assertEqual( + [d.name for d in node.downstream], ["MONTHLY_SALES.JAN", "MONTHLY_SALES.FEB"] + ) + + node = lineage("e", sql, schema=schema, dialect="snowflake") + self.assertEqual([d.name for d in node.downstream], ["MONTHLY_SALES.EMPID"]) + + # Without a schema the star can't be expanded, so the positional renames are + # unknowable and must not be applied (else `d` would misalign to jan/feb) + sql = """ + WITH src AS (SELECT * FROM monthly_sales) + SELECT d FROM src UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s) + """ + node = lineage("d", sql, dialect="snowflake") + self.assertEqual([d.name for d in node.downstream], ["SRC.D"]) + + def test_pivot_with_alias_columns(self) -> None: + sql = """ + SELECT x FROM (SELECT value, category FROM sample_data) AS sd + PIVOT (SUM(value) FOR category IN ('a', 'b')) AS p(x, y) + """ + node = lineage("x", sql) + + self.assertEqual(node.downstream[0].name, "sd.value") + self.assertEqual(node.downstream[0].downstream[0].name, "sample_data.value") + def test_table_udtf_snowflake(self) -> None: lateral_flatten = """ SELECT f.value:external_id::string AS external_id