diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index d04ee6a926..d17335ad16 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -1999,6 +1999,23 @@ def remove_comments(sql_query: str, uuids: List[str]) -> str: ) +def _validate_iceberg_named_version_ref( + value: Optional[str], param_name: str +) -> Optional[str]: + """Validate a non-empty Iceberg tag/branch/version-ref name.""" + if value is None: + return None + if not isinstance(value, str): + raise ValueError( + f"'{param_name}' must be a string Iceberg name, " + f"got {type(value).__name__}." + ) + stripped = value.strip() + if not stripped: + raise ValueError(f"'{param_name}' must be a non-empty Iceberg name.") + return stripped + + class TimeTravelConfig(NamedTuple): """Configuration for time travel operations.""" @@ -2010,6 +2027,8 @@ class TimeTravelConfig(NamedTuple): stream: Optional[str] = None version: Optional[int] = None version_tag: Optional[str] = None + version_ref: Optional[str] = None + branch: Optional[str] = None @staticmethod def validate_and_normalize_params( @@ -2021,6 +2040,8 @@ def validate_and_normalize_params( stream: Optional[str] = None, version: Optional[int] = None, version_tag: Optional[str] = None, + version_ref: Optional[str] = None, + branch: Optional[str] = None, ) -> Optional["TimeTravelConfig"]: """ Validates and normalizes time travel parameters. @@ -2042,9 +2063,31 @@ def validate_and_normalize_params( Raises: ValueError: If parameters are invalid. """ + version_tag = _validate_iceberg_named_version_ref(version_tag, "version_tag") + version_ref = _validate_iceberg_named_version_ref(version_ref, "version_ref") + branch = _validate_iceberg_named_version_ref(branch, "branch") + + named_ref_count = sum( + arg is not None for arg in (version_tag, version_ref, branch) + ) + if named_ref_count > 1: + raise ValueError( + "Exactly one of 'version_tag', 'version_ref', or 'branch' may be " + "provided for Iceberg named-ref time travel." + ) + time_travel_arg_count = sum( arg is not None - for arg in (statement, offset, timestamp, stream, version, version_tag) + for arg in ( + statement, + offset, + timestamp, + stream, + version, + version_tag, + version_ref, + branch, + ) ) # Validate mode @@ -2079,32 +2122,25 @@ def validate_and_normalize_params( f"'version' must be an int Iceberg snapshot id, got {type(version).__name__}." ) - # version_tag (Iceberg tag name, mapped to Snowflake's - # ``AT(VERSION_TAG => '')`` grammar) only works with 'at' mode — - # Iceberg tag reads are positional (bound to a specific snapshot), - # not range-of-time, so ``BEFORE`` has no meaning. - if version_tag is not None and time_travel_mode.lower() != "at": - raise ValueError( - "Iceberg version_tag time travel can only be used with " - "time_travel_mode='at', not 'before'." - ) - - # Validate version_tag type — Iceberg tag names are strings. Empty - # strings are invalid. - if version_tag is not None: - if not isinstance(version_tag, str): + # Named Iceberg refs (tag / branch / version_ref) map to Snowflake's + # ``AT(VERSION_REF => '')`` grammar and only work with 'at' mode. + for param_name, param_value in ( + ("version_tag", version_tag), + ("version_ref", version_ref), + ("branch", branch), + ): + if param_value is not None and time_travel_mode.lower() != "at": raise ValueError( - f"'version_tag' must be a string Iceberg tag name, " - f"got {type(version_tag).__name__}." + f"Iceberg {param_name} time travel can only be used with " + "time_travel_mode='at', not 'before'." ) - if not version_tag: - raise ValueError("'version_tag' must be a non-empty Iceberg tag name.") # Validate exactly one parameter is provided if time_travel_arg_count != 1: raise ValueError( "Exactly one of 'statement', 'offset', 'timestamp', 'stream', " - "'version', or 'version_tag' must be provided." + "'version', 'version_tag', 'version_ref', or 'branch' must be " + "provided." ) # Normalize timestamp @@ -2140,6 +2176,8 @@ def validate_and_normalize_params( stream=stream, version=version, version_tag=version_tag, + version_ref=version_ref, + branch=branch, ) def generate_sql_clause(self) -> str: @@ -2150,11 +2188,12 @@ def generate_sql_clause(self) -> str: Returns: SQL clause like " AT (TIMESTAMP => TO_TIMESTAMP_NTZ('...'))", " AT (VERSION => 1234567890)" for Iceberg snapshot id time travel, - or " AT (VERSION_TAG => 'release_v1')" for Iceberg tag time - travel. + or " AT (VERSION_REF => 'audit-branch')" for Iceberg tag/branch + time travel. Note on escaping: string-valued parameters (``statement``, - ``stream``, ``version_tag``, ``timestamp``) are embedded inside + ``stream``, ``version_tag``, ``version_ref``, ``branch``, + ``timestamp``) are embedded inside single-quoted SQL literals via the existing ``str_to_sql`` helper in ``analyzer.datatype_mapper`` so embedded ``'``, ``\\`` and newline characters are properly escaped. This keeps the @@ -2181,8 +2220,12 @@ def generate_sql_clause(self) -> str: clause += f"(STREAM => {str_to_sql(self.stream)})" elif self.version is not None: clause += f"(VERSION => {self.version})" + elif self.version_ref is not None: + clause += f"(VERSION_REF => {str_to_sql(self.version_ref)})" + elif self.branch is not None: + clause += f"(VERSION_REF => {str_to_sql(self.branch)})" elif self.version_tag is not None: - clause += f"(VERSION_TAG => {str_to_sql(self.version_tag)})" + clause += f"(VERSION_REF => {str_to_sql(self.version_tag)})" elif self.timestamp is not None: if self.timestamp_type is not None: timestamp_type = self.timestamp_type.upper() diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 16f4b641de..84e998127f 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -164,12 +164,17 @@ def _extract_time_travel_from_options(options: dict) -> dict: Special handling for 'VERSION_TAG' / 'VERSION-TAG' (Iceberg tag name) — both aliases map to the internal ``version_tag`` time travel parameter - and emit ``AT(VERSION_TAG => '')`` on the Snowflake side (see - Spark Iceberg's ``VERSION AS OF ''`` reader path): - - Automatically set time_travel_mode to 'at' - (tag reads are positional — bound to a specific snapshot — not - range-of-time) - - Cannot be used with time_travel_mode='before' (raises error) + and emit ``AT(VERSION_REF => '')`` on the Snowflake side. + + Special handling for 'VERSION_REF' / 'VERSION-REF' (Iceberg tag or + branch name) — maps to the internal ``version_ref`` parameter. + + Special handling for 'BRANCH' (Spark Iceberg WAP branch read) — maps + to the internal ``branch`` parameter and emits + ``AT(VERSION_REF => '')``. + + Special handling for 'TAG' (Spark Iceberg ``SparkReadOptions.TAG``) — + maps to the internal ``version_tag`` parameter. """ result = {} excluded_keys = set() @@ -220,27 +225,64 @@ def _extract_time_travel_from_options(options: dict) -> dict: ) result["time_travel_mode"] = "at" - # Handle Iceberg tag (``version_tag`` / ``version-tag``). Both aliases - # route to the internal ``version_tag`` parameter and emit - # ``AT(VERSION_TAG => '')`` server-side. Auto-sets mode='at'. - version_tag_value = options.get("VERSION_TAG") - version_tag_source = "version_tag" - if version_tag_value is None: - version_tag_value = options.get("VERSION-TAG") - version_tag_source = "version-tag" - if version_tag_value is not None: + def _set_named_ref_option( + result: dict, + *, + param_name: str, + option_source: str, + raw_value: Any, + ) -> None: if ( "TIME_TRAVEL_MODE" in options and options["TIME_TRAVEL_MODE"].lower() == "before" ): raise ValueError( - f"Cannot use '{version_tag_source}' option with " - "time_travel_mode='before'. Iceberg tag time travel only " - "supports time_travel_mode='at'." + f"Cannot use '{option_source}' option with " + "time_travel_mode='before'. Iceberg named-ref time travel " + "only supports time_travel_mode='at'." ) - result["version_tag"] = str(version_tag_value) + result[param_name] = str(raw_value) result["time_travel_mode"] = "at" + # Handle Iceberg tag (``version_tag`` / ``version-tag`` / Spark ``tag``). + version_tag_value = options.get("VERSION_TAG") + version_tag_source = "version_tag" + if version_tag_value is None: + version_tag_value = options.get("VERSION-TAG") + version_tag_source = "version-tag" + if version_tag_value is None: + version_tag_value = options.get("TAG") + version_tag_source = "tag" + if version_tag_value is not None: + _set_named_ref_option( + result, + param_name="version_tag", + option_source=version_tag_source, + raw_value=version_tag_value, + ) + + version_ref_value = options.get("VERSION_REF") + version_ref_source = "version_ref" + if version_ref_value is None: + version_ref_value = options.get("VERSION-REF") + version_ref_source = "version-ref" + if version_ref_value is not None: + _set_named_ref_option( + result, + param_name="version_ref", + option_source=version_ref_source, + raw_value=version_ref_value, + ) + + branch_value = options.get("BRANCH") + if branch_value is not None: + _set_named_ref_option( + result, + param_name="branch", + option_source="branch", + raw_value=branch_value, + ) + for option_key, param_name in _TIME_TRAVEL_OPTIONS_PARAMS_MAP.items(): if option_key in options and option_key not in excluded_keys: result[param_name] = options[option_key] @@ -671,6 +713,8 @@ def table( # still pass them without us advertising the surface. version = kwargs.pop("version", None) version_tag = kwargs.pop("version_tag", None) + version_ref = kwargs.pop("version_ref", None) + branch = kwargs.pop("branch", None) if kwargs: raise TypeError( f"table() got unexpected keyword arguments: {sorted(kwargs)}" @@ -700,15 +744,23 @@ def table( time_travel_mode is not None or version is not None or version_tag is not None + or version_ref is not None + or branch is not None ): - # If version / version_tag is provided without mode, default to - # 'at' — snapshot ids and tag reads only make sense with AT - # (symmetric with the as-of-timestamp option handling). + # If version / named-ref params are provided without mode, + # default to 'at'. effective_mode = ( time_travel_mode if time_travel_mode else ( - "at" if (version is not None or version_tag is not None) else None + "at" + if ( + version is not None + or version_tag is not None + or version_ref is not None + or branch is not None + ) + else None ) ) time_travel_params = { @@ -720,6 +772,8 @@ def table( "stream": stream, "version": version, "version_tag": version_tag, + "version_ref": version_ref, + "branch": branch, } else: # if time_travel_mode is not provided, extract time travel config from options diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index cd5ce3ecfb..62a6bae1c6 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2780,6 +2780,8 @@ def table( # still pass them without us advertising the surface. version = kwargs.pop("version", None) version_tag = kwargs.pop("version_tag", None) + version_ref = kwargs.pop("version_ref", None) + branch = kwargs.pop("branch", None) if kwargs: raise TypeError( f"table() got unexpected keyword arguments: {sorted(kwargs)}" @@ -2823,6 +2825,8 @@ def table( stream=stream, version=version, version_tag=version_tag, + version_ref=version_ref, + branch=branch, ) # Replace API call origin for table set_api_call_source(t, "Session.table") diff --git a/src/snowflake/snowpark/table.py b/src/snowflake/snowpark/table.py index fc863a6434..8d5deefa86 100644 --- a/src/snowflake/snowpark/table.py +++ b/src/snowflake/snowpark/table.py @@ -303,6 +303,8 @@ def __init__( # still pass them without us advertising the surface. version = kwargs.pop("version", None) version_tag = kwargs.pop("version_tag", None) + version_ref = kwargs.pop("version_ref", None) + branch = kwargs.pop("branch", None) if kwargs: raise TypeError( f"Table() got unexpected keyword arguments: {sorted(kwargs)}" @@ -336,6 +338,8 @@ def __init__( stream=stream, version=version, version_tag=version_tag, + version_ref=version_ref, + branch=branch, ) snowflake_table_plan = SnowflakeTable( diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index c0ca0be038..faac805bdd 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -8601,3 +8601,21 @@ def test_iceberg_version_tag_time_travel_dataframe_reader_option(session): session.read.option("version-tag", tag_name).table(table_fqn).collect() ) assert via_kwarg == via_option == via_hyphen_option + + +@pytest.mark.skip( + reason=( + "Requires a Snowflake-managed Iceberg table with a WAP branch and " + "FEATURE_ICEBERG_TIME_TRAVEL enabled on the account. Tested manually." + ) +) +def test_iceberg_branch_time_travel_dataframe_reader_option(session): + """End-to-end: Spark Iceberg ``branch`` option maps to VERSION_REF.""" + table_fqn = "ICEBERG_GAP_TEST_HORIZON.TESTSCHEMA.snapshot_demo" + branch_name = "audit_branch" + + via_branch = session.read.option("branch", branch_name).table(table_fqn).collect() + via_version_ref = ( + session.read.option("version_ref", branch_name).table(table_fqn).collect() + ) + assert via_branch == via_version_ref diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 141b5b9b59..7950f3f593 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -993,7 +993,7 @@ def test_time_travel_version_tag(): """Test Iceberg tag time travel via ``version_tag`` parameter. Covers SQL generation, validation, and the ``mode='at'``-only restriction. - Verifies the SQL matches the Snowflake ``AT(VERSION_TAG => '')`` + Verifies the SQL matches the Snowflake ``AT(VERSION_REF => '')`` grammar — the released tag-only form of Iceberg time travel (Spark Iceberg's ``VERSION AS OF ''`` for tag reads). """ @@ -1003,7 +1003,7 @@ def test_time_travel_version_tag(): ) assert config.version_tag == "release_v1" assert config.time_travel_mode == "at" - assert config.generate_sql_clause() == " AT (VERSION_TAG => 'release_v1')" + assert config.generate_sql_clause() == " AT (VERSION_REF => 'release_v1')" # Hyphens and dots in tag names round-trip through the SQL clause — # Iceberg allows these in tag names (e.g. ``snapshot-2023-01-01``, @@ -1011,19 +1011,19 @@ def test_time_travel_version_tag(): config_hyphen = TimeTravelConfig.validate_and_normalize_params( time_travel_mode="at", version_tag="audit-tag" ) - assert config_hyphen.generate_sql_clause() == " AT (VERSION_TAG => 'audit-tag')" + assert config_hyphen.generate_sql_clause() == " AT (VERSION_REF => 'audit-tag')" config_dotted = TimeTravelConfig.validate_and_normalize_params( time_travel_mode="at", version_tag="snapshot-2023-01-01" ) assert ( config_dotted.generate_sql_clause() - == " AT (VERSION_TAG => 'snapshot-2023-01-01')" + == " AT (VERSION_REF => 'snapshot-2023-01-01')" ) # Direct construction also generates the right SQL. direct = TimeTravelConfig(time_travel_mode="AT", version_tag="EOM_JULY_2025") - assert direct.generate_sql_clause() == " AT (VERSION_TAG => 'EOM_JULY_2025')" + assert direct.generate_sql_clause() == " AT (VERSION_REF => 'EOM_JULY_2025')" # version_tag + 'before' is invalid (tag reads are positional — bound # to a specific snapshot — not range-of-time). @@ -1056,9 +1056,7 @@ def test_time_travel_version_tag(): ) # Non-string version_tag is rejected. - with pytest.raises( - ValueError, match="'version_tag' must be a string Iceberg tag name" - ): + with pytest.raises(ValueError, match="'version_tag' must be a string Iceberg name"): TimeTravelConfig.validate_and_normalize_params( time_travel_mode="at", version_tag=123 ) @@ -1091,7 +1089,7 @@ def test_time_travel_string_literal_escaping(): config = TimeTravelConfig.validate_and_normalize_params( time_travel_mode="at", version_tag="release_'s" ) - assert config.generate_sql_clause() == " AT (VERSION_TAG => 'release_''s')" + assert config.generate_sql_clause() == " AT (VERSION_REF => 'release_''s')" # An attempted injection payload — the closing quote and the # injected DROP must be fully neutralized by the doubled quotes. @@ -1100,7 +1098,7 @@ def test_time_travel_string_literal_escaping(): ) assert ( config_injection.generate_sql_clause() - == " AT (VERSION_TAG => 'x''); DROP TABLE foo; --')" + == " AT (VERSION_REF => 'x''); DROP TABLE foo; --')" ) # statement — same escape applies (pre-existing pattern hardened @@ -1139,7 +1137,7 @@ def test_time_travel_string_literal_escaping(): ) # Single backslash in the input → doubled in the SQL. assert ( - config_backslash.generate_sql_clause() == " AT (VERSION_TAG => 'weird\\\\name')" + config_backslash.generate_sql_clause() == " AT (VERSION_REF => 'weird\\\\name')" ) config_newline = TimeTravelConfig.validate_and_normalize_params( @@ -1149,7 +1147,7 @@ def test_time_travel_string_literal_escaping(): # the SQL text (so the literal stays on one line for Snowflake's # parser). assert ( - config_newline.generate_sql_clause() == " AT (VERSION_TAG => 'line1\\nline2')" + config_newline.generate_sql_clause() == " AT (VERSION_REF => 'line1\\nline2')" ) @@ -1158,8 +1156,8 @@ def test_extract_time_travel_version_tag_option(): Both ``version_tag`` and ``version-tag`` aliases map to the internal ``version_tag`` time-travel parameter and auto-set - ``time_travel_mode='at'`` (``AT(VERSION_TAG => '')`` is the only - valid form — tag reads are positional). + ``time_travel_mode='at'`` (``AT(VERSION_REF => '')`` is the unified + named-ref form for tag reads). """ from snowflake.snowpark.dataframe_reader import _extract_time_travel_from_options @@ -1193,3 +1191,70 @@ def test_extract_time_travel_version_tag_option(): _extract_time_travel_from_options( {"VERSION-TAG": "release_v1", "TIME_TRAVEL_MODE": "before"} ) + + +def test_time_travel_version_ref_and_branch(): + """Iceberg tag/branch reads emit Snowflake ``AT(VERSION_REF => ...)``.""" + config_ref = TimeTravelConfig.validate_and_normalize_params( + time_travel_mode="at", version_ref="historical-snapshot" + ) + assert config_ref.version_ref == "historical-snapshot" + assert ( + config_ref.generate_sql_clause() == " AT (VERSION_REF => 'historical-snapshot')" + ) + + config_branch = TimeTravelConfig.validate_and_normalize_params( + time_travel_mode="at", branch="audit-branch" + ) + assert config_branch.branch == "audit-branch" + assert config_branch.generate_sql_clause() == " AT (VERSION_REF => 'audit-branch')" + + with pytest.raises( + ValueError, + match="Exactly one of 'version_tag', 'version_ref', or 'branch'", + ): + TimeTravelConfig.validate_and_normalize_params( + time_travel_mode="at", + version_tag="tag1", + branch="audit-branch", + ) + + +def test_extract_time_travel_branch_and_version_ref_options(): + from snowflake.snowpark.dataframe_reader import _extract_time_travel_from_options + + assert _extract_time_travel_from_options({"BRANCH": "audit-branch"}) == { + "time_travel_mode": "at", + "branch": "audit-branch", + } + assert _extract_time_travel_from_options({"TAG": "release_v1"}) == { + "time_travel_mode": "at", + "version_tag": "release_v1", + } + assert _extract_time_travel_from_options({"VERSION_REF": "wap_ref"}) == { + "time_travel_mode": "at", + "version_ref": "wap_ref", + } + assert _extract_time_travel_from_options({"VERSION-REF": "wap_ref"}) == { + "time_travel_mode": "at", + "version_ref": "wap_ref", + } + + with pytest.raises( + ValueError, + match=r"Cannot use 'branch' option with time_travel_mode='before'", + ): + _extract_time_travel_from_options( + {"BRANCH": "audit-branch", "TIME_TRAVEL_MODE": "before"} + ) + + +def test_time_travel_version_ref_validation(): + with pytest.raises(ValueError, match="'version_ref' must be a non-empty"): + TimeTravelConfig.validate_and_normalize_params( + time_travel_mode="at", version_ref="" + ) + with pytest.raises(ValueError, match="'branch' must be a string Iceberg name"): + TimeTravelConfig.validate_and_normalize_params( + time_travel_mode="at", branch=123 + )