Skip to content

Commit f379dfa

Browse files
committed
fix(oauth): preserve TTL on refresh, age via updated_at, reject fractional floats
Addresses Oracle second-pass review findings on dd3c25f9: - _refresh_access_token: when the refresh response omits expires_in, preserve the prior TTL by computing (expires_at - updated_at) from the existing record instead of clearing expires_at to NULL. Clearing it caused proactive refresh to stop after the first such response, since _is_token_expired returns False for NULL. Falls back to NULL only when the token had no prior expiry (genuine no-expiry providers like GitHub OAuth Apps). New _preserve_prior_ttl() module-level helper handles timezone-naive timestamps and non-positive deltas. - cleanup_expired_tokens: NULL-expiry rows now age via updated_at, not created_at. store_tokens advances updated_at on re-authorization, so using created_at would delete tokens that were re-authorized recently but originally created more than max_age_days ago. - parse_expires_in: tighten validation against int() truncation. Previously expires_in=-0.5 would silently pass as 0 because int(-0.5) == 0 (Python truncates toward zero). Now: sign-check the original numeric before conversion, and reject non-integer floats (RFC 6749 §5.1 specifies integer seconds). Order matters - sign check fires first since 'negative' is the more fundamental violation. Tests: - test_refresh_without_expires_in_clears_expiry replaced by: test_refresh_without_expires_in_preserves_prior_ttl (asserts new TTL ~= prior TTL after refresh) and test_refresh_without_expires_in_no_prior_ttl_stays_none. - test_cleanup_expired_tokens_targets_null_expires_at now asserts the query uses updated_at AND does NOT use created_at. - test_parse_expires_in_negative_raises extended with -0.5 and -3600.7. - New test_parse_expires_in_non_integer_float_raises covers 3600.5, 0.5, 3600.7. Signed-off-by: Jonathan Springer <jps@s390x.com>
1 parent e8176da commit f379dfa

4 files changed

Lines changed: 122 additions & 14 deletions

File tree

mcpgateway/services/oauth_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,13 @@ def parse_expires_in(token_response: Dict[str, Any]) -> Optional[int]:
17111711
# Reject bools (True/False are int subclasses in Python) and any non-scalar types.
17121712
if isinstance(raw, bool) or not isinstance(raw, (int, float, str)):
17131713
raise OAuthError(f"Invalid expires_in from OAuth provider: {raw!r}")
1714+
# Sign-check the original numeric BEFORE int() truncation, otherwise int(-0.5) == 0
1715+
# would bypass the negative check.
1716+
if isinstance(raw, (int, float)) and raw < 0:
1717+
raise OAuthError(f"Invalid expires_in from OAuth provider (negative): {raw}")
1718+
# RFC 6749 §5.1 specifies integer seconds; reject non-integral floats explicitly.
1719+
if isinstance(raw, float) and not raw.is_integer():
1720+
raise OAuthError(f"Invalid expires_in from OAuth provider (non-integer): {raw}")
17141721
try:
17151722
value = int(raw)
17161723
except (TypeError, ValueError) as exc:

mcpgateway/services/token_storage_service.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,48 @@
2929
logger = logging.getLogger(__name__)
3030

3131

