Skip to content

Commit 78f5530

Browse files
committed
Feat(sqlmesh_dbt): Select based on dbt name, not sqlmesh name
1 parent a53e2eb commit 78f5530

File tree

6 files changed

+252
-10
lines changed

6 files changed

+252
-10
lines changed

sqlmesh/core/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ class GenericContext(BaseContext, t.Generic[C]):
348348
load: Whether or not to automatically load all models and macros (default True).
349349
console: The rich instance used for printing out CLI command results.
350350
users: A list of users to make known to SQLMesh.
351+
dbt_mode: A flag to indicate we are running in 'dbt mode' which means that things like
352+
model selections should use the dbt names and not the native SQLMesh names
351353
"""
352354

353355
CONFIG_TYPE: t.Type[C]
@@ -368,6 +370,7 @@ def __init__(
368370
load: bool = True,
369371
users: t.Optional[t.List[User]] = None,
370372
config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
373+
dbt_mode: bool = False,
371374
):
372375
self.configs = (
373376
config
@@ -390,6 +393,7 @@ def __init__(
390393
self._engine_adapter: t.Optional[EngineAdapter] = None
391394
self._linters: t.Dict[str, Linter] = {}
392395
self._loaded: bool = False
396+
self._dbt_mode = dbt_mode
393397

394398
self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))
395399

@@ -2901,6 +2905,7 @@ def _new_selector(
29012905
default_catalog=self.default_catalog,
29022906
dialect=self.default_dialect,
29032907
cache_dir=self.cache_dir,
2908+
dbt_mode=self._dbt_mode,
29042909
)
29052910

29062911
def _register_notification_targets(self) -> None:

sqlmesh/core/selector.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import fnmatch
44
import typing as t
55
from pathlib import Path
6+
from itertools import zip_longest
67

78
from sqlglot import exp
89
from sqlglot.errors import ParseError
@@ -36,6 +37,7 @@ def __init__(
3637
default_catalog: t.Optional[str] = None,
3738
dialect: t.Optional[str] = None,
3839
cache_dir: t.Optional[Path] = None,
40+
dbt_mode: bool = False,
3941
):
4042
self._state_reader = state_reader
4143
self._models = models
@@ -44,6 +46,7 @@ def __init__(
4446
self._default_catalog = default_catalog
4547
self._dialect = dialect
4648
self._git_client = GitClient(context_path)
49+
self._dbt_mode = dbt_mode
4750

4851
if dag is None:
4952
self._dag: DAG[str] = DAG()
@@ -167,13 +170,13 @@ def get_model(fqn: str) -> t.Optional[Model]:
167170
def expand_model_selections(
168171
self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Model]] = None
169172
) -> t.Set[str]:
170-
"""Expands a set of model selections into a set of model names.
173+
"""Expands a set of model selections into a set of model fqns that can be looked up in the Context.
171174
172175
Args:
173176
model_selections: A set of model selections.
174177
175178
Returns:
176-
A set of model names.
179+
A set of model fqns.
177180
"""
178181

179182
node = parse(" | ".join(f"({s})" for s in model_selections))
@@ -194,10 +197,9 @@ def evaluate(node: exp.Expression) -> t.Set[str]:
194197
return {
195198
fqn
196199
for fqn, model in all_models.items()
197-
if fnmatch.fnmatchcase(model.name, node.this)
200+
if fnmatch.fnmatchcase(self._model_name(model), node.this)
198201
}
199-
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
200-
return {fqn} if fqn in all_models else set()
202+
return self._pattern_to_model_fqns(pattern, all_models)
201203
if isinstance(node, exp.And):
202204
return evaluate(node.left) & evaluate(node.right)
203205
if isinstance(node, exp.Or):
@@ -241,6 +243,59 @@ def evaluate(node: exp.Expression) -> t.Set[str]:
241243

242244
return evaluate(node)
243245

246+
def _model_fqn(self, model: Model) -> str:
247+
if self._dbt_mode:
248+
dbt_fqn = model.dbt_fqn
249+
if dbt_fqn is None:
250+
raise SQLMeshError("Expecting dbt node information to be populated; it wasnt")
251+
return dbt_fqn
252+
return model.fqn
253+
254+
def _model_name(self, model: Model) -> str:
255+
if self._dbt_mode:
256+
# dbt always matches on the fqn, not the name
257+
return self._model_fqn(model)
258+
return model.name
259+
260+
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
261+
# note: all_models should be keyed by sqlmesh fqn, not dbt fqn
262+
if not self._dbt_mode:
263+
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
264+
return {fqn} if fqn in all_models else set()
265+
266+
# a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers"
267+
# but not a model called "jaffle_shop.customers.staging"
268+
# also a pattern like "aging" should not match "staging" so we need to consider components; not substrings
269+
pattern_components = pattern.split(".")
270+
first_pattern_component = pattern_components[0]
271+
matches = set()
272+
for fqn, model in all_models.items():
273+
if not model.dbt_fqn:
274+
continue
275+
276+
dbt_fqn_components = model.dbt_fqn.split(".")
277+
try:
278+
starting_idx = dbt_fqn_components.index(first_pattern_component)
279+
except ValueError:
280+
continue
281+
for pattern_component, fqn_component in zip_longest(
282+
pattern_components, dbt_fqn_components[starting_idx:]
283+
):
284+
if pattern_component and not fqn_component:
285+
# the pattern still goes but we have run out of fqn components to match; no match
286+
break
287+
if fqn_component and not pattern_component:
288+
# all elements of the pattern have matched elements of the fqn; match
289+
matches.add(fqn)
290+
break
291+
if pattern_component != fqn_component:
292+
# the pattern explicitly doesnt match a component; no match
293+
break
294+
else:
295+
# called if no explicit break, indicating all components of the pattern matched all components of the fqn
296+
matches.add(fqn)
297+
return matches
298+
244299

