Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@
- `get_username_password`
- `get_cloud_provider_token`

- Added support for the following scalar functions in `functions.py`:
- `h3_cell_to_boundary`
- `h3_cell_to_parent`
- `h3_cell_to_point`
- `h3_compact_cells`
- `h3_compact_cells_strings`
- `h3_coverage`
- `h3_coverage_strings`
- `h3_get_resolution`
- `h3_grid_disk`
- `h3_grid_distance`

### Snowpark pandas API Updates

#### New Features
Expand Down
10 changes: 10 additions & 0 deletions docs/source/snowpark/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ Functions
hex
hex_encode
hour
h3_cell_to_boundary
h3_cell_to_parent
h3_cell_to_point
h3_compact_cells
h3_compact_cells_strings
h3_coverage
h3_coverage_strings
h3_get_resolution
h3_grid_disk
h3_grid_distance
iff
ifnull
in_
Expand Down
276 changes: 254 additions & 22 deletions src/snowflake/snowpark/_functions/scalar_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from snowflake.snowpark._functions.general_functions import (
builtin,
lit,
)


Expand Down Expand Up @@ -321,9 +322,6 @@ def getdate(_emit_ast: bool = True) -> Column:
"""
Returns the current timestamp for the system in the local time zone.

Args:
_emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True.

Returns:
A :class:`~snowflake.snowpark.Column` with the current date and time.

Expand All @@ -342,9 +340,6 @@ def localtime(_emit_ast: bool = True) -> Column:
"""
Returns the current time for the system.

Args:
_emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True.

Returns:
A :class:`~snowflake.snowpark.Column` with the current local time.

Expand All @@ -362,9 +357,6 @@ def systimestamp(_emit_ast: bool = True) -> Column:
"""
Returns the current timestamp for the system.

Args:
_emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True.

Returns:
A :class:`~snowflake.snowpark.Column` with the current system timestamp.

Expand All @@ -383,9 +375,6 @@ def invoker_role(_emit_ast: bool = True) -> Column:
"""
Returns the name of the role that was active when the current stored procedure or user-defined function was called.

Args:
_emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True.

Returns:
Column: A Snowflake `Column` object representing the name of the active role.

Expand All @@ -406,10 +395,6 @@ def invoker_share(_emit_ast: bool = True) -> Column:
Returns the name of the share that directly accessed the table or view where the INVOKER_SHARE
function is invoked, otherwise the function returns None.

Args:
_emit_ast (bool, optional): A flag indicating whether to emit the abstract
syntax tree (AST). Defaults to True.

Returns:
Column: A Snowflake `Column` object representing the name of the active share.

Expand All @@ -428,7 +413,6 @@ def is_application_role_in_session(role_name: str, _emit_ast: bool = True) -> Co

Args:
role_name (str): The name of the application role to check.
_emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True.

