diff --git a/integration_tests/tests/test_column_pii_sampling.py b/integration_tests/tests/test_column_pii_sampling.py new file mode 100644 index 000000000..9eba27887 --- /dev/null +++ b/integration_tests/tests/test_column_pii_sampling.py @@ -0,0 +1,132 @@ +import json + +import pytest +from dbt_project import DbtProject + +SENSITIVE_COLUMN = "email" +SAFE_COLUMN = "order_count" + +SAMPLES_QUERY = """ + with latest_elementary_test_result as ( + select id + from {{{{ ref("elementary_test_results") }}}} + where lower(table_name) = lower('{test_id}') + order by created_at desc + limit 1 + ) + + select result_row + from {{{{ ref("test_result_rows") }}}} + where elementary_test_results_id in (select * from latest_elementary_test_result) +""" + +TEST_SAMPLE_ROW_COUNT = 5 + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_column_pii_sampling_enabled(test_id: str, dbt_project: DbtProject): + """Test that PII columns are excluded when column-level PII protection is enabled""" + data = [ + {SENSITIVE_COLUMN: f"user{i}@example.com", SAFE_COLUMN: None} for i in range(10) + ] + + test_result = dbt_project.test( + test_id, + "not_null", + test_args=dict(column_name=SAFE_COLUMN), + data=data, + columns=[ + {"name": SENSITIVE_COLUMN, "config": {"tags": ["pii"]}}, + {"name": SAFE_COLUMN}, + ], + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, + "disable_samples_on_pii_columns": True, + "pii_column_tags": ["pii"], + }, + ) + assert test_result["status"] == "fail" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + + assert len(samples) == TEST_SAMPLE_ROW_COUNT + for sample in samples: + assert SENSITIVE_COLUMN not in sample + assert SAFE_COLUMN in sample + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_column_pii_sampling_disabled(test_id: str, dbt_project: DbtProject): + """Test that all columns are included when column-level PII protection is disabled""" + data = [ + {SENSITIVE_COLUMN: f"user{i}@example.com", SAFE_COLUMN: None} for i in range(10) + ] + + test_result = dbt_project.test( + test_id, + "not_null", + test_args=dict(column_name=SAFE_COLUMN), + data=data, + columns=[ + {"name": SENSITIVE_COLUMN, "config": {"tags": ["pii"]}}, + {"name": SAFE_COLUMN}, + ], + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, + "disable_samples_on_pii_columns": False, + }, + ) + assert test_result["status"] == "fail" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + + assert len(samples) == TEST_SAMPLE_ROW_COUNT + for sample in samples: + assert SENSITIVE_COLUMN in sample + assert SAFE_COLUMN in sample + + +@pytest.mark.skip_targets(["clickhouse"]) +def test_column_pii_sampling_all_columns_pii(test_id: str, dbt_project: DbtProject): + """Test behavior when all columns are tagged as PII""" + data = [ + {SENSITIVE_COLUMN: f"user{i}@example.com", SAFE_COLUMN: i} for i in range(10) + ] + + test_result = dbt_project.test( + test_id, + "not_null", + test_args=dict(column_name=SAFE_COLUMN), + data=data, + columns=[ + {"name": SENSITIVE_COLUMN, "config": {"tags": ["pii"]}}, + {"name": SAFE_COLUMN, "config": {"tags": ["pii"]}}, + ], + test_vars={ + "enable_elementary_test_materialization": True, + "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, + "disable_samples_on_pii_columns": True, + "pii_column_tags": ["pii"], + }, + ) + assert test_result["status"] == "pass" + + samples = [ + json.loads(row["result_row"]) + for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) + ] + + assert len(samples) == TEST_SAMPLE_ROW_COUNT + for sample in samples: + assert "_no_non_pii_columns" in sample + assert sample["_no_non_pii_columns"] == 1 + assert SENSITIVE_COLUMN not in sample + assert SAFE_COLUMN not in sample diff --git a/macros/edr/materializations/test/test.sql b/macros/edr/materializations/test/test.sql index a63f89f0a..7c211953e 100644 --- a/macros/edr/materializations/test/test.sql +++ b/macros/edr/materializations/test/test.sql @@ -51,7 +51,8 @@ {% macro handle_dbt_test(flattened_test, materialization_macro) %} {% set result = materialization_macro() %} {% set result_rows = elementary.query_test_result_rows(sample_limit=elementary.get_config_var('test_sample_row_count'), - ignore_passed_tests=true) %} + ignore_passed_tests=true, + flattened_test=flattened_test) %} {% set elementary_test_results_row = elementary.get_dbt_test_result_row(flattened_test, result_rows) %} {% do elementary.cache_elementary_test_results_rows([elementary_test_results_row]) %} {% do return(result) %} @@ -103,7 +104,7 @@ {% do return(new_sql) %} {% endmacro %} -{% macro query_test_result_rows(sample_limit=none, ignore_passed_tests=false) %} +{% macro query_test_result_rows(sample_limit=none, ignore_passed_tests=false, flattened_test=none) %} {% if sample_limit == 0 %} {# performance: no need to run a sql query that we know returns an empty list #} {% do return([]) %} {% endif %} @@ -111,11 +112,35 @@ {% do elementary.debug_log("Skipping sample query because the test passed.") %} {% do return([]) %} {% endif %} + + {% set pii_columns = [] %} + {% if flattened_test %} + {% set pii_columns = elementary.get_pii_columns_from_parent_model(flattened_test) %} + {% endif %} + + {% set select_clause = "*" %} + {% if pii_columns %} + {% set query_to_get_columns %} + with test_results as ( + {{ sql }} + ) + select * from test_results limit 0 + {% endset %} + {% set columns_result = elementary.run_query(query_to_get_columns) %} + {% set all_columns = columns_result.column_names %} + {% set safe_columns = all_columns | reject("in", pii_columns) | list %} + {% if safe_columns %} + {% set select_clause = safe_columns | join(", ") %} + {% else %} + {% set select_clause = "1 as _no_non_pii_columns" %} + {% endif %} + {% endif %} + {% set query %} with test_results as ( {{ sql }} ) - select * from test_results {% if sample_limit is not none %} limit {{ sample_limit }} {% endif %} + select {{ select_clause }} from test_results {% if sample_limit is not none %} limit {{ sample_limit }} {% endif %} {% endset %} {% do return(elementary.agate_to_dicts(elementary.run_query(query))) %} {% endmacro %} diff --git a/macros/edr/system/system_utils/get_config_var.sql b/macros/edr/system/system_utils/get_config_var.sql index 431061811..4a8df400a 100644 --- a/macros/edr/system/system_utils/get_config_var.sql +++ b/macros/edr/system/system_utils/get_config_var.sql @@ -55,6 +55,8 @@ 'disable_skipped_test_alerts': true, 'dbt_artifacts_chunk_size': 5000, 'test_sample_row_count': 5, + 'disable_samples_on_pii_columns': false, + 'pii_column_tags': ['pii', 'personal', 'sensitive'], 'edr_cli_run': false, 'max_int': 2147483647, 'custom_run_started_at': none, diff --git a/macros/edr/system/system_utils/is_pii_column.sql b/macros/edr/system/system_utils/is_pii_column.sql new file mode 100644 index 000000000..a5714eaa1 --- /dev/null +++ b/macros/edr/system/system_utils/is_pii_column.sql @@ -0,0 +1,42 @@ +{% macro get_pii_columns_from_parent_model(flattened_test) %} + {% set pii_columns = [] %} + + {% if not elementary.get_config_var('disable_samples_on_pii_columns') %} + {% do return(pii_columns) %} + {% endif %} + + {% set parent_model_unique_id = elementary.insensitive_get_dict_value(flattened_test, 'parent_model_unique_id') %} + {% set parent_model = elementary.get_node(parent_model_unique_id) %} + + {% if not parent_model %} + {% do return(pii_columns) %} + {% endif %} + + {% set column_nodes = parent_model.get("columns") %} + {% if not column_nodes %} + {% do return(pii_columns) %} + {% endif %} + + {% set pii_column_tags = elementary.get_config_var('pii_column_tags') %} + {% if pii_column_tags is string %} + {% set pii_column_tags = [pii_column_tags] %} + {% endif %} + + {% for column_node in column_nodes.values() %} + {% set config_dict = column_node.get('config', {}) %} + {% set config_tags = config_dict.get('tags', []) %} + {% set global_tags = column_node.get('tags', []) %} + {% set meta_dict = column_node.get('meta', {}) %} + {% set meta_tags = meta_dict.get('tags', []) %} + {% set all_column_tags = config_tags + global_tags + meta_tags %} + + {% for pii_tag in pii_column_tags %} + {% if pii_tag in all_column_tags %} + {% do pii_columns.append(column_node.get('name')) %} + {% break %} + {% endif %} + {% endfor %} + {% endfor %} + + {% do return(pii_columns) %} +{% endmacro %}