diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 612055329c..843d717238 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -12,6 +12,7 @@ import logging import re from collections import Counter, defaultdict +from dataclasses import dataclass, field import typing import uuid from collections.abc import Hashable, Iterable, Mapping, Sequence @@ -23,6 +24,7 @@ Callable, List, Literal, + NamedTuple, Optional, TypeVar, Union, @@ -32,6 +34,7 @@ ) import modin.pandas as pd +from modin.config import AutoSwitchBackend from modin.pandas import Series, DataFrame from modin.pandas.base import BasePandasDataset import numpy as np @@ -499,9 +502,116 @@ HYBRID_ALL_EXPENSIVE_METHODS = ( HYBRID_HIGH_OVERHEAD_METHODS + HYBRID_ITERATIVE_STYLE_METHODS ) -# Set of (class name, method name) tuples for methods that are wholly unimplemented by + + +class MethodKey(NamedTuple): + """Named tuple for method registry keys.""" + + api_cls_name: Optional[str] + method_name: str + + +@dataclass +class UnsupportedKwargsRule: + """Rule defining which kwargs should trigger auto-switching.""" + + # List of conditions that can be either: + # - tuple[Callable, str]: (condition_function, reason) for complex conditions + # - tuple[str, Any]: (argument_name, unsupported_value) for simple value checks + unsupported_conditions: List[Union[Tuple[Callable, str], Tuple[str, Any]]] = field( + default_factory=list + ) + + +# Set of MethodKey objects for methods that are wholly unimplemented by # Snowpark pandas. This list is populated by the register_*_not_implemented decorators. -HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS: Set[Tuple[str, str]] = set() +HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS: Set[MethodKey] = set() + +# Global registry for kwargs-based switching rules +# Format: {MethodKey(class_name, method_name): UnsupportedKwargsRule} +HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS: dict[MethodKey, UnsupportedKwargsRule] = {} + +# POC: concat method rule will be registered via enhanced decorator in general_overrides.py + + +def register_query_compiler_method_not_implemented( + api_cls_name: str, + method_name: str, + unsupported_kwargs: Optional["UnsupportedKwargsRule"] = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorator for SnowflakeQueryCompiler methods with kwargs-based auto-switching. + + Registers pre-op switching and (when auto-switching is disabled) raises a + NotImplementedError if any unsupported-parameter predicate evaluates True. + + Args: + api_cls_name: Frontend class name (e.g., "BasePandasDataset", "Series", "DataFrame", "None"). + method_name: Method name to register. + unsupported_kwargs: UnsupportedKwargsRule for kwargs-based auto-switching. + If None, method is treated as completely unimplemented. + """ + reg_key = MethodKey(api_cls_name, method_name) + + # register the method in the hybrid switch for unsupported params + if unsupported_kwargs is None: + HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(reg_key) + else: + HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS[reg_key] = unsupported_kwargs + + # Local import to avoid import cycles (as in original) + from modin.core.storage_formats.pandas.query_compiler_caster import ( + register_function_for_pre_op_switch, + ) + + register_function_for_pre_op_switch( + class_name=api_cls_name, + backend="Snowflake", + method=method_name, + arg_based=unsupported_kwargs is not None, + ) + + def decorator(query_compiler_method: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(query_compiler_method) + def wrapper(self: "SnowflakeQueryCompiler", *args: Any, **kwargs: Any) -> Any: + # Fast path: if auto-switching is enabled or there are no rules, call directly. + if AutoSwitchBackend.get() or unsupported_kwargs is None: + return query_compiler_method(self, *args, **kwargs) + + # Present kwargs as a read-only mapping to predicates. + arguments = MappingProxyType(kwargs) + + # Check if any condition triggers unsupported behavior + if SnowflakeQueryCompiler._has_unsupported_kwargs( + api_cls_name, method_name, arguments + ): + # Get specific reason and build error + reason = SnowflakeQueryCompiler._get_unsupported_kwargs_reason( + api_cls_name, method_name, arguments + ) + if reason: + ErrorMessage.not_implemented_with_reason(method_name, reason) + else: + # Fallback to generic error + params_str = ( + ", ".join(f"{k}={repr(v)}" for k, v in arguments.items()) + or "provided arguments" + ) + raise NotImplementedError( + f"Snowpark pandas doesn't support '{method_name}' with the parameter combination: {params_str}. " + f"This combination of parameters is not supported in Snowflake and requires pandas execution. " + f"However, auto-switching is disabled, so this operation cannot fall back to pandas. Enable " + f"auto-switching with " + f"'from modin.config import AutoSwitchBackend; AutoSwitchBackend.enable()' " + f"to use pandas for unsupported operations." + ) + + return query_compiler_method(self, *args, **kwargs) + + return wrapper + + return decorator + T = TypeVar("T", bound=Callable[..., Any]) @@ -871,6 +981,102 @@ def _are_dtypes_compatible_with_snowflake(cls, compiler: BaseQueryCompiler) -> b return False return True + @classmethod + def _has_unsupported_kwargs( + cls, + api_cls_name: Optional[str], + operation: Optional[str], + arguments: MappingProxyType[str, Any], + ) -> bool: + """ + Check if method call contains unsupported kwargs that require pandas backend. + + Args: + api_cls_name: Class name (DataFrame, Series, BasePandasDataset, None for top-level functions) + operation: Method name + arguments: Method arguments including kwargs + + Returns: + True if unsupported kwargs detected and auto-switch should occur + """ + if not operation: + return False + + rule = HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS.get( + MethodKey(api_cls_name, operation) + ) + if rule is None: + return False + + # Check custom conditions + for condition in rule.unsupported_conditions: + try: + if callable(condition[0]): + # tuple[Callable, str]: (condition_function, reason) + condition_func, _ = condition + if condition_func(arguments): + return True + else: + # tuple[str, Any]: (argument_name, unsupported_value) + arg_name, unsupported_value = condition + if arguments.get(arg_name) == unsupported_value: + return True + except Exception: + # Any error - skip this condition + continue + + return False + + @classmethod + def _get_unsupported_kwargs_reason( + cls, + api_cls_name: Optional[str], + operation: Optional[str], + arguments: MappingProxyType[str, Any], + ) -> Optional[str]: + """ + Get the specific reason why kwargs are unsupported. + + Args: + api_cls_name: Class name (DataFrame, Series, BasePandasDataset, None for top-level functions) + operation: Method name + arguments: Method arguments including kwargs + + Returns: + The specific reason string if unsupported kwargs detected, None otherwise + """ + if not operation: + return None + + rule = HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS.get( + MethodKey(api_cls_name, operation) + ) + if rule is None: + return None + + # Check custom conditions and return the specific reason + for condition in rule.unsupported_conditions: + try: + if len(condition) == 2 and callable(condition[0]): + # tuple[Callable, str]: (condition_function, reason) + condition_func, reason = condition + if condition_func(arguments): + return reason + elif len(condition) == 2: + # tuple[str, Any]: (argument_name, unsupported_value) + arg_name, unsupported_value = condition + if ( + isinstance(arg_name, str) + and arguments.get(arg_name) == unsupported_value + ): + # Auto-generate error message for simple value checks + return f"{arg_name}={repr(unsupported_value)} is not supported" + except Exception: + # Any error - skip this condition + continue + + return None + @classmethod def _transfer_threshold(cls) -> int: return SnowflakePandasTransferThreshold.get() @@ -916,9 +1122,21 @@ def stay_cost( ) -> Optional[int]: if ( self._is_in_memory_init(api_cls_name, operation, arguments) - or (api_cls_name, operation) in HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS + or MethodKey(api_cls_name, operation) + in HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS ): return QCCoercionCost.COST_IMPOSSIBLE + + # POC: Check for unsupported kwargs that require pandas backend + if arguments and SnowflakeQueryCompiler._has_unsupported_kwargs( + api_cls_name, operation, arguments + ): + WarningMessage.single_warning( + f"Method '{operation}' with specified parameters requires pandas backend. " + f"Automatically switching for execution." + ) + + return QCCoercionCost.COST_IMPOSSIBLE # Strongly discourage the use of these methods in snowflake if operation in HYBRID_ALL_EXPENSIVE_METHODS: return QCCoercionCost.COST_HIGH @@ -953,10 +1171,31 @@ def move_to_me_cost( if ( cls._is_in_memory_init(api_cls_name, operation, arguments) or not cls._are_dtypes_compatible_with_snowflake(other_qc) - or (api_cls_name, operation) in HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS + or ( + operation is not None + and MethodKey(api_cls_name, operation) + in HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS + ) ): return QCCoercionCost.COST_IMPOSSIBLE + # POC: Check for unsupported kwargs that require pandas backend + if arguments and cls._has_unsupported_kwargs( + api_cls_name, operation, arguments + ): + from snowflake.snowpark.modin.plugin.utils.warning_message import ( + WarningMessage, + ) + + WarningMessage.single_warning( + f"Method '{operation}' with specified parameters requires pandas backend. " + f"Automatically switching for execution." + ) + return QCCoercionCost.COST_IMPOSSIBLE + + if (api_cls_name, operation) in HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS: + return QCCoercionCost.COST_ZERO + if arguments is not None and ( ( ( @@ -1949,6 +2188,56 @@ def _to_snowpark_dataframe_from_snowpark_pandas_dataframe( col_mapper=rename_mapper ) + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="to_csv", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: "float_format" in kwargs + and kwargs.get("float_format") + is not TO_CSV_DEFAULTS["float_format"], + "custom float_format is not supported", + ), + ( + lambda kwargs: "mode" in kwargs + and kwargs.get("mode") is not TO_CSV_DEFAULTS["mode"], + "non-default mode is not supported", + ), + ( + lambda kwargs: "encoding" in kwargs + and kwargs.get("encoding") is not TO_CSV_DEFAULTS["encoding"], + "custom encoding is not supported", + ), + ( + lambda kwargs: "quoting" in kwargs + and kwargs.get("quoting") is not TO_CSV_DEFAULTS["quoting"], + "custom quoting is not supported", + ), + ( + lambda kwargs: "quotechar" in kwargs + and kwargs.get("quotechar") is not TO_CSV_DEFAULTS["quotechar"], + "custom quotechar is not supported", + ), + ( + lambda kwargs: "lineterminator" in kwargs + and kwargs.get("lineterminator") + is not TO_CSV_DEFAULTS["lineterminator"], + "custom lineterminator is not supported", + ), + ( + lambda kwargs: "doublequote" in kwargs + and kwargs.get("doublequote") is not TO_CSV_DEFAULTS["doublequote"], + "custom doublequote is not supported", + ), + ( + lambda kwargs: "decimal" in kwargs + and kwargs.get("decimal") is not TO_CSV_DEFAULTS["decimal"], + "custom decimal is not supported", + ), + ] + ), + ) def to_csv_with_snowflake(self, **kwargs: Any) -> None: """ Write data to a csv file in snowflake stage. @@ -2353,6 +2642,23 @@ def _shift_index(self, periods: int, freq: Any) -> "SnowflakeQueryCompiler": # # TODO: SNOW-1023324, implement shifting index only. ErrorMessage.not_implemented("shifting index values not yet supported.") + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="shift", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("suffix") is not None, + "suffix parameter is not supported", + ), + ( + lambda kwargs: "periods" in kwargs + and not isinstance(kwargs.get("periods"), int), + "non-integer periods are not supported", + ), + ] + ), + ) def shift( self, periods: Union[int, Sequence[int]] = 1, @@ -2755,6 +3061,32 @@ def binary_op( ) return self._maybe_set_relaxed_qc(qc, relaxed_query_compiler) + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="binary_op", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("level") is not None, + "level parameter is not supported", + ), + ( + lambda kwargs: kwargs.get("fill_value") is not None + and not is_scalar(kwargs.get("fill_value")), + "non-scalar fill_value is not supported", + ), + ( + lambda kwargs: kwargs.get("squeeze_self") is not False, + "squeeze_self=True is not supported", + ), + ( + lambda kwargs: kwargs.get("op") is not None + and not BinaryOp.is_binary_op_supported(kwargs.get("op")), + "unsupported binary operation", + ), + ] + ), + ) def _binary_op_internal( self, op: str, @@ -3013,6 +3345,18 @@ def any( False, "any", axis=axis, _bool_only=bool_only, skipna=skipna ) + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="reindex", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: "axis" in kwargs and kwargs.get("axis") != 0, + "axis=1 (column reindexing) is not supported", + ), + ] + ), + ) def reindex( self, axis: int, @@ -3940,6 +4284,30 @@ def first_last_valid_index( ) return None + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="sort_index", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: "axis" in kwargs and kwargs.get("axis") != 0, + "axis=1 (column sorting) is not supported", + ), + ( + lambda kwargs: kwargs.get("key") is not None, + "key parameter for custom sorting is not supported", + ), + ( + lambda kwargs: kwargs.get("level") is not None, + "level parameter for multiindex sorting is not supported", + ), + ( + lambda kwargs, self: self._modin_frame.is_multiindex(), + "multiindex operations are not supported", + ), + ] + ), + ) def sort_index( self, *, @@ -4030,23 +4398,35 @@ def sort_columns_by_row_values( self, rows: IndexLabel, ascending: bool = True, **kwargs: Any ) -> None: """ - Reorder the columns based on the lexicographic order of the given rows. - - Args: - rows : label or list of labels - The row or rows to sort by. - ascending : bool, default: True - Sort in ascending order (True) or descending order (False). - **kwargs : dict - Serves the compatibility purpose. Does not affect the result. + Reorder the columns based on the lexicographic order of the given rows. - Returns: - New QueryCompiler that contains result of the sort. + Args: + rows : label or list of labels + The row or rows to sort by. + ascending : bool, default: True + Sort in ascending order (True) or descending order (False). + **kwargs : dict + Serves the compatibility purpose. Does not affect the result. + o + Returns: + New QueryCompiler that contains result of the sort. """ ErrorMessage.not_implemented( "Snowpark pandas sort_values API doesn't yet support axis == 1" ) + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="sort_rows_by_column_values", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("key") is not None, + "key parameter is not supported", + ), + ] + ), + ) def sort_rows_by_column_values( self, columns: list[Hashable], @@ -7760,6 +8140,18 @@ def dataframe_to_datetime( ) ) + @register_query_compiler_method_not_implemented( + api_cls_name="Series", + method_name="to_timedelta", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("errors") != "raise", + "only errors='raise' is supported, not 'ignore' or 'coerce'", + ), + ] + ), + ) def to_timedelta( self, unit: str = "ns", @@ -7956,10 +8348,6 @@ def concat( NOTE: Original column level names are lost and result column index has only one level. """ - if levels is not None: - raise NotImplementedError( - "Snowpark pandas doesn't support 'levels' argument in concat API" - ) frames = [self._modin_frame] + [o._modin_frame for o in other] for frame in frames: self._raise_not_implemented_error_for_timedelta(frame=frame) @@ -8138,6 +8526,18 @@ def concat( qc._attrs = copy.deepcopy(self._attrs) return qc + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="cumsum", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + "axis", + 1, + ), # Simple value check - auto-generates "axis=1 is not supported" + ] + ), + ) def cumsum( self, axis: int = 0, skipna: bool = True, *args: Any, **kwargs: Any ) -> "SnowflakeQueryCompiler": @@ -8156,9 +8556,7 @@ def cumsum( SnowflakeQueryCompiler instance with cumulative sum of Series or DataFrame. """ self._raise_not_implemented_error_for_timedelta() - - if axis == 1: - ErrorMessage.not_implemented("cumsum with axis=1 is not supported yet") + # POC: axis=1 is now handled by the decorator's kwargs-based switching system cumagg_col_to_expr_map = get_cumagg_col_to_expr_map_axis0(self, sum_, skipna) return SnowflakeQueryCompiler( @@ -8254,10 +8652,7 @@ def melt( Notes: melt does not yet handle multiindex or ignore index """ - if col_level is not None: - raise NotImplementedError( - "Snowpark Pandas doesn't support 'col_level' argument in melt API" - ) + # Note: col_level parameter handling is managed by the decorator's unsupported_if condition if self._modin_frame.is_multiindex(axis=1): raise NotImplementedError( "Snowpark Pandas doesn't support multiindex columns in melt API" @@ -9944,6 +10339,67 @@ def pivot( sort=True, ) + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="pivot_table", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ("observed", True), # Simple value check + ( + lambda kwargs: kwargs.get("sort", True) is not True, + "sort=False is not supported", + ), + ( + lambda kwargs: bool( + kwargs.get("index") + and ( + not isinstance(kwargs.get("index"), str) + and not all( + [isinstance(v, str) for v in kwargs.get("index")] + ) + and None not in kwargs.get("index") + ) + ), + "non-string index values are not supported", + ), + ( + lambda kwargs: bool( + kwargs.get("columns") + and ( + not isinstance(kwargs.get("columns"), str) + and not all( + [isinstance(v, str) for v in kwargs.get("columns")] + ) + and None not in kwargs.get("columns") + ) + ), + "non-string column values are not supported", + ), + ( + lambda kwargs: bool( + kwargs.get("values") + and ( + not isinstance(kwargs.get("values"), str) + and not all( + [isinstance(v, str) for v in kwargs.get("values")] + ) + and None not in kwargs.get("values") + ) + ), + "non-string values parameter is not supported", + ), + ( + lambda kwargs: isinstance(kwargs.get("aggfunc"), dict) + and any( + not isinstance(af, str) + for af in kwargs.get("aggfunc", {}).values() + ) + and kwargs.get("index") is None, + "complex aggfunc dictionary with no index is not supported", + ), + ] + ), + ) def pivot_table( self, index: Any, @@ -10065,10 +10521,7 @@ def pivot_table( raise ValueError( "Margins not supported if list of aggregation functions" ) - elif index is None: - raise NotImplementedError( - "Not implemented index is None and list of aggregation functions." - ) + # The case where index is None is now handled by the decorator's unsupported_if_kwargs # Duplicate pivot column and index are not allowed, but duplicate aggregation values are supported. index_and_data_column_pandas_labels = ( @@ -10478,6 +10931,18 @@ def nunique_index(self, dropna: bool) -> int: .iloc[0, 0] ) + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="nunique", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("axis") == 1, + "axis=1 (column-wise unique count) is not supported", + ), + ] + ), + ) def nunique( self, axis: Axis, dropna: bool, **kwargs: Any ) -> "SnowflakeQueryCompiler": @@ -15214,6 +15679,22 @@ def rolling_max( agg_kwargs=dict(numeric_only=numeric_only), ) + @register_query_compiler_method_not_implemented( + api_cls_name="BasePandasDataset", + method_name="rolling_corr", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("other") is None, + "other parameter is required", + ), + ( + lambda kwargs: kwargs.get("pairwise") is True, + "pairwise=True is not supported", + ), + ] + ), + ) def rolling_corr( self, fold_axis: Union[int, str], @@ -18981,6 +19462,22 @@ def dt_to_timestamp(self) -> None: "Snowpark pandas doesn't yet support the method 'Series.dt.to_timestamp'" ) + @register_query_compiler_method_not_implemented( + api_cls_name="Series", + method_name="dt_tz_localize", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("ambiguous") != "raise", + "only ambiguous='raise' is supported", + ), + ( + lambda kwargs: kwargs.get("nonexistent") != "raise", + "only nonexistent='raise' is supported", + ), + ] + ), + ) def dt_tz_localize( self, tz: Union[str, tzinfo], @@ -19054,6 +19551,22 @@ def dt_tz_convert( ) ) + @register_query_compiler_method_not_implemented( + api_cls_name="Series", + method_name="dt_ceil", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("ambiguous") != "raise", + "only ambiguous='raise' is supported", + ), + ( + lambda kwargs: kwargs.get("nonexistent") != "raise", + "only nonexistent='raise' is supported", + ), + ] + ), + ) def dt_ceil( self, freq: Frequency, @@ -19137,6 +19650,22 @@ def ceil_func(column: SnowparkColumn) -> SnowparkColumn: ) ) + @register_query_compiler_method_not_implemented( + api_cls_name="Series", + method_name="dt_round", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("ambiguous") != "raise", + "only ambiguous='raise' is supported", + ), + ( + lambda kwargs: kwargs.get("nonexistent") != "raise", + "only nonexistent='raise' is supported", + ), + ] + ), + ) def dt_round( self, freq: Frequency, @@ -19298,6 +19827,22 @@ def round_func(column: SnowparkColumn) -> SnowparkColumn: ) ) + @register_query_compiler_method_not_implemented( + api_cls_name="Series", + method_name="dt_floor", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("ambiguous") != "raise", + "only ambiguous='raise' is supported", + ), + ( + lambda kwargs: kwargs.get("nonexistent") != "raise", + "only nonexistent='raise' is supported", + ), + ] + ), + ) def dt_floor( self, freq: Frequency, @@ -19396,6 +19941,18 @@ def normalize_column(column: SnowparkColumn) -> SnowparkColumn: ) ) + @register_query_compiler_method_not_implemented( + api_cls_name="Series", + method_name="dt_month_name", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("locale") is not None, + "locale parameter is not supported", + ), + ] + ), + ) def dt_month_name( self, locale: Optional[str] = None, include_index: bool = False ) -> "SnowflakeQueryCompiler": @@ -19430,6 +19987,18 @@ def month_name_func(column: SnowparkColumn) -> SnowparkColumn: ) ) + @register_query_compiler_method_not_implemented( + api_cls_name="Series", + method_name="dt_day_name", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("locale") is not None, + "locale parameter is not supported", + ), + ] + ), + ) def dt_day_name( self, locale: Optional[str] = None, include_index: bool = False ) -> "SnowflakeQueryCompiler": @@ -21013,6 +21582,46 @@ def groupby_unique( ) ) + @register_query_compiler_method_not_implemented( + api_cls_name="Series", + method_name="hist", + unsupported_kwargs=UnsupportedKwargsRule( + unsupported_conditions=[ + ( + lambda kwargs: kwargs.get("by") is not None, + "by parameter is not supported", + ), + ( + lambda kwargs: kwargs.get("xlabelsize") is not None, + "xlabelsize parameter is not supported", + ), + ( + lambda kwargs: kwargs.get("xrot") is not None, + "xrot parameter is not supported", + ), + ( + lambda kwargs: kwargs.get("ylabelsize") is not None, + "ylabelsize parameter is not supported", + ), + ( + lambda kwargs: kwargs.get("yrot") is not None, + "yrot parameter is not supported", + ), + ( + lambda kwargs: kwargs.get("figsize") is not None, + "figsize parameter is not supported", + ), + ( + lambda kwargs: kwargs.get("backend") is not None, + "backend parameter is not supported", + ), + ( + lambda kwargs: kwargs.get("legend") is True, + "legend=True is not supported", + ), + ] + ), + ) def hist_on_series( self, by: object = None, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 3f65c2913d..27add85022 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -69,6 +69,7 @@ from snowflake.snowpark.modin.plugin._typing import ListLike from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS, + MethodKey, ) from snowflake.snowpark.modin.plugin._internal.utils import new_snow_series from snowflake.snowpark.modin.plugin.extensions.utils import ( @@ -97,9 +98,20 @@ def register_base_not_implemented(): + """ + POC: Enhanced decorator for BasePandasDataset methods with kwargs-based auto-switching. + + Args: + unsupported_kwargs: UnsupportedKwargsRule for kwargs-based auto-switching. + If None, method is completely unimplemented (original behavior). + """ + def decorator(base_method: Any): name = base_method.__name__ - HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(("BasePandasDataset", name)) + + HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add( + MethodKey("BasePandasDataset", name) + ) register_function_for_pre_op_switch( class_name="BasePandasDataset", backend="Snowflake", method=name ) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py index 732a79cfa2..5d5711d921 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py @@ -89,6 +89,7 @@ from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS, + MethodKey, ) from snowflake.snowpark.modin.plugin.extensions.index import Index from snowflake.snowpark.modin.plugin.extensions.snow_partition_iterator import ( @@ -136,7 +137,7 @@ def register_dataframe_not_implemented(): def decorator(base_method: Any): func = dataframe_not_implemented()(base_method) name = base_method.__name__ - HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(("DataFrame", name)) + HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(MethodKey("DataFrame", name)) register_function_for_pre_op_switch( class_name="DataFrame", backend="Snowflake", method=name ) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 19ee80a9d2..1d87d6b3d0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -59,6 +59,7 @@ ) from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS, + MethodKey, ) from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike from snowflake.snowpark.modin.plugin.extensions.snow_partition_iterator import ( @@ -102,7 +103,7 @@ def decorator(base_method: Any): if isinstance(base_method, property) else base_method.__name__ ) - HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(("Series", name)) + HYBRID_SWITCH_FOR_UNIMPLEMENTED_METHODS.add(MethodKey("Series", name)) register_function_for_pre_op_switch( class_name="Series", backend="Snowflake", method=name ) diff --git a/src/snowflake/snowpark/modin/plugin/utils/error_message.py b/src/snowflake/snowpark/modin/plugin/utils/error_message.py index 9d3b71f946..6f0fda1fd6 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/error_message.py +++ b/src/snowflake/snowpark/modin/plugin/utils/error_message.py @@ -162,6 +162,26 @@ def not_implemented(cls, message: str) -> NoReturn: # pragma: no cover logger.debug(f"NotImplementedError: {message}") raise NotImplementedError(message) + @classmethod + def not_implemented_with_reason( + cls, method_name: str, reason: str + ) -> NoReturn: # pragma: no cover + """ + Raise NotImplementedError with specific reason. + + Args: + method_name: Name of the method that's not implemented + reason: Specific reason why the parameters are not supported + """ + message = ( + f"Snowpark pandas {method_name} does not yet support the parameter combination because {reason}. " + f"Enable auto-switching with 'from modin.config import AutoSwitchBackend; AutoSwitchBackend.enable()' " + f"to use pandas for unsupported operations." + ) + + logger.debug(f"NotImplementedError: {message}") + raise NotImplementedError(message) + @classmethod def not_implemented_for_timedelta(cls, method: str) -> NoReturn: ErrorMessage.not_implemented( diff --git a/test_cumsum_POC.py b/test_cumsum_POC.py new file mode 100644 index 0000000000..1ce070f640 --- /dev/null +++ b/test_cumsum_POC.py @@ -0,0 +1,35 @@ +# Setup +from tests.parameters import CONNECTION_PARAMETERS +from snowflake.snowpark import Session +import modin.pandas as pd +from modin.config import context as config_context + +from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( + HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS, +) + +session = Session.builder.configs(CONNECTION_PARAMETERS).create() +pd.session = session + +# Verify rule registration +cumsum_rule = HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS.get(("BasePandasDataset", "cumsum")) +# print(cumsum_rule) + +# Test autoswitch = False, should raise NotImplementedError +df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) +snow_df = df.move_to("Snowflake") +with config_context(AutoSwitchBackend=False): + try: + snow_df.cumsum(axis=1) + raise AssertionError("Should have raised NotImplementedError") + except NotImplementedError as e: + # print(e) + assert "requires pandas backend" in str( + e + ), "Should mention 'requires pandas backend'" + +# Test autoswitch = True, should switch to Pandas +with config_context(AutoSwitchBackend=True): + result = snow_df.cumsum(axis=1) + # print(result.get_backend()) + # print(result) diff --git a/test_real_cumsum_POC.py b/test_real_cumsum_POC.py new file mode 100644 index 0000000000..1a3e0144b6 --- /dev/null +++ b/test_real_cumsum_POC.py @@ -0,0 +1,140 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) + + +def test_kwargs_based_auto_switching(): + """ + Test kwargs-based auto-switching with cumsum method frontend override. + + This test validates that the enhanced decorator approach works for methods + with parameter limitations, demonstrating cost-based backend switching. + """ + try: + from tests.parameters import CONNECTION_PARAMETERS + from snowflake.snowpark import Session + import modin.pandas as pd + from modin.config import context as config_context + from types import MappingProxyType + from modin.core.storage_formats.base.query_compiler import QCCoercionCost + + from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( + SnowflakeQueryCompiler, + HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS, + ) + + session = Session.builder.configs(CONNECTION_PARAMETERS).create() + pd.session = session + + # Verify kwargs rule registration + cumsum_rule = HYBRID_SWITCH_FOR_UNSUPPORTED_PARAMS.get( + ("BasePandasDataset", "cumsum") + ) + assert cumsum_rule is not None, "cumsum rule should be registered" + assert ( + len(cumsum_rule.unsupported_conditions) > 0 + ), "axis condition should be registered" + + # Test kwargs detection logic + args_axis0 = MappingProxyType({"axis": 0, "skipna": True}) + args_axis1 = MappingProxyType({"axis": 1, "skipna": True}) + + should_switch_axis0 = SnowflakeQueryCompiler._has_unsupported_kwargs( + "BasePandasDataset", "cumsum", args_axis0 + ) + should_switch_axis1 = SnowflakeQueryCompiler._has_unsupported_kwargs( + "BasePandasDataset", "cumsum", args_axis1 + ) + + assert should_switch_axis0 is False, "axis=0 should not trigger switching" + assert should_switch_axis1 is True, "axis=1 should trigger switching" + + # Test cost calculation integration + df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + cost_axis0 = SnowflakeQueryCompiler.move_to_me_cost( + df._query_compiler, + api_cls_name="BasePandasDataset", + operation="cumsum", + arguments=args_axis0, + ) + + cost_axis1 = SnowflakeQueryCompiler.move_to_me_cost( + df._query_compiler, + api_cls_name="BasePandasDataset", + operation="cumsum", + arguments=args_axis1, + ) + + assert ( + cost_axis0 != QCCoercionCost.COST_IMPOSSIBLE + ), "axis=0 should allow movement to Snowflake" + + assert ( + cost_axis1 == QCCoercionCost.COST_IMPOSSIBLE + ), "axis=1 should prevent movement to Snowflake" + + # Test method execution + result_axis0 = df.cumsum(axis=0) + expected_axis0 = [[1, 4], [3, 9], [6, 15]] + assert ( + result_axis0.values.tolist() == expected_axis0 + ), "axis=0 result should be correct" + + # Test auto-switching behavior + snow_df = df.move_to("Snowflake") + assert snow_df.get_backend() == "Snowflake", "DataFrame should be on Snowflake" + + stay_cost_axis1 = snow_df._query_compiler.stay_cost( + api_cls_name="BasePandasDataset", operation="cumsum", arguments=args_axis1 + ) + assert ( + stay_cost_axis1 == QCCoercionCost.COST_IMPOSSIBLE + ), "Should prevent staying on Snowflake" + + # Test error message when auto-switching is disabled + with config_context(AutoSwitchBackend=False): + try: + snow_df.cumsum(axis=1) + raise AssertionError("Should have raised NotImplementedError") + except NotImplementedError as e: + error_msg = str(e) + # print(f"Error message when auto-switching disabled: {error_msg}") + + assert ( + "is not supported" in error_msg + ), f"Error should mention 'is not supported': {error_msg}" + assert ( + "axis=1" in error_msg + ), f"Error should mention 'axis=1': {error_msg}" + + # Test auto-switching behavior when enabled + with config_context(AutoSwitchBackend=True): + result_axis1 = snow_df.cumsum(axis=1) + expected_axis1 = [[1, 5], [2, 7], [3, 9]] + + assert ( + result_axis1.get_backend() == "Pandas" + ), "Should auto-switch to Pandas backend" + assert ( + result_axis1.values.tolist() == expected_axis1 + ), "axis=1 result should be correct" + + session.close() + return True + + except Exception: + import traceback + + traceback.print_exc() + return False + + +if __name__ == "__main__": + success = test_kwargs_based_auto_switching() + sys.exit(0 if success else 1)