Skip to content

Commit d298ab8

Browse files
authored
Feat: Support sensors for external tables in Airflow (#1536)
1 parent f2e9ed8 commit d298ab8

10 files changed

Lines changed: 260 additions & 7 deletions

File tree

sqlmesh/core/model/definition.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,53 @@ def render_post_statements(
363363
"""
364364
return []
365365

366+
def render_signals(
367+
self,
368+
*,
369+
start: t.Optional[TimeLike] = None,
370+
end: t.Optional[TimeLike] = None,
371+
execution_time: t.Optional[TimeLike] = None,
372+
) -> t.List[t.Dict[str, str | int | float | bool]]:
373+
"""Renders external; signals defined for this model.
374+
375+
Args:
376+
start: The start datetime to render. Defaults to epoch start.
377+
end: The end datetime to render. Defaults to epoch start.
378+
execution_time: The date/time time reference to use for execution time.
379+
380+
Returns:
381+
The list of rendered expressions.
382+
"""
383+
384+
def _create_renderer(expression: exp.Expression) -> ExpressionRenderer:
385+
return ExpressionRenderer(
386+
expression,
387+
self.dialect,
388+
[],
389+
path=self._path,
390+
jinja_macro_registry=self.jinja_macros,
391+
python_env=self.python_env,
392+
only_execution_time=False,
393+
)
394+
395+
def _render(e: exp.Expression) -> str | int | float | bool:
396+
rendered_exprs = _create_renderer(e).render(
397+
start=start, end=end, execution_time=execution_time
398+
)
399+
if len(rendered_exprs) != 1:
400+
raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}")
401+
402+
rendered = rendered_exprs[0]
403+
if rendered.is_int:
404+
return int(rendered.this)
405+
if rendered.is_number:
406+
return float(rendered.this)
407+
if isinstance(rendered, (exp.Literal, exp.Boolean)):
408+
return rendered.this
409+
return rendered.sql(dialect=self.dialect)
410+
411+
return [{t.this.name: _render(t.expression) for t in signal} for signal in self.signals]
412+
366413
def ctas_query(self, **render_kwarg: t.Any) -> exp.Subqueryable:
367414
"""Return a dummy query to do a CTAS.
368415
@@ -780,6 +827,8 @@ def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str:
780827
else:
781828
raise SQLMeshError(f"Unexpected audit name '{audit_name}'.")
782829

830+
metadata.extend([s.sql(comments=True) for s in self.signals])
831+
783832
# Add comments from the query.
784833
if self.is_sql:
785834
rendered_query = self.render_query()
@@ -1957,4 +2006,5 @@ def _refs_to_sql(values: t.Any) -> exp.Expression:
19572006
"table_properties_": lambda value: value,
19582007
"session_properties_": lambda value: value,
19592008
"allow_partials": exp.convert,
2009+
"signals": lambda values: exp.Tuple(expressions=values),
19602010
}

sqlmesh/core/model/meta.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlmesh.core.model.common import (
1212
bool_validator,
1313
parse_expressions,
14+
parse_properties,
1415
properties_validator,
1516
)
1617
from sqlmesh.core.model.kind import (
@@ -57,6 +58,7 @@ class ModelMeta(_Node, extra="allow"):
5758
table_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="table_properties")
5859
session_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="session_properties")
5960
allow_partials: bool = False
61+
signals: t.List[exp.Tuple] = []
6062

6163
_table_properties: t.Dict[str, exp.Expression] = {}
6264

@@ -230,6 +232,28 @@ def _refs_validator(cls, vs: t.Any, values: t.Dict[str, t.Any]) -> t.List[exp.Ex
230232

231233
return refs
232234

235+
@field_validator("signals", mode="before")
236+
@field_validator_v1_args
237+
def _signals_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any:
238+
if v is None:
239+
return []
240+
241+
if isinstance(v, str):
242+
dialect = values.get("dialect")
243+
v = d.parse_one(v, dialect=dialect)
244+
245+
if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)):
246+
tuples: t.List[exp.Expression] = (
247+
[v.unnest()] if isinstance(v, exp.Paren) else v.expressions
248+
)
249+
signals = [parse_properties(cls, t, values) for t in tuples]
250+
elif isinstance(v, list):
251+
signals = [parse_properties(cls, t, values) for t in v]
252+
else:
253+
raise ConfigError(f"Unexpected signals '{v}'")
254+
255+
return signals
256+
233257
@model_validator(mode="before")
234258
def _pre_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
235259
grain = values.pop("grain", None)

sqlmesh/schedulers/airflow/dag_generator.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from airflow.models import BaseOperator, baseoperator
1010
from airflow.operators.empty import EmptyOperator
1111
from airflow.operators.python import PythonOperator
12+
from airflow.sensors.base import BaseSensorOperator
1213

