Skip to content

Commit a303011

Browse files
authored
Fix: unique, user-friendly audit names for custom named dbt tests (#5484)
1 parent 9dce264 commit a303011

File tree

3 files changed

+201
-4
lines changed

3 files changed

+201
-4
lines changed

sqlmesh/dbt/manifest.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
extract_call_names,
6262
jinja_call_arg_name,
6363
)
64+
from sqlglot.helper import ensure_list
6465

6566
if t.TYPE_CHECKING:
6667
from dbt.contracts.graph.manifest import Macro, Manifest
@@ -353,15 +354,17 @@ def _load_tests(self) -> None:
353354
)
354355

355356
test_model = _test_model(node)
357+
node_config = _node_base_config(node)
358+
node_config["name"] = _build_test_name(node, dependencies)
356359

357360
test = TestConfig(
358361
sql=sql,
359362
model_name=test_model,
360363
test_kwargs=node.test_metadata.kwargs if hasattr(node, "test_metadata") else {},
361364
dependencies=dependencies,
362-
**_node_base_config(node),
365+
**node_config,
363366
)
364-
self._tests_per_package[node.package_name][node.name.lower()] = test
367+
self._tests_per_package[node.package_name][node.unique_id] = test
365368
if test_model:
366369
self._tests_by_owner[test_model].append(test)
367370

@@ -741,7 +744,12 @@ def _test_model(node: ManifestNode) -> t.Optional[str]:
741744
attached_node = getattr(node, "attached_node", None)
742745
if attached_node:
743746
pieces = attached_node.split(".")
744-
return pieces[-1] if pieces[0] in ["model", "seed"] else None
747+
if pieces[0] in ["model", "seed"]:
748+
# versioned models have format "model.package.model_name.v1" (4 parts)
749+
if len(pieces) == 4:
750+
return f"{pieces[2]}_{pieces[3]}"
751+
return pieces[-1]
752+
return None
745753

746754
key_name = getattr(node, "file_key_name", None)
747755
if key_name:
@@ -798,3 +806,53 @@ def _strip_jinja_materialization_tags(materialization_jinja: str) -> str:
798806
)
799807

800808
return materialization_jinja.strip()
809+
810+
811+
def _build_test_name(node: ManifestNode, dependencies: Dependencies) -> str:
812+
"""
813+
Build a user-friendly test name that includes the test's model/source, column,
814+
and args for tests with custom user names. Needed because dbt only generates these
815+
names for tests that do not specify the "name" field in their YAML definition.
816+
817+
Name structure
818+
- Model test: [namespace]_[test name]_[model name]_[column name]__[arg values]
819+
- Source test: [namespace]_source_[test name]_[source name]_[table name]_[column name]__[arg values]
820+
"""
821+
# standalone test
822+
if not hasattr(node, "test_metadata"):
823+
return node.name
824+
825+
model_name = _test_model(node)
826+
source_name = None
827+
if not model_name and dependencies.sources:
828+
# extract source and table names
829+
source_parts = list(dependencies.sources)[0].split(".")
830+
source_name = "_".join(source_parts) if len(source_parts) == 2 else source_parts[-1]
831+
entity_name = model_name or source_name or ""
832+
entity_name = f"_{entity_name}" if entity_name else ""
833+
834+
name_prefix = ""
835+
if namespace := getattr(node.test_metadata, "namespace", None):
836+
name_prefix += f"{namespace}_"
837+
if source_name and not model_name:
838+
name_prefix += "source_"
839+
840+
metadata_kwargs = node.test_metadata.kwargs
841+
arg_val_parts = []
842+
for arg, val in sorted(metadata_kwargs.items()):
843+
if arg == "model":
844+
continue
845+
if isinstance(val, dict):
846+
val = list(val.values())
847+
val = [re.sub("[^0-9a-zA-Z_]+", "_", str(v)) for v in ensure_list(val)]
848+
arg_val_parts.extend(val)
849+
unique_args = "__".join(arg_val_parts) if arg_val_parts else ""
850+
unique_args = f"_{unique_args}" if unique_args else ""
851+
852+
auto_name = f"{name_prefix}{node.test_metadata.name}{entity_name}{unique_args}"
853+
854+
if node.name == auto_name:
855+
return node.name
856+
857+
custom_prefix = name_prefix if source_name and not model_name else ""
858+
return f"{custom_prefix}{node.name}{entity_name}{unique_args}"

sqlmesh/dbt/test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,14 @@ def is_standalone(self) -> bool:
122122
return True
123123

124124
# Check if test has references to other models
125-
other_refs = {ref for ref in self.dependencies.refs if ref != self.model_name}
125+
# For versioned models, refs include version (e.g., "model_name_v1") but model_name may not
126+
self_refs = {self.model_name}
127+
for ref in self.dependencies.refs:
128+
# versioned models end in _vX
129+
if ref.startswith(f"{self.model_name}_v"):
130+
self_refs.add(ref)
131+
132+
other_refs = {ref for ref in self.dependencies.refs if ref not in self_refs}
126133
return bool(other_refs)
127134