Returns:
A :class:`~snowflake.snowpark.Column` indicating whether the specified application role is active in the current session.
Expand All @@ -452,7 +436,6 @@ def is_database_role_in_session(

Args:
role_name (ColumnOrName): The name of the database role to check. Can be a string or a Column.
_emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True.

Returns:
Column: A Snowflake `Column` object representing the result of the check.
Expand All @@ -476,7 +459,6 @@ def is_granted_to_invoker_role(role_name: str, _emit_ast: bool = True) -> Column

Args:
role_name (str): The name of the role to check.
_emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True.

Returns:
Column: A Snowflake `Column` object representing the result of the check.
Expand All @@ -498,7 +480,6 @@ def is_role_in_session(role: ColumnOrName, _emit_ast: bool = True) -> Column:

Args:
role (ColumnOrName): A Column or column name containing the role name to check.
_emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True.

Returns:
Column: A Snowflake `Column` object representing the result of the check.
Expand All @@ -520,8 +501,6 @@ def getvariable(name: str, _emit_ast: bool = True) -> Column:

Args:
name (str): The name of the session variable to retrieve.
_emit_ast (bool, optional): A flag indicating whether to emit the abstract syntax tree (AST).
Defaults to True.

Returns:
Column: A Snowflake `Column` object representing the value of the specified session variable.
Expand All @@ -531,3 +510,256 @@ def getvariable(name: str, _emit_ast: bool = True) -> Column:
>>> assert result[0]["RESULT"] is None
"""
return builtin("getvariable", _emit_ast=_emit_ast)(name)


@publicapi
def h3_cell_to_boundary(cell_id: ColumnOrName, _emit_ast: bool = True) -> Column:
"""Returns the boundary of an H3 cell as a GeoJSON polygon.

Args:
cell_id (ColumnOrName): The H3 cell IDs.

Returns:
Column: The boundary of the H3 cell as a GeoJSON polygon string.

Example::
>>> df = session.create_dataframe([613036919424548863, 577023702256844799], schema=["cell_id"])
>>> result = df.select(h3_cell_to_boundary(df["cell_id"]).alias("boundary")).collect()
>>> len(result) == 2
True
"""
c = _to_col_if_str(cell_id, "h3_cell_to_boundary")
return builtin("h3_cell_to_boundary", _emit_ast=_emit_ast)(c)


@publicapi
def h3_cell_to_parent(
cell_id: ColumnOrName, target_resolution: ColumnOrName, _emit_ast: bool = True
) -> Column:
"""Returns the parent H3 cell at the specified target resolution.

Args:
cell_id (ColumnOrName): The H3 cell IDs.
target_resolution (ColumnOrName) : The target resolution levels.

Returns:
Column: The parent H3 cell at the target resolution.

Example::
>>> from snowflake.snowpark.functions import col
>>> df = session.create_dataframe([[613036919424548863, 7], [608533319805566975, 6]], schema=["cell_id", "target_resolution"])
>>> df.select(h3_cell_to_parent(col("cell_id"), col("target_resolution")).alias("parent_cell")).collect()
[Row(PARENT_CELL=608533319805566975), Row(PARENT_CELL=604029720295636991)]
"""
cell_id_c = _to_col_if_str(cell_id, "h3_cell_to_parent")
target_resolution_c = _to_col_if_str(target_resolution, "h3_cell_to_parent")
return builtin("h3_cell_to_parent", _emit_ast=_emit_ast)(
cell_id_c, target_resolution_c
)


@publicapi
def h3_cell_to_point(cell_id: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Returns the center point of an H3 cell as a GeoJSON Point object.

Args:
cell_id (ColumnOrName): The H3 cell IDs.

Returns:
Column: GeoJSON Point objects representing the center points of the H3 cells.

Example::
>>> import json
>>> df = session.create_dataframe([613036919424548863], schema=["cell_id"])
>>> result = df.select(h3_cell_to_point(df["cell_id"]).alias("POINT")).collect()
>>> assert len(result) == 1
>>> point_json = json.loads(result[0]["POINT"])
>>> assert point_json["type"] == "Point"
>>> assert "coordinates" in point_json
>>> assert len(point_json["coordinates"]) == 2
"""
c = _to_col_if_str(cell_id, "h3_cell_to_point")
return builtin("h3_cell_to_point", _emit_ast=_emit_ast)(c)


@publicapi
def h3_compact_cells(array_of_cell_ids: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Returns a compacted array of H3 cell IDs by merging cells at the same resolution into their parent cells when possible.

Args:
array_of_cell_ids (ColumnOrName): An array of H3 cell IDs to be compacted.

Returns:
Column: An array of compacted H3 cell IDs.

Example::
>>> df = session.create_dataframe([
... [[622236750562230271, 622236750562263039, 622236750562295807, 622236750562328575, 622236750562361343, 622236750562394111, 622236750562426879, 622236750558396415]]
... ], schema=["cell_ids"])
>>> df.select(h3_compact_cells(df["cell_ids"]).alias("compacted")).collect()
[Row(COMPACTED='[\\n 622236750558396415,\\n 617733150935089151\\n]')]
"""
c = _to_col_if_str(array_of_cell_ids, "h3_compact_cells")
return builtin("h3_compact_cells", _emit_ast=_emit_ast)(c)


@publicapi
def h3_compact_cells_strings(
array_of_cell_ids: ColumnOrName, _emit_ast: bool = True
) -> Column:
"""
Returns a compacted array of H3 cell IDs by removing redundant cells that are covered by their parent cells at coarser resolutions.

Args:
array_of_cell_ids (ColumnOrName): An array of H3 cell ID strings to be compacted.

Returns:
Column: The compacted array of H3 cell ID strings.

Example::

>>> df = session.create_dataframe([[
... ['8a2a10705507fff', '8a2a1070550ffff', '8a2a10705517fff', '8a2a1070551ffff',
... '8a2a10705527fff', '8a2a1070552ffff', '8a2a10705537fff', '8a2a10705cdffff']
... ]], schema=["cell_ids"])
>>> df.select(h3_compact_cells_strings("cell_ids").alias("compacted")).collect()
[Row(COMPACTED='[\\n "8a2a10705cdffff",\\n "892a1070553ffff"\\n]')]
"""
c = _to_col_if_str(array_of_cell_ids, "h3_compact_cells_strings")
return builtin("h3_compact_cells_strings", _emit_ast=_emit_ast)(c)


@publicapi
def h3_coverage(
geography_expression: ColumnOrName,
target_resolution: ColumnOrName,
_emit_ast: bool = True,
) -> Column:
"""
Returns an array of H3 cell IDs that cover the given geography at the specified resolution.

Args:
geography_expression (ColumnOrName) : A GEOGRAPHY object.
target_resolution (ColumnOrName) : The target H3 resolution (0-15).

Returns:
Column: An array of H3 cell IDs as strings.

Example::
>>> from snowflake.snowpark.functions import to_geography, lit
>>> df = session.create_dataframe([
... ["POLYGON((-122.481889 37.826683,-122.479487 37.808548,-122.474150 37.808904,-122.476510 37.826935,-122.481889 37.826683))"]
... ], schema=["polygon_wkt"])
>>> result = df.select(h3_coverage(to_geography(df["polygon_wkt"]), lit(8)).alias("h3_cells")).collect()
>>> result
[Row(H3_CELLS='[\\n 613196571539931135,\\n 613196571542028287,\\n 613196571548319743,\\n 613196571550416895,\\n 613196571560902655,\\n 613196571598651391\\n]')]
"""
geography_col = _to_col_if_str(geography_expression, "h3_coverage")
resolution_col = _to_col_if_str(target_resolution, "h3_coverage")
return builtin("h3_coverage", _emit_ast=_emit_ast)(geography_col, resolution_col)


@publicapi
def h3_coverage_strings(
geography_expression: ColumnOrName,
target_resolution: Union[ColumnOrName, int],
_emit_ast: bool = True,
) -> Column:
"""
Returns an array of H3 cell identifiers as strings that cover the given geography at the specified resolution.

Args:
geography_expression (ColumnOrName): The GEOGRAPHY to cover.
target_resolution (ColumnOrName, int): The H3 resolution level (0-15).

Returns:
Column: An array of H3 cell identifiers as strings.

Example::

>>> from snowflake.snowpark.functions import to_geography
>>> df = session.create_dataframe([
... "POLYGON((-122.481889 37.826683,-122.479487 37.808548,-122.474150 37.808904,-122.476510 37.826935,-122.481889 37.826683))"
... ], schema=["geo_wkt"])
>>> df.select(h3_coverage_strings(to_geography(df["geo_wkt"]), 8).alias("h3_cells")).collect()
[Row(H3_CELLS='[\\n "8828308701fffff",\\n "8828308703fffff",\\n "8828308709fffff",\\n "882830870bfffff",\\n "8828308715fffff",\\n "8828308739fffff"\\n]')]
"""
geo_col = _to_col_if_str(geography_expression, "h3_coverage_strings")
res_col = (
target_resolution
if isinstance(target_resolution, Column)
else lit(target_resolution)
)
Comment thread
sfc-gh-aling marked this conversation as resolved.
return builtin("h3_coverage_strings", _emit_ast=_emit_ast)(geo_col, res_col)


@publicapi
def h3_get_resolution(cell_id: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Returns the resolution of an H3 cell ID.

Args:
cell_id (ColumnOrName): The H3 cell ID.

Returns:
Column: The resolution of the H3 cell ID.

Example::

>>> df = session.create_dataframe([617540519050084351, 617540519050084352], schema=["cell_id"])
>>> df.select(h3_get_resolution(df["cell_id"]).alias("resolution")).collect()
[Row(RESOLUTION=9), Row(RESOLUTION=9)]
"""
c = _to_col_if_str(cell_id, "h3_get_resolution")
return builtin("h3_get_resolution", _emit_ast=_emit_ast)(c)


@publicapi
def h3_grid_disk(
cell_id: ColumnOrName, k_value: ColumnOrName, _emit_ast: bool = True
) -> Column:
"""
Returns an array of H3 cell IDs within k distance of the origin cell.

Args:
cell_id (ColumnOrName): The H3 cell ID as the center of the disk.
k_value (ColumnOrName): The distance (number of rings) from the center cell.

Returns:
Column: An array of H3 cell IDs within the specified distance.

Example::

>>> df = session.create_dataframe([[617540519050084351, 1], [617540519050084351, 2]], schema=["cell_id", "k_value"])
>>> df.select(h3_grid_disk("cell_id", "k_value").alias("grid_disk")).collect()
[Row(GRID_DISK='[\\n 617540519050084351,\\n 617540519051657215,\\n 617540519050608639,\\n 617540519050870783,\\n 617540519050346495,\\n 617540519051395071,\\n 617540519051132927\\n]'), Row(GRID_DISK='[\\n 617540519050084351,\\n 617540519051657215,\\n 617540519050608639,\\n 617540519050870783,\\n 617540519050346495,\\n 617540519051395071,\\n 617540519051132927,\\n 617540519048249343,\\n 617540519048773631,\\n 617540519089143807,\\n 617540519088095231,\\n 617540519107756031,\\n 617540519108018175,\\n 617540519104086015,\\n 617540519103561727,\\n 617540519046414335,\\n 617540519047462911,\\n 617540519044579327,\\n 617540519044317183\\n]')]
"""
cell_id_col = _to_col_if_str(cell_id, "h3_grid_disk")
k_value_col = _to_col_if_str(k_value, "h3_grid_disk")
return builtin("h3_grid_disk", _emit_ast=_emit_ast)(cell_id_col, k_value_col)


@publicapi
def h3_grid_distance(
cell_id_1: ColumnOrName, cell_id_2: ColumnOrName, _emit_ast: bool = True
) -> Column:
"""
Returns the grid distance between two H3 cell IDs.

Args:
cell_id_1 (ColumnOrName): The first H3 cell ID column or value.
cell_id_2 (ColumnOrName): The second H3 cell ID column or value.

Returns:
Column: The grid distance between the two H3 cells.

Example::
>>> df = session.create_dataframe([[617540519103561727, 617540519052967935]], schema=["cell_id_1", "cell_id_2"])
>>> df.select(h3_grid_distance(df["cell_id_1"], df["cell_id_2"]).alias("distance")).collect()
[Row(DISTANCE=5)]
"""
cell_id_1 = _to_col_if_str(cell_id_1, "h3_grid_distance")
cell_id_2 = _to_col_if_str(cell_id_2, "h3_grid_distance")
return builtin("h3_grid_distance", _emit_ast=_emit_ast)(cell_id_1, cell_id_2)