Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
__all__ = [
"ChannelBuilder",
"DefaultChannelBuilder",
"PathAwareChannelBuilder",
"RpcDeadlines",
"SparkConnectClient",
]
Expand All @@ -32,6 +33,7 @@
import threading
import os
import copy
import collections
import platform
import urllib.parse
import uuid
Expand All @@ -40,6 +42,7 @@
import traceback
import weakref
from typing import (
Callable,
Iterable,
Iterator,
Optional,
Expand Down Expand Up @@ -546,6 +549,191 @@ def toChannel(self) -> grpc.Channel:
return self._secure_channel(self.endpoint, creds)


class _ClientCallDetails(
collections.namedtuple(
"_ClientCallDetails",
("method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"),
),
grpc.ClientCallDetails,
):
pass


class _PathPrefixInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
"""Prepends a fixed path prefix to every gRPC method, so an ingress doing
path-based routing can dispatch to the right Spark Connect driver."""

def __init__(self, prefix: str) -> None:
self._prefix = prefix

def _rewrite(self, details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
return _ClientCallDetails(
method=self._prefix + details.method,
timeout=details.timeout,
metadata=details.metadata,
credentials=details.credentials,
wait_for_ready=getattr(details, "wait_for_ready", None),
compression=getattr(details, "compression", None),
)

def intercept_unary_unary(
self,
continuation: Callable[[grpc.ClientCallDetails, Any], Any],
client_call_details: grpc.ClientCallDetails,
request: Any,
) -> Any:
return continuation(self._rewrite(client_call_details), request)

def intercept_unary_stream(
self,
continuation: Callable[[grpc.ClientCallDetails, Any], Any],
client_call_details: grpc.ClientCallDetails,
request: Any,
) -> Any:
return continuation(self._rewrite(client_call_details), request)

def intercept_stream_unary(
self,
continuation: Callable[[grpc.ClientCallDetails, Any], Any],
client_call_details: grpc.ClientCallDetails,
request_iterator: Iterator[Any],
) -> Any:
return continuation(self._rewrite(client_call_details), request_iterator)

def intercept_stream_stream(
self,
continuation: Callable[[grpc.ClientCallDetails, Any], Any],
client_call_details: grpc.ClientCallDetails,
request_iterator: Iterator[Any],
) -> Any:
return continuation(self._rewrite(client_call_details), request_iterator)


class PathAwareChannelBuilder(DefaultChannelBuilder):
"""
Channel builder that extends :class:`DefaultChannelBuilder` with support for
ingress path-based routing.

In addition to the standard form (`sc://host[:port][/;params]`), it accepts a
path-routed form `sc://gateway/<prefix>[:<port>][/;params]`, where the path
component is an ingress route prefix and an optional trailing `:<port>` on the
final path segment is the connection port (used only when the netloc carries no
port of its own). When a path prefix is present it is prepended to every gRPC
method via an interceptor, so an HTTPS ingress can dispatch by URL to the right
Spark Connect driver, and TLS is enabled implicitly when the resolved port is
443 and `;use_ssl=` was not specified.

Examples
--------
>>> cb = PathAwareChannelBuilder("sc://host1/path1:443")
>>> cb.endpoint, cb.path_prefix, cb.secure
('host1:443', '/path1', True)
"""

def __init__(self, url: str, channelOptions: Optional[List[Tuple[str, Any]]] = None) -> None:
"""
Constructs a new channel builder. This is used to create the proper GRPC channel from
the connection string.

Parameters
----------
url : str
Spark Connect connection string
channelOptions: list of tuple, optional
Additional options that can be passed to the GRPC channel construction.
"""
# Strip the ingress route prefix (and optional trailing :port) from the path,
# then let DefaultChannelBuilder parse the remaining path-free URL. This reuses
# its scheme check, params/hostname/port parsing (including IPv6 handling and the
# SPARK_TESTING ephemeral-port resolution) and channel construction.
base_url, path_prefix = self._split_path_prefix(url)
super().__init__(base_url, channelOptions=channelOptions)
self._path_prefix = path_prefix
if path_prefix:
# A resolved port of 443 implies a standard HTTPS ingress; enable TLS unless
# ;use_ssl= was set explicitly. This lives under the same condition as the
# prefix, so a prefix-less URL never gets a port-implied TLS (or, worse, a
# path-derived port 443 spoken in plaintext).
if self._port == 443 and ChannelBuilder.PARAM_USE_SSL not in self._params:
self.set(ChannelBuilder.PARAM_USE_SSL, "true")
self.add_interceptor(_PathPrefixInterceptor(path_prefix))

@staticmethod
def _split_path_prefix(url: str) -> Tuple[str, str]:
"""Split a path-routed `sc://` URL into a path-free URL for
:class:`DefaultChannelBuilder` and the extracted ingress route prefix (empty
when there is none). Raises for a malformed path port or a route prefix that
contains a `;` parameter separator.
"""
if url[:5] != "sc://":
# Defer to DefaultChannelBuilder to raise the standard scheme error.
return url, ""
# Rewrite the URL to use http as the scheme so we can reuse Python's parser.
parsed = urllib.parse.urlparse("http" + url[2:])
path = parsed.path
if not path or path == "/":
return url, ""

# A trailing slash may remain when the path is combined with the standard
# `/;params` form (e.g. `sc://gateway/app/driver:443/;token=abc` parses to path
# `/app/driver:443/`); strip it so the port on the final segment is recognized.
prefix = path.rstrip("/")
path_port: Optional[int] = None
last_segment = prefix.rsplit("/", 1)[-1]
if ":" in last_segment:
head, _, port_str = prefix.rpartition(":")
# Validate the path-derived port the way urlparse validates the netloc port,
# so `driver:99999`, `driver:-1` and `driver:4_43` are not silently accepted.
if not (port_str.isdigit() and 0 <= int(port_str) <= 65535):
raise PySparkValueError(
errorClass="INVALID_CONNECT_URL",
messageParameters={
"detail": f"Port '{port_str}' in the path of '{url}' must be an "
"integer between 0 and 65535. Please update the URL to follow the "
"correct format, e.g., 'sc://host/prefix:port'.",
},
)
path_port = int(port_str)
prefix = head

prefix = "/" + prefix.strip("/")
if prefix == "/":
# A bare `:port` with no route prefix is not a path-routed URL; drop both the
# empty prefix and the meaningless port and defer to the standard parser.
prefix = ""
path_port = None
elif ";" in prefix:
# urlparse only strips `;params` from the final path segment, so a
# `;key=value` elsewhere would otherwise be silently swallowed into the route
# prefix (and leak into the :path of every RPC). `;` is never legal here.
raise PySparkValueError(
errorClass="INVALID_CONNECT_URL",
messageParameters={
"detail": f"The path prefix '{prefix}' must not contain ';'. A "
"';key=value' parameter is only recognized after the final path "
"segment, e.g., 'sc://host/prefix:port/;key=value'.",
},
)

# Rebuild a path-free URL for DefaultChannelBuilder. The path-derived port is
# folded into the netloc only when the netloc carries no port of its own.
netloc = parsed.netloc
if path_port is not None and parsed.port is None:
netloc = f"{netloc}:{path_port}"
base_url = f"sc://{netloc}/;{parsed.params}" if parsed.params else f"sc://{netloc}"
return base_url, prefix

@property
def path_prefix(self) -> str:
"""The ingress path prefix prepended to every gRPC method, or '' if unset."""
return self._path_prefix


class PlanObservedMetrics(ObservedMetrics):
def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys: List[str]):
self._name = name
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.dataframe import DataFrame as ParentDataFrame
from pyspark.sql.connect.logging import logger
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder
from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
from pyspark.sql.connect.conf import RuntimeConf
from pyspark.sql.connect.plan import (
SQL,
Expand Down Expand Up @@ -126,7 +126,7 @@ class Builder:

def __init__(self) -> None:
self._options: Dict[str, Any] = {}
self._channel_builder: Optional[DefaultChannelBuilder] = None
self._channel_builder: Optional[ChannelBuilder] = None
self._hook_factories: list["Callable[[SparkSession], SparkSession.Hook]"] = []

@overload
Expand Down Expand Up @@ -159,7 +159,7 @@ def appName(self, name: str) -> "SparkSession.Builder":
def remote(self, location: str = "sc://localhost") -> "SparkSession.Builder":
return self.config("spark.remote", location)

def channelBuilder(self, channelBuilder: DefaultChannelBuilder) -> "SparkSession.Builder":
def channelBuilder(self, channelBuilder: ChannelBuilder) -> "SparkSession.Builder":
"""Uses custom :class:`ChannelBuilder` implementation, when there is a need
to customize the behavior for creation of GRPC connections.

Expand Down Expand Up @@ -279,7 +279,7 @@ def on_execute_plan(self, request: "pb2.ExecutePlanRequest") -> "pb2.ExecutePlan

def __init__(
self,
connection: Union[str, DefaultChannelBuilder],
connection: Union[str, ChannelBuilder],
userId: Optional[str] = None,
hook_factories: Optional[list["Callable[[SparkSession], Hook]"]] = None,
) -> None:
Expand Down
Loading