245300
class SelectorDialect(Dialect):
246301
IDENTIFIERS_CAN_START_WITH_DIGIT = True

sqlmesh_dbt/operations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ def create(
250250
paths=[project_dir],
251251
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
252252
load=True,
253+
# dbt mode enables selectors to use dbt model fqn's rather than SQLMesh model names
254+
dbt_mode=True,
253255
)
254256

255257
dbt_loader = sqlmesh_context._loaders[0]

tests/dbt/cli/test_list.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
1919

2020

2121
def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
22-
result = invoke_cli(["list", "--select", "main.raw_customers+"])
22+
result = invoke_cli(["list", "--select", "raw_customers+"])
2323

2424
assert result.exit_code == 0
2525
assert not result.exception
@@ -34,7 +34,7 @@ def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Resul
3434

3535
def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
3636
# single exclude
37-
result = invoke_cli(["list", "--select", "main.raw_customers+", "--exclude", "main.orders"])
37+
result = invoke_cli(["list", "--select", "raw_customers+", "--exclude", "orders"])
3838

3939
assert result.exit_code == 0
4040
assert not result.exception
@@ -49,8 +49,8 @@ def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..
4949

5050
# multiple exclude
5151
for args in (
52-
["--select", "main.stg_orders+", "--exclude", "main.customers", "--exclude", "main.orders"],
53-
["--select", "main.stg_orders+", "--exclude", "main.customers main.orders"],
52+
["--select", "stg_orders+", "--exclude", "customers", "--exclude", "orders"],
53+
["--select", "stg_orders+", "--exclude", "customers orders"],
5454
):
5555
result = invoke_cli(["list", *args])
5656
assert result.exit_code == 0

tests/dbt/cli/test_selectors.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import typing as t
22
import pytest
33
from sqlmesh_dbt import selectors
4+
from sqlmesh.core.context import Context
5+
from pathlib import Path
46

57

