Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 89 additions & 14 deletions sqlglot/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlglot.errors import SqlglotError
from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
from sqlglot.optimizer.scope import ScopeType
from sqlglot.schema import ensure_schema

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
Expand Down Expand Up @@ -133,6 +134,8 @@ def lineage(
copy=copy,
)

schema = ensure_schema(schema, dialect=dialect)

if not scope:
expression = qualify.qualify(
expression,
Expand Down Expand Up @@ -162,6 +165,7 @@ def lineage(
scope,
dialect,
trim_selects=trim_selects,
schema=schema,
_cache=cache,
_scope_meta=scope_meta,
on_node=on_node,
Expand All @@ -180,6 +184,7 @@ def lineage(
scope,
dialect,
trim_selects=trim_selects,
schema=schema,
_cache=cache,
_scope_meta=scope_meta,
on_node=on_node,
Expand All @@ -197,6 +202,7 @@ def to_node(
source_name: str | None = None,
reference_node_name: str | None = None,
trim_selects: bool = True,
schema: Schema | None = None,
_cache: dict[tuple, Node] | None = None,
_scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] | None = None,
on_node: t.Callable[[Node], None] | None = None,
Expand Down Expand Up @@ -248,6 +254,7 @@ def to_node(
source_name=source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
schema=schema,
_cache=_cache,
_scope_meta=_scope_meta,
on_node=on_node,
Expand Down Expand Up @@ -288,6 +295,7 @@ def to_node(
source_name=source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
schema=schema,
_cache=_cache,
_scope_meta=_scope_meta,
on_node=on_node,
Expand Down Expand Up @@ -337,6 +345,7 @@ def to_node(
dialect=dialect,
upstream=node,
trim_selects=trim_selects,
schema=schema,
_cache=_cache,
_scope_meta=_scope_meta,
on_node=on_node,
Expand Down Expand Up @@ -373,7 +382,18 @@ def to_node(

pivots = scope.pivots
pivot = pivots[0] if len(pivots) == 1 else None
pivot_column_mapping = _pivot_column_mapping(pivot) if pivot else {}
pivot_renames: dict[str, str] = {}
pivot_column_mapping: dict[str, list[exp.Column]] = {}

if pivot:
pivot_renames = _pivot_output_renames(pivot, scope, schema)
pivot_column_mapping = _pivot_column_mapping(pivot)
if pivot_renames:
pivot_column_mapping = {
post: pivot_column_mapping[pre]
for post, pre in pivot_renames.items()
if pre in pivot_column_mapping
}

for c in source_columns:
table = c.table
Expand All @@ -397,6 +417,7 @@ def to_node(
source_name=source_names.get(table) or source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
schema=schema,
_cache=_cache,
_scope_meta=_scope_meta,
on_node=on_node,
Expand All @@ -412,10 +433,22 @@ def to_node(
# pivoted source -- adapt column to be from the implicit pivoted source.
pivot_parent = pivot.parent
downstream_columns.append(
exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "")
exp.column(
pivot_renames.get(c.name, c.this),
table=pivot_parent.alias_or_name if pivot_parent else None,
)
)

for downstream_column in downstream_columns:
if not downstream_column.table:
# Some dialects (e.g. bigquery) don't qualify the IN-list columns,
# but they can only come from the pivoted source
pivot_parent = pivot.parent
downstream_column = exp.column(
downstream_column.this,
table=pivot_parent.alias_or_name if pivot_parent else None,
)

table = downstream_column.table
col_source = scope.sources.get(table)
if isinstance(col_source, exp.Table) and not col_source.db:
Expand All @@ -432,6 +465,7 @@ def to_node(
source_name=source_names.get(table) or source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
schema=schema,
_cache=_cache,
_scope_meta=_scope_meta,
on_node=on_node,
Expand Down Expand Up @@ -466,25 +500,66 @@ def to_node(
return node


def _pivot_output_renames(
pivot: exp.Pivot, scope: Scope, schema: Schema | None = None
) -> dict[str, str]:
"""
Map each (UN)PIVOT output column name to its pre-rename name, when an alias column
list (`... AS t(c1, c2, ...)`) renames the outputs. The renames are positional over
the operator's full output, so they can only be aligned when the pre-pivot columns
are known: from the projections of a derived table or CTE source, or from the
schema for a physical table.
"""
if not pivot.alias_column_names:
return {}

parent = pivot.parent
pre_pivot_columns: list[str] = []
if isinstance(parent, exp.DerivedTable) and isinstance(parent.this, exp.Query):
pre_pivot_columns = parent.this.named_selects
elif isinstance(parent, exp.Table):
cte_source = scope.cte_sources.get(parent.name) if not parent.db else None
if isinstance(cte_source, Scope) and isinstance(cte_source.expression, exp.Query):
pre_pivot_columns = cte_source.expression.named_selects
elif schema is not None:
pre_pivot_columns = list(schema.column_names(parent, only_visible=True))

# The alignment is also unknowable when the source's projections aren't fully
# expanded (e.g. an unresolved star), since the renames would silently shift
if not pre_pivot_columns or "*" in pre_pivot_columns:
return {}

return pivot.output_columns(pre_pivot_columns)


def _pivot_column_mapping(pivot: exp.Pivot) -> dict[str, list[exp.Column]]:
"""Map each (UN)PIVOT output column name to the source columns it's derived from."""
mapping: dict[str, list[exp.Column]] = {}

if pivot.unpivot:
# UNPIVOT(val FOR name IN (a, b)): both the value column(s) and the name column
# are derived from the IN-list source columns
unpivot_columns = [
col
for field in pivot.fields
for e in field.expressions
for col in e.find_all(exp.Column)
# UNPIVOT((v1, v2) FOR name IN ((a1, a2), (b1, b2))): each value column is derived
# positionally from the IN-list entries, and the name column from all of them
value_columns = [
identifier for e in pivot.expressions for identifier in e.find_all(exp.Identifier)
]
for value_column in pivot.expressions:
for identifier in value_column.find_all(exp.Identifier):
mapping[identifier.name] = unpivot_columns
for value_column in value_columns:
mapping[value_column.name] = []

for field in pivot.fields:
if isinstance(field, exp.In):
mapping[field.this.name] = unpivot_columns
if not isinstance(field, exp.In):
continue

name_columns = mapping.setdefault(field.this.name, [])
for entry in field.expressions:
entry_columns = list(entry.find_all(exp.Column))
name_columns.extend(entry_columns)

if len(entry_columns) == len(value_columns):
for value_column, column in zip(value_columns, entry_columns):
mapping[value_column.name].append(column)
else:
for value_column in value_columns:
mapping[value_column.name].extend(entry_columns)

return mapping

Expand Down
81 changes: 81 additions & 0 deletions tests/test_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,87 @@ def test_unpivot_with_cte(self) -> None:
self.assertEqual(node.downstream[0].name, "SRC.ID")
self.assertEqual(node.downstream[0].downstream[0].name, "SALES.ID")

def test_unpivot_multi_column(self) -> None:
sql = """
SELECT product, semesters, first_half_sales, second_half_sales
FROM produce
UNPIVOT((first_half_sales, second_half_sales) FOR semesters IN ((q1, q2) AS 'semester_1', (q3, q4) AS 'semester_2'))
"""
schema = {
"produce": {
"product": "string",
"q1": "int64",
"q2": "int64",
"q3": "int64",
"q4": "int64",
}
}

node = lineage("first_half_sales", sql, schema=schema, dialect="bigquery")
self.assertEqual([d.name for d in node.downstream], ["produce.q1", "produce.q3"])

node = lineage("second_half_sales", sql, schema=schema, dialect="bigquery")
self.assertEqual([d.name for d in node.downstream], ["produce.q2", "produce.q4"])

node = lineage("semesters", sql, schema=schema, dialect="bigquery")
self.assertEqual(
[d.name for d in node.downstream],
["produce.q1", "produce.q2", "produce.q3", "produce.q4"],
)

node = lineage("product", sql, schema=schema, dialect="bigquery")
self.assertEqual([d.name for d in node.downstream], ["produce.product"])

def test_unpivot_with_alias_columns(self) -> None:
sql = """
WITH src AS (
SELECT empid, dept, jan, feb FROM monthly_sales
)
SELECT m, s, e FROM src UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s)
"""
schema = {"monthly_sales": {"empid": "int", "dept": "text", "jan": "int", "feb": "int"}}

for renamed in ("s", "m"):
node = lineage(renamed, sql, schema=schema, dialect="snowflake")
self.assertEqual([d.name for d in node.downstream], ["SRC.JAN", "SRC.FEB"])
self.assertEqual(node.downstream[0].downstream[0].name, "MONTHLY_SALES.JAN")
self.assertEqual(node.downstream[1].downstream[0].name, "MONTHLY_SALES.FEB")

node = lineage("e", sql, schema=schema, dialect="snowflake")
self.assertEqual(node.downstream[0].name, "SRC.EMPID")
self.assertEqual(node.downstream[0].downstream[0].name, "MONTHLY_SALES.EMPID")

# Physical table source: the pre-pivot columns come from the schema
sql = "SELECT m, s, e FROM monthly_sales UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s)"

for renamed in ("s", "m"):
node = lineage(renamed, sql, schema=schema, dialect="snowflake")
self.assertEqual(
[d.name for d in node.downstream], ["MONTHLY_SALES.JAN", "MONTHLY_SALES.FEB"]
)

node = lineage("e", sql, schema=schema, dialect="snowflake")
self.assertEqual([d.name for d in node.downstream], ["MONTHLY_SALES.EMPID"])

# Without a schema the star can't be expanded, so the positional renames are
# unknowable and must not be applied (else `d` would misalign to jan/feb)
sql = """
WITH src AS (SELECT * FROM monthly_sales)
SELECT d FROM src UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s)
"""
node = lineage("d", sql, dialect="snowflake")
self.assertEqual([d.name for d in node.downstream], ["SRC.D"])

def test_pivot_with_alias_columns(self) -> None:
sql = """
SELECT x FROM (SELECT value, category FROM sample_data) AS sd
PIVOT (SUM(value) FOR category IN ('a', 'b')) AS p(x, y)
"""
node = lineage("x", sql)

self.assertEqual(node.downstream[0].name, "sd.value")
self.assertEqual(node.downstream[0].downstream[0].name, "sample_data.value")

def test_table_udtf_snowflake(self) -> None:
lateral_flatten = """
SELECT f.value:external_id::string AS external_id
Expand Down
Loading