2323
2424import logging
2525import os
26+ import re
2627import sys
2728import threading
29+ import time
2830import warnings
29- from dataclasses import dataclass
30- from typing import List , Optional , TYPE_CHECKING
31+ from dataclasses import dataclass , field
32+ from typing import Iterable , List , Optional , TYPE_CHECKING
3133
3234import pyspark .sql .connect .proto as pb2
3335from pyspark .errors import PySparkAssertionError
3436from pyspark .errors .exceptions .connect import SparkConnectException
37+ from pyspark .sql .connect .shell .progress import (
38+ ProgressHandler as _ProgressHandlerBase ,
39+ StageInfo ,
40+ )
3541
3642if TYPE_CHECKING :
3743 from pyspark .sql .connect .session import SparkSession
3844
3945_logger = logging .getLogger (__name__ )
4046
4147
48+ @dataclass (frozen = True )
49+ class StageProgress :
50+ stage_id : int
51+ num_tasks : int
52+ num_completed_tasks : int
53+ num_bytes_read : int
54+ done : bool
55+
56+
57+ @dataclass (frozen = True )
58+ class ProgressSnapshot :
59+ """Immutable snapshot of the latest Connect operation progress for the UI."""
60+
61+ operation_id : Optional [str ]
62+ total_tasks : int
63+ completed_tasks : int
64+ inflight_tasks : int
65+ bytes_read : int
66+ elapsed_seconds : float
67+ done : bool
68+ stages : List [StageProgress ] = field (default_factory = list )
69+
70+
71+ class _UIProgressHandler (_ProgressHandlerBase ):
72+ """``ProgressHandler`` that retains the latest progress event for the Flask UI.
73+
74+ The Connect client fires events from a network thread; the UI thread reads
75+ ``snapshot()`` from the Flask request handler. A lock guards the mutable state.
76+ """
77+
78+ def __init__ (self ) -> None :
79+ self ._lock = threading .Lock ()
80+ self ._operation_id : Optional [str ] = None
81+ self ._started_at : Optional [float ] = None
82+ self ._last_seen_at : float = 0.0
83+ self ._stages : List [StageInfo ] = []
84+ self ._inflight_tasks : int = 0
85+ self ._done : bool = True
86+
87+ def __call__ (
88+ self ,
89+ stages : Optional [Iterable [StageInfo ]],
90+ inflight_tasks : int ,
91+ operation_id : Optional [str ],
92+ done : bool ,
93+ ) -> None :
94+ now = time .monotonic ()
95+ with self ._lock :
96+ if operation_id and operation_id != self ._operation_id :
97+ self ._operation_id = operation_id
98+ self ._started_at = now
99+ self ._stages = list (stages or [])
100+ self ._inflight_tasks = inflight_tasks
101+ self ._last_seen_at = now
102+ self ._done = done
103+
104+ def snapshot (self ) -> ProgressSnapshot :
105+ with self ._lock :
106+ stages = [
107+ StageProgress (
108+ stage_id = s .stage_id ,
109+ num_tasks = s .num_tasks ,
110+ num_completed_tasks = s .num_completed_tasks ,
111+ num_bytes_read = s .num_bytes_read ,
112+ done = s .done ,
113+ )
114+ for s in self ._stages
115+ ]
116+ if self ._started_at is None :
117+ elapsed = 0.0
118+ elif self ._done :
119+ # Freeze at the last progress-event time so completed-query stats don't keep
120+ # ticking after the query ends.
121+ elapsed = self ._last_seen_at - self ._started_at
122+ else :
123+ # Wall-clock-ish: events from the server can be sparse (one per several
124+ # seconds), so compute against ``now`` rather than the last event time --
125+ # otherwise the bar appears frozen between events.
126+ elapsed = time .monotonic () - self ._started_at
127+ return ProgressSnapshot (
128+ operation_id = self ._operation_id ,
129+ total_tasks = sum (s .num_tasks for s in stages ),
130+ completed_tasks = sum (s .num_completed_tasks for s in stages ),
131+ inflight_tasks = self ._inflight_tasks ,
132+ bytes_read = sum (s .num_bytes_read for s in stages ),
133+ elapsed_seconds = elapsed ,
134+ done = self ._done ,
135+ stages = stages ,
136+ )
137+
138+
42139@dataclass (frozen = True )
43140class SqlExecutionSummary :
44141 execution_id : int
@@ -49,6 +146,98 @@ class SqlExecutionSummary:
49146 completion_time_ms : Optional [int ]
50147 error_message : Optional [str ]
51148 job_ids : List [int ]
149+ query_id : Optional [str ] = None
150+ details : Optional [str ] = None
151+ stage_count : int = 0
152+ # Best-effort parse of ``description`` -- Connect sets the SQL job description to
153+ # ``"Spark Connect - <protobuf text dump of ExecutePlanRequest>"`` truncated to ~128 chars,
154+ # which surfaces user_id, the request session_id, and the leading plan operation.
155+ user_id : Optional [str ] = None
156+ request_session_id : Optional [str ] = None
157+ operation : Optional [str ] = None
158+
159+
160+ _PARSE_USER_ID = re .compile (r'user_id:\s*"([^"]*)"' )
161+ _PARSE_SESSION_ID = re .compile (r'session_id:\s*"([^"]*)"' )
162+ _PARSE_PLAN_OPEN = re .compile (r"plan\s*\{\s*(?:root|command)\s*\{" )
163+ _PARSE_FIELD_OPEN = re .compile (r"\s*([a-z_][a-z0-9_]*)\s*\{" )
164+
165+
166+ def _skip_balanced_block (text : str , start : int ) -> int :
167+ """Return the index just past the ``}`` that closes a ``{`` already opened before ``start``.
168+
169+ Returns ``-1`` if the block is unterminated (e.g. truncated mid-stream).
170+ """
171+ depth = 1
172+ i = start
173+ while i < len (text ) and depth > 0 :
174+ c = text [i ]
175+ if c == "{" :
176+ depth += 1
177+ elif c == "}" :
178+ depth -= 1
179+ i += 1
180+ return i if depth == 0 else - 1
181+
182+
183+ def _parse_connect_proto_text (
184+ text : str ,
185+ ) -> "tuple[Optional[str], Optional[str], Optional[str]]" :
186+ """Extract ``(user_id, session_id, operation)`` from a Connect proto-text dump.
187+
188+ Handles both the short SQL job description (~128 chars, prefixed with ``"Spark Connect - "``)
189+ and the longer ``callSite.long`` dump (~2048 chars, the bare ``ExecutePlanRequest`` text).
190+ Both share the same structure -- only the truncation depth differs.
191+ """
192+ if not text or "user_context" not in text or "plan {" not in text :
193+ return (None , None , None )
194+ user = _PARSE_USER_ID .search (text )
195+ sess = _PARSE_SESSION_ID .search (text )
196+
197+ operation : Optional [str ] = None
198+ plan_open = _PARSE_PLAN_OPEN .search (text )
199+ if plan_open is not None :
200+ pos = plan_open .end ()
201+ # Inside `root { ... }` (or `command { ... }`) the proto text always begins with a
202+ # `common { ... }` metadata block. Skip past it by counting braces -- the block has
203+ # nested children so a simple regex won't suffice.
204+ common_open = re .match (r"\s*common\s*\{" , text [pos :])
205+ if common_open is not None :
206+ past_common = _skip_balanced_block (text , pos + common_open .end ())
207+ if past_common == - 1 :
208+ pos = - 1 # truncated mid-common, can't find the real op
209+ else :
210+ pos = past_common
211+ if pos != - 1 :
212+ op_match = _PARSE_FIELD_OPEN .match (text , pos )
213+ if op_match is not None :
214+ operation = op_match .group (1 )
215+
216+ return (
217+ user .group (1 ) if user else None ,
218+ sess .group (1 ) if sess else None ,
219+ operation ,
220+ )
221+
222+
223+ def _parse_description (
224+ description : str , details : Optional [str ]
225+ ) -> "tuple[Optional[str], Optional[str], Optional[str]]" :
226+ """Return ``(user_id, session_id, operation)`` for a Connect SQL execution.
227+
228+ Prefers the long ``details`` (the ``callSite.long`` of the Connect job, ~2048 chars) for
229+ ``operation`` because the short ``description`` is truncated to ~128 chars and almost
230+ always cuts off mid-``common`` block. Falls back to ``description`` if ``details`` is
231+ missing or non-Connect.
232+ """
233+ user , sess , op = _parse_connect_proto_text (details or "" )
234+ if user is None and sess is None and op is None :
235+ user , sess , op = _parse_connect_proto_text (description )
236+ elif op is None :
237+ # Details parsed but operation wasn't extractable; try description as a last resort.
238+ _ , _ , fallback_op = _parse_connect_proto_text (description )
239+ op = fallback_op
240+ return (user , sess , op )
52241
53242
54243def _verify_ui_response (client : "object" , resp : "pb2.ListSqlExecutionsResponse" ) -> None :
@@ -121,21 +310,31 @@ def list_sql_executions(
121310 if resp is None :
122311 raise SparkConnectException ("Invalid state during retry exception handling." )
123312
124- return [
125- SqlExecutionSummary (
126- execution_id = e .execution_id ,
127- root_execution_id = e .root_execution_id ,
128- description = e .description ,
129- status = _status_name (e .status ),
130- submission_time_ms = e .submission_time_ms ,
131- completion_time_ms = (
132- e .completion_time_ms if e .HasField ("completion_time_ms" ) else None
133- ),
134- error_message = e .error_message if e .HasField ("error_message" ) else None ,
135- job_ids = list (e .job_ids ),
313+ summaries = []
314+ for e in resp .executions :
315+ details = e .details if e .HasField ("details" ) else None
316+ user_id , request_session_id , operation = _parse_description (e .description , details )
317+ summaries .append (
318+ SqlExecutionSummary (
319+ execution_id = e .execution_id ,
320+ root_execution_id = e .root_execution_id ,
321+ description = e .description ,
322+ status = _status_name (e .status ),
323+ submission_time_ms = e .submission_time_ms ,
324+ completion_time_ms = (
325+ e .completion_time_ms if e .HasField ("completion_time_ms" ) else None
326+ ),
327+ error_message = e .error_message if e .HasField ("error_message" ) else None ,
328+ job_ids = list (e .job_ids ),
329+ query_id = e .query_id if e .HasField ("query_id" ) else None ,
330+ details = e .details if e .HasField ("details" ) else None ,
331+ stage_count = e .stage_count ,
332+ user_id = user_id ,
333+ request_session_id = request_session_id ,
334+ operation = operation ,
335+ )
136336 )
137- for e in resp .executions
138- ]
337+ return summaries
139338
140339
141340def start_in_background (
@@ -165,11 +364,20 @@ def start_in_background(
165364 )
166365 return None
167366
367+ progress_handler = _UIProgressHandler ()
368+ try :
369+ spark .registerProgressHandler (progress_handler )
370+ except Exception as e : # noqa: BLE001 - non-fatal: UI still works without progress
371+ _logger .warning ("Failed to register Connect UI progress handler: %s" , e )
372+ progress_handler = None # type: ignore[assignment]
373+
168374 # 0 = ephemeral. Each SparkSession gets its own UI, so collisions on a fixed
169375 # port would otherwise force the second session onto a different one anyway.
170376 bind_port = port if port is not None else 0
171377 try :
172- app = make_app (spark , refresh_seconds = refresh_seconds )
378+ app = make_app (
379+ spark , refresh_seconds = refresh_seconds , progress_handler = progress_handler
380+ )
173381 server = make_server (host , bind_port , app , threaded = True )
174382 except OSError as e :
175383 warnings .warn (
0 commit comments