Skip to content

Commit a5777bd

Browse files
address comments
1 parent f8c3b04 commit a5777bd

3 files changed

Lines changed: 29 additions & 124 deletions

File tree

src/snowflake/snowpark/session.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -763,10 +763,7 @@ def __init__(
763763
_SNOWPARK_PANDAS_HYBRID_EXECUTION_ENABLED, AutoSwitchBackend().get()
764764
)
765765
)
766-
if pandas_hybrid_execution_enabled:
767-
AutoSwitchBackend.enable()
768-
else:
769-
AutoSwitchBackend.disable()
766+
AutoSwitchBackend.put(pandas_hybrid_execution_enabled)
770767

771768
self._thread_store = create_thread_local(
772769
self._conn._thread_safe_session_enabled
@@ -1046,8 +1043,9 @@ def pandas_hybrid_execution_enabled(self) -> bool:
10461043
This can significantly improve performance for operations that are more efficient in pandas than in Snowflake.
10471044
"""
10481045
if not importlib.util.find_spec("modin"):
1049-
# If modin is not installed, always return False
1050-
return False
1046+
raise ImportError(
1047+
"The 'modin' package is required to enable this feature. Please install it first."
1048+
)
10511049

10521050
from modin.config import AutoSwitchBackend
10531051

@@ -1232,16 +1230,14 @@ def dummy_row_pos_optimization_enabled(self, value: bool) -> None:
12321230
def pandas_hybrid_execution_enabled(self, value: bool) -> None:
12331231
"""Set the value for pandas_hybrid_execution_enabled"""
12341232
if not importlib.util.find_spec("modin"):
1235-
# If modin is not installed, treat this method as a no-op.
1236-
return
1233+
raise ImportError(
1234+
"The 'modin' package is required to enable this feature. Please install it first."
1235+
)
12371236

12381237
from modin.config import AutoSwitchBackend
12391238

12401239
if value in [True, False]:
1241-
if value:
1242-
AutoSwitchBackend.enable()
1243-
else:
1244-
AutoSwitchBackend.disable()
1240+
AutoSwitchBackend.put(value)
12451241
else:
12461242
raise ValueError(
12471243
"value for pandas_hybrid_execution_enabled must be True or False!"

tests/integ/modin/hybrid/test_data_movement.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,6 @@ def test_unsupported_movement(session, pandas_df):
4545
assert move_from_result is NotImplemented
4646

4747

48-
@sql_count_checker(query_count=0)
49-
def test_unsupported_movement_using_hybrid_session_param(session, pandas_df):
50-
with config_context(Backend="Snowflake"):
51-
session.pandas_hybrid_execution_enabled = False
52-
snow_df = pd.DataFrame(pandas_df)
53-
mock_qc = MockQueryCompiler()
54-
move_to_result = snow_df._query_compiler.move_to("UnsupportedBackend")
55-
move_from_result = SnowflakeQueryCompiler.move_from(mock_qc)
56-
assert move_to_result is NotImplemented
57-
assert move_from_result is NotImplemented
58-
59-
6048
@pytest.mark.skipif(
6149
sys.version_info.minor >= 12,
6250
reason="snowflake-ml-python for efficient movement is not installed above python 3.12",
@@ -73,23 +61,6 @@ def test_move_to_ray(session, pandas_df):
7361
df_equals(result_df, snow_df)
7462

7563

76-
@pytest.mark.skipif(
77-
sys.version_info.minor >= 12,
78-
reason="snowflake-ml-python for efficient movement is not installed above python 3.12",
79-
)
80-
@sql_count_checker(query_count=9)
81-
def test_move_to_ray_using_hybrid_session_param(session, pandas_df):
82-
with config_context(Backend="Snowflake"):
83-
session.pandas_hybrid_execution_enabled = False
84-
snow_df = pd.DataFrame(pandas_df)
85-
assert snow_df.get_backend() == "Snowflake"
86-
result = snow_df._query_compiler.move_to("Ray")
87-
result_df = pd.DataFrame(query_compiler=result)
88-
assert result_df.get_backend() == "Ray"
89-
assert Backend.get() == "Snowflake"
90-
df_equals(result_df, snow_df)
91-
92-
9364
@pytest.mark.skip(reason="SNOW-2276090")
9465
@sql_count_checker(query_count=4)
9566
def test_move_from_ray(session, pandas_df):

tests/integ/modin/hybrid/test_switch_operations.py

Lines changed: 21 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,9 @@ def test_tqdm_usage_during_snowflake_to_pandas_switch():
395395
("Series", "transform", (lambda x: x * 2,)), # declared in series_overrides
396396
],
397397
)
398+
@pytest.mark.parametrize("use_session_param", [True, False])
398399
@sql_count_checker(query_count=1)
399-
def test_unimplemented_autoswitches(class_name, method_name, f_args):
400+
def test_unimplemented_autoswitches(class_name, method_name, f_args, use_session_param):
400401
# Unimplemented methods declared via register_*_not_implemented should automatically
401402
# default to local pandas execution.
402403
# This test needs to be modified if any of the APIs in question are ever natively implemented
@@ -405,6 +406,12 @@ def test_unimplemented_autoswitches(class_name, method_name, f_args):
405406
method = getattr(getattr(pd, class_name)(data).move_to("Snowflake"), method_name)
406407
# Attempting to call the method without switching should raise.
407408
with config_context(AutoSwitchBackend=False):
409+
if use_session_param:
410+
from modin.config import AutoSwitchBackend
411+
412+
AutoSwitchBackend.enable()
413+
pd.session.pandas_hybrid_execution_enabled = False
414+
assert pd.session.pandas_hybrid_execution_enabled is False
408415
with pytest.raises(
409416
NotImplementedError, match="Snowpark pandas does not yet support the method"
410417
):
@@ -425,49 +432,6 @@ def test_unimplemented_autoswitches(class_name, method_name, f_args):
425432
assert snow_result == pandas_result
426433

427434

428-
@pytest.mark.parametrize(
429-
"class_name, method_name, f_args",
430-
[
431-
("DataFrame", "to_json", ()), # declared in base_overrides
432-
("Series", "to_json", ()), # declared in base_overrides
433-
("DataFrame", "dot", ([6],)), # declared in dataframe_overrides
434-
("Series", "transform", (lambda x: x * 2,)), # declared in series_overrides
435-
],
436-
)
437-
@sql_count_checker(query_count=1)
438-
def test_unimplemented_autoswitches_using_hybrid_session_param(
439-
class_name, method_name, f_args
440-
):
441-
# Unimplemented methods declared via register_*_not_implemented should automatically
442-
# default to local pandas execution.
443-
# This test needs to be modified if any of the APIs in question are ever natively implemented
444-
# for Snowpark pandas.
445-
data = [1, 2, 3]
446-
method = getattr(getattr(pd, class_name)(data).move_to("Snowflake"), method_name)
447-
# Attempting to call the method without switching should raise.
448-
pd.session.pandas_hybrid_execution_enabled = False
449-
with pytest.raises(
450-
NotImplementedError, match="Snowpark pandas does not yet support the method"
451-
):
452-
method(*f_args)
453-
454-
# Attempting to call the method while switching is enabled should work fine.
455-
pd.session.pandas_hybrid_execution_enabled = True
456-
snow_result = method(*f_args)
457-
pandas_result = getattr(getattr(native_pd, class_name)(data), method_name)(*f_args)
458-
if isinstance(snow_result, (pd.DataFrame, pd.Series)):
459-
assert snow_result.get_backend() == "Pandas"
460-
assert_array_equal(snow_result.to_numpy(), pandas_result.to_numpy())
461-
else:
462-
# Series.to_json will output an extraneous level for the __reduced__ column, but that's OK
463-
# since we don't officially support the method.
464-
# See modin bug: https://github.com/modin-project/modin/issues/7624
465-
if class_name == "Series" and method_name == "to_json":
466-
assert snow_result == '{"__reduced__":{"0":1,"1":2,"2":3}}'
467-
else:
468-
assert snow_result == pandas_result
469-
470-
471435
@sql_count_checker(query_count=0)
472436
def test_to_datetime():
473437
assert Backend.get() == "Snowflake"
@@ -476,14 +440,15 @@ def test_to_datetime():
476440
assert isinstance(result, DatetimeIndex)
477441

478442

443+
@pytest.mark.parametrize("use_session_param", [True, False])
479444
@sql_count_checker(
480445
query_count=11,
481446
join_count=6,
482447
udtf_count=2,
483448
high_count_expected=True,
484449
high_count_reason="tests queries across different execution modes",
485450
)
486-
def test_query_count_no_switch(init_transaction_tables):
451+
def test_query_count_no_switch(init_transaction_tables, use_session_param):
487452
"""
488453
Tests that when there is no switching behavior the query count is the
489454
same under hybrid mode and non-hybrid mode.
@@ -501,50 +466,23 @@ def inner_test(df_in):
501466
hybrid_len = None
502467
with pd.session.query_history() as query_history_orig:
503468
with config_context(AutoSwitchBackend=False, NativePandasMaxRows=10):
469+
if use_session_param:
470+
from modin.config import AutoSwitchBackend
471+
472+
AutoSwitchBackend.enable()
473+
pd.session.pandas_hybrid_execution_enabled = False
474+
assert pd.session.pandas_hybrid_execution_enabled is False
504475
df_result = inner_test(df_transactions)
505476
orig_len = len(df_result)
506477

507478
with pd.session.query_history() as query_history_hybrid:
508479
with config_context(AutoSwitchBackend=True, NativePandasMaxRows=10):
509-
df_result = inner_test(df_transactions)
510-
hybrid_len = len(df_result)
511-
512-
assert orig_len == hybrid_len
513-
assert len(query_history_orig.queries) == len(query_history_hybrid.queries)
514-
515-
516-
@sql_count_checker(
517-
query_count=11,
518-
join_count=6,
519-
udtf_count=2,
520-
high_count_expected=True,
521-
high_count_reason="tests queries across different execution modes",
522-
)
523-
def test_query_count_no_switch_using_hybrid_session_param(init_transaction_tables):
524-
"""
525-
Tests that when there is no switching behavior the query count is the
526-
same under hybrid mode and non-hybrid mode.
527-
"""
528-
529-
def inner_test(df_in):
530-
df_result = df_in[(df_in["REVENUE"] > 123) & (df_in["REVENUE"] < 200)]
531-
df_result["REVENUE_DUPE"] = df_result["REVENUE"]
532-
df_result["COUNT"] = df_result.groupby("DATE")["REVENUE"].transform("count")
533-
return df_result
534-
535-
df_transactions = pd.read_snowflake("REVENUE_TRANSACTIONS")
536-
inner_test(df_transactions)
537-
orig_len = None
538-
hybrid_len = None
539-
with pd.session.query_history() as query_history_orig:
540-
with config_context(NativePandasMaxRows=10):
541-
pd.session.pandas_hybrid_execution_enabled = False
542-
df_result = inner_test(df_transactions)
543-
orig_len = len(df_result)
480+
if use_session_param:
481+
from modin.config import AutoSwitchBackend
544482

545-
with pd.session.query_history() as query_history_hybrid:
546-
with config_context(NativePandasMaxRows=10):
547-
pd.session.pandas_hybrid_execution_enabled = True
483+
AutoSwitchBackend.disable()
484+
pd.session.pandas_hybrid_execution_enabled = True
485+
assert pd.session.pandas_hybrid_execution_enabled is True
548486
df_result = inner_test(df_transactions)
549487
hybrid_len = len(df_result)
550488

0 commit comments

Comments
 (0)