32+
def _preserve_prior_ttl(token_record: OAuthToken) -> Optional[int]:
33+
"""Compute the token's prior TTL in seconds, or ``None`` if not derivable.
34+
35+
Used when an OAuth refresh response omits ``expires_in`` but the token
36+
previously had a finite lifetime - the gateway preserves the original
37+
issuance TTL by computing ``expires_at - updated_at`` from the existing
38+
record. Returns ``None`` when either timestamp is missing or the difference
39+
is non-positive (clock skew or already-expired records).
40+
41+
Args:
42+
token_record: Existing OAuth token row, before the refresh applies.
43+
44+
Returns:
45+
Positive integer seconds of prior TTL, or ``None``.
46+
47+
Examples:
48+
>>> from types import SimpleNamespace
49+
>>> from datetime import datetime, timedelta, timezone
50+
>>> issued = datetime(2026, 1, 1, tzinfo=timezone.utc)
51+
>>> rec = SimpleNamespace(expires_at=issued + timedelta(hours=1), updated_at=issued)
52+
>>> _preserve_prior_ttl(rec)
53+
3600
54+
>>> _preserve_prior_ttl(SimpleNamespace(expires_at=None, updated_at=issued)) is None
55+
True
56+
>>> _preserve_prior_ttl(SimpleNamespace(expires_at=issued, updated_at=issued + timedelta(hours=1))) is None
57+
True
58+
"""
59+
prev_expires_at = token_record.expires_at
60+
prev_updated_at = token_record.updated_at
61+
if prev_expires_at is None or prev_updated_at is None:
62+
return None
63+
# Normalize naive timestamps to UTC for the subtraction.
64+
if prev_expires_at.tzinfo is None:
65+
prev_expires_at = prev_expires_at.replace(tzinfo=timezone.utc)
66+
if prev_updated_at.tzinfo is None:
67+
prev_updated_at = prev_updated_at.replace(tzinfo=timezone.utc)
68+
prev_ttl = int((prev_expires_at - prev_updated_at).total_seconds())
69+
if prev_ttl <= 0:
70+
return None
71+
return prev_ttl
72+
73+
3274
class TokenStorageService:
3375
"""Manages OAuth token storage and retrieval.
3476
@@ -330,15 +372,30 @@ def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None:
330372
# Update the token record
331373
token_record.access_token = encrypted_access
332374
token_record.refresh_token = encrypted_refresh
375+
now = datetime.now(timezone.utc)
333376
if expires_in is not None:
334-
token_record.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
377+
token_record.expires_at = now + timedelta(seconds=expires_in)
335378
else:
336-
logger.info(
337-
"No expires_in on refresh response for gateway %s; clearing local expiry",
338-
SecurityValidator.sanitize_log_message(token_record.gateway_id),
339-
)
340-
token_record.expires_at = None
341-
token_record.updated_at = datetime.now(timezone.utc)
379+
# Refresh response omitted expires_in. If the token previously had a finite
380+
# expiry, preserve the prior TTL (expires_at - updated_at) so proactive
381+
# refresh keeps working - clearing it outright would cause _is_token_expired
382+
# to return False forever and stop the refresh loop. If there was no prior
383+
# expiry, leave it as None (provider-level "no known lifetime").
384+
preserved_ttl = _preserve_prior_ttl(token_record)
385+
if preserved_ttl is not None:
386+
logger.info(
387+
"No expires_in on refresh response for gateway %s; preserving prior TTL of %d seconds",
388+
SecurityValidator.sanitize_log_message(token_record.gateway_id),
389+
preserved_ttl,
390+
)
391+
token_record.expires_at = now + timedelta(seconds=preserved_ttl)
392+
else:
393+
logger.info(
394+
"No expires_in on refresh response for gateway %s; no prior TTL to preserve",
395+
SecurityValidator.sanitize_log_message(token_record.gateway_id),
396+
)
397+
token_record.expires_at = None
398+
token_record.updated_at = now
342399

343400
self.db.commit()
344401
logger.info(f"Successfully refreshed token for gateway {token_record.gateway_id}, user {token_record.app_user_email}")
@@ -500,14 +557,18 @@ async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int:
500557
1. Tokens whose ``expires_at`` is older than the cutoff (the original
501558
"expired more than N days ago" behaviour).
502559
2. Tokens with ``expires_at IS NULL`` (provider omitted ``expires_in``)
503-
whose ``created_at`` is older than the cutoff. ``NULL < <cutoff>``
560+
whose ``updated_at`` is older than the cutoff. ``NULL < <cutoff>``
504561
evaluates to ``NULL`` in SQL three-valued logic, so without this
505-
branch those rows would never age out.
562+
branch those rows would never age out. ``updated_at`` (rather than
563+
``created_at``) is the right freshness signal because
564+
``store_tokens`` advances it on re-authorization, so a recently
565+
re-authorized token isn't deleted just because its original row was
566+
old.
506567
507568
Args:
508569
max_age_days: Maximum age of tokens to keep, measured from
509570
``expires_at`` for tokens with a known expiry and from
510-
``created_at`` for tokens with no provider-supplied expiry.
571+
``updated_at`` for tokens with no provider-supplied expiry.
511572
512573
Returns:
513574
Number of tokens cleaned up
@@ -526,7 +587,7 @@ async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int:
526587

527588
stale_filter = or_(
528589
OAuthToken.expires_at < cutoff_date,
529-
and_(OAuthToken.expires_at.is_(None), OAuthToken.created_at < cutoff_date),
590+
and_(OAuthToken.expires_at.is_(None), OAuthToken.updated_at < cutoff_date),
530591
)
531592
result = self.db.execute(delete(OAuthToken).where(stale_filter))
532593
count = result.rowcount

tests/unit/mcpgateway/services/test_oauth_manager.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ def test_parse_expires_in_explicit_null_returns_none():
6969
assert parse_expires_in({"expires_in": None}) is None
7070

7171

72-
@pytest.mark.parametrize("bad_value", [-1, -3600, "-1"])
72+
@pytest.mark.parametrize("bad_value", [-1, -3600, "-1", -0.5, -3600.7])
7373
def test_parse_expires_in_negative_raises(bad_value):
74+
"""Negative numerics raise even when int() would silently truncate (-0.5 -> 0)."""
7475
with pytest.raises(OAuthError, match="negative"):
7576
parse_expires_in({"expires_in": bad_value})
7677

@@ -81,6 +82,13 @@ def test_parse_expires_in_garbage_string_raises(bad_value):
8182
parse_expires_in({"expires_in": bad_value})
8283

8384

85+
@pytest.mark.parametrize("bad_value", [3600.5, 0.5, 3600.7])
86+
def test_parse_expires_in_non_integer_float_raises(bad_value):
87+
"""RFC 6749 §5.1 specifies integer seconds; non-integer floats are rejected."""
88+
with pytest.raises(OAuthError, match="non-integer"):
89+
parse_expires_in({"expires_in": bad_value})
90+
91+
8492
@pytest.mark.parametrize("bad_value", [True, False, [3600], {"seconds": 3600}, object()])
8593
def test_parse_expires_in_non_scalar_raises(bad_value):
8694
with pytest.raises(OAuthError, match="Invalid expires_in"):

tests/unit/mcpgateway/services/test_token_storage_service.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,44 @@ async def test_refresh_success(service, mock_db):
339339

340340

341341
@pytest.mark.asyncio
342-
async def test_refresh_without_expires_in_clears_expiry(service, mock_db):
342+
async def test_refresh_without_expires_in_preserves_prior_ttl(service, mock_db):
343+
"""Refresh response missing expires_in: preserve the prior TTL so proactive refresh keeps working."""
343344
gw = MagicMock(oauth_config={"token_url": "https://token", "client_id": "cid"}, url="https://gw.com")
344345
mock_db.query.return_value.filter.return_value.first.return_value = gw
345346
mock_oauth_manager = MagicMock()
346347
mock_oauth_manager.refresh_token = AsyncMock(return_value={"access_token": "new_access", "refresh_token": "new_refresh"})
348+
349+
# Token issued 100 seconds ago with a 1-hour TTL.
347350
record = _make_token_record()
351+
issued = datetime.now(tz=timezone.utc) - timedelta(seconds=100)
352+
record.updated_at = issued
353+
record.expires_at = issued + timedelta(seconds=3600)
354+
355+
with patch("mcpgateway.services.oauth_manager.OAuthManager", return_value=mock_oauth_manager):
356+
result = await service._refresh_access_token(record)
357+
358+
assert result == "new_access"
359+
# Prior TTL preserved: new expires_at should be ~3600s after the refresh moment (now).
348360
assert record.expires_at is not None
361+
delta = (record.expires_at - record.updated_at).total_seconds()
362+
assert 3599 <= delta <= 3601
363+
364+
365+
@pytest.mark.asyncio
366+
async def test_refresh_without_expires_in_no_prior_ttl_stays_none(service, mock_db):
367+
"""Refresh response missing expires_in AND no prior TTL: expires_at stays None."""
368+
gw = MagicMock(oauth_config={"token_url": "https://token", "client_id": "cid"}, url="https://gw.com")
369+
mock_db.query.return_value.filter.return_value.first.return_value = gw
370+
mock_oauth_manager = MagicMock()
371+
mock_oauth_manager.refresh_token = AsyncMock(return_value={"access_token": "new_access", "refresh_token": "new_refresh"})
372+
373+
# Token had no prior expiry (e.g. GitHub OAuth Apps).
374+
record = _make_token_record()
375+
record.expires_at = None
376+
349377
with patch("mcpgateway.services.oauth_manager.OAuthManager", return_value=mock_oauth_manager):
350378
result = await service._refresh_access_token(record)
379+
351380
assert result == "new_access"
352381
assert record.expires_at is None
353382

@@ -568,7 +597,10 @@ async def test_cleanup_expired_tokens_targets_null_expires_at(service, mock_db):
568597
delete_stmt = mock_db.execute.call_args.args[0]
569598
rendered = str(delete_stmt.compile(compile_kwargs={"literal_binds": True})).lower()
570599
assert "expires_at is null" in rendered
571-
assert "created_at" in rendered
600+
# NULL-expires_at rows must be aged out by updated_at (re-auth advances it),
601+
# not created_at (which would delete recently re-authorized tokens).
602+
assert "updated_at" in rendered
603+
assert "created_at" not in rendered
572604

573605

574606
# ---------- token_type validation in get_user_token ----------

0 commit comments

Comments
 (0)