Skip to content

Commit 3d7f2b6

Browse files
committed
Better version
1 parent 0a354b2 commit 3d7f2b6

6 files changed

Lines changed: 1315 additions & 106 deletions

File tree

python/pyspark/sql/connect/proto/base_pb2.py

Lines changed: 9 additions & 9 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/base_pb2.pyi

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4759,6 +4759,9 @@ class ListSqlExecutionsResponse(google.protobuf.message.Message):
47594759
COMPLETION_TIME_MS_FIELD_NUMBER: builtins.int
47604760
ERROR_MESSAGE_FIELD_NUMBER: builtins.int
47614761
JOB_IDS_FIELD_NUMBER: builtins.int
4762+
QUERY_ID_FIELD_NUMBER: builtins.int
4763+
DETAILS_FIELD_NUMBER: builtins.int
4764+
STAGE_COUNT_FIELD_NUMBER: builtins.int
47624765
execution_id: builtins.int
47634766
root_execution_id: builtins.int
47644767
description: builtins.str
@@ -4774,6 +4777,14 @@ class ListSqlExecutionsResponse(google.protobuf.message.Message):
47744777
"""Job IDs associated with this execution. Job statuses are not included; clients can look
47754778
them up via the existing /api/v1/applications/{appId}/jobs REST endpoint if needed.
47764779
"""
4780+
query_id: builtins.str
4781+
"""UUID assigned by SQLExecution; null for executions recovered from old event logs."""
4782+
details: builtins.str
4783+
"""Long form of the call site for the executing SQL/DataFrame operation. For Connect
4784+
executions this is set to a redacted, abbreviated rendering of the ExecutePlanRequest.
4785+
"""
4786+
stage_count: builtins.int
4787+
"""Number of Spark stages associated with this execution."""
47774788
def __init__(
47784789
self,
47794790
*,
@@ -4785,39 +4796,60 @@ class ListSqlExecutionsResponse(google.protobuf.message.Message):
47854796
completion_time_ms: builtins.int | None = ...,
47864797
error_message: builtins.str | None = ...,
47874798
job_ids: collections.abc.Iterable[builtins.int] | None = ...,
4799+
query_id: builtins.str | None = ...,
4800+
details: builtins.str | None = ...,
4801+
stage_count: builtins.int = ...,
47884802
) -> None: ...
47894803
def HasField(
47904804
self,
47914805
field_name: typing_extensions.Literal[
47924806
"_completion_time_ms",
47934807
b"_completion_time_ms",
4808+
"_details",
4809+
b"_details",
47944810
"_error_message",
47954811
b"_error_message",
4812+
"_query_id",
4813+
b"_query_id",
47964814
"completion_time_ms",
47974815
b"completion_time_ms",
4816+
"details",
4817+
b"details",
47984818
"error_message",
47994819
b"error_message",
4820+
"query_id",
4821+
b"query_id",
48004822
],
48014823
) -> builtins.bool: ...
48024824
def ClearField(
48034825
self,
48044826
field_name: typing_extensions.Literal[
48054827
"_completion_time_ms",
48064828
b"_completion_time_ms",
4829+
"_details",
4830+
b"_details",
48074831
"_error_message",
48084832
b"_error_message",
4833+
"_query_id",
4834+
b"_query_id",
48094835
"completion_time_ms",
48104836
b"completion_time_ms",
48114837
"description",
48124838
b"description",
4839+
"details",
4840+
b"details",
48134841
"error_message",
48144842
b"error_message",
48154843
"execution_id",
48164844
b"execution_id",
48174845
"job_ids",
48184846
b"job_ids",
4847+
"query_id",
4848+
b"query_id",
48194849
"root_execution_id",
48204850
b"root_execution_id",
4851+
"stage_count",
4852+
b"stage_count",
48214853
"status",
48224854
b"status",
48234855
"submission_time_ms",
@@ -4830,9 +4862,17 @@ class ListSqlExecutionsResponse(google.protobuf.message.Message):
48304862
oneof_group: typing_extensions.Literal["_completion_time_ms", b"_completion_time_ms"],
48314863
) -> typing_extensions.Literal["completion_time_ms"] | None: ...
48324864
@typing.overload
4865+
def WhichOneof(
4866+
self, oneof_group: typing_extensions.Literal["_details", b"_details"]
4867+
) -> typing_extensions.Literal["details"] | None: ...
4868+
@typing.overload
48334869
def WhichOneof(
48344870
self, oneof_group: typing_extensions.Literal["_error_message", b"_error_message"]
48354871
) -> typing_extensions.Literal["error_message"] | None: ...
4872+
@typing.overload
4873+
def WhichOneof(
4874+
self, oneof_group: typing_extensions.Literal["_query_id", b"_query_id"]
4875+
) -> typing_extensions.Literal["query_id"] | None: ...
48364876

48374877
SESSION_ID_FIELD_NUMBER: builtins.int
48384878
SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int

python/pyspark/sql/connect/ui/__init__.py

Lines changed: 225 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,119 @@
2323

2424
import logging
2525
import os
26+
import re
2627
import sys
2728
import threading
29+
import time
2830
import warnings
29-
from dataclasses import dataclass
30-
from typing import List, Optional, TYPE_CHECKING
31+
from dataclasses import dataclass, field
32+
from typing import Iterable, List, Optional, TYPE_CHECKING
3133

3234
import pyspark.sql.connect.proto as pb2
3335
from pyspark.errors import PySparkAssertionError
3436
from pyspark.errors.exceptions.connect import SparkConnectException
37+
from pyspark.sql.connect.shell.progress import (
38+
ProgressHandler as _ProgressHandlerBase,
39+
StageInfo,
40+
)
3541

3642
if TYPE_CHECKING:
3743
from pyspark.sql.connect.session import SparkSession
3844

3945
_logger = logging.getLogger(__name__)
4046

4147

48+
@dataclass(frozen=True)
49+
class StageProgress:
50+
stage_id: int
51+
num_tasks: int
52+
num_completed_tasks: int
53+
num_bytes_read: int
54+
done: bool
55+
56+
57+
@dataclass(frozen=True)
58+
class ProgressSnapshot:
59+
"""Immutable snapshot of the latest Connect operation progress for the UI."""
60+
61+
operation_id: Optional[str]
62+
total_tasks: int
63+
completed_tasks: int
64+
inflight_tasks: int
65+
bytes_read: int
66+
elapsed_seconds: float
67+
done: bool
68+
stages: List[StageProgress] = field(default_factory=list)
69+
70+
71+
class _UIProgressHandler(_ProgressHandlerBase):
72+
"""``ProgressHandler`` that retains the latest progress event for the Flask UI.
73+
74+
The Connect client fires events from a network thread; the UI thread reads
75+
``snapshot()`` from the Flask request handler. A lock guards the mutable state.
76+
"""
77+
78+
def __init__(self) -> None:
79+
self._lock = threading.Lock()
80+
self._operation_id: Optional[str] = None
81+
self._started_at: Optional[float] = None
82+
self._last_seen_at: float = 0.0
83+
self._stages: List[StageInfo] = []
84+
self._inflight_tasks: int = 0
85+
self._done: bool = True
86+
87+
def __call__(
88+
self,
89+
stages: Optional[Iterable[StageInfo]],
90+
inflight_tasks: int,
91+
operation_id: Optional[str],
92+
done: bool,
93+
) -> None:
94+
now = time.monotonic()
95+
with self._lock:
96+
if operation_id and operation_id != self._operation_id:
97+
self._operation_id = operation_id
98+
self._started_at = now
99+
self._stages = list(stages or [])
100+
self._inflight_tasks = inflight_tasks
101+
self._last_seen_at = now
102+
self._done = done
103+
104+
def snapshot(self) -> ProgressSnapshot:
105+
with self._lock:
106+
stages = [
107+
StageProgress(
108+
stage_id=s.stage_id,
109+
num_tasks=s.num_tasks,
110+
num_completed_tasks=s.num_completed_tasks,
111+
num_bytes_read=s.num_bytes_read,
112+
done=s.done,
113+
)
114+
for s in self._stages
115+
]
116+
if self._started_at is None:
117+
elapsed = 0.0
118+
elif self._done:
119+
# Freeze at the last progress-event time so completed-query stats don't keep
120+
# ticking after the query ends.
121+
elapsed = self._last_seen_at - self._started_at
122+
else:
123+
# Wall-clock-ish: events from the server can be sparse (one per several
124+
# seconds), so compute against ``now`` rather than the last event time --
125+
# otherwise the bar appears frozen between events.
126+
elapsed = time.monotonic() - self._started_at
127+
return ProgressSnapshot(
128+
operation_id=self._operation_id,
129+
total_tasks=sum(s.num_tasks for s in stages),
130+
completed_tasks=sum(s.num_completed_tasks for s in stages),
131+
inflight_tasks=self._inflight_tasks,
132+
bytes_read=sum(s.num_bytes_read for s in stages),
133+
elapsed_seconds=elapsed,
134+
done=self._done,
135+
stages=stages,
136+
)
137+
138+
42139
@dataclass(frozen=True)
43140
class SqlExecutionSummary:
44141
execution_id: int
@@ -49,6 +146,98 @@ class SqlExecutionSummary:
49146
completion_time_ms: Optional[int]
50147
error_message: Optional[str]
51148
job_ids: List[int]
149+
query_id: Optional[str] = None
150+
details: Optional[str] = None
151+
stage_count: int = 0
152+
# Best-effort parse of ``description`` -- Connect sets the SQL job description to
153+
# ``"Spark Connect - <protobuf text dump of ExecutePlanRequest>"`` truncated to ~128 chars,
154+
# which surfaces user_id, the request session_id, and the leading plan operation.
155+
user_id: Optional[str] = None
156+
request_session_id: Optional[str] = None
157+
operation: Optional[str] = None
158+
159+
160+
_PARSE_USER_ID = re.compile(r'user_id:\s*"([^"]*)"')
161+
_PARSE_SESSION_ID = re.compile(r'session_id:\s*"([^"]*)"')
162+
_PARSE_PLAN_OPEN = re.compile(r"plan\s*\{\s*(?:root|command)\s*\{")
163+
_PARSE_FIELD_OPEN = re.compile(r"\s*([a-z_][a-z0-9_]*)\s*\{")
164+
165+
166+
def _skip_balanced_block(text: str, start: int) -> int:
167+
"""Return the index just past the ``}`` that closes a ``{`` already opened before ``start``.
168+
169+
Returns ``-1`` if the block is unterminated (e.g. truncated mid-stream).
170+
"""
171+
depth = 1
172+
i = start
173+
while i < len(text) and depth > 0:
174+
c = text[i]
175+
if c == "{":
176+
depth += 1
177+
elif c == "}":
178+
depth -= 1
179+
i += 1
180+
return i if depth == 0 else -1
181+
182+
183+
def _parse_connect_proto_text(
184+
text: str,
185+
) -> "tuple[Optional[str], Optional[str], Optional[str]]":
186+
"""Extract ``(user_id, session_id, operation)`` from a Connect proto-text dump.
187+
188+
Handles both the short SQL job description (~128 chars, prefixed with ``"Spark Connect - "``)
189+
and the longer ``callSite.long`` dump (~2048 chars, the bare ``ExecutePlanRequest`` text).
190+
Both share the same structure -- only the truncation depth differs.
191+
"""
192+
if not text or "user_context" not in text or "plan {" not in text:
193+
return (None, None, None)
194+
user = _PARSE_USER_ID.search(text)
195+
sess = _PARSE_SESSION_ID.search(text)
196+
197+
operation: Optional[str] = None
198+
plan_open = _PARSE_PLAN_OPEN.search(text)
199+
if plan_open is not None:
200+
pos = plan_open.end()
201+
# Inside `root { ... }` (or `command { ... }`) the proto text always begins with a
202+
# `common { ... }` metadata block. Skip past it by counting braces -- the block has
203+
# nested children so a simple regex won't suffice.
204+
common_open = re.match(r"\s*common\s*\{", text[pos:])
205+
if common_open is not None:
206+
past_common = _skip_balanced_block(text, pos + common_open.end())
207+
if past_common == -1:
208+
pos = -1 # truncated mid-common, can't find the real op
209+
else:
210+
pos = past_common
211+
if pos != -1:
212+
op_match = _PARSE_FIELD_OPEN.match(text, pos)
213+
if op_match is not None:
214+
operation = op_match.group(1)
215+
216+
return (
217+
user.group(1) if user else None,
218+
sess.group(1) if sess else None,
219+
operation,
220+
)
221+
222+
223+
def _parse_description(
224+
description: str, details: Optional[str]
225+
) -> "tuple[Optional[str], Optional[str], Optional[str]]":
226+
"""Return ``(user_id, session_id, operation)`` for a Connect SQL execution.
227+
228+
Prefers the long ``details`` (the ``callSite.long`` of the Connect job, ~2048 chars) for
229+
``operation`` because the short ``description`` is truncated to ~128 chars and almost
230+
always cuts off mid-``common`` block. Falls back to ``description`` if ``details`` is
231+
missing or non-Connect.
232+
"""
233+
user, sess, op = _parse_connect_proto_text(details or "")
234+
if user is None and sess is None and op is None:
235+
user, sess, op = _parse_connect_proto_text(description)
236+
elif op is None:
237+
# Details parsed but operation wasn't extractable; try description as a last resort.
238+
_, _, fallback_op = _parse_connect_proto_text(description)
239+
op = fallback_op
240+
return (user, sess, op)
52241

53242

54243
def _verify_ui_response(client: "object", resp: "pb2.ListSqlExecutionsResponse") -> None:
@@ -121,21 +310,31 @@ def list_sql_executions(
121310
if resp is None:
122311
raise SparkConnectException("Invalid state during retry exception handling.")
123312

124-
return [
125-
SqlExecutionSummary(
126-
execution_id=e.execution_id,
127-
root_execution_id=e.root_execution_id,
128-
description=e.description,
129-
status=_status_name(e.status),
130-
submission_time_ms=e.submission_time_ms,
131-
completion_time_ms=(
132-
e.completion_time_ms if e.HasField("completion_time_ms") else None
133-
),
134-
error_message=e.error_message if e.HasField("error_message") else None,
135-
job_ids=list(e.job_ids),
313+
summaries = []
314+
for e in resp.executions:
315+
details = e.details if e.HasField("details") else None
316+
user_id, request_session_id, operation = _parse_description(e.description, details)
317+
summaries.append(
318+
SqlExecutionSummary(
319+
execution_id=e.execution_id,
320+
root_execution_id=e.root_execution_id,
321+
description=e.description,
322+
status=_status_name(e.status),
323+
submission_time_ms=e.submission_time_ms,
324+
completion_time_ms=(
325+
e.completion_time_ms if e.HasField("completion_time_ms") else None
326+
),
327+
error_message=e.error_message if e.HasField("error_message") else None,
328+
job_ids=list(e.job_ids),
329+
query_id=e.query_id if e.HasField("query_id") else None,
330+
details=e.details if e.HasField("details") else None,
331+
stage_count=e.stage_count,
332+
user_id=user_id,
333+
request_session_id=request_session_id,
334+
operation=operation,
335+
)
136336
)
137-
for e in resp.executions
138-
]
337+
return summaries
139338

140339

141340
def start_in_background(
@@ -165,11 +364,20 @@ def start_in_background(
165364
)
166365
return None
167366

367+
progress_handler = _UIProgressHandler()
368+
try:
369+
spark.registerProgressHandler(progress_handler)
370+
except Exception as e: # noqa: BLE001 - non-fatal: UI still works without progress
371+
_logger.warning("Failed to register Connect UI progress handler: %s", e)
372+
progress_handler = None # type: ignore[assignment]
373+
168374
# 0 = ephemeral. Each SparkSession gets its own UI, so collisions on a fixed
169375
# port would otherwise force the second session onto a different one anyway.
170376
bind_port = port if port is not None else 0
171377
try:
172-
app = make_app(spark, refresh_seconds=refresh_seconds)
378+
app = make_app(
379+
spark, refresh_seconds=refresh_seconds, progress_handler=progress_handler
380+
)
173381
server = make_server(host, bind_port, app, threaded=True)
174382
except OSError as e:
175383
warnings.warn(

0 commit comments

Comments
 (0)