Skip to content

Commit 3865b60

Browse files
do not fail a spark job if a single row fails on UDF
1 parent 6b1f0da commit 3865b60

6 files changed

Lines changed: 187 additions & 26 deletions

File tree

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,16 @@ from datacustomcode.client import Client, llm_gateway_generate_text_col
316316
def main():
317317
client = Client()
318318
df = client.read_dlo("Input__dll")
319+
# llm_gateway_generate_text_col returns a struct
320+
# {status, response, error_code, error_message} per row, so per-row
321+
# failures don't abort the Spark job. Pick the field you want with [].
319322
df_generated = df.withColumn(
320323
"greeting__c",
321324
llm_gateway_generate_text_col(
322325
"In one sentence, greet {name} from {city}.",
323326
{"name": col("name__c"), "city": col("homecity__c")},
324327
model_id="sfdc_ai__DefaultGPT4Omni", # An AI model in your org
325-
),
328+
)["response"],
326329
)
327330
328331
dlo_name = "Output_dll"

src/datacustomcode/client.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,23 @@ def llm_gateway_generate_text_col(
6262
) -> "Column":
6363
"""Build a Spark Column that runs the LLM Gateway per row.
6464
65+
The returned Column yields a struct ``{status, response, error_code,
66+
error_message}`` for each row. Use ``[...]`` (or ``getField``) to pick the
67+
field you want, e.g. ``llm_gateway_generate_text_col(...)["response"]``.
68+
Per-row failures populate ``status`` / ``error_code`` / ``error_message``
69+
so a single bad row does not abort the whole Spark job.
70+
6571
Example:
6672
67-
>>> df.withColumn(
68-
... "greeting__c",
69-
... llm_gateway_generate_text_col(
70-
... "In one sentence, greet {name} from {city}.",
71-
... {"name": col("name__c"), "city": col("homecity__c")},
72-
... model_id="sfdc_ai__DefaultGPT4Omni",
73-
... ),
73+
>>> result = llm_gateway_generate_text_col(
74+
... "In one sentence, greet {name} from {city}.",
75+
... {"name": col("name__c"), "city": col("homecity__c")},
76+
... model_id="sfdc_ai__DefaultGPT4Omni",
77+
... )
78+
>>> df.withColumn("greeting__c", result["response"])
79+
>>> # …or keep the struct around and inspect failures:
80+
>>> df.withColumn("llm", result).select(
81+
... "llm.status", "llm.response", "llm.error_message"
7482
... )
7583
7684
Args:
@@ -81,7 +89,11 @@ def llm_gateway_generate_text_col(
8189
model_id: LLM model id. Defaults to ``sfdc_ai__DefaultGPT4Omni``.
8290
8391
Returns:
84-
A Spark ``Column`` that, when evaluated, produces the generated text.
92+
A Spark ``Column`` of ``StructType`` with fields ``status``,
93+
``response``, ``error_code``, and ``error_message`` (all nullable
94+
strings). On success, ``status == "SUCCESS"`` and ``response`` holds
95+
the generated text; on failure, ``status == "ERROR"`` and the
96+
``error_*`` fields carry diagnostic detail.
8597
"""
8698
gateway = Client()._get_spark_llm_gateway()
8799
return gateway.llm_gateway_generate_text_col(template, values, model_id=model_id)

src/datacustomcode/llm_gateway/spark_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,13 @@ def llm_gateway_generate_text_col(
5050
values: Union[Dict[str, "Column"], "Column"],
5151
model_id: Optional[str] = None,
5252
) -> "Column":
53-
"""Build a Spark ``Column`` that invokes the LLM Gateway per row."""
53+
"""Build a Spark ``Column`` that invokes the LLM Gateway per row and
54+
yields a struct ``{status, response, error_code, error_message}``.
55+
56+
Select an individual field, e.g.
57+
``llm_gateway_generate_text_col(...)["response"]``. Returning a struct
58+
means a single failing row doesn't abort the Spark job.
59+
Failing row leaves the rest of the DataFrame intact — callers can
60+
inspect ``status`` / ``error_code`` per row instead of having the
61+
Spark job abort.
62+
"""

src/datacustomcode/llm_gateway/spark_default.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,16 @@
2828
from pyspark.sql import Column
2929

3030
from datacustomcode.llm_gateway.base import LLMGateway
31+
from datacustomcode.llm_gateway.types.generate_text_response import (
32+
GenerateTextResponse,
33+
)
3134

3235

3336
_DEFAULT_LLM_MODEL_ID = "sfdc_ai__DefaultGPT4Omni"
3437

38+
_STATUS_SUCCESS = "SUCCESS"
39+
_STATUS_ERROR = "ERROR"
40+
3541

3642
class DefaultSparkLLMGateway(SparkLLMGateway):
3743

@@ -60,29 +66,50 @@ def llm_gateway_generate_text_col(
6066
values: Union[Dict[str, "Column"], "Column"],
6167
model_id: Optional[str] = None,
6268
) -> "Column":
63-
69+
"""Build a per-row UDF that returns a struct ``{status, response,
70+
error_code, error_message}`` so per-row failures do not abort the
71+
Spark job. Callers select the field they want, e.g.
72+
``llm_gateway_generate_text_col(...)["response"]``.
73+
"""
6474
from pyspark.sql.functions import struct, udf
65-
from pyspark.sql.types import StringType
75+
from pyspark.sql.types import (
76+
StringType,
77+
StructField,
78+
StructType,
79+
)
6680

6781
if isinstance(values, dict):
6882
values_col = struct(*[v.alias(k) for k, v in values.items()])
6983
else:
7084
values_col = values
7185

7286
gateway = self._llm_gateway
87+
result_schema = StructType(
88+
[
89+
StructField("status", StringType(), True),
90+
StructField("response", StringType(), True),
91+
StructField("error_code", StringType(), True),
92+
StructField("error_message", StringType(), True),
93+
]
94+
)
7395

74-
def _generate(values_row: Any) -> str:
96+
def _generate(values_row: Any) -> Dict[str, Optional[str]]:
7597
if values_row is None:
76-
return ""
98+
return {
99+
"status": _STATUS_ERROR,
100+
"response": None,
101+
"error_code": None,
102+
"error_message": "values column was null for this row",
103+
}
77104
subs = (
78105
values_row.asDict()
79106
if hasattr(values_row, "asDict")
80107
else dict(values_row)
81108
)
82109
prompt = template.format(**subs)
83-
return _invoke_llm_gateway(gateway, prompt, model_id)
110+
return _invoke_llm_gateway_as_struct(gateway, prompt, model_id)
84111

85-
return udf(_generate, StringType())(values_col)
112+
return udf(_generate, result_schema)(values_col)
86113

87114

88115
def _build_underlying_gateway() -> "LLMGateway":
@@ -97,22 +124,33 @@ def _build_underlying_gateway() -> "LLMGateway":
97124
return cfg.to_object()
98125

99126

100-
def _invoke_llm_gateway(
127+
def _call_llm_gateway(
101128
gateway: "LLMGateway",
102129
prompt: str,
103130
model_id: Optional[str],
104-
) -> str:
105-
from datacustomcode.llm_gateway.errors import LLMGatewayCallError
131+
) -> "GenerateTextResponse":
132+
"""Build the request and dispatch it to the underlying gateway."""
106133
from datacustomcode.llm_gateway.types.generate_text_request_builder import (
107134
GenerateTextRequestBuilder,
108135
)
109136

110-
builder = (
137+
request = (
111138
GenerateTextRequestBuilder()
112139
.set_prompt(prompt)
113140
.set_model(model_id or _DEFAULT_LLM_MODEL_ID)
141+
.build()
114142
)
115-
response = gateway.generate_text(builder.build())
143+
return gateway.generate_text(request)
144+
145+
146+
def _invoke_llm_gateway(
147+
gateway: "LLMGateway",
148+
prompt: str,
149+
model_id: Optional[str],
150+
) -> str:
151+
from datacustomcode.llm_gateway.errors import LLMGatewayCallError
152+
153+
response = _call_llm_gateway(gateway, prompt, model_id)
116154
if response.is_error:
117155
raise LLMGatewayCallError(
118156
f"LLM Gateway call failed: status_code={response.status_code}, "
@@ -123,3 +161,24 @@ def _invoke_llm_gateway(
123161
error_message=str(response.data) if response.data else None,
124162
)
125163
return response.text
164+
165+
166+
def _invoke_llm_gateway_as_struct(
167+
gateway: "LLMGateway",
168+
prompt: str,
169+
model_id: Optional[str],
170+
) -> Dict[str, Optional[str]]:
171+
response = _call_llm_gateway(gateway, prompt, model_id)
172+
if response.is_error:
173+
return {
174+
"status": _STATUS_ERROR,
175+
"response": None,
176+
"error_code": response.error_code or None,
177+
"error_message": str(response.data) if response.data else None,
178+
}
179+
return {
180+
"status": _STATUS_SUCCESS,
181+
"response": response.text,
182+
"error_code": None,
183+
"error_message": None,
184+
}

src/datacustomcode/templates/script/payload/entrypoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def main():
1616
You can use your AI models configured in Salesforce to generate column
1717
values. See README.md for how to test locally before deploying to Data Cloud.
1818
19-
Example:
19+
Example (the per-row helper returns a struct
20+
``{status, response, error_code, error_message}`` — pick the field you
21+
want with ``[...]``):
2022
2123
>>> from datacustomcode.client import llm_gateway_generate_text_col
2224
df_generated = df.withColumn(
@@ -25,7 +27,7 @@ def main():
2527
... "In one sentence, greet {name} from {city}.",
2628
... {"name": col("name__c"), "city": col("homecity__c")},
2729
... model_id="sfdc_ai__DefaultGPT4Omni",
28-
... ),
30+
... )["response"],
2931
... )
3032
3133
You can also invoke the LLM with a literal plain text prompt — no

tests/test_spark_llm_gateway.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
from datacustomcode.llm_gateway import DefaultSparkLLMGateway, LLMGatewayCallError
88
from datacustomcode.llm_gateway.spark_default import (
9+
_STATUS_ERROR,
10+
_STATUS_SUCCESS,
911
_build_underlying_gateway,
1012
_invoke_llm_gateway,
13+
_invoke_llm_gateway_as_struct,
1114
)
1215
from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse
1316

@@ -134,7 +137,12 @@ def test_dict_values_built_into_struct_and_wrapped_in_udf(
134137
row.asDict.return_value = {"name": "Ada", "city": "London"}
135138
out = udf_fn(row)
136139

137-
assert out == "row-out"
140+
assert out == {
141+
"status": _STATUS_SUCCESS,
142+
"response": "row-out",
143+
"error_code": None,
144+
"error_message": None,
145+
}
138146
sent = mock_inner.generate_text.call_args.args[0]
139147
assert sent.prompt == "Greet Ada from London."
140148
assert sent.model_name == "test-model"
@@ -158,7 +166,7 @@ def test_column_values_passed_through_without_struct(self, mock_struct, mock_udf
158166

159167
@patch("pyspark.sql.functions.udf")
160168
@patch("pyspark.sql.functions.struct")
161-
def test_udf_returns_empty_for_null_row(self, mock_struct, mock_udf):
169+
def test_udf_returns_error_struct_for_null_row(self, mock_struct, mock_udf):
162170
mock_struct.return_value = MagicMock()
163171
mock_udf.return_value = MagicMock()
164172
mock_inner = MagicMock()
@@ -167,9 +175,37 @@ def test_udf_returns_empty_for_null_row(self, mock_struct, mock_udf):
167175
gateway.llm_gateway_generate_text_col("template", {"placeholder": MagicMock()})
168176

169177
udf_fn = mock_udf.call_args.args[0]
170-
assert udf_fn(None) == ""
178+
out = udf_fn(None)
179+
assert out["status"] == _STATUS_ERROR
180+
assert out["response"] is None
181+
assert "null" in out["error_message"].lower()
171182
mock_inner.generate_text.assert_not_called()
172183

184+
@patch("pyspark.sql.functions.udf")
185+
@patch("pyspark.sql.functions.struct")
186+
def test_udf_returns_error_struct_on_http_error(self, mock_struct, mock_udf):
187+
"""Per-row HTTP errors are returned as ``status="ERROR"`` structs so
188+
one bad row does not abort the Spark job."""
189+
mock_struct.return_value = MagicMock()
190+
mock_udf.return_value = MagicMock()
191+
mock_inner = MagicMock()
192+
mock_inner.generate_text.return_value = _error_response(
193+
status_code=503, error_code="UNAVAILABLE"
194+
)
195+
gateway = DefaultSparkLLMGateway(llm_gateway=mock_inner)
196+
197+
gateway.llm_gateway_generate_text_col("Greet {name}", {"name": MagicMock()})
198+
199+
udf_fn = mock_udf.call_args.args[0]
200+
row = MagicMock()
201+
row.asDict.return_value = {"name": "Ada"}
202+
out = udf_fn(row)
203+
204+
assert out["status"] == _STATUS_ERROR
205+
assert out["response"] is None
206+
assert out["error_code"] == "UNAVAILABLE"
207+
assert out["error_message"] is not None
208+
173209

174210
class TestInvokeLLMGateway:
175211

@@ -207,6 +243,46 @@ def test_raises_llm_gateway_call_error_on_error_response(self):
207243
assert "UNAVAILABLE" in str(excinfo.value)
208244

209245

246+
class TestInvokeLLMGatewayAsStruct:
247+
"""Non-raising variant of ``_invoke_llm_gateway`` used by the per-row UDF.
248+
Both SUCCESS and ERROR cases land in the same struct shape so callers can
249+
select fields uniformly."""
250+
251+
def test_success_returns_success_struct(self):
252+
mock_inner = MagicMock()
253+
mock_inner.generate_text.return_value = _success_response("howdy")
254+
255+
out = _invoke_llm_gateway_as_struct(mock_inner, "prompt", "model")
256+
257+
assert out == {
258+
"status": _STATUS_SUCCESS,
259+
"response": "howdy",
260+
"error_code": None,
261+
"error_message": None,
262+
}
263+
264+
def test_error_returns_error_struct_without_raising(self):
265+
mock_inner = MagicMock()
266+
mock_inner.generate_text.return_value = _error_response(
267+
status_code=503, error_code="UNAVAILABLE"
268+
)
269+
270+
out = _invoke_llm_gateway_as_struct(mock_inner, "prompt", "model")
271+
272+
assert out["status"] == _STATUS_ERROR
273+
assert out["response"] is None
274+
assert out["error_code"] == "UNAVAILABLE"
275+
assert out["error_message"] is not None
276+
277+
def test_uses_default_model_when_none(self):
278+
mock_inner = MagicMock()
279+
mock_inner.generate_text.return_value = _success_response("ok")
280+
281+
_invoke_llm_gateway_as_struct(mock_inner, "prompt", None)
282+
sent = mock_inner.generate_text.call_args.args[0]
283+
assert sent.model_name == "sfdc_ai__DefaultGPT4Omni"
284+
285+
210286
class TestDefaultSparkLLMGatewayGenerateTextErrorHandling:
211287
"""The scalar generate_text path raises when the underlying gateway errors."""
212288

0 commit comments

Comments
 (0)