Skip to content

Commit fdb21ea

Browse files
authored
Fix: Use dbt source_name for source method and add support package arg for ref method. (#935)
1 parent 8361ce0 commit fdb21ea

17 files changed

Lines changed: 126 additions & 100 deletions

sqlmesh/dbt/basemodel.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
GeneralConfig,
2323
QuotingConfig,
2424
SqlStr,
25-
context_for_dependencies,
2625
)
2726
from sqlmesh.dbt.test import TestConfig
2827
from sqlmesh.utils import AttributeDict
@@ -66,6 +65,8 @@ class BaseModelConfig(GeneralConfig):
6665
(eg. 'parquet')
6766
path: The file path of the model
6867
dependencies: The macro, source, var, and ref dependencies used to execute the model and its hooks
68+
name: Name of the model.
69+
package_name: Name of the package that defines the model.
6970
database: Database the model is stored in
7071
schema: Custom schema name added to the model schema name
7172
alias: Relation identifier for this model instead of the filename
@@ -85,6 +86,8 @@ class BaseModelConfig(GeneralConfig):
8586
dependencies: Dependencies = Dependencies()
8687

8788
# DBT configuration fields
89+
name: str = ""
90+
package_name: str = ""
8891
schema_: str = Field("", alias="schema")
8992
database: t.Optional[str] = None
9093
alias: t.Optional[str] = None
@@ -163,7 +166,14 @@ def table_name(self) -> str:
163166
return self.alias or self.path.stem
164167