128135
@property

tests/dbt/test_test.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
15
from sqlmesh.dbt.test import TestConfig
26

37

@@ -8,3 +12,131 @@ def test_multiline_test_kwarg() -> None:
812
test_kwargs={"test_field": "foo\nbar\n"},
913
)
1014
assert test._kwargs() == 'test_field="foo\nbar"'
15+
16+
17+
@pytest.mark.xdist_group("dbt_manifest")
18+
def test_tests_get_unique_names(tmp_path: Path, create_empty_project) -> None:
19+
from sqlmesh.utils.yaml import YAML
20+
from sqlmesh.core.context import Context
21+
22+
yaml = YAML()
23+
project_dir, model_dir = create_empty_project(project_name="local")
24+
25+
model_file = model_dir / "my_model.sql"
26+
with open(model_file, "w", encoding="utf-8") as f:
27+
f.write("SELECT 1 as id, 'value1' as status")
28+
29+
# Create schema.yml with:
30+
# 1. Same test on model and source, both with/without custom test name
31+
# 2. Same test on same model with different args, both with/without custom test name
32+
# 3. Versioned model with tests (both built-in and custom named)
33+
schema_yaml = {
34+
"version": 2,
35+
"sources": [
36+
{
37+
"name": "raw",
38+
"tables": [
39+
{
40+
"name": "my_source",
41+
"columns": [
42+
{
43+
"name": "id",
44+
"data_tests": [
45+
{"not_null": {"name": "custom_notnull_name"}},
46+
{"not_null": {}},
47+
],
48+
}
49+
],
50+
}
51+
],
52+
}
53+
],
54+
"models": [
55+
{
56+
"name": "my_model",
57+
"columns": [
58+
{
59+
"name": "id",
60+
"data_tests": [
61+
{"not_null": {"name": "custom_notnull_name"}},
62+
{"not_null": {}},
63+
],
64+
},
65+
{
66+
"name": "status",
67+
"data_tests": [
68+
{"accepted_values": {"values": ["value1", "value2"]}},
69+
{"accepted_values": {"values": ["value1", "value2", "value3"]}},
70+
{
71+
"accepted_values": {
72+
"name": "custom_accepted_values_name",
73+
"values": ["value1", "value2"],
74+
}
75+
},
76+
{
77+
"accepted_values": {
78+
"name": "custom_accepted_values_name",
79+
"values": ["value1", "value2", "value3"],
80+
}
81+
},
82+
],
83+
},
84+
],
85+
},
86+
{
87+
"name": "versioned_model",
88+
"columns": [
89+
{
90+
"name": "id",
91+
"data_tests": [
92+
{"not_null": {}},
93+
{"not_null": {"name": "custom_versioned_notnull"}},
94+
],
95+
},
96+
{
97+
"name": "amount",
98+
"data_tests": [
99+
{"accepted_values": {"values": ["low", "high"]}},
100+
],
101+
},
102+
],
103+
"versions": [
104+
{"v": 1},
105+
{"v": 2},
106+
],
107+
},
108+
],
109+
}
110+
111+
schema_file = model_dir / "schema.yml"
112+
with open(schema_file, "w", encoding="utf-8") as f:
113+
yaml.dump(schema_yaml, f)
114+
115+
# Create versioned model files
116+
versioned_model_v1_file = model_dir / "versioned_model_v1.sql"
117+
with open(versioned_model_v1_file, "w", encoding="utf-8") as f:
118+
f.write("SELECT 1 as id, 'low' as amount")
119+
120+
versioned_model_v2_file = model_dir / "versioned_model_v2.sql"
121+
with open(versioned_model_v2_file, "w", encoding="utf-8") as f:
122+
f.write("SELECT 1 as id, 'low' as amount")
123+
124+
context = Context(paths=project_dir)
125+
126+
all_audit_names = list(context._audits.keys()) + list(context._standalone_audits.keys())
127+
assert sorted(all_audit_names) == [
128+
"local.accepted_values_my_model_status__value1__value2",
129+
"local.accepted_values_my_model_status__value1__value2__value3",
130+
"local.accepted_values_versioned_model_v1_amount__low__high",
131+
"local.accepted_values_versioned_model_v2_amount__low__high",
132+
"local.custom_accepted_values_name_my_model_status__value1__value2",
133+
"local.custom_accepted_values_name_my_model_status__value1__value2__value3",
134+
"local.custom_notnull_name_my_model_id",
135+
"local.custom_versioned_notnull_versioned_model_v1_id",
136+
"local.custom_versioned_notnull_versioned_model_v2_id",
137+
"local.not_null_my_model_id",
138+
"local.not_null_versioned_model_v1_id",
139+
"local.not_null_versioned_model_v2_id",
140+
"local.source_custom_notnull_name_raw_my_source_id",
141+
"local.source_not_null_raw_my_source_id",
142+
]

0 commit comments

Comments
 (0)