Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
retry_stop_after_attempts_duration: Optional[float] = None,
retry_delay_default: Optional[float] = None,
retry_dangerous_codes: Optional[List[int]] = None,
respect_server_retry_after_header: Optional[bool] = None,
proxy_auth_method: Optional[str] = None,
pool_connections: Optional[int] = None,
pool_maxsize: Optional[int] = None,
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
)
self.retry_delay_default = retry_delay_default or 5.0
self.retry_dangerous_codes = retry_dangerous_codes or []
self.respect_server_retry_after_header = bool(respect_server_retry_after_header)
self.proxy_auth_method = proxy_auth_method
self.pool_connections = pool_connections or 10
self.pool_maxsize = pool_maxsize or 20
Expand Down
15 changes: 13 additions & 2 deletions src/databricks/sql/auth/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
stop_after_attempts_duration: float,
delay_default: float,
force_dangerous_codes: List[int],
respect_server_retry_after_header: bool = False,
urllib3_kwargs: dict = {},
):
# These values do not change from one command to the next
Expand All @@ -103,6 +104,7 @@ def __init__(
self.stop_after_attempts_duration = stop_after_attempts_duration
self._delay_default = delay_default
self.force_dangerous_codes = force_dangerous_codes
self.respect_server_retry_after_header = respect_server_retry_after_header

# the urllib3 kwargs are a mix of configuration (some of which we override)
# and counters like `total` or `connect` which may change between successive retries
Expand Down Expand Up @@ -202,6 +204,7 @@ def new(
stop_after_attempts_duration=self.stop_after_attempts_duration,
delay_default=self.delay_default,
force_dangerous_codes=self.force_dangerous_codes,
respect_server_retry_after_header=self.respect_server_retry_after_header,
urllib3_kwargs={},
)

Expand Down Expand Up @@ -323,7 +326,9 @@ def get_backoff_time(self) -> float:

return proposed_backoff

def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
def should_retry(
self, method: str, status_code: int, has_retry_after: bool = False
) -> Tuple[bool, str]:
"""This method encapsulates the connector's approach to retries.

We always retry a request unless one of these conditions is met:
Expand Down Expand Up @@ -388,6 +393,12 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
if not self._is_method_retryable(method):
return False, "Only POST requests are retried"

# When respect_server_retry_after_header is enabled, only retry when the
# server explicitly signals it's safe via a Retry-After header. This prevents
# duplicate side effects for non-idempotent operations.
if self.respect_server_retry_after_header and not has_retry_after:
return (False, "respect_server_retry_after_header mode: no Retry-After header present")

# Request failed, was an ExecuteStatement and the command may have reached the server
if (
self.command_type == CommandType.EXECUTE_STATEMENT
Expand Down Expand Up @@ -430,7 +441,7 @@ def is_retry(
Logs a debug message if the request will be retried
"""

should_retry, msg = self.should_retry(method, status_code)
should_retry, msg = self.should_retry(method, status_code, has_retry_after)

if should_retry:
logger.debug(msg)
Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/backend/sea/utils/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(
)
self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0)
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
self._respect_server_retry_after_header = kwargs.get(
"_respect_server_retry_after_header", False
)

# Connection pooling settings
self.max_connections = kwargs.get("max_connections", 10)
Expand All @@ -116,6 +119,7 @@ def __init__(
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
delay_default=self._retry_delay_default,
force_dangerous_codes=self.force_dangerous_codes,
respect_server_retry_after_header=self._respect_server_retry_after_header,
urllib3_kwargs=urllib3_kwargs,
)
else:
Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def __init__(
" This behaviour is deprecated and will be removed in a future release."
)
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
self._respect_server_retry_after_header = kwargs.get(
"_respect_server_retry_after_header", False
)

additional_transport_args = {}

Expand All @@ -217,6 +220,7 @@ def __init__(
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
delay_default=self._retry_delay_default,
force_dangerous_codes=self.force_dangerous_codes,
respect_server_retry_after_header=self._respect_server_retry_after_header,
urllib3_kwargs=urllib3_kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/common/unified_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _setup_pool_managers(self):
stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration,
delay_default=self.config.retry_delay_default,
force_dangerous_codes=self.config.retry_dangerous_codes,
respect_server_retry_after_header=self.config.respect_server_retry_after_header,
)

# Initialize the required attributes that DatabricksRetryPolicy expects
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs):
),
retry_delay_default=kwargs.get("_retry_delay_default"),
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"),
respect_server_retry_after_header=kwargs.get("_respect_server_retry_after_header"),
proxy_auth_method=kwargs.get("_proxy_auth_method"),
pool_connections=kwargs.get("_pool_connections"),
pool_maxsize=kwargs.get("_pool_maxsize"),
Expand Down
89 changes: 86 additions & 3 deletions tests/unit/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@


class TestRetry:
@pytest.fixture()
def retry_policy(self) -> DatabricksRetryPolicy:
return DatabricksRetryPolicy(
def _make_retry_policy(self, **overrides) -> DatabricksRetryPolicy:
defaults = dict(
delay_min=1,
delay_max=30,
stop_after_attempts_count=3,
stop_after_attempts_duration=900,
delay_default=2,
force_dangerous_codes=[],
)
defaults.update(overrides)
return DatabricksRetryPolicy(**defaults)

@pytest.fixture()
def retry_policy(self) -> DatabricksRetryPolicy:
return self._make_retry_policy()

@pytest.fixture()
def error_history(self) -> RequestHistory:
Expand Down Expand Up @@ -84,6 +89,84 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy):
# Internally urllib3 calls the increment function generating a new instance for every retry
retry_policy = retry_policy.increment()