1314
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo
1415
from sqlmesh.core.notification_target import NotificationTarget
@@ -21,7 +22,10 @@
2122
)
2223
from sqlmesh.schedulers.airflow import common, util
2324
from sqlmesh.schedulers.airflow.operators import targets
24-
from sqlmesh.schedulers.airflow.operators.hwm_sensor import HighWaterMarkSensor
25+
from sqlmesh.schedulers.airflow.operators.hwm_sensor import (
26+
HighWaterMarkExternalSensor,
27+
HighWaterMarkSensor,
28+
)
2529
from sqlmesh.schedulers.airflow.operators.notification import (
2630
BaseNotificationOperatorProvider,
2731
)
@@ -58,12 +62,16 @@ def __init__(
5862
engine_operator_args: t.Optional[t.Dict[str, t.Any]],
5963
ddl_engine_operator: t.Type[BaseOperator],
6064
ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]],
65+
external_table_sensor_factory: t.Optional[
66+
t.Callable[[t.Dict[str, t.Any]], BaseSensorOperator]
67+
],
6168
snapshots: t.Dict[SnapshotId, Snapshot],
6269
):
6370
self._engine_operator = engine_operator
6471
self._engine_operator_args = engine_operator_args or {}
6572
self._ddl_engine_operator = ddl_engine_operator
6673
self._ddl_engine_operator_args = ddl_engine_operator_args or {}
74+
self._external_table_sensor_factory = external_table_sensor_factory
6775
self._snapshots = snapshots
6876

6977
def generate_cadence_dags(self, snapshots: t.Iterable[SnapshotIdLike]) -> t.List[DAG]:
@@ -506,10 +514,12 @@ def _create_snapshot_evaluator_operator(
506514
task_id=task_id,
507515
)
508516

509-
def _create_hwm_sensors(self, snapshot: Snapshot) -> t.List[HighWaterMarkSensor]:
510-
output = []
517+
def _create_hwm_sensors(self, snapshot: Snapshot) -> t.List[BaseSensorOperator]:
518+
output: t.List[BaseSensorOperator] = []
519+
depends_on = snapshot.node.depends_on
511520
for upstream_snapshot_id in snapshot.parents:
512521
upstream_snapshot = self._snapshots[upstream_snapshot_id]
522+
depends_on.discard(upstream_snapshot.name)
513523
if not upstream_snapshot.is_symbolic and not upstream_snapshot.is_seed:
514524
output.append(
515525
HighWaterMarkSensor(
@@ -518,6 +528,16 @@ def _create_hwm_sensors(self, snapshot: Snapshot) -> t.List[HighWaterMarkSensor]
518528
task_id=f"{sanitize_name(upstream_snapshot.name)}_{upstream_snapshot.version}_high_water_mark_sensor",
519529
)
520530
)
531+
532+
if self._external_table_sensor_factory and snapshot.model.signals:
533+
output.append(
534+
HighWaterMarkExternalSensor(
535+
snapshot=snapshot,
536+
external_table_sensor_factory=self._external_table_sensor_factory,
537+
task_id="external_high_water_mark_sensor",
538+
)
539+
)
540+
521541
return output
522542

523543

sqlmesh/schedulers/airflow/integration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from airflow import DAG
88
from airflow.models import BaseOperator, TaskInstance, Variable
99
from airflow.operators.python import PythonOperator
10+
from airflow.sensors.base import BaseSensorOperator
1011
from airflow.utils.session import provide_session
1112
from sqlalchemy.orm import Session
1213

