Skip to content

Commit 405896c

Browse files
authored
add airflow depends_on_past (#940)
1 parent f5d4d45 commit 405896c

File tree

8 files changed

+226
-28
lines changed

8 files changed

+226
-28
lines changed

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ ignore_missing_imports = True
7272
[mypy-langchain.*]
7373
ignore_missing_imports = True
7474

75+
[mypy-pytest_lazyfixture.*]
76+
ignore_missing_imports = True
77+
7578
[autoflake]
7679
in-place = True
7780
expand-star-imports = True

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"PyGithub",
7979
"pytest",
8080
"pytest-asyncio",
81+
"pytest-lazy-fixture",
8182
"pytest-mock",
8283
"pyspark>=3.4.0",
8384
"pytz",

sqlmesh/schedulers/airflow/dag_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def _create_backfill_tasks(
349349
snapshot_end_task = EmptyOperator(
350350
task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__end"
351351
)
352-
if snapshot.is_incremental_by_unique_key:
352+
if snapshot.depends_on_past:
353353
baseoperator.chain(snapshot_start_task, *tasks, snapshot_end_task)
354354
else:
355355
snapshot_start_task >> tasks >> snapshot_end_task

sqlmesh/schedulers/airflow/plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def create_plan_dag_spec(
6262
intervals=state_sync.get_snapshot_intervals(all_snapshots.values()),
6363
start=request.environment.start_at,
6464
end=end,
65-
latest=end,
65+
latest=now(),
6666
is_dev=request.is_dev,
6767
restatements=request.restatements,
6868
)

tests/conftest.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import datetime
34
import typing as t
45
from pathlib import Path
56
from shutil import rmtree
@@ -13,15 +14,39 @@
1314
from sqlglot.helper import ensure_list
1415

1516
from sqlmesh.core.context import Context
17+
from sqlmesh.core.engine_adapter.base import EngineAdapter
1618
from sqlmesh.core.model import Model
1719
from sqlmesh.core.plan import BuiltInPlanEvaluator, Plan
1820
from sqlmesh.core.snapshot import Snapshot
1921
from sqlmesh.utils import random_id
20-
from sqlmesh.utils.date import TimeLike
22+
from sqlmesh.utils.date import TimeLike, to_date, to_ds
2123

2224
pytest_plugins = ["tests.common_fixtures"]
2325

2426

27+
class SushiDataValidator:
28+
def __init__(self, engine_adapter: EngineAdapter):
29+
self.engine_adapter = engine_adapter
30+
31+
@classmethod
32+
def from_context(cls, context: Context):
33+
return cls(engine_adapter=context.engine_adapter)
34+
35+
def validate(self, model_name: str, start: TimeLike, end: TimeLike) -> t.Dict[t.Any, t.Any]:
36+
if model_name == "sushi.customer_revenue_lifetime":
37+
query = "SELECT ds, count(*) AS the_count FROM sushi.customer_revenue_lifetime group by 1 order by 2 desc, 1 desc"
38+
results = self.engine_adapter.fetchdf(query).to_dict()
39+
start_date, end_date = to_date(start), to_date(end)
40+
num_days_diff = (end_date - start_date).days + 1
41+
assert len(results["ds"]) == num_days_diff
42+
assert list(results["ds"].values()) == [
43+
to_ds(end_date - datetime.timedelta(days=x)) for x in range(num_days_diff)
44+
]
45+
return results
46+
else:
47+
raise NotImplementedError(f"Unknown model_name: {model_name}")
48+
49+
2550
# Ignore all local config files
2651
@pytest.fixture(scope="session", autouse=True)
2752
def ignore_local_config_files():
@@ -142,6 +167,16 @@ def random_name() -> t.Callable:
142167
return lambda: f"generated_{random_id()}"
143168

144169

170+
@pytest.fixture
171+
def sushi_data_validator(sushi_context: Context) -> SushiDataValidator:
172+
return SushiDataValidator.from_context(sushi_context)
173+
174+
175+
@pytest.fixture
176+
def sushi_fixed_date_data_validator(sushi_context_fixed_date: Context) -> SushiDataValidator:
177+
return SushiDataValidator.from_context(sushi_context_fixed_date)
178+
179+
145180
def delete_cache(project_paths: str | t.List[str]) -> None:
146181
for path in ensure_list(project_paths):
147182
try:

tests/core/test_integration.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import datetime
21
import typing as t
32
from collections import Counter
43

@@ -26,7 +25,8 @@
2625
SnapshotInfoLike,
2726
SnapshotTableInfo,
2827
)
29-
from sqlmesh.utils.date import TimeLike, to_date, to_ds, to_timestamp, yesterday
28+
from sqlmesh.utils.date import TimeLike, to_date, to_ds, to_timestamp
29+
from tests.conftest import SushiDataValidator
3030

3131

3232
@pytest.fixture(autouse=True)
@@ -655,18 +655,15 @@ def test_multi(mocker):
655655

656656
@pytest.mark.integration
657657
@pytest.mark.core_integration
658-
def test_incremental_time_self_reference(mocker: MockerFixture, sushi_context: Context):
658+
def test_incremental_time_self_reference(
659+
mocker: MockerFixture, sushi_context: Context, sushi_data_validator: SushiDataValidator
660+
):
661+
start_date, end_date = to_date("1 week ago"), to_date("yesterday")
659662
df = sushi_context.engine_adapter.fetchdf("SELECT MIN(ds) FROM sushi.customer_revenue_lifetime")
660-
assert df.iloc[0, 0] == to_ds("1 week ago")
663+
assert df.iloc[0, 0] == to_ds(start_date)
661664
df = sushi_context.engine_adapter.fetchdf("SELECT MAX(ds) FROM sushi.customer_revenue_lifetime")
662-
assert df.iloc[0, 0] == to_ds("yesterday")
663-
query_get_date_and_count = "SELECT ds, count(*) AS the_count FROM sushi.customer_revenue_lifetime group by 1 order by 2 desc, 1 desc"
664-
results = sushi_context.engine_adapter.fetchdf(query_get_date_and_count).to_dict()
665-
# Validate that both rows increase over time and all days are present
666-
assert len(results["ds"]) == 7
667-
assert list(results["ds"].values()) == [
668-
to_ds(yesterday() - datetime.timedelta(days=x)) for x in range(7)
669-
]
665+
assert df.iloc[0, 0] == to_ds(end_date)
666+
results = sushi_data_validator.validate("sushi.customer_revenue_lifetime", start_date, end_date)
670667
plan = sushi_context.plan(
671668
restate_models=["sushi.customer_revenue_lifetime", "sushi.customer_revenue_by_day"],
672669
no_prompts=True,
@@ -709,7 +706,9 @@ def test_incremental_time_self_reference(mocker: MockerFixture, sushi_context: C
709706
"sushi.customer_revenue_by_day": 1,
710707
}
711708
# Validate that the results are the same as before the restate
712-
assert results == sushi_context.engine_adapter.fetchdf(query_get_date_and_count).to_dict()
709+
assert results == sushi_data_validator.validate(
710+
"sushi.customer_revenue_lifetime", start_date, end_date
711+
)
713712

714713

715714
def initial_add(context: Context, environment: str):

tests/schedulers/airflow/test_end_to_end.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from sqlmesh.core.context import Context
66
from sqlmesh.schedulers.airflow.client import AirflowClient
7-
from sqlmesh.utils.date import now, yesterday_ds
7+
from sqlmesh.utils.date import now
8+
from tests.conftest import SushiDataValidator
89

910

1011
@pytest.fixture(autouse=True)
@@ -19,12 +20,13 @@ def get_receiver_dag() -> None:
1920
@pytest.mark.integration
2021
@pytest.mark.airflow_integration
2122
def test_sushi(mocker: MockerFixture, is_docker: bool):
22-
start = yesterday_ds()
23+
start = "1 week ago"
2324
end = now()
2425
latest = end
2526

2627
airflow_config = "airflow_config_docker" if is_docker else "airflow_config"
2728
context = Context(paths="./examples/sushi", config=airflow_config)
29+
data_validator = SushiDataValidator.from_context(context)
2830

2931
context.plan(
3032
environment="test_dev",
@@ -36,6 +38,8 @@ def test_sushi(mocker: MockerFixture, is_docker: bool):
3638
auto_apply=True,
3739
)
3840

41+
data_validator.validate("sushi.customer_revenue_lifetime", start, end)
42+
3943
# Ensure that the plan has been applied successfully.
4044
no_change_plan = context.plan(
4145
environment="test_dev_two",

0 commit comments

Comments
 (0)