Skip to content

Commit 30ebaa8

Browse files
committed
minor: making retry_if_exception more defensive
1 parent 87c4cca commit 30ebaa8

2 files changed

Lines changed: 50 additions & 7 deletions

File tree

pyathena/util.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import logging
55
import re
6-
from typing import Any, Callable, Iterable, Optional, Pattern, Tuple
6+
from typing import Any, Callable, Iterable, Optional, Pattern, Tuple, cast
77

88
import tenacity
99
from tenacity import after_log, retry_if_exception, stop_after_attempt, wait_exponential
@@ -172,12 +172,18 @@ def retry_api_call(
172172
Only retries on AWS exceptions listed in the RetryConfig.exceptions.
173173
Does not retry on client errors or non-AWS exceptions.
174174
"""
175+
176+
def _extract_code(ex: BaseException) -> Optional[str]:
177+
resp = cast(Optional[dict[str, Any]], getattr(ex, "response", None))
178+
err = cast(Optional[dict[str, Any]], (resp or {}).get("Error"))
179+
return cast(Optional[str], (err or {}).get("Code"))
180+
181+
def _is_retryable(ex: BaseException) -> bool:
182+
code = _extract_code(ex)
183+
return code is not None and code in config.exceptions
184+
175185
retry = tenacity.Retrying(
176-
retry=retry_if_exception(
177-
lambda e: getattr(e, "response", {}).get("Error", {}).get("Code") in config.exceptions
178-
if e
179-
else False
180-
),
186+
retry=retry_if_exception(_is_retryable),
181187
stop=stop_after_attempt(config.attempt),
182188
wait=wait_exponential(
183189
multiplier=config.multiplier,

tests/pyathena/test_util.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# -*- coding: utf-8 -*-
2+
from typing import Any
3+
24
import pytest
35

46
from pyathena import DataError
5-
from pyathena.util import parse_output_location, strtobool
7+
from pyathena.util import RetryConfig, parse_output_location, retry_api_call, strtobool
68

79

810
def test_parse_output_location():
@@ -25,3 +27,38 @@ def test_strtobool():
2527

2628
for n in no:
2729
assert not strtobool(n)
30+
31+
32+
class _WithCodeError(Exception):
33+
def __init__(self, code: int) -> None:
34+
super().__init__(f"error:{code}")
35+
self.response = {"Error": {"Code": code}}
36+
37+
38+
class _NoResponseError(Exception):
39+
def __init__(self) -> None:
40+
super().__init__("error")
41+
self.response = None
42+
43+
44+
def _test_retry(ex: Exception) -> None:
45+
calls = {"n": 0}
46+
47+
def fn() -> Any:
48+
calls["n"] += 1
49+
raise ex
50+
51+
cfg = RetryConfig(attempt=1, max_delay=1)
52+
53+
with pytest.raises(type(ex)):
54+
retry_api_call(fn, config=cfg)
55+
56+
assert calls["n"] == 1
57+
58+
59+
def test_retry_api_call():
60+
_test_retry(_WithCodeError(500))
61+
62+
63+
def test_retry_api_call_with_none_error():
64+
_test_retry(_NoResponseError())

0 commit comments

Comments
 (0)