Skip to content

Commit 6b1f0da

Browse files
raise error when llm gateway call fails
1 parent 240db70 commit 6b1f0da

13 files changed

Lines changed: 107 additions & 95 deletions

File tree

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def main():
322322
"In one sentence, greet {name} from {city}.",
323323
{"name": col("name__c"), "city": col("homecity__c")},
324324
model_id="sfdc_ai__DefaultGPT4Omni", # An AI model in your org
325-
max_tokens=100,
326325
),
327326
)
328327

src/datacustomcode/client.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def llm_gateway_generate_text_col(
5959
template: str,
6060
values: Union[Dict[str, "Column"], "Column"],
6161
model_id: Optional[str] = None,
62-
max_tokens: Optional[int] = None,
6362
) -> "Column":
6463
"""Build a Spark Column that runs the LLM Gateway per row.
6564
@@ -71,7 +70,6 @@ def llm_gateway_generate_text_col(
7170
... "In one sentence, greet {name} from {city}.",
7271
... {"name": col("name__c"), "city": col("homecity__c")},
7372
... model_id="sfdc_ai__DefaultGPT4Omni",
74-
... max_tokens=100,
7573
... ),
7674
... )
7775
@@ -81,15 +79,12 @@ def llm_gateway_generate_text_col(
8179
values: Either a mapping from placeholder name to Spark ``Column``, or
8280
a single ``Column`` whose value is already a struct.
8381
model_id: LLM model id. Defaults to ``sfdc_ai__DefaultGPT4Omni``.
84-
max_tokens: Maximum tokens to generate. Defaults to 200.
8582
8683
Returns:
8784
A Spark ``Column`` that, when evaluated, produces the generated text.
8885
"""
8986
gateway = Client()._get_spark_llm_gateway()
90-
return gateway.llm_gateway_generate_text_col(
91-
template, values, model_id=model_id, max_tokens=max_tokens
92-
)
87+
return gateway.llm_gateway_generate_text_col(template, values, model_id=model_id)
9388

9489

9590
class DataCloudObjectType(Enum):
@@ -150,9 +145,7 @@ class Client:
150145
finder: Find a file path
151146
reader: A custom reader to use for reading Data Cloud objects.
152147
writer: A custom writer to use for writing Data Cloud objects.
153-
spark_llm_gateway: Optional custom :class:`SparkLLMGateway`. When
154-
omitted, the gateway is lazily resolved from
155-
``spark_llm_gateway_config``.
148+
spark_llm_gateway: Optional custom :class:`SparkLLMGateway`.
156149
157150
Example:
158151
>>> client = Client()
@@ -292,7 +285,6 @@ def llm_gateway_generate_text(
292285
self,
293286
prompt: str,
294287
model_id: Optional[str] = None,
295-
max_tokens: Optional[int] = None,
296288
) -> str:
297289
"""Issue a one-shot LLM Gateway call. This is the scalar counterpart to
298290
:func:`llm_gateway_generate_text_col`: it runs **once** — not per row.
@@ -310,15 +302,13 @@ def llm_gateway_generate_text(
310302
``{field}`` substitution is performed on this string.
311303
model_id: LLM model id to target. Defaults to
312304
``sfdc_ai__DefaultGPT4Omni`` when ``None``.
313-
max_tokens: Hard upper bound on the number of tokens the model
314-
may generate. Defaults to 200 when ``None``.
315305
316306
Returns:
317307
The generated text as a plain Python ``str``; empty when the
318308
gateway response carries no generated text.
319309
"""
320310
return self._get_spark_llm_gateway().llm_gateway_generate_text(
321-
prompt, model_id=model_id, max_tokens=max_tokens
311+
prompt, model_id=model_id
322312
)
323313

324314
def _get_spark_llm_gateway(self) -> SparkLLMGateway:

src/datacustomcode/llm_gateway/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515

1616
from datacustomcode.llm_gateway.base import LLMGateway
1717
from datacustomcode.llm_gateway.default import DefaultLLMGateway
18+
from datacustomcode.llm_gateway.errors import LLMGatewayCallError
1819
from datacustomcode.llm_gateway.spark_base import SparkLLMGateway
1920
from datacustomcode.llm_gateway.spark_default import DefaultSparkLLMGateway
2021

2122
__all__ = [
2223
"DefaultLLMGateway",
2324
"DefaultSparkLLMGateway",
2425
"LLMGateway",
26+
"LLMGatewayCallError",
2527
"SparkLLMGateway",
2628
]

src/datacustomcode/llm_gateway/default.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
3434

3535
payload: Dict[str, Any] = {"prompt": request.prompt}
3636

37-
if request.max_tokens is not None:
38-
payload["max_tokens"] = request.max_tokens
3937
if request.localization:
4038
payload["localization"] = request.localization
4139
if request.tags:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Exceptions raised by LLM Gateway implementations."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Optional
20+
21+
22+
class LLMGatewayCallError(RuntimeError):
23+
"""Raised when an LLM Gateway call returns an error."""
24+
25+
def __init__(
26+
self,
27+
message: str,
28+
*,
29+
status: Optional[object] = None,
30+
error_code: Optional[str] = None,
31+
error_message: Optional[str] = None,
32+
) -> None:
33+
super().__init__(message)
34+
self.status = status
35+
self.error_code = error_code
36+
self.error_message = error_message

src/datacustomcode/llm_gateway/spark_base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def llm_gateway_generate_text(
4040
self,
4141
prompt: str,
4242
model_id: Optional[str] = None,
43-
max_tokens: Optional[int] = None,
4443
) -> str:
4544
"""Issue a one-shot LLM Gateway call and return the generated text."""
4645

