Skip to content

Commit f3ba8e4

Browse files
authored
Feat(lineage): more UNPIVOT lineage improvements (#7736)
1 parent ffd3a14 commit f3ba8e4

2 files changed

Lines changed: 170 additions & 14 deletions

File tree

sqlglot/lineage.py

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sqlglot.errors import SqlglotError
1010
from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
1111
from sqlglot.optimizer.scope import ScopeType
12+
from sqlglot.schema import ensure_schema
1213

1314
if t.TYPE_CHECKING:
1415
from sqlglot.dialects.dialect import DialectType
@@ -133,6 +134,8 @@ def lineage(
133134
copy=copy,
134135
)
135136

137+
schema = ensure_schema(schema, dialect=dialect)
138+
136139
if not scope:
137140
expression = qualify.qualify(
138141
expression,
@@ -162,6 +165,7 @@ def lineage(
162165
scope,
163166
dialect,
164167
trim_selects=trim_selects,
168+
schema=schema,
165169
_cache=cache,
166170
_scope_meta=scope_meta,
167171
on_node=on_node,
@@ -180,6 +184,7 @@ def lineage(
180184
scope,
181185
dialect,
182186
trim_selects=trim_selects,
187+
schema=schema,
183188
_cache=cache,
184189
_scope_meta=scope_meta,
185190
on_node=on_node,
@@ -197,6 +202,7 @@ def to_node(
197202
source_name: str | None = None,
198203
reference_node_name: str | None = None,
199204
trim_selects: bool = True,
205+
schema: Schema | None = None,
200206
_cache: dict[tuple, Node] | None = None,
201207
_scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] | None = None,
202208
on_node: t.Callable[[Node], None] | None = None,
@@ -248,6 +254,7 @@ def to_node(
248254
source_name=source_name,
249255
reference_node_name=reference_node_name,
250256
trim_selects=trim_selects,
257+
schema=schema,
251258
_cache=_cache,
252259
_scope_meta=_scope_meta,
253260
on_node=on_node,
@@ -288,6 +295,7 @@ def to_node(
288295
source_name=source_name,
289296
reference_node_name=reference_node_name,
290297
trim_selects=trim_selects,
298+
schema=schema,
291299
_cache=_cache,
292300
_scope_meta=_scope_meta,
293301
on_node=on_node,
@@ -337,6 +345,7 @@ def to_node(
337345
dialect=dialect,
338346
upstream=node,
339347
trim_selects=trim_selects,
348+
schema=schema,
340349
_cache=_cache,
341350
_scope_meta=_scope_meta,
342351
on_node=on_node,
@@ -373,7 +382,18 @@ def to_node(
373382

374383
pivots = scope.pivots
375384
pivot = pivots[0] if len(pivots) == 1 else None
376-
pivot_column_mapping = _pivot_column_mapping(pivot) if pivot else {}
385+
pivot_renames: dict[str, str] = {}
386+
pivot_column_mapping: dict[str, list[exp.Column]] = {}
387+
388+
if pivot:
389+
pivot_renames = _pivot_output_renames(pivot, scope, schema)
390+
pivot_column_mapping = _pivot_column_mapping(pivot)
391+
if pivot_renames:
392+
pivot_column_mapping = {
393+
post: pivot_column_mapping[pre]
394+
for post, pre in pivot_renames.items()
395+
if pre in pivot_column_mapping
396+
}
377397

378398
for c in source_columns:
379399
table = c.table
@@ -397,6 +417,7 @@ def to_node(
397417
source_name=source_names.get(table) or source_name,
398418
reference_node_name=reference_node_name,
399419
trim_selects=trim_selects,
420+
schema=schema,
400421
_cache=_cache,
401422
_scope_meta=_scope_meta,
402423
on_node=on_node,
@@ -412,10 +433,22 @@ def to_node(
412433
# pivoted source -- adapt column to be from the implicit pivoted source.
413434
pivot_parent = pivot.parent
414435
downstream_columns.append(
415-
exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "")
436+
exp.column(
437+
pivot_renames.get(c.name, c.this),
438+
table=pivot_parent.alias_or_name if pivot_parent else None,
439+
)
416440
)
417441

418442
for downstream_column in downstream_columns:
443+
if not downstream_column.table:
444+
# Some dialects (e.g. bigquery) don't qualify the IN-list columns,
445+
# but they can only come from the pivoted source
446+
pivot_parent = pivot.parent
447+
downstream_column = exp.column(
448+
downstream_column.this,
449+
table=pivot_parent.alias_or_name if pivot_parent else None,
450+
)
451+
419452
table = downstream_column.table
420453
col_source = scope.sources.get(table)
421454
if isinstance(col_source, exp.Table) and not col_source.db:
@@ -432,6 +465,7 @@ def to_node(
432465
source_name=source_names.get(table) or source_name,
433466
reference_node_name=reference_node_name,
434467
trim_selects=trim_selects,
468+
schema=schema,
435469
_cache=_cache,
436470
_scope_meta=_scope_meta,
437471
on_node=on_node,
@@ -466,25 +500,66 @@ def to_node(
466500
return node
467501

468502

503+
def _pivot_output_renames(
504+
pivot: exp.Pivot, scope: Scope, schema: Schema | None = None
505+
) -> dict[str, str]:
506+
"""
507+
Map each (UN)PIVOT output column name to its pre-rename name, when an alias column
508+
list (`... AS t(c1, c2, ...)`) renames the outputs. The renames are positional over
509+
the operator's full output, so they can only be aligned when the pre-pivot columns
510+
are known: from the projections of a derived table or CTE source, or from the
511+
schema for a physical table.
512+
"""
513+
if not pivot.alias_column_names:
514+
return {}
515+
516+
parent = pivot.parent
517+
pre_pivot_columns: list[str] = []
518+
if isinstance(parent, exp.DerivedTable) and isinstance(parent.this, exp.Query):
519+
pre_pivot_columns = parent.this.named_selects
520+
elif isinstance(parent, exp.Table):
521+
cte_source = scope.cte_sources.get(parent.name) if not parent.db else None
522+
if isinstance(cte_source, Scope) and isinstance(cte_source.expression, exp.Query):
523+
pre_pivot_columns = cte_source.expression.named_selects
524+
elif schema is not None:
525+
pre_pivot_columns = list(schema.column_names(parent, only_visible=True))
526+
527+
# The alignment is also unknowable when the source's projections aren't fully
528+
# expanded (e.g. an unresolved star), since the renames would silently shift
529+
if not pre_pivot_columns or "*" in pre_pivot_columns:
530+
return {}
531+
532+
return pivot.output_columns(pre_pivot_columns)
533+
534+
469535
def _pivot_column_mapping(pivot: exp.Pivot) -> dict[str, list[exp.Column]]:
470536
"""Map each (UN)PIVOT output column name to the source columns it's derived from."""
471537
mapping: dict[str, list[exp.Column]] = {}
472538

473539
if pivot.unpivot:
474-
# UNPIVOT(val FOR name IN (a, b)): both the value column(s) and the name column
475-
# are derived from the IN-list source columns
476-
unpivot_columns = [
477-
col
478-
for field in pivot.fields
479-
for e in field.expressions
480-
for col in e.find_all(exp.Column)
540+
# UNPIVOT((v1, v2) FOR name IN ((a1, a2), (b1, b2))): each value column is derived
541+
# positionally from the IN-list entries, and the name column from all of them
542+
value_columns = [
543+
identifier for e in pivot.expressions for identifier in e.find_all(exp.Identifier)
481544
]
482-
for value_column in pivot.expressions:
483-
for identifier in value_column.find_all(exp.Identifier):
484-
mapping[identifier.name] = unpivot_columns
545+
for value_column in value_columns:
546+
mapping[value_column.name] = []
547+
485548
for field in pivot.fields:
486-
if isinstance(field, exp.In):
487-
mapping[field.this.name] = unpivot_columns
549+
if not isinstance(field, exp.In):
550+
continue
551+
552+
name_columns = mapping.setdefault(field.this.name, [])
553+
for entry in field.expressions:
554+
entry_columns = list(entry.find_all(exp.Column))
555+
name_columns.extend(entry_columns)
556+
557+
if len(entry_columns) == len(value_columns):
558+
for value_column, column in zip(value_columns, entry_columns):
559+
mapping[value_column.name].append(column)
560+
else:
561+
for value_column in value_columns:
562+
mapping[value_column.name].extend(entry_columns)
488563

489564
return mapping
490565

tests/test_lineage.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,87 @@ def test_unpivot_with_cte(self) -> None:
720720
self.assertEqual(node.downstream[0].name, "SRC.ID")
721721
self.assertEqual(node.downstream[0].downstream[0].name, "SALES.ID")
722722

723+
def test_unpivot_multi_column(self) -> None:
724+
sql = """
725+
SELECT product, semesters, first_half_sales, second_half_sales
726+
FROM produce
727+
UNPIVOT((first_half_sales, second_half_sales) FOR semesters IN ((q1, q2) AS 'semester_1', (q3, q4) AS 'semester_2'))
728+
"""
729+
schema = {
730+
"produce": {
731+
"product": "string",
732+
"q1": "int64",
733+
"q2": "int64",
734+
"q3": "int64",
735+
"q4": "int64",
736+
}
737+
}
738+
739+
node = lineage("first_half_sales", sql, schema=schema, dialect="bigquery")
740+
self.assertEqual([d.name for d in node.downstream], ["produce.q1", "produce.q3"])
741+
742+
node = lineage("second_half_sales", sql, schema=schema, dialect="bigquery")
743+
self.assertEqual([d.name for d in node.downstream], ["produce.q2", "produce.q4"])
744+
745+
node = lineage("semesters", sql, schema=schema, dialect="bigquery")
746+
self.assertEqual(
747+
[d.name for d in node.downstream],
748+
["produce.q1", "produce.q2", "produce.q3", "produce.q4"],
749+
)
750+
751+
node = lineage("product", sql, schema=schema, dialect="bigquery")
752+
self.assertEqual([d.name for d in node.downstream], ["produce.product"])
753+
754+
def test_unpivot_with_alias_columns(self) -> None:
755+
sql = """
756+
WITH src AS (
757+
SELECT empid, dept, jan, feb FROM monthly_sales
758+
)
759+
SELECT m, s, e FROM src UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s)
760+
"""
761+
schema = {"monthly_sales": {"empid": "int", "dept": "text", "jan": "int", "feb": "int"}}
762+
763+
for renamed in ("s", "m"):
764+
node = lineage(renamed, sql, schema=schema, dialect="snowflake")
765+
self.assertEqual([d.name for d in node.downstream], ["SRC.JAN", "SRC.FEB"])
766+
self.assertEqual(node.downstream[0].downstream[0].name, "MONTHLY_SALES.JAN")
767+
self.assertEqual(node.downstream[1].downstream[0].name, "MONTHLY_SALES.FEB")
768+
769+
node = lineage("e", sql, schema=schema, dialect="snowflake")
770+
self.assertEqual(node.downstream[0].name, "SRC.EMPID")
771+
self.assertEqual(node.downstream[0].downstream[0].name, "MONTHLY_SALES.EMPID")
772+
773+
# Physical table source: the pre-pivot columns come from the schema
774+
sql = "SELECT m, s, e FROM monthly_sales UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s)"
775+
776+
for renamed in ("s", "m"):
777+
node = lineage(renamed, sql, schema=schema, dialect="snowflake")
778+
self.assertEqual(
779+
[d.name for d in node.downstream], ["MONTHLY_SALES.JAN", "MONTHLY_SALES.FEB"]
780+
)
781+
782+
node = lineage("e", sql, schema=schema, dialect="snowflake")
783+
self.assertEqual([d.name for d in node.downstream], ["MONTHLY_SALES.EMPID"])
784+
785+
# Without a schema the star can't be expanded, so the positional renames are
786+
# unknowable and must not be applied (else `d` would misalign to jan/feb)
787+
sql = """
788+
WITH src AS (SELECT * FROM monthly_sales)
789+
SELECT d FROM src UNPIVOT(sales FOR month IN (jan, feb)) AS t(e, d, m, s)
790+
"""
791+
node = lineage("d", sql, dialect="snowflake")
792+
self.assertEqual([d.name for d in node.downstream], ["SRC.D"])
793+
794+
def test_pivot_with_alias_columns(self) -> None:
795+
sql = """
796+
SELECT x FROM (SELECT value, category FROM sample_data) AS sd
797+
PIVOT (SUM(value) FOR category IN ('a', 'b')) AS p(x, y)
798+
"""
799+
node = lineage("x", sql)
800+
801+
self.assertEqual(node.downstream[0].name, "sd.value")
802+
self.assertEqual(node.downstream[0].downstream[0].name, "sample_data.value")
803+
723804
def test_table_udtf_snowflake(self) -> None:
724805
lateral_flatten = """
725806
SELECT f.value:external_id::string AS external_id

0 commit comments

Comments
 (0)