|
3 | 3 | from unittest.mock import Mock |
4 | 4 |
|
5 | 5 | import pytest |
| 6 | +import requests |
6 | 7 | from botocore.exceptions import BotoCoreError |
7 | 8 | from haystack.utils.auth import EnvVarSecret |
| 9 | + |
8 | 10 | from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator |
9 | | -from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError |
| 11 | +from haystack_integrations.components.generators.amazon_sagemaker.errors import ( |
| 12 | + AWSConfigurationError, |
| 13 | + SagemakerInferenceError, |
| 14 | + SagemakerNotReadyError, |
| 15 | +) |
10 | 16 |
|
11 | 17 |
|
12 | 18 | def test_to_dict(set_env_variables, mock_boto3_session): # noqa: ARG001 |
@@ -144,6 +150,66 @@ def test_run_with_single_dictionary(set_env_variables, mock_boto3_session): # n |
144 | 150 | assert response["meta"][0]["other"] == "metadata" |
145 | 151 |
|
146 | 152 |
|
| 153 | +def test_run_raises_on_unexpected_response_type(set_env_variables, mock_boto3_session): # noqa: ARG001 |
| 154 | + client_mock = Mock() |
| 155 | + client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'"just-a-string"')} |
| 156 | + component = SagemakerGenerator(model="test-model") |
| 157 | + component.client = client_mock |
| 158 | + with pytest.raises(ValueError, match="Unexpected model response type"): |
| 159 | + component.run("What's Natural Language Processing?") |
| 160 | + |
| 161 | + |
| 162 | +def test_run_with_unknown_generation_key_returns_none_reply(set_env_variables, mock_boto3_session): # noqa: ARG001 |
| 163 | + client_mock = Mock() |
| 164 | + client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'{"unknown_key": "value"}')} |
| 165 | + component = SagemakerGenerator(model="test-model") |
| 166 | + component.client = client_mock |
| 167 | + response = component.run("What's Natural Language Processing?") |
| 168 | + assert response["replies"] == [None] |
| 169 | + assert response["meta"] == [{"unknown_key": "value"}] |
| 170 | + |
| 171 | + |
| 172 | +def test_run_raises_not_ready_error_on_429(set_env_variables, mock_boto3_session): # noqa: ARG001 |
| 173 | + client_mock = Mock() |
| 174 | + http_response = Mock(status_code=429, text="model is loading") |
| 175 | + client_mock.invoke_endpoint.side_effect = requests.HTTPError(response=http_response) |
| 176 | + component = SagemakerGenerator(model="test-model") |
| 177 | + component.client = client_mock |
| 178 | + with pytest.raises(SagemakerNotReadyError, match="Sagemaker model not ready: model is loading"): |
| 179 | + component.run("What's Natural Language Processing?") |
| 180 | + |
| 181 | + |
| 182 | +def test_run_raises_inference_error_on_other_http_error(set_env_variables, mock_boto3_session): # noqa: ARG001 |
| 183 | + client_mock = Mock() |
| 184 | + http_response = Mock(status_code=500, text="internal server error") |
| 185 | + client_mock.invoke_endpoint.side_effect = requests.HTTPError(response=http_response) |
| 186 | + component = SagemakerGenerator(model="test-model") |
| 187 | + component.client = client_mock |
| 188 | + with pytest.raises( |
| 189 | + SagemakerInferenceError, |
| 190 | + match=re.escape( |
| 191 | + "SageMaker Inference returned an error. Status code: 500. Response body: internal server error" |
| 192 | + ), |
| 193 | + ): |
| 194 | + component.run("What's Natural Language Processing?") |
| 195 | + |
| 196 | + |
| 197 | +def test_run_passes_custom_attributes_and_generation_kwargs(set_env_variables, mock_boto3_session): # noqa: ARG001 |
| 198 | + client_mock = Mock() |
| 199 | + client_mock.invoke_endpoint.return_value = {"Body": Mock(read=lambda: b'{"generated_text": "ok"}')} |
| 200 | + component = SagemakerGenerator( |
| 201 | + model="test-model", |
| 202 | + aws_custom_attributes={"accept_eula": True, "max_retries": 3}, |
| 203 | + ) |
| 204 | + component.client = client_mock |
| 205 | + component.run("prompt", generation_kwargs={"max_new_tokens": 5}) |
| 206 | + |
| 207 | + call_kwargs = client_mock.invoke_endpoint.call_args.kwargs |
| 208 | + assert call_kwargs["EndpointName"] == "test-model" |
| 209 | + assert call_kwargs["CustomAttributes"] == "accept_eula=true;max_retries=3" |
| 210 | + assert '"max_new_tokens": 5' in call_kwargs["Body"] |
| 211 | + |
| 212 | + |
147 | 213 | @pytest.mark.skipif( |
148 | 214 | (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), |
149 | 215 | reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", |
|
0 commit comments