|
17 | 17 | from collections import defaultdict |
18 | 18 | from functools import reduce |
19 | 19 | from logging import getLogger |
20 | | -from threading import RLock |
| 20 | +from threading import Event, RLock |
21 | 21 | from types import ModuleType |
22 | 22 | from typing import ( |
23 | 23 | TYPE_CHECKING, |
@@ -858,6 +858,7 @@ def __init__( |
858 | 858 | self._client_telemetry = EventTableTelemetry(session=self) |
859 | 859 |
|
860 | 860 | self._ast_batch = AstBatch(self) |
| 861 | + self._start_async_aggregation_prefetch_if_needed() |
861 | 862 |
|
862 | 863 | _logger.info("Snowpark Session information: %s", self._session_info) |
863 | 864 |
|
@@ -5045,53 +5046,129 @@ def _execute_sproc_internal( |
5045 | 5046 |
|
5046 | 5047 | def _retrieve_aggregation_function_list(self) -> None: |
5047 | 5048 | """Retrieve the list of aggregation functions which will later be used in sql simplifier.""" |
5048 | | - if ( |
5049 | | - not ( |
5050 | | - context._is_snowpark_connect_compatible_mode |
5051 | | - and context._snowpark_connect_flatten_select_after_sort |
5052 | | - ) |
5053 | | - or context._aggregation_function_set |
| 5049 | + if not ( |
| 5050 | + context._is_snowpark_connect_compatible_mode |
| 5051 | + and context._snowpark_connect_flatten_select_after_sort |
5054 | 5052 | ): |
5055 | 5053 | return |
5056 | 5054 |
|
| 5055 | + with context._aggregation_function_set_lock: |
| 5056 | + if context._aggregation_function_set: |
| 5057 | + return |
| 5058 | + |
5057 | 5059 | retrieved_set = set() |
| 5060 | + system_fetch_succeeded = False |
| 5061 | + prefetch_state = context._aggregation_function_prefetch_state |
| 5062 | + |
| 5063 | + # Atomically claim the async job. The claiming thread creates an Event so concurrent |
| 5064 | + # threads can wait on it rather than issuing redundant sync queries. |
| 5065 | + # AsyncJob.result() is not thread-safe — the underlying connector cursor mutates |
| 5066 | + # shared state (_result, _rownumber, _prefetch_hook) during result fetching, so only |
| 5067 | + # one thread may call it. The lock is held only for the pointer swap and event setup |
| 5068 | + # (nanoseconds), not the network call itself. |
| 5069 | + with prefetch_state["lock"]: |
| 5070 | + job, prefetch_state["job"] = prefetch_state["job"], None |
| 5071 | + if job is not None: |
| 5072 | + fetch_event = Event() |
| 5073 | + prefetch_state["event"] = fetch_event |
| 5074 | + wait_event = None |
| 5075 | + elif prefetch_state["event"] is not None: |
| 5076 | + fetch_event = None |
| 5077 | + wait_event = prefetch_state["event"] |
| 5078 | + else: |
| 5079 | + fetch_event = None |
| 5080 | + wait_event = None |
| 5081 | + |
| 5082 | + if wait_event is not None: |
| 5083 | + # The query typically finishes in ~5s; 20s gives ample headroom while |
| 5084 | + # bounding the hang in the rare case the winner thread dies before its |
| 5085 | + # finally block runs (e.g. os._exit, interpreter shutdown). |
| 5086 | + wait_event.wait(timeout=20) |
| 5087 | + with context._aggregation_function_set_lock: |
| 5088 | + if context._aggregation_function_set: |
| 5089 | + return |
| 5090 | + # Winner failed or timed out; fall through to sync query. |
5058 | 5091 |
|
5059 | | - # User-defined aggregation functions |
5060 | 5092 | try: |
5061 | | - retrieved_set.update( |
5062 | | - { |
5063 | | - r[0].lower() |
5064 | | - for r in self.sql( |
5065 | | - """select function_name from information_schema.functions where is_aggregate = 'YES'""" |
5066 | | - ).collect() |
5067 | | - } |
5068 | | - ) |
5069 | | - except Exception as e: |
5070 | | - _logger.debug( |
5071 | | - "Unable to get user-defined aggregation functions: %s", |
5072 | | - e, |
5073 | | - ) |
| 5093 | + if job is not None: |
| 5094 | + try: |
| 5095 | + retrieved_set.update({r[0].lower() for r in job.result()}) |
| 5096 | + system_fetch_succeeded = True |
| 5097 | + except Exception as e: |
| 5098 | + _logger.debug( |
| 5099 | + "Unable to use async aggregation function prefetch: %s", |
| 5100 | + e, |
| 5101 | + ) |
| 5102 | + else: |
| 5103 | + _logger.debug( |
| 5104 | + "Async aggregation function prefetch job is unavailable; using sync fallback." |
| 5105 | + ) |
5074 | 5106 |
|
5075 | | - # System built-in aggregation functions |
5076 | | - try: |
5077 | | - retrieved_set.update( |
5078 | | - { |
5079 | | - r[0].lower() |
5080 | | - for r in self.sql( |
5081 | | - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" |
5082 | | - ).collect() |
5083 | | - } |
5084 | | - ) |
5085 | | - except Exception as e: |
5086 | | - _logger.debug( |
5087 | | - "Unable to get system aggregation functions, " |
5088 | | - "falling back to hardcoded list: %s", |
5089 | | - e, |
5090 | | - ) |
5091 | | - retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) |
| 5107 | + # Sync fallback query. |
| 5108 | + if not system_fetch_succeeded: |
| 5109 | + try: |
| 5110 | + retrieved_set.update( |
| 5111 | + { |
| 5112 | + r[0].lower() |
| 5113 | + for r in self._conn.run_query( |
| 5114 | + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", |
| 5115 | + _is_internal=True, |
| 5116 | + )["data"] |
| 5117 | + } |
| 5118 | + ) |
| 5119 | + system_fetch_succeeded = True |
| 5120 | + except Exception as e: |
| 5121 | + _logger.debug( |
| 5122 | + "Unable to get aggregation functions via sync fallback query: %s", |
| 5123 | + e, |
| 5124 | + ) |
5092 | 5125 |
|
5093 | | - with context._aggregation_function_set_lock: |
5094 | | - context._aggregation_function_set.update(retrieved_set) |
| 5126 | + # Fallback to the local hardcoded list only when metadata retrieval fails. |
| 5127 | + if not system_fetch_succeeded: |
| 5128 | + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) |
| 5129 | + |
| 5130 | + with context._aggregation_function_set_lock: |
| 5131 | + context._aggregation_function_set.update(retrieved_set) |
| 5132 | + finally: |
| 5133 | + # Signal after _aggregation_function_set is published so waiters see |
| 5134 | + # the populated set immediately upon waking. Also fires on BaseException |
| 5135 | + # (e.g. KeyboardInterrupt) so waiters are never left blocking until timeout. |
| 5136 | + if fetch_event is not None: |
| 5137 | + fetch_event.set() |
| 5138 | + |
| 5139 | + def _start_async_aggregation_prefetch_if_needed(self) -> None: |
| 5140 | + """Start aggregation metadata prefetch only when not already in progress.""" |
| 5141 | + if not ( |
| 5142 | + context._is_snowpark_connect_compatible_mode |
| 5143 | + and context._snowpark_connect_flatten_select_after_sort |
| 5144 | + ): |
| 5145 | + return |
| 5146 | + prefetch_state = context._aggregation_function_prefetch_state |
| 5147 | + with prefetch_state["lock"]: |
| 5148 | + with context._aggregation_function_set_lock: |
| 5149 | + if context._aggregation_function_set: |
| 5150 | + return |
| 5151 | + if prefetch_state["job"] is not None: |
| 5152 | + return |
| 5153 | + # A winner thread has already claimed the async job and is still publishing results. |
| 5154 | + # Do not start a new async query while that in-flight fetch is unfinished. |
| 5155 | + if ( |
| 5156 | + prefetch_state["event"] is not None |
| 5157 | + and not prefetch_state["event"].is_set() |
| 5158 | + ): |
| 5159 | + return |
| 5160 | + try: |
| 5161 | + result = self._conn.execute_async_and_notify_query_listener( |
| 5162 | + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", |
| 5163 | + _is_internal=True, |
| 5164 | + ) |
| 5165 | + prefetch_state["job"] = self.create_async_job(result["queryId"]) |
| 5166 | + except Exception as e: # pragma: no cover |
| 5167 | + _logger.debug( |
| 5168 | + "Unable to start async aggregation metadata prefetch: %s", |
| 5169 | + e, |
| 5170 | + ) |
| 5171 | + prefetch_state["job"] = None |
5095 | 5172 |
|
5096 | 5173 | def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: |
5097 | 5174 | """ |
|
0 commit comments