-
Notifications
You must be signed in to change notification settings - Fork 137
Implement column-level PII protection for sample collection #833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
5d4dd69
3ba7a44
a159aa7
9fa0eb1
6014bac
d481311
c45cba3
4907483
949b357
06e9451
3b6854b
057a2cc
89e49de
c5b9f45
8cc4050
812d9c2
25ddd67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| import json | ||
|
|
||
| import pytest | ||
| from dbt_project import DbtProject | ||
|
|
||
| SENSITIVE_COLUMN = "email" | ||
| SAFE_COLUMN = "order_count" | ||
|
|
||
| SAMPLES_QUERY = """ | ||
| with latest_elementary_test_result as ( | ||
| select id | ||
| from {{{{ ref("elementary_test_results") }}}} | ||
| where lower(table_name) = lower('{test_id}') | ||
| order by created_at desc | ||
| limit 1 | ||
| ) | ||
|
|
||
| select result_row | ||
| from {{{{ ref("test_result_rows") }}}} | ||
| where elementary_test_results_id in (select * from latest_elementary_test_result) | ||
| """ | ||
|
|
||
| TEST_SAMPLE_ROW_COUNT = 5 | ||
|
|
||
|
|
||
| @pytest.mark.skip_targets(["clickhouse"]) | ||
| def test_column_pii_sampling_enabled(test_id: str, dbt_project: DbtProject): | ||
| """Test that PII columns are excluded when column-level PII protection is enabled""" | ||
| data = [ | ||
| {SENSITIVE_COLUMN: f"user{i}@example.com", SAFE_COLUMN: None} for i in range(10) | ||
| ] | ||
|
|
||
| test_result = dbt_project.test( | ||
| test_id, | ||
| "not_null", | ||
| test_args=dict(column_name=SAFE_COLUMN), | ||
| data=data, | ||
| columns=[ | ||
| {"name": SENSITIVE_COLUMN, "config": {"tags": ["pii"]}}, | ||
| {"name": SAFE_COLUMN}, | ||
| ], | ||
| test_vars={ | ||
| "enable_elementary_test_materialization": True, | ||
| "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, | ||
| "disable_samples_on_pii_columns": True, | ||
| "pii_column_tags": ["pii"], | ||
| }, | ||
| ) | ||
| assert test_result["status"] == "fail" | ||
|
|
||
| samples = [ | ||
| json.loads(row["result_row"]) | ||
| for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) | ||
| ] | ||
|
|
||
| assert len(samples) == TEST_SAMPLE_ROW_COUNT | ||
| for sample in samples: | ||
| assert SENSITIVE_COLUMN not in sample | ||
| assert SAFE_COLUMN in sample | ||
|
|
||
|
|
||
| @pytest.mark.skip_targets(["clickhouse"]) | ||
| def test_column_pii_sampling_disabled(test_id: str, dbt_project: DbtProject): | ||
| """Test that all columns are included when column-level PII protection is disabled""" | ||
| data = [ | ||
| {SENSITIVE_COLUMN: f"user{i}@example.com", SAFE_COLUMN: None} for i in range(10) | ||
| ] | ||
|
|
||
| test_result = dbt_project.test( | ||
| test_id, | ||
| "not_null", | ||
| test_args=dict(column_name=SAFE_COLUMN), | ||
| data=data, | ||
| columns=[ | ||
| {"name": SENSITIVE_COLUMN, "config": {"tags": ["pii"]}}, | ||
| {"name": SAFE_COLUMN}, | ||
| ], | ||
| test_vars={ | ||
| "enable_elementary_test_materialization": True, | ||
| "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, | ||
| "disable_samples_on_pii_columns": False, | ||
| }, | ||
| ) | ||
| assert test_result["status"] == "fail" | ||
|
|
||
| samples = [ | ||
| json.loads(row["result_row"]) | ||
| for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) | ||
| ] | ||
|
|
||
| assert len(samples) == TEST_SAMPLE_ROW_COUNT | ||
| for sample in samples: | ||
| assert SENSITIVE_COLUMN in sample | ||
| assert SAFE_COLUMN in sample | ||
|
|
||
|
|
||
| @pytest.mark.skip_targets(["clickhouse"]) | ||
| def test_column_pii_sampling_all_columns_pii(test_id: str, dbt_project: DbtProject): | ||
| """Test behavior when all columns are tagged as PII""" | ||
| data = [ | ||
| {SENSITIVE_COLUMN: f"user{i}@example.com", SAFE_COLUMN: i} for i in range(10) | ||
| ] | ||
|
|
||
| test_result = dbt_project.test( | ||
| test_id, | ||
| "not_null", | ||
| test_args=dict(column_name=SAFE_COLUMN), | ||
| data=data, | ||
| columns=[ | ||
| {"name": SENSITIVE_COLUMN, "config": {"tags": ["pii"]}}, | ||
| {"name": SAFE_COLUMN, "config": {"tags": ["pii"]}}, | ||
| ], | ||
| test_vars={ | ||
| "enable_elementary_test_materialization": True, | ||
| "test_sample_row_count": TEST_SAMPLE_ROW_COUNT, | ||
| "disable_samples_on_pii_columns": True, | ||
| "pii_column_tags": ["pii"], | ||
| }, | ||
| ) | ||
| assert test_result["status"] == "pass" | ||
|
|
||
| samples = [ | ||
| json.loads(row["result_row"]) | ||
| for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) | ||
| ] | ||
|
|
||
| assert len(samples) == TEST_SAMPLE_ROW_COUNT | ||
| for sample in samples: | ||
| assert "_no_non_pii_columns" in sample | ||
| assert sample["_no_non_pii_columns"] == 1 | ||
| assert SENSITIVE_COLUMN not in sample | ||
| assert SAFE_COLUMN not in sample | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| import json | ||
|
|
||
| import pytest | ||
| from dbt_project import DbtProject | ||
|
|
||
| COLUMN_NAME = "sensitive_data" | ||
|
|
||
| SAMPLES_QUERY = """ | ||
| with latest_elementary_test_result as ( | ||
| select id | ||
| from {{ ref("elementary_test_results") }} | ||
| where lower(table_name) = lower('{test_id}') | ||
| order by created_at desc | ||
| limit 1 | ||
| ) | ||
|
|
||
| select result_row | ||
| from {{ ref("test_result_rows") }} | ||
| where elementary_test_results_id in (select * from latest_elementary_test_result) | ||
| """ | ||
|
|
||
|
|
||
| @pytest.mark.skip_targets(["clickhouse"]) | ||
| def test_disable_samples_config_prevents_sampling( | ||
| test_id: str, dbt_project: DbtProject | ||
| ): | ||
| null_count = 20 | ||
| data = [{COLUMN_NAME: None} for _ in range(null_count)] | ||
|
|
||
| columns = [ | ||
| { | ||
| "name": COLUMN_NAME, | ||
| "config": {"disable_samples": True}, | ||
| "tests": [{"not_null": {}}], | ||
| } | ||
| ] | ||
|
|
||
| test_result = dbt_project.test( | ||
| test_id, | ||
| "not_null", | ||
| columns=columns, | ||
| data=data, | ||
| test_vars={ | ||
| "enable_elementary_test_materialization": True, | ||
| "test_sample_row_count": 5, | ||
| }, | ||
| ) | ||
| assert test_result["status"] == "fail" | ||
|
|
||
| samples = [ | ||
| json.loads(row["result_row"]) | ||
| for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) | ||
| ] | ||
| assert len(samples) == 0 | ||
|
|
||
|
|
||
| @pytest.mark.skip_targets(["clickhouse"]) | ||
| def test_disable_samples_false_allows_sampling(test_id: str, dbt_project: DbtProject): | ||
| null_count = 20 | ||
| data = [{COLUMN_NAME: None} for _ in range(null_count)] | ||
|
|
||
| columns = [ | ||
| { | ||
| "name": COLUMN_NAME, | ||
| "config": {"disable_samples": False}, | ||
| "tests": [{"not_null": {}}], | ||
| } | ||
| ] | ||
|
|
||
| test_result = dbt_project.test( | ||
| test_id, | ||
| "not_null", | ||
| columns=columns, | ||
| data=data, | ||
| test_vars={ | ||
| "enable_elementary_test_materialization": True, | ||
| "test_sample_row_count": 5, | ||
| }, | ||
| ) | ||
| assert test_result["status"] == "fail" | ||
|
|
||
| samples = [ | ||
| json.loads(row["result_row"]) | ||
| for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) | ||
| ] | ||
| assert len(samples) == 5 | ||
| assert all([row == {COLUMN_NAME: None} for row in samples]) | ||
|
|
||
|
|
||
| @pytest.mark.skip_targets(["clickhouse"]) | ||
| def test_disable_samples_config_overrides_pii_tags( | ||
| test_id: str, dbt_project: DbtProject | ||
| ): | ||
| null_count = 20 | ||
| data = [{COLUMN_NAME: None} for _ in range(null_count)] | ||
|
|
||
| columns = [ | ||
| { | ||
| "name": COLUMN_NAME, | ||
| "config": {"disable_samples": True, "tags": ["pii"]}, | ||
| "tests": [{"not_null": {}}], | ||
| } | ||
| ] | ||
|
|
||
| test_result = dbt_project.test( | ||
| test_id, | ||
| "not_null", | ||
| columns=columns, | ||
| data=data, | ||
| test_vars={ | ||
| "enable_elementary_test_materialization": True, | ||
| "test_sample_row_count": 5, | ||
| "disable_samples_on_pii_columns": True, | ||
| }, | ||
| ) | ||
| assert test_result["status"] == "fail" | ||
|
|
||
| samples = [ | ||
| json.loads(row["result_row"]) | ||
| for row in dbt_project.run_query(SAMPLES_QUERY.format(test_id=test_id)) | ||
| ] | ||
| assert len(samples) == 0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,7 +51,8 @@ | |
| {% macro handle_dbt_test(flattened_test, materialization_macro) %} | ||
| {% set result = materialization_macro() %} | ||
| {% set result_rows = elementary.query_test_result_rows(sample_limit=elementary.get_config_var('test_sample_row_count'), | ||
| ignore_passed_tests=true) %} | ||
| ignore_passed_tests=true, | ||
| flattened_test=flattened_test) %} | ||
| {% set elementary_test_results_row = elementary.get_dbt_test_result_row(flattened_test, result_rows) %} | ||
| {% do elementary.cache_elementary_test_results_rows([elementary_test_results_row]) %} | ||
| {% do return(result) %} | ||
|
|
@@ -103,23 +104,70 @@ | |
| {% do return(new_sql) %} | ||
| {% endmacro %} | ||
|
|
||
| {% macro query_test_result_rows(sample_limit=none, ignore_passed_tests=false) %} | ||
| {% macro query_test_result_rows(sample_limit=none, ignore_passed_tests=false, flattened_test=none) %} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the flattened_test param needed? |
||
| {% if sample_limit == 0 %} {# performance: no need to run a sql query that we know returns an empty list #} | ||
| {% do return([]) %} | ||
| {% endif %} | ||
| {% if ignore_passed_tests and elementary.did_test_pass() %} | ||
| {% do elementary.debug_log("Skipping sample query because the test passed.") %} | ||
| {% do return([]) %} | ||
| {% endif %} | ||
|
|
||
| {% if flattened_test and elementary.is_sampling_disabled_for_column(flattened_test) %} | ||
| {% do elementary.debug_log("Skipping sample query because disable_samples is true for this column.") %} | ||
|
arbiv marked this conversation as resolved.
Outdated
|
||
| {% do return([]) %} | ||
| {% endif %} | ||
|
|
||
| {% set pii_columns = [] %} | ||
| {% if flattened_test %} | ||
| {% set pii_columns = elementary.get_pii_columns_from_parent_model(flattened_test) %} | ||
| {% endif %} | ||
|
|
||
| {% set select_clause = "*" %} | ||
| {% if pii_columns %} | ||
| {% set query_to_get_columns %} | ||
| with test_results as ( | ||
| {{ sql }} | ||
| ) | ||
| select * from test_results limit 0 | ||
| {% endset %} | ||
| {% set columns_result = elementary.run_query(query_to_get_columns) %} | ||
| {% set all_columns = columns_result.column_names %} | ||
| {% set safe_columns = all_columns | reject("in", pii_columns) | list %} | ||
| {% if safe_columns %} | ||
| {% set select_clause = safe_columns | join(", ") %} | ||
| {% else %} | ||
| {% set select_clause = "1 as _no_non_pii_columns" %} | ||
| {% endif %} | ||
| {% endif %} | ||
|
|
||
| {% set query %} | ||
| with test_results as ( | ||
| {{ sql }} | ||
| ) | ||
| select * from test_results {% if sample_limit is not none %} limit {{ sample_limit }} {% endif %} | ||
| select {{ select_clause }} from test_results {% if sample_limit is not none %} limit {{ sample_limit }} {% endif %} | ||
| {% endset %} | ||
| {% do return(elementary.agate_to_dicts(elementary.run_query(query))) %} | ||
| {% endmacro %} | ||
|
|
||
| {% macro is_sampling_disabled_for_column(flattened_test) %} | ||
| {% set test_column_name = elementary.insensitive_get_dict_value(flattened_test, 'test_column_name') %} | ||
| {% set parent_model_unique_id = elementary.insensitive_get_dict_value(flattened_test, 'parent_model_unique_id') %} | ||
|
|
||
| {% if not test_column_name or not parent_model_unique_id %} | ||
| {% do return(false) %} | ||
| {% endif %} | ||
|
|
||
| {% set parent_model = elementary.get_node(parent_model_unique_id) %} | ||
| {% if parent_model and parent_model.get('columns') %} | ||
| {% set column_config = parent_model.get('columns', {}).get(test_column_name, {}).get('config', {}) %} | ||
| {% set disable_samples = elementary.safe_get_with_default(column_config, 'disable_samples', false) %} | ||
| {% do return(disable_samples) %} | ||
| {% endif %} | ||
|
|
||
| {% do return(false) %} | ||
| {% endmacro %} | ||
|
|
||
| {% macro cache_elementary_test_results_rows(elementary_test_results_rows) %} | ||
| {% do elementary.get_cache("elementary_test_results").update({model.unique_id: elementary_test_results_rows}) %} | ||
| {% endmacro %} | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,42 @@ | ||||
| {% macro get_pii_columns_from_parent_model(flattened_test) %} | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's rename the file to match the macro? |
||||
| {% set pii_columns = [] %} | ||||
|
|
||||
| {% if not elementary.get_config_var('disable_samples_on_pii_columns') %} | ||||
| {% do return(pii_columns) %} | ||||
| {% endif %} | ||||
|
|
||||
| {% set parent_model_unique_id = elementary.insensitive_get_dict_value(flattened_test, 'parent_model_unique_id') %} | ||||
| {% set parent_model = elementary.get_node(parent_model_unique_id) %} | ||||
|
|
||||
| {% if not parent_model %} | ||||
| {% do return(pii_columns) %} | ||||
| {% endif %} | ||||
|
|
||||
| {% set column_nodes = parent_model.get("columns") %} | ||||
| {% if not column_nodes %} | ||||
| {% do return(pii_columns) %} | ||||
| {% endif %} | ||||
|
|
||||
| {% set pii_column_tags = elementary.get_config_var('pii_column_tags') %} | ||||
| {% if pii_column_tags is string %} | ||||
| {% set pii_column_tags = [pii_column_tags] %} | ||||
| {% endif %} | ||||
|
|
||||
| {% for column_node in column_nodes.values() %} | ||||
| {% set config_dict = column_node.get('config', {}) %} | ||||
| {% set config_tags = config_dict.get('tags', []) %} | ||||
| {% set global_tags = column_node.get('tags', []) %} | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename to column tags |
||||
| {% set meta_dict = column_node.get('meta', {}) %} | ||||
| {% set meta_tags = meta_dict.get('tags', []) %} | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also consider the case where the model itself has a PII tag.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is implemented here:
|
||||
| {% set all_column_tags = config_tags + global_tags + meta_tags %} | ||||
|
|
||||
| {% for pii_tag in pii_column_tags %} | ||||
| {% if pii_tag in all_column_tags %} | ||||
| {% do pii_columns.append(column_node.get('name')) %} | ||||
| {% break %} | ||||
| {% endif %} | ||||
| {% endfor %} | ||||
| {% endfor %} | ||||
|
|
||||
| {% do return(pii_columns) %} | ||||
| {% endmacro %} | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Configuration variable name inconsistency
Line 45 uses
disable_samples_on_pii_tags: Truewhich matches the incorrect variable name in the macro, but according to the PR objectives it should bedisable_samples_on_pii_columns. This test will need to be updated when the macro is fixed.The test logic and assertions are correct for the intended behavior.
🤖 Prompt for AI Agents