Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,11 @@ def __init__(
# Metadata/Attributes for the plan
self._attributes: Optional[List[Attribute]] = None
self.table_reference = self.entity.name
if self.entity.time_travel_config is not None:
if self.entity.iceberg_changes_config is not None:
self.table_reference += (
self.entity.iceberg_changes_config.generate_sql_clause()
)
elif self.entity.time_travel_config is not None:
self.table_reference += self.entity.time_travel_config.generate_sql_clause()

def __deepcopy__(self, memodict={}) -> "SelectableEntity": # noqa: B006
Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,13 @@ def large_local_relation_plan(

def table(self, table_name: str, source_plan: LogicalPlan) -> SnowflakePlan:
table_reference = table_name
if isinstance(source_plan, SnowflakeTable) and source_plan.time_travel_config:
table_reference += source_plan.time_travel_config.generate_sql_clause()
if isinstance(source_plan, SnowflakeTable):
if source_plan.iceberg_changes_config:
table_reference += (
source_plan.iceberg_changes_config.generate_sql_clause()
)
elif source_plan.time_travel_config:
table_reference += source_plan.time_travel_config.generate_sql_clause()

return self.query(project_statement([], table_reference), source_plan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from snowflake.snowpark._internal.utils import TimeTravelConfig
from snowflake.snowpark._internal.utils import IcebergChangesConfig, TimeTravelConfig

from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
Expand Down Expand Up @@ -97,10 +97,12 @@ def __init__(
session: "Session",
is_temp_table_for_cleanup: bool = False,
time_travel_config: Optional[TimeTravelConfig] = None,
iceberg_changes_config: Optional[IcebergChangesConfig] = None,
) -> None:
super().__init__()
self.name = name
self.time_travel_config = time_travel_config
self.iceberg_changes_config = iceberg_changes_config
# When `is_temp_table_for_cleanup` is True, it's a temp table
# generated by Snowpark (currently only df.cache_result) under the hood
# and users are not aware of it.
Expand Down
83 changes: 83 additions & 0 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,89 @@ def generate_sql_clause(self) -> str:
return clause


class IcebergChangesConfig(NamedTuple):
"""Configuration for Iceberg incremental reads via Snowflake ``CHANGES``.

Spark Iceberg exposes incremental reads through::

spark.read.format("iceberg")
.option("start-snapshot-id", S1)
.option("end-snapshot-id", S2) # optional
.load("table")

Snowflake translates this to::

SELECT * FROM table
CHANGES (INFORMATION => APPEND_ONLY)
AT (VERSION => S1)
[ END (VERSION => S2) ]

When ``end_version`` is omitted, Snowflake uses the current snapshot as
the end of the change interval (same semantics as omitting ``END`` on
generic ``CHANGES`` queries).
"""

start_version: int
end_version: Optional[int] = None
information: str = "APPEND_ONLY"

@staticmethod
def _coerce_snapshot_id(value: object, option_name: str) -> int:
try:
snapshot_id = int(value) # type: ignore[arg-type]
except (TypeError, ValueError):
raise ValueError(
f"'{option_name}' must be a 64-bit integer Iceberg snapshot id, "
f"got {value!r}."
) from None
if isinstance(snapshot_id, bool):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question about this part of code:
It looks like after the first try block, it is impossible for snapshot_id to be a bool value here, can you tell me what is the scenario this check trying to guard?

raise ValueError(
f"'{option_name}' must be a 64-bit integer Iceberg snapshot id, "
f"got {type(value).__name__}."
)
return snapshot_id

@staticmethod
def validate_and_normalize_params(
start_snapshot_id: Optional[int] = None,
end_snapshot_id: Optional[int] = None,
information: str = "APPEND_ONLY",
) -> Optional["IcebergChangesConfig"]:
if start_snapshot_id is None and end_snapshot_id is None:
return None
if start_snapshot_id is None:
raise ValueError(
"Iceberg incremental read requires 'start-snapshot-id'; "
"'end-snapshot-id' cannot be used alone."
)
start = IcebergChangesConfig._coerce_snapshot_id(
start_snapshot_id, "start-snapshot-id"
)
end = None
if end_snapshot_id is not None:
end = IcebergChangesConfig._coerce_snapshot_id(
end_snapshot_id, "end-snapshot-id"
)
info = information.upper()
if info not in ("APPEND_ONLY", "DEFAULT"):
raise ValueError(
"Iceberg incremental read 'information' must be 'APPEND_ONLY' "
f"or 'DEFAULT', got {information!r}."
)
return IcebergChangesConfig(
start_version=start, end_version=end, information=info
)

def generate_sql_clause(self) -> str:
clause = (
f" CHANGES (INFORMATION => {self.information}) "
f"AT (VERSION => {self.start_version})"
)
if self.end_version is not None:
clause += f" END (VERSION => {self.end_version})"
return clause


def get_line_numbers(
commented_sql_query: str,
child_uuids: List[str],
Expand Down
104 changes: 104 additions & 0 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,81 @@ def _extract_time_travel_from_options(options: dict) -> dict:
return result


def _get_reader_option(options: dict, *keys: str):
"""Case-insensitive lookup for a reader option key."""
for key in keys:
for option_key, value in options.items():
if option_key.upper() == key.upper():
return value
return None


def _extract_iceberg_changes_from_options(options: dict) -> dict:
"""Extract Spark Iceberg incremental-read options from a reader dict.

Maps ``start-snapshot-id`` / ``end-snapshot-id`` (and underscore
variants) to internal ``start_snapshot_id`` / ``end_snapshot_id``
kwargs consumed by :meth:`Session.table`.
"""
start = _get_reader_option(options, "start-snapshot-id", "start_snapshot_id")
end = _get_reader_option(options, "end-snapshot-id", "end_snapshot_id")
if start is None and end is None:
return {}
if start is None:
raise ValueError(
"Iceberg incremental read requires 'start-snapshot-id'; "
"'end-snapshot-id' cannot be used alone."
)
try:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part looks to be the same logic as _coerce_snapshot_id introduced in utils, is it possible to reuse?

start_id = int(start)
except (TypeError, ValueError):
raise ValueError(
"'start-snapshot-id' must be a 64-bit integer Iceberg snapshot id, "
f"got {start!r}."
) from None
end_id = None
if end is not None:
try:
end_id = int(end)
except (TypeError, ValueError):
raise ValueError(
"'end-snapshot-id' must be a 64-bit integer Iceberg snapshot id, "
f"got {end!r}."
) from None
return {"start_snapshot_id": start_id, "end_snapshot_id": end_id}


def _reader_options_conflict_with_incremental_read(options: dict) -> list[str]:
"""Return reader option keys that cannot coexist with incremental read."""
incremental_keys = {
"start-snapshot-id",
"start_snapshot_id",
"end-snapshot-id",
"end_snapshot_id",
}
if not any(
k.upper().replace("_", "-") in {x.replace("_", "-") for x in incremental_keys}
for k in options
):
return []
blocked = []
for key in options:
upper = key.upper()
if upper in incremental_keys or upper.replace("_", "-") in {
x.replace("_", "-") for x in incremental_keys
}:
continue
if upper in _TIME_TRAVEL_OPTIONS_PARAMS_MAP or upper in (
"SNAPSHOT-ID",
"SNAPSHOT_ID",
"AS-OF-TIMESTAMP",
"VERSION_TAG",
"VERSION-TAG",
):
blocked.append(key)
return blocked


class DataFrameReader:
"""Provides methods to load data in various supported formats from a Snowflake
stage to a :class:`DataFrame`. The paths provided to the DataFrameReader must refer
Expand Down Expand Up @@ -671,11 +746,29 @@ def table(
# still pass them without us advertising the surface.
version = kwargs.pop("version", None)
version_tag = kwargs.pop("version_tag", None)
start_snapshot_id = kwargs.pop("start_snapshot_id", None)
end_snapshot_id = kwargs.pop("end_snapshot_id", None)
if kwargs:
raise TypeError(
f"table() got unexpected keyword arguments: {sorted(kwargs)}"
)

changes_from_options = _extract_iceberg_changes_from_options(self._cur_options)
if changes_from_options:
conflicting = _reader_options_conflict_with_incremental_read(
self._cur_options
)
if conflicting:
raise ValueError(
"Cannot combine Iceberg incremental read "
"('start-snapshot-id' / 'end-snapshot-id') with time travel "
f"options on the same read; found {conflicting!r}."
)
if start_snapshot_id is None:
start_snapshot_id = changes_from_options["start_snapshot_id"]
if end_snapshot_id is None:
end_snapshot_id = changes_from_options.get("end_snapshot_id")

# AST.
stmt = None
if _emit_ast and self._ast is not None:
Expand All @@ -697,6 +790,17 @@ def table(
ast.stream.value = stream

if (
start_snapshot_id is not None
or end_snapshot_id is not None
or changes_from_options
):
table = self._session.table(
name,
_emit_ast=False,
start_snapshot_id=start_snapshot_id,
end_snapshot_id=end_snapshot_id,
)
elif (
time_travel_mode is not None
or version is not None
or version_tag is not None
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,6 +2780,8 @@ def table(
# still pass them without us advertising the surface.
version = kwargs.pop("version", None)
version_tag = kwargs.pop("version_tag", None)
start_snapshot_id = kwargs.pop("start_snapshot_id", None)
end_snapshot_id = kwargs.pop("end_snapshot_id", None)
if kwargs:
raise TypeError(
f"table() got unexpected keyword arguments: {sorted(kwargs)}"
Expand Down Expand Up @@ -2823,6 +2825,8 @@ def table(
stream=stream,
version=version,
version_tag=version_tag,
start_snapshot_id=start_snapshot_id,
end_snapshot_id=end_snapshot_id,
)
# Replace API call origin for table
set_api_call_source(t, "Session.table")
Expand Down
23 changes: 23 additions & 0 deletions src/snowflake/snowpark/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from snowflake.snowpark._internal.telemetry import add_api_call, set_api_call_source
from snowflake.snowpark._internal.type_utils import ColumnOrLiteral
from snowflake.snowpark._internal.utils import (
IcebergChangesConfig,
publicapi,
TimeTravelConfig,
)
Expand Down Expand Up @@ -303,11 +304,18 @@ def __init__(
# still pass them without us advertising the surface.
version = kwargs.pop("version", None)
version_tag = kwargs.pop("version_tag", None)
start_snapshot_id = kwargs.pop("start_snapshot_id", None)
end_snapshot_id = kwargs.pop("end_snapshot_id", None)
if kwargs:
raise TypeError(
f"Table() got unexpected keyword arguments: {sorted(kwargs)}"
)

iceberg_changes_config = IcebergChangesConfig.validate_and_normalize_params(
start_snapshot_id=start_snapshot_id,
end_snapshot_id=end_snapshot_id,
)

if _ast_stmt is None and session is not None and _emit_ast:
_ast_stmt = session._ast_batch.bind()
ast = with_src_position(_ast_stmt.expr.table, _ast_stmt)
Expand Down Expand Up @@ -337,12 +345,19 @@ def __init__(
version=version,
version_tag=version_tag,
)
if iceberg_changes_config is not None and time_travel_config is not None:
raise ValueError(
"Cannot combine Iceberg incremental read "
"('start-snapshot-id' / 'end-snapshot-id') with time travel "
"options on the same read."
)

snowflake_table_plan = SnowflakeTable(
table_name,
session=session,
is_temp_table_for_cleanup=is_temp_table_for_cleanup,
time_travel_config=time_travel_config,
iceberg_changes_config=iceberg_changes_config,
)
if session.sql_simplifier_enabled:
plan = session._analyzer.create_select_statement(
Expand All @@ -358,6 +373,7 @@ def __init__(
self.table_name: str = table_name #: The table name
self._is_temp_table_for_cleanup = is_temp_table_for_cleanup
self._time_travel_config = time_travel_config
self._iceberg_changes_config = iceberg_changes_config

# By default, the set the initial API call to say 'Table.__init__' since
# people could instantiate a table directly. This value is overwritten when
Expand All @@ -368,6 +384,13 @@ def _copy_without_ast(self):
kwargs = {}
if self._time_travel_config:
kwargs.update(self._time_travel_config._asdict())
if self._iceberg_changes_config:
kwargs.update(
{
"start_snapshot_id": self._iceberg_changes_config.start_version,
"end_snapshot_id": self._iceberg_changes_config.end_version,
}
)

return Table(
self.table_name,
Expand Down
Loading
Loading