Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 67 additions & 24 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,23 @@ def remove_comments(sql_query: str, uuids: List[str]) -> str:
)


def _validate_iceberg_named_version_ref(
value: Optional[str], param_name: str
) -> Optional[str]:
"""Validate a non-empty Iceberg tag/branch/version-ref name."""
if value is None:
return None
if not isinstance(value, str):
raise ValueError(
f"'{param_name}' must be a string Iceberg name, "
f"got {type(value).__name__}."
)
stripped = value.strip()
if not stripped:
raise ValueError(f"'{param_name}' must be a non-empty Iceberg name.")
return stripped


class TimeTravelConfig(NamedTuple):
"""Configuration for time travel operations."""

Expand All @@ -2010,6 +2027,8 @@ class TimeTravelConfig(NamedTuple):
stream: Optional[str] = None
version: Optional[int] = None
version_tag: Optional[str] = None
version_ref: Optional[str] = None
branch: Optional[str] = None

@staticmethod
def validate_and_normalize_params(
Expand All @@ -2021,6 +2040,8 @@ def validate_and_normalize_params(
stream: Optional[str] = None,
version: Optional[int] = None,
version_tag: Optional[str] = None,
version_ref: Optional[str] = None,
branch: Optional[str] = None,
) -> Optional["TimeTravelConfig"]:
"""
Validates and normalizes time travel parameters.
Expand All @@ -2042,9 +2063,31 @@ def validate_and_normalize_params(
Raises:
ValueError: If parameters are invalid.
"""
version_tag = _validate_iceberg_named_version_ref(version_tag, "version_tag")
version_ref = _validate_iceberg_named_version_ref(version_ref, "version_ref")
branch = _validate_iceberg_named_version_ref(branch, "branch")

named_ref_count = sum(
arg is not None for arg in (version_tag, version_ref, branch)
)
if named_ref_count > 1:
raise ValueError(
"Exactly one of 'version_tag', 'version_ref', or 'branch' may be "
"provided for Iceberg named-ref time travel."
)

time_travel_arg_count = sum(
arg is not None
for arg in (statement, offset, timestamp, stream, version, version_tag)
for arg in (
statement,
offset,
timestamp,
stream,
version,
version_tag,
version_ref,
branch,
)
)

# Validate mode
Expand Down Expand Up @@ -2079,32 +2122,25 @@ def validate_and_normalize_params(
f"'version' must be an int Iceberg snapshot id, got {type(version).__name__}."
)

# version_tag (Iceberg tag name, mapped to Snowflake's
# ``AT(VERSION_TAG => '<name>')`` grammar) only works with 'at' mode —
# Iceberg tag reads are positional (bound to a specific snapshot),
# not range-of-time, so ``BEFORE`` has no meaning.
if version_tag is not None and time_travel_mode.lower() != "at":
raise ValueError(
"Iceberg version_tag time travel can only be used with "
"time_travel_mode='at', not 'before'."
)

# Validate version_tag type — Iceberg tag names are strings. Empty
# strings are invalid.
if version_tag is not None:
if not isinstance(version_tag, str):
# Named Iceberg refs (tag / branch / version_ref) map to Snowflake's
# ``AT(VERSION_REF => '<name>')`` grammar and only work with 'at' mode.
for param_name, param_value in (
("version_tag", version_tag),
("version_ref", version_ref),
("branch", branch),
):
if param_value is not None and time_travel_mode.lower() != "at":
raise ValueError(
f"'version_tag' must be a string Iceberg tag name, "
f"got {type(version_tag).__name__}."
f"Iceberg {param_name} time travel can only be used with "
"time_travel_mode='at', not 'before'."
)
if not version_tag:
raise ValueError("'version_tag' must be a non-empty Iceberg tag name.")

# Validate exactly one parameter is provided
if time_travel_arg_count != 1:
raise ValueError(
"Exactly one of 'statement', 'offset', 'timestamp', 'stream', "
"'version', or 'version_tag' must be provided."
"'version', 'version_tag', 'version_ref', or 'branch' must be "
"provided."
)

# Normalize timestamp
Expand Down Expand Up @@ -2140,6 +2176,8 @@ def validate_and_normalize_params(
stream=stream,
version=version,
version_tag=version_tag,
version_ref=version_ref,
branch=branch,
)

def generate_sql_clause(self) -> str:
Expand All @@ -2150,11 +2188,12 @@ def generate_sql_clause(self) -> str:
Returns:
SQL clause like " AT (TIMESTAMP => TO_TIMESTAMP_NTZ('...'))",
" AT (VERSION => 1234567890)" for Iceberg snapshot id time travel,
or " AT (VERSION_TAG => 'release_v1')" for Iceberg tag time
travel.
or " AT (VERSION_REF => 'audit-branch')" for Iceberg tag/branch
time travel.

Note on escaping: string-valued parameters (``statement``,
``stream``, ``version_tag``, ``timestamp``) are embedded inside
``stream``, ``version_tag``, ``version_ref``, ``branch``,
``timestamp``) are embedded inside
single-quoted SQL literals via the existing ``str_to_sql``
helper in ``analyzer.datatype_mapper`` so embedded ``'``, ``\\``
and newline characters are properly escaped. This keeps the
Expand All @@ -2181,8 +2220,12 @@ def generate_sql_clause(self) -> str:
clause += f"(STREAM => {str_to_sql(self.stream)})"
elif self.version is not None:
clause += f"(VERSION => {self.version})"
elif self.version_ref is not None:
clause += f"(VERSION_REF => {str_to_sql(self.version_ref)})"
elif self.branch is not None:
clause += f"(VERSION_REF => {str_to_sql(self.branch)})"
elif self.version_tag is not None:
clause += f"(VERSION_TAG => {str_to_sql(self.version_tag)})"
clause += f"(VERSION_REF => {str_to_sql(self.version_tag)})"
Comment thread
sfc-gh-igarish marked this conversation as resolved.
elif self.timestamp is not None:
if self.timestamp_type is not None:
timestamp_type = self.timestamp_type.upper()
Expand Down
100 changes: 77 additions & 23 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,17 @@ def _extract_time_travel_from_options(options: dict) -> dict:

