Skip to content

Commit 1253e7e

Browse files
SNOW-3484790: initialize aggregation functions list during SCOS init (#4217)
1 parent 318190c commit 1253e7e

4 files changed

Lines changed: 531 additions & 83 deletions

File tree

src/snowflake/snowpark/context.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""Context module for Snowpark."""
77
import logging
88
import sys
9-
from typing import Callable, Optional
9+
from typing import Any, Callable, Optional
1010

1111
import snowflake.snowpark
1212
import threading
@@ -45,6 +45,11 @@
4545
set()
4646
) # lower cased names of aggregation functions, used in sql simplification
4747
_aggregation_function_set_lock = threading.RLock()
48+
_aggregation_function_prefetch_state: dict[str, Any] = {
49+
"lock": threading.RLock(),
50+
"event": None,
51+
"job": None,
52+
}
4853

4954
# Hardcoded fallback for system built-in aggregation functions.
5055
# Used when the dynamic query fails to retrieve the list from the database.
@@ -62,35 +67,36 @@
6267
"ai_agg",
6368
"ai_summarize_agg",
6469
"any_value",
65-
"approximate_count_distinct",
66-
"approximate_jaccard_index",
67-
"approximate_similarity",
6870
"approx_count_distinct",
6971
"approx_percentile",
7072
"approx_percentile_accumulate",
7173
"approx_percentile_combine",
7274
"approx_top_k",
7375
"approx_top_k_accumulate",
7476
"approx_top_k_combine",
75-
"arrayagg",
77+
"approximate_count_distinct",
78+
"approximate_jaccard_index",
79+
"approximate_similarity",
7680
"array_agg",
7781
"array_union_agg",
7882
"array_unique_agg",
83+
"arrayagg",
7984
"avg",
80-
"bitandagg",
85+
"bit_and_agg",
86+
"bit_andagg",
87+
"bit_or_agg",
88+
"bit_oragg",
89+
"bit_xor_agg",
90+
"bit_xoragg",
8191
"bitand_agg",
92+
"bitandagg",
93+
"bitmap_and_agg",
8294
"bitmap_construct_agg",
8395
"bitmap_or_agg",
84-
"bitoragg",
8596
"bitor_agg",
86-
"bitxoragg",
97+
"bitoragg",
8798
"bitxor_agg",
88-
"bit_andagg",
89-
"bit_and_agg",
90-
"bit_oragg",
91-
"bit_or_agg",
92-
"bit_xoragg",
93-
"bit_xor_agg",
99+
"bitxoragg",
94100
"booland_agg",
95101
"boolor_agg",
96102
"boolxor_agg",
@@ -115,12 +121,12 @@
115121
"max_by",
116122
"median",
117123
"min",
124+
"min_by",
118125
"minhash",
119126
"minhash_combine",
120-
"min_by",
121127
"mode",
122-
"objectagg",
123128
"object_agg",
129+
"objectagg",
124130
"percentile_cont",
125131
"percentile_disc",
126132
"regr_avgx",
@@ -133,27 +139,29 @@
133139
"regr_sxy",
134140
"regr_syy",
135141
"skew",
142+
"st_intersection_agg_geography_internal",
143+
"st_union_agg_geography_internal",
136144
"stddev",
137145
"stddev_pop",
138146
"stddev_samp",
139-
"st_intersection_agg_geography_internal",
140-
"st_union_agg_geography_internal",
141147
"sum",
142148
"sum_internal",
143149
"sum_internal_real",
144150
"sum_real",
151+
"summarize_agg",
152+
"var_pop",
153+
"var_samp",
145154
"variance",
146155
"variance_pop",
147156
"variance_samp",
148-
"var_pop",
149-
"var_samp",
150157
"vector_avg",
151158
"vector_max",
152159
"vector_min",
153160
"vector_sum",
154161
]
155162
)
156163

164+
157165
_cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session.
158166

159167
# Following are internal-only global flags, used to enable development features.

src/snowflake/snowpark/session.py

Lines changed: 117 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections import defaultdict
1818
from functools import reduce
1919
from logging import getLogger
20-
from threading import RLock
20+
from threading import Event, RLock
2121
from types import ModuleType
2222
from typing import (
2323
TYPE_CHECKING,
@@ -858,6 +858,7 @@ def __init__(
858858
self._client_telemetry = EventTableTelemetry(session=self)
859859

860860
self._ast_batch = AstBatch(self)
861+
self._start_async_aggregation_prefetch_if_needed()
861862

862863
_logger.info("Snowpark Session information: %s", self._session_info)
863864

@@ -5045,53 +5046,129 @@ def _execute_sproc_internal(
50455046

50465047
def _retrieve_aggregation_function_list(self) -> None:
50475048
"""Retrieve the list of aggregation functions which will later be used in sql simplifier."""
5048-
if (
5049-
not (
5050-
context._is_snowpark_connect_compatible_mode
5051-
and context._snowpark_connect_flatten_select_after_sort
5052-
)
5053-
or context._aggregation_function_set
5049+
if not (
5050+
context._is_snowpark_connect_compatible_mode
5051+
and context._snowpark_connect_flatten_select_after_sort
50545052
):
50555053
return
50565054

5055+
with context._aggregation_function_set_lock:
5056+
if context._aggregation_function_set:
5057+
return
5058+
50575059
retrieved_set = set()
5060+
system_fetch_succeeded = False
5061+
prefetch_state = context._aggregation_function_prefetch_state
5062+
5063+
# Atomically claim the async job. The claiming thread creates an Event so concurrent
5064+
# threads can wait on it rather than issuing redundant sync queries.
5065+
# AsyncJob.result() is not thread-safe — the underlying connector cursor mutates
5066+
# shared state (_result, _rownumber, _prefetch_hook) during result fetching, so only
5067+
# one thread may call it. The lock is held only for the pointer swap and event setup
5068+
# (nanoseconds), not the network call itself.
5069+
with prefetch_state["lock"]:
5070+
job, prefetch_state["job"] = prefetch_state["job"], None
5071+
if job is not None:
5072+
fetch_event = Event()
5073+
prefetch_state["event"] = fetch_event
5074+
wait_event = None
5075+
elif prefetch_state["event"] is not None:
5076+
fetch_event = None
5077+
wait_event = prefetch_state["event"]
5078+
else:
5079+
fetch_event = None
5080+
wait_event = None
5081+
5082+
if wait_event is not None:
5083+
# The query typically finishes in ~5s; 20s gives ample headroom while
5084+
# bounding the hang in the rare case the winner thread dies before its
5085+
# finally block runs (e.g. os._exit, interpreter shutdown).
5086+
wait_event.wait(timeout=20)
5087+
with context._aggregation_function_set_lock:
5088+
if context._aggregation_function_set:
5089+
return
5090+
# Winner failed or timed out; fall through to sync query.
50585091

5059-
# User-defined aggregation functions
50605092
try:
5061-
retrieved_set.update(
5062-
{
5063-
r[0].lower()
5064-
for r in self.sql(
5065-
"""select function_name from information_schema.functions where is_aggregate = 'YES'"""
5066-
).collect()
5067-
}
5068-
)
5069-
except Exception as e:
5070-
_logger.debug(
5071-
"Unable to get user-defined aggregation functions: %s",
5072-
e,
5073-
)
5093+
if job is not None:
5094+
try:
5095+
retrieved_set.update({r[0].lower() for r in job.result()})
5096+
system_fetch_succeeded = True
5097+
except Exception as e:
5098+
_logger.debug(
5099+
"Unable to use async aggregation function prefetch: %s",
5100+
e,
5101+
)
5102+
else:
5103+
_logger.debug(
5104+
"Async aggregation function prefetch job is unavailable; using sync fallback."
5105+
)
50745106

5075-
# System built-in aggregation functions
5076-
try:
5077-
retrieved_set.update(
5078-
{
5079-
r[0].lower()
5080-
for r in self.sql(
5081-
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'"""
5082-
).collect()
5083-
}
5084-
)
5085-
except Exception as e:
5086-
_logger.debug(
5087-
"Unable to get system aggregation functions, "
5088-
"falling back to hardcoded list: %s",
5089-
e,
5090-
)
5091-
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)
5107+
# Sync fallback query.
5108+
if not system_fetch_succeeded:
5109+
try:
5110+
retrieved_set.update(
5111+
{
5112+
r[0].lower()
5113+
for r in self._conn.run_query(
5114+
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
5115+
_is_internal=True,
5116+
)["data"]
5117+
}
5118+
)
5119+
system_fetch_succeeded = True
5120+
except Exception as e:
5121+
_logger.debug(
5122+
"Unable to get aggregation functions via sync fallback query: %s",
5123+
e,
5124+
)
50925125