@@ -50,6 +49,5 @@ def llm_gateway_generate_text_col(
5049
template: str,
5150
values: Union[Dict[str, "Column"], "Column"],
5251
model_id: Optional[str] = None,
53-
max_tokens: Optional[int] = None,
5452
) -> "Column":
5553
"""Build a Spark ``Column`` that invokes the LLM Gateway per row."""

src/datacustomcode/llm_gateway/spark_default.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232

3333
_DEFAULT_LLM_MODEL_ID = "sfdc_ai__DefaultGPT4Omni"
34-
_DEFAULT_LLM_MAX_TOKENS = 200
3534

3635

3736
class DefaultSparkLLMGateway(SparkLLMGateway):
@@ -52,16 +51,14 @@ def llm_gateway_generate_text(
5251
self,
5352
prompt: str,
5453
model_id: Optional[str] = None,
55-
max_tokens: Optional[int] = None,
5654
) -> str:
57-
return _invoke_llm_gateway(self._llm_gateway, prompt, model_id, max_tokens)
55+
return _invoke_llm_gateway(self._llm_gateway, prompt, model_id)
5856

5957
def llm_gateway_generate_text_col(
6058
self,
6159
template: str,
6260
values: Union[Dict[str, "Column"], "Column"],
6361
model_id: Optional[str] = None,
64-
max_tokens: Optional[int] = None,
6562
) -> "Column":
6663

6764
from pyspark.sql.functions import struct, udf
@@ -83,7 +80,7 @@ def _generate(values_row: Any) -> str:
8380
else dict(values_row)
8481
)
8582
prompt = template.format(**subs)
86-
return _invoke_llm_gateway(gateway, prompt, model_id, max_tokens)
83+
return _invoke_llm_gateway(gateway, prompt, model_id)
8784

8885
return udf(_generate, StringType())(values_col)
8986

@@ -104,8 +101,8 @@ def _invoke_llm_gateway(
104101
gateway: "LLMGateway",
105102
prompt: str,
106103
model_id: Optional[str],
107-
max_tokens: Optional[int],
108104
) -> str:
105+
from datacustomcode.llm_gateway.errors import LLMGatewayCallError
109106
from datacustomcode.llm_gateway.types.generate_text_request_builder import (
110107
GenerateTextRequestBuilder,
111108
)
@@ -114,6 +111,15 @@ def _invoke_llm_gateway(
114111
GenerateTextRequestBuilder()
115112
.set_prompt(prompt)
116113
.set_model(model_id or _DEFAULT_LLM_MODEL_ID)
117-
.set_max_tokens(max_tokens or _DEFAULT_LLM_MAX_TOKENS)
118114
)
119-
return gateway.generate_text(builder.build()).text
115+
response = gateway.generate_text(builder.build())
116+
if response.is_error:
117+
raise LLMGatewayCallError(
118+
f"LLM Gateway call failed: status_code={response.status_code}, "
119+
f"error_code={response.error_code!r}, "
120+
f"message={response.data!r}",
121+
status=response.status_code,
122+
error_code=response.error_code or None,
123+
error_message=str(response.data) if response.data else None,
124+
)
125+
return response.text

src/datacustomcode/llm_gateway/types/generate_text_request.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,6 @@ class GenerateTextRequest(BaseModel):
4040
)
4141
model_name: str = Field(..., min_length=1, description="Name of the model to use")
4242
prompt: str = Field(..., description="Input prompt")
43-
max_tokens: Optional[int] = Field(
44-
default=None,
45-
ge=1,
46-
description=(
47-
"Maximum number of tokens to generate. If None, server default applies."
48-
),
49-
)
5043
localization: Optional[Dict[str, Any]] = Field(
5144
default=None, description="Localization settings"
5245
)

src/datacustomcode/llm_gateway/types/generate_text_request_builder.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ class GenerateTextRequestBuilder:
2626
def __init__(self) -> None:
2727
self._prompt = ""
2828
self._model_name = ""
29-
self._max_tokens: Optional[int] = None
3029
self._localization: Optional[Dict[str, Any]] = None
3130
self._tags: Optional[Dict[str, Any]] = None
3231

@@ -38,10 +37,6 @@ def set_model(self, model_name: str) -> "GenerateTextRequestBuilder":
3837
self._model_name = model_name
3938
return self
4039

41-
def set_max_tokens(self, max_tokens: int) -> "GenerateTextRequestBuilder":
42-
self._max_tokens = max_tokens
43-
return self
44-
4540
def set_localization(
4641
self,
4742
localization: Optional[Dict[str, Any]] = None,
@@ -80,7 +75,6 @@ def build(self) -> GenerateTextRequest:
8075
request = GenerateTextRequest(
8176
prompt=self._prompt,
8277
model_name=self._model_name,
83-
max_tokens=self._max_tokens,
8478
localization=self._localization,
8579
tags=self._tags,
8680
)

src/datacustomcode/templates/script/payload/entrypoint.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def main():
2525
... "In one sentence, greet {name} from {city}.",
2626
... {"name": col("name__c"), "city": col("homecity__c")},
2727
... model_id="sfdc_ai__DefaultGPT4Omni",
28-
... max_tokens=100,
2928
... ),
3029
... )
3130
@@ -35,7 +34,7 @@ def main():
3534
Example:
3635
3736
>>> generated_text = client.llm_gateway_generate_text(
38-
... prompt, model_id, max_tokens
37+
... prompt, model_id
3938
... )
4039
"""
4140

0 commit comments

Comments
 (0)