1717from enum import Enum
1818from typing import (
1919 TYPE_CHECKING ,
20+ Any ,
2021 ClassVar ,
2122 Dict ,
2223 Optional ,
2324 Union ,
2425)
2526
2627from datacustomcode .config import config
28+ from datacustomcode .einstein_predictions_config import spark_einstein_predictions_config
2729from datacustomcode .file .path .default import DefaultFindFilePath
2830from datacustomcode .io .reader .base import BaseDataCloudReader
2931from datacustomcode .llm_gateway_config import spark_llm_gateway_config
3436
3537 from pyspark .sql import Column , DataFrame as PySparkDataFrame
3638
39+ from datacustomcode .einstein_predictions .spark_base import SparkEinsteinPredictions
40+ from datacustomcode .einstein_predictions .types import PredictionType
3741 from datacustomcode .io .reader .base import BaseDataCloudReader
3842 from datacustomcode .io .writer .base import BaseDataCloudWriter , WriteMode
3943 from datacustomcode .llm_gateway .spark_base import SparkLLMGateway
@@ -99,6 +103,70 @@ def llm_gateway_generate_text_col(
99103 return gateway .llm_gateway_generate_text_col (template , values , model_id = model_id )
100104
101105
106+ def _build_spark_einstein_predictions () -> "SparkEinsteinPredictions" :
107+ """Instantiate the SDK-configured :class:`SparkEinsteinPredictions`.
108+
109+ Raises:
110+ RuntimeError: If no ``spark_einstein_predictions_config`` has been loaded.
111+ """
112+ cfg = spark_einstein_predictions_config .spark_einstein_predictions_config
113+ if cfg is None :
114+ raise RuntimeError (
115+ "spark_einstein_predictions_config is not configured. Add a "
116+ "'spark_einstein_predictions_config' section to config.yaml."
117+ )
118+ return cfg .to_object ()
119+
120+
121+ def einstein_predict_col (
122+ model_api_name : str ,
123+ prediction_type : "PredictionType" ,
124+ features : Dict [str , "Column" ],
125+ settings : Optional [Dict [str , Any ]] = None ,
126+ ) -> "Column" :
127+ """Build a Spark Column that runs an Einstein prediction per row.
128+
129+ The returned Column yields a struct ``{status, response, error_code,
130+ error_message}`` for each row. Use ``[...]`` (or ``getField``) to pick the
131+ field you want, e.g. ``einstein_predict_col(...)["response"]``. ``response``
132+ holds the prediction response payload as a JSON string. Per-row failures
133+ populate ``status`` / ``error_code`` / ``error_message`` so a single bad row
134+ does not abort the whole Spark job.
135+
136+ Example:
137+
138+ >>> from datacustomcode.einstein_predictions.types import PredictionType
139+ >>> result = einstein_predict_col(
140+ ... "my_regression_model",
141+ ... PredictionType.REGRESSION,
142+ ... {"square_feet": col("square_feet__c"), "beds": col("beds__c")},
143+ ... )
144+ >>> df.withColumn("prediction__c", result["response"])
145+ >>> # …or keep the struct around and inspect failures:
146+ >>> df.withColumn("pred", result).select(
147+ ... "pred.status", "pred.response", "pred.error_message"
148+ ... )
149+
150+ Args:
151+ model_api_name: API name of the Einstein model to invoke.
152+ prediction_type: The :class:`PredictionType` of the model.
153+ features: A mapping from model feature column name to a Spark ``Column``
154+ supplying that feature's per-row value.
155+ settings: Optional prediction settings forwarded to the model.
156+
157+ Returns:
158+ A Spark ``Column`` of ``StructType`` with fields ``status``,
159+ ``response``, ``error_code``, and ``error_message`` (all nullable
160+ strings). On success, ``status == "SUCCESS"`` and ``response`` holds
161+ the JSON-serialized prediction payload; on failure, ``status ==
162+ "ERROR"`` and the ``error_*`` fields carry diagnostic detail.
163+ """
164+ predictions = Client ()._get_spark_einstein_predictions ()
165+ return predictions .einstein_predict_col (
166+ model_api_name , prediction_type , features , settings = settings
167+ )
168+
169+
102170class DataCloudObjectType (Enum ):
103171 DLO = "dlo"
104172 DMO = "dmo"
@@ -158,6 +226,8 @@ class Client:
158226 reader: A custom reader to use for reading Data Cloud objects.
159227 writer: A custom writer to use for writing Data Cloud objects.
160228 spark_llm_gateway: Optional custom :class:`SparkLLMGateway`.
229+ spark_einstein_predictions: Optional custom
230+ :class:`SparkEinsteinPredictions`.
161231
162232 Example:
163233 >>> client = Client()
@@ -172,6 +242,7 @@ class Client:
172242 _writer : BaseDataCloudWriter
173243 _file : DefaultFindFilePath
174244 _spark_llm_gateway : Optional [SparkLLMGateway ]
245+ _spark_einstein_predictions : Optional [SparkEinsteinPredictions ]
175246 _data_layer_history : dict [DataCloudObjectType , set [str ]]
176247 _code_type : str
177248
@@ -181,12 +252,14 @@ def __new__(
181252 writer : Optional [BaseDataCloudWriter ] = None ,
182253 spark_provider : Optional [BaseSparkSessionProvider ] = None ,
183254 spark_llm_gateway : Optional [SparkLLMGateway ] = None ,
255+ spark_einstein_predictions : Optional [SparkEinsteinPredictions ] = None ,
184256 code_type : str = "script" ,
185257 ) -> Client :
186258
187259 if cls ._instance is None :
188260 cls ._instance = super ().__new__ (cls )
189261 cls ._instance ._spark_llm_gateway = spark_llm_gateway
262+ cls ._instance ._spark_einstein_predictions = spark_einstein_predictions
190263 # Initialize Readers and Writers from config
191264 # and/or provided reader and writer
192265 if reader is None or writer is None :
@@ -358,6 +431,49 @@ def _get_spark_llm_gateway(self) -> SparkLLMGateway:
358431 self ._spark_llm_gateway = _build_spark_llm_gateway ()
359432 return self ._spark_llm_gateway
360433
434+ def einstein_predict (
435+ self ,
436+ model_api_name : str ,
437+ prediction_type : "PredictionType" ,
438+ features : Dict [str , Any ],
439+ settings : Optional [Dict [str , Any ]] = None ,
440+ ) -> Dict [str , Any ]:
441+ """Issue a one-shot Einstein prediction. This is the scalar counterpart
442+ to :func:`einstein_predict_col`: it runs **once** — not per row. Use the
443+ column helper method instead when you want to fan a prediction out
444+ across every row of a DataFrame.
445+
446+ Example:
447+
448+ >>> from datacustomcode.einstein_predictions.types import PredictionType
449+ >>> response = Client().einstein_predict(
450+ ... "my_regression_model",
451+ ... PredictionType.REGRESSION,
452+ ... {"square_feet": 1800, "beds": 3},
453+ ... )
454+
455+ Args:
456+ model_api_name: API name of the Einstein model to invoke.
457+ prediction_type: The :class:`PredictionType` of the model.
458+ features: A mapping from model feature column name to a single
459+ scalar value (``str`` / ``float`` / ``bool``).
460+ settings: Optional prediction settings forwarded to the model.
461+
462+ Returns:
463+ The prediction response payload as a plain Python ``dict``.
464+
465+ Raises:
466+ EinsteinPredictionsCallError: If the prediction call fails.
467+ """
468+ return self ._get_spark_einstein_predictions ().einstein_predict (
469+ model_api_name , prediction_type , features , settings = settings
470+ )
471+
472+ def _get_spark_einstein_predictions (self ) -> SparkEinsteinPredictions :
473+ if self ._spark_einstein_predictions is None :
474+ self ._spark_einstein_predictions = _build_spark_einstein_predictions ()
475+ return self ._spark_einstein_predictions
476+
361477 def _validate_data_layer_history_does_not_contain (
362478 self , data_cloud_object_type : DataCloudObjectType
363479 ) -> None :
0 commit comments