diff --git a/aws/logs_monitoring/settings.py b/aws/logs_monitoring/settings.py index 605732e73..2900e326d 100644 --- a/aws/logs_monitoring/settings.py +++ b/aws/logs_monitoring/settings.py @@ -250,24 +250,31 @@ def is_api_key_valid(): # Validate the API key logger.debug("Validating the Datadog API key") - with requests.Session() as s: - retries = requests.adapters.Retry( - total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504] - ) + try: + with requests.Session() as s: + retries = requests.adapters.Retry( + total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504] + ) - s.mount("http://", requests.adapters.HTTPAdapter(max_retries=retries)) - s.mount("https://", requests.adapters.HTTPAdapter(max_retries=retries)) + s.mount("http://", requests.adapters.HTTPAdapter(max_retries=retries)) + s.mount("https://", requests.adapters.HTTPAdapter(max_retries=retries)) - validation_res = s.get( - "{}/api/v1/validate?api_key={}".format(DD_API_URL, DD_API_KEY), - verify=(not DD_SKIP_SSL_VALIDATION), - timeout=10, - ) - if not validation_res.ok: - logger.error( - f"Datadog API key validation failed (HTTP {validation_res.status_code}). Verify your API key is correct and DD_SITE matches your Datadog account region (current: {DD_SITE}). See: https://docs.datadoghq.com/getting_started/site/" + validation_res = s.get( + "{}/api/v1/validate?api_key={}".format(DD_API_URL, DD_API_KEY), + verify=(not DD_SKIP_SSL_VALIDATION), + timeout=10, ) - return False + if not validation_res.ok: + logger.error( + f"Datadog API key validation failed (HTTP {validation_res.status_code}). Verify your API key is correct and DD_SITE matches your Datadog account region (current: {DD_SITE}). See: https://docs.datadoghq.com/getting_started/site/" + ) + return False + except requests.exceptions.RequestException as e: + logger.warning( + f"Could not validate Datadog API key due to a network error: {e}. " + "Proceeding without validation." + ) + return False return True diff --git a/aws/logs_monitoring/tests/test_settings.py b/aws/logs_monitoring/tests/test_settings.py new file mode 100644 index 000000000..f353889f1 --- /dev/null +++ b/aws/logs_monitoring/tests/test_settings.py @@ -0,0 +1,63 @@ +import unittest +from unittest.mock import MagicMock, patch + +from settings import is_api_key_valid + +VALID_API_KEY = "11111111111111111111111111111111" + + +# For the integration tests to work because of other tests set sys.modules["requests"] as a MagicMock. +class _FakeNetworkError(Exception): + pass + + +class TestIsApiKeyValid(unittest.TestCase): + @patch("settings.DD_API_KEY", VALID_API_KEY) + @patch("settings.requests.Session") + def test_valid_api_key(self, mock_session_cls): + mock_response = MagicMock() + mock_response.ok = True + mock_session_cls.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + self.assertTrue(is_api_key_valid()) + + @patch("settings.DD_API_KEY", "") + def test_empty_api_key(self): + with self.assertRaises(Exception): + is_api_key_valid() + + @patch("settings.DD_API_KEY", "shortapikey") + def test_invalid_api_key_format(self): + with self.assertRaises(Exception): + is_api_key_valid() + + @patch("settings.DD_API_KEY", VALID_API_KEY) + @patch("settings.logger") + @patch("settings.requests.exceptions.RequestException", _FakeNetworkError) + @patch("settings.requests.Session") + def test_on_connection_exception(self, mock_session_cls, mock_logger): + mock_session_cls.return_value.__enter__.return_value.get.side_effect = ( + _FakeNetworkError("DNS resolution failed") + ) + result = is_api_key_valid() + self.assertFalse(result) + mock_logger.warning.assert_called_once() + self.assertIn("network error", mock_logger.warning.call_args[0][0].lower()) + + @patch("settings.DD_API_KEY", VALID_API_KEY) + @patch("settings.logger") + @patch("settings.requests.exceptions.RequestException", _FakeNetworkError) + @patch("settings.requests.Session") + def test_on_timeout_exception(self, mock_session_cls, mock_logger): + mock_session_cls.return_value.__enter__.return_value.get.side_effect = ( + _FakeNetworkError("Request timed out") + ) + result = is_api_key_valid() + self.assertFalse(result) + mock_logger.warning.assert_called_once() + self.assertIn("network error", mock_logger.warning.call_args[0][0].lower()) + + +if __name__ == "__main__": + unittest.main()