68
@pytest.mark.parametrize(
@@ -77,3 +79,181 @@ def test_split_unions_and_intersections(
7779
expression: str, expected: t.Tuple[t.List[str], t.List[str]]
7880
):
7981
assert selectors._split_unions_and_intersections(expression) == expected
82+
83+
84+
@pytest.mark.parametrize(
85+
"dbt_select,expected",
86+
[
87+
(["aging"], set()),
88+
(
89+
["staging"],
90+
{
91+
'"jaffle_shop"."main"."stg_customers"',
92+
'"jaffle_shop"."main"."stg_orders"',
93+
'"jaffle_shop"."main"."stg_payments"',
94+
},
95+
),
96+
(["staging.stg_customers"], {'"jaffle_shop"."main"."stg_customers"'}),
97+
(["stg_customers.staging"], set()),
98+
(
99+
["+customers"],
100+
{
101+
'"jaffle_shop"."main"."customers"',
102+
'"jaffle_shop"."main"."stg_customers"',
103+
'"jaffle_shop"."main"."stg_orders"',
104+
'"jaffle_shop"."main"."stg_payments"',
105+
'"jaffle_shop"."main"."raw_customers"',
106+
'"jaffle_shop"."main"."raw_orders"',
107+
'"jaffle_shop"."main"."raw_payments"',
108+
},
109+
),
110+
(["customers+"], {'"jaffle_shop"."main"."customers"'}),
111+
(
112+
["customers+", "stg_orders"],
113+
{'"jaffle_shop"."main"."customers"', '"jaffle_shop"."main"."stg_orders"'},
114+
),
115+
(["tag:agg"], {'"jaffle_shop"."main"."agg_orders"'}),
116+
(
117+
["staging.stg_customers", "tag:agg"],
118+
{
119+
'"jaffle_shop"."main"."stg_customers"',
120+
'"jaffle_shop"."main"."agg_orders"',
121+
},
122+
),
123+
(
124+
["+tag:agg"],
125+
{
126+
'"jaffle_shop"."main"."agg_orders"',
127+
'"jaffle_shop"."main"."orders"',
128+
'"jaffle_shop"."main"."stg_orders"',
129+
'"jaffle_shop"."main"."stg_payments"',
130+
'"jaffle_shop"."main"."raw_orders"',
131+
'"jaffle_shop"."main"."raw_payments"',
132+
},
133+
),
134+
(
135+
["tag:agg+"],
136+
{
137+
'"jaffle_shop"."main"."agg_orders"',
138+
},
139+
),
140+
],
141+
)
142+
def test_select_by_dbt_names(
143+
jaffle_shop_duckdb: Path,
144+
jaffle_shop_duckdb_context: Context,
145+
dbt_select: t.List[str],
146+
expected: t.Set[str],
147+
):
148+
(jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text("""
149+
{{ config(tags=["agg"]) }}
150+
select order_date, count(*) as num_orders from {{ ref('orders') }}
151+
""")
152+
153+
ctx = jaffle_shop_duckdb_context
154+
ctx.load()
155+
assert '"jaffle_shop"."main"."agg_orders"' in ctx.models
156+
157+
selector = ctx._new_selector()
158+
assert selector._dbt_mode
159+
160+
sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=[])
161+
assert sqlmesh_selector
162+
163+
assert selector.expand_model_selections([sqlmesh_selector]) == expected
164+
165+
166+
@pytest.mark.parametrize(
167+
"dbt_exclude,expected",
168+
[
169+
(["jaffle_shop"], set()),
170+
(
171+
["staging"],
172+
{
173+
'"jaffle_shop"."main"."agg_orders"',
174+
'"jaffle_shop"."main"."customers"',
175+
'"jaffle_shop"."main"."orders"',
176+
'"jaffle_shop"."main"."raw_customers"',
177+
'"jaffle_shop"."main"."raw_orders"',
178+
'"jaffle_shop"."main"."raw_payments"',
179+
},
180+
),
181+
(["+customers"], {'"jaffle_shop"."main"."orders"', '"jaffle_shop"."main"."agg_orders"'}),
182+
(
183+
["+tag:agg"],
184+
{
185+
'"jaffle_shop"."main"."customers"',
186+
'"jaffle_shop"."main"."stg_customers"',
187+
'"jaffle_shop"."main"."raw_customers"',
188+
},
189+
),
190+
],
191+
)
192+
def test_exclude_by_dbt_names(
193+
jaffle_shop_duckdb: Path,
194+
jaffle_shop_duckdb_context: Context,
195+
dbt_exclude: t.List[str],
196+
expected: t.Set[str],
197+
):
198+
(jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text("""
199+
{{ config(tags=["agg"]) }}
200+
select order_date, count(*) as num_orders from {{ ref('orders') }}
201+
""")
202+
203+
ctx = jaffle_shop_duckdb_context
204+
ctx.load()
205+
assert '"jaffle_shop"."main"."agg_orders"' in ctx.models
206+
207+
selector = ctx._new_selector()
208+
assert selector._dbt_mode
209+
210+
sqlmesh_selector = selectors.to_sqlmesh(dbt_select=[], dbt_exclude=dbt_exclude)
211+
assert sqlmesh_selector
212+
213+
assert selector.expand_model_selections([sqlmesh_selector]) == expected
214+
215+
216+
@pytest.mark.parametrize(
217+
"dbt_select,dbt_exclude,expected",
218+
[
219+
(["jaffle_shop"], ["jaffle_shop"], set()),
220+
(
221+
["staging"],
222+
["stg_customers"],
223+
{
224+
'"jaffle_shop"."main"."stg_orders"',
225+
'"jaffle_shop"."main"."stg_payments"',
226+
},
227+
),
228+
(
229+
["staging.stg_customers", "tag:agg"],
230+
["tag:agg"],
231+
{
232+
'"jaffle_shop"."main"."stg_customers"',
233+
},
234+
),
235+
],
236+
)
237+
def test_selection_and_exclusion_by_dbt_names(
238+
jaffle_shop_duckdb: Path,
239+
jaffle_shop_duckdb_context: Context,
240+
dbt_select: t.List[str],
241+
dbt_exclude: t.List[str],
242+
expected: t.Set[str],
243+
):
244+
(jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text("""
245+
{{ config(tags=["agg"]) }}
246+
select order_date, count(*) as num_orders from {{ ref('orders') }}
247+
""")
248+
249+
ctx = jaffle_shop_duckdb_context
250+
ctx.load()
251+
assert '"jaffle_shop"."main"."agg_orders"' in ctx.models
252+
253+
selector = ctx._new_selector()
254+
assert selector._dbt_mode
255+
256+
sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=dbt_exclude)
257+
assert sqlmesh_selector
258+
259+
assert selector.expand_model_selections([sqlmesh_selector]) == expected

tests/dbt/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def jaffle_shop_duckdb(copy_to_temp_path: t.Callable[..., t.List[Path]]) -> t.It
9999
@pytest.fixture
100100
def jaffle_shop_duckdb_context(jaffle_shop_duckdb: Path) -> Context:
101101
init_project_if_required(jaffle_shop_duckdb)
102-
return Context(paths=[jaffle_shop_duckdb])
102+
return Context(paths=[jaffle_shop_duckdb], dbt_mode=True)
103103

104104

105105
@pytest.fixture()

0 commit comments

Comments
 (0)