Skip to content

Commit 20396e6

Browse files
committed
Add column mask support
1 parent dab5794 commit 20396e6

11 files changed

Lines changed: 387 additions & 0 deletions

File tree

dbt/adapters/databricks/impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,7 @@ def _describe_relation(
984984
results["foreign_key_constraints"] = adapter.execute_macro(
985985
"fetch_foreign_key_constraints", kwargs=kwargs
986986
)
987+
results["column_masks"] = adapter.execute_macro("fetch_column_masks", kwargs=kwargs)
987988
results["show_tblproperties"] = adapter.execute_macro("fetch_tbl_properties", kwargs=kwargs)
988989

989990
kwargs = {"table_name": relation}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from dataclasses import asdict
2+
from typing import ClassVar, Optional
3+
4+
from dbt.adapters.contracts.relation import RelationConfig
5+
from dbt.adapters.databricks.relation_configs.base import (
6+
DatabricksComponentConfig,
7+
DatabricksComponentProcessor,
8+
)
9+
from dbt.adapters.relation_configs.config_base import RelationResults
10+
11+
12+
class ColumnMaskConfig(DatabricksComponentConfig):
13+
# column name -> mask
14+
set_column_masks: dict[str, str]
15+
unset_column_masks: list[str] = []
16+
17+
def get_diff(self, other: "ColumnMaskConfig") -> Optional["ColumnMaskConfig"]:
18+
# Find column masks that need to be unset
19+
unset_column_mask = [
20+
col for col in other.set_column_masks if col not in self.set_column_masks
21+
]
22+
23+
# Find column masks that need to be set or updated
24+
set_column_mask = {
25+
col: mask
26+
for col, mask in self.set_column_masks.items()
27+
if col not in other.set_column_masks or other.set_column_masks[col] != mask
28+
}
29+
30+
if set_column_mask or unset_column_mask:
31+
return ColumnMaskConfig(
32+
set_column_masks=set_column_mask,
33+
unset_column_masks=unset_column_mask,
34+
)
35+
return None
36+
37+
38+
class ColumnMaskProcessor(DatabricksComponentProcessor[ColumnMaskConfig]):
39+
name: ClassVar[str] = "column_masks"
40+
41+
@classmethod
42+
def from_relation_results(cls, results: RelationResults) -> ColumnMaskConfig:
43+
column_masks = results.get("column_masks")
44+
set_column_masks = {}
45+
46+
if column_masks:
47+
for row in column_masks.rows:
48+
set_column_masks[row[0]] = row[1]
49+
50+
return ColumnMaskConfig(set_column_masks=set_column_masks)
51+
52+
@classmethod
53+
def from_relation_config(cls, relation_config: RelationConfig) -> ColumnMaskConfig:
54+
# Extract config from model node
55+
columns = getattr(relation_config, "columns", {})
56+
columns = [
57+
{"name": name, **(col if isinstance(col, dict) else asdict(col))}
58+
for name, col in columns.items()
59+
]
60+
61+
set_column_masks = {}
62+
for col in columns:
63+
extra = col.get("_extra", {})
64+
if extra and "column_mask" in extra:
65+
set_column_masks[col["name"]] = extra["column_mask"]
66+
return ColumnMaskConfig(set_column_masks=set_column_masks)

dbt/adapters/databricks/relation_configs/incremental.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
DatabricksRelationConfigBase,
33
)
44
from dbt.adapters.databricks.relation_configs.column_comments import ColumnCommentsProcessor
5+
from dbt.adapters.databricks.relation_configs.column_mask import ColumnMaskProcessor
56
from dbt.adapters.databricks.relation_configs.comment import CommentProcessor
67
from dbt.adapters.databricks.relation_configs.constraints import ConstraintsProcessor
78
from dbt.adapters.databricks.relation_configs.liquid_clustering import LiquidClusteringProcessor
@@ -13,6 +14,7 @@ class IncrementalTableConfig(DatabricksRelationConfigBase):
1314
config_components = [
1415
CommentProcessor,
1516
ColumnCommentsProcessor,
17+
ColumnMaskProcessor,
1618
ConstraintsProcessor,
1719
TagsProcessor,
1820
TblPropertiesProcessor,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
{% macro fetch_column_masks(relation) -%}
2+
{% if relation.is_hive_metastore() %}
3+
{{ exceptions.raise_compiler_error("Column masks are not supported for Hive Metastore") }}
4+
{%- endif %}
5+
{% call statement('list_column_masks', fetch_result=True) -%}
6+
{{ fetch_column_masks_sql(relation) }}
7+
{% endcall %}
8+
{% do return(load_result('list_column_masks').table) %}
9+
{%- endmacro -%}
10+
11+
{% macro fetch_column_masks_sql(relation) -%}
12+
SELECT
13+
column_name,
14+
mask_name
15+
FROM `{{ relation.database|lower }}`.information_schema.column_masks
16+
WHERE table_catalog = '{{ relation.database|lower }}'
17+
AND table_schema = '{{ relation.schema|lower }}'
18+
AND table_name = '{{ relation.identifier|lower }}';
19+
{%- endmacro -%}
20+
21+
{% macro apply_column_masks_from_model_columns(relation) -%}
22+
{% if relation.is_hive_metastore() %}
23+
{{ exceptions.raise_compiler_error("Column masks are not supported for Hive Metastore") }}
24+
{%- endif %}
25+
{{ log("Applying column masks from model to relation " ~ relation) }}
26+
{% set columns = model.get('columns', {}) %}
27+
{% for column_name, column_def in columns.items() %}
28+
{% if column_def is mapping and column_def.get('column_mask') %}
29+
{%- call statement('main') -%}
30+
{{ alter_set_column_mask(relation, column_name, column_def.column_mask) }}
31+
{%- endcall -%}
32+
{% endif %}
33+
{% endfor %}
34+
{%- endmacro -%}
35+
36+
{% macro apply_column_masks(relation, column_masks) -%}
37+
{% if relation.is_hive_metastore() %}
38+
{{ exceptions.raise_compiler_error("Column masks are not supported for Hive Metastore") }}
39+
{%- endif %}
40+
{{ log("Applying column masks to relation " ~ relation) }}
41+
{%- if column_masks.unset_column_mask %}
42+
{%- for column in column_masks.unset_column_mask -%}
43+
{%- call statement('main') -%}
44+
{{ alter_drop_column_mask(relation, column) }}
45+
{%- endcall -%}
46+
{%- endfor -%}
47+
{%- endif %}
48+
{%- if column_masks.set_column_mask %}
49+
{%- for column, mask in column_masks.set_column_mask.items() -%}
50+
{%- call statement('main') -%}
51+
{{ alter_set_column_mask(relation, column, mask) }}
52+
{%- endcall -%}
53+
{%- endfor -%}
54+
{%- endif %}
55+
{%- endmacro -%}
56+
57+
{% macro alter_drop_column_mask(relation, column) -%}
58+
ALTER TABLE {{ relation.render() }}
59+
ALTER COLUMN {{ column }}
60+
DROP MASK;
61+
{%- endmacro -%}
62+
63+
{% macro alter_set_column_mask(relation, column, mask) -%}
64+
ALTER TABLE {{ relation.render() }}
65+
ALTER COLUMN {{ column }}
66+
SET MASK {{ mask }};
67+
{%- endmacro -%}
68+
69+

dbt/include/databricks/macros/relations/table/alter.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
{% set tblproperties = configuration_changes.changes.get("tblproperties") %}
88
{% set liquid_clustering = configuration_changes.changes.get("liquid_clustering")%}
99
{% set constraints = configuration_changes.changes.get("constraints") %}
10+
{% set column_masks = configuration_changes.changes.get("column_masks") %}
1011
{% if tags is not none %}
1112
{% do apply_tags(target_relation, tags.set_tags, tags.unset_tags) %}
1213
{%- endif -%}
@@ -25,5 +26,8 @@
2526
{% if constraints %}
2627
{{ apply_constraints(target_relation, constraints) }}
2728
{% endif %}
29+
{% if column_masks %}
30+
{{ apply_column_masks(target_relation, column_masks) }}
31+
{% endif %}
2832
{%- endif -%}
2933
{% endmacro %}

dbt/include/databricks/macros/relations/table/create.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
{{ apply_alter_constraints(target_relation) }}
1414
{{ apply_tags(target_relation, tags) }}
15+
{{ apply_column_masks_from_model_columns(target_relation) }}
1516

1617
{% call statement('merge into target') %}
1718
insert into {{ target_relation }} select * from {{ intermediate_relation }}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
base_model_sql = """
2+
{{ config(
3+
materialized = 'table'
4+
) }}
5+
SELECT 'abc-123' as id, 'password123' as password;
6+
"""
7+
8+
9+
model = """
10+
version: 2
11+
models:
12+
- name: base_model
13+
columns:
14+
- name: id
15+
data_type: string
16+
- name: password
17+
column_mask: password_mask
18+
data_type: string
19+
"""
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
3+
from dbt.tests.util import run_dbt
4+
from tests.functional.adapter.column_masks.fixtures import (
5+
base_model_sql,
6+
model,
7+
)
8+
from tests.functional.adapter.fixtures import MaterializationV2Mixin
9+
10+
11+
class TestColumnMask(MaterializationV2Mixin):
12+
@pytest.fixture(scope="class")
13+
def models(self):
14+
return {
15+
"base_model.sql": base_model_sql,
16+
"schema.yml": model,
17+
}
18+
19+
def test_column_mask(self, project):
20+
# Create the mask function
21+
project.run_sql(
22+
f"CREATE OR REPLACE FUNCTION {project.database}.{project.test_schema}."
23+
"password_mask(password STRING) RETURNS STRING RETURN '*****';"
24+
)
25+
26+
run_dbt(["run"])
27+
28+
# Verify column mask was created
29+
masks = project.run_sql(
30+
f"""
31+
SELECT column_name, mask_name
32+
FROM {project.database}.information_schema.column_masks
33+
""",
34+
fetch="all",
35+
)
36+
37+
assert len(masks) == 1
38+
assert masks[0][0] == "password" # column_name
39+
assert masks[0][1] == f"{project.database}.{project.test_schema}.password_mask" # mask_name
40+
41+
# Verify masked value
42+
result = project.run_sql("SELECT id, password FROM base_model", fetch="one")
43+
assert result[0] == "abc-123"
44+
assert result[1] == "*****" # Masked value should be 5 asterisks
45+
46+
47+
class TestIncrementalColumnMask(TestColumnMask):
48+
@pytest.fixture(scope="class")
49+
def models(self):
50+
return {
51+
"base_model.sql": base_model_sql.replace("table", "incremental"),
52+
"schema.yml": model,
53+
}

tests/functional/adapter/incremental/fixtures.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,3 +850,43 @@ def model(dbt, spark):
850850
to: ref('fk_referenced_to_table')
851851
to_columns: [id, version]
852852
"""
853+
854+
855+
column_mask_sql = """
856+
{{ config(
857+
materialized = 'incremental',
858+
incremental_strategy = 'merge',
859+
unique_key = 'id',
860+
) }}
861+
862+
select cast(1 as bigint) as id, 'hello' as name, 'john.doe@example.com' as email,
863+
'password123' as password
864+
"""
865+
866+
column_mask_name = """
867+
version: 2
868+
869+
models:
870+
- name: column_mask_sql
871+
columns:
872+
- name: id
873+
- name: name
874+
column_mask: full_mask
875+
- name: email
876+
column_mask: full_mask
877+
- name: password
878+
"""
879+
880+
column_mask_password = """
881+
version: 2
882+
883+
models:
884+
- name: column_mask_sql
885+
columns:
886+
- name: id
887+
- name: name
888+
- name: email
889+
column_mask: email_mask
890+
- name: password
891+
column_mask: full_mask
892+
"""
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import pytest
2+
3+
from dbt.tests import util
4+
from tests.functional.adapter.fixtures import MaterializationV2Mixin
5+
from tests.functional.adapter.incremental import fixtures
6+
7+
8+
@pytest.mark.skip_profile("databricks_cluster")
9+
class TestIncrementalColumnMasks(MaterializationV2Mixin):
10+
@pytest.fixture(scope="class")
11+
def models(self):
12+
return {
13+
"column_mask_sql.sql": fixtures.column_mask_sql,
14+
"schema.yml": fixtures.column_mask_name,
15+
}
16+
17+
def test_changing_column_masks(self, project):
18+
# Create the mask functions
19+
project.run_sql(
20+
f"""
21+
CREATE OR REPLACE FUNCTION
22+
{project.database}.{project.test_schema}.full_mask(value STRING)
23+
RETURNS STRING
24+
RETURN '*****';
25+
"""
26+
)
27+
# Masks all characters before the @ symbol
28+
project.run_sql(
29+
f"""
30+
CREATE OR REPLACE FUNCTION
31+
{project.database}.{project.test_schema}.email_mask(value STRING)
32+
RETURNS STRING
33+
RETURN CONCAT(
34+
REPEAT('*', POSITION('@' IN value) - 1),
35+
SUBSTR(value, POSITION('@' IN value))
36+
);
37+
"""
38+
)
39+
40+
# First run with name masked
41+
util.run_dbt(["run"])
42+
masks = project.run_sql(
43+
"SELECT id, name, email, password FROM column_mask_sql",
44+
fetch="all",
45+
)
46+
assert len(masks) == 1
47+
assert masks[0][0] == 1
48+
assert masks[0][1] == "*****" # name (masked)
49+
assert masks[0][2] == "*****" # email (masked)
50+
assert masks[0][3] == "password123" # password (unmasked)
51+
52+
# Update masks and verify changes
53+
util.write_file(fixtures.column_mask_password, "models", "schema.yml")
54+
util.run_dbt(["run"])
55+
56+
result = project.run_sql(
57+
"SELECT id, name, email, password FROM column_mask_sql", fetch="all"
58+
)
59+
assert len(result) == 1
60+
assert result[0][0] == 1
61+
assert result[0][1] == "hello" # name (unmasked)
62+
assert result[0][2] == "********@example.com" # email (partially masked)
63+
assert result[0][3] == "*****" # password (masked)

0 commit comments

Comments
 (0)