Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def __init__(
self._profiler_collector = ConnectProfilerCollector()

self._progress_handlers: List[ProgressHandler] = []
self._execution_info_callbacks: "List[Callable[[str, ExecutionInfo], None]]" = []

self._zstd_module = _import_zstandard_if_available()
self._plan_compression_threshold: Optional[int] = None # Will be fetched lazily
Expand Down Expand Up @@ -792,6 +793,25 @@ def register_progress_handler(self, handler: ProgressHandler) -> None:
def clear_progress_handlers(self) -> None:
self._progress_handlers.clear()

def register_execution_info_callback(
self, cb: "Callable[[str, ExecutionInfo], None]"
) -> None:
if cb not in self._execution_info_callbacks:
self._execution_info_callbacks.append(cb)

def remove_execution_info_callback(
self, cb: "Callable[[str, ExecutionInfo], None]"
) -> None:
if cb in self._execution_info_callbacks:
self._execution_info_callbacks.remove(cb)

def _fire_execution_info(self, operation_id: str, ei: ExecutionInfo) -> None:
for cb in self._execution_info_callbacks:
try:
cb(operation_id, ei)
except Exception:
pass

def remove_progress_handler(self, handler: ProgressHandler) -> None:
"""
Remove a progress handler from the list of registered handlers.
Expand Down Expand Up @@ -1039,6 +1059,7 @@ def to_table(

# Create a query execution object.
ei = ExecutionInfo(metrics, observed_metrics)
self._fire_execution_info(req.operation_id, ei)
assert table is not None
return table, schema, ei

Expand Down Expand Up @@ -1075,6 +1096,7 @@ def to_pandas(
)
assert table is not None
ei = ExecutionInfo(metrics, observed_metrics)
self._fire_execution_info(req.operation_id, ei)

schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
assert schema is not None and isinstance(schema, StructType)
Expand Down Expand Up @@ -1220,6 +1242,7 @@ def execute_command(
)
# Create a query execution object.
ei = ExecutionInfo(metrics, observed_metrics)
self._fire_execution_info(req.operation_id, ei)
if data is not None:
return (data.to_pandas(), properties, ei)
else:
Expand Down
18 changes: 13 additions & 5 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

287 changes: 287 additions & 0 deletions python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4630,3 +4630,290 @@ class GetStatusResponse(google.protobuf.message.Message):
) -> None: ...

global___GetStatusResponse = GetStatusResponse

class ListSqlExecutionsRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

SESSION_ID_FIELD_NUMBER: builtins.int
USER_CONTEXT_FIELD_NUMBER: builtins.int
CLIENT_TYPE_FIELD_NUMBER: builtins.int
CLIENT_OBSERVED_SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
OFFSET_FIELD_NUMBER: builtins.int
LENGTH_FIELD_NUMBER: builtins.int
session_id: builtins.str
"""(Required) Spark session for the user identified by user_context.user_id."""
@property
def user_context(self) -> global___UserContext:
"""(Required) user_context.user_id and session_id identify a unique remote Spark session."""
client_type: builtins.str
"""(Optional) Client information for logging only; not interpreted by the server."""
client_observed_server_side_session_id: builtins.str
"""(Optional) Server-side generated idempotency key from a previous response. The server uses
this to validate that the server-side session has not changed since the client last saw it.
"""
offset: builtins.int
"""(Optional) Pagination. Negative offsets are clamped to 0. A length <= 0 or larger than the
server-side maximum is clamped by the server.
"""
length: builtins.int
def __init__(
self,
*,
session_id: builtins.str = ...,
user_context: global___UserContext | None = ...,
client_type: builtins.str | None = ...,
client_observed_server_side_session_id: builtins.str | None = ...,
offset: builtins.int = ...,
length: builtins.int = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_client_observed_server_side_session_id",
b"_client_observed_server_side_session_id",
"_client_type",
b"_client_type",
"client_observed_server_side_session_id",
b"client_observed_server_side_session_id",
"client_type",
b"client_type",
"user_context",
b"user_context",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_client_observed_server_side_session_id",
b"_client_observed_server_side_session_id",
"_client_type",
b"_client_type",
"client_observed_server_side_session_id",
b"client_observed_server_side_session_id",
"client_type",
b"client_type",
"length",
b"length",
"offset",
b"offset",
"session_id",
b"session_id",
"user_context",
b"user_context",
],
) -> None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
"_client_observed_server_side_session_id", b"_client_observed_server_side_session_id"
],
) -> typing_extensions.Literal["client_observed_server_side_session_id"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"]
) -> typing_extensions.Literal["client_type"] | None: ...

global___ListSqlExecutionsRequest = ListSqlExecutionsRequest

class ListSqlExecutionsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

class _SqlExecutionStatus:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType

class _SqlExecutionStatusEnumTypeWrapper(
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
ListSqlExecutionsResponse._SqlExecutionStatus.ValueType
],
builtins.type,
): # noqa: F821
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
SQL_EXECUTION_STATUS_UNSPECIFIED: (
ListSqlExecutionsResponse._SqlExecutionStatus.ValueType
) # 0
SQL_EXECUTION_STATUS_RUNNING: ListSqlExecutionsResponse._SqlExecutionStatus.ValueType # 1
SQL_EXECUTION_STATUS_COMPLETED: ListSqlExecutionsResponse._SqlExecutionStatus.ValueType # 2
SQL_EXECUTION_STATUS_FAILED: ListSqlExecutionsResponse._SqlExecutionStatus.ValueType # 3

class SqlExecutionStatus(_SqlExecutionStatus, metaclass=_SqlExecutionStatusEnumTypeWrapper): ...
SQL_EXECUTION_STATUS_UNSPECIFIED: ListSqlExecutionsResponse.SqlExecutionStatus.ValueType # 0
SQL_EXECUTION_STATUS_RUNNING: ListSqlExecutionsResponse.SqlExecutionStatus.ValueType # 1
SQL_EXECUTION_STATUS_COMPLETED: ListSqlExecutionsResponse.SqlExecutionStatus.ValueType # 2
SQL_EXECUTION_STATUS_FAILED: ListSqlExecutionsResponse.SqlExecutionStatus.ValueType # 3

class SqlExecutionSummary(google.protobuf.message.Message):
"""Lightweight summary of a single SQL execution. Plan strings and per-node metrics are
intentionally omitted here -- list responses must stay cheap. A future GetSqlExecution RPC
can return the heavy fields for a single execution_id.
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

EXECUTION_ID_FIELD_NUMBER: builtins.int
ROOT_EXECUTION_ID_FIELD_NUMBER: builtins.int
DESCRIPTION_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
SUBMISSION_TIME_MS_FIELD_NUMBER: builtins.int
COMPLETION_TIME_MS_FIELD_NUMBER: builtins.int
ERROR_MESSAGE_FIELD_NUMBER: builtins.int
JOB_IDS_FIELD_NUMBER: builtins.int
QUERY_ID_FIELD_NUMBER: builtins.int
DETAILS_FIELD_NUMBER: builtins.int
STAGE_COUNT_FIELD_NUMBER: builtins.int
execution_id: builtins.int
root_execution_id: builtins.int
description: builtins.str
status: global___ListSqlExecutionsResponse.SqlExecutionStatus.ValueType
submission_time_ms: builtins.int
completion_time_ms: builtins.int
"""Unset while the execution is still running."""
error_message: builtins.str
@property
def job_ids(
self,
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""Job IDs associated with this execution. Job statuses are not included; clients can look
them up via the existing /api/v1/applications/{appId}/jobs REST endpoint if needed.
"""
query_id: builtins.str
"""UUID assigned by SQLExecution; null for executions recovered from old event logs."""
details: builtins.str
"""Long form of the call site for the executing SQL/DataFrame operation. For Connect
executions this is set to a redacted, abbreviated rendering of the ExecutePlanRequest.
"""
stage_count: builtins.int
"""Number of Spark stages associated with this execution."""
def __init__(
self,
*,
execution_id: builtins.int = ...,
root_execution_id: builtins.int = ...,
description: builtins.str = ...,
status: global___ListSqlExecutionsResponse.SqlExecutionStatus.ValueType = ...,
submission_time_ms: builtins.int = ...,
completion_time_ms: builtins.int | None = ...,
error_message: builtins.str | None = ...,
job_ids: collections.abc.Iterable[builtins.int] | None = ...,
query_id: builtins.str | None = ...,
details: builtins.str | None = ...,
stage_count: builtins.int = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_completion_time_ms",
b"_completion_time_ms",
"_details",
b"_details",
"_error_message",
b"_error_message",
"_query_id",
b"_query_id",
"completion_time_ms",
b"completion_time_ms",
"details",
b"details",
"error_message",
b"error_message",
"query_id",
b"query_id",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_completion_time_ms",
b"_completion_time_ms",
"_details",
b"_details",
"_error_message",
b"_error_message",
"_query_id",
b"_query_id",
"completion_time_ms",
b"completion_time_ms",
"description",
b"description",
"details",
b"details",
"error_message",
b"error_message",
"execution_id",
b"execution_id",
"job_ids",
b"job_ids",
"query_id",
b"query_id",
"root_execution_id",
b"root_execution_id",
"stage_count",
b"stage_count",
"status",
b"status",
"submission_time_ms",
b"submission_time_ms",
],
) -> None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal["_completion_time_ms", b"_completion_time_ms"],
) -> typing_extensions.Literal["completion_time_ms"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_details", b"_details"]
) -> typing_extensions.Literal["details"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_error_message", b"_error_message"]
) -> typing_extensions.Literal["error_message"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_query_id", b"_query_id"]
) -> typing_extensions.Literal["query_id"] | None: ...

SESSION_ID_FIELD_NUMBER: builtins.int
SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
EXECUTIONS_FIELD_NUMBER: builtins.int
TOTAL_COUNT_FIELD_NUMBER: builtins.int
session_id: builtins.str
"""Session id of the session for which executions were requested."""
server_side_session_id: builtins.str
"""Server-side generated idempotency key that the client can use to assert that the
server-side session has not changed.
"""
@property
def executions(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___ListSqlExecutionsResponse.SqlExecutionSummary
]:
"""Page of executions, in the order returned by the underlying status store."""
total_count: builtins.int
"""Total number of executions known to the server, regardless of pagination."""
def __init__(
self,
*,
session_id: builtins.str = ...,
server_side_session_id: builtins.str = ...,
executions: collections.abc.Iterable[global___ListSqlExecutionsResponse.SqlExecutionSummary]
| None = ...,
total_count: builtins.int = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"executions",
b"executions",
"server_side_session_id",
b"server_side_session_id",
"session_id",
b"session_id",
"total_count",
b"total_count",
],
) -> None: ...

global___ListSqlExecutionsResponse = ListSqlExecutionsResponse
Loading