Skip to content
6 changes: 6 additions & 0 deletions integration_tests/tests/dbt_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test(
materialization: str = "table", # Only relevant if as_model=True
test_vars: Optional[dict] = None,
elementary_enabled: bool = True,
model_config: Optional[Dict[str, Any]] = None,
*,
multiple_results: Literal[False] = False,
) -> Dict[str, Any]:
Expand All @@ -128,6 +129,7 @@ def test(
materialization: str = "table", # Only relevant if as_model=True
test_vars: Optional[dict] = None,
elementary_enabled: bool = True,
model_config: Optional[Dict[str, Any]] = None,
*,
multiple_results: Literal[True],
) -> List[Dict[str, Any]]:
Expand All @@ -146,6 +148,7 @@ def test(
materialization: str = "table", # Only relevant if as_model=True
test_vars: Optional[dict] = None,
elementary_enabled: bool = True,
model_config: Optional[Dict[str, Any]] = None,
*,
multiple_results: bool = False,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
Expand All @@ -161,6 +164,9 @@ def test(
test_args = test_args or {}
table_yaml: Dict[str, Any] = {"name": test_id}

if model_config:
table_yaml.update(model_config)

if columns:
table_yaml["columns"] = columns

Expand Down
110 changes: 110 additions & 0 deletions integration_tests/tests/test_sampling_pii.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import json

import pytest
from dbt_project import DbtProject

COLUMN_NAME = "value"


SAMPLES_QUERY = """
with latest_elementary_test_result as (
select id
from {{{{ ref("elementary_test_results") }}}}
where lower(table_name) = lower('{test_id}')
order by created_at desc, id desc
limit 1
)

select result_row
from {{{{ ref("test_result_rows") }}}}
where elementary_test_results_id in (select * from latest_elementary_test_result)
"""

TEST_SAMPLE_ROW_COUNT = 7


@pytest.mark.skip_targets(["clickhouse"])
def test_sampling_pii_disabled(test_id: str, dbt_project: DbtProject):
"""Test that PII-tagged tables don't upload samples even when tests fail"""
null_count = 50
data = [{COLUMN_NAME: None} for _ in range(null_count)]

test_result = dbt_project.test(
test_id,
"not_null",
dict(column_name=COLUMN_NAME),
data=data,
as_model=True,
model_config={"config": {"tags": ["pii"]}},
test_vars={
"enable_elementary_test_materialization": True,
"test_sample_row_count": TEST_SAMPLE_ROW_COUNT,
"disable_samples_on_pii_tables": True,
"pii_table_tags": ["pii", "sensitive"],
},
)
assert test_result["status"] == "fail"

samples = [
json.loads(row["result_row"])
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
]
assert len(samples) == 0


@pytest.mark.skip_targets(["clickhouse"])
def test_sampling_non_pii_enabled(test_id: str, dbt_project: DbtProject):
"""Test that non-PII tables still collect samples normally"""
null_count = 50
data = [{COLUMN_NAME: None} for _ in range(null_count)]

test_result = dbt_project.test(
test_id,
"not_null",
dict(column_name=COLUMN_NAME),
data=data,
as_model=True,
model_config={"config": {"tags": ["normal"]}},
test_vars={
"enable_elementary_test_materialization": True,
"test_sample_row_count": TEST_SAMPLE_ROW_COUNT,
"disable_samples_on_pii_tables": True,
"pii_table_tags": ["pii", "sensitive"],
},
)
assert test_result["status"] == "fail"

samples = [
json.loads(row["result_row"])
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
]
assert len(samples) == TEST_SAMPLE_ROW_COUNT


@pytest.mark.skip_targets(["clickhouse"])
def test_sampling_pii_feature_disabled(test_id: str, dbt_project: DbtProject):
"""Test that when PII feature is disabled, PII tables still collect samples"""
null_count = 50
data = [{COLUMN_NAME: None} for _ in range(null_count)]

test_result = dbt_project.test(
test_id,
"not_null",
dict(column_name=COLUMN_NAME),
data=data,
as_model=True,
model_config={"config": {"tags": ["pii"]}},
test_vars={
"enable_elementary_test_materialization": True,
"test_sample_row_count": TEST_SAMPLE_ROW_COUNT,
"disable_samples_on_pii_tables": False,
"pii_table_tags": ["pii", "sensitive"],
},
)
assert test_result["status"] == "fail"

samples = [
json.loads(row["result_row"])
for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id))
]
assert len(samples) == TEST_SAMPLE_ROW_COUNT
6 changes: 5 additions & 1 deletion macros/edr/materializations/test/test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@

{% macro handle_dbt_test(flattened_test, materialization_macro) %}
{% set result = materialization_macro() %}
{% set result_rows = elementary.query_test_result_rows(sample_limit=elementary.get_config_var('test_sample_row_count'),
{% set sample_limit = elementary.get_config_var('test_sample_row_count') %}
{% if elementary.is_pii_table(flattened_test) %}
{% set sample_limit = 0 %}
{% endif %}
{% set result_rows = elementary.query_test_result_rows(sample_limit=sample_limit,
ignore_passed_tests=true) %}
{% set elementary_test_results_row = elementary.get_dbt_test_result_row(flattened_test, result_rows) %}
{% do elementary.cache_elementary_test_results_rows([elementary_test_results_row]) %}
Expand Down
4 changes: 3 additions & 1 deletion macros/edr/system/system_utils/get_config_var.sql
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@
},
'include_other_warehouse_specific_columns': false,
'fail_on_zero': false,
'anomaly_exclude_metrics': none
'anomaly_exclude_metrics': none,
'disable_samples_on_pii_tables': false,
'pii_table_tags': ['pii']
} %}
{{- return(default_config) -}}
{%- endmacro -%}
Expand Down
14 changes: 14 additions & 0 deletions macros/edr/system/system_utils/is_pii_table.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{% macro is_pii_table(flattened_test) %}
{% set disable_samples_on_pii_tables = elementary.get_config_var('disable_samples_on_pii_tables') %}
{% if not disable_samples_on_pii_tables %}
{% do return(false) %}
{% endif %}
Comment thread
arbiv marked this conversation as resolved.

{% set pii_table_tags = elementary.get_config_var('pii_table_tags') %}
{% set model_tags = elementary.insensitive_get_dict_value(flattened_test, 'model_tags', []) %}

{% set intersection = elementary.lists_intersection(model_tags, pii_table_tags) %}
{% set is_pii = intersection | length > 0 %}

{% do return(is_pii) %}
{% endmacro %}
Loading