|
| 1 | +"""Spark session management and capability probing. |
| 2 | +
|
| 3 | +This module centralizes obtaining a Spark session and detecting per-session |
| 4 | +capabilities that vary across Spark builds (e.g. classic Spark 4.0 vs some |
| 5 | +Spark Connect builds). |
| 6 | +
|
| 7 | +All PySpark imports are lazy so that ``import tablespec.session`` succeeds even |
| 8 | +when PySpark is not installed. Callers that actually invoke a session or probe |
| 9 | +require PySpark at call time. |
| 10 | +""" |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import logging |
| 15 | +from typing import TYPE_CHECKING |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from pyspark.sql import SparkSession |
| 19 | + |
| 20 | +logger = logging.getLogger(__name__) |
| 21 | + |
| 22 | +# Cache of per-session capabilities, keyed by id(spark). |
| 23 | +_session_capabilities: dict[int, dict[str, bool]] = {} |
| 24 | + |
| 25 | + |
| 26 | +def _probe_try_to_timestamp_with_format(spark: object) -> bool: |
| 27 | + """Probe whether ``F.try_to_timestamp(col, F.lit(fmt))`` works on this session. |
| 28 | +
|
| 29 | + This expression works on classic Spark 4.0 but may not on some Spark Connect |
| 30 | + builds. We evaluate a tiny 1-row DataFrame expression and return True on |
| 31 | + success, False on any failure. |
| 32 | +
|
| 33 | + PySpark is imported lazily so the module imports cleanly without PySpark. |
| 34 | + """ |
| 35 | + try: |
| 36 | + from pyspark.sql import functions as F # noqa: N812 |
| 37 | + |
| 38 | + df = spark.createDataFrame([("2020-01-01",)], ["d"]) # type: ignore[attr-defined] |
| 39 | + expr = F.try_to_timestamp(df["d"], F.lit("yyyy-MM-dd")) |
| 40 | + df.select(expr).collect() |
| 41 | + except Exception: |
| 42 | + return False |
| 43 | + else: |
| 44 | + return True |
| 45 | + |
| 46 | + |
| 47 | +def get_capabilities(spark: object) -> dict[str, bool]: |
| 48 | + """Return capability flags for the given Spark session, with caching. |
| 49 | +
|
| 50 | + The result is cached keyed by ``id(spark)``. On a cache miss the relevant |
| 51 | + probes are run; on a cache hit the cached result is returned without |
| 52 | + re-probing. |
| 53 | + """ |
| 54 | + key = id(spark) |
| 55 | + cached = _session_capabilities.get(key) |
| 56 | + if cached is not None: |
| 57 | + return cached |
| 58 | + |
| 59 | + capabilities = { |
| 60 | + "try_to_timestamp_with_format": _probe_try_to_timestamp_with_format(spark), |
| 61 | + } |
| 62 | + _session_capabilities[key] = capabilities |
| 63 | + return capabilities |
| 64 | + |
| 65 | + |
| 66 | +def get_session(app_name: str = "tablespec", backend: str = "spark") -> SparkSession: |
| 67 | + """Obtain a Spark session. |
| 68 | +
|
| 69 | + Args: |
| 70 | + ---- |
| 71 | + app_name: Name of the Spark application. |
| 72 | + backend: Session backend. Currently only ``"spark"`` is supported. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + ------- |
| 76 | + An active or newly created SparkSession. |
| 77 | +
|
| 78 | + """ |
| 79 | + if backend != "spark": |
| 80 | + msg = f"Unknown backend: {backend}" |
| 81 | + raise ValueError(msg) |
| 82 | + |
| 83 | + from pyspark.sql import SparkSession |
| 84 | + |
| 85 | + from tablespec.spark_factory import SparkSessionFactory |
| 86 | + |
| 87 | + try: |
| 88 | + existing = SparkSession.getActiveSession() |
| 89 | + except Exception: |
| 90 | + existing = None |
| 91 | + if existing is not None: |
| 92 | + return existing |
| 93 | + |
| 94 | + return SparkSessionFactory.create_session(app_name) |
| 95 | + |
| 96 | + |
| 97 | +__all__ = [ |
| 98 | + "get_capabilities", |
| 99 | + "get_session", |
| 100 | +] |
0 commit comments