Skip to content

Commit 5e145d4

Browse files
authored
test: Amazon Sagemaker - add unit tests (#3191)
1 parent 3a8f16b commit 5e145d4

1 file changed

Lines changed: 67 additions & 1 deletion

File tree

integrations/amazon_sagemaker/tests/test_sagemaker.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33
from unittest.mock import Mock
44

55
import pytest
6+
import requests
67
from botocore.exceptions import BotoCoreError
78
from haystack.utils.auth import EnvVarSecret
9+
810
from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator
9-
from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError
11+
from haystack_integrations.components.generators.amazon_sagemaker.errors import (
12+
AWSConfigurationError,
13+
SagemakerInferenceError,
14+
SagemakerNotReadyError,
15+
)
1016

1117

1218
def test_to_dict(set_env_variables, mock_boto3_session): # noqa: ARG001
@@ -144,6 +150,66 @@ def test_run_with_single_dictionary(set_env_variables, mock_boto3_session): # n
144150
assert response["meta"][0]["other"] == "metadata"
145151

146152

153+
def test_run_raises_on_unexpected_response_type(set_env_variables, mock_boto3_session): # noqa: ARG001
154+
client_mock = Mock()
155+
client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'"just-a-string"')}
156+
component = SagemakerGenerator(model="test-model")
157+
component.client = client_mock
158+
with pytest.raises(ValueError, match="Unexpected model response type"):
159+
component.run("What's Natural Language Processing?")
160+
161+
162+
def test_run_with_unknown_generation_key_returns_none_reply(set_env_variables, mock_boto3_session): # noqa: ARG001
163+
client_mock = Mock()
164+
client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'{"unknown_key": "value"}')}
165+
component = SagemakerGenerator(model="test-model")
166+
component.client = client_mock
167+
response = component.run("What's Natural Language Processing?")
168+
assert response["replies"] == [None]
169+
assert response["meta"] == [{"unknown_key": "value"}]
170+
171+
172+
def test_run_raises_not_ready_error_on_429(set_env_variables, mock_boto3_session): # noqa: ARG001
173+
client_mock = Mock()
174+
http_response = Mock(status_code=429, text="model is loading")
175+
client_mock.invoke_endpoint.side_effect = requests.HTTPError(response=http_response)
176+
component = SagemakerGenerator(model="test-model")
177+
component.client = client_mock
178+
with pytest.raises(SagemakerNotReadyError, match="Sagemaker model not ready: model is loading"):
179+
component.run("What's Natural Language Processing?")
180+
181+
182+
def test_run_raises_inference_error_on_other_http_error(set_env_variables, mock_boto3_session): # noqa: ARG001
183+
client_mock = Mock()
184+
http_response = Mock(status_code=500, text="internal server error")
185+
client_mock.invoke_endpoint.side_effect = requests.HTTPError(response=http_response)
186+
component = SagemakerGenerator(model="test-model")
187+
component.client = client_mock
188+
with pytest.raises(
189+
SagemakerInferenceError,
190+
match=re.escape(
191+
"SageMaker Inference returned an error. Status code: 500. Response body: internal server error"
192+
),
193+
):
194+
component.run("What's Natural Language Processing?")
195+
196+
197+
def test_run_passes_custom_attributes_and_generation_kwargs(set_env_variables, mock_boto3_session): # noqa: ARG001
198+
client_mock = Mock()
199+
client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'{"generated_text": "ok"}')}
200+
component = SagemakerGenerator(
201+
model="test-model",
202+
aws_custom_attributes={"accept_eula": True, "max_retries": 3},
203+
)
204+
component.client = client_mock
205+
component.run("prompt", generation_kwargs={"max_new_tokens": 5})
206+
207+
call_kwargs = client_mock.invoke_endpoint.call_args.kwargs
208+
assert call_kwargs["EndpointName"] == "test-model"
209+
assert call_kwargs["CustomAttributes"] == "accept_eula=true;max_retries=3"
210+
assert '"max_new_tokens": 5' in call_kwargs["Body"]
211+
212+
147213
@pytest.mark.skipif(
148214
(not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)),
149215
reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.",

0 commit comments

Comments
 (0)