diff --git a/CHANGELOG.md b/CHANGELOG.md index 0aa47071de..2191aecac2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,18 @@ - `get_username_password` - `get_cloud_provider_token` +- Added support for the following scalar functions in `functions.py`: + - `h3_cell_to_boundary` + - `h3_cell_to_parent` + - `h3_cell_to_point` + - `h3_compact_cells` + - `h3_compact_cells_strings` + - `h3_coverage` + - `h3_coverage_strings` + - `h3_get_resolution` + - `h3_grid_disk` + - `h3_grid_distance` + ### Snowpark pandas API Updates #### New Features diff --git a/docs/source/snowpark/functions.rst b/docs/source/snowpark/functions.rst index 21da5f49db..9ae84e13f8 100644 --- a/docs/source/snowpark/functions.rst +++ b/docs/source/snowpark/functions.rst @@ -229,6 +229,16 @@ Functions hex hex_encode hour + h3_cell_to_boundary + h3_cell_to_parent + h3_cell_to_point + h3_compact_cells + h3_compact_cells_strings + h3_coverage + h3_coverage_strings + h3_get_resolution + h3_grid_disk + h3_grid_distance iff ifnull in_ diff --git a/src/snowflake/snowpark/_functions/scalar_functions.py b/src/snowflake/snowpark/_functions/scalar_functions.py index b8945328d7..924feccf95 100644 --- a/src/snowflake/snowpark/_functions/scalar_functions.py +++ b/src/snowflake/snowpark/_functions/scalar_functions.py @@ -13,6 +13,7 @@ ) from snowflake.snowpark._functions.general_functions import ( builtin, + lit, ) @@ -321,9 +322,6 @@ def getdate(_emit_ast: bool = True) -> Column: """ Returns the current timestamp for the system in the local time zone. - Args: - _emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True. - Returns: A :class:`~snowflake.snowpark.Column` with the current date and time. @@ -342,9 +340,6 @@ def localtime(_emit_ast: bool = True) -> Column: """ Returns the current time for the system. - Args: - _emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True. - Returns: A :class:`~snowflake.snowpark.Column` with the current local time. @@ -362,9 +357,6 @@ def systimestamp(_emit_ast: bool = True) -> Column: """ Returns the current timestamp for the system. - Args: - _emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True. - Returns: A :class:`~snowflake.snowpark.Column` with the current system timestamp. @@ -383,9 +375,6 @@ def invoker_role(_emit_ast: bool = True) -> Column: """ Returns the name of the role that was active when the current stored procedure or user-defined function was called. - Args: - _emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True. - Returns: Column: A Snowflake `Column` object representing the name of the active role. @@ -406,10 +395,6 @@ def invoker_share(_emit_ast: bool = True) -> Column: Returns the name of the share that directly accessed the table or view where the INVOKER_SHARE function is invoked, otherwise the function returns None. - Args: - _emit_ast (bool, optional): A flag indicating whether to emit the abstract - syntax tree (AST). Defaults to True. - Returns: Column: A Snowflake `Column` object representing the name of the active share. @@ -428,7 +413,6 @@ def is_application_role_in_session(role_name: str, _emit_ast: bool = True) -> Co Args: role_name (str): The name of the application role to check. - _emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True. Returns: A :class:`~snowflake.snowpark.Column` indicating whether the specified application role is active in the current session. @@ -452,7 +436,6 @@ def is_database_role_in_session( Args: role_name (ColumnOrName): The name of the database role to check. Can be a string or a Column. - _emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True. Returns: Column: A Snowflake `Column` object representing the result of the check. @@ -476,7 +459,6 @@ def is_granted_to_invoker_role(role_name: str, _emit_ast: bool = True) -> Column Args: role_name (str): The name of the role to check. - _emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True. Returns: Column: A Snowflake `Column` object representing the result of the check. @@ -498,7 +480,6 @@ def is_role_in_session(role: ColumnOrName, _emit_ast: bool = True) -> Column: Args: role (ColumnOrName): A Column or column name containing the role name to check. - _emit_ast (bool, optional): Whether to emit the abstract syntax tree (AST). Defaults to True. Returns: Column: A Snowflake `Column` object representing the result of the check. @@ -520,8 +501,6 @@ def getvariable(name: str, _emit_ast: bool = True) -> Column: Args: name (str): The name of the session variable to retrieve. - _emit_ast (bool, optional): A flag indicating whether to emit the abstract syntax tree (AST). - Defaults to True. Returns: Column: A Snowflake `Column` object representing the value of the specified session variable. @@ -531,3 +510,256 @@ def getvariable(name: str, _emit_ast: bool = True) -> Column: >>> assert result[0]["RESULT"] is None """ return builtin("getvariable", _emit_ast=_emit_ast)(name) + + +@publicapi +def h3_cell_to_boundary(cell_id: ColumnOrName, _emit_ast: bool = True) -> Column: + """Returns the boundary of an H3 cell as a GeoJSON polygon. + + Args: + cell_id (ColumnOrName): The H3 cell IDs. + + Returns: + Column: The boundary of the H3 cell as a GeoJSON polygon string. + + Example:: + >>> df = session.create_dataframe([613036919424548863, 577023702256844799], schema=["cell_id"]) + >>> result = df.select(h3_cell_to_boundary(df["cell_id"]).alias("boundary")).collect() + >>> len(result) == 2 + True + """ + c = _to_col_if_str(cell_id, "h3_cell_to_boundary") + return builtin("h3_cell_to_boundary", _emit_ast=_emit_ast)(c) + + +@publicapi +def h3_cell_to_parent( + cell_id: ColumnOrName, target_resolution: ColumnOrName, _emit_ast: bool = True +) -> Column: + """Returns the parent H3 cell at the specified target resolution. + + Args: + cell_id (ColumnOrName): The H3 cell IDs. + target_resolution (ColumnOrName) : The target resolution levels. + + Returns: + Column: The parent H3 cell at the target resolution. + + Example:: + >>> from snowflake.snowpark.functions import col + >>> df = session.create_dataframe([[613036919424548863, 7], [608533319805566975, 6]], schema=["cell_id", "target_resolution"]) + >>> df.select(h3_cell_to_parent(col("cell_id"), col("target_resolution")).alias("parent_cell")).collect() + [Row(PARENT_CELL=608533319805566975), Row(PARENT_CELL=604029720295636991)] + """ + cell_id_c = _to_col_if_str(cell_id, "h3_cell_to_parent") + target_resolution_c = _to_col_if_str(target_resolution, "h3_cell_to_parent") + return builtin("h3_cell_to_parent", _emit_ast=_emit_ast)( + cell_id_c, target_resolution_c + ) + + +@publicapi +def h3_cell_to_point(cell_id: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Returns the center point of an H3 cell as a GeoJSON Point object. + + Args: + cell_id (ColumnOrName): The H3 cell IDs. + + Returns: + Column: GeoJSON Point objects representing the center points of the H3 cells. + + Example:: + >>> import json + >>> df = session.create_dataframe([613036919424548863], schema=["cell_id"]) + >>> result = df.select(h3_cell_to_point(df["cell_id"]).alias("POINT")).collect() + >>> assert len(result) == 1 + >>> point_json = json.loads(result[0]["POINT"]) + >>> assert point_json["type"] == "Point" + >>> assert "coordinates" in point_json + >>> assert len(point_json["coordinates"]) == 2 + """ + c = _to_col_if_str(cell_id, "h3_cell_to_point") + return builtin("h3_cell_to_point", _emit_ast=_emit_ast)(c) + + +@publicapi +def h3_compact_cells(array_of_cell_ids: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Returns a compacted array of H3 cell IDs by merging cells at the same resolution into their parent cells when possible. + + Args: + array_of_cell_ids (ColumnOrName): An array of H3 cell IDs to be compacted. + + Returns: + Column: An array of compacted H3 cell IDs. + + Example:: + >>> df = session.create_dataframe([ + ... [[622236750562230271, 622236750562263039, 622236750562295807, 622236750562328575, 622236750562361343, 622236750562394111, 622236750562426879, 622236750558396415]] + ... ], schema=["cell_ids"]) + >>> df.select(h3_compact_cells(df["cell_ids"]).alias("compacted")).collect() + [Row(COMPACTED='[\\n 622236750558396415,\\n 617733150935089151\\n]')] + """ + c = _to_col_if_str(array_of_cell_ids, "h3_compact_cells") + return builtin("h3_compact_cells", _emit_ast=_emit_ast)(c) + + +@publicapi +def h3_compact_cells_strings( + array_of_cell_ids: ColumnOrName, _emit_ast: bool = True +) -> Column: + """ + Returns a compacted array of H3 cell IDs by removing redundant cells that are covered by their parent cells at coarser resolutions. + + Args: + array_of_cell_ids (ColumnOrName): An array of H3 cell ID strings to be compacted. + + Returns: + Column: The compacted array of H3 cell ID strings. + + Example:: + + >>> df = session.create_dataframe([[ + ... ['8a2a10705507fff', '8a2a1070550ffff', '8a2a10705517fff', '8a2a1070551ffff', + ... '8a2a10705527fff', '8a2a1070552ffff', '8a2a10705537fff', '8a2a10705cdffff'] + ... ]], schema=["cell_ids"]) + >>> df.select(h3_compact_cells_strings("cell_ids").alias("compacted")).collect() + [Row(COMPACTED='[\\n "8a2a10705cdffff",\\n "892a1070553ffff"\\n]')] + """ + c = _to_col_if_str(array_of_cell_ids, "h3_compact_cells_strings") + return builtin("h3_compact_cells_strings", _emit_ast=_emit_ast)(c) + + +@publicapi +def h3_coverage( + geography_expression: ColumnOrName, + target_resolution: ColumnOrName, + _emit_ast: bool = True, +) -> Column: + """ + Returns an array of H3 cell IDs that cover the given geography at the specified resolution. + + Args: + geography_expression (ColumnOrName) : A GEOGRAPHY object. + target_resolution (ColumnOrName) : The target H3 resolution (0-15). + + Returns: + Column: An array of H3 cell IDs as strings. + + Example:: + >>> from snowflake.snowpark.functions import to_geography, lit + >>> df = session.create_dataframe([ + ... ["POLYGON((-122.481889 37.826683,-122.479487 37.808548,-122.474150 37.808904,-122.476510 37.826935,-122.481889 37.826683))"] + ... ], schema=["polygon_wkt"]) + >>> result = df.select(h3_coverage(to_geography(df["polygon_wkt"]), lit(8)).alias("h3_cells")).collect() + >>> result + [Row(H3_CELLS='[\\n 613196571539931135,\\n 613196571542028287,\\n 613196571548319743,\\n 613196571550416895,\\n 613196571560902655,\\n 613196571598651391\\n]')] + """ + geography_col = _to_col_if_str(geography_expression, "h3_coverage") + resolution_col = _to_col_if_str(target_resolution, "h3_coverage") + return builtin("h3_coverage", _emit_ast=_emit_ast)(geography_col, resolution_col) + + +@publicapi +def h3_coverage_strings( + geography_expression: ColumnOrName, + target_resolution: Union[ColumnOrName, int], + _emit_ast: bool = True, +) -> Column: + """ + Returns an array of H3 cell identifiers as strings that cover the given geography at the specified resolution. + + Args: + geography_expression (ColumnOrName): The GEOGRAPHY to cover. + target_resolution (ColumnOrName, int): The H3 resolution level (0-15). + + Returns: + Column: An array of H3 cell identifiers as strings. + + Example:: + + >>> from snowflake.snowpark.functions import to_geography + >>> df = session.create_dataframe([ + ... "POLYGON((-122.481889 37.826683,-122.479487 37.808548,-122.474150 37.808904,-122.476510 37.826935,-122.481889 37.826683))" + ... ], schema=["geo_wkt"]) + >>> df.select(h3_coverage_strings(to_geography(df["geo_wkt"]), 8).alias("h3_cells")).collect() + [Row(H3_CELLS='[\\n "8828308701fffff",\\n "8828308703fffff",\\n "8828308709fffff",\\n "882830870bfffff",\\n "8828308715fffff",\\n "8828308739fffff"\\n]')] + """ + geo_col = _to_col_if_str(geography_expression, "h3_coverage_strings") + res_col = ( + target_resolution + if isinstance(target_resolution, Column) + else lit(target_resolution) + ) + return builtin("h3_coverage_strings", _emit_ast=_emit_ast)(geo_col, res_col) + + +@publicapi +def h3_get_resolution(cell_id: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Returns the resolution of an H3 cell ID. + + Args: + cell_id (ColumnOrName): The H3 cell ID. + + Returns: + Column: The resolution of the H3 cell ID. + + Example:: + + >>> df = session.create_dataframe([617540519050084351, 617540519050084352], schema=["cell_id"]) + >>> df.select(h3_get_resolution(df["cell_id"]).alias("resolution")).collect() + [Row(RESOLUTION=9), Row(RESOLUTION=9)] + """ + c = _to_col_if_str(cell_id, "h3_get_resolution") + return builtin("h3_get_resolution", _emit_ast=_emit_ast)(c) + + +@publicapi +def h3_grid_disk( + cell_id: ColumnOrName, k_value: ColumnOrName, _emit_ast: bool = True +) -> Column: + """ + Returns an array of H3 cell IDs within k distance of the origin cell. + + Args: + cell_id (ColumnOrName): The H3 cell ID as the center of the disk. + k_value (ColumnOrName): The distance (number of rings) from the center cell. + + Returns: + Column: An array of H3 cell IDs within the specified distance. + + Example:: + + >>> df = session.create_dataframe([[617540519050084351, 1], [617540519050084351, 2]], schema=["cell_id", "k_value"]) + >>> df.select(h3_grid_disk("cell_id", "k_value").alias("grid_disk")).collect() + [Row(GRID_DISK='[\\n 617540519050084351,\\n 617540519051657215,\\n 617540519050608639,\\n 617540519050870783,\\n 617540519050346495,\\n 617540519051395071,\\n 617540519051132927\\n]'), Row(GRID_DISK='[\\n 617540519050084351,\\n 617540519051657215,\\n 617540519050608639,\\n 617540519050870783,\\n 617540519050346495,\\n 617540519051395071,\\n 617540519051132927,\\n 617540519048249343,\\n 617540519048773631,\\n 617540519089143807,\\n 617540519088095231,\\n 617540519107756031,\\n 617540519108018175,\\n 617540519104086015,\\n 617540519103561727,\\n 617540519046414335,\\n 617540519047462911,\\n 617540519044579327,\\n 617540519044317183\\n]')] + """ + cell_id_col = _to_col_if_str(cell_id, "h3_grid_disk") + k_value_col = _to_col_if_str(k_value, "h3_grid_disk") + return builtin("h3_grid_disk", _emit_ast=_emit_ast)(cell_id_col, k_value_col) + + +@publicapi +def h3_grid_distance( + cell_id_1: ColumnOrName, cell_id_2: ColumnOrName, _emit_ast: bool = True +) -> Column: + """ + Returns the grid distance between two H3 cell IDs. + + Args: + cell_id_1 (ColumnOrName): The first H3 cell ID column or value. + cell_id_2 (ColumnOrName): The second H3 cell ID column or value. + + Returns: + Column: The grid distance between the two H3 cells. + + Example:: + >>> df = session.create_dataframe([[617540519103561727, 617540519052967935]], schema=["cell_id_1", "cell_id_2"]) + >>> df.select(h3_grid_distance(df["cell_id_1"], df["cell_id_2"]).alias("distance")).collect() + [Row(DISTANCE=5)] + """ + cell_id_1 = _to_col_if_str(cell_id_1, "h3_grid_distance") + cell_id_2 = _to_col_if_str(cell_id_2, "h3_grid_distance") + return builtin("h3_grid_distance", _emit_ast=_emit_ast)(cell_id_1, cell_id_2)