Skip to content

Commit 62af328

Browse files
committed
fix(auth): re-mint on any refresh failure
1 parent 499f607 commit 62af328

2 files changed

Lines changed: 106 additions & 24 deletions

File tree

hotdata/_auth.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -177,33 +177,44 @@ def bearer_value(self):
177177
return self._jwt
178178

179179
def _mint(self, params):
180-
# Returns True on success. A non-200 from a refresh returns False so the
181-
# caller can re-mint from the API token; a non-200 from an api_token
182-
# mint raises TokenExchangeError.
180+
# Returns True on success. The refresh path is best-effort: ANY failure
181+
# -- a non-200, a transport error, or a malformed/missing-token body --
182+
# returns False so the caller re-mints from the held API token. An
183+
# api_token mint instead raises TokenExchangeError on any failure, since
184+
# there is no further fallback.
183185
params["client_id"] = _CLIENT_ID
184-
pool = self._pool or _pool_from_config(self._config) # reuses ssl_ca_cert/cert/proxy
185-
host = self._config.host.rstrip("/") # read host lazily -- may be set post-construct
186-
resp = pool.request(
187-
"POST",
188-
f"{host}/v1/auth/jwt",
189-
body=urlencode(params),
190-
headers={"Content-Type": "application/x-www-form-urlencoded"},
191-
timeout=_TIMEOUT,
192-
)
193-
if resp.status != 200:
194-
if params["grant_type"] == "refresh_token":
195-
return False # let caller re-mint from the API token
196-
raise TokenExchangeError(
197-
f"token exchange failed: {resp.status} {resp.data[:200]!r}"
198-
)
186+
is_refresh = params["grant_type"] == "refresh_token"
199187
try:
188+
pool = self._pool or _pool_from_config(self._config) # reuses ssl_ca_cert/cert/proxy
189+
host = self._config.host.rstrip("/") # read host lazily -- may be set post-construct
190+
resp = pool.request(
191+
"POST",
192+
f"{host}/v1/auth/jwt",
193+
body=urlencode(params),
194+
headers={"Content-Type": "application/x-www-form-urlencoded"},
195+
timeout=_TIMEOUT,
196+
)
197+
if resp.status != 200:
198+
raise TokenExchangeError(
199+
f"token exchange failed: {resp.status} {resp.data[:200]!r}"
200+
)
200201
data = json.loads(resp.data)
201-
except (ValueError, TypeError) as exc:
202-
raise TokenExchangeError(
203-
f"token exchange returned a non-JSON body: {resp.data[:200]!r}"
204-
) from exc
205-
self._jwt = data["access_token"]
206-
self._exp = time.time() + data.get("expires_in", 300)
202+
token = data["access_token"]
203+
expires_in = float(data.get("expires_in", 300))
204+
except (
205+
TokenExchangeError,
206+
urllib3.exceptions.HTTPError,
207+
ValueError,
208+
TypeError,
209+
KeyError,
210+
) as exc:
211+
if is_refresh:
212+
return False # let caller re-mint from the API token
213+
if isinstance(exc, TokenExchangeError):
214+
raise
215+
raise TokenExchangeError(f"token exchange failed: {exc!r}") from exc
216+
self._jwt = token
217+
self._exp = time.time() + expires_in
207218
self._refresh = data.get("refresh_token") or self._refresh
208219
return True
209220

tests/test_auth.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,74 @@ def test_non_json_success_body_raises_token_exchange_error() -> None:
503503

504504
with pytest.raises(TokenExchangeError):
505505
mgr.bearer_value()
506+
507+
508+
def test_missing_access_token_raises_token_exchange_error() -> None:
509+
"""A 200 with valid JSON but no ``access_token`` (e.g. a misrouted endpoint
510+
returning some other JSON document) must surface as a TokenExchangeError,
511+
not a bare KeyError."""
512+
pool = _FakePool([_FakeResponse(200, {"token_type": "Bearer"})])
513+
mgr = _TokenManager("hd_secret_token", _config(), pool=pool)
514+
515+
with pytest.raises(TokenExchangeError):
516+
mgr.bearer_value()
517+
518+
519+
# --------------------------------------------------------------------------
520+
# Refresh that fails by *raising* (not just a non-200) still re-mints
521+
# --------------------------------------------------------------------------
522+
523+
524+
def test_refresh_raising_falls_back_to_api_token_mint() -> None:
525+
"""The refresh step is best-effort: if it fails in *any* way -- not just a
526+
non-200, but a malformed/non-JSON body or a transport error -- the manager
527+
must drop the refresh token and re-mint from the held API token rather than
528+
letting the exception escape ``bearer_value()``."""
529+
short_lived = _mint_response(
530+
access_token="eyJ.short.jwt",
531+
refresh_token="rt_doomed",
532+
expires_in=_LEEWAY - 5,
533+
)
534+
# Refresh returns 200 but a non-JSON body -> would raise inside _mint.
535+
refresh_garbage = _FakeResponse(200, b"<html>oops</html>")
536+
remint = _mint_response(access_token="eyJ.reminted.jwt", expires_in=300)
537+
pool = _FakePool([short_lived, refresh_garbage, remint])
538+
mgr = _TokenManager("hd_secret_token", _config(), pool=pool)
539+
540+
assert mgr.bearer_value() == "eyJ.short.jwt"
541+
# Second call: refresh raises internally -> fall back to api_token mint.
542+
assert mgr.bearer_value() == "eyJ.reminted.jwt"
543+
544+
assert len(pool.calls) == 3
545+
assert _form(pool.calls[1]["body"])["grant_type"] == ["refresh_token"]
546+
assert _form(pool.calls[2]["body"])["grant_type"] == ["api_token"]
547+
548+
549+
# --------------------------------------------------------------------------
550+
# auth_settings() reads the token exactly once (no double bearer_value())
551+
# --------------------------------------------------------------------------
552+
553+
554+
def test_auth_settings_reads_token_once(monkeypatch: pytest.MonkeyPatch) -> None:
555+
"""``auth_settings()`` must resolve the bearer token a single time, not
556+
once for the null-check and again for the value -- otherwise it acquires the
557+
manager lock twice per request and a concurrent ``api_key`` reset between the
558+
two reads could yield ``'Bearer ' + None``."""
559+
pool = _FakePool([_mint_response(access_token="eyJ.once.jwt")])
560+
cfg = _config()
561+
mgr = _TokenManager("hd_secret_token", cfg, pool=pool)
562+
cfg._token_manager = mgr
563+
564+
count = {"n": 0}
565+
real = mgr.bearer_value
566+
567+
def counting() -> str:
568+
count["n"] += 1
569+
return real()
570+
571+
monkeypatch.setattr(mgr, "bearer_value", counting)
572+
573+
auth = cfg.auth_settings()
574+
575+
assert _bearer_from(auth) == "Bearer eyJ.once.jwt"
576+
assert count["n"] == 1

0 commit comments

Comments
 (0)