Skip to content

Commit 4c1b671

Browse files
authored
Feat: Introduce API endpoint for column lineage (#691)
* Introduce API endpoint for column lineage * PR feedback * Use to_column instead of parse_one
1 parent 86f2278 commit 4c1b671

4 files changed

Lines changed: 83 additions & 3 deletions

File tree

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"requests",
4444
"rich",
4545
"ruamel.yaml",
46-
"sqlglot>=11.4.4",
46+
"sqlglot>=11.5.3",
4747
],
4848
extras_require={
4949
"dev": [
@@ -76,10 +76,10 @@
7676
"types-requests==2.28.8",
7777
],
7878
"web": [
79-
"fastapi==0.89.1",
79+
"fastapi==0.95.0",
8080
"hyperscript==0.0.1",
8181
"pyarrow==11.0.0",
82-
"uvicorn==0.20.0",
82+
"uvicorn==0.21.1",
8383
],
8484
"snowflake": [
8585
"snowflake-connector-python[pandas]",

tests/web/test_main.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,18 @@ def test_get_environments(project_context: Context) -> None:
469469
"expiration_ts": None,
470470
}
471471
}
472+
473+
474+
def test_get_lineage(web_sushi_context: Context) -> None:
475+
response = client.get(
476+
"/api/lineage", params={"model": "sushi.top_waiters", "column": "revenue"}
477+
)
478+
assert response.status_code == 200
479+
assert response.json() == {
480+
"sushi.top_waiters": {"revenue": {"sushi.waiter_revenue_by_day": ["revenue"]}},
481+
"sushi.waiter_revenue_by_day": {
482+
"revenue": {"sushi.items": ["price"], "sushi.order_items": ["quantity"]}
483+
},
484+
"sushi.items": {"price": {}},
485+
"sushi.order_items": {"quantity": {}},
486+
}

web/server/api/endpoints/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
environments,
88
events,
99
files,
10+
lineage,
1011
models,
1112
plan,
1213
)
@@ -21,5 +22,6 @@
2122
api_router.include_router(plan.router, prefix="/plan")
2223
api_router.include_router(environments.router, prefix="/environments")
2324
api_router.include_router(events.router, prefix="/events")
25+
api_router.include_router(lineage.router, prefix="/lineage")
2426
api_router.include_router(models.router, prefix="/models")
2527
api_router.include_router(context.router, prefix="/context")
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
import collections
4+
import traceback
5+
import typing as t
6+
7+
from fastapi import APIRouter, Depends, HTTPException
8+
from sqlglot import exp
9+
from sqlglot.lineage import Node, lineage
10+
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
11+
12+
from sqlmesh.core.context import Context
13+
from web.server.settings import get_loaded_context
14+
15+
router = APIRouter()
16+
17+
18+
def _get_table(node: Node) -> str:
19+
"""Get a node's table/source"""
20+
if isinstance(node.expression, exp.Table):
21+
return exp.table_name(node.expression)
22+
else:
23+
return node.alias
24+
25+
26+
def _process_downstream(downstream: t.List[Node]) -> t.Dict[str, t.List[str]]:
27+
"""Aggregate a list of downstream nodes by table/source"""
28+
graph = collections.defaultdict(list)
29+
for node in downstream:
30+
column = exp.to_column(node.name).name
31+
table = _get_table(node)
32+
graph[table].append(column)
33+
return graph
34+
35+
36+
@router.get("/")
37+
async def column_lineage(
38+
column: str,
39+
model: str,
40+
context: Context = Depends(get_loaded_context),
41+
) -> t.Dict[str, t.Dict[str, t.Dict[str, t.List[str]]]]:
42+
"""Get a column's lineage"""
43+
try:
44+
node = lineage(
45+
column=column,
46+
sql=context.models[model].render_query(),
47+
sources={
48+
model: context.models[model].render_query() for model in context.dag.upstream(model)
49+
},
50+
)
51+
except Exception:
52+
raise HTTPException(
53+
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=traceback.format_exc()
54+
)
55+
56+
graph = {}
57+
table = model
58+
for i, node in enumerate(node.walk()):
59+
if i > 0:
60+
table = _get_table(node)
61+
column = exp.to_column(node.name).name
62+
graph[table] = {column: _process_downstream(node.downstream)}
63+
return graph

0 commit comments

Comments
 (0)