Skip to content

Commit fb4dff6

Browse files
authored
Fix: Use version-specific compilation error for dbt projects (#1055)
1 parent 98ef71f commit fb4dff6

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,16 @@ def __init__(
157157
identifier=quote_param,
158158
)
159159

160-
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
160+
def get_relation(
161+
self, database: t.Optional[str], schema: str, identifier: str
162+
) -> t.Optional[BaseRelation]:
161163
relations_list = self.list_relations(database, schema)
162164
matching_relations = [
163165
r
164166
for r in relations_list
165-
if r.identifier == identifier and r.schema == schema and r.database == database
167+
if r.identifier == identifier
168+
and r.schema == schema
169+
and (r.database == database or database is None)
166170
]
167171
return seq_get(matching_relations, 0)
168172

sqlmesh/dbt/builtin.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,22 @@
1414

1515
from sqlmesh.core.engine_adapter import EngineAdapter
1616
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter
17+
from sqlmesh.dbt.util import DBT_VERSION
1718
from sqlmesh.utils import AttributeDict, yaml
1819
from sqlmesh.utils.errors import ConfigError, MacroEvalError
1920
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal
2021

2122

2223
class Exceptions:
2324
def raise_compiler_error(self, msg: str) -> None:
24-
from dbt.exceptions import CompilationError
25+
if DBT_VERSION >= (1, 4):
26+
from dbt.exceptions import CompilationError
2527

26-
raise CompilationError(msg)
28+
raise CompilationError(msg)
29+
else:
30+
from dbt.exceptions import CompilationException # type: ignore
31+
32+
raise CompilationException(msg)
2733

2834
def warn(self, msg: str) -> str:
2935
print(msg)

sqlmesh/dbt/manifest.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from dbt.config.renderer import DbtProjectYamlRenderer, ProfileRenderer
1515
from dbt.parser.manifest import ManifestLoader
1616
from dbt.tracking import do_not_track
17-
from dbt.version import get_installed_version
1817

1918
from sqlmesh.dbt.basemodel import Dependencies
2019
from sqlmesh.dbt.macros import MACRO_OVERRIDES
@@ -23,6 +22,7 @@
2322
from sqlmesh.dbt.seed import SeedConfig
2423
from sqlmesh.dbt.source import SourceConfig
2524
from sqlmesh.dbt.test import TestConfig
25+
from sqlmesh.dbt.util import DBT_VERSION
2626
from sqlmesh.utils.errors import ConfigError
2727
from sqlmesh.utils.jinja import MacroInfo, MacroReference
2828

@@ -310,11 +310,6 @@ def _model_node_id(model_name: str, package: str) -> str:
310310
return f"model.{package}.{model_name}"
311311

312312

313-
def _get_dbt_version() -> t.Tuple[int, int]:
314-
dbt_version = get_installed_version()
315-
return (int(dbt_version.major or "0"), int(dbt_version.minor or "0"))
316-
317-
318313
def _test_owner(node: ManifestNode) -> t.Optional[str]:
319314
attached_node = getattr(node, "attached_node", None)
320315
if attached_node:
@@ -357,6 +352,3 @@ def _convert_jinja_test_to_macro(test_jinja: str) -> str:
357352

358353
macro = "{% macro test_" + test_jinja[match.span()[-1] :]
359354
return re.sub(ENDTEST_REGEX, "{% endmacro %}", macro)
360-
361-
362-
DBT_VERSION = _get_dbt_version()

sqlmesh/dbt/util.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import typing as t
2+
13
import agate
24
import pandas as pd
5+
from dbt.version import get_installed_version
36

47

58
def pandas_to_agate(df: pd.DataFrame) -> agate.Table:
@@ -9,3 +12,11 @@ def pandas_to_agate(df: pd.DataFrame) -> agate.Table:
912
from dbt.clients.agate_helper import table_from_data
1013

1114
return table_from_data(df.to_dict(orient="records"), df.columns.tolist())
15+
16+
17+
def _get_dbt_version() -> t.Tuple[int, int]:
18+
dbt_version = get_installed_version()
19+
return (int(dbt_version.major or "0"), int(dbt_version.minor or "0"))
20+
21+
22+
DBT_VERSION = _get_dbt_version()

0 commit comments

Comments
 (0)