From 31255514e8e3b1402d2e171451c49f2e9167c536 Mon Sep 17 00:00:00 2001 From: andreas-neumann_data Date: Thu, 21 May 2026 03:03:01 +0000 Subject: [PATCH 1/3] [CONNECT][SDP] Add Python APIs for Auto CDC SCD Type 1 --- python/pyspark/pipelines/__init__.py | 2 + python/pyspark/pipelines/api.py | 139 +++++++++++++++++- python/pyspark/pipelines/flow.py | 40 ++++- .../pipelines/graph_element_registry.py | 6 +- .../spark_connect_graph_element_registry.py | 47 +++++- .../tests/local_graph_element_registry.py | 10 +- .../tests/test_graph_element_registry.py | 80 +++++++++- 7 files changed, 316 insertions(+), 8 deletions(-) diff --git a/python/pyspark/pipelines/__init__.py b/python/pyspark/pipelines/__init__.py index d93320e963766..bd41c9ecd6b2e 100644 --- a/python/pyspark/pipelines/__init__.py +++ b/python/pyspark/pipelines/__init__.py @@ -16,6 +16,7 @@ # from pyspark.pipelines.api import ( append_flow, + create_auto_cdc_flow, create_streaming_table, materialized_view, table, @@ -25,6 +26,7 @@ __all__ = [ "append_flow", + "create_auto_cdc_flow", "create_streaming_table", "materialized_view", "table", diff --git a/python/pyspark/pipelines/api.py b/python/pyspark/pipelines/api.py index e6bae4f832d51..c1849824e236c 100644 --- a/python/pyspark/pipelines/api.py +++ b/python/pyspark/pipelines/api.py @@ -14,12 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Callable, Dict, List, Optional, Union, overload +from typing import Callable, Dict, List, Literal, Optional, Union, overload from pyspark.errors import PySparkTypeError from pyspark.pipelines.graph_element_registry import get_active_graph_element_registry from pyspark.pipelines.type_error_utils import validate_optional_list_of_str_arg -from pyspark.pipelines.flow import Flow, QueryFunction +from pyspark.pipelines.flow import AutoCdcFlow, Flow, QueryFunction from pyspark.pipelines.source_code_location import ( get_caller_source_code_location, ) @@ -29,6 +29,8 @@ TemporaryView, Sink, ) +from pyspark.sql import Column +from pyspark.sql import functions as F from pyspark.sql.types import StructType @@ -525,3 +527,136 @@ def create_sink( comment=None, ) get_active_graph_element_registry().register_output(sink) + + +def create_auto_cdc_flow( + target: str, + source: str, + keys: Union[List[str], List[Column]], + sequence_by: Union[str, Column], + apply_as_deletes: Optional[Union[str, Column]] = None, + apply_as_truncates: Optional[Union[str, Column]] = None, + column_list: Optional[Union[List[str], List[Column]]] = None, + except_column_list: Optional[Union[List[str], List[Column]]] = None, + stored_as_scd_type: Optional[Literal[1, "1"]] = None, + name: Optional[str] = None, + ignore_null_updates_column_list: Optional[Union[List[str], List[Column]]] = None, + ignore_null_updates_except_column_list: Optional[Union[List[str], List[Column]]] = None, +) -> None: + """ + Create an Auto CDC flow into the target table from the Change Data Capture (CDC) source. + Target table must have already been created using create_streaming_table function. Only one + of column_list and except_column_list can be specified. + + Example: + create_auto_cdc_flow( + target = "target", + source = "source", + keys = ["key"], + sequence_by = "sequence_expr", + ignore_null_updates_column_list = ["value"], + column_list = ["key", "value"], + ) + + Note that for keys, sequence_by, column_list, except_column_list, + ignore_null_updates_column_list, and ignore_null_updates_except_column_list the arguments + have to be column identifiers without qualifiers, e.g. they cannot be + col("sourceTable.keyId"). + + :param target: The name of the target table that receives the Auto CDC flow. + :param source: The name of the CDC source to stream from. + :param keys: The column or combination of columns that uniquely identify a row in the source \ + data. This is used to identify which CDC events apply to specific records in the target \ + table. These keys also identify records in the target table, e.g., if there exists a record \ + for given keys and the CDC source has an UPSERT operation for the same keys, we will update \ + the existing record. At least one key must be provided. This should be a list of column \ + identifiers without qualifiers, expressed as either Python strings or Pyspark Columns. + :param sequence_by: An expression that we use to order the source data. This can be expressed \ + as either a Python string or Pyspark Expression. + :param apply_as_deletes: Delete condition for the merged operation. This should be a string of \ + expression e.g. "operation = 'DELETE'" + :param apply_as_truncates: Truncate condition for the merged operation. This should be a string \ + expression e.g. "operation = 'TRUNCATE'" + :param column_list: Columns that will be included in the output table. This should be a list \ + of column identifiers without qualifiers, expressed as either Python strings or Pyspark \ + Column. Only one of column_list and except_column_list can be specified. + :param except_column_list: Columns that will be excluded in the output table. This should be a \ + list of column identifiers without qualifiers, expressed as either Python strings or Pyspark \ + Column. Only one of column_list and except_column_list can be specified. When this is \ + specified, all columns in the dataframe of the target table except those in this list will \ + be in the output table. + :param stored_as_scd_type: The SCD type for the target table. Only 1 (or "1") is supported. \ + When not specified the server default applies. + :param name: The name of the flow for this create_auto_cdc_flow command. When unspecified this \ + will build a "default flow" with name equal to the target name. + :param ignore_null_updates_column_list: Subset of columns to ignore null values in during \ + updates. When a source row has a null for one of these columns, the existing value in the \ + target is preserved. Only one of ignore_null_updates_column_list and \ + ignore_null_updates_except_column_list can be specified. + :param ignore_null_updates_except_column_list: Columns excluded from null-update ignoring. \ + All other columns will have null values ignored during updates. Only one of \ + ignore_null_updates_column_list and ignore_null_updates_except_column_list can be specified. + """ + keys = _normalize_column_list(keys) + + column_list = _normalize_optional_column_list(column_list) + except_column_list = _normalize_optional_column_list(except_column_list) + ignore_null_updates_column_list = _normalize_optional_column_list( + ignore_null_updates_column_list + ) + ignore_null_updates_except_column_list = _normalize_optional_column_list( + ignore_null_updates_except_column_list + ) + + if isinstance(sequence_by, str): + sequence_by = F.expr(sequence_by) + + if isinstance(apply_as_deletes, str): + apply_as_deletes = F.expr(apply_as_deletes) + + if isinstance(apply_as_truncates, str): + apply_as_truncates = F.expr(apply_as_truncates) + + if stored_as_scd_type is not None and str(stored_as_scd_type) != "1": + raise PySparkTypeError( + errorClass="NOT_EXPECTED_TYPE", + messageParameters={ + "arg_name": "stored_as_scd_type", + "expected_type": "Literal[1, '1']", + "arg_type": type(stored_as_scd_type).__name__, + }, + ) + + source_code_location = get_caller_source_code_location(stacklevel = 1) + + flow = AutoCdcFlow( + name = name, + target = target, + source = source, + keys = keys, + sequence_by = sequence_by, + apply_as_deletes = apply_as_deletes, + apply_as_truncates = apply_as_truncates, + column_list = column_list, + except_column_list = except_column_list, + stored_as_scd_type = stored_as_scd_type, + ignore_null_updates_column_list = ignore_null_updates_column_list, + ignore_null_updates_except_column_list = ignore_null_updates_except_column_list, + source_code_location = source_code_location, + ) + + get_active_graph_element_registry().register_auto_cdc_flow(flow) + + +def _normalize_optional_column_list( + column_list: Optional[Union[List[str], List[Column]]], +) -> Optional[List[Column]]: + if column_list is None: + return None + return _normalize_column_list(column_list) + + +def _normalize_column_list( + column_list: List[Union[str, Column]], +) -> List[Column]: + return [F.col(c) if isinstance(c, str) else c for c in column_list] diff --git a/python/pyspark/pipelines/flow.py b/python/pyspark/pipelines/flow.py index 7c499c0b36221..bcb91da4044c3 100644 --- a/python/pyspark/pipelines/flow.py +++ b/python/pyspark/pipelines/flow.py @@ -15,9 +15,10 @@ # limitations under the License. # from dataclasses import dataclass -from typing import Callable, Dict +from typing import Callable, Dict, List, Literal, Optional from pyspark.sql import DataFrame +from pyspark.sql import Column from pyspark.pipelines.source_code_location import SourceCodeLocation QueryFunction = Callable[[], DataFrame] @@ -41,3 +42,40 @@ class Flow: spark_conf: Dict[str, str] source_code_location: SourceCodeLocation func: QueryFunction + + +@dataclass(frozen=True) +class AutoCdcFlow: + """Definition of an Auto CDC flow in a pipeline dataflow graph. + + An Auto CDC flow applies Change Data Capture (CDC) events from a source to a target + streaming table. + + :param name: Optional name of the flow. When None, defaults to the target name. + :param target: The name of the target streaming table. + :param source: The name of the CDC source to stream from. + :param keys: Column(s) that uniquely identify a row in source and target data. + :param sequence_by: Expression used to order the source data. + :param apply_as_deletes: Optional delete condition for the merged operation. + :param apply_as_truncates: Optional truncate condition for the merged operation. + :param column_list: Optional columns to include in the output table. + :param except_column_list: Optional columns to exclude from the output table. + :param stored_as_scd_type: Optional SCD type for the target table. Only 1 is supported. + :param ignore_null_updates_column_list: Subset of columns to ignore null in updates. + :param ignore_null_updates_except_column_list: Columns excluded from null-ignore in updates. + :param source_code_location: The location of the source code that created this flow. + """ + + name: Optional[str] + target: str + source: str + keys: List[Column] + sequence_by: Column + apply_as_deletes: Optional[Column] + apply_as_truncates: Optional[Column] + column_list: Optional[List[Column]] + except_column_list: Optional[List[Column]] + stored_as_scd_type: Optional[Literal[1, "1"]] + ignore_null_updates_column_list: Optional[List[Column]] + ignore_null_updates_except_column_list: Optional[List[Column]] + source_code_location: SourceCodeLocation diff --git a/python/pyspark/pipelines/graph_element_registry.py b/python/pyspark/pipelines/graph_element_registry.py index 8e311fc2ca98e..4eddabaabda0e 100644 --- a/python/pyspark/pipelines/graph_element_registry.py +++ b/python/pyspark/pipelines/graph_element_registry.py @@ -19,7 +19,7 @@ from pathlib import Path from pyspark.pipelines.output import Output -from pyspark.pipelines.flow import Flow +from pyspark.pipelines.flow import AutoCdcFlow, Flow from contextlib import contextmanager from contextvars import ContextVar from typing import Generator, Optional @@ -42,6 +42,10 @@ def register_output(self, output: Output) -> None: def register_flow(self, flow: Flow) -> None: """Add the given flow to the registry.""" + @abstractmethod + def register_auto_cdc_flow(self, flow: AutoCdcFlow) -> None: + """Add the given Auto CDC flow to the registry.""" + @abstractmethod def register_sql(self, sql_text: str, file_path: Path) -> None: """Register a string containing SQL statements the dataflow graph. diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index ab88317908302..80c04e9504e32 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -27,12 +27,12 @@ StreamingTable, TemporaryView, ) -from pyspark.pipelines.flow import Flow +from pyspark.pipelines.flow import AutoCdcFlow, Flow from pyspark.pipelines.graph_element_registry import GraphElementRegistry from pyspark.pipelines.source_code_location import SourceCodeLocation from pyspark.sql.connect.types import pyspark_types_to_proto_types from pyspark.sql.types import StructType -from typing import Any, cast +from typing import Any, List, Optional, cast import pyspark.sql.connect.proto as pb2 from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context @@ -133,6 +133,49 @@ def register_flow(self, flow: Flow) -> None: command.pipeline_command.define_flow.CopyFrom(inner_command) self._client.execute_command(command) + def register_auto_cdc_flow(self, flow: AutoCdcFlow) -> None: + from pyspark.sql.connect.column import Column as ConnectColumn + + def to_plan(col: Column) -> Any: + return cast(ConnectColumn, col).to_plan(self._client) + + def to_plans(cols: Optional[List[Column]]) -> list: + return [] if cols is None else [to_plan(c) for c in cols] + + auto_cdc_details = pb2.PipelineCommand.DefineFlow.AutoCdcFlowDetails( + source=flow.source, + keys=to_plans(flow.keys), + sequence_by=to_plan(flow.sequence_by), + column_list=to_plans(flow.column_list), + except_column_list=to_plans(flow.except_column_list), + ignore_null_updates_column_list=to_plans(flow.ignore_null_updates_column_list), + ignore_null_updates_except_column_list=to_plans( + flow.ignore_null_updates_except_column_list + ), + ) + if flow.stored_as_scd_type is not None: + auto_cdc_details.stored_as_scd_type = ( + pb2.PipelineCommand.DefineFlow.SCDType.SCD_TYPE_1 + ) + if flow.apply_as_deletes is not None: + auto_cdc_details.apply_as_deletes.CopyFrom(to_plan(flow.apply_as_deletes)) + if flow.apply_as_truncates is not None: + auto_cdc_details.apply_as_truncates.CopyFrom(to_plan(flow.apply_as_truncates)) + + inner_command = pb2.PipelineCommand.DefineFlow( + dataflow_graph_id=self._dataflow_graph_id, + target_dataset_name=flow.target, + auto_cdc_flow_details=auto_cdc_details, + sql_conf={}, + source_code_location=source_code_location_to_proto(flow.source_code_location), + ) + if flow.name is not None: + inner_command.flow_name = flow.name + + command = pb2.Command() + command.pipeline_command.define_flow.CopyFrom(inner_command) + self._client.execute_command(command) + def register_sql(self, sql_text: str, file_path: Path) -> None: inner_command = pb2.PipelineCommand.DefineSqlGraphElements( dataflow_graph_id=self._dataflow_graph_id, diff --git a/python/pyspark/pipelines/tests/local_graph_element_registry.py b/python/pyspark/pipelines/tests/local_graph_element_registry.py index 0e22641930b9a..3b9ea15a1ed6b 100644 --- a/python/pyspark/pipelines/tests/local_graph_element_registry.py +++ b/python/pyspark/pipelines/tests/local_graph_element_registry.py @@ -20,7 +20,7 @@ from typing import List, Sequence from pyspark.pipelines.output import Output -from pyspark.pipelines.flow import Flow +from pyspark.pipelines.flow import AutoCdcFlow, Flow from pyspark.pipelines.graph_element_registry import GraphElementRegistry @@ -34,6 +34,7 @@ class LocalGraphElementRegistry(GraphElementRegistry): def __init__(self) -> None: self._outputs: List[Output] = [] self._flows: List[Flow] = [] + self._auto_cdc_flows: List[AutoCdcFlow] = [] self._sql_files: List[SqlFile] = [] def register_output(self, output: Output) -> None: @@ -42,6 +43,9 @@ def register_output(self, output: Output) -> None: def register_flow(self, flow: Flow) -> None: self._flows.append(flow) + def register_auto_cdc_flow(self, flow: AutoCdcFlow) -> None: + self._auto_cdc_flows.append(flow) + def register_sql(self, sql_text: str, file_path: Path) -> None: self._sql_files.append(SqlFile(sql_text, file_path)) @@ -53,6 +57,10 @@ def outputs(self) -> Sequence[Output]: def flows(self) -> Sequence[Flow]: return self._flows + @property + def auto_cdc_flows(self) -> Sequence[AutoCdcFlow]: + return self._auto_cdc_flows + @property def sql_files(self) -> Sequence[SqlFile]: return self._sql_files diff --git a/python/pyspark/pipelines/tests/test_graph_element_registry.py b/python/pyspark/pipelines/tests/test_graph_element_registry.py index 1e6fcf224a0ac..e78c627356d11 100644 --- a/python/pyspark/pipelines/tests/test_graph_element_registry.py +++ b/python/pyspark/pipelines/tests/test_graph_element_registry.py @@ -17,9 +17,10 @@ import unittest -from pyspark.errors import PySparkException +from pyspark.errors import PySparkException, PySparkTypeError from pyspark.pipelines.graph_element_registry import graph_element_registration_context from pyspark import pipelines as dp +from pyspark.pipelines.flow import AutoCdcFlow from pyspark.pipelines.output import Sink from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry from typing import cast @@ -97,6 +98,83 @@ def flow2(): self.assertEqual(sink_obj.options["key1"], "value1") assert sink_obj.source_code_location.filename.endswith("test_graph_element_registry.py") + def test_create_auto_cdc_flow(self): + from pyspark.sql.connect.functions.builtin import col, expr + + registry = LocalGraphElementRegistry() + with graph_element_registration_context(registry): + dp.create_streaming_table("target") + dp.create_auto_cdc_flow( + target="target", + source="source", + keys=[col("key")], + sequence_by=expr("seq"), + ) + + self.assertEqual(len(registry.outputs), 1) + self.assertEqual(len(registry.auto_cdc_flows), 1) + + flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0]) + self.assertEqual(flow.target, "target") + self.assertEqual(flow.source, "source") + self.assertIsNone(flow.name) + assert flow.source_code_location.filename.endswith("test_graph_element_registry.py") + + def test_create_auto_cdc_flow_with_all_args(self): + from pyspark.sql.connect.functions.builtin import col, expr + + registry = LocalGraphElementRegistry() + with graph_element_registration_context(registry): + dp.create_streaming_table("tgt") + dp.create_auto_cdc_flow( + target="tgt", + source="src", + keys=[col("id")], + sequence_by=expr("ts"), + apply_as_deletes=expr("op = 'DELETE'"), + apply_as_truncates=expr("op = 'TRUNCATE'"), + column_list=[col("id"), col("val")], + ignore_null_updates_column_list=[col("val")], + stored_as_scd_type=1, + name="my_flow", + ) + + flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0]) + self.assertEqual(flow.name, "my_flow") + self.assertIsNotNone(flow.ignore_null_updates_column_list) + self.assertEqual(flow.stored_as_scd_type, 1) + + def test_create_auto_cdc_flow_stored_as_scd_type_string(self): + from pyspark.sql.connect.functions.builtin import col, expr + + registry = LocalGraphElementRegistry() + with graph_element_registration_context(registry): + dp.create_auto_cdc_flow( + target="t", + source="s", + keys=[col("k")], + sequence_by=expr("seq"), + stored_as_scd_type="1", + ) + + flow = cast(AutoCdcFlow, registry.auto_cdc_flows[0]) + self.assertEqual(flow.stored_as_scd_type, "1") + + def test_create_auto_cdc_flow_invalid_scd_type(self): + from pyspark.sql.connect.functions.builtin import col, expr + + registry = LocalGraphElementRegistry() + with graph_element_registration_context(registry): + with self.assertRaises(PySparkTypeError) as ctx: + dp.create_auto_cdc_flow( + target="t", + source="s", + keys=[col("k")], + sequence_by=expr("seq"), + stored_as_scd_type=2, # type: ignore[arg-type] + ) + self.assertEqual(ctx.exception.getCondition(), "NOT_EXPECTED_TYPE") + def test_definition_without_graph_element_registry(self): for decorator in [dp.table, dp.temporary_view, dp.materialized_view]: with self.assertRaises(PySparkException) as context: From c531ae128672fdccce33e53e1bcb6da2f2e288e3 Mon Sep 17 00:00:00 2001 From: andreas-neumann_data Date: Thu, 21 May 2026 20:46:45 +0000 Subject: [PATCH 2/3] [CONNECT][SDP] Minor style cleanups in Auto CDC Python API - Remove spaces around = in keyword arguments (PEP 8) - Fix type hint: List[Union[str, Column]] -> Union[List[str], List[Column]] - Reorder imports and collapse unnecessary line continuations Co-authored-by: Isaac --- python/pyspark/pipelines/api.py | 30 +++++++++---------- .../spark_connect_graph_element_registry.py | 17 +++++------ 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/python/pyspark/pipelines/api.py b/python/pyspark/pipelines/api.py index c1849824e236c..9fc613bc72ec4 100644 --- a/python/pyspark/pipelines/api.py +++ b/python/pyspark/pipelines/api.py @@ -627,22 +627,22 @@ def create_auto_cdc_flow( }, ) - source_code_location = get_caller_source_code_location(stacklevel = 1) + source_code_location = get_caller_source_code_location(stacklevel=1) flow = AutoCdcFlow( - name = name, - target = target, - source = source, - keys = keys, - sequence_by = sequence_by, - apply_as_deletes = apply_as_deletes, - apply_as_truncates = apply_as_truncates, - column_list = column_list, - except_column_list = except_column_list, - stored_as_scd_type = stored_as_scd_type, - ignore_null_updates_column_list = ignore_null_updates_column_list, - ignore_null_updates_except_column_list = ignore_null_updates_except_column_list, - source_code_location = source_code_location, + name=name, + target=target, + source=source, + keys=keys, + sequence_by=sequence_by, + apply_as_deletes=apply_as_deletes, + apply_as_truncates=apply_as_truncates, + column_list=column_list, + except_column_list=except_column_list, + stored_as_scd_type=stored_as_scd_type, + ignore_null_updates_column_list=ignore_null_updates_column_list, + ignore_null_updates_except_column_list=ignore_null_updates_except_column_list, + source_code_location=source_code_location, ) get_active_graph_element_registry().register_auto_cdc_flow(flow) @@ -657,6 +657,6 @@ def _normalize_optional_column_list( def _normalize_column_list( - column_list: List[Union[str, Column]], + column_list: Union[List[str], List[Column]], ) -> List[Column]: return [F.col(c) if isinstance(c, str) else c for c in column_list] diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index 80c04e9504e32..3be13fdfc6755 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -17,8 +17,13 @@ from pathlib import Path from pyspark.errors import PySparkTypeError -from pyspark.sql import SparkSession +from pyspark.sql import SparkSession, Column from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame +from pyspark.sql.connect.types import pyspark_types_to_proto_types +from pyspark.sql.types import StructType +from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context +from pyspark.pipelines.flow import AutoCdcFlow, Flow +from pyspark.pipelines.graph_element_registry import GraphElementRegistry from pyspark.pipelines.output import ( Output, MaterializedView, @@ -27,14 +32,10 @@ StreamingTable, TemporaryView, ) -from pyspark.pipelines.flow import AutoCdcFlow, Flow -from pyspark.pipelines.graph_element_registry import GraphElementRegistry from pyspark.pipelines.source_code_location import SourceCodeLocation -from pyspark.sql.connect.types import pyspark_types_to_proto_types -from pyspark.sql.types import StructType from typing import Any, List, Optional, cast + import pyspark.sql.connect.proto as pb2 -from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context class SparkConnectGraphElementRegistry(GraphElementRegistry): @@ -154,9 +155,7 @@ def to_plans(cols: Optional[List[Column]]) -> list: ), ) if flow.stored_as_scd_type is not None: - auto_cdc_details.stored_as_scd_type = ( - pb2.PipelineCommand.DefineFlow.SCDType.SCD_TYPE_1 - ) + auto_cdc_details.stored_as_scd_type = pb2.PipelineCommand.DefineFlow.SCDType.SCD_TYPE_1 if flow.apply_as_deletes is not None: auto_cdc_details.apply_as_deletes.CopyFrom(to_plan(flow.apply_as_deletes)) if flow.apply_as_truncates is not None: From ab9fd5ae67834f7966884a4003a4ebe4c562246d Mon Sep 17 00:00:00 2001 From: andreas-neumann_data Date: Fri, 22 May 2026 14:02:27 +0000 Subject: [PATCH 3/3] [CONNECT][SDP] Fix and improve Auto CDC flow tests - Move inline imports to module level - Fix assertNone -> assertIsNone - Fix assertEqual(stored_as_scd_type, "1") -> assertIsNone for default case - Add missing assertions for optional fields in test_create_auto_cdc_flow Co-authored-by: Isaac --- .../pipelines/tests/test_graph_element_registry.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/pyspark/pipelines/tests/test_graph_element_registry.py b/python/pyspark/pipelines/tests/test_graph_element_registry.py index e78c627356d11..42749c98bf51e 100644 --- a/python/pyspark/pipelines/tests/test_graph_element_registry.py +++ b/python/pyspark/pipelines/tests/test_graph_element_registry.py @@ -23,6 +23,7 @@ from pyspark.pipelines.flow import AutoCdcFlow from pyspark.pipelines.output import Sink from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry +from pyspark.sql.connect.functions.builtin import col, expr from typing import cast @@ -99,8 +100,6 @@ def flow2(): assert sink_obj.source_code_location.filename.endswith("test_graph_element_registry.py") def test_create_auto_cdc_flow(self): - from pyspark.sql.connect.functions.builtin import col, expr - registry = LocalGraphElementRegistry() with graph_element_registration_context(registry): dp.create_streaming_table("target") @@ -118,11 +117,14 @@ def test_create_auto_cdc_flow(self): self.assertEqual(flow.target, "target") self.assertEqual(flow.source, "source") self.assertIsNone(flow.name) + self.assertIsNone(flow.ignore_null_updates_column_list) + self.assertIsNone(flow.ignore_null_updates_except_column_list) + self.assertIsNone(flow.stored_as_scd_type) + self.assertIsNone(flow.apply_as_deletes) + self.assertIsNone(flow.apply_as_truncates) assert flow.source_code_location.filename.endswith("test_graph_element_registry.py") def test_create_auto_cdc_flow_with_all_args(self): - from pyspark.sql.connect.functions.builtin import col, expr - registry = LocalGraphElementRegistry() with graph_element_registration_context(registry): dp.create_streaming_table("tgt") @@ -145,8 +147,6 @@ def test_create_auto_cdc_flow_with_all_args(self): self.assertEqual(flow.stored_as_scd_type, 1) def test_create_auto_cdc_flow_stored_as_scd_type_string(self): - from pyspark.sql.connect.functions.builtin import col, expr - registry = LocalGraphElementRegistry() with graph_element_registration_context(registry): dp.create_auto_cdc_flow( @@ -161,8 +161,6 @@ def test_create_auto_cdc_flow_stored_as_scd_type_string(self): self.assertEqual(flow.stored_as_scd_type, "1") def test_create_auto_cdc_flow_invalid_scd_type(self): - from pyspark.sql.connect.functions.builtin import col, expr - registry = LocalGraphElementRegistry() with graph_element_registration_context(registry): with self.assertRaises(PySparkTypeError) as ctx: