66
77from datacustomcode .llm_gateway import DefaultSparkLLMGateway , LLMGatewayCallError
88from datacustomcode .llm_gateway .spark_default import (
9+ _STATUS_ERROR ,
10+ _STATUS_SUCCESS ,
911 _build_underlying_gateway ,
1012 _invoke_llm_gateway ,
13+ _invoke_llm_gateway_as_struct ,
1114)
1215from datacustomcode .llm_gateway .types .generate_text_response import GenerateTextResponse
1316
@@ -134,7 +137,12 @@ def test_dict_values_built_into_struct_and_wrapped_in_udf(
134137 row .asDict .return_value = {"name" : "Ada" , "city" : "London" }
135138 out = udf_fn (row )
136139
137- assert out == "row-out"
140+ assert out == {
141+ "status" : _STATUS_SUCCESS ,
142+ "response" : "row-out" ,
143+ "error_code" : None ,
144+ "error_message" : None ,
145+ }
138146 sent = mock_inner .generate_text .call_args .args [0 ]
139147 assert sent .prompt == "Greet Ada from London."
140148 assert sent .model_name == "test-model"
@@ -158,7 +166,7 @@ def test_column_values_passed_through_without_struct(self, mock_struct, mock_udf
158166
159167 @patch ("pyspark.sql.functions.udf" )
160168 @patch ("pyspark.sql.functions.struct" )
161- def test_udf_returns_empty_for_null_row (self , mock_struct , mock_udf ):
169+ def test_udf_returns_error_struct_for_null_row (self , mock_struct , mock_udf ):
162170 mock_struct .return_value = MagicMock ()
163171 mock_udf .return_value = MagicMock ()
164172 mock_inner = MagicMock ()
@@ -167,9 +175,37 @@ def test_udf_returns_empty_for_null_row(self, mock_struct, mock_udf):
167175 gateway .llm_gateway_generate_text_col ("template" , {"placeholder" : MagicMock ()})
168176
169177 udf_fn = mock_udf .call_args .args [0 ]
170- assert udf_fn (None ) == ""
178+ out = udf_fn (None )
179+ assert out ["status" ] == _STATUS_ERROR
180+ assert out ["response" ] is None
181+ assert "null" in out ["error_message" ].lower ()
171182 mock_inner .generate_text .assert_not_called ()
172183
184+ @patch ("pyspark.sql.functions.udf" )
185+ @patch ("pyspark.sql.functions.struct" )
186+ def test_udf_returns_error_struct_on_http_error (self , mock_struct , mock_udf ):
187+ """Per-row HTTP errors are returned as ``status="ERROR"`` structs so
188+ one bad row does not abort the Spark job."""
189+ mock_struct .return_value = MagicMock ()
190+ mock_udf .return_value = MagicMock ()
191+ mock_inner = MagicMock ()
192+ mock_inner .generate_text .return_value = _error_response (
193+ status_code = 503 , error_code = "UNAVAILABLE"
194+ )
195+ gateway = DefaultSparkLLMGateway (llm_gateway = mock_inner )
196+
197+ gateway .llm_gateway_generate_text_col ("Greet {name}" , {"name" : MagicMock ()})
198+
199+ udf_fn = mock_udf .call_args .args [0 ]
200+ row = MagicMock ()
201+ row .asDict .return_value = {"name" : "Ada" }
202+ out = udf_fn (row )
203+
204+ assert out ["status" ] == _STATUS_ERROR
205+ assert out ["response" ] is None
206+ assert out ["error_code" ] == "UNAVAILABLE"
207+ assert out ["error_message" ] is not None
208+
173209
174210class TestInvokeLLMGateway :
175211
@@ -207,6 +243,46 @@ def test_raises_llm_gateway_call_error_on_error_response(self):
207243 assert "UNAVAILABLE" in str (excinfo .value )
208244
209245
246+ class TestInvokeLLMGatewayAsStruct :
247+ """Non-raising variant of ``_invoke_llm_gateway`` used by the per-row UDF.
248+ Both SUCCESS and ERROR cases land in the same struct shape so callers can
249+ select fields uniformly."""
250+
251+ def test_success_returns_success_struct (self ):
252+ mock_inner = MagicMock ()
253+ mock_inner .generate_text .return_value = _success_response ("howdy" )
254+
255+ out = _invoke_llm_gateway_as_struct (mock_inner , "prompt" , "model" )
256+
257+ assert out == {
258+ "status" : _STATUS_SUCCESS ,
259+ "response" : "howdy" ,
260+ "error_code" : None ,
261+ "error_message" : None ,
262+ }
263+
264+ def test_error_returns_error_struct_without_raising (self ):
265+ mock_inner = MagicMock ()
266+ mock_inner .generate_text .return_value = _error_response (
267+ status_code = 503 , error_code = "UNAVAILABLE"
268+ )
269+
270+ out = _invoke_llm_gateway_as_struct (mock_inner , "prompt" , "model" )
271+
272+ assert out ["status" ] == _STATUS_ERROR
273+ assert out ["response" ] is None
274+ assert out ["error_code" ] == "UNAVAILABLE"
275+ assert out ["error_message" ] is not None
276+
277+ def test_uses_default_model_when_none (self ):
278+ mock_inner = MagicMock ()
279+ mock_inner .generate_text .return_value = _success_response ("ok" )
280+
281+ _invoke_llm_gateway_as_struct (mock_inner , "prompt" , None )
282+ sent = mock_inner .generate_text .call_args .args [0 ]
283+ assert sent .model_name == "sfdc_ai__DefaultGPT4Omni"
284+
285+
210286class TestDefaultSparkLLMGatewayGenerateTextErrorHandling :
211287 """The scalar generate_text path raises when the underlying gateway errors."""
212288
0 commit comments