5093-
with context._aggregation_function_set_lock:
5094-
context._aggregation_function_set.update(retrieved_set)
5126+
# Fallback to the local hardcoded list only when metadata retrieval fails.
5127+
if not system_fetch_succeeded:
5128+
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)
5129+
5130+
with context._aggregation_function_set_lock:
5131+
context._aggregation_function_set.update(retrieved_set)
5132+
finally:
5133+
# Signal after _aggregation_function_set is published so waiters see
5134+
# the populated set immediately upon waking. Also fires on BaseException
5135+
# (e.g. KeyboardInterrupt) so waiters are never left blocking until timeout.
5136+
if fetch_event is not None:
5137+
fetch_event.set()
5138+
5139+
def _start_async_aggregation_prefetch_if_needed(self) -> None:
5140+
"""Start aggregation metadata prefetch only when not already in progress."""
5141+
if not (
5142+
context._is_snowpark_connect_compatible_mode
5143+
and context._snowpark_connect_flatten_select_after_sort
5144+
):
5145+
return
5146+
prefetch_state = context._aggregation_function_prefetch_state
5147+
with prefetch_state["lock"]:
5148+
with context._aggregation_function_set_lock:
5149+
if context._aggregation_function_set:
5150+
return
5151+
if prefetch_state["job"] is not None:
5152+
return
5153+
# A winner thread has already claimed the async job and is still publishing results.
5154+
# Do not start a new async query while that in-flight fetch is unfinished.
5155+
if (
5156+
prefetch_state["event"] is not None
5157+
and not prefetch_state["event"].is_set()
5158+
):
5159+
return
5160+
try:
5161+
result = self._conn.execute_async_and_notify_query_listener(
5162+
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
5163+
_is_internal=True,
5164+
)
5165+
prefetch_state["job"] = self.create_async_job(result["queryId"])
5166+
except Exception as e: # pragma: no cover
5167+
_logger.debug(
5168+
"Unable to start async aggregation metadata prefetch: %s",
5169+
e,
5170+
)
5171+
prefetch_state["job"] = None
50955172

50965173
def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
50975174
"""

0 commit comments

Comments
 (0)