diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index 34b250b3b..ae924e64c 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -1481,4 +1481,62 @@ def safe_parse_json(col): # explode warnings warnings_df = valid_and_quarantine_df.select(F.explode(F.col("dq_warnings")).alias("dq")).select(F.expr("dq.*")) -display(warnings_df) \ No newline at end of file +display(warnings_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Advanced: Variable Substitution +# MAGIC +# MAGIC DQX supports variable substitution in declarative check definitions (YAML, JSON, or Delta tables). +# MAGIC This allows you to parameterize your rules and inject values at **load time** via the `variables` parameter in `load_checks`. +# MAGIC +# MAGIC ### Example Usage +# MAGIC +# MAGIC 1. Define a rule with `{{ placeholder }}` syntax. +# MAGIC 2. Pass a dictionary of variables when loading the rules. + +# COMMAND ---------- + +from databricks.labs.dqx.config import WorkspaceFileChecksStorageConfig + +# Save to a temporary file + +# Define parameterized checks +parameterized_checks_yaml = """ +- criticality: error + name: "threshold_check_{{ threshold_name }}" + check: + function: is_not_greater_than + arguments: + column: "{{ target_column }}" + limit: "{{ max_value }}" +""" + +# Save to a temporary file +# demo_file_directory is defined at the beginning of this notebook +temp_checks_path = os.path.join(demo_file_directory, "parameterized_checks.yml") +with open(temp_checks_path, "w") as f: + f.write(parameterized_checks_yaml) + +dq_engine = DQEngine(WorkspaceClient()) + +# Load checks with variable resolution +# Resolution happens during the load process +resolved_checks = dq_engine.load_checks( + config=WorkspaceFileChecksStorageConfig(location=temp_checks_path), + variables={ + "threshold_name": "critical", + "target_column": "col1", + "max_value": 100 + } +) + +# The resolved checks now have the values injected +# Note: DQEngine internally converts string numbers to their appropriate types if needed during validation or apply +print(yaml.dump(resolved_checks)) + +# Apply the resolved checks to a DataFrame +data = spark.createDataFrame([[50], [150]], "col1: int") +result_df = dq_engine.apply_checks_by_metadata(data, resolved_checks) +display(result_df) \ No newline at end of file diff --git a/docs/dqx/docs/guide/additional_configuration.mdx b/docs/dqx/docs/guide/additional_configuration.mdx index 69701e6fb..e078f9902 100644 --- a/docs/dqx/docs/guide/additional_configuration.mdx +++ b/docs/dqx/docs/guide/additional_configuration.mdx @@ -171,3 +171,80 @@ from pyspark.sql import functions as F skipped = checked_df.select(F.explode("_errors").alias("e")).filter(F.col("e.skipped") == True) ``` + +## Defining default variables for substitution + +DQX allows you to define engine-level defaults for variables used in declarative check definitions (YAML, JSON, or Delta tables). These defaults are automatically applied during `load_checks` and `save_checks` unless overridden by the per-call `variables` parameter. + + + + ```python + from databricks.labs.dqx.engine import DQEngine + from databricks.labs.dqx.config import ExtraParams, FileChecksStorageConfig, TableChecksStorageConfig + from databricks.sdk import WorkspaceClient + + # Initialize engine with default variables + dq_engine = DQEngine( + WorkspaceClient(), + extra_params=ExtraParams( + variables={ + "min_temp": 0, + "max_temp": 50, + "region": "GLOBAL" + } + ) + ) + + # Load checks - uses 'min_temp' and 'max_temp' from defaults, + # but overrides 'region' specifically for this call. + resolved_checks = dq_engine.load_checks( + config=FileChecksStorageConfig(location="checks.yml"), + variables={"region": "EMEA"}, + ) + + # Save checks - resolves variables before computing fingerprints and persisting. + # Uses 'min_temp' and 'max_temp' from defaults, overrides 'region' for this call. + dq_engine.save_checks( + checks=checks, + config=TableChecksStorageConfig(location="catalog.schema.checks_table"), + variables={"region": "EMEA"}, + ) + ``` + + + + +Variable substitution is not currently supported in DQX installable workflows. Variables can be defined and stored as YAML in the configuration file but will not be applied during workflow execution. + +Variable substitution is only available when defining checks declaratively (as dictionaries or in files/tables). It is not supported when using DQX classes (e.g., `DQRowRule`) directly. + + +## Overwriting run metadata + +By default, DQX automatically generates a unique `run_id` for each engine instance and uses the current timestamp as the `run_time`. You can manually overwrite these values using `ExtraParams` if you need to align DQX results with external systems or re-run checks for a specific historical point in time. + + + + ```python + from databricks.labs.dqx.engine import DQEngine + from databricks.labs.dqx.config import ExtraParams + from databricks.sdk import WorkspaceClient + + extra_params = ExtraParams( + run_id_overwrite="custom-execution-id-123", + run_time_overwrite="2024-01-01T12:00:00Z" + ) + + dq_engine = DQEngine(WorkspaceClient(), extra_params=extra_params) + ``` + + + You can set the following fields in the [configuration file](/docs/installation/#configuration-file) to overwrite the run metadata when using DQX workflows: + ```yaml + extra_params: + run_id_overwrite: custom-execution-id-123 + run_time_overwrite: 2024-01-01T12:00:00Z + ``` + + + diff --git a/docs/dqx/docs/guide/quality_checks_definition.mdx b/docs/dqx/docs/guide/quality_checks_definition.mdx index ffd270393..ff11678d8 100644 --- a/docs/dqx/docs/guide/quality_checks_definition.mdx +++ b/docs/dqx/docs/guide/quality_checks_definition.mdx @@ -720,6 +720,80 @@ Example checks saved in a Delta or Lakebase table (compact format — `for_each_ If `run_config_name` is not provided, "default" is used. Typically, the input table or job name is used for run config name to establish a one-to-one mapping between tables or jobs and checks. +## Variable Substitution + +DQX supports variable substitution in declarative check definitions (YAML, JSON, or Delta tables). This allows you to parameterize your quality rules and inject values at **load time** or **save time** from engine-level defaults and/or via the `variables` parameter in `load_checks` or `save_checks`. + +### Syntax and Scope + +Placeholders are defined using the `{{ variable_name }}` syntax. Variable substitution is supported in **all string values** within the check definitions, including: +- `name` +- `filter` +- `check` function arguments (`arguments`) and column names (`for_each_column`) +- any other top-level or nested string field + + +The `criticality` field only accepts fixed values (`error` or `warn`). Do not use variable placeholders for `criticality` — the resolved value must be a valid criticality and substituting it defeats the purpose of having an explicit severity level in the check definition. + + +### Resolution + +Variables are resolved when checks are loaded or saved via the engine. To resolve variables, pass a dictionary to the `variables` parameter of `load_checks` or `save_checks`. User can decide whether to provide variables when loading or saving checks. + + +When using `save_checks` with variables, placeholders are resolved **before** computing rule fingerprints and persisting. This ensures that stored checks and their fingerprints reflect the actual resolved check logic. Without resolving at save time, fingerprints would be computed on unresolved `{{ }}` placeholders, causing a mismatch between the fingerprints stored in the checks table and those recorded in the summary metrics and per-row detailed results tables. + + + +Variable substitution is only available when defining checks declaratively (as dictionaries or in files/tables). It is not supported when using DQX classes (e.g., `DQRowRule`) directly. + + +```python +import yaml +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.config import FileChecksStorageConfig, TableChecksStorageConfig +from databricks.sdk import WorkspaceClient + +dq_engine = DQEngine(WorkspaceClient()) + +# Define checks with variable placeholders +checks = yaml.safe_load(""" + - criticality: error + check: + function: is_in_range + arguments: + column: temperature + min_limit: "{{ min_temp }}" + max_limit: "{{ max_temp }}" + filter: "region = '{{ region }}'" +""") + +variables = { + "min_temp": 0, + "max_temp": 100, + "region": "EMEA", +} + +# Load checks from file with variable resolution +resolved_checks = dq_engine.load_checks( + config=FileChecksStorageConfig(location="checks.yml"), + variables=variables, +) + +# Or resolve variables when saving checks (ensures fingerprints are consistent) +dq_engine.save_checks( + checks=checks, + config=TableChecksStorageConfig(location="catalog.schema.checks_table"), + variables=variables, +) +``` + +## Default Variables + +In addition to specifying variables during the load or save process, you can define engine-level defaults using the `ExtraParams` class. These constants are automatically applied to all checks unless explicitly overridden. + +For technical details and configuration examples, see [Default Variables](/docs/guide/additional_configuration#defining-default-variables-for-substitution) in the Additional Configuration guide. + ## Validating syntax of quality checks You can validate the syntax of checks loaded from a storage system or checks defined programmatically before applying them. diff --git a/docs/dqx/docs/guide/quality_checks_storage.mdx b/docs/dqx/docs/guide/quality_checks_storage.mdx index 1f330073c..c56a36038 100644 --- a/docs/dqx/docs/guide/quality_checks_storage.mdx +++ b/docs/dqx/docs/guide/quality_checks_storage.mdx @@ -180,6 +180,12 @@ If you create checks as a list of DQRule objects, you can convert them using the # also works for absolute and relative workspace paths if invoked from Databricks notebook or job checks: list[dict] = dq_engine.load_checks(config=FileChecksStorageConfig(location="checks.yml")) + # load checks from a local file with variable substitution + checks: list[dict] = dq_engine.load_checks( + FileChecksStorageConfig(location="checks.yml"), + variables={"threshold": 100, "column_name": "total_amount"} + ) + # load checks from arbitrary workspace location using absolute path checks: list[dict] = dq_engine.load_checks(config=WorkspaceFileChecksStorageConfig(location="/Shared/App1/checks.yml")) diff --git a/docs/dqx/docs/reference/engine.mdx b/docs/dqx/docs/reference/engine.mdx index e88c19487..2177e28d1 100644 --- a/docs/dqx/docs/reference/engine.mdx +++ b/docs/dqx/docs/reference/engine.mdx @@ -62,8 +62,8 @@ The following table outlines the available methods of the `DQEngine` and their f | `validate_checks` | Validates declarative checks (list of dict metadata): expected shape, argument types where the check function has annotations, unknown argument names, and required parameters of each check function’s signature. | `checks`: List of checks to validate; `custom_check_functions`: (optional) Dictionary of custom check functions that can be used; `validate_custom_check_functions`: (optional) If True, validates custom check functions (defaults to True). | Yes | | `get_invalid` | Retrieves records from the DataFrame that violate data quality checks (records with warnings and errors). | `df`: Input DataFrame. | Yes | | `get_valid` | Retrieves records from the DataFrame that pass all data quality checks. | `df`: Input DataFrame. | Yes | -| `load_checks` | Loads quality rules (checks) from storage backend. Multiple storage backends are supported including tables, files, workspace files, or installation-managed sources inferred from run config. | `config`: Configuration for loading checks from a storage backend, e.g., `FileChecksStorageConfig` (local YAML/JSON file or workspace file), `WorkspaceFileChecksStorageConfig` (workspace file with absolute path), `VolumeFileChecksStorageConfig` (Unity Catalog Volume YAML/JSON), `TableChecksStorageConfig` (table), `InstallationChecksStorageConfig` (installation-managed backend using `checks_location` in run config). | Yes (only with `FileChecksStorageConfig`) | -| `save_checks` | Saves quality rules (checks) to a storage backend. Multiple storage backends are supported including tables, files, workspace files, or installation-managed targets inferred from run config. | `checks`: List of checks defined as dictionary; `config`: Configuration for saving checks in a storage backend, e.g., `FileChecksStorageConfig` (local YAML/JSON file or workspace file), `WorkspaceFileChecksStorageConfig` (workspace file with absolute path), `VolumeFileChecksStorageConfig` (Unity Catalog Volume YAML/JSON), `TableChecksStorageConfig` (table), `InstallationChecksStorageConfig` (installation-managed backend using `checks_location` in run config). | Yes (only with `FileChecksStorageConfig`) | +| `load_checks` | Loads quality rules (checks) from storage backend. Multiple storage backends are supported including tables, files, workspace files, or installation-managed sources inferred from run config. | `config`: Configuration for loading checks from a storage backend, e.g., `FileChecksStorageConfig` (local YAML/JSON file or workspace file), `WorkspaceFileChecksStorageConfig` (workspace file with absolute path), `VolumeFileChecksStorageConfig` (Unity Catalog Volume YAML/JSON), `TableChecksStorageConfig` (table), `InstallationChecksStorageConfig` (installation-managed backend using `checks_location` in run config); `variables`: (optional) dictionary of variables for [variable substitution](/docs/guide/quality_checks_definition/#variable-substitution). | Yes (only with `FileChecksStorageConfig`) | +| `save_checks` | Saves quality rules (checks) to a storage backend. Multiple storage backends are supported including tables, files, workspace files, or installation-managed targets inferred from run config. Variables are resolved before computing fingerprints and persisting. | `checks`: List of checks defined as dictionary; `config`: Configuration for saving checks in a storage backend, e.g., `FileChecksStorageConfig` (local YAML/JSON file or workspace file), `WorkspaceFileChecksStorageConfig` (workspace file with absolute path), `VolumeFileChecksStorageConfig` (Unity Catalog Volume YAML/JSON), `TableChecksStorageConfig` (table), `InstallationChecksStorageConfig` (installation-managed backend using `checks_location` in run config); `variables`: (optional) dictionary of variables for [variable substitution](/docs/guide/quality_checks_definition/#variable-substitution). | Yes (only with `FileChecksStorageConfig`) | | `save_results_in_table` | Saves DataFrames as tables using Unity Catalog table references or storage paths. Supports both batch and streaming writes. For streaming DataFrames, returns a StreamingQuery that can be used to monitor or wait for completion. For batch DataFrames, data is written synchronously and None is returned. | `output_df`: (optional) DataFrame containing the output data (batch or streaming); `quarantine_df`: (optional) DataFrame containing invalid data (batch or streaming); `observation`: (optional) Spark Observation tracking summary metrics; `output_config`: `OutputConfig` with location (table name or storage path), mode, format, options, and optional trigger (supports `partition_by` or `cluster_by`, only one applies;); `quarantine_config`: (optional) `OutputConfig` with location (table name or storage path), mode, format, options, and optional trigger (supports `partition_by` or `cluster_by`, only one applies;); `metrics_config`: (optional) `OutputConfig` with location for summary metrics; `rule_set_fingerprint`: (optional) SHA-256 fingerprint of the rule set used for this run, included in summary metrics when metrics_config is provided; `run_config_name`: Name of the run config to use; `install_folder`: (optional) Installation folder where DQX is installed (only required for custom folder); `assume_user`: (optional) If True, assume user installation, otherwise global. | No | | `save_summary_metrics` | Saves quality checking summary metrics to a Delta table. | `observed_metrics`: `dict[str, Any]` Collected summary metrics from Spark Observation; `metrics_config`: `OutputConfig` object with the table name, output mode, and options for the summary metrics data; `input_config`: (optional) `InputConfig` object with the table name for reading the input data; `output_config`: (optional) `OutputConfig` object with the table name for the output data (supports `partition_by` or `cluster_by`, only one applies); `quarantine_config`: (optional) `OutputConfig` object with the table name for the quarantine data (supports `partition_by` or `cluster_by`, only one applies); `checks_location`: (optional) Location where checks are stored; `rule_set_fingerprint`: (optional) SHA-256 fingerprint of the rule set used for this run. | No | | `get_streaming_metrics_listener` | Gets a streaming metrics listener for writing metrics to an output table. Only required when using streaming DataFrames. | `metrics_config`: `OutputConfig` object with the table name, output mode, and options for the summary metrics data; `input_config`: (optional) `InputConfig` object with the table name for reading the input data; `output_config`: (optional) `OutputConfig` object with the table name for the output data (supports `partition_by` or `cluster_by`, only one applies); `quarantine_config`: (optional) `OutputConfig` object with the table name for the quarantine data (supports `partition_by` or `cluster_by`, only one applies); `checks_location`: (optional) checks location; `rule_set_fingerprint`: (optional) SHA-256 fingerprint of the rule set used for this run; `target_query_id`: (optional) Query ID of the specific streaming query to monitor, if provided, metrics will be collected only for this query. | No | diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index c663889d5..478c2b1d5 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -20,6 +20,10 @@ When you define checks **declaratively** (YAML, JSON, or list of dicts), check a You can explore the implementation details of the check functions [here](https://github.com/databrickslabs/dqx/blob/v0.13.0/src/databricks/labs/dqx/check_funcs.py). + +All declarative check definitions (YAML, JSON, or Delta tables) support **variable substitution** for string-based fields using the `{{ variable_name }}` syntax. This allows for dynamic parameterization of column names, thresholds, and filters at load time. See the [User Guide](/docs/guide/quality_checks_definition/#variable-substitution) for more details. + + ## Row-level checks reference Row-level checks are applied to each row in a PySpark DataFrame. The quality check results are reported for individual rows in the result columns. diff --git a/src/databricks/labs/dqx/base.py b/src/databricks/labs/dqx/base.py index 7a4fdf8f7..7a5c6a8d1 100644 --- a/src/databricks/labs/dqx/base.py +++ b/src/databricks/labs/dqx/base.py @@ -2,11 +2,14 @@ from collections.abc import Callable from functools import cached_property from typing import final + from pyspark.sql import DataFrame, Observation + +from databricks.labs.dqx.__about__ import __version__ from databricks.labs.dqx.checks_validator import ChecksValidationStatus from databricks.labs.dqx.rule import DQRule +from databricks.labs.dqx.utils import VariableValue from databricks.sdk import WorkspaceClient -from databricks.labs.dqx.__about__ import __version__ class DQEngineBase(abc.ABC): @@ -175,14 +178,19 @@ def get_valid(self, df: DataFrame) -> DataFrame: @staticmethod @abc.abstractmethod - def load_checks_from_local_file(filepath: str) -> list[dict]: + def load_checks_from_local_file(filepath: str, variables: dict[str, VariableValue] | None = None) -> list[dict]: """ Load DQ rules (checks) from a local JSON or YAML file. The returned checks can be used as input to *apply_checks_by_metadata*. + **Security note:** variable values substituted into **sql_expression** checks are + not sanitized. Callers must ensure that variable values come from trusted sources. + Args: filepath: Path to a file containing checks definitions. + variables: Optional mapping of placeholder names to replacement values. Replaces placeholders + in all string values of the check definitions before returning. Returns: List of DQ rules (checks). diff --git a/src/databricks/labs/dqx/config.py b/src/databricks/labs/dqx/config.py index ae6b9986f..04a268f86 100644 --- a/src/databricks/labs/dqx/config.py +++ b/src/databricks/labs/dqx/config.py @@ -4,6 +4,7 @@ from databricks.labs.dqx.checks_serializer import SerializerFactory from databricks.labs.dqx.errors import InvalidConfigError, InvalidParameterError +from databricks.labs.dqx.utils import VariableValue __all__ = [ "WorkspaceConfig", @@ -217,6 +218,7 @@ class ExtraParams: run_time_overwrite: str | None = None run_id_overwrite: str | None = None suppress_skipped: bool = False + variables: dict[str, VariableValue] = field(default_factory=dict) @dataclass diff --git a/src/databricks/labs/dqx/engine.py b/src/databricks/labs/dqx/engine.py index e48d78aa0..f3a3ce53a 100644 --- a/src/databricks/labs/dqx/engine.py +++ b/src/databricks/labs/dqx/engine.py @@ -51,7 +51,7 @@ from databricks.labs.dqx.telemetry import telemetry_logger, log_telemetry, log_dataframe_telemetry from databricks.sdk import WorkspaceClient from databricks.labs.dqx.errors import InvalidCheckError, InvalidConfigError, InvalidParameterError -from databricks.labs.dqx.utils import list_tables, safe_strip_file_from_path +from databricks.labs.dqx.utils import list_tables, safe_strip_file_from_path, resolve_variables, VariableValue from databricks.labs.dqx.io import is_one_time_trigger logger = logging.getLogger(__name__) @@ -342,19 +342,25 @@ def get_valid(self, df: DataFrame) -> DataFrame: ) @staticmethod - def load_checks_from_local_file(filepath: str) -> list[dict]: + def load_checks_from_local_file(filepath: str, variables: dict[str, VariableValue] | None = None) -> list[dict]: """ Load DQ rules (checks) from a local JSON or YAML file. The returned checks can be used as input to *apply_checks_by_metadata*. + **Security note:** variable values substituted into **sql_expression** checks are + not sanitized. Callers must ensure that variable values come from trusted sources. + Args: filepath: Path to a file containing checks definitions. + variables: Optional mapping of placeholder names to replacement values. Replaces placeholders + in all string values of the check definitions before returning. Returns: List of DQ rules. """ - return FileChecksStorageHandler().load(FileChecksStorageConfig(location=filepath)) + checks = FileChecksStorageHandler().load(FileChecksStorageConfig(location=filepath)) + return resolve_variables(checks=checks, variables=variables) @staticmethod def save_checks_in_local_file(checks: list[dict], filepath: str): @@ -579,8 +585,9 @@ def __init__( ): super().__init__(workspace_client) + self._extra_params = extra_params or ExtraParams() self.spark = SparkSession.builder.getOrCreate() if spark is None else spark - self._engine = engine or DQEngineCore(workspace_client, spark, extra_params, observer) + self._engine = engine or DQEngineCore(workspace_client, spark, self._extra_params, observer) self._config_serializer = config_serializer or ConfigSerializer(workspace_client) self._checks_handler_factory: BaseChecksStorageHandlerFactory = ( checks_handler_factory or ChecksStorageHandlerFactory(self.ws, self.spark) @@ -652,7 +659,9 @@ def apply_checks_by_metadata( summary metrics. Summary metrics are returned by any `DQEngine` with an `observer` specified. """ log_dataframe_telemetry(self.ws, self.spark, df) - return self._engine.apply_checks_by_metadata(df, checks, custom_check_functions, ref_dfs) + return self._engine.apply_checks_by_metadata( + df=df, checks=checks, custom_check_functions=custom_check_functions, ref_dfs=ref_dfs + ) @telemetry_logger("engine", "apply_checks_by_metadata_and_split") def apply_checks_by_metadata_and_split( @@ -681,7 +690,9 @@ def apply_checks_by_metadata_and_split( quality summary metrics. Summary metrics are returned by any `DQEngine` with an `observer` specified. """ log_dataframe_telemetry(self.ws, self.spark, df) - return self._engine.apply_checks_by_metadata_and_split(df, checks, custom_check_functions, ref_dfs) + return self._engine.apply_checks_by_metadata_and_split( + df=df, checks=checks, custom_check_functions=custom_check_functions, ref_dfs=ref_dfs + ) @telemetry_logger("engine", "apply_checks_and_save_in_table") def apply_checks_and_save_in_table( @@ -847,7 +858,9 @@ def apply_checks_by_metadata_and_save_in_table( quarantine_streaming_query = None if quarantine_config: - check_result = self.apply_checks_by_metadata_and_split(df, checks, custom_check_functions, ref_dfs) + check_result = self.apply_checks_by_metadata_and_split( + df=df, checks=checks, custom_check_functions=custom_check_functions, ref_dfs=ref_dfs + ) if self._engine.observer: good_df, bad_df, batch_observation = check_result else: @@ -856,7 +869,9 @@ def apply_checks_by_metadata_and_save_in_table( quarantine_streaming_query = save_dataframe_as_table(bad_df, quarantine_config) target_streaming_query = quarantine_streaming_query else: - check_result = self.apply_checks_by_metadata(df, checks, custom_check_functions, ref_dfs) + check_result = self.apply_checks_by_metadata( + df=df, checks=checks, custom_check_functions=custom_check_functions, ref_dfs=ref_dfs + ) if self._engine.observer: checked_df, batch_observation = check_result else: @@ -1035,7 +1050,11 @@ def validate_checks( Returns: ChecksValidationStatus indicating the validation result. """ - return DQEngineCore.validate_checks(checks, custom_check_functions, validate_custom_check_functions) + return DQEngineCore.validate_checks( + checks=checks, + custom_check_functions=custom_check_functions, + validate_custom_check_functions=validate_custom_check_functions, + ) def get_invalid(self, df: DataFrame) -> DataFrame: """ @@ -1168,7 +1187,9 @@ def save_results_in_table( ) @telemetry_logger("engine", "load_checks") - def load_checks(self, config: BaseChecksStorageConfig) -> list[dict]: + def load_checks( + self, config: BaseChecksStorageConfig, variables: dict[str, VariableValue] | None = None + ) -> list[dict]: """Load DQ rules (checks) from the storage backend described by *config*. This method delegates to a storage handler selected by the factory @@ -1183,8 +1204,16 @@ def load_checks(self, config: BaseChecksStorageConfig) -> list[dict]: - *InstallationChecksStorageConfig* (installation directory); - *VolumeFileChecksStorageConfig* (Unity Catalog volume file); + Per-call *variables* are merged with engine-level defaults from + *ExtraParams.variables* (per-call values take precedence on conflict). + + **Security note:** variable values substituted into **sql_expression** checks are + not sanitized. Callers must ensure that variable values come from trusted sources. + Args: config: Configuration object describing the storage backend. + variables: Optional mapping of placeholder names to replacement values. Replaces placeholders + in all string values of the check definitions before returning. Returns: List of DQ rules (checks) represented as dictionaries. @@ -1193,10 +1222,31 @@ def load_checks(self, config: BaseChecksStorageConfig) -> list[dict]: InvalidConfigError: If the configuration type is unsupported. """ handler = self._checks_handler_factory.create(config) - return handler.load(config) + checks = handler.load(config) + merged_variables = self._merge_variables(variables) + return resolve_variables(checks=checks, variables=merged_variables) + + def _merge_variables(self, per_call: dict[str, VariableValue] | None) -> dict[str, VariableValue] | None: + """Merge engine-level default variables with per-call overrides. + + Per-call values take precedence over engine-level defaults. + """ + defaults = self._extra_params.variables + if not defaults and not per_call: + return None + if not defaults: + return per_call + if not per_call: + return defaults + return {**defaults, **per_call} @telemetry_logger("engine", "save_checks") - def save_checks(self, checks: list[dict], config: BaseChecksStorageConfig) -> None: + def save_checks( + self, + checks: list[dict], + config: BaseChecksStorageConfig, + variables: dict[str, VariableValue] | None = None, + ) -> None: """Persist DQ rules (checks) to the storage backend described by *config*. The appropriate storage handler is resolved from the configuration @@ -1212,9 +1262,16 @@ def save_checks(self, checks: list[dict], config: BaseChecksStorageConfig) -> No - *InstallationChecksStorageConfig* (installation directory); - *VolumeFileChecksStorageConfig* (Unity Catalog volume file); + Per-call *variables* are merged with engine-level defaults from + *ExtraParams.variables* (per-call values take precedence on conflict). + Variables are resolved before computing fingerprints and persisting, + ensuring that stored checks and their fingerprints are consistent. + Args: checks: List of DQ rules (checks) to save (as dictionaries). config: Configuration object describing the storage backend and write options. + variables: Optional mapping of placeholder names to replacement values. Replaces placeholders + in all string values of the check definitions before saving. Returns: None @@ -1222,8 +1279,10 @@ def save_checks(self, checks: list[dict], config: BaseChecksStorageConfig) -> No Raises: InvalidConfigError: If the configuration type is unsupported. """ + merged_variables = self._merge_variables(variables) + resolved_checks = resolve_variables(checks=checks, variables=merged_variables) handler = self._checks_handler_factory.create(config) - handler.save(checks, config) + handler.save(resolved_checks, config) @telemetry_logger("engine", "save_summary_metrics") def save_summary_metrics( diff --git a/src/databricks/labs/dqx/utils.py b/src/databricks/labs/dqx/utils.py index a1a43a699..5e3a0ba92 100644 --- a/src/databricks/labs/dqx/utils.py +++ b/src/databricks/labs/dqx/utils.py @@ -6,10 +6,11 @@ from decimal import Decimal from enum import Enum from importlib.util import find_spec -from typing import Any +from typing import Any, TypeVar, overload from fnmatch import fnmatch from pathlib import Path + from pyspark.sql import Column from pyspark.sql.types import StructType @@ -29,9 +30,17 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") + + COLUMN_NORMALIZE_EXPRESSION = re.compile("[^a-zA-Z0-9]+") COLUMN_PATTERN = re.compile(r"Column<'(.*?)(?: AS (\w+))?'>$", re.DOTALL) INVALID_COLUMN_NAME_PATTERN = re.compile(r"[\s,;{}\(\)\n\t=]+") +_UNRESOLVED_PLACEHOLDER_PATTERN = re.compile(r"\{\{[^}]*\}\}") +_SCALAR_VARIABLE_TYPES = (str, int, float, bool, Decimal, datetime.date, datetime.datetime, datetime.time) + +VariableValue = str | int | float | bool | Decimal | datetime.date | datetime.datetime | datetime.time +"""Supported scalar types for variable substitution values.""" def get_column_name_or_alias( @@ -544,6 +553,132 @@ def missing_required_packages(packages: list[str]) -> bool: return not all(find_spec(spec) for spec in packages) +def _replace_template(text: str, variables: dict[str, str]) -> str: + """Replace **{{ key }}** placeholders in *text* with values from *variables*. + + Uses a single-pass regex substitution. + Tolerates whitespace inside braces (e.g. **{{ key }}**, **{{key}}**). + Logs a warning if any unresolved **{{ ... }}** placeholders remain after substitution. + + Args: + text: Input string potentially containing **{{ key }}** placeholders. + variables: Pre-stringified mapping of placeholder names to values. + + Returns: + String with all matching placeholders replaced. + """ + if not variables: + if _UNRESOLVED_PLACEHOLDER_PATTERN.search(text): + logger.warning(f"Unresolved placeholder found: '{text}'") + return text + + def _resolve(match_obj: re.Match[str]) -> str: + key = match_obj.group(0).strip("{} \t") + if key in variables: + return variables[key] + unresolved.append(key) + return match_obj.group(0) + + unresolved: list[str] = [] + output = _UNRESOLVED_PLACEHOLDER_PATTERN.sub(_resolve, text) + if unresolved: + logger.warning( + f"Unresolved placeholders found: {unresolved}. " + f"They may be resolved at runtime for certain checks (e.g. sql_query)." + ) + return output + + +@overload +def _substitute_variables(obj: str, variables: dict[str, str]) -> str: ... + + +@overload +def _substitute_variables(obj: list[T], variables: dict[str, str]) -> list[T]: ... + + +@overload +def _substitute_variables(obj: dict[str, T], variables: dict[str, str]) -> dict[str, T]: ... + + +@overload +def _substitute_variables(obj: T, variables: dict[str, str]) -> T: ... + + +def _substitute_variables(obj: Any, variables: dict[str, str]) -> Any: + """Recursively replace **{{ key }}** placeholders in all string values within *obj*. + + Traverses dicts, lists, and strings. Non-string/non-collection values are + returned unchanged. Dict keys are not substituted. + + Args: + obj: A string, dict, list, or other value to process. + variables: Pre-stringified mapping of placeholder names to values. + + Returns: + A new object with all string values having placeholders replaced. + """ + if isinstance(obj, str): + return _replace_template(obj, variables) + if isinstance(obj, dict): + return {k: _substitute_variables(v, variables) for k, v in obj.items()} + if isinstance(obj, list): + return [_substitute_variables(item, variables) for item in obj] + return obj + + +def _validate_variable_types(variables: dict[str, VariableValue]) -> None: + """Raise :class:`InvalidParameterError` if any variable value is not a supported scalar type.""" + for key, val in variables.items(): + if not isinstance(val, _SCALAR_VARIABLE_TYPES): + raise InvalidParameterError( + f"Variable '{key}' has unsupported type '{type(val).__name__}'. " + f"Only scalar types are supported: str, int, float, bool, Decimal, " + f"datetime.date, datetime.datetime, datetime.time." + ) + + +def resolve_variables(checks: list[dict], variables: dict[str, VariableValue] | None) -> list[dict]: + """Resolve variable substitution in check definitions. + + Replaces placeholders in all string values of *checks* with the corresponding values + from *variables*. + + Variable values must be scalar types (e.g. *str*, *int*, *float*, *bool*, *Decimal*, + *datetime.date*, *datetime.datetime*, *datetime.time*). Non-string scalars are + converted to strings via *str()* in the substituted string. Collection type + variables (e.g. *list*, *dict*, *set*, etc.) are rejected with + *databricks.labs.dqx.errors.InvalidParameterError* because their string representation + is rarely meaningful in SQL or column expressions. + + Logs a warning for any placeholders that remain unresolved after substitution + (e.g. misspelled variable names). + + Note: + Variable values substituted into *sql_expression* checks are not sanitized and are + passed directly to *F.expr()*. Callers must **ensure variable values come from trusted + sources** to prevent SQL injection. + + Args: + checks: List of check definition dictionaries (metadata format). + variables: Mapping of placeholder names to scalar replacement values. + If *None* or empty the checks are returned unchanged. + + Returns: + A new list of check dicts with placeholders resolved, or the original list + when no substitution is needed. + + Raises: + InvalidParameterError: If any variable value is not a supported scalar type. + """ + if not variables: + return checks + + _validate_variable_types(variables) + str_variables = {k: str(v) for k, v in variables.items()} + return _substitute_variables(checks, str_variables) + + def get_file_extension(file_path: str | os.PathLike) -> str: """ Extract file extension from a file path. diff --git a/tests/integration/test_save_and_load_checks_from_table.py b/tests/integration/test_save_and_load_checks_from_table.py index 745f920a3..a506902fc 100644 --- a/tests/integration/test_save_and_load_checks_from_table.py +++ b/tests/integration/test_save_and_load_checks_from_table.py @@ -12,6 +12,7 @@ TableChecksStorageConfig, InstallationChecksStorageConfig, BaseChecksStorageConfig, + ExtraParams, ) from databricks.labs.dqx.engine import DQEngine from databricks.labs.dqx.errors import InvalidConfigError, UnsafeSqlQueryError @@ -677,3 +678,58 @@ def test_save_idempotency_overwrite_mode(ws, make_schema, make_random, spark): checks = engine.load_checks(config=TableChecksStorageConfig(location=table_name)) assert checks == EXPECTED_CHECKS_FROM_TABLE_LOAD[1:], "Idempotency guard must prevent duplicate overwrite" + + +def test_save_and_load_checks_from_table_with_variables(ws, make_schema, make_random, spark): + """Save checks with {{ }} placeholders resolved via engine-level + per-call variables, then load and apply.""" + catalog_name = TEST_CATALOG + schema_name = make_schema(catalog_name=catalog_name).name + table_name = f"{catalog_name}.{schema_name}.{make_random(10).lower()}" + + checks_with_placeholders = [ + { + "criticality": "{{ crit }}", + "name": "{{ col1 }}_null_check", + "check": { + "function": "is_not_null", + "arguments": {"column": "{{ col1 }}"}, + }, + }, + { + "criticality": "warn", + "name": "{{ col2 }}_not_empty_check", + "check": { + "function": "is_not_null_and_not_empty", + "arguments": {"column": "{{ col2 }}"}, + }, + "filter": "{{ filter_col }} IS NOT NULL", + }, + ] + + # Engine-level defaults; per-call override: crit "warn" -> "error" + extra_params = ExtraParams(variables={"crit": "warn", "col1": "a", "col2": "b", "filter_col": "a"}) + engine = DQEngine(ws, spark, extra_params=extra_params) + + config = TableChecksStorageConfig(location=table_name) + engine.save_checks(checks_with_placeholders, config=config, variables={"crit": "error"}) + + # Load — checks are already resolved, no variables needed + loaded = engine.load_checks(config=config) + + expected = [ + { + "name": "a_null_check", + "criticality": "error", + "check": {"function": "is_not_null", "arguments": {"column": "a"}}, + }, + { + "name": "b_not_empty_check", + "criticality": "warn", + "check": {"function": "is_not_null_and_not_empty", "arguments": {"column": "b"}}, + "filter": "a IS NOT NULL", + }, + ] + assert loaded == expected, "Variable substitution did not resolve correctly after table roundtrip." + + # Verify the resolved checks are valid and can be applied end-to-end + assert not engine.validate_checks(loaded).has_errors diff --git a/tests/unit/test_checks_validation.py b/tests/unit/test_checks_validation.py index 89723a214..7c61a8e74 100644 --- a/tests/unit/test_checks_validation.py +++ b/tests/unit/test_checks_validation.py @@ -1,5 +1,5 @@ from pyspark.sql.functions import col -from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.engine import DQEngine, DQEngineCore def dummy_func(column): @@ -486,3 +486,60 @@ def test_is_in_range_float_arguments(): ] status = DQEngine.validate_checks(checks) assert not status.has_errors + + +def test_validate_checks_with_variables(tmp_path): + checks_yaml = """ + - criticality: "{{ crit }}" + check: + function: is_not_null + arguments: + column: "{{ col }}" + """ + checks_file = tmp_path / "checks.yml" + checks_file.write_text(checks_yaml, encoding="utf-8") + checks = DQEngineCore.load_checks_from_local_file(str(checks_file), variables={"crit": "error", "col": "b"}) + + status = DQEngine.validate_checks(checks) + assert not status.has_errors + + +def test_validate_checks_with_variables_invalid_after_substitution(tmp_path): + checks_yaml = """ + - criticality: "{{ crit }}" + check: + function: is_not_null + arguments: + column: b + """ + checks_file = tmp_path / "checks.yml" + checks_file.write_text(checks_yaml, encoding="utf-8") + checks = DQEngineCore.load_checks_from_local_file(str(checks_file), variables={"crit": "not_a_valid_criticality"}) + + status = DQEngine.validate_checks(checks) + expected_error = ( + "Invalid 'criticality' value: 'not_a_valid_criticality'. Expected 'warn' or 'error'. " + "Check details: {'criticality': 'not_a_valid_criticality', " + "'check': {'function': 'is_not_null', 'arguments': {'column': 'b'}}}" + ) + assert status.errors[0] == expected_error + + +def test_validate_checks_without_variables_fails_on_placeholders(): + checks = [ + { + "criticality": "{{ crit }}", + "check": { + "function": "is_not_null", + "arguments": {"column": "b"}, + }, + }, + ] + + status = DQEngine.validate_checks(checks) + expected_error = ( + "Invalid 'criticality' value: '{{ crit }}'. Expected 'warn' or 'error'. " + "Check details: {'criticality': '{{ crit }}', " + "'check': {'function': 'is_not_null', 'arguments': {'column': 'b'}}}" + ) + assert status.errors[0] == expected_error diff --git a/tests/unit/test_load_checks.py b/tests/unit/test_load_checks.py index c0bdb2bd3..22a5e111f 100644 --- a/tests/unit/test_load_checks.py +++ b/tests/unit/test_load_checks.py @@ -1,15 +1,21 @@ +import logging from unittest.mock import create_autospec import pytest +from pyspark.sql import SparkSession + +from databricks.labs.dqx.checks_storage import ( + BaseChecksStorageHandlerFactory, + ChecksStorageHandler, + VolumeFileChecksStorageHandler, +) +from databricks.labs.dqx.config import FileChecksStorageConfig, VolumeFileChecksStorageConfig, ExtraParams +from databricks.labs.dqx.engine import DQEngine, DQEngineCore +from databricks.labs.dqx.errors import InvalidCheckError, CheckDownloadError, InvalidConfigError from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound from databricks.sdk.service.files import DownloadResponse -from databricks.labs.dqx.checks_storage import VolumeFileChecksStorageHandler -from databricks.labs.dqx.config import VolumeFileChecksStorageConfig -from databricks.labs.dqx.engine import DQEngineCore -from databricks.labs.dqx.errors import InvalidCheckError, CheckDownloadError, InvalidConfigError - def test_load_checks_from_local_file_json(make_local_check_file_as_json, expected_checks): file = make_local_check_file_as_json @@ -84,3 +90,421 @@ def test_file_download_contents_read_none(): with pytest.raises(NotFound, match="No contents at Unity Catalog volume path"): handler.load(VolumeFileChecksStorageConfig(location="/Volumes/catalog/schema/volume/test_path.yml")) + + +def test_load_checks_from_local_file_with_variables(tmp_path): + content = """- criticality: "{{ crit }}" + check: + function: is_not_null + arguments: + column: "{{ col }}" +""" + file_path = tmp_path / "checks.yml" + file_path.write_text(content, encoding="utf-8") + + checks = DQEngineCore.load_checks_from_local_file(str(file_path), variables={"crit": "error", "col": "id"}) + + assert checks == [ + {"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "id"}}}, + ] + + +def test_load_checks_from_local_file_variables_none(tmp_path): + content = """- criticality: error + check: + function: is_not_null + arguments: + column: id +""" + file_path = tmp_path / "checks.yml" + file_path.write_text(content, encoding="utf-8") + + checks = DQEngineCore.load_checks_from_local_file(str(file_path), variables=None) + + assert checks == [ + {"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "id"}}}, + ] + + +def test_load_checks_from_local_file_variables_empty(tmp_path): + content = """- criticality: error + check: + function: is_not_null + arguments: + column: id +""" + file_path = tmp_path / "checks.yml" + file_path.write_text(content, encoding="utf-8") + + checks = DQEngineCore.load_checks_from_local_file(str(file_path), variables={}) + + assert checks == [ + {"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "id"}}}, + ] + + +def test_load_checks_with_variables(): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + raw_checks = [ + {"criticality": "{{ crit }}", "check": {"function": "is_not_null", "arguments": {"column": "{{ col }}"}}} + ] + + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + mock_handler.load.return_value = raw_checks + + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory) + config = FileChecksStorageConfig(location="checks.yml") + + checks = engine.load_checks(config, variables={"crit": "error", "col": "id"}) + + assert checks == [ + {"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "id"}}}, + ] + + +def test_load_checks_variables_none(): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + raw_checks = [{"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "id"}}}] + + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + mock_handler.load.return_value = raw_checks + + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory) + config = FileChecksStorageConfig(location="checks.yml") + + checks = engine.load_checks(config, variables=None) + + assert checks == raw_checks + + +def test_load_checks_from_local_file_unresolved_placeholder(tmp_path, caplog): + content = """- criticality: error + check: + function: is_not_null + arguments: + column: "{{ col }}" +""" + file_path = tmp_path / "checks.yml" + file_path.write_text(content, encoding="utf-8") + + with caplog.at_level(logging.WARNING): + checks = DQEngineCore.load_checks_from_local_file(str(file_path), variables={"other": "value"}) + + assert checks[0]["check"]["arguments"]["column"] == "{{ col }}" + assert any("Unresolved placeholder" in msg for msg in caplog.messages) + + +def test_load_checks_with_engine_default_variables(): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + raw_checks = [ + {"criticality": "{{ crit }}", "check": {"function": "is_not_null", "arguments": {"column": "{{ col }}"}}} + ] + + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + mock_handler.load.return_value = raw_checks + + extra_params = ExtraParams(variables={"crit": "error", "col": "default_col"}) + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory, extra_params=extra_params) + config = FileChecksStorageConfig(location="checks.yml") + + checks = engine.load_checks(config) + + assert checks == [ + {"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "default_col"}}}, + ] + + +def test_load_checks_per_call_overrides_engine_defaults(): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + raw_checks = [ + {"criticality": "{{ crit }}", "check": {"function": "is_not_null", "arguments": {"column": "{{ col }}"}}} + ] + + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + mock_handler.load.return_value = raw_checks + + extra_params = ExtraParams(variables={"crit": "warn", "col": "default_col"}) + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory, extra_params=extra_params) + config = FileChecksStorageConfig(location="checks.yml") + + checks = engine.load_checks(config, variables={"crit": "error"}) + + assert checks == [ + {"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "default_col"}}}, + ] + + +def test_extra_params_variables_substitution_and_overrides(tmp_path): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + checks_yaml = """ + - criticality: error + name: "id_check" + check: + function: is_not_null + arguments: + column: "{{ target_col }}" + user_metadata: + env: "{{ environment }}" + rule_id: "{{ nested_var }}" + """ + checks_file = tmp_path / "checks_extra.yml" + checks_file.write_text(checks_yaml, encoding="utf-8") + + raw_checks = DQEngineCore.load_checks_from_local_file(str(checks_file)) + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + mock_handler.load.return_value = raw_checks + + extra_params = ExtraParams(variables={"target_col": "id", "environment": "dev", "nested_var": "old"}) + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory, extra_params=extra_params) + config = FileChecksStorageConfig(location=str(checks_file)) + + checks = engine.load_checks(config, variables={"environment": "prod", "nested_var": "new"}) + + assert checks[0]["check"]["arguments"]["column"] == "id" + assert checks[0]["user_metadata"]["env"] == "prod" + assert checks[0]["user_metadata"]["rule_id"] == "new" + + +def test_load_checks_by_metadata_and_split_with_variables(tmp_path): + + checks_yaml = """ + - criticality: error + name: "{{ col }}_null_check" + check: + function: is_not_null_and_not_empty + arguments: + column: "{{ col }}" + - criticality: warn + check: + function: sql_expression + arguments: + expression: "{{ expr_col }} > {{ threshold }}" + """ + checks_file = tmp_path / "checks.yml" + checks_file.write_text(checks_yaml, encoding="utf-8") + checks = DQEngineCore.load_checks_from_local_file( + str(checks_file), variables={"col": "b", "expr_col": "a", "threshold": 1} + ) + + assert checks == [ + { + "criticality": "error", + "name": "b_null_check", + "check": { + "function": "is_not_null_and_not_empty", + "arguments": {"column": "b"}, + }, + }, + { + "criticality": "warn", + "check": { + "function": "sql_expression", + "arguments": {"expression": "a > 1"}, + }, + }, + ] + + +def test_load_checks_sql_query_no_variables(tmp_path, caplog): + checks_yaml = """ + - criticality: error + check: + function: sql_query + arguments: + query: "SELECT id, COUNT(*) > 0 AS condition FROM {{ input_view }} GROUP BY id" + merge_columns: + - id + """ + checks_file = tmp_path / "checks.yml" + checks_file.write_text(checks_yaml, encoding="utf-8") + + with caplog.at_level(logging.WARNING): + checks = DQEngineCore.load_checks_from_local_file(str(checks_file)) + + assert not any("input_view" in msg for msg in caplog.messages) + + assert checks == [ + { + "criticality": "error", + "check": { + "function": "sql_query", + "arguments": { + "query": "SELECT id, COUNT(*) > 0 AS condition FROM {{ input_view }} GROUP BY id", + "merge_columns": ["id"], + }, + }, + }, + ] + + +def test_load_checks_sql_query_with_variables(tmp_path, caplog): + checks_yaml = """ + - criticality: "{{ crit }}" + name: "count_check" + check: + function: sql_query + arguments: + query: "SELECT id, COUNT(*) > 0 AS condition FROM {{ input_view }} GROUP BY id" + merge_columns: + - id + """ + checks_file = tmp_path / "checks.yml" + checks_file.write_text(checks_yaml, encoding="utf-8") + + with caplog.at_level(logging.WARNING): + checks = DQEngineCore.load_checks_from_local_file(str(checks_file), variables={"crit": "error"}) + + assert checks == [ + { + "criticality": "error", + "name": "count_check", + "check": { + "function": "sql_query", + "arguments": { + "query": "SELECT id, COUNT(*) > 0 AS condition FROM {{ input_view }} GROUP BY id", + "merge_columns": ["id"], + }, + }, + }, + ] + # {{ input_view }} is left unresolved — it is resolved at runtime by sql_query itself + assert any("input_view" in msg for msg in caplog.messages) + + +def test_save_checks_with_variables(): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + raw_checks = [ + {"criticality": "{{ crit }}", "check": {"function": "is_not_null", "arguments": {"column": "{{ col }}"}}} + ] + + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory) + config = FileChecksStorageConfig(location="checks.yml") + + engine.save_checks(raw_checks, config, variables={"crit": "error", "col": "id"}) + + mock_handler.save.assert_called_once_with( + [{"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "id"}}}], + config, + ) + + +def test_save_checks_variables_none(): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + raw_checks = [{"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "id"}}}] + + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory) + config = FileChecksStorageConfig(location="checks.yml") + + engine.save_checks(raw_checks, config, variables=None) + + mock_handler.save.assert_called_once_with(raw_checks, config) + + +def test_save_checks_with_engine_default_variables(): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + raw_checks = [ + {"criticality": "{{ crit }}", "check": {"function": "is_not_null", "arguments": {"column": "{{ col }}"}}} + ] + + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + + extra_params = ExtraParams(variables={"crit": "error", "col": "default_col"}) + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory, extra_params=extra_params) + config = FileChecksStorageConfig(location="checks.yml") + + engine.save_checks(raw_checks, config) + + mock_handler.save.assert_called_once_with( + [{"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "default_col"}}}], + config, + ) + + +def test_save_checks_per_call_overrides_engine_defaults(): + ws = create_autospec(WorkspaceClient) + mock_spark = create_autospec(SparkSession) + + raw_checks = [ + {"criticality": "{{ crit }}", "check": {"function": "is_not_null", "arguments": {"column": "{{ col }}"}}} + ] + + mock_factory = create_autospec(BaseChecksStorageHandlerFactory) + mock_handler = create_autospec(ChecksStorageHandler) + mock_factory.create.return_value = mock_handler + + extra_params = ExtraParams(variables={"crit": "warn", "col": "default_col"}) + engine = DQEngine(ws, spark=mock_spark, checks_handler_factory=mock_factory, extra_params=extra_params) + config = FileChecksStorageConfig(location="checks.yml") + + engine.save_checks(raw_checks, config, variables={"crit": "error"}) + + mock_handler.save.assert_called_once_with( + [{"criticality": "error", "check": {"function": "is_not_null", "arguments": {"column": "default_col"}}}], + config, + ) + + +def test_load_checks_by_metadata_with_variables_name_and_filter(tmp_path): + + checks_yaml = """ + - criticality: error + name: "{{ col }}_greater_than_{{ threshold }}" + check: + function: sql_expression + arguments: + expression: "{{ col }} > {{ threshold }}" + filter: "{{ filter_col }} IS NOT NULL" + """ + checks_file = tmp_path / "checks.yml" + checks_file.write_text(checks_yaml, encoding="utf-8") + checks = DQEngineCore.load_checks_from_local_file( + str(checks_file), variables={"col": "a", "threshold": 1, "filter_col": "a"} + ) + + assert checks == [ + { + "criticality": "error", + "name": "a_greater_than_1", + "check": { + "function": "sql_expression", + "arguments": {"expression": "a > 1"}, + }, + "filter": "a IS NOT NULL", + } + ] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 9cec5a554..ccd3e366c 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,5 @@ -from datetime import date, datetime +import logging +from datetime import date, datetime, time from decimal import Decimal from enum import Enum from typing import Any @@ -20,6 +21,7 @@ safe_strip_file_from_path, missing_required_packages, get_file_extension, + resolve_variables, ) from databricks.labs.dqx.rule import normalize_bound_args from databricks.labs.dqx.errors import InvalidParameterError, InvalidConfigError @@ -529,3 +531,319 @@ def test_get_file_extension_with_path_object(): """Test get_file_extension function with Path object.""" file_path = Path("/path/to/file.json") assert get_file_extension(file_path) == ".json" + + +def test_resolve_variables_replaces_all_string_fields(): + checks = [ + { + "criticality": "error", + "name": "{{ col }}_not_null", + "check": { + "function": "is_not_null", + "arguments": {"column": "{{ col }}"}, + }, + "filter": "{{ filter_col }} = 'active'", + } + ] + variables = {"col": "email", "filter_col": "status"} + result = resolve_variables(checks, variables) + + assert result[0]["name"] == "email_not_null" + assert result[0]["check"]["arguments"]["column"] == "email" + assert result[0]["filter"] == "status = 'active'" + + +def test_resolve_variables_empty_variables(): + checks = [{"name": "{{ x }}"}] + result = resolve_variables(checks, {}) + assert result is checks # same object, no copy + assert result[0]["name"] == "{{ x }}" + + +def test_resolve_variables_non_string_values_converted(): + checks = [ + { + "check": { + "function": "sql_expression", + "arguments": {"expression": "{{ col }} > {{ threshold }}"}, + }, + } + ] + variables = {"col": "age", "threshold": 18} + result = resolve_variables(checks, variables) + assert result[0]["check"]["arguments"]["expression"] == "age > 18" + + +def test_resolve_variables_does_not_mutate_original(): + checks = [ + { + "name": "{{ col }}_check", + "check": { + "function": "is_not_null", + "arguments": {"column": "{{ col }}"}, + }, + } + ] + variables = {"col": "name"} + resolve_variables(checks, variables) + + # Original must be unchanged + assert checks[0]["name"] == "{{ col }}_check" + assert checks[0]["check"]["arguments"]["column"] == "{{ col }}" + + +def test_resolve_variables_nested_dicts(): + checks = [ + { + "check": { + "function": "sql_expression", + "arguments": { + "expression": "{{ col }} IS NOT NULL", + }, + }, + "user_metadata": {"owner": "{{ team }}"}, + } + ] + variables = {"col": "id", "team": "data-eng"} + result = resolve_variables(checks, variables) + + assert result[0]["check"]["arguments"]["expression"] == "id IS NOT NULL" + assert result[0]["user_metadata"]["owner"] == "data-eng" + + +def test_resolve_variables_partial_replacement(): + checks = [{"name": "{{ p1 }}_greater_than_{{ threshold }}"}] + variables = {"p1": "column1", "threshold": 10} + result = resolve_variables(checks, variables) + assert result[0]["name"] == "column1_greater_than_10" + + +def test_resolve_variables_unresolved_placeholder_warning(caplog): + checks = [{"name": "{{ resolved }}_{{ unresolved }}"}] + variables = {"resolved": "ok"} + with caplog.at_level(logging.WARNING, logger="databricks.labs.dqx.utils"): + result = resolve_variables(checks, variables) + + assert result[0]["name"] == "ok_{{ unresolved }}" + assert any("Unresolved placeholder" in msg for msg in caplog.messages) + + +def test_resolve_variables_whitespace_tolerance(): + checks = [ + {"a": "{{x}}", "b": "{{ x }}", "c": "{{ x }}"}, + ] + variables = {"x": "val"} + result = resolve_variables(checks, variables) + assert result[0]["a"] == "val" + assert result[0]["b"] == "val" + assert result[0]["c"] == "val" + + +def test_resolve_variables_non_string_dict_values_untouched(): + checks = [ + { + "criticality": "error", + "check": { + "function": "is_in_list", + "arguments": {"column": "{{ col }}", "allowed": [1, 2, 3]}, + }, + } + ] + variables = {"col": "status"} + result = resolve_variables(checks, variables) + assert result[0]["check"]["arguments"]["column"] == "status" + assert result[0]["check"]["arguments"]["allowed"] == [1, 2, 3] + assert result[0]["criticality"] == "error" + + +def test_resolve_variables_for_each_column(): + checks = [ + { + "criticality": "error", + "check": { + "function": "is_not_null", + "for_each_column": ["{{ col1 }}", "{{ col2 }}"], + }, + } + ] + variables = {"col1": "first_name", "col2": "last_name"} + result = resolve_variables(checks, variables) + assert result[0]["check"]["for_each_column"] == ["first_name", "last_name"] + + +def test_resolve_variables_multiple_checks(): + checks = [ + { + "name": "{{ col }}_not_null", + "check": {"function": "is_not_null", "arguments": {"column": "{{ col }}"}}, + }, + { + "name": "{{ col2 }}_not_empty", + "check": {"function": "is_not_empty", "arguments": {"column": "{{ col2 }}"}}, + }, + ] + variables = {"col": "a", "col2": "b"} + result = resolve_variables(checks, variables) + assert result[0]["name"] == "a_not_null" + assert result[0]["check"]["arguments"]["column"] == "a" + assert result[1]["name"] == "b_not_empty" + assert result[1]["check"]["arguments"]["column"] == "b" + + +def test_resolve_variables_empty_checks_list(): + result = resolve_variables([], {"col": "x"}) + assert result == [] + + +def test_resolve_variables_empty_string_value(): + checks = [{"name": "prefix_{{ col }}_suffix"}] + result = resolve_variables(checks, {"col": ""}) + assert result[0]["name"] == "prefix__suffix" + + +def test_resolve_variables_value_contains_braces(): + """Variable value itself contains {{ }} — should NOT be re-expanded.""" + checks = [{"expr": "{{ col }}"}] + result = resolve_variables(checks, {"col": "{{ other }}"}) + assert result[0]["expr"] == "{{ other }}" + + +def test_resolve_variables_key_with_regex_special_chars(): + """Variable keys with regex metacharacters must be escaped properly.""" + checks = [{"name": "{{ col.name }}_check", "filter": "{{ col+1 }} > 0"}] + variables = {"col.name": "revenue", "col+1": "amount"} + result = resolve_variables(checks, variables) + assert result[0]["name"] == "revenue_check" + assert result[0]["filter"] == "amount > 0" + + +def test_resolve_variables_same_placeholder_repeated_in_string(): + checks = [{"expr": "{{ x }} + {{ x }}"}] + result = resolve_variables(checks, {"x": "col"}) + assert result[0]["expr"] == "col + col" + + +def test_resolve_variables_deeply_nested(): + checks = [{"a": {"b": {"c": {"d": "{{ v }}"}}}}] + result = resolve_variables(checks, {"v": "deep"}) + assert result[0]["a"]["b"]["c"]["d"] == "deep" + + +def test_resolve_variables_value_with_backslash(): + """Backslashes in values should be treated literally (no regex group refs).""" + checks = [{"path": "{{ p }}"}] + result = resolve_variables(checks, {"p": r"C:\Users\test"}) + assert result[0]["path"] == r"C:\Users\test" + + +def test_resolve_variables_rejects_list_value(): + checks = [{"check": {"arguments": {"column": "{{ col }}"}}}] + with pytest.raises(InvalidParameterError, match="unsupported type 'list'"): + resolve_variables(checks, {"col": ["a", "b"]}) + + +def test_resolve_variables_rejects_dict_value(): + checks = [{"check": {"arguments": {"column": "{{ col }}"}}}] + with pytest.raises(InvalidParameterError, match="unsupported type 'dict'"): + resolve_variables(checks, {"col": {"nested": "value"}}) + + +def test_resolve_variables_accepts_decimal_value(): + checks = [{"expr": "col > {{ threshold }}"}] + result = resolve_variables(checks, {"threshold": Decimal("3.14")}) + assert result[0]["expr"] == "col > 3.14" + + +def test_resolve_variables_accepts_bool_value(): + checks = [{"expr": "{{ flag }}"}] + result = resolve_variables(checks, {"flag": True}) + assert result[0]["expr"] == "True" + + +def test_resolve_variables_false_bool(): + checks = [{"expr": "{{ flag }}"}] + result = resolve_variables(checks, {"flag": False}) + assert result[0]["expr"] == "False" + + +def test_resolve_variables_rejects_none_value(): + checks = [{"col": "{{ col }}"}] + with pytest.raises(InvalidParameterError, match="unsupported type 'NoneType'"): + resolve_variables(checks, {"col": None}) + + +def test_resolve_variables_rejects_set_value(): + checks = [{"col": "{{ col }}"}] + with pytest.raises(InvalidParameterError, match="unsupported type 'set'"): + resolve_variables(checks, {"col": {1, 2}}) + + +def test_resolve_variables_rejects_tuple_value(): + checks = [{"col": "{{ col }}"}] + with pytest.raises(InvalidParameterError, match="unsupported type 'tuple'"): + resolve_variables(checks, {"col": (1, 2)}) + + +def test_resolve_variables_dict_keys_not_substituted(): + checks = [{"{{ col }}": "value", "other": "{{ col }}"}] + result = resolve_variables(checks, {"col": "replaced"}) + assert "{{ col }}" in result[0] + assert result[0]["{{ col }}"] == "value" + assert result[0]["other"] == "replaced" + + +def test_resolve_variables_nan(): + checks = [{"expr": "{{ val }}"}] + result = resolve_variables(checks, {"val": float("nan")}) + assert result[0]["expr"] == "nan" + + +def test_resolve_variables_inf(): + checks = [{"expr": "{{ val }}"}] + result = resolve_variables(checks, {"val": float("inf")}) + assert result[0]["expr"] == "inf" + + +def test_resolve_variables_multiple_unresolved_warns(caplog): + checks = [{"expr": "{{ a }} and {{ b }}"}] + with caplog.at_level(logging.WARNING): + result = resolve_variables(checks, {"a": "x"}) + assert result[0]["expr"] == "x and {{ b }}" + assert any("Unresolved placeholder" in msg for msg in caplog.messages) + + +def test_resolve_variables_none_vars_no_warning(caplog): + checks = [{"col": "{{ x }}"}] + with caplog.at_level(logging.WARNING): + result = resolve_variables(checks, None) + assert result[0]["col"] == "{{ x }}" + assert not any("Unresolved placeholder" in msg for msg in caplog.messages) + + with caplog.at_level(logging.WARNING): + result = resolve_variables(checks, {}) + assert result[0]["col"] == "{{ x }}" + assert not any("Unresolved placeholder" in msg for msg in caplog.messages) + + +def test_resolve_variables_unicode_values(): + checks = [{"col": "{{ col }}"}] + result = resolve_variables(checks, {"col": "prénom"}) + assert result[0]["col"] == "prénom" + + +def test_resolve_variables_accepts_date(): + checks = [{"expr": "date > '{{ d }}'"}] + result = resolve_variables(checks, {"d": date(2024, 1, 15)}) + assert result[0]["expr"] == "date > '2024-01-15'" + + +def test_resolve_variables_accepts_datetime(): + checks = [{"expr": "ts > '{{ ts }}'"}] + result = resolve_variables(checks, {"ts": datetime(2024, 1, 15, 10, 30)}) + assert "2024-01-15" in result[0]["expr"] + + +def test_resolve_variables_accepts_time(): + checks = [{"expr": "t > '{{ t }}'"}] + result = resolve_variables(checks, {"t": time(10, 30)}) + assert result[0]["expr"] == "t > '10:30:00'"