From e3811f8c15b37771cc145db901cf7e0535e7917a Mon Sep 17 00:00:00 2001 From: kkedziak Date: Tue, 27 May 2025 17:07:00 +0200 Subject: [PATCH 1/2] Retries --- solnlib/splunk_rest_client.py | 35 +++++++----- tests/unit/conftest.py | 77 +++++++++++++++++++++++++++ tests/unit/test_splunk_rest_client.py | 66 +++++++++++++++++++++++ 3 files changed, 164 insertions(+), 14 deletions(-) create mode 100644 tests/unit/conftest.py diff --git a/solnlib/splunk_rest_client.py b/solnlib/splunk_rest_client.py index 7e858248..4bd65268 100644 --- a/solnlib/splunk_rest_client.py +++ b/solnlib/splunk_rest_client.py @@ -71,6 +71,8 @@ def _request_handler(context): 'cert_file': string 'pool_connections', int, 'pool_maxsize', int, + 'max_retries': int, + 'retry_status_codes': list, } :type content: dict """ @@ -103,24 +105,29 @@ def _request_handler(context): cert = None retries = Retry( - total=MAX_REQUEST_RETRIES, + total=context.get("max_retries", MAX_REQUEST_RETRIES), backoff_factor=0.3, - status_forcelist=[500, 502, 503, 504], + status_forcelist=context.get("retry_status_codes", [500, 502, 503, 504]), allowed_methods=["GET", "POST", "PUT", "DELETE"], raise_on_status=False, ) - if context.get("pool_connections", 0): - logging.info("Use HTTP connection pooling") - session = requests.Session() - adapter = requests.adapters.HTTPAdapter( - max_retries=retries, - pool_connections=context.get("pool_connections", 10), - pool_maxsize=context.get("pool_maxsize", 10), - ) - session.mount("https://", adapter) - req_func = session.request - else: - req_func = requests.request + + adapter_args = { + "max_retries": retries, + } + + # By default, pool_connections and pool_maxsize are set to 10 in urllib3 + if "pool_connections" in context: + adapter_args["pool_connections"] = context["pool_connections"] + if "pool_maxsize" in context: + adapter_args["pool_maxsize"] = context["pool_maxsize"] + + session = requests.Session() + adapter = requests.adapters.HTTPAdapter(**adapter_args) + session.mount("http://", adapter) + session.mount("https://", adapter) + + req_func = session.request def request(url, message, **kwargs): """ diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..8b3bf665 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,77 @@ +import json +import socket +from contextlib import closing +from http.server import BaseHTTPRequestHandler, HTTPServer +from threading import Thread + +import pytest + + +@pytest.fixture(scope="session") +def http_mock_server(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = s.getsockname()[1] + + class Mock: + def __init__(self, host, port): + self.host = host + self.port = port + self.get_func = None + + def get(self, func): + self.get_func = func + return func + + mock = Mock("localhost", port) + + class RequestArg: + def __init__(self): + self.headers = { + "Content-Type": "application/json", + } + self.response_code = 200 + + def send_header(self, key, value): + self.headers[key] = value + + def send_response(self, code): + self.response_code = code + + class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if mock.get_func is None: + self.send_response(404) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"error": "Not Found"}).encode("utf-8")) + return + + request = RequestArg() + response = mock.get_func(request) + + self.send_response(request.response_code) + + for key, value in request.headers.items(): + self.send_header(key, value) + + self.end_headers() + + if isinstance(response, dict): + response = json.dumps(response) + + self.wfile.write(response.encode("utf-8")) + + server_address = ("", mock.port) + httpd = HTTPServer(server_address, Handler) + + thread = Thread(target=httpd.serve_forever) + thread.setDaemon(True) + thread.start() + + yield mock + + httpd.shutdown() + httpd.server_close() + thread.join() diff --git a/tests/unit/test_splunk_rest_client.py b/tests/unit/test_splunk_rest_client.py index de43dbdd..25d9d10d 100644 --- a/tests/unit/test_splunk_rest_client.py +++ b/tests/unit/test_splunk_rest_client.py @@ -17,6 +17,8 @@ from unittest import mock import pytest +from splunklib.binding import HTTPError + from solnlib.splunk_rest_client import MAX_REQUEST_RETRIES from requests.exceptions import ConnectionError @@ -109,3 +111,67 @@ def test_request_retry(http_conn_pool, http_resp, mock_get_splunkd_access_info): http_conn_pool.side_effect = side_effects with pytest.raises(ConnectionError): rest_client.get("test") + + +@pytest.mark.parametrize("error_code", [429, 500, 503]) +def test_request_throttling(http_mock_server, error_code): + @http_mock_server.get + def throttling(request): + """Mock endpoint to simulate request throttling. + + The endpoint will return an error status code for the first 5 + requests, and a 200 status code for subsequent requests. + """ + number = getattr(throttling, "call_count", 0) + throttling.call_count = number + 1 + + if number < 2: + request.send_response(error_code) + request.send_header("Retry-After", "1") + return {"error": f"Error {number}"} + + return {"content": "Success"} + + rest_client = SplunkRestClient( + "msg_name_1", + "session_key", + "_", + scheme="http", + host="localhost", + port=http_mock_server.port, + ) + + resp = rest_client.get("test") + assert resp.status == 200 + assert resp.body.read().decode("utf-8") == '{"content": "Success"}' + + +@pytest.mark.parametrize("error_code", [429, 500, 503]) +def test_request_throttling_exceeded(http_mock_server, error_code): + @http_mock_server.get + def throttling(request): + """Mock endpoint to simulate request throttling. + + The endpoint will always return an error status code. + """ + number = getattr(throttling, "call_count", 0) + throttling.call_count = number + 1 + + request.send_response(error_code) + request.send_header("Retry-After", "1") + return {"error": f"Error {number}"} + + rest_client = SplunkRestClient( + "msg_name_1", + "session_key", + "_", + scheme="http", + host="localhost", + port=http_mock_server.port, + ) + + with pytest.raises(HTTPError) as ex: + rest_client.get("test") + + assert ex.value.status == error_code + assert ex.value.body.decode("utf-8") == '{"error": "Error 5"}' From 6b9ec57fc2c0d7cb1eda411e993ea14888256ede Mon Sep 17 00:00:00 2001 From: kkedziak Date: Wed, 28 May 2025 10:44:53 +0200 Subject: [PATCH 2/2] Separate instances --- solnlib/splunk_rest_client.py | 38 ++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/solnlib/splunk_rest_client.py b/solnlib/splunk_rest_client.py index 4bd65268..1f44ef58 100644 --- a/solnlib/splunk_rest_client.py +++ b/solnlib/splunk_rest_client.py @@ -104,28 +104,30 @@ def _request_handler(context): else: cert = None - retries = Retry( - total=context.get("max_retries", MAX_REQUEST_RETRIES), - backoff_factor=0.3, - status_forcelist=context.get("retry_status_codes", [500, 502, 503, 504]), - allowed_methods=["GET", "POST", "PUT", "DELETE"], - raise_on_status=False, - ) + def adapter(): + retries = Retry( + total=context.get("max_retries", MAX_REQUEST_RETRIES), + backoff_factor=0.3, + status_forcelist=context.get("retry_status_codes", [500, 502, 503, 504]), + allowed_methods=["GET", "POST", "PUT", "DELETE"], + raise_on_status=False, + ) - adapter_args = { - "max_retries": retries, - } + adapter_args = { + "max_retries": retries, + } + + # By default, pool_connections and pool_maxsize are set to 10 in urllib3 + if "pool_connections" in context: + adapter_args["pool_connections"] = context["pool_connections"] + if "pool_maxsize" in context: + adapter_args["pool_maxsize"] = context["pool_maxsize"] - # By default, pool_connections and pool_maxsize are set to 10 in urllib3 - if "pool_connections" in context: - adapter_args["pool_connections"] = context["pool_connections"] - if "pool_maxsize" in context: - adapter_args["pool_maxsize"] = context["pool_maxsize"] + return requests.adapters.HTTPAdapter(**adapter_args) session = requests.Session() - adapter = requests.adapters.HTTPAdapter(**adapter_args) - session.mount("http://", adapter) - session.mount("https://", adapter) + session.mount("http://", adapter()) + session.mount("https://", adapter()) req_func = session.request