Special handling for 'VERSION_TAG' / 'VERSION-TAG' (Iceberg tag name) —
both aliases map to the internal ``version_tag`` time travel parameter
and emit ``AT(VERSION_TAG => '<name>')`` on the Snowflake side (see
Spark Iceberg's ``VERSION AS OF '<tag_name>'`` reader path):
- Automatically set time_travel_mode to 'at'
(tag reads are positional — bound to a specific snapshot — not
range-of-time)
- Cannot be used with time_travel_mode='before' (raises error)
and emit ``AT(VERSION_REF => '<name>')`` on the Snowflake side.

Special handling for 'VERSION_REF' / 'VERSION-REF' (Iceberg tag or
branch name) — maps to the internal ``version_ref`` parameter.

Special handling for 'BRANCH' (Spark Iceberg WAP branch read) — maps
to the internal ``branch`` parameter and emits
``AT(VERSION_REF => '<name>')``.

Special handling for 'TAG' (Spark Iceberg ``SparkReadOptions.TAG``) —
maps to the internal ``version_tag`` parameter.
"""
result = {}
excluded_keys = set()
Expand Down Expand Up @@ -220,27 +225,64 @@ def _extract_time_travel_from_options(options: dict) -> dict:
)
result["time_travel_mode"] = "at"

# Handle Iceberg tag (``version_tag`` / ``version-tag``). Both aliases
# route to the internal ``version_tag`` parameter and emit
# ``AT(VERSION_TAG => '<name>')`` server-side. Auto-sets mode='at'.
version_tag_value = options.get("VERSION_TAG")
version_tag_source = "version_tag"
if version_tag_value is None:
version_tag_value = options.get("VERSION-TAG")
version_tag_source = "version-tag"
if version_tag_value is not None:
def _set_named_ref_option(
result: dict,
*,
param_name: str,
option_source: str,
raw_value: Any,
) -> None:
if (
"TIME_TRAVEL_MODE" in options
and options["TIME_TRAVEL_MODE"].lower() == "before"
):
raise ValueError(
f"Cannot use '{version_tag_source}' option with "
"time_travel_mode='before'. Iceberg tag time travel only "
"supports time_travel_mode='at'."
f"Cannot use '{option_source}' option with "
"time_travel_mode='before'. Iceberg named-ref time travel "
"only supports time_travel_mode='at'."
)
result["version_tag"] = str(version_tag_value)
result[param_name] = str(raw_value)
result["time_travel_mode"] = "at"

# Handle Iceberg tag (``version_tag`` / ``version-tag`` / Spark ``tag``).
version_tag_value = options.get("VERSION_TAG")
version_tag_source = "version_tag"
if version_tag_value is None:
version_tag_value = options.get("VERSION-TAG")
version_tag_source = "version-tag"
if version_tag_value is None:
version_tag_value = options.get("TAG")
version_tag_source = "tag"
if version_tag_value is not None:
_set_named_ref_option(
result,
param_name="version_tag",
option_source=version_tag_source,
raw_value=version_tag_value,
)

version_ref_value = options.get("VERSION_REF")
version_ref_source = "version_ref"
if version_ref_value is None:
version_ref_value = options.get("VERSION-REF")
version_ref_source = "version-ref"
if version_ref_value is not None:
_set_named_ref_option(
result,
param_name="version_ref",
option_source=version_ref_source,
raw_value=version_ref_value,
)

branch_value = options.get("BRANCH")
if branch_value is not None:
_set_named_ref_option(
result,
param_name="branch",
option_source="branch",
raw_value=branch_value,
)

for option_key, param_name in _TIME_TRAVEL_OPTIONS_PARAMS_MAP.items():
if option_key in options and option_key not in excluded_keys:
result[param_name] = options[option_key]
Expand Down Expand Up @@ -671,6 +713,8 @@ def table(
# still pass them without us advertising the surface.
version = kwargs.pop("version", None)
version_tag = kwargs.pop("version_tag", None)
version_ref = kwargs.pop("version_ref", None)
branch = kwargs.pop("branch", None)
if kwargs:
raise TypeError(
f"table() got unexpected keyword arguments: {sorted(kwargs)}"
Expand Down Expand Up @@ -700,15 +744,23 @@ def table(
time_travel_mode is not None
or version is not None
or version_tag is not None
or version_ref is not None
or branch is not None
):
# If version / version_tag is provided without mode, default to
# 'at' — snapshot ids and tag reads only make sense with AT
# (symmetric with the as-of-timestamp option handling).
# If version / named-ref params are provided without mode,
# default to 'at'.
effective_mode = (
time_travel_mode
if time_travel_mode
else (
"at" if (version is not None or version_tag is not None) else None
"at"
if (
version is not None
or version_tag is not None
or version_ref is not None
or branch is not None
)
else None
)
)
time_travel_params = {
Expand All @@ -720,6 +772,8 @@ def table(
"stream": stream,
"version": version,
"version_tag": version_tag,
"version_ref": version_ref,
"branch": branch,
}
else:
# if time_travel_mode is not provided, extract time travel config from options
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,6 +2780,8 @@ def table(
# still pass them without us advertising the surface.
version = kwargs.pop("version", None)
version_tag = kwargs.pop("version_tag", None)
version_ref = kwargs.pop("version_ref", None)
branch = kwargs.pop("branch", None)
if kwargs:
raise TypeError(
f"table() got unexpected keyword arguments: {sorted(kwargs)}"
Expand Down Expand Up @@ -2823,6 +2825,8 @@ def table(
stream=stream,
version=version,
version_tag=version_tag,
version_ref=version_ref,
branch=branch,
)
# Replace API call origin for table
set_api_call_source(t, "Session.table")
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def __init__(
# still pass them without us advertising the surface.
version = kwargs.pop("version", None)
version_tag = kwargs.pop("version_tag", None)
version_ref = kwargs.pop("version_ref", None)
branch = kwargs.pop("branch", None)
if kwargs:
raise TypeError(
f"Table() got unexpected keyword arguments: {sorted(kwargs)}"
Expand Down Expand Up @@ -336,6 +338,8 @@ def __init__(
stream=stream,
version=version,
version_tag=version_tag,
version_ref=version_ref,
branch=branch,
)

snowflake_table_plan = SnowflakeTable(
Expand Down
18 changes: 18 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8601,3 +8601,21 @@ def test_iceberg_version_tag_time_travel_dataframe_reader_option(session):
session.read.option("version-tag", tag_name).table(table_fqn).collect()
)
assert via_kwarg == via_option == via_hyphen_option


@pytest.mark.skip(
reason=(
"Requires a Snowflake-managed Iceberg table with a WAP branch and "
"FEATURE_ICEBERG_TIME_TRAVEL enabled on the account. Tested manually."
)
)
def test_iceberg_branch_time_travel_dataframe_reader_option(session):
"""End-to-end: Spark Iceberg ``branch`` option maps to VERSION_REF."""
table_fqn = "ICEBERG_GAP_TEST_HORIZON.TESTSCHEMA.snapshot_demo"
branch_name = "audit_branch"

via_branch = session.read.option("branch", branch_name).table(table_fqn).collect()
via_version_ref = (
session.read.option("version_ref", branch_name).table(table_fqn).collect()
)
assert via_branch == via_version_ref
Loading
Loading