diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a3ed0f16b..e439f8473a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History +## 1.41.0 (YYYY-MM-DD) + +### Snowpark Python API Updates + +#### New Features + +- Added a new function `service` in `snowflake.snowpark.functions` that allows users to create a callable representing a Snowpark Container Services (SPCS) service. + ## 1.40.0 (YYYY-MM-DD) ### Snowpark Python API Updates diff --git a/docs/source/snowpark/functions.rst b/docs/source/snowpark/functions.rst index baaf9beefe..b49e001954 100644 --- a/docs/source/snowpark/functions.rst +++ b/docs/source/snowpark/functions.rst @@ -414,6 +414,7 @@ Functions seq4 seq8 sequence + service sha1 sha2 sin diff --git a/src/conftest.py b/src/conftest.py index cb859391c1..acc91767af 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -97,6 +97,7 @@ def pytest_collection_modifyitems(config, items): disabled_doctests = [ "ai_classify", "model", + "service", ] # Add any test names that should be skipped for item in items: # identify doctest items diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 6749972980..5700ae2fc9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -28,6 +28,7 @@ like_expression, list_agg, model_expression, + service_expression, named_arguments_function, order_expression, range_statement, @@ -74,6 +75,7 @@ ListAgg, Literal, ModelExpression, + ServiceExpression, MultipleExpression, NamedExpression, NamedFunctionExpression, @@ -430,6 +432,16 @@ def analyze( ], ) + if isinstance(expr, ServiceExpression): + return service_expression( + expr.service_name, + expr.method_name, + [ + self.to_sql_try_avoid_cast(c, df_aliased_col_name_to_real_col_name) + for c in expr.children + ], + ) + if isinstance(expr, FunctionExpression): if expr.api_call_source is not None: self.session._conn._telemetry_client.send_function_usage_telemetry( diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 92f9ab2344..3b84ad61ee 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -278,6 +278,14 @@ def model_expression( return f"{MODEL}{LEFT_PARENTHESIS}{model_args_str}{RIGHT_PARENTHESIS}{EXCLAMATION_MARK}{method_name}{LEFT_PARENTHESIS}{COMMA.join(children)}{RIGHT_PARENTHESIS}" +def service_expression( + service_name: str, + method_name: str, + children: List[str], +) -> str: + return f"{service_name}{EXCLAMATION_MARK}{method_name}{LEFT_PARENTHESIS}{COMMA.join(children)}{RIGHT_PARENTHESIS}" + + def function_expression(name: str, children: List[str], is_distinct: bool) -> str: return ( name diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 05e8f9e672..d95dcdc95a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -575,6 +575,29 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION +class ServiceExpression(Expression): + def __init__( + self, + service_name: str, + method_name: str, + arguments: List[Expression], + ) -> None: + super().__init__() + self.service_name = service_name + self.method_name = method_name + self.children = arguments + + def dependent_column_names(self) -> Optional[AbstractSet[str]]: + return derive_dependent_columns(*self.children) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION + + class FunctionExpression(Expression): def __init__( self, diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index b5a0f33ba4..236d5c0ff8 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -181,6 +181,7 @@ ListAgg, Literal, ModelExpression, + ServiceExpression, MultipleExpression, Star, NamedFunctionExpression, @@ -10746,6 +10747,30 @@ def _call_model( ) +def _call_service( + service_name: str, + method_name: str, + *args, + _emit_ast: bool = True, +) -> Column: + if _emit_ast: + _ast = build_function_expr("service", [service_name, method_name, *args]) + else: + _ast = None + + args_list = parse_positional_args_to_list(*args) + expressions = [Column._to_expr(arg) for arg in args_list] + return Column( + ServiceExpression( + service_name, + method_name, + expressions, + ), + _ast=_ast, + _emit_ast=_emit_ast, + ) + + @publicapi def model( model_name: str, @@ -10775,6 +10800,44 @@ def model( ) +@publicapi +def service( + service_name: str, + _emit_ast: bool = True, +) -> Callable: + """ + Creates a service function that can be used to call a service method. + + Args: + service_name: The name of the service to call. + + Example:: + + >>> service_instance = service("TESTSCHEMA_SNOWPARK_PYTHON.FORECAST_MODEL_SERVICE") + >>> # Prepare a DataFrame with the ten expected features + >>> df = session.create_dataframe( + ... [ + ... (0.038076, 0.050680, 0.061696, 0.021872, -0.044223, -0.034821, -0.043401, -0.002592, 0.019907, -0.017646), + ... ], + ... schema=["age", "sex", "bmi", "bp", "s1", "s2", "s3", "s4", "s5", "s6"], + ... ) + >>> # Invoke the model's predict method exposed by the service + >>> result_df = df.select( + ... service_instance("predict")(col("age"), col("sex"), col("bmi"), col("bp"), col("s1"), col("s2"), col("s3"), col("s4"), col("s5"), col("s6"))["output_feature_0"] + ... ) + >>> result_df.show() + ------------------------------------------------------ + |"TESTSCHEMA_SNOWPARK_PYTHON.FORECAST_MODEL_SERV... | + ------------------------------------------------------ + |220.2223358154297 | + ------------------------------------------------------ + + """ + return lambda method_name: lambda *args: _call_service( + service_name, method_name, *args, _emit_ast=_emit_ast + ) + + # Add these alias for user code migration call_builtin = call_function collect_set = array_unique_agg