@@ -58,6 +59,7 @@ class SQLMeshAirflow:
5859
deletion from Airflow. Default: 1 hour.
5960
plan_application_dag_ttl: Determines the time-to-live period for finished plan application DAGs.
6061
Once this period is exceeded, finished plan application DAGs are deleted by the janitor. Default: 2 days.
62+
external_table_sensor_factory: A factory function that creates a sensor operator for a given signal payload.
6163
"""
6264

6365
def __init__(
@@ -68,6 +70,9 @@ def __init__(
6870
ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]] = None,
6971
janitor_interval: timedelta = timedelta(hours=1),
7072
plan_application_dag_ttl: timedelta = timedelta(days=2),
73+
external_table_sensor_factory: t.Optional[
74+
t.Callable[[t.Dict[str, t.Any]], BaseSensorOperator]
75+
] = None,
7176
):
7277
if isinstance(engine_operator, str):
7378
if not ddl_engine_operator:
@@ -89,6 +94,7 @@ def __init__(
8994
self._ddl_engine_operator_args = ddl_engine_operator_args or {}
9095
self._janitor_interval = janitor_interval
9196
self._plan_application_dag_ttl = plan_application_dag_ttl
97+
self._external_table_sensor_factory = external_table_sensor_factory
9298

9399
@property
94100
def dags(self) -> t.List[DAG]:
@@ -109,6 +115,7 @@ def dags(self) -> t.List[DAG]:
109115
self._engine_operator_args,
110116
self._ddl_engine_operator,
111117
self._ddl_engine_operator_args,
118+
self._external_table_sensor_factory,
112119
stored_snapshots,
113120
)
114121

sqlmesh/schedulers/airflow/operators/hwm_sensor.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from datetime import datetime
44

55
from airflow.models import DagRun
6-
from airflow.sensors.base import BaseSensorOperator
6+
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
77
from airflow.utils.context import Context
88

99
from sqlmesh.core.snapshot import Snapshot, SnapshotTableInfo
1010
from sqlmesh.schedulers.airflow import util
11-
from sqlmesh.utils.date import to_datetime
11+
from sqlmesh.utils.date import now, to_datetime
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -63,3 +63,33 @@ def _compute_target_high_water_mark(
6363
target_prev = to_datetime(target_snapshot.node.cron_floor(target_date))
6464
this_prev = to_datetime(self.this_snapshot.node.cron_floor(target_date))
6565
return min(target_prev, this_prev)
66+
67+
68+
class HighWaterMarkExternalSensor(BaseSensorOperator):
69+
def __init__(
70+
self,
71+
snapshot: Snapshot,
72+
external_table_sensor_factory: t.Callable[[t.Dict[str, t.Any]], BaseSensorOperator],
73+
poke_interval: float = 60.0,
74+
timeout: float = 7.0 * 24.0 * 60.0 * 60.0, # 7 days
75+
mode: str = "reschedule",
76+
**kwargs: t.Any,
77+
):
78+
super().__init__(
79+
poke_interval=poke_interval,
80+
timeout=timeout,
81+
mode=mode,
82+
**kwargs,
83+
)
84+
self.snapshot = snapshot
85+
self.external_table_sensor_factory = external_table_sensor_factory
86+
87+
def poke(self, context: Context) -> t.Union[bool, PokeReturnValue]:
88+
dag_run = context["dag_run"]
89+
signals = self.snapshot.model.render_signals(
90+
start=to_datetime(dag_run.data_interval_start),
91+
end=to_datetime(dag_run.data_interval_end),
92+
execution_time=now(minute_floor=False),
93+
)
94+
delegates = [self.external_table_sensor_factory(signal) for signal in signals]
95+
return all(d.poke(context) for d in delegates)

sqlmesh/utils/date.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def date_dict(
214214
kwargs[f"{prefix}_ts"] = dt.isoformat()
215215
kwargs[f"{prefix}_epoch"] = millis / 1000
216216
kwargs[f"{prefix}_millis"] = millis
217+
kwargs[f"{prefix}_hour"] = dt.hour
217218
return kwargs
218219

219220

tests/core/test_model.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,3 +2282,68 @@ def test_model_allow_partials():
22822282
assert model.allow_partials
22832283

22842284
assert "allow_partials TRUE" in model.render_definition()[0].sql()
2285+
2286+
2287+
def test_signals():
2288+
expressions = d.parse(
2289+
"""
2290+
MODEL (
2291+
name db.table,
2292+
signals [
2293+
(
2294+
table_name = 'table_a',
2295+
ds = @end_ds,
2296+
),
2297+
(
2298+
table_name = 'table_b',
2299+
ds = @end_ds,
2300+
hour = @end_hour,
2301+
),
2302+
(
2303+
bool_key = True,
2304+
int_key = 1,
2305+
float_key = 1.0,
2306+
string_key = 'string',
2307+
)
2308+
],
2309+
);
2310+
SELECT 1;
2311+
"""
2312+
)
2313+
2314+
model = load_sql_based_model(expressions)
2315+
assert model.signals == [
2316+
exp.Tuple(
2317+
expressions=[
2318+
exp.to_column("table_name").eq("table_a"),
2319+
exp.to_column("ds").eq(d.MacroVar(this="end_ds")),
2320+
]
2321+
),
2322+
exp.Tuple(
2323+
expressions=[
2324+
exp.to_column("table_name").eq("table_b"),
2325+
exp.to_column("ds").eq(d.MacroVar(this="end_ds")),
2326+
exp.to_column("hour").eq(d.MacroVar(this="end_hour")),
2327+
]
2328+
),
2329+
exp.Tuple(
2330+
expressions=[
2331+
exp.to_column("bool_key").eq(True),
2332+
exp.to_column("int_key").eq(1),
2333+
exp.to_column("float_key").eq(1.0),
2334+
exp.to_column("string_key").eq("string"),
2335+
]
2336+
),
2337+
]
2338+
2339+
rendered_signals = model.render_signals(start="2023-01-01", end="2023-01-02 15:00:00")
2340+
assert rendered_signals == [
2341+
{"table_name": "table_a", "ds": "2023-01-02"},
2342+
{"table_name": "table_b", "ds": "2023-01-02", "hour": 14},
2343+
{"bool_key": True, "int_key": 1, "float_key": 1.0, "string_key": "string"},
2344+
]
2345+
2346+
assert (
2347+
"signals ((table_name = 'table_a', ds = @end_ds), (table_name = 'table_b', ds = @end_ds, hour = @end_hour), (bool_key = TRUE, int_key = 1, float_key = 1.0, string_key = 'string')"
2348+
in model.render_definition()[0].sql()
2349+
)

tests/core/test_snapshot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def test_json(snapshot: Snapshot):
120120
"references": [],
121121
"hash_raw_query": False,
122122
"allow_partials": False,
123+
"signals": [],
123124
},
124125
"audits": [],
125126
"name": "name",

0 commit comments

Comments
 (0)