def test_respect_server_retry_after__retries_with_retry_after(self):
"""429 + Retry-After header → should retry"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True)
assert should_retry is True

def test_respect_server_retry_after__no_retry_without_retry_after(self):
"""429 without Retry-After header → no retry"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
should_retry, msg = policy.should_retry("POST", 429, has_retry_after=False)
assert should_retry is False
assert "respect_server_retry_after_header" in msg

def test_respect_server_retry_after__no_retry_503_without_header(self):
"""503 without Retry-After header → no retry"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
should_retry, msg = policy.should_retry("POST", 503, has_retry_after=False)
assert should_retry is False
assert "respect_server_retry_after_header" in msg

def test_respect_server_retry_after__overrides_dangerous_codes(self):
"""force_dangerous_codes=[500] + no Retry-After → no retry in respect_server_retry_after_header mode"""
policy = self._make_retry_policy(
force_dangerous_codes=[500], respect_server_retry_after_header=True
)
policy._retry_start_time = time.time()
policy.command_type = CommandType.EXECUTE_STATEMENT
should_retry, msg = policy.should_retry("POST", 500, has_retry_after=False)
assert should_retry is False
assert "respect_server_retry_after_header" in msg

def test_respect_server_retry_after__non_retryable_codes_unaffected(self):
"""401/403/501 still don't retry even with Retry-After header"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
for code in [401, 403, 501]:
should_retry, msg = policy.should_retry(
"POST", code, has_retry_after=True
)
assert should_retry is False, f"Code {code} should never retry"

def test_default_mode_unchanged(self, retry_policy):
"""respect_server_retry_after_header=False preserves existing behavior — 429 retries without header"""
retry_policy._retry_start_time = time.time()
retry_policy.command_type = CommandType.OTHER
should_retry, msg = retry_policy.should_retry(
"POST", 429, has_retry_after=False
)
assert should_retry is True

def test_respect_server_retry_after__survives_new(self):
"""urllib3 calls .new() between retries to create a fresh policy instance.
Verify that respect_server_retry_after_header is carried over and still enforced."""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
new_policy = policy.new()
assert new_policy.respect_server_retry_after_header is True
# The new instance should still block retries without Retry-After
should_retry, msg = new_policy.should_retry("POST", 429, has_retry_after=False)
assert should_retry is False
assert "respect_server_retry_after_header" in msg

def test_respect_server_retry_after__execute_statement_with_retry_after(self):
"""EXECUTE_STATEMENT + 429 + Retry-After header → retry"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.EXECUTE_STATEMENT
should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True)
assert should_retry is True

def test_404_does_not_retry_for_any_command_type(self, retry_policy):
"""Test that 404 never retries for any CommandType"""
retry_policy._retry_start_time = time.time()
Expand Down
41 changes: 27 additions & 14 deletions tests/unit/test_unified_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def client_context(self):
context.retry_stop_after_attempts_duration = 300.0
context.retry_delay_default = 5.0
context.retry_dangerous_codes = []
context.respect_server_retry_after_header = False
context.proxy_auth_method = None
context.pool_connections = 10
context.pool_maxsize = 20
Expand All @@ -48,16 +49,19 @@ def http_client(self, client_context):
"""Create UnifiedHttpClient instance."""
return UnifiedHttpClient(client_context)

@pytest.mark.parametrize("status_code,path", [
(429, "reason.response"),
(503, "reason.response"),
(500, "direct_response"),
])
@pytest.mark.parametrize(
"status_code,path",
[
(429, "reason.response"),
(503, "reason.response"),
(500, "direct_response"),
],
)
def test_max_retry_error_with_status_codes(self, http_client, status_code, path):
"""Test MaxRetryError with various status codes and response paths."""
mock_pool = Mock()
max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com")

if path == "reason.response":
max_retry_error.reason = Mock()
max_retry_error.reason.response = Mock()
Expand All @@ -79,12 +83,21 @@ def test_max_retry_error_with_status_codes(self, http_client, status_code, path)
assert "http-code" in error.context
assert error.context["http-code"] == status_code

@pytest.mark.parametrize("setup_func", [
lambda e: None, # No setup - error with no attributes
lambda e: setattr(e, "reason", None), # reason=None
lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None
lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr
])
@pytest.mark.parametrize(
"setup_func",
[
lambda e: None, # No setup - error with no attributes
lambda e: setattr(e, "reason", None), # reason=None
lambda e: (
setattr(e, "reason", Mock()),
setattr(e.reason, "response", None),
), # reason.response=None
lambda e: (
setattr(e, "reason", Mock()),
setattr(e.reason, "response", Mock(spec=[])),
), # No status attr
],
)
def test_max_retry_error_missing_status(self, http_client, setup_func):
"""Test MaxRetryError without status code (no crash, empty context)."""
mock_pool = Mock()
Expand All @@ -104,12 +117,12 @@ def test_max_retry_error_prefers_reason_response(self, http_client):
"""Test that e.reason.response.status is preferred over e.response.status."""
mock_pool = Mock()
max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com")

# Set both structures with different status codes
max_retry_error.reason = Mock()
max_retry_error.reason.response = Mock()
max_retry_error.reason.response.status = 429 # Should use this

max_retry_error.response = Mock()
max_retry_error.response.status = 500 # Should be ignored

Expand Down
Loading