Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion integrations/amazon_sagemaker/tests/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
from unittest.mock import Mock

import pytest
import requests
from botocore.exceptions import BotoCoreError
from haystack.utils.auth import EnvVarSecret

from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator
from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError
from haystack_integrations.components.generators.amazon_sagemaker.errors import (
AWSConfigurationError,
SagemakerInferenceError,
SagemakerNotReadyError,
)


def test_to_dict(set_env_variables, mock_boto3_session): # noqa: ARG001
Expand Down Expand Up @@ -144,6 +150,66 @@ def test_run_with_single_dictionary(set_env_variables, mock_boto3_session): # n
assert response["meta"][0]["other"] == "metadata"


def test_run_raises_on_unexpected_response_type(set_env_variables, mock_boto3_session): # noqa: ARG001
client_mock = Mock()
client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'"just-a-string"')}
component = SagemakerGenerator(model="test-model")
component.client = client_mock
with pytest.raises(ValueError, match="Unexpected model response type"):
component.run("What's Natural Language Processing?")


def test_run_with_unknown_generation_key_returns_none_reply(set_env_variables, mock_boto3_session): # noqa: ARG001
client_mock = Mock()
client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'{"unknown_key": "value"}')}
component = SagemakerGenerator(model="test-model")
component.client = client_mock
response = component.run("What's Natural Language Processing?")
assert response["replies"] == [None]
assert response["meta"] == [{"unknown_key": "value"}]


def test_run_raises_not_ready_error_on_429(set_env_variables, mock_boto3_session): # noqa: ARG001
client_mock = Mock()
http_response = Mock(status_code=429, text="model is loading")
client_mock.invoke_endpoint.side_effect = requests.HTTPError(response=http_response)
component = SagemakerGenerator(model="test-model")
component.client = client_mock
with pytest.raises(SagemakerNotReadyError, match="Sagemaker model not ready: model is loading"):
component.run("What's Natural Language Processing?")


def test_run_raises_inference_error_on_other_http_error(set_env_variables, mock_boto3_session): # noqa: ARG001
client_mock = Mock()
http_response = Mock(status_code=500, text="internal server error")
client_mock.invoke_endpoint.side_effect = requests.HTTPError(response=http_response)
component = SagemakerGenerator(model="test-model")
component.client = client_mock
with pytest.raises(
SagemakerInferenceError,
match=re.escape(
"SageMaker Inference returned an error. Status code: 500. Response body: internal server error"
),
):
component.run("What's Natural Language Processing?")


def test_run_passes_custom_attributes_and_generation_kwargs(set_env_variables, mock_boto3_session): # noqa: ARG001
client_mock = Mock()
client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'{"generated_text": "ok"}')}
component = SagemakerGenerator(
model="test-model",
aws_custom_attributes={"accept_eula": True, "max_retries": 3},
)
component.client = client_mock
component.run("prompt", generation_kwargs={"max_new_tokens": 5})

call_kwargs = client_mock.invoke_endpoint.call_args.kwargs
assert call_kwargs["EndpointName"] == "test-model"
assert call_kwargs["CustomAttributes"] == "accept_eula=true;max_retries=3"
assert '"max_new_tokens": 5' in call_kwargs["Body"]


@pytest.mark.skipif(
(not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)),
reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.",
Expand Down
Loading