From 92b4d7087a6ee046538248676be82720245dad78 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 7 May 2026 11:35:14 -0700 Subject: [PATCH 01/14] use local list instead of fetch from snowflake --- src/snowflake/snowpark/context.py | 42 +++++++++++++++++-------------- src/snowflake/snowpark/session.py | 17 +------------ 2 files changed, 24 insertions(+), 35 deletions(-) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index a111839050..e2f9a465bc 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -57,9 +57,6 @@ "ai_agg", "ai_summarize_agg", "any_value", - "approximate_count_distinct", - "approximate_jaccard_index", - "approximate_similarity", "approx_count_distinct", "approx_percentile", "approx_percentile_accumulate", @@ -67,32 +64,38 @@ "approx_top_k", "approx_top_k_accumulate", "approx_top_k_combine", - "arrayagg", + "approximate_count_distinct", + "approximate_jaccard_index", + "approximate_similarity", "array_agg", "array_union_agg", "array_unique_agg", + "arrayagg", "avg", - "bitandagg", + "bit_and_agg", + "bit_andagg", + "bit_or_agg", + "bit_oragg", + "bit_xor_agg", + "bit_xoragg", "bitand_agg", + "bitandagg", + "bitmap_and_agg", "bitmap_construct_agg", "bitmap_or_agg", - "bitoragg", "bitor_agg", - "bitxoragg", + "bitoragg", "bitxor_agg", - "bit_andagg", - "bit_and_agg", - "bit_oragg", - "bit_or_agg", - "bit_xoragg", - "bit_xor_agg", + "bitxoragg", "booland_agg", "boolor_agg", "boolxor_agg", "corr", "count", + "count(*)", "count_if", "count_internal", + "count_internal(*)", "covar_pop", "covar_samp", "datasketches_hll", @@ -110,12 +113,12 @@ "max_by", "median", "min", + "min_by", "minhash", "minhash_combine", - "min_by", "mode", - "objectagg", "object_agg", + "objectagg", "percentile_cont", "percentile_disc", "regr_avgx", @@ -128,20 +131,21 @@ "regr_sxy", "regr_syy", "skew", + "st_intersection_agg_geography_internal", + "st_union_agg_geography_internal", "stddev", "stddev_pop", "stddev_samp", - "st_intersection_agg_geography_internal", - "st_union_agg_geography_internal", "sum", "sum_internal", "sum_internal_real", "sum_real", + "summarize_agg", + "var_pop", + "var_samp", "variance", "variance_pop", "variance_samp", - "var_pop", - "var_samp", "vector_avg", "vector_max", "vector_min", diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 457f28f95b..c6da487c7d 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5073,22 +5073,7 @@ def _retrieve_aggregation_function_list(self) -> None: ) # System built-in aggregation functions - try: - retrieved_set.update( - { - r[0].lower() - for r in self.sql( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect() - } - ) - except Exception as e: - _logger.debug( - "Unable to get system aggregation functions, " - "falling back to hardcoded list: %s", - e, - ) - retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set) From 937650ac96d3668d008cd53fe8ffa17190551a3f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 11 May 2026 16:14:36 -0700 Subject: [PATCH 02/14] test --- src/snowflake/snowpark/context.py | 1 + src/snowflake/snowpark/session.py | 115 +++++++++++++++++++++++++----- 2 files changed, 100 insertions(+), 16 deletions(-) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index e2f9a465bc..240672a571 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -153,6 +153,7 @@ ] ) + _cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session. # Following are internal-only global flags, used to enable development features. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index c6da487c7d..a2b4cec3dd 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -856,8 +856,11 @@ def __init__( self._dataframe_profiler = DataframeProfiler(session=self) self._catalog = None self._client_telemetry = EventTableTelemetry(session=self) + self._system_agg_function_prefetch_job: Optional[AsyncJob] = None + self._user_agg_function_prefetch_job: Optional[AsyncJob] = None self._ast_batch = AstBatch(self) + self._start_async_aggregation_prefetch() _logger.info("Snowpark Session information: %s", self._session_info) @@ -5056,28 +5059,108 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set = set() - # User-defined aggregation functions - try: - retrieved_set.update( - { - r[0].lower() - for r in self.sql( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect() - } - ) - except Exception as e: - _logger.debug( - "Unable to get user-defined aggregation functions: %s", - e, - ) + # User-defined aggregation functions. + # If init has already issued the async query, wait and use it. + # Otherwise, execute synchronously now for select-statement correctness. + if self._user_agg_function_prefetch_job is not None: + try: + retrieved_set.update( + { + r[0].lower() + for r in self._user_agg_function_prefetch_job.result() + } + ) + except Exception as e: + _logger.debug( + "Unable to use async user-defined aggregation function prefetch: %s", + e, + ) + finally: + self._user_agg_function_prefetch_job = None + else: + try: + retrieved_set.update( + { + r[0].lower() + for r in self.sql( + """select function_name from information_schema.functions where is_aggregate = 'YES'""" + ).collect() + } + ) + except Exception as e: + _logger.debug( + "Unable to get user-defined aggregation functions: %s", + e, + ) - # System built-in aggregation functions + # System aggregation functions from metadata query. + if self._system_agg_function_prefetch_job is not None: + try: + retrieved_set.update( + { + r[0].lower() + for r in self._system_agg_function_prefetch_job.result() + } + ) + except Exception as e: + _logger.debug( + "Unable to use async system aggregation function prefetch: %s", + e, + ) + finally: + self._system_agg_function_prefetch_job = None + else: + try: + retrieved_set.update( + { + r[0].lower() + for r in self.sql( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" + ).collect() + } + ) + except Exception as e: + _logger.debug( + "Unable to get system aggregation functions: %s", + e, + ) + + # Keep hardcoded fallback behavior. retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set) + def _start_async_aggregation_prefetch(self) -> None: + """Issue async prefetch query for aggregation metadata once.""" + if not ( + context._is_snowpark_connect_compatible_mode + and context._snowpark_connect_flatten_select_after_sort + ): + return + + try: + self._user_agg_function_prefetch_job = self.sql( + """select function_name from information_schema.functions where is_aggregate = 'YES'""" + ).collect_nowait() + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async user-defined aggregation metadata prefetch: %s", + e, + ) + self._user_agg_function_prefetch_job = None + + try: + self._system_agg_function_prefetch_job = self.sql( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" + ).collect_nowait() + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async system aggregation metadata prefetch: %s", + e, + ) + self._system_agg_function_prefetch_job = None + def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ Returns a DataFrame representing the results of a directory table query on the specified stage. From 1650ef6415afd78bbc8df075f12285c63c641c1b Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 12 May 2026 10:08:37 -0700 Subject: [PATCH 03/14] aysnc update --- src/snowflake/snowpark/session.py | 63 +++++++++++++++++++------------ 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a2b4cec3dd..55c656efcb 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -860,7 +860,7 @@ def __init__( self._user_agg_function_prefetch_job: Optional[AsyncJob] = None self._ast_batch = AstBatch(self) - self._start_async_aggregation_prefetch() + self._start_async_aggregation_prefetch_if_needed() _logger.info("Snowpark Session information: %s", self._session_info) @@ -5058,6 +5058,7 @@ def _retrieve_aggregation_function_list(self) -> None: return retrieved_set = set() + system_fetch_succeeded = False # User-defined aggregation functions. # If init has already issued the async query, wait and use it. @@ -5102,6 +5103,7 @@ def _retrieve_aggregation_function_list(self) -> None: for r in self._system_agg_function_prefetch_job.result() } ) + system_fetch_succeeded = True except Exception as e: _logger.debug( "Unable to use async system aggregation function prefetch: %s", @@ -5119,47 +5121,58 @@ def _retrieve_aggregation_function_list(self) -> None: ).collect() } ) + system_fetch_succeeded = True except Exception as e: _logger.debug( "Unable to get system aggregation functions: %s", e, ) - # Keep hardcoded fallback behavior. - retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + # Fallback to the local hardcoded list only when both metadata fetches fail. + if not system_fetch_succeeded: + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set) - def _start_async_aggregation_prefetch(self) -> None: - """Issue async prefetch query for aggregation metadata once.""" + def _start_async_aggregation_prefetch_if_needed(self) -> None: + """Start aggregation metadata prefetch only when not already in progress.""" if not ( context._is_snowpark_connect_compatible_mode and context._snowpark_connect_flatten_select_after_sort ): return + if context._aggregation_function_set: + return + if ( + self._user_agg_function_prefetch_job is not None + and self._system_agg_function_prefetch_job is not None + ): + return - try: - self._user_agg_function_prefetch_job = self.sql( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect_nowait() - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async user-defined aggregation metadata prefetch: %s", - e, - ) - self._user_agg_function_prefetch_job = None + if self._user_agg_function_prefetch_job is None: + try: + self._user_agg_function_prefetch_job = self.sql( + """select function_name from information_schema.functions where is_aggregate = 'YES'""" + ).collect_nowait() + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async user-defined aggregation metadata prefetch: %s", + e, + ) + self._user_agg_function_prefetch_job = None - try: - self._system_agg_function_prefetch_job = self.sql( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect_nowait() - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async system aggregation metadata prefetch: %s", - e, - ) - self._system_agg_function_prefetch_job = None + if self._system_agg_function_prefetch_job is None: + try: + self._system_agg_function_prefetch_job = self.sql( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" + ).collect_nowait() + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async system aggregation metadata prefetch: %s", + e, + ) + self._system_agg_function_prefetch_job = None def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ From 09b3946c908f27d9429ec71163973c46d96b68d0 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 12 May 2026 11:07:15 -0700 Subject: [PATCH 04/14] add test --- src/snowflake/snowpark/session.py | 34 ++++++++---- tests/integ/test_simplifier_suite.py | 79 ++++++++++++++++++++++++++++ tests/unit/test_session.py | 48 +++++++++++------ 3 files changed, 136 insertions(+), 25 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 55c656efcb..441bbb23d0 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5083,9 +5083,10 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set.update( { r[0].lower() - for r in self.sql( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect() + for r in self._conn.run_query( + """select function_name from information_schema.functions where is_aggregate = 'YES'""", + _is_internal=True, + )["data"] } ) except Exception as e: @@ -5116,9 +5117,10 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set.update( { r[0].lower() - for r in self.sql( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect() + for r in self._conn.run_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + _is_internal=True, + )["data"] } ) system_fetch_succeeded = True @@ -5152,9 +5154,9 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: if self._user_agg_function_prefetch_job is None: try: - self._user_agg_function_prefetch_job = self.sql( + self._user_agg_function_prefetch_job = self._submit_internal_async_prefetch_query( """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect_nowait() + ) except Exception as e: # pragma: no cover _logger.debug( "Unable to start async user-defined aggregation metadata prefetch: %s", @@ -5164,9 +5166,9 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: if self._system_agg_function_prefetch_job is None: try: - self._system_agg_function_prefetch_job = self.sql( + self._system_agg_function_prefetch_job = self._submit_internal_async_prefetch_query( """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect_nowait() + ) except Exception as e: # pragma: no cover _logger.debug( "Unable to start async system aggregation metadata prefetch: %s", @@ -5174,6 +5176,18 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: ) self._system_agg_function_prefetch_job = None + def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]: + """Submit a prefetch query as internal async and return an AsyncJob handle.""" + try: + result = self._conn.execute_async_and_notify_query_listener( + query, + _is_internal=True, + ) + return self.create_async_job(result["queryId"]) + except Exception as e: # pragma: no cover + _logger.debug("Unable to submit internal async prefetch query: %s", e) + return None + def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ Returns a DataFrame representing the results of a directory table query on the specified stage. diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index b446347d51..76ed0502c4 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2519,3 +2519,82 @@ def test_retrieving_aggregation_funcs(session, monkeypatch): assert not context._aggregation_function_set session._retrieve_aggregation_function_list() assert not context._aggregation_function_set + + +def test_internal_async_aggregation_prefetch_submission(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._user_agg_function_prefetch_job = None + session._system_agg_function_prefetch_job = None + + call_kwargs = [] + + def _fake_execute_async(query, **kwargs): + call_kwargs.append(kwargs) + return {"queryId": f"qid_{len(call_kwargs)}"} + + monkeypatch.setattr( + session._conn, "execute_async_and_notify_query_listener", _fake_execute_async + ) + session._start_async_aggregation_prefetch_if_needed() + + assert len(call_kwargs) == 2 + assert all(kwargs.get("_is_internal") is True for kwargs in call_kwargs) + assert session._user_agg_function_prefetch_job.query_id == "qid_1" + assert session._system_agg_function_prefetch_job.query_id == "qid_2" + + +def test_aggregation_fallback_used_when_system_source_fails(session, monkeypatch): + import snowflake.snowpark.context as context + + class _FakeAsyncJob: + def __init__(self, rows=None, error=None) -> None: + self._rows = rows + self._error = error + + def result(self): + if self._error is not None: + raise self._error + return self._rows + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._user_agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)]) + session._system_agg_function_prefetch_job = _FakeAsyncJob( + error=RuntimeError("system fetch failed") + ) + + session._retrieve_aggregation_function_list() + + assert "sum" in context._aggregation_function_set + assert "sum_internal" in context._aggregation_function_set + + +def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._user_agg_function_prefetch_job = None + session._system_agg_function_prefetch_job = None + + call_kwargs = [] + + def _fake_run_query(query, **kwargs): + call_kwargs.append(kwargs) + if "information_schema.functions" in query: + return {"data": [("SUM",)]} + return {"data": [("AVG",)]} + + monkeypatch.setattr(session._conn, "run_query", _fake_run_query) + session._retrieve_aggregation_function_list() + + assert len(call_kwargs) == 2 + assert all(kwargs.get("_is_internal") is True for kwargs in call_kwargs) + assert "sum" in context._aggregation_function_set + assert "avg" in context._aggregation_function_set diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0349618659..6b44195b0e 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -818,34 +818,34 @@ def test_retrieve_aggregation_function_list_handles_user_defined_error(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() - mock_df = MagicMock() - call_count = [0] - - def sql_side_effect(query, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + if "information_schema.functions" in query: raise RuntimeError("user-defined query failed") - mock_df.collect.return_value = [["SUM"], ["AVG"]] - return mock_df + return {"data": [["SUM"], ["AVG"]]} - with mock.patch.object(session, "sql", side_effect=sql_side_effect): + with mock.patch.object( + fake_server_connection, "run_query", side_effect=run_query_side_effect + ): session._retrieve_aggregation_function_list() assert "sum" in ctx._aggregation_function_set assert "avg" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set def test_retrieve_aggregation_function_list_handles_system_error(): - """When querying system aggregation functions fails, the method falls back - to the hardcoded _KNOWN_AGGREGATION_FUNCTIONS set.""" + """When system aggregation metadata retrieval fails, hardcoded fallback applies.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -853,20 +853,29 @@ def test_retrieve_aggregation_function_list_handles_system_error(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() - mock_df = MagicMock() - mock_df.collect.side_effect = RuntimeError("system query failed") + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + if "show functions" in query: + raise RuntimeError("system query failed") + return {"data": [["SUM"]]} - with mock.patch.object(session, "sql", return_value=mock_df): + with mock.patch.object( + fake_server_connection, "run_query", side_effect=run_query_side_effect + ): session._retrieve_aggregation_function_list() + assert "sum" in ctx._aggregation_function_set assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set @@ -880,17 +889,26 @@ def test_retrieve_aggregation_function_list_handles_both_errors(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + raise RuntimeError("query failed") + with mock.patch.object( - session, "sql", side_effect=RuntimeError("query failed") + fake_server_connection, + "run_query", + side_effect=run_query_side_effect, ): session._retrieve_aggregation_function_list() assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set From c606f3f846b4db56aba2c6edbf4420e68f285b3a Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 12 May 2026 16:17:55 -0700 Subject: [PATCH 05/14] update change --- src/snowflake/snowpark/session.py | 101 +++++++-------------------- tests/integ/test_simplifier_suite.py | 45 ++++++------ tests/unit/test_session.py | 37 +++++----- 3 files changed, 69 insertions(+), 114 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 441bbb23d0..f943881e7c 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -856,8 +856,7 @@ def __init__( self._dataframe_profiler = DataframeProfiler(session=self) self._catalog = None self._client_telemetry = EventTableTelemetry(session=self) - self._system_agg_function_prefetch_job: Optional[AsyncJob] = None - self._user_agg_function_prefetch_job: Optional[AsyncJob] = None + self._agg_function_prefetch_job: Optional[AsyncJob] = None self._ast_batch = AstBatch(self) self._start_async_aggregation_prefetch_if_needed() @@ -5060,59 +5059,27 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set = set() system_fetch_succeeded = False - # User-defined aggregation functions. - # If init has already issued the async query, wait and use it. - # Otherwise, execute synchronously now for select-statement correctness. - if self._user_agg_function_prefetch_job is not None: + # Try async result first if prefetch was already started. + if self._agg_function_prefetch_job is not None: try: retrieved_set.update( - { - r[0].lower() - for r in self._user_agg_function_prefetch_job.result() - } - ) - except Exception as e: - _logger.debug( - "Unable to use async user-defined aggregation function prefetch: %s", - e, - ) - finally: - self._user_agg_function_prefetch_job = None - else: - try: - retrieved_set.update( - { - r[0].lower() - for r in self._conn.run_query( - """select function_name from information_schema.functions where is_aggregate = 'YES'""", - _is_internal=True, - )["data"] - } - ) - except Exception as e: - _logger.debug( - "Unable to get user-defined aggregation functions: %s", - e, - ) - - # System aggregation functions from metadata query. - if self._system_agg_function_prefetch_job is not None: - try: - retrieved_set.update( - { - r[0].lower() - for r in self._system_agg_function_prefetch_job.result() - } + {r[0].lower() for r in self._agg_function_prefetch_job.result()} ) system_fetch_succeeded = True except Exception as e: _logger.debug( - "Unable to use async system aggregation function prefetch: %s", + "Unable to use async aggregation function prefetch: %s", e, ) finally: - self._system_agg_function_prefetch_job = None + self._agg_function_prefetch_job = None else: + _logger.debug( + "Async aggregation function prefetch job is unavailable; using sync fallback." + ) + + # Sync fallback query. + if not system_fetch_succeeded: try: retrieved_set.update( { @@ -5126,11 +5093,11 @@ def _retrieve_aggregation_function_list(self) -> None: system_fetch_succeeded = True except Exception as e: _logger.debug( - "Unable to get system aggregation functions: %s", + "Unable to get aggregation functions via sync fallback query: %s", e, ) - # Fallback to the local hardcoded list only when both metadata fetches fail. + # Fallback to the local hardcoded list only when metadata retrieval fails. if not system_fetch_succeeded: retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) @@ -5146,35 +5113,21 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: return if context._aggregation_function_set: return - if ( - self._user_agg_function_prefetch_job is not None - and self._system_agg_function_prefetch_job is not None - ): + if self._agg_function_prefetch_job is not None: return - if self._user_agg_function_prefetch_job is None: - try: - self._user_agg_function_prefetch_job = self._submit_internal_async_prefetch_query( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ) - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async user-defined aggregation metadata prefetch: %s", - e, - ) - self._user_agg_function_prefetch_job = None - - if self._system_agg_function_prefetch_job is None: - try: - self._system_agg_function_prefetch_job = self._submit_internal_async_prefetch_query( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ) - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async system aggregation metadata prefetch: %s", - e, - ) - self._system_agg_function_prefetch_job = None + try: + self._agg_function_prefetch_job = self._submit_internal_async_prefetch_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' +union +select function_name from information_schema.functions where is_aggregate = 'YES'""" + ) + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async aggregation metadata prefetch: %s", + e, + ) + self._agg_function_prefetch_job = None def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]: """Submit a prefetch query as internal async and return an AsyncJob handle.""" diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 76ed0502c4..976035b1c7 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2527,27 +2527,29 @@ def test_internal_async_aggregation_prefetch_submission(session, monkeypatch): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._user_agg_function_prefetch_job = None - session._system_agg_function_prefetch_job = None + session._agg_function_prefetch_job = None - call_kwargs = [] + calls = [] def _fake_execute_async(query, **kwargs): - call_kwargs.append(kwargs) - return {"queryId": f"qid_{len(call_kwargs)}"} + calls.append((query, kwargs)) + return {"queryId": "qid_combined"} monkeypatch.setattr( session._conn, "execute_async_and_notify_query_listener", _fake_execute_async ) session._start_async_aggregation_prefetch_if_needed() - assert len(call_kwargs) == 2 - assert all(kwargs.get("_is_internal") is True for kwargs in call_kwargs) - assert session._user_agg_function_prefetch_job.query_id == "qid_1" - assert session._system_agg_function_prefetch_job.query_id == "qid_2" + assert len(calls) == 1 + assert calls[0][1].get("_is_internal") is True + assert "show functions" in calls[0][0] + assert "information_schema.functions" in calls[0][0] + assert session._agg_function_prefetch_job.query_id == "qid_combined" -def test_aggregation_fallback_used_when_system_source_fails(session, monkeypatch): +def test_aggregation_fallback_not_used_when_combined_async_succeeds( + session, monkeypatch +): import snowflake.snowpark.context as context class _FakeAsyncJob: @@ -2563,15 +2565,12 @@ def result(self): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._user_agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)]) - session._system_agg_function_prefetch_job = _FakeAsyncJob( - error=RuntimeError("system fetch failed") - ) + session._agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)]) session._retrieve_aggregation_function_list() assert "sum" in context._aggregation_function_set - assert "sum_internal" in context._aggregation_function_set + assert "sum_internal" not in context._aggregation_function_set def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): @@ -2580,21 +2579,19 @@ def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._user_agg_function_prefetch_job = None - session._system_agg_function_prefetch_job = None + session._agg_function_prefetch_job = None - call_kwargs = [] + calls = [] def _fake_run_query(query, **kwargs): - call_kwargs.append(kwargs) - if "information_schema.functions" in query: - return {"data": [("SUM",)]} + calls.append((query, kwargs)) return {"data": [("AVG",)]} monkeypatch.setattr(session._conn, "run_query", _fake_run_query) session._retrieve_aggregation_function_list() - assert len(call_kwargs) == 2 - assert all(kwargs.get("_is_internal") is True for kwargs in call_kwargs) - assert "sum" in context._aggregation_function_set + assert len(calls) == 1 + assert calls[0][1].get("_is_internal") is True + assert "show functions" in calls[0][0] + assert "information_schema.functions" not in calls[0][0] assert "avg" in context._aggregation_function_set diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6b44195b0e..4b508c6dc9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -808,9 +808,8 @@ def test_infer_is_return_table_uses_internal_describe(): assert mocked_run_query.call_count == 1 -def test_retrieve_aggregation_function_list_handles_user_defined_error(): - """When querying user-defined aggregation functions fails, the error is - swallowed and the method continues to query system functions.""" +def test_retrieve_aggregation_function_list_handles_async_error(): + """When async metadata prefetch fails, sync internal fallback is used.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -825,10 +824,13 @@ def test_retrieve_aggregation_function_list_handles_user_defined_error(): ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + fake_async_job = MagicMock() + fake_async_job.result.side_effect = RuntimeError("async query failed") + session._agg_function_prefetch_job = fake_async_job + def run_query_side_effect(query, **kwargs): assert kwargs.get("_is_internal") is True - if "information_schema.functions" in query: - raise RuntimeError("user-defined query failed") + assert "show functions" in query return {"data": [["SUM"], ["AVG"]]} with mock.patch.object( @@ -844,8 +846,8 @@ def run_query_side_effect(query, **kwargs): ctx._aggregation_function_set = original_agg_set -def test_retrieve_aggregation_function_list_handles_system_error(): - """When system aggregation metadata retrieval fails, hardcoded fallback applies.""" +def test_retrieve_aggregation_function_list_handles_sync_error(): + """When sync metadata query fails, hardcoded fallback applies.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -862,16 +864,14 @@ def test_retrieve_aggregation_function_list_handles_system_error(): def run_query_side_effect(query, **kwargs): assert kwargs.get("_is_internal") is True - if "show functions" in query: - raise RuntimeError("system query failed") - return {"data": [["SUM"]]} + assert "show functions" in query + raise RuntimeError("sync query failed") with mock.patch.object( fake_server_connection, "run_query", side_effect=run_query_side_effect ): session._retrieve_aggregation_function_list() - assert "sum" in ctx._aggregation_function_set assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) finally: ctx._is_snowpark_connect_compatible_mode = original_compat @@ -879,9 +879,8 @@ def run_query_side_effect(query, **kwargs): ctx._aggregation_function_set = original_agg_set -def test_retrieve_aggregation_function_list_handles_both_errors(): - """When both aggregation function queries fail, the hardcoded fallback - set is still populated.""" +def test_retrieve_aggregation_function_list_uses_single_internal_sync_query(): + """Sync fallback executes exactly one internal metadata query.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -896,9 +895,12 @@ def test_retrieve_aggregation_function_list_handles_both_errors(): ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + called_queries = [] + def run_query_side_effect(query, **kwargs): + called_queries.append(query) assert kwargs.get("_is_internal") is True - raise RuntimeError("query failed") + return {"data": [["SUM"]]} with mock.patch.object( fake_server_connection, @@ -907,7 +909,10 @@ def run_query_side_effect(query, **kwargs): ): session._retrieve_aggregation_function_list() - assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) + assert len(called_queries) == 1 + assert "show functions" in called_queries[0] + assert "information_schema.functions" not in called_queries[0] + assert "sum" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat ctx._snowpark_connect_flatten_select_after_sort = original_flatten From 971ce2ea2dce6ac8ecee94393c5d4ffbbf3b46c7 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 14 May 2026 16:28:00 -0700 Subject: [PATCH 06/14] address comment --- src/snowflake/snowpark/session.py | 18 ++++++++++++++++++ tests/integ/test_simplifier_suite.py | 2 +- tests/unit/test_session.py | 2 +- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index f943881e7c..3e2e538720 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5077,6 +5077,24 @@ def _retrieve_aggregation_function_list(self) -> None: _logger.debug( "Async aggregation function prefetch job is unavailable; using sync fallback." ) + try: + retrieved_set.update( + { + r[0].lower() + for r in self._conn.run_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' +union +select function_name from information_schema.functions where is_aggregate = 'YES'""", + _is_internal=True, + )["data"] + } + ) + system_fetch_succeeded = True + except Exception as e: + _logger.debug( + "Unable to get aggregation functions via sync union query: %s", + e, + ) # Sync fallback query. if not system_fetch_succeeded: diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 976035b1c7..7d91228a56 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2593,5 +2593,5 @@ def _fake_run_query(query, **kwargs): assert len(calls) == 1 assert calls[0][1].get("_is_internal") is True assert "show functions" in calls[0][0] - assert "information_schema.functions" not in calls[0][0] + assert "information_schema.functions" in calls[0][0] assert "avg" in context._aggregation_function_set diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4b508c6dc9..e79071342b 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -911,7 +911,7 @@ def run_query_side_effect(query, **kwargs): assert len(called_queries) == 1 assert "show functions" in called_queries[0] - assert "information_schema.functions" not in called_queries[0] + assert "information_schema.functions" in called_queries[0] assert "sum" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat From b49e48f27fee38c9e91c74156adeca062facbe8f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 15 May 2026 16:49:48 +0000 Subject: [PATCH 07/14] add lock for async job object --- src/snowflake/snowpark/session.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 3e2e538720..2690dc8bfe 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -17,7 +17,7 @@ from collections import defaultdict from functools import reduce from logging import getLogger -from threading import RLock +from threading import Lock, RLock from types import ModuleType from typing import ( TYPE_CHECKING, @@ -857,6 +857,8 @@ def __init__( self._catalog = None self._client_telemetry = EventTableTelemetry(session=self) self._agg_function_prefetch_job: Optional[AsyncJob] = None + # Guards the one-time atomic claim of _agg_function_prefetch_job. + self._agg_function_prefetch_lock = Lock() self._ast_batch = AstBatch(self) self._start_async_aggregation_prefetch_if_needed() @@ -5059,11 +5061,19 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set = set() system_fetch_succeeded = False - # Try async result first if prefetch was already started. - if self._agg_function_prefetch_job is not None: + # Try async result first if prefetch was already started. Atomically claim the job so + # that only one thread ever calls job.result(). AsyncJob.result() is not thread-safe — + # the underlying connector cursor mutates shared state (e.g. _result, _rownumber, + # _prefetch_hook) during result fetching, so concurrent calls would cause torn reads. + # The lock is held only for the pointer swap (nanoseconds), not the network call itself. + # In the worst case a second thread finds the job already claimed and falls back to the + # sync query path, matching pre-optimization performance. + with self._agg_function_prefetch_lock: + job, self._agg_function_prefetch_job = self._agg_function_prefetch_job, None + if job is not None: try: retrieved_set.update( - {r[0].lower() for r in self._agg_function_prefetch_job.result()} + {r[0].lower() for r in job.result()} ) system_fetch_succeeded = True except Exception as e: @@ -5071,8 +5081,6 @@ def _retrieve_aggregation_function_list(self) -> None: "Unable to use async aggregation function prefetch: %s", e, ) - finally: - self._agg_function_prefetch_job = None else: _logger.debug( "Async aggregation function prefetch job is unavailable; using sync fallback." From 73ee4e055998ac7ac90cd6c69d8f7e4f75953ebf Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 15 May 2026 18:37:47 +0000 Subject: [PATCH 08/14] make waiting threads reuse async job result instead of issuing redundant sync queries --- src/snowflake/snowpark/session.py | 42 +++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 2690dc8bfe..d92ec3a015 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -17,7 +17,7 @@ from collections import defaultdict from functools import reduce from logging import getLogger -from threading import Lock, RLock +from threading import Event, Lock, RLock from types import ModuleType from typing import ( TYPE_CHECKING, @@ -859,6 +859,9 @@ def __init__( self._agg_function_prefetch_job: Optional[AsyncJob] = None # Guards the one-time atomic claim of _agg_function_prefetch_job. self._agg_function_prefetch_lock = Lock() + # Set by the thread that claimed the async job once it finishes (success or failure), + # so other threads can wait instead of issuing redundant sync queries. + self._agg_function_fetch_event: Optional[Event] = None self._ast_batch = AstBatch(self) self._start_async_aggregation_prefetch_if_needed() @@ -5061,15 +5064,34 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set = set() system_fetch_succeeded = False - # Try async result first if prefetch was already started. Atomically claim the job so - # that only one thread ever calls job.result(). AsyncJob.result() is not thread-safe — - # the underlying connector cursor mutates shared state (e.g. _result, _rownumber, - # _prefetch_hook) during result fetching, so concurrent calls would cause torn reads. - # The lock is held only for the pointer swap (nanoseconds), not the network call itself. - # In the worst case a second thread finds the job already claimed and falls back to the - # sync query path, matching pre-optimization performance. + # Atomically claim the async job. The claiming thread creates an Event so concurrent + # threads can wait on it rather than issuing redundant sync queries. + # AsyncJob.result() is not thread-safe — the underlying connector cursor mutates + # shared state (_result, _rownumber, _prefetch_hook) during result fetching, so only + # one thread may call it. The lock is held only for the pointer swap and event setup + # (nanoseconds), not the network call itself. with self._agg_function_prefetch_lock: job, self._agg_function_prefetch_job = self._agg_function_prefetch_job, None + if job is not None: + fetch_event = Event() + self._agg_function_fetch_event = fetch_event + wait_event = None + elif self._agg_function_fetch_event is not None: + fetch_event = None + wait_event = self._agg_function_fetch_event + else: + fetch_event = None + wait_event = None + + if wait_event is not None: + # The query typically finishes in ~5s; 20s gives ample headroom while + # bounding the hang in the rare case the winner thread dies before its + # finally block runs (e.g. os._exit, interpreter shutdown). + wait_event.wait(timeout=20) + if context._aggregation_function_set: + return + # Winner failed or timed out; fall through to sync query. + if job is not None: try: retrieved_set.update( @@ -5081,6 +5103,10 @@ def _retrieve_aggregation_function_list(self) -> None: "Unable to use async aggregation function prefetch: %s", e, ) + finally: + # Always unblock waiting threads regardless of success, failure, or + # BaseException (e.g. KeyboardInterrupt). + fetch_event.set() else: _logger.debug( "Async aggregation function prefetch job is unavailable; using sync fallback." From a47adc68465bc501c857316214ce34dc8926b04a Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 15 May 2026 19:33:29 +0000 Subject: [PATCH 09/14] fix event-before-publish race and add thread-safety tests --- src/snowflake/snowpark/session.py | 115 ++++++------ tests/integ/test_simplifier_suite.py | 104 +++++++++++ tests/unit/test_session.py | 261 +++++++++++++++++++++++++++ 3 files changed, 424 insertions(+), 56 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index d92ec3a015..4e4487bdad 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5092,69 +5092,72 @@ def _retrieve_aggregation_function_list(self) -> None: return # Winner failed or timed out; fall through to sync query. - if job is not None: - try: - retrieved_set.update( - {r[0].lower() for r in job.result()} - ) - system_fetch_succeeded = True - except Exception as e: + try: + if job is not None: + try: + retrieved_set.update( + {r[0].lower() for r in job.result()} + ) + system_fetch_succeeded = True + except Exception as e: + _logger.debug( + "Unable to use async aggregation function prefetch: %s", + e, + ) + else: _logger.debug( - "Unable to use async aggregation function prefetch: %s", - e, + "Async aggregation function prefetch job is unavailable; using sync fallback." ) - finally: - # Always unblock waiting threads regardless of success, failure, or - # BaseException (e.g. KeyboardInterrupt). - fetch_event.set() - else: - _logger.debug( - "Async aggregation function prefetch job is unavailable; using sync fallback." - ) - try: - retrieved_set.update( - { - r[0].lower() - for r in self._conn.run_query( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' + try: + retrieved_set.update( + { + r[0].lower() + for r in self._conn.run_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' union select function_name from information_schema.functions where is_aggregate = 'YES'""", - _is_internal=True, - )["data"] - } - ) - system_fetch_succeeded = True - except Exception as e: - _logger.debug( - "Unable to get aggregation functions via sync union query: %s", - e, - ) + _is_internal=True, + )["data"] + } + ) + system_fetch_succeeded = True + except Exception as e: + _logger.debug( + "Unable to get aggregation functions via sync union query: %s", + e, + ) - # Sync fallback query. - if not system_fetch_succeeded: - try: - retrieved_set.update( - { - r[0].lower() - for r in self._conn.run_query( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", - _is_internal=True, - )["data"] - } - ) - system_fetch_succeeded = True - except Exception as e: - _logger.debug( - "Unable to get aggregation functions via sync fallback query: %s", - e, - ) + # Sync fallback query. + if not system_fetch_succeeded: + try: + retrieved_set.update( + { + r[0].lower() + for r in self._conn.run_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + _is_internal=True, + )["data"] + } + ) + system_fetch_succeeded = True + except Exception as e: + _logger.debug( + "Unable to get aggregation functions via sync fallback query: %s", + e, + ) - # Fallback to the local hardcoded list only when metadata retrieval fails. - if not system_fetch_succeeded: - retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + # Fallback to the local hardcoded list only when metadata retrieval fails. + if not system_fetch_succeeded: + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) - with context._aggregation_function_set_lock: - context._aggregation_function_set.update(retrieved_set) + with context._aggregation_function_set_lock: + context._aggregation_function_set.update(retrieved_set) + finally: + # Signal after _aggregation_function_set is published so waiters see + # the populated set immediately upon waking. Also fires on BaseException + # (e.g. KeyboardInterrupt) so waiters are never left blocking until timeout. + if fetch_event is not None: + fetch_event.set() def _start_async_aggregation_prefetch_if_needed(self) -> None: """Start aggregation metadata prefetch only when not already in progress.""" diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 7d91228a56..3a3587365d 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2595,3 +2595,107 @@ def _fake_run_query(query, **kwargs): assert "show functions" in calls[0][0] assert "information_schema.functions" in calls[0][0] assert "avg" in context._aggregation_function_set + + +def test_concurrent_retrieve_agg_waiters_no_sync_query(session, monkeypatch): + """Concurrent calls to _retrieve_aggregation_function_list must result in zero + sync queries from waiters — they reuse the winner's async result via the Event.""" + import threading + import time + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._agg_function_fetch_event = None + + job_may_proceed = threading.Event() + waiter_count = [0] + waiter_count_lock = threading.Lock() + + class SlowFakeAsyncJob: + def result(self): + job_may_proceed.wait() + return [("SUM",), ("AVG",)] + + session._agg_function_prefetch_job = SlowFakeAsyncJob() + + sync_query_calls = [] + original_run_query = session._conn.run_query + + def counting_run_query(query, **kwargs): + if kwargs.get("_is_internal") and "show functions" in query: + sync_query_calls.append(query) + return original_run_query(query, **kwargs) + + monkeypatch.setattr(session._conn, "run_query", counting_run_query) + + errors = [] + + def run_winner(): + try: + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + def run_waiter(): + try: + with waiter_count_lock: + waiter_count[0] += 1 + if waiter_count[0] == 2: + job_may_proceed.set() + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + winner = threading.Thread(target=run_winner) + waiters = [threading.Thread(target=run_waiter) for _ in range(2)] + + winner.start() + time.sleep(0.05) # give winner time to claim job and set fetch_event + for w in waiters: + w.start() + winner.join(timeout=15) + for w in waiters: + w.join(timeout=15) + + assert not errors + assert "sum" in context._aggregation_function_set + assert "avg" in context._aggregation_function_set + assert len(sync_query_calls) == 0, ( + f"Expected 0 sync queries from waiters, got {len(sync_query_calls)}" + ) + + +def test_concurrent_retrieve_agg_event_set_after_context_published(session, monkeypatch): + """The fetch_event must be set only after _aggregation_function_set is published — + waiters must see a non-empty set the moment they wake up.""" + import snowflake.snowpark.context as context + from threading import Event as _Event + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._agg_function_fetch_event = None + + class _FakeAsyncJob: + def result(self): + return [("SUM",)] + + session._agg_function_prefetch_job = _FakeAsyncJob() + + snapshot_at_set = [] + original_event_set = _Event.set + + def patched_set(self_event): + snapshot_at_set.append(frozenset(context._aggregation_function_set)) + original_event_set(self_event) + + monkeypatch.setattr(_Event, "set", patched_set) + + session._retrieve_aggregation_function_list() + + assert snapshot_at_set, "fetch_event.set() was never called" + assert snapshot_at_set[0], ( + "fetch_event fired before _aggregation_function_set was published" + ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e79071342b..470c2a89cc 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -917,3 +917,264 @@ def run_query_side_effect(query, **kwargs): ctx._is_snowpark_connect_compatible_mode = original_compat ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set + + +def _make_agg_session(): + """Create a minimal Session backed by a mock ServerConnection.""" + import snowflake.snowpark.context as ctx + + fake_conn = mock.create_autospec(ServerConnection) + fake_conn._thread_safe_session_enabled = True + session = Session(fake_conn) + session._agg_function_prefetch_job = None + session._agg_function_fetch_event = None + return session, fake_conn, ctx + + +def _ctx_setup(ctx): + ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True + ctx._aggregation_function_set = set() + + +def _ctx_restore(ctx, orig): + ctx._is_snowpark_connect_compatible_mode = orig[0] + ctx._snowpark_connect_flatten_select_after_sort = orig[1] + ctx._aggregation_function_set = orig[2] + + +def test_retrieve_agg_concurrent_waiters_see_result_not_sync_query(): + """Waiting threads read from the populated set after the winner finishes — + they must NOT issue a sync query of their own.""" + import threading + + session, fake_conn, ctx = _make_agg_session() + orig = ( + ctx._is_snowpark_connect_compatible_mode, + ctx._snowpark_connect_flatten_select_after_sort, + ctx._aggregation_function_set, + ) + try: + _ctx_setup(ctx) + + # The winner's job blocks until we signal it, giving waiters time to + # enter the lock, grab wait_event, and call wait_event.wait(). + job_may_proceed = threading.Event() + waiters_are_waiting = threading.Event() + waiter_count = [0] + waiter_count_lock = threading.Lock() + + class SlowAsyncJob: + def result(self): + # Signal that the winner is inside result() so the main thread + # can start the waiter threads now. + job_may_proceed.wait() + return [("SUM",), ("AVG",)] + + session._agg_function_prefetch_job = SlowAsyncJob() + + sync_query_calls = [] + + def run_query_side_effect(query, **kwargs): + sync_query_calls.append(query) + return {"data": []} + + fake_conn.run_query.side_effect = run_query_side_effect + + errors = [] + + def run_winner(): + try: + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + def run_waiter(): + try: + with waiter_count_lock: + waiter_count[0] += 1 + if waiter_count[0] == 2: + # Both waiters have registered — let the winner proceed. + job_may_proceed.set() + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + winner = threading.Thread(target=run_winner) + waiters = [threading.Thread(target=run_waiter) for _ in range(2)] + + winner.start() + # Give the winner time to claim the job and create the fetch_event. + import time; time.sleep(0.05) + for w in waiters: + w.start() + # Waiters increment count and release winner once both are registered. + winner.join(timeout=10) + for w in waiters: + w.join(timeout=10) + + assert not errors + assert "sum" in ctx._aggregation_function_set + assert "avg" in ctx._aggregation_function_set + # Waiters must NOT have issued sync queries — they should have returned + # after seeing the populated set. + assert len(sync_query_calls) == 0, ( + f"Expected 0 sync queries, got {len(sync_query_calls)}: {sync_query_calls}" + ) + finally: + _ctx_restore(ctx, orig) + + +def test_retrieve_agg_event_set_after_context_published(): + """fetch_event.set() must be called only after _aggregation_function_set is + populated — the original bug was setting the event before publishing.""" + import threading + + session, fake_conn, ctx = _make_agg_session() + orig = ( + ctx._is_snowpark_connect_compatible_mode, + ctx._snowpark_connect_flatten_select_after_sort, + ctx._aggregation_function_set, + ) + try: + _ctx_setup(ctx) + + publish_order = [] + + original_update = ctx._aggregation_function_set_lock.__class__ # noqa: unused + + class TrackingAsyncJob: + def result(self): + return [("SUM",)] + + session._agg_function_prefetch_job = TrackingAsyncJob() + + # Patch Event.set to record what's in the context set at the moment it fires. + from threading import Event as _Event + + original_set = _Event.set + + def patched_set(self_event): + publish_order.append(("event_set", frozenset(ctx._aggregation_function_set))) + original_set(self_event) + + with mock.patch.object(_Event, "set", patched_set): + session._retrieve_aggregation_function_list() + + assert publish_order, "fetch_event.set() was never called" + # At the moment event fires the set must already contain the result. + _, snapshot = publish_order[0] + assert "sum" in snapshot, ( + f"fetch_event fired before context was populated; snapshot={snapshot}" + ) + finally: + _ctx_restore(ctx, orig) + + +def test_retrieve_agg_waiters_fall_through_on_winner_failure(): + """When the winner's async job fails, waiters fall through to sync query + rather than hanging or returning an empty set.""" + import threading + import time + + session, fake_conn, ctx = _make_agg_session() + orig = ( + ctx._is_snowpark_connect_compatible_mode, + ctx._snowpark_connect_flatten_select_after_sort, + ctx._aggregation_function_set, + ) + try: + _ctx_setup(ctx) + + job_may_proceed = threading.Event() + waiter_registered = threading.Event() + + class FailingAsyncJob: + def result(self): + job_may_proceed.wait() + raise RuntimeError("async job failed") + + session._agg_function_prefetch_job = FailingAsyncJob() + + sync_query_calls = [] + + def run_query_side_effect(query, **kwargs): + sync_query_calls.append(query) + return {"data": [("COUNT",)]} + + fake_conn.run_query.side_effect = run_query_side_effect + + errors = [] + + def run_winner(): + try: + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + def run_waiter(): + try: + waiter_registered.set() + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + winner = threading.Thread(target=run_winner) + waiter = threading.Thread(target=run_waiter) + + winner.start() + time.sleep(0.05) # give winner time to claim the job and set fetch_event + waiter.start() + waiter_registered.wait(timeout=5) + time.sleep(0.05) # give waiter time to reach wait_event.wait() + job_may_proceed.set() # let the winner fail + + winner.join(timeout=10) + waiter.join(timeout=10) + + assert not errors + # Winner failed → waiter fell through to sync query → count in set. + assert "count" in ctx._aggregation_function_set + finally: + _ctx_restore(ctx, orig) + + +def test_retrieve_agg_event_always_set_on_base_exception(): + """fetch_event.set() fires even when a BaseException escapes the async job, + so waiters are never left blocking until timeout.""" + import threading + + session, fake_conn, ctx = _make_agg_session() + orig = ( + ctx._is_snowpark_connect_compatible_mode, + ctx._snowpark_connect_flatten_select_after_sort, + ctx._aggregation_function_set, + ) + try: + _ctx_setup(ctx) + + class KeyboardInterruptJob: + def result(self): + raise KeyboardInterrupt() + + session._agg_function_prefetch_job = KeyboardInterruptJob() + + event_was_set = [] + + from threading import Event as _Event + + original_set = _Event.set + + def patched_set(self_event): + event_was_set.append(True) + original_set(self_event) + + with mock.patch.object(_Event, "set", patched_set): + try: + session._retrieve_aggregation_function_list() + except KeyboardInterrupt: + pass + + assert event_was_set, "fetch_event.set() was not called despite BaseException" + finally: + _ctx_restore(ctx, orig) From d3dc48741eb78a9317736337fd035997267fe384 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 15 May 2026 22:10:01 +0000 Subject: [PATCH 10/14] use RLock, inline prefetch helper, clean up tests --- src/snowflake/snowpark/session.py | 22 +-- tests/unit/test_session.py | 313 ++++++++++-------------------- 2 files changed, 104 insertions(+), 231 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 4e4487bdad..42b9cfa6eb 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -17,7 +17,7 @@ from collections import defaultdict from functools import reduce from logging import getLogger -from threading import Event, Lock, RLock +from threading import Event, RLock from types import ModuleType from typing import ( TYPE_CHECKING, @@ -858,7 +858,7 @@ def __init__( self._client_telemetry = EventTableTelemetry(session=self) self._agg_function_prefetch_job: Optional[AsyncJob] = None # Guards the one-time atomic claim of _agg_function_prefetch_job. - self._agg_function_prefetch_lock = Lock() + self._agg_function_prefetch_lock = RLock() # Set by the thread that claimed the async job once it finishes (success or failure), # so other threads can wait instead of issuing redundant sync queries. self._agg_function_fetch_event: Optional[Event] = None @@ -5172,11 +5172,13 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: return try: - self._agg_function_prefetch_job = self._submit_internal_async_prefetch_query( + result = self._conn.execute_async_and_notify_query_listener( """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' union -select function_name from information_schema.functions where is_aggregate = 'YES'""" +select function_name from information_schema.functions where is_aggregate = 'YES'""", + _is_internal=True, ) + self._agg_function_prefetch_job = self.create_async_job(result["queryId"]) except Exception as e: # pragma: no cover _logger.debug( "Unable to start async aggregation metadata prefetch: %s", @@ -5184,18 +5186,6 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: ) self._agg_function_prefetch_job = None - def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]: - """Submit a prefetch query as internal async and return an AsyncJob handle.""" - try: - result = self._conn.execute_async_and_notify_query_listener( - query, - _is_internal=True, - ) - return self.create_async_job(result["queryId"]) - except Exception as e: # pragma: no cover - _logger.debug("Unable to submit internal async prefetch query: %s", e) - return None - def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ Returns a DataFrame representing the results of a directory table query on the specified stage. diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 470c2a89cc..d547800ac5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -919,262 +919,145 @@ def run_query_side_effect(query, **kwargs): ctx._aggregation_function_set = original_agg_set -def _make_agg_session(): - """Create a minimal Session backed by a mock ServerConnection.""" + +def test_retrieve_agg_event_set_after_context_published(monkeypatch): + """fetch_event.set() must be called only after _aggregation_function_set is + populated — the original bug was setting the event before publishing.""" import snowflake.snowpark.context as ctx + from threading import Event as _Event fake_conn = mock.create_autospec(ServerConnection) fake_conn._thread_safe_session_enabled = True session = Session(fake_conn) - session._agg_function_prefetch_job = None session._agg_function_fetch_event = None - return session, fake_conn, ctx + monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(ctx, "_aggregation_function_set", set()) -def _ctx_setup(ctx): - ctx._is_snowpark_connect_compatible_mode = True - ctx._snowpark_connect_flatten_select_after_sort = True - ctx._aggregation_function_set = set() + class TrackingAsyncJob: + def result(self): + return [("SUM",)] + session._agg_function_prefetch_job = TrackingAsyncJob() -def _ctx_restore(ctx, orig): - ctx._is_snowpark_connect_compatible_mode = orig[0] - ctx._snowpark_connect_flatten_select_after_sort = orig[1] - ctx._aggregation_function_set = orig[2] + publish_order = [] + original_set = _Event.set + def patched_set(self_event): + publish_order.append(("event_set", frozenset(ctx._aggregation_function_set))) + original_set(self_event) -def test_retrieve_agg_concurrent_waiters_see_result_not_sync_query(): - """Waiting threads read from the populated set after the winner finishes — - they must NOT issue a sync query of their own.""" - import threading + with mock.patch.object(_Event, "set", patched_set): + session._retrieve_aggregation_function_list() - session, fake_conn, ctx = _make_agg_session() - orig = ( - ctx._is_snowpark_connect_compatible_mode, - ctx._snowpark_connect_flatten_select_after_sort, - ctx._aggregation_function_set, + assert publish_order, "fetch_event.set() was never called" + # At the moment event fires the set must already contain the result. + _, snapshot = publish_order[0] + assert "sum" in snapshot, ( + f"fetch_event fired before context was populated; snapshot={snapshot}" ) - try: - _ctx_setup(ctx) - - # The winner's job blocks until we signal it, giving waiters time to - # enter the lock, grab wait_event, and call wait_event.wait(). - job_may_proceed = threading.Event() - waiters_are_waiting = threading.Event() - waiter_count = [0] - waiter_count_lock = threading.Lock() - - class SlowAsyncJob: - def result(self): - # Signal that the winner is inside result() so the main thread - # can start the waiter threads now. - job_may_proceed.wait() - return [("SUM",), ("AVG",)] - session._agg_function_prefetch_job = SlowAsyncJob() - - sync_query_calls = [] - - def run_query_side_effect(query, **kwargs): - sync_query_calls.append(query) - return {"data": []} - - fake_conn.run_query.side_effect = run_query_side_effect - - errors = [] - - def run_winner(): - try: - session._retrieve_aggregation_function_list() - except Exception as e: - errors.append(e) - - def run_waiter(): - try: - with waiter_count_lock: - waiter_count[0] += 1 - if waiter_count[0] == 2: - # Both waiters have registered — let the winner proceed. - job_may_proceed.set() - session._retrieve_aggregation_function_list() - except Exception as e: - errors.append(e) - - winner = threading.Thread(target=run_winner) - waiters = [threading.Thread(target=run_waiter) for _ in range(2)] - - winner.start() - # Give the winner time to claim the job and create the fetch_event. - import time; time.sleep(0.05) - for w in waiters: - w.start() - # Waiters increment count and release winner once both are registered. - winner.join(timeout=10) - for w in waiters: - w.join(timeout=10) - - assert not errors - assert "sum" in ctx._aggregation_function_set - assert "avg" in ctx._aggregation_function_set - # Waiters must NOT have issued sync queries — they should have returned - # after seeing the populated set. - assert len(sync_query_calls) == 0, ( - f"Expected 0 sync queries, got {len(sync_query_calls)}: {sync_query_calls}" - ) - finally: - _ctx_restore(ctx, orig) - -def test_retrieve_agg_event_set_after_context_published(): - """fetch_event.set() must be called only after _aggregation_function_set is - populated — the original bug was setting the event before publishing.""" +def test_retrieve_agg_waiters_fall_through_on_winner_failure(monkeypatch): + """When the winner's async job fails, waiters fall through to sync query + rather than hanging or returning an empty set.""" import threading + import time + import snowflake.snowpark.context as ctx - session, fake_conn, ctx = _make_agg_session() - orig = ( - ctx._is_snowpark_connect_compatible_mode, - ctx._snowpark_connect_flatten_select_after_sort, - ctx._aggregation_function_set, - ) - try: - _ctx_setup(ctx) - - publish_order = [] - - original_update = ctx._aggregation_function_set_lock.__class__ # noqa: unused - - class TrackingAsyncJob: - def result(self): - return [("SUM",)] + fake_conn = mock.create_autospec(ServerConnection) + fake_conn._thread_safe_session_enabled = True + session = Session(fake_conn) + session._agg_function_fetch_event = None - session._agg_function_prefetch_job = TrackingAsyncJob() + monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(ctx, "_aggregation_function_set", set()) - # Patch Event.set to record what's in the context set at the moment it fires. - from threading import Event as _Event + job_may_proceed = threading.Event() + waiter_registered = threading.Event() - original_set = _Event.set + class FailingAsyncJob: + def result(self): + job_may_proceed.wait() + raise RuntimeError("async job failed") - def patched_set(self_event): - publish_order.append(("event_set", frozenset(ctx._aggregation_function_set))) - original_set(self_event) + session._agg_function_prefetch_job = FailingAsyncJob() - with mock.patch.object(_Event, "set", patched_set): - session._retrieve_aggregation_function_list() + sync_query_calls = [] - assert publish_order, "fetch_event.set() was never called" - # At the moment event fires the set must already contain the result. - _, snapshot = publish_order[0] - assert "sum" in snapshot, ( - f"fetch_event fired before context was populated; snapshot={snapshot}" - ) - finally: - _ctx_restore(ctx, orig) + def run_query_side_effect(query, **kwargs): + sync_query_calls.append(query) + return {"data": [("COUNT",)]} + fake_conn.run_query.side_effect = run_query_side_effect -def test_retrieve_agg_waiters_fall_through_on_winner_failure(): - """When the winner's async job fails, waiters fall through to sync query - rather than hanging or returning an empty set.""" - import threading - import time + errors = [] - session, fake_conn, ctx = _make_agg_session() - orig = ( - ctx._is_snowpark_connect_compatible_mode, - ctx._snowpark_connect_flatten_select_after_sort, - ctx._aggregation_function_set, - ) - try: - _ctx_setup(ctx) + def run_winner(): + try: + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) - job_may_proceed = threading.Event() - waiter_registered = threading.Event() + def run_waiter(): + try: + waiter_registered.set() + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) - class FailingAsyncJob: - def result(self): - job_may_proceed.wait() - raise RuntimeError("async job failed") + winner = threading.Thread(target=run_winner) + waiter = threading.Thread(target=run_waiter) - session._agg_function_prefetch_job = FailingAsyncJob() + winner.start() + time.sleep(0.05) # give winner time to claim the job and set fetch_event + waiter.start() + waiter_registered.wait(timeout=5) + time.sleep(0.05) # give waiter time to reach wait_event.wait() + job_may_proceed.set() # let the winner fail - sync_query_calls = [] + winner.join(timeout=10) + waiter.join(timeout=10) - def run_query_side_effect(query, **kwargs): - sync_query_calls.append(query) - return {"data": [("COUNT",)]} - - fake_conn.run_query.side_effect = run_query_side_effect - - errors = [] - - def run_winner(): - try: - session._retrieve_aggregation_function_list() - except Exception as e: - errors.append(e) - - def run_waiter(): - try: - waiter_registered.set() - session._retrieve_aggregation_function_list() - except Exception as e: - errors.append(e) - - winner = threading.Thread(target=run_winner) - waiter = threading.Thread(target=run_waiter) - - winner.start() - time.sleep(0.05) # give winner time to claim the job and set fetch_event - waiter.start() - waiter_registered.wait(timeout=5) - time.sleep(0.05) # give waiter time to reach wait_event.wait() - job_may_proceed.set() # let the winner fail - - winner.join(timeout=10) - waiter.join(timeout=10) - - assert not errors - # Winner failed → waiter fell through to sync query → count in set. - assert "count" in ctx._aggregation_function_set - finally: - _ctx_restore(ctx, orig) + assert not errors + # Winner failed → waiter fell through to sync query → count in set. + assert "count" in ctx._aggregation_function_set -def test_retrieve_agg_event_always_set_on_base_exception(): +def test_retrieve_agg_event_always_set_on_base_exception(monkeypatch): """fetch_event.set() fires even when a BaseException escapes the async job, so waiters are never left blocking until timeout.""" - import threading - - session, fake_conn, ctx = _make_agg_session() - orig = ( - ctx._is_snowpark_connect_compatible_mode, - ctx._snowpark_connect_flatten_select_after_sort, - ctx._aggregation_function_set, - ) - try: - _ctx_setup(ctx) + import snowflake.snowpark.context as ctx + from threading import Event as _Event - class KeyboardInterruptJob: - def result(self): - raise KeyboardInterrupt() + fake_conn = mock.create_autospec(ServerConnection) + fake_conn._thread_safe_session_enabled = True + session = Session(fake_conn) + session._agg_function_fetch_event = None - session._agg_function_prefetch_job = KeyboardInterruptJob() + monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(ctx, "_aggregation_function_set", set()) - event_was_set = [] + class KeyboardInterruptJob: + def result(self): + raise KeyboardInterrupt() - from threading import Event as _Event + session._agg_function_prefetch_job = KeyboardInterruptJob() - original_set = _Event.set + event_was_set = [] + original_set = _Event.set - def patched_set(self_event): - event_was_set.append(True) - original_set(self_event) + def patched_set(self_event): + event_was_set.append(True) + original_set(self_event) - with mock.patch.object(_Event, "set", patched_set): - try: - session._retrieve_aggregation_function_list() - except KeyboardInterrupt: - pass + with mock.patch.object(_Event, "set", patched_set): + try: + session._retrieve_aggregation_function_list() + except KeyboardInterrupt: + pass - assert event_was_set, "fetch_event.set() was not called despite BaseException" - finally: - _ctx_restore(ctx, orig) + assert event_was_set, "fetch_event.set() was not called despite BaseException" From abaae6327c45ee39057b4c929192e8580b71c5a5 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 18 May 2026 15:15:54 -0700 Subject: [PATCH 11/14] fix lint --- src/snowflake/snowpark/session.py | 5 +---- tests/integ/test_simplifier_suite.py | 18 ++++++++++-------- tests/unit/test_session.py | 9 ++++----- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 42b9cfa6eb..9bdf7872b5 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5095,9 +5095,7 @@ def _retrieve_aggregation_function_list(self) -> None: try: if job is not None: try: - retrieved_set.update( - {r[0].lower() for r in job.result()} - ) + retrieved_set.update({r[0].lower() for r in job.result()}) system_fetch_succeeded = True except Exception as e: _logger.debug( @@ -5170,7 +5168,6 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: return if self._agg_function_prefetch_job is not None: return - try: result = self._conn.execute_async_and_notify_query_listener( """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 3a3587365d..32043bc8f5 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2662,14 +2662,16 @@ def run_waiter(): assert not errors assert "sum" in context._aggregation_function_set assert "avg" in context._aggregation_function_set - assert len(sync_query_calls) == 0, ( - f"Expected 0 sync queries from waiters, got {len(sync_query_calls)}" - ) + assert ( + len(sync_query_calls) == 0 + ), f"Expected 0 sync queries from waiters, got {len(sync_query_calls)}" -def test_concurrent_retrieve_agg_event_set_after_context_published(session, monkeypatch): +def test_concurrent_retrieve_agg_event_set_after_context_published( + session, monkeypatch +): """The fetch_event must be set only after _aggregation_function_set is published — - waiters must see a non-empty set the moment they wake up.""" + waiters must see a non-empty set the moment they wake up""" import snowflake.snowpark.context as context from threading import Event as _Event @@ -2696,6 +2698,6 @@ def patched_set(self_event): session._retrieve_aggregation_function_list() assert snapshot_at_set, "fetch_event.set() was never called" - assert snapshot_at_set[0], ( - "fetch_event fired before _aggregation_function_set was published" - ) + assert snapshot_at_set[ + 0 + ], "fetch_event fired before _aggregation_function_set was published" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d547800ac5..0d651204d4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -919,7 +919,6 @@ def run_query_side_effect(query, **kwargs): ctx._aggregation_function_set = original_agg_set - def test_retrieve_agg_event_set_after_context_published(monkeypatch): """fetch_event.set() must be called only after _aggregation_function_set is populated — the original bug was setting the event before publishing.""" @@ -954,9 +953,9 @@ def patched_set(self_event): assert publish_order, "fetch_event.set() was never called" # At the moment event fires the set must already contain the result. _, snapshot = publish_order[0] - assert "sum" in snapshot, ( - f"fetch_event fired before context was populated; snapshot={snapshot}" - ) + assert ( + "sum" in snapshot + ), f"fetch_event fired before context was populated; snapshot={snapshot}" def test_retrieve_agg_waiters_fall_through_on_winner_failure(monkeypatch): @@ -1028,7 +1027,7 @@ def run_waiter(): def test_retrieve_agg_event_always_set_on_base_exception(monkeypatch): """fetch_event.set() fires even when a BaseException escapes the async job, - so waiters are never left blocking until timeout.""" + so waiters are never left blocking until timeout""" import snowflake.snowpark.context as ctx from threading import Event as _Event From ad2d18eee9b6dd7c6e0df275ea6085d57bd2925f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 18 May 2026 16:39:13 -0700 Subject: [PATCH 12/14] remove functions with (*) --- src/snowflake/snowpark/context.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index 240672a571..f8d6030efc 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -92,10 +92,8 @@ "boolxor_agg", "corr", "count", - "count(*)", "count_if", "count_internal", - "count_internal(*)", "covar_pop", "covar_samp", "datasketches_hll", From 90b836c7f36ca1c404fe270ab9e3e4d4ea637499 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 19 May 2026 16:40:13 -0700 Subject: [PATCH 13/14] fix logic --- src/snowflake/snowpark/context.py | 7 ++- src/snowflake/snowpark/session.py | 78 +++++++++++----------------- tests/integ/test_simplifier_suite.py | 32 ++++++++---- tests/unit/test_session.py | 19 ++++--- 4 files changed, 70 insertions(+), 66 deletions(-) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index d495188776..812f29d502 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -6,7 +6,7 @@ """Context module for Snowpark.""" import logging import sys -from typing import Callable, Optional +from typing import Any, Callable, Optional import snowflake.snowpark import threading @@ -45,6 +45,11 @@ set() ) # lower cased names of aggregation functions, used in sql simplification _aggregation_function_set_lock = threading.RLock() +_aggregation_function_prefetch_state: dict[str, Any] = { + "lock": threading.RLock(), + "event": None, + "job": None, +} # Hardcoded fallback for system built-in aggregation functions. # Used when the dynamic query fails to retrieve the list from the database. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 9bdf7872b5..010d957893 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -856,12 +856,6 @@ def __init__( self._dataframe_profiler = DataframeProfiler(session=self) self._catalog = None self._client_telemetry = EventTableTelemetry(session=self) - self._agg_function_prefetch_job: Optional[AsyncJob] = None - # Guards the one-time atomic claim of _agg_function_prefetch_job. - self._agg_function_prefetch_lock = RLock() - # Set by the thread that claimed the async job once it finishes (success or failure), - # so other threads can wait instead of issuing redundant sync queries. - self._agg_function_fetch_event: Optional[Event] = None self._ast_batch = AstBatch(self) self._start_async_aggregation_prefetch_if_needed() @@ -5063,6 +5057,7 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set = set() system_fetch_succeeded = False + prefetch_state = context._aggregation_function_prefetch_state # Atomically claim the async job. The claiming thread creates an Event so concurrent # threads can wait on it rather than issuing redundant sync queries. @@ -5070,15 +5065,15 @@ def _retrieve_aggregation_function_list(self) -> None: # shared state (_result, _rownumber, _prefetch_hook) during result fetching, so only # one thread may call it. The lock is held only for the pointer swap and event setup # (nanoseconds), not the network call itself. - with self._agg_function_prefetch_lock: - job, self._agg_function_prefetch_job = self._agg_function_prefetch_job, None + with prefetch_state["lock"]: + job, prefetch_state["job"] = prefetch_state["job"], None if job is not None: fetch_event = Event() - self._agg_function_fetch_event = fetch_event + prefetch_state["event"] = fetch_event wait_event = None - elif self._agg_function_fetch_event is not None: + elif prefetch_state["event"] is not None: fetch_event = None - wait_event = self._agg_function_fetch_event + wait_event = prefetch_state["event"] else: fetch_event = None wait_event = None @@ -5106,24 +5101,6 @@ def _retrieve_aggregation_function_list(self) -> None: _logger.debug( "Async aggregation function prefetch job is unavailable; using sync fallback." ) - try: - retrieved_set.update( - { - r[0].lower() - for r in self._conn.run_query( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' -union -select function_name from information_schema.functions where is_aggregate = 'YES'""", - _is_internal=True, - )["data"] - } - ) - system_fetch_succeeded = True - except Exception as e: - _logger.debug( - "Unable to get aggregation functions via sync union query: %s", - e, - ) # Sync fallback query. if not system_fetch_succeeded: @@ -5164,24 +5141,31 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: and context._snowpark_connect_flatten_select_after_sort ): return - if context._aggregation_function_set: - return - if self._agg_function_prefetch_job is not None: - return - try: - result = self._conn.execute_async_and_notify_query_listener( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' -union -select function_name from information_schema.functions where is_aggregate = 'YES'""", - _is_internal=True, - ) - self._agg_function_prefetch_job = self.create_async_job(result["queryId"]) - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async aggregation metadata prefetch: %s", - e, - ) - self._agg_function_prefetch_job = None + prefetch_state = context._aggregation_function_prefetch_state + with prefetch_state["lock"]: + if context._aggregation_function_set: + return + if prefetch_state["job"] is not None: + return + # A winner thread has already claimed the async job and is still publishing results. + # Do not start a new async query while that in-flight fetch is unfinished. + if ( + prefetch_state["event"] is not None + and not prefetch_state["event"].is_set() + ): + return + try: + result = self._conn.execute_async_and_notify_query_listener( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + _is_internal=True, + ) + prefetch_state["job"] = self.create_async_job(result["queryId"]) + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async aggregation metadata prefetch: %s", + e, + ) + prefetch_state["job"] = None def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 32043bc8f5..616d1d928a 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2522,12 +2522,15 @@ def test_retrieving_aggregation_funcs(session, monkeypatch): def test_internal_async_aggregation_prefetch_submission(session, monkeypatch): + from threading import Event + import snowflake.snowpark.context as context monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._agg_function_prefetch_job = None + context._aggregation_function_prefetch_state["job"] = None + context._aggregation_function_prefetch_state["event"] = None calls = [] @@ -2543,8 +2546,16 @@ def _fake_execute_async(query, **kwargs): assert len(calls) == 1 assert calls[0][1].get("_is_internal") is True assert "show functions" in calls[0][0] - assert "information_schema.functions" in calls[0][0] - assert session._agg_function_prefetch_job.query_id == "qid_combined" + assert "information_schema.functions" not in calls[0][0] + assert ( + context._aggregation_function_prefetch_state["job"].query_id == "qid_combined" + ) + + # Another session start during in-flight fetch should not submit another async query. + context._aggregation_function_prefetch_state["job"] = None + context._aggregation_function_prefetch_state["event"] = Event() + session._start_async_aggregation_prefetch_if_needed() + assert len(calls) == 1 def test_aggregation_fallback_not_used_when_combined_async_succeeds( @@ -2565,7 +2576,7 @@ def result(self): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)]) + context._aggregation_function_prefetch_state["job"] = _FakeAsyncJob(rows=[("SUM",)]) session._retrieve_aggregation_function_list() @@ -2579,7 +2590,8 @@ def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._agg_function_prefetch_job = None + context._aggregation_function_prefetch_state["job"] = None + context._aggregation_function_prefetch_state["event"] = None calls = [] @@ -2593,7 +2605,7 @@ def _fake_run_query(query, **kwargs): assert len(calls) == 1 assert calls[0][1].get("_is_internal") is True assert "show functions" in calls[0][0] - assert "information_schema.functions" in calls[0][0] + assert "information_schema.functions" not in calls[0][0] assert "avg" in context._aggregation_function_set @@ -2607,7 +2619,7 @@ def test_concurrent_retrieve_agg_waiters_no_sync_query(session, monkeypatch): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._agg_function_fetch_event = None + context._aggregation_function_prefetch_state["event"] = None job_may_proceed = threading.Event() waiter_count = [0] @@ -2618,7 +2630,7 @@ def result(self): job_may_proceed.wait() return [("SUM",), ("AVG",)] - session._agg_function_prefetch_job = SlowFakeAsyncJob() + context._aggregation_function_prefetch_state["job"] = SlowFakeAsyncJob() sync_query_calls = [] original_run_query = session._conn.run_query @@ -2678,13 +2690,13 @@ def test_concurrent_retrieve_agg_event_set_after_context_published( monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._agg_function_fetch_event = None + context._aggregation_function_prefetch_state["event"] = None class _FakeAsyncJob: def result(self): return [("SUM",)] - session._agg_function_prefetch_job = _FakeAsyncJob() + context._aggregation_function_prefetch_state["job"] = _FakeAsyncJob() snapshot_at_set = [] original_event_set = _Event.set diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0d651204d4..235f2ff130 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -823,10 +823,11 @@ def test_retrieve_aggregation_function_list_handles_async_error(): ctx._is_snowpark_connect_compatible_mode = True ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + ctx._aggregation_function_prefetch_state["event"] = None fake_async_job = MagicMock() fake_async_job.result.side_effect = RuntimeError("async query failed") - session._agg_function_prefetch_job = fake_async_job + ctx._aggregation_function_prefetch_state["job"] = fake_async_job def run_query_side_effect(query, **kwargs): assert kwargs.get("_is_internal") is True @@ -894,6 +895,8 @@ def test_retrieve_aggregation_function_list_uses_single_internal_sync_query(): ctx._is_snowpark_connect_compatible_mode = True ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + ctx._aggregation_function_prefetch_state["event"] = None + ctx._aggregation_function_prefetch_state["job"] = None called_queries = [] @@ -911,7 +914,7 @@ def run_query_side_effect(query, **kwargs): assert len(called_queries) == 1 assert "show functions" in called_queries[0] - assert "information_schema.functions" in called_queries[0] + assert "information_schema.functions" not in called_queries[0] assert "sum" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat @@ -928,7 +931,7 @@ def test_retrieve_agg_event_set_after_context_published(monkeypatch): fake_conn = mock.create_autospec(ServerConnection) fake_conn._thread_safe_session_enabled = True session = Session(fake_conn) - session._agg_function_fetch_event = None + ctx._aggregation_function_prefetch_state["event"] = None monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) @@ -938,7 +941,7 @@ class TrackingAsyncJob: def result(self): return [("SUM",)] - session._agg_function_prefetch_job = TrackingAsyncJob() + ctx._aggregation_function_prefetch_state["job"] = TrackingAsyncJob() publish_order = [] original_set = _Event.set @@ -968,7 +971,7 @@ def test_retrieve_agg_waiters_fall_through_on_winner_failure(monkeypatch): fake_conn = mock.create_autospec(ServerConnection) fake_conn._thread_safe_session_enabled = True session = Session(fake_conn) - session._agg_function_fetch_event = None + ctx._aggregation_function_prefetch_state["event"] = None monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) @@ -982,7 +985,7 @@ def result(self): job_may_proceed.wait() raise RuntimeError("async job failed") - session._agg_function_prefetch_job = FailingAsyncJob() + ctx._aggregation_function_prefetch_state["job"] = FailingAsyncJob() sync_query_calls = [] @@ -1034,7 +1037,7 @@ def test_retrieve_agg_event_always_set_on_base_exception(monkeypatch): fake_conn = mock.create_autospec(ServerConnection) fake_conn._thread_safe_session_enabled = True session = Session(fake_conn) - session._agg_function_fetch_event = None + ctx._aggregation_function_prefetch_state["event"] = None monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) @@ -1044,7 +1047,7 @@ class KeyboardInterruptJob: def result(self): raise KeyboardInterrupt() - session._agg_function_prefetch_job = KeyboardInterruptJob() + ctx._aggregation_function_prefetch_state["job"] = KeyboardInterruptJob() event_was_set = [] original_set = _Event.set From b951239d507d7fda5979b40ad7ad2695956b33de Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 21 May 2026 13:59:02 -0700 Subject: [PATCH 14/14] add lock to protect --- src/snowflake/snowpark/session.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 010d957893..c8bdc0ca5a 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5046,15 +5046,16 @@ def _execute_sproc_internal( def _retrieve_aggregation_function_list(self) -> None: """Retrieve the list of aggregation functions which will later be used in sql simplifier.""" - if ( - not ( - context._is_snowpark_connect_compatible_mode - and context._snowpark_connect_flatten_select_after_sort - ) - or context._aggregation_function_set + if not ( + context._is_snowpark_connect_compatible_mode + and context._snowpark_connect_flatten_select_after_sort ): return + with context._aggregation_function_set_lock: + if context._aggregation_function_set: + return + retrieved_set = set() system_fetch_succeeded = False prefetch_state = context._aggregation_function_prefetch_state @@ -5083,8 +5084,9 @@ def _retrieve_aggregation_function_list(self) -> None: # bounding the hang in the rare case the winner thread dies before its # finally block runs (e.g. os._exit, interpreter shutdown). wait_event.wait(timeout=20) - if context._aggregation_function_set: - return + with context._aggregation_function_set_lock: + if context._aggregation_function_set: + return # Winner failed or timed out; fall through to sync query. try: @@ -5143,8 +5145,9 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: return prefetch_state = context._aggregation_function_prefetch_state with prefetch_state["lock"]: - if context._aggregation_function_set: - return + with context._aggregation_function_set_lock: + if context._aggregation_function_set: + return if prefetch_state["job"] is not None: return # A winner thread has already claimed the async job and is still publishing results.