165168
@property
166-
def model_name(self) -> str:
169+
def config_name(self) -> str:
170+
"""
171+
Get the model's config name (package_name.name)
172+
"""
173+
return f"{self.package_name}.{self.name}"
174+
175+
@property
176+
def sql_name(self) -> str:
167177
"""
168178
Get the sqlmesh model name
169179
@@ -206,7 +216,7 @@ def model_function(self) -> AttributeDict[str, t.Any]:
206216

207217
def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
208218
"""Get common sqlmesh model parameters"""
209-
model_context = context_for_dependencies(context, self.dependencies)
219+
model_context = context.context_for_dependencies(self.dependencies)
210220
jinja_macros = model_context.jinja_macros.trim(self.dependencies.macros)
211221
jinja_macros.global_objs.update(
212222
{
@@ -228,8 +238,8 @@ def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
228238
"audits": [(test.name, {}) for test in self.tests],
229239
"columns": column_types_to_sqlmesh(self.columns) or None,
230240
"column_descriptions_": column_descriptions_to_sqlmesh(self.columns) or None,
231-
"depends_on": {context.refs[ref] for ref in self.dependencies.refs}.union(
232-
{context.sources[source].source_name for source in self.dependencies.sources}
241+
"depends_on": {model.sql_name for model in model_context.refs.values()}.union(
242+
{source.sql_name for source in model_context.sources.values()}
233243
),
234244
"jinja_macros": jinja_macros,
235245
"path": self.path,

sqlmesh/dbt/builtin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,9 @@ def var(name: str, default: t.Optional[str] = None) -> str:
152152

153153

154154
def generate_ref(refs: t.Dict[str, t.Any]) -> t.Callable:
155-
156-
# TODO suport package name
157155
def ref(package: str, name: t.Optional[str] = None) -> t.Optional[BaseRelation]:
158-
name = name or package
159-
relation_info = refs.get(name)
156+
ref_name = f"{package}.{name}" if name else package
157+
relation_info = refs.get(ref_name)
160158
if relation_info is None:
161159
return None
162160

sqlmesh/dbt/common.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818

1919
T = t.TypeVar("T", bound="GeneralConfig")
2020

21-
if t.TYPE_CHECKING:
22-
from sqlmesh.dbt.context import DbtContext
23-
2421
PROJECT_FILENAME = "dbt_project.yml"
2522

2623
JINJA_ONLY = {
@@ -179,34 +176,6 @@ def dict(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
179176
return out
180177

181178

182-
def context_for_dependencies(context: DbtContext, dependencies: Dependencies) -> DbtContext:
183-
dependency_context = context.copy()
184-
185-
models = {}
186-
seeds = {}
187-
sources = {}
188-
189-
for ref in dependencies.refs:
190-
if ref in context.seeds:
191-
seeds[ref] = context.seeds[ref]
192-
elif ref in context.models:
193-
models[ref] = context.models[ref]
194-
else:
195-
raise ConfigError(f"Model '{ref}' was not found.")
196-
197-
for source in dependencies.sources:
198-
if source in context.sources:
199-
sources[source] = context.sources[source]
200-
else:
201-
raise ConfigError(f"Source '{source}' was not found.")
202-
203-
dependency_context.sources = sources
204-
dependency_context.seeds = seeds
205-
dependency_context.models = models
206-
207-
return dependency_context
208-
209-
210179
def extract_jinja_config(input: str) -> t.Tuple[str, str]:
211180
def jinja_end(sql: str, start: int) -> int:
212181
cursor = start

sqlmesh/dbt/context.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
if t.TYPE_CHECKING:
1515
from jinja2 import Environment
1616

17+
from sqlmesh.dbt.basemodel import Dependencies
1718
from sqlmesh.dbt.model import ModelConfig
1819
from sqlmesh.dbt.seed import SeedConfig
1920
from sqlmesh.dbt.source import SourceConfig
@@ -40,7 +41,7 @@ class DbtContext:
4041
_models: t.Dict[str, ModelConfig] = field(default_factory=dict)
4142
_seeds: t.Dict[str, SeedConfig] = field(default_factory=dict)
4243
_sources: t.Dict[str, SourceConfig] = field(default_factory=dict)
43-
_refs: t.Dict[str, str] = field(default_factory=dict)
44+
_refs: t.Dict[str, t.Union[ModelConfig, SeedConfig]] = field(default_factory=dict)
4445

4546
_target: t.Optional[TargetConfig] = None
4647

@@ -128,9 +129,12 @@ def add_sources(self, sources: t.Dict[str, SourceConfig]) -> None:
128129
self._jinja_environment = None
129130

130131
@property
131-
def refs(self) -> t.Dict[str, str]:
132+
def refs(self) -> t.Dict[str, t.Union[ModelConfig, SeedConfig]]:
132133
if not self._refs:
133-
self._refs = {k: v.model_name for k, v in {**self._seeds, **self._models}.items()} # type: ignore
134+
# Refs can be called with or without package name.
135+
for model in {**self._seeds, **self._models}.values(): # type: ignore
136+
self._refs[model.name] = model
137+
self._refs[model.config_name] = model
134138
return self._refs
135139

136140
@property
@@ -162,14 +166,46 @@ def jinja_environment(self) -> Environment:
162166

163167
@property
164168
def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]:
165-
refs: t.Dict[str, t.Union[ModelConfig, SeedConfig]] = {**self.models, **self.seeds}
166169
output: t.Dict[str, JinjaGlobalAttribute] = {
167170
"vars": AttributeDict(self.variables),
168-
"refs": AttributeDict({k: v.relation_info for k, v in refs.items()}),
171+
"refs": AttributeDict({k: v.relation_info for k, v in self.refs.items()}),
169172
"sources": AttributeDict({k: v.relation_info for k, v in self.sources.items()}),
170173
}
171174
if self.project_name is not None:
172175
output["project_name"] = self.project_name
173176
if self._target is not None:
174177
output["target"] = self._target.attribute_dict()
175178
return output
179+
180+
def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext:
181+
from sqlmesh.dbt.model import ModelConfig
182+
from sqlmesh.dbt.seed import SeedConfig
183+
184+
dependency_context = self.copy()
185+
186+
models = {}
187+
seeds = {}
188+
sources = {}
189+
190+
for ref in dependencies.refs:
191+
model = self.refs.get(ref)
192+
if model:
193+
if isinstance(model, SeedConfig):
194+
seeds[ref] = t.cast(SeedConfig, model)
195+
else:
196+
models[ref] = t.cast(ModelConfig, model)
197+
else:
198+
raise ConfigError(f"Model '{ref}' was not found.")
199+
200+
for source in dependencies.sources:
201+
if source in self.sources:
202+
sources[source] = self.sources[source]
203+
else:
204+
raise ConfigError(f"Source '{source}' was not found.")
205+
206+
dependency_context.sources = sources
207+
dependency_context.seeds = seeds
208+
dependency_context.models = models
209+
dependency_context._refs = {**dependency_context._seeds, **dependency_context._models} # type: ignore
210+
211+
return dependency_context

sqlmesh/dbt/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _load_models(
9999

100100
models.update(
101101
{
102-
model.model_name: cache.get_or_load_model(
102+
model.sql_name: cache.get_or_load_model(
103103
model.path, lambda: self._to_sqlmesh(model, context)
104104
)
105105
for model in package_models.values()
@@ -139,7 +139,7 @@ def _load_project(self) -> Project:
139139

140140
@classmethod
141141
def _to_sqlmesh(cls, config: BMC, context: DbtContext) -> Model:
142-
logger.debug("Converting '%s' to sqlmesh format", config.model_name)
142+
logger.debug("Converting '%s' to sqlmesh format", config.sql_name)
143143
return config.to_sqlmesh(context)
144144

145145
def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]:

sqlmesh/dbt/manifest.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(
6161
self._sources_per_package: t.Dict[str, SourceConfigs] = defaultdict(dict)
6262
self._macros_per_package: t.Dict[str, MacroConfigs] = defaultdict(dict)
6363

64+
self._tests_by_owner: t.Dict[str, t.List[TestConfig]] = defaultdict(list)
65+
6466
def tests(self, package_name: t.Optional[str] = None) -> TestConfigs:
6567
self._load_all()
6668
return self._tests_per_package[package_name or self._project_name]
@@ -109,7 +111,7 @@ def _load_sources(self) -> None:
109111
**source_dict,
110112
)
111113
self._sources_per_package[source.package_name][
112-
source_config.source_name
114+
source_config.config_name
113115
] = source_config
114116

115117
def _load_macros(self) -> None:
@@ -162,20 +164,26 @@ def _load_tests(self) -> None:
162164
)
163165
continue
164166

165-
self._tests_per_package[node.package_name][node.name.lower()] = TestConfig(
167+
test = TestConfig(
166168
sql=node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql, # type: ignore
167169
owner=test_owner,
168170
test_kwargs=node.test_metadata.kwargs if hasattr(node, "test_metadata") else {},
169171
dependencies=dependencies,
170172
**_node_base_config(node),
171173
)
174+
self._tests_per_package[node.package_name][node.name.lower()] = test
175+
self._tests_by_owner[test_owner].append(test)
172176

173177
def _load_models_and_seeds(self) -> None:
174178
for node in self._manifest.nodes.values():
175179
if node.resource_type not in ("model", "seed"):
176180
continue
177181

178182
macro_references = _macro_references(self._manifest, node)
183+
tests = (
184+
self._tests_by_owner[node.name]
185+
+ self._tests_by_owner[f"{node.package_name}.{node.name}"]
186+
)
179187

180188
if node.resource_type == "model":
181189
self._models_per_package[node.package_name][node.name] = ModelConfig(
@@ -185,13 +193,13 @@ def _load_models_and_seeds(self) -> None:
185193
refs=_refs(node),
186194
sources=_sources(node),
187195
),
188-
tests=_tests_for_node(node, self._tests_per_package[node.package_name]),
196+
tests=tests,
189197
**_node_base_config(node),
190198
)
191199
else:
192200
self._seeds_per_package[node.package_name][node.name] = SeedConfig(
193201
dependencies=Dependencies(macros=macro_references),
194-
tests=_tests_for_node(node, self._tests_per_package[node.package_name]),
202+
tests=tests,
195203
**_node_base_config(node),
196204
)
197205

@@ -259,9 +267,9 @@ def _macro_references(
259267

260268
def _refs(node: ManifestNode) -> t.Set[str]:
261269
if DBT_VERSION >= (1, 5):
262-
return {r.name for r in node.refs} # type: ignore
270+
return {f"{r.package}.{r.name}" if r.package else r.name for r in node.refs} # type: ignore
263271
else:
264-
return {r[1] if len(r) > 1 else r[0] for r in node.refs} # type: ignore
272+
return {".".join(r) for r in node.refs} # type: ignore
265273

266274

267275
def _sources(node: ManifestNode) -> t.Set[str]:
@@ -309,10 +317,6 @@ def _node_base_config(node: ManifestNode) -> t.Dict[str, t.Any]:
309317
}
310318

311319

312-
def _tests_for_node(node: ManifestNode, tests: t.Dict[str, TestConfig]) -> t.List[TestConfig]:
313-
return [test for test in tests.values() if test.owner == node.name]
314-
315-
316320
def _convert_jinja_test_to_macro(test_jinja: str) -> str:
317321
TEST_TAG_REGEX = "\s*{%\s*test\s+"
318322
ENDTEST_REGEX = "{%\s*endtest\s*%}"

sqlmesh/dbt/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ def model_kind(self, target: TargetConfig) -> ModelKind:
152152
and strategy not in INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES
153153
):
154154
raise ConfigError(
155-
f"{self.model_name}: SQLMesh IncrementalByUniqueKey is not compatible with '{strategy}'"
155+
f"{self.sql_name}: SQLMesh IncrementalByUniqueKey is not compatible with '{strategy}'"
156156
f" incremental strategy. Supported strategies include {collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES)}."
157157
)
158158
return IncrementalByUniqueKeyKind(unique_key=self.unique_key, **incremental_kwargs)
159159

160160
raise ConfigError(
161-
f"{self.model_name}: Incremental materialization requires either a "
161+
f"{self.sql_name}: Incremental materialization requires either a "
162162
f"time_column {collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES)}) or a "
163163
f"unique_key ({collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES.union(['none']))}) configuration."
164164
)
@@ -205,10 +205,10 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
205205
optional_kwargs[field] = field_val
206206

207207
if not context.target:
208-
raise ConfigError(f"Target required to load '{self.model_name}' into SQLMesh.")
208+
raise ConfigError(f"Target required to load '{self.sql_name}' into SQLMesh.")
209209

210210
return create_sql_model(
211-
self.model_name,
211+
self.sql_name,
212212
expressions[0],
213213
dialect=dialect,
214214
kind=self.model_kind(context.target),

sqlmesh/dbt/seed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class SeedConfig(BaseModelConfig):
2121
def to_sqlmesh(self, context: DbtContext) -> Model:
2222
"""Converts the dbt seed into a SQLMesh model."""
2323
return create_seed_model(
24-
self.model_name,
24+
self.sql_name,
2525
SeedKind(path=self.path.absolute()),
2626
**self.sqlmesh_model_kwargs(context),
2727
)

sqlmesh/dbt/source.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
class SourceConfig(GeneralConfig):
1515
"""
1616
Args:
17-
config_name: The schema.table_name names declared in source config
18-
name: The name of the source or table
17+
name: The name of the table
18+
source_name: The name of the source that defines the table
1919
database: Name of the database where the table is stored. By default, the project's target database is used.
2020
schema: The scehma name as stored in the database. If not specified, the source name is used.
2121
identifier: The table name as stored in the database. If not specified, the source table name is used
@@ -28,11 +28,9 @@ class SourceConfig(GeneralConfig):
2828
columns: Columns within the source
2929
"""
3030

31-
# sqlmesh fields
32-
config_name: str = ""
33-
3431
# DBT configuration fields
35-
name: t.Optional[str] = None
32+
name: str = ""
33+
source_name_: str = Field("", alias="source_name")
3634
database: t.Optional[str] = None
3735
schema_: t.Optional[str] = Field(None, alias="schema")
3836
identifier: t.Optional[str] = None
@@ -54,7 +52,11 @@ def table_name(self) -> t.Optional[str]:
5452
return self.identifier or self.name
5553

5654
@property
57-
def source_name(self) -> str:
55+
def config_name(self) -> str:
56+
return f"{self.source_name_}.{self.name}"
57+
58+
@property
59+
def sql_name(self) -> str:
5860
return ".".join(part for part in (self.schema_, self.table_name) if part)
5961

6062
@property

0 commit comments

Comments
 (0)