Skip to content

Commit 59c70a3

Browse files
authored
Merge pull request #119 from forcedotcom/einstein-predict-script
Support einstein predict for script in public SDK
2 parents d9c4e12 + 5894a8d commit 59c70a3

13 files changed

Lines changed: 1088 additions & 1 deletion

README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,56 @@ datacustomcode run ./payload/entrypoint.py --sf-cli-org myorg
373373
```
374374

375375

376+
## Testing Einstein Predictions
377+
378+
You can use AI models configured in Einstein Studio to score your data while
379+
transforming it. As with the LLM Gateway, there are two flavors: a one-shot
380+
scalar call (`client.einstein_predict`) and a per-row column helper
381+
(`einstein_predict_col`). Below is a sample code example:
382+
383+
```
384+
from datacustomcode.client import Client, einstein_predict_col
385+
from datacustomcode.einstein_predictions.types import PredictionType
386+
387+
388+
def main():
389+
client = Client()
390+
df = client.read_dlo("Input__dll")
391+
# einstein_predict_col returns a struct
392+
# {status, response, error_code, error_message} per row, so per-row
393+
# failures don't abort the Spark job. `response` is the prediction
394+
# payload as a JSON string. Pick the field you want with [].
395+
df_scored = df.withColumn(
396+
"prediction__c",
397+
einstein_predict_col(
398+
"my_regression_model", # An AI model in your org
399+
PredictionType.REGRESSION,
400+
{"square_feet": col("square_feet__c"), "beds": col("beds__c")},
401+
)["response"],
402+
)
403+
404+
dlo_name = "Output_dll"
405+
client.write_to_dlo(dlo_name, df_scored, write_mode=WriteMode.APPEND)
406+
407+
# One-shot scalar prediction returns the response payload as a dict
408+
prediction = client.einstein_predict(
409+
"my_regression_model",
410+
PredictionType.REGRESSION,
411+
{"square_feet": 1800, "beds": 3},
412+
)
413+
414+
if __name__ == "__main__":
415+
main()
416+
```
417+
418+
Testing this code locally uses the same External Client App setup described in
419+
[Testing LLM Gateway](#testing-llm-gateway). Once your `myorg` alias is set up,
420+
run:
421+
```
422+
datacustomcode run ./payload/entrypoint.py --sf-cli-org myorg
423+
```
424+
425+
376426
## Docker usage
377427

378428
The SDK provides Docker-based development options that allow you to test your code in an environment that closely resembles Data Cloud's execution environment.

src/datacustomcode/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717
"AuthType",
1818
"Client",
1919
"Credentials",
20+
"DefaultSparkEinsteinPredictions",
2021
"DefaultSparkLLMGateway",
2122
"PrintDataCloudWriter",
2223
"QueryAPIDataCloudReader",
24+
"SparkEinsteinPredictions",
2325
"SparkLLMGateway",
26+
"einstein_predict_col",
2427
"llm_gateway_generate_text_col",
2528
]
2629

@@ -59,4 +62,16 @@ def __getattr__(name: str):
5962
from datacustomcode.client import llm_gateway_generate_text_col
6063

6164
return llm_gateway_generate_text_col
65+
elif name == "SparkEinsteinPredictions":
66+
from datacustomcode.einstein_predictions import SparkEinsteinPredictions
67+
68+
return SparkEinsteinPredictions
69+
elif name == "DefaultSparkEinsteinPredictions":
70+
from datacustomcode.einstein_predictions import DefaultSparkEinsteinPredictions
71+
72+
return DefaultSparkEinsteinPredictions
73+
elif name == "einstein_predict_col":
74+
from datacustomcode.client import einstein_predict_col
75+
76+
return einstein_predict_col
6277
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

src/datacustomcode/client.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
from enum import Enum
1818
from typing import (
1919
TYPE_CHECKING,
20+
Any,
2021
ClassVar,
2122
Dict,
2223
Optional,
2324
Union,
2425
)
2526

2627
from datacustomcode.config import config
28+
from datacustomcode.einstein_predictions_config import spark_einstein_predictions_config
2729
from datacustomcode.file.path.default import DefaultFindFilePath
2830
from datacustomcode.io.reader.base import BaseDataCloudReader
2931
from datacustomcode.llm_gateway_config import spark_llm_gateway_config
@@ -34,6 +36,8 @@
3436

3537
from pyspark.sql import Column, DataFrame as PySparkDataFrame
3638

39+
from datacustomcode.einstein_predictions.spark_base import SparkEinsteinPredictions
40+
from datacustomcode.einstein_predictions.types import PredictionType
3741
from datacustomcode.io.reader.base import BaseDataCloudReader
3842
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
3943
from datacustomcode.llm_gateway.spark_base import SparkLLMGateway
@@ -99,6 +103,70 @@ def llm_gateway_generate_text_col(
99103
return gateway.llm_gateway_generate_text_col(template, values, model_id=model_id)
100104

101105

106+
def _build_spark_einstein_predictions() -> "SparkEinsteinPredictions":
107+
"""Instantiate the SDK-configured :class:`SparkEinsteinPredictions`.
108+
109+
Raises:
110+
RuntimeError: If no ``spark_einstein_predictions_config`` has been loaded.
111+
"""
112+
cfg = spark_einstein_predictions_config.spark_einstein_predictions_config
113+
if cfg is None:
114+
raise RuntimeError(
115+
"spark_einstein_predictions_config is not configured. Add a "
116+
"'spark_einstein_predictions_config' section to config.yaml."
117+
)
118+
return cfg.to_object()
119+
120+
121+
def einstein_predict_col(
122+
model_api_name: str,
123+
prediction_type: "PredictionType",
124+
features: Dict[str, "Column"],
125+
settings: Optional[Dict[str, Any]] = None,
126+
) -> "Column":
127+
"""Build a Spark Column that runs an Einstein prediction per row.
128+
129+
The returned Column yields a struct ``{status, response, error_code,
130+
error_message}`` for each row. Use ``[...]`` (or ``getField``) to pick the
131+
field you want, e.g. ``einstein_predict_col(...)["response"]``. ``response``
132+
holds the prediction response payload as a JSON string. Per-row failures
133+
populate ``status`` / ``error_code`` / ``error_message`` so a single bad row
134+
does not abort the whole Spark job.
135+
136+
Example:
137+
138+
>>> from datacustomcode.einstein_predictions.types import PredictionType
139+
>>> result = einstein_predict_col(
140+
... "my_regression_model",
141+
... PredictionType.REGRESSION,
142+
... {"square_feet": col("square_feet__c"), "beds": col("beds__c")},
143+
... )
144+
>>> df.withColumn("prediction__c", result["response"])
145+
>>> # …or keep the struct around and inspect failures:
146+
>>> df.withColumn("pred", result).select(
147+
... "pred.status", "pred.response", "pred.error_message"
148+
... )
149+
150+
Args:
151+
model_api_name: API name of the Einstein model to invoke.
152+
prediction_type: The :class:`PredictionType` of the model.
153+
features: A mapping from model feature column name to a Spark ``Column``
154+
supplying that feature's per-row value.
155+
settings: Optional prediction settings forwarded to the model.
156+
157+
Returns:
158+
A Spark ``Column`` of ``StructType`` with fields ``status``,
159+
``response``, ``error_code``, and ``error_message`` (all nullable
160+
strings). On success, ``status == "SUCCESS"`` and ``response`` holds
161+
the JSON-serialized prediction payload; on failure, ``status ==
162+
"ERROR"`` and the ``error_*`` fields carry diagnostic detail.
163+
"""
164+
predictions = Client()._get_spark_einstein_predictions()
165+
return predictions.einstein_predict_col(
166+
model_api_name, prediction_type, features, settings=settings
167+
)
168+
169+
102170
class DataCloudObjectType(Enum):
103171
DLO = "dlo"
104172
DMO = "dmo"
@@ -158,6 +226,8 @@ class Client:
158226
reader: A custom reader to use for reading Data Cloud objects.
159227
writer: A custom writer to use for writing Data Cloud objects.
160228
spark_llm_gateway: Optional custom :class:`SparkLLMGateway`.
229+
spark_einstein_predictions: Optional custom
230+
:class:`SparkEinsteinPredictions`.
161231
162232
Example:
163233
>>> client = Client()
@@ -172,6 +242,7 @@ class Client:
172242
_writer: BaseDataCloudWriter
173243
_file: DefaultFindFilePath
174244
_spark_llm_gateway: Optional[SparkLLMGateway]
245+
_spark_einstein_predictions: Optional[SparkEinsteinPredictions]
175246
_data_layer_history: dict[DataCloudObjectType, set[str]]
176247
_code_type: str
177248

@@ -181,12 +252,14 @@ def __new__(
181252
writer: Optional[BaseDataCloudWriter] = None,
182253
spark_provider: Optional[BaseSparkSessionProvider] = None,
183254
spark_llm_gateway: Optional[SparkLLMGateway] = None,
255+
spark_einstein_predictions: Optional[SparkEinsteinPredictions] = None,
184256
code_type: str = "script",
185257
) -> Client:
186258

187259
if cls._instance is None:
188260
cls._instance = super().__new__(cls)
189261
cls._instance._spark_llm_gateway = spark_llm_gateway
262+
cls._instance._spark_einstein_predictions = spark_einstein_predictions
190263
# Initialize Readers and Writers from config
191264
# and/or provided reader and writer
192265
if reader is None or writer is None:
@@ -358,6 +431,49 @@ def _get_spark_llm_gateway(self) -> SparkLLMGateway:
358431
self._spark_llm_gateway = _build_spark_llm_gateway()
359432
return self._spark_llm_gateway
360433

434+
def einstein_predict(
435+
self,
436+
model_api_name: str,
437+
prediction_type: "PredictionType",
438+
features: Dict[str, Any],
439+
settings: Optional[Dict[str, Any]] = None,
440+
) -> Dict[str, Any]:
441+
"""Issue a one-shot Einstein prediction. This is the scalar counterpart
442+
to :func:`einstein_predict_col`: it runs **once** — not per row. Use the
443+
column helper method instead when you want to fan a prediction out
444+
across every row of a DataFrame.
445+
446+
Example:
447+
448+
>>> from datacustomcode.einstein_predictions.types import PredictionType
449+
>>> response = Client().einstein_predict(
450+
... "my_regression_model",
451+
... PredictionType.REGRESSION,
452+
... {"square_feet": 1800, "beds": 3},
453+
... )
454+
455+
Args:
456+
model_api_name: API name of the Einstein model to invoke.
457+
prediction_type: The :class:`PredictionType` of the model.
458+
features: A mapping from model feature column name to a single
459+
scalar value (``str`` / ``float`` / ``bool``).
460+
settings: Optional prediction settings forwarded to the model.
461+
462+
Returns:
463+
The prediction response payload as a plain Python ``dict``.
464+
465+
Raises:
466+
EinsteinPredictionsCallError: If the prediction call fails.
467+
"""
468+
return self._get_spark_einstein_predictions().einstein_predict(
469+
model_api_name, prediction_type, features, settings=settings
470+
)
471+
472+
def _get_spark_einstein_predictions(self) -> SparkEinsteinPredictions:
473+
if self._spark_einstein_predictions is None:
474+
self._spark_einstein_predictions = _build_spark_einstein_predictions()
475+
return self._spark_einstein_predictions
476+
361477
def _validate_data_layer_history_does_not_contain(
362478
self, data_cloud_object_type: DataCloudObjectType
363479
) -> None:

src/datacustomcode/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ einstein_predictions_config:
2424
options:
2525
credentials_profile: default
2626

27+
spark_einstein_predictions_config:
28+
type_config_name: DefaultSparkEinsteinPredictions
29+
2730
llm_gateway_config:
2831
type_config_name: DefaultLLMGateway
2932
options:

src/datacustomcode/einstein_predictions/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414
# limitations under the License.
1515

1616
from datacustomcode.einstein_predictions.base import EinsteinPredictions
17+
from datacustomcode.einstein_predictions.errors import EinsteinPredictionsCallError
1718
from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions
19+
from datacustomcode.einstein_predictions.spark_base import SparkEinsteinPredictions
20+
from datacustomcode.einstein_predictions.spark_default import (
21+
DefaultSparkEinsteinPredictions,
22+
)
1823

1924
__all__ = [
2025
"DefaultEinsteinPredictions",
26+
"DefaultSparkEinsteinPredictions",
2127
"EinsteinPredictions",
28+
"EinsteinPredictionsCallError",
29+
"SparkEinsteinPredictions",
2230
]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Exceptions raised by Einstein Predictions implementations."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Optional
20+
21+
22+
class EinsteinPredictionsCallError(RuntimeError):
23+
"""Raised when an Einstein Predictions call returns an error."""
24+
25+
def __init__(
26+
self,
27+
message: str,
28+
*,
29+
status: Optional[object] = None,
30+
error_code: Optional[str] = None,
31+
error_message: Optional[str] = None,
32+
) -> None:
33+
super().__init__(message)
34+
self.status = status
35+
self.error_code = error_code
36+
self.error_message = error_message

0 commit comments

Comments
 (0)