Skip to content

Commit 688463b

Browse files
feat: add Iceberg WAP branch read via VERSION_REF time travel
Extend TimeTravelConfig with version_ref and branch parameters, map Spark Iceberg branch/tag reader options, and emit Snowflake AT(VERSION_REF => ...) for named-ref reads. Tags continue to work through version_tag but now use the unified VERSION_REF SQL form.
1 parent e2e4074 commit 688463b

6 files changed

Lines changed: 235 additions & 56 deletions

File tree

src/snowflake/snowpark/_internal/utils.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,6 +1946,23 @@ def remove_comments(sql_query: str, uuids: List[str]) -> str:
19461946
)
19471947

19481948

1949+
def _validate_iceberg_named_version_ref(
1950+
value: Optional[str], param_name: str
1951+
) -> Optional[str]:
1952+
"""Validate a non-empty Iceberg tag/branch/version-ref name."""
1953+
if value is None:
1954+
return None
1955+
if not isinstance(value, str):
1956+
raise ValueError(
1957+
f"'{param_name}' must be a string Iceberg name, "
1958+
f"got {type(value).__name__}."
1959+
)
1960+
stripped = value.strip()
1961+
if not stripped:
1962+
raise ValueError(f"'{param_name}' must be a non-empty Iceberg name.")
1963+
return stripped
1964+
1965+
19491966
class TimeTravelConfig(NamedTuple):
19501967
"""Configuration for time travel operations."""
19511968

@@ -1957,6 +1974,8 @@ class TimeTravelConfig(NamedTuple):
19571974
stream: Optional[str] = None
19581975
version: Optional[int] = None
19591976
version_tag: Optional[str] = None
1977+
version_ref: Optional[str] = None
1978+
branch: Optional[str] = None
19601979

19611980
@staticmethod
19621981
def validate_and_normalize_params(
@@ -1968,6 +1987,8 @@ def validate_and_normalize_params(
19681987
stream: Optional[str] = None,
19691988
version: Optional[int] = None,
19701989
version_tag: Optional[str] = None,
1990+
version_ref: Optional[str] = None,
1991+
branch: Optional[str] = None,
19711992
) -> Optional["TimeTravelConfig"]:
19721993
"""
19731994
Validates and normalizes time travel parameters.
@@ -1989,9 +2010,31 @@ def validate_and_normalize_params(
19892010
Raises:
19902011
ValueError: If parameters are invalid.
19912012
"""
2013+
version_tag = _validate_iceberg_named_version_ref(version_tag, "version_tag")
2014+
version_ref = _validate_iceberg_named_version_ref(version_ref, "version_ref")
2015+
branch = _validate_iceberg_named_version_ref(branch, "branch")
2016+
2017+
named_ref_count = sum(
2018+
arg is not None for arg in (version_tag, version_ref, branch)
2019+
)
2020+
if named_ref_count > 1:
2021+
raise ValueError(
2022+
"Exactly one of 'version_tag', 'version_ref', or 'branch' may be "
2023+
"provided for Iceberg named-ref time travel."
2024+
)
2025+
19922026
time_travel_arg_count = sum(
19932027
arg is not None
1994-
for arg in (statement, offset, timestamp, stream, version, version_tag)
2028+
for arg in (
2029+
statement,
2030+
offset,
2031+
timestamp,
2032+
stream,
2033+
version,
2034+
version_tag,
2035+
version_ref,
2036+
branch,
2037+
)
19952038
)
19962039

19972040
# Validate mode
@@ -2026,32 +2069,25 @@ def validate_and_normalize_params(
20262069
f"'version' must be an int Iceberg snapshot id, got {type(version).__name__}."
20272070
)
20282071

2029-
# version_tag (Iceberg tag name, mapped to Snowflake's
2030-
# ``AT(VERSION_TAG => '<name>')`` grammar) only works with 'at' mode —
2031-
# Iceberg tag reads are positional (bound to a specific snapshot),
2032-
# not range-of-time, so ``BEFORE`` has no meaning.
2033-
if version_tag is not None and time_travel_mode.lower() != "at":
2034-
raise ValueError(
2035-
"Iceberg version_tag time travel can only be used with "
2036-
"time_travel_mode='at', not 'before'."
2037-
)
2038-
2039-
# Validate version_tag type — Iceberg tag names are strings. Empty
2040-
# strings are invalid.
2041-
if version_tag is not None:
2042-
if not isinstance(version_tag, str):
2072+
# Named Iceberg refs (tag / branch / version_ref) map to Snowflake's
2073+
# ``AT(VERSION_REF => '<name>')`` grammar and only work with 'at' mode.
2074+
for param_name, param_value in (
2075+
("version_tag", version_tag),
2076+
("version_ref", version_ref),
2077+
("branch", branch),
2078+
):
2079+
if param_value is not None and time_travel_mode.lower() != "at":
20432080
raise ValueError(
2044-
f"'version_tag' must be a string Iceberg tag name, "
2045-
f"got {type(version_tag).__name__}."
2081+
f"Iceberg {param_name} time travel can only be used with "
2082+
"time_travel_mode='at', not 'before'."
20462083
)
2047-
if not version_tag:
2048-
raise ValueError("'version_tag' must be a non-empty Iceberg tag name.")
20492084

20502085
# Validate exactly one parameter is provided
20512086
if time_travel_arg_count != 1:
20522087
raise ValueError(
20532088
"Exactly one of 'statement', 'offset', 'timestamp', 'stream', "
2054-
"'version', or 'version_tag' must be provided."
2089+
"'version', 'version_tag', 'version_ref', or 'branch' must be "
2090+
"provided."
20552091
)
20562092

20572093
# Normalize timestamp
@@ -2087,6 +2123,8 @@ def validate_and_normalize_params(
20872123
stream=stream,
20882124
version=version,
20892125
version_tag=version_tag,
2126+
version_ref=version_ref,
2127+
branch=branch,
20902128
)
20912129

20922130
def generate_sql_clause(self) -> str:
@@ -2097,11 +2135,12 @@ def generate_sql_clause(self) -> str:
20972135
Returns:
20982136
SQL clause like " AT (TIMESTAMP => TO_TIMESTAMP_NTZ('...'))",
20992137
" AT (VERSION => 1234567890)" for Iceberg snapshot id time travel,
2100-
or " AT (VERSION_TAG => 'release_v1')" for Iceberg tag time
2101-
travel.
2138+
or " AT (VERSION_REF => 'audit-branch')" for Iceberg tag/branch
2139+
time travel.
21022140
21032141
Note on escaping: string-valued parameters (``statement``,
2104-
``stream``, ``version_tag``, ``timestamp``) are embedded inside
2142+
``stream``, ``version_tag``, ``version_ref``, ``branch``,
2143+
``timestamp``) are embedded inside
21052144
single-quoted SQL literals via the existing ``str_to_sql``
21062145
helper in ``analyzer.datatype_mapper`` so embedded ``'``, ``\\``
21072146
and newline characters are properly escaped. This keeps the
@@ -2128,8 +2167,12 @@ def generate_sql_clause(self) -> str:
21282167
clause += f"(STREAM => {str_to_sql(self.stream)})"
21292168
elif self.version is not None:
21302169
clause += f"(VERSION => {self.version})"
2170+
elif self.version_ref is not None:
2171+
clause += f"(VERSION_REF => {str_to_sql(self.version_ref)})"
2172+
elif self.branch is not None:
2173+
clause += f"(VERSION_REF => {str_to_sql(self.branch)})"
21312174
elif self.version_tag is not None:
2132-
clause += f"(VERSION_TAG => {str_to_sql(self.version_tag)})"
2175+
clause += f"(VERSION_REF => {str_to_sql(self.version_tag)})"
21332176
elif self.timestamp is not None:
21342177
if self.timestamp_type is not None:
21352178
timestamp_type = self.timestamp_type.upper()

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,17 @@ def _extract_time_travel_from_options(options: dict) -> dict:
164164
165165
Special handling for 'VERSION_TAG' / 'VERSION-TAG' (Iceberg tag name) —
166166
both aliases map to the internal ``version_tag`` time travel parameter
167-
and emit ``AT(VERSION_TAG => '<name>')`` on the Snowflake side (see
168-
Spark Iceberg's ``VERSION AS OF '<tag_name>'`` reader path):
169-
- Automatically set time_travel_mode to 'at'
170-
(tag reads are positional — bound to a specific snapshot — not
171-
range-of-time)
172-
- Cannot be used with time_travel_mode='before' (raises error)
167+
and emit ``AT(VERSION_REF => '<name>')`` on the Snowflake side.
168+
169+
Special handling for 'VERSION_REF' / 'VERSION-REF' (Iceberg tag or
170+
branch name) — maps to the internal ``version_ref`` parameter.
171+
172+
Special handling for 'BRANCH' (Spark Iceberg WAP branch read) — maps
173+
to the internal ``branch`` parameter and emits
174+
``AT(VERSION_REF => '<name>')``.
175+
176+
Special handling for 'TAG' (Spark Iceberg ``SparkReadOptions.TAG``) —
177+
maps to the internal ``version_tag`` parameter.
173178
"""
174179
result = {}
175180
excluded_keys = set()
@@ -220,27 +225,64 @@ def _extract_time_travel_from_options(options: dict) -> dict:
220225
)
221226
result["time_travel_mode"] = "at"
222227

223-
# Handle Iceberg tag (``version_tag`` / ``version-tag``). Both aliases
224-
# route to the internal ``version_tag`` parameter and emit
225-
# ``AT(VERSION_TAG => '<name>')`` server-side. Auto-sets mode='at'.
226-
version_tag_value = options.get("VERSION_TAG")
227-
version_tag_source = "version_tag"
228-
if version_tag_value is None:
229-
version_tag_value = options.get("VERSION-TAG")
230-
version_tag_source = "version-tag"
231-
if version_tag_value is not None:
228+
def _set_named_ref_option(
229+
result: dict,
230+
*,
231+
param_name: str,
232+
option_source: str,
233+
raw_value: Any,
234+
) -> None:
232235
if (
233236
"TIME_TRAVEL_MODE" in options
234237
and options["TIME_TRAVEL_MODE"].lower() == "before"
235238
):
236239
raise ValueError(
237-
f"Cannot use '{version_tag_source}' option with "
238-
"time_travel_mode='before'. Iceberg tag time travel only "
239-
"supports time_travel_mode='at'."
240+
f"Cannot use '{option_source}' option with "
241+
"time_travel_mode='before'. Iceberg named-ref time travel "
242+
"only supports time_travel_mode='at'."
240243
)
241-
result["version_tag"] = str(version_tag_value)
244+
result[param_name] = str(raw_value)
242245
result["time_travel_mode"] = "at"
243246

247+
# Handle Iceberg tag (``version_tag`` / ``version-tag`` / Spark ``tag``).
248+
version_tag_value = options.get("VERSION_TAG")
249+
version_tag_source = "version_tag"
250+
if version_tag_value is None:
251+
version_tag_value = options.get("VERSION-TAG")
252+
version_tag_source = "version-tag"
253+
if version_tag_value is None:
254+
version_tag_value = options.get("TAG")
255+
version_tag_source = "tag"
256+
if version_tag_value is not None:
257+
_set_named_ref_option(
258+
result,
259+
param_name="version_tag",
260+
option_source=version_tag_source,
261+
raw_value=version_tag_value,
262+
)
263+
264+
version_ref_value = options.get("VERSION_REF")
265+
version_ref_source = "version_ref"
266+
if version_ref_value is None:
267+
version_ref_value = options.get("VERSION-REF")
268+
version_ref_source = "version-ref"
269+
if version_ref_value is not None:
270+
_set_named_ref_option(
271+
result,
272+
param_name="version_ref",
273+
option_source=version_ref_source,
274+
raw_value=version_ref_value,
275+
)
276+
277+
branch_value = options.get("BRANCH")
278+
if branch_value is not None:
279+
_set_named_ref_option(
280+
result,
281+
param_name="branch",
282+
option_source="branch",
283+
raw_value=branch_value,
284+
)
285+
244286
for option_key, param_name in _TIME_TRAVEL_OPTIONS_PARAMS_MAP.items():
245287
if option_key in options and option_key not in excluded_keys:
246288
result[param_name] = options[option_key]
@@ -671,6 +713,8 @@ def table(
671713
# still pass them without us advertising the surface.
672714
version = kwargs.pop("version", None)
673715
version_tag = kwargs.pop("version_tag", None)
716+
version_ref = kwargs.pop("version_ref", None)
717+
branch = kwargs.pop("branch", None)
674718
if kwargs:
675719
raise TypeError(
676720
f"table() got unexpected keyword arguments: {sorted(kwargs)}"
@@ -700,15 +744,23 @@ def table(
700744
time_travel_mode is not None
701745
or version is not None
702746
or version_tag is not None
747+
or version_ref is not None
748+
or branch is not None
703749
):
704-
# If version / version_tag is provided without mode, default to
705-
# 'at' — snapshot ids and tag reads only make sense with AT
706-
# (symmetric with the as-of-timestamp option handling).
750+
# If version / named-ref params are provided without mode,
751+
# default to 'at'.
707752
effective_mode = (
708753
time_travel_mode
709754
if time_travel_mode
710755
else (
711-
"at" if (version is not None or version_tag is not None) else None
756+
"at"
757+
if (
758+
version is not None
759+
or version_tag is not None
760+
or version_ref is not None
761+
or branch is not None
762+
)
763+
else None
712764
)
713765
)
714766
time_travel_params = {
@@ -720,6 +772,8 @@ def table(
720772
"stream": stream,
721773
"version": version,
722774
"version_tag": version_tag,
775+
"version_ref": version_ref,
776+
"branch": branch,
723777
}
724778
else:
725779
# if time_travel_mode is not provided, extract time travel config from options

src/snowflake/snowpark/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,6 +2780,8 @@ def table(
27802780
# still pass them without us advertising the surface.
27812781
version = kwargs.pop("version", None)
27822782
version_tag = kwargs.pop("version_tag", None)
2783+
version_ref = kwargs.pop("version_ref", None)
2784+
branch = kwargs.pop("branch", None)
27832785
if kwargs:
27842786
raise TypeError(
27852787
f"table() got unexpected keyword arguments: {sorted(kwargs)}"
@@ -2823,6 +2825,8 @@ def table(
28232825
stream=stream,
28242826
version=version,
28252827
version_tag=version_tag,
2828+
version_ref=version_ref,
2829+
branch=branch,
28262830
)
28272831
# Replace API call origin for table
28282832
set_api_call_source(t, "Session.table")

src/snowflake/snowpark/table.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ def __init__(
303303
# still pass them without us advertising the surface.
304304
version = kwargs.pop("version", None)
305305
version_tag = kwargs.pop("version_tag", None)
306+
version_ref = kwargs.pop("version_ref", None)
307+
branch = kwargs.pop("branch", None)
306308
if kwargs:
307309
raise TypeError(
308310
f"Table() got unexpected keyword arguments: {sorted(kwargs)}"
@@ -336,6 +338,8 @@ def __init__(
336338
stream=stream,
337339
version=version,
338340
version_tag=version_tag,
341+
version_ref=version_ref,
342+
branch=branch,
339343
)
340344

341345
snowflake_table_plan = SnowflakeTable(

tests/integ/test_dataframe.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8490,3 +8490,21 @@ def test_iceberg_version_tag_time_travel_dataframe_reader_option(session):
84908490
session.read.option("version-tag", tag_name).table(table_fqn).collect()
84918491
)
84928492
assert via_kwarg == via_option == via_hyphen_option
8493+
8494+
8495+
@pytest.mark.skip(
8496+
reason=(
8497+
"Requires a Snowflake-managed Iceberg table with a WAP branch and "
8498+
"FEATURE_ICEBERG_TIME_TRAVEL enabled on the account. Tested manually."
8499+
)
8500+
)
8501+
def test_iceberg_branch_time_travel_dataframe_reader_option(session):
8502+
"""End-to-end: Spark Iceberg ``branch`` option maps to VERSION_REF."""
8503+
table_fqn = "ICEBERG_GAP_TEST_HORIZON.TESTSCHEMA.snapshot_demo"
8504+
branch_name = "audit_branch"
8505+
8506+
via_branch = session.read.option("branch", branch_name).table(table_fqn).collect()
8507+
via_version_ref = (
8508+
session.read.option("version_ref", branch_name).table(table_fqn).collect()
8509+
)
8510+
assert via_branch == via_version_ref

0 commit comments

Comments
 (0)