Skip to content

Commit b685539

Browse files
authored
Add an expiration time to cached MCP auth credentials (#1872)
* Remote MCP servers may expire registered auth clients, when this happens the NAT MCP client will need to re-register client credentials. * This adds a new `oauth_client_ttl` configuration attribute to `MCPOAuth2ProviderConfig` * This incorporates changes from PR #1871 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/resources/contributing/index.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. ## Summary by CodeRabbit * **New Features** * Configurable oauth_client_ttl for OAuth2 credential caching (default 270s; 0 disables caching). * **Bug Fixes** * Improved authentication robustness: serialized discovery/registration to avoid races, automatic re-registration on failures, safer handling when credentials or endpoints are missing/expired, and retry logic for registration rejections. * **Documentation** * Documented oauth_client_ttl behavior, defaults, and TTL guidance. * **Tests** * Added tests for credential TTL, cache expiry, and related authentication flows. Authors: - David Gardner (https://github.com/dagardner-nv) - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) Approvers: - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) URL: #1872
1 parent cbb5fed commit b685539

5 files changed

Lines changed: 150 additions & 9 deletions

File tree

docs/source/components/auth/mcp-auth/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ authentication:
4949
Configuration options:
5050
- `server_url`: The URL of the MCP server that requires authentication.
5151
- `redirect_uri`: The redirect URI for the OAuth2 flow. This must match the address where your server is accessible from your browser.
52+
- `oauth_client_ttl`: Amount of time, in seconds, to cache OAuth client credentials obtained via Dynamic Client Registration. Some MCP servers will invalidate client credentials after a certain period, requiring this value to match the timeout setting of the server minus a small safety buffer (for example, 30 seconds). After this period elapses, the client re-registers with the authorization server and obtains a new `client_id`. Defaults to `270` seconds. Set to `0` to disable caching (re-register on every authentication attempt).
5253

5354
To view all configuration options for the `mcp_oauth2` authentication provider, run the following command:
5455
```bash

packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import asyncio
1617
import logging
18+
import time
1719
from collections.abc import Awaitable
1820
from collections.abc import Callable
1921
from urllib.parse import urljoin
@@ -316,7 +318,9 @@ def __init__(self, config: MCPOAuth2ProviderConfig, builder=None):
316318

317319
# Client registration
318320
self._registrar = DynamicClientRegistration(config)
321+
self._credentials_cache_time: float | None = None
319322
self._cached_credentials: OAuth2Credentials | None = None
323+
self._discover_register_lock = asyncio.Lock()
320324

321325
# For the OAuth2 flow
322326
self._auth_code_provider = None
@@ -331,12 +335,48 @@ def __init__(self, config: MCPOAuth2ProviderConfig, builder=None):
331335
if self.config.token_storage_object_store:
332336
# Store object store name, will be resolved later when builder context is available
333337
self._token_storage_object_store_name = self.config.token_storage_object_store
334-
logger.info(f"Configured to use object store '{self._token_storage_object_store_name}' for token storage")
338+
logger.info("Configured to use object store '%s' for token storage", self._token_storage_object_store_name)
335339
else:
336340
# Default: use in-memory token storage
337341
from nat.authentication.token_storage import InMemoryTokenStorage
338342
self._token_storage = InMemoryTokenStorage()
339343

344+
def _invalidate_cached_registration(self, reason: str) -> None:
345+
"""Invalidate cached OAuth client registration and auth provider."""
346+
previous_client_id = self._cached_credentials.client_id if self._cached_credentials else None
347+
self._credentials_cache_time = None
348+
self._cached_credentials = None
349+
self._auth_code_provider = None
350+
logger.info("Invalidated cached OAuth2 registration: reason=%s previous_client_id=%s",
351+
reason,
352+
previous_client_id)
353+
354+
def _is_cached_credentials_expired(self) -> bool:
355+
"""Check if cached credentials are expired."""
356+
if self._credentials_cache_time is None:
357+
return True
358+
359+
# `0` means "do not reuse across attempts", not "invalidate within the same attempt".
360+
if self.config.oauth_client_ttl == 0:
361+
return False
362+
363+
return (time.monotonic() - self._credentials_cache_time) >= self.config.oauth_client_ttl
364+
365+
def _is_redirect_uri_registration_error(self, error: Exception) -> bool:
366+
"""Check if error indicates AS rejected redirect URI registration for this client."""
367+
msg = str(error).lower()
368+
return ("redirect uri" in msg and "not registered for client" in msg)
369+
370+
async def _discover_and_register_locked(self,
371+
response: httpx.Response | None = None,
372+
*,
373+
force_refresh: bool = False):
374+
"""Serialize discovery/registration to avoid races across concurrent auth flows."""
375+
async with self._discover_register_lock:
376+
if force_refresh:
377+
self._invalidate_cached_registration(reason="forced-refresh")
378+
await self._discover_and_register(response=response)
379+
340380
def _set_custom_auth_callback(self,
341381
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
342382
Awaitable[AuthenticatedContext]]):
@@ -364,9 +404,19 @@ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult
364404

365405
response = kwargs.get('response')
366406
if response and response.status_code == 401:
367-
await self._discover_and_register(response=response)
407+
await self._discover_and_register_locked(response=response)
368408

369-
return await self._nat_oauth2_authenticate(user_id=user_id)
409+
try:
410+
return await self._nat_oauth2_authenticate(user_id=user_id)
411+
except RuntimeError as e:
412+
# Some AS deployments intermittently reject authorize requests with
413+
# "redirect URI not registered" for a cached client_id. Force one
414+
# re-registration attempt to self-heal before failing the request.
415+
if self._is_redirect_uri_registration_error(e):
416+
logger.warning("Detected redirect URI registration error; forcing re-registration and retry")
417+
await self._discover_and_register_locked(response=response, force_refresh=True)
418+
return await self._nat_oauth2_authenticate(user_id=user_id)
419+
raise
370420

371421
@property
372422
def _effective_scopes(self) -> list[str]:
@@ -382,12 +432,12 @@ async def _discover_and_register(self, response: httpx.Response | None = None):
382432
self._cached_endpoints, endpoints_changed = await self._discoverer.discover(response=response)
383433
if endpoints_changed:
384434
logger.info("OAuth2 endpoints: %s", self._cached_endpoints)
385-
self._cached_credentials = None # invalidate credentials tied to old AS
386-
self._auth_code_provider = None
435+
self._invalidate_cached_registration(reason="endpoints-changed")
387436
effective_scopes = self._effective_scopes
388437

389438
# Client registration
390-
if not self._cached_credentials:
439+
if (not self._cached_credentials or self.config.oauth_client_ttl == 0 or self._is_cached_credentials_expired()):
440+
self._invalidate_cached_registration(reason="registration-expired")
391441
if self.config.client_id:
392442
# Manual registration mode
393443
self._cached_credentials = OAuth2Credentials(
@@ -400,12 +450,19 @@ async def _discover_and_register(self, response: httpx.Response | None = None):
400450
self._cached_credentials = await self._registrar.register(self._cached_endpoints, effective_scopes)
401451
logger.info("Registered OAuth2 client: %s", self._cached_credentials.client_id)
402452

453+
self._credentials_cache_time = time.monotonic()
454+
403455
async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResult:
404456
"""Perform the OAuth2 flow using MCP-specific authentication flow handler."""
405457
from nat.authentication.oauth2.oauth2_auth_code_flow_provider import OAuth2AuthCodeFlowProvider
406458

407-
if not self._cached_endpoints or not self._cached_credentials:
459+
if (not self._cached_endpoints or not self._cached_credentials or self._is_cached_credentials_expired()):
408460
# if discovery is yet to to be done return empty auth result
461+
logger.warning(
462+
"OAuth2 endpoints or credentials not available or expired for user_id=%s. "
463+
"Discovery and registration must be performed before authentication. "
464+
"Returning empty AuthResult.",
465+
user_id)
409466
return AuthResult(credentials=[], token_expires_at=None, raw={})
410467

411468
endpoints = self._cached_endpoints
@@ -422,8 +479,10 @@ async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResu
422479
logger.info(f"Initialized token storage with object store '{self._token_storage_object_store_name}'")
423480
except Exception as e:
424481
logger.warning(
425-
f"Failed to resolve object store '{self._token_storage_object_store_name}' for token storage: {e}. "
426-
"Falling back to in-memory storage.")
482+
"Failed to resolve object store '%s' for token storage: %s. Falling back to in-memory storage.",
483+
self._token_storage_object_store_name,
484+
e,
485+
)
427486
from nat.authentication.token_storage import InMemoryTokenStorage
428487
self._token_storage = InMemoryTokenStorage()
429488

packages/nvidia_nat_mcp/src/nat/plugins/mcp/auth/auth_provider_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
5656
default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
5757
allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
5858

59+
# OAuth client credential caching
60+
oauth_client_ttl: float = Field(default=270.0,
61+
ge=0.0,
62+
description="Amount of time, in seconds, to cache oauth client credentials. "
63+
"Setting this to 0 disables caching.")
64+
5965
# Token storage configuration
6066
token_storage_object_store: str | None = Field(
6167
default=None,

packages/nvidia_nat_mcp/tests/client/test_mcp_auth_provider.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import asyncio
1617
from unittest.mock import AsyncMock
1718
from unittest.mock import MagicMock
1819
from unittest.mock import patch
@@ -752,3 +753,75 @@ async def test_effective_scopes_config_overrides_discovered(self, mock_config):
752753

753754
scopes = provider._effective_scopes
754755
assert scopes == ['config_scope'] # Config should take precedence
756+
757+
@pytest.mark.parametrize("oauth_client_ttl", [0.01, 0.0], ids=["0.01", "disabled"])
758+
async def test_oauth_client_ttl(self, mock_endpoints, oauth_client_ttl):
759+
"""Test that expired oauth_client_ttl causes re-registration with a new client_id."""
760+
config = MCPOAuth2ProviderConfig(
761+
server_url="https://example.com/mcp", # type: ignore
762+
redirect_uri="https://example.com/callback", # type: ignore
763+
enable_dynamic_registration=True,
764+
oauth_client_ttl=oauth_client_ttl,
765+
)
766+
provider = MCPOAuth2Provider(config)
767+
768+
first_credentials = OAuth2Credentials(client_id="first_client_id", client_secret="secret")
769+
second_credentials = OAuth2Credentials(client_id="second_client_id", client_secret="secret")
770+
771+
with patch.object(provider._discoverer, 'discover') as mock_discover:
772+
mock_discover.return_value = (mock_endpoints, False)
773+
774+
with patch.object(provider._registrar, 'register') as mock_register:
775+
mock_register.return_value = first_credentials
776+
777+
await provider._discover_and_register()
778+
779+
assert provider._cached_credentials.client_id == "first_client_id"
780+
assert provider._auth_code_provider is None # not built yet
781+
first_cache_time = provider._credentials_cache_time
782+
783+
# Wait for TTL to expire
784+
await asyncio.sleep(oauth_client_ttl)
785+
786+
mock_register.return_value = second_credentials
787+
await provider._discover_and_register()
788+
789+
assert provider._cached_credentials.client_id == "second_client_id"
790+
assert provider._credentials_cache_time > first_cache_time
791+
assert provider._auth_code_provider is None # reset on re-registration
792+
assert mock_register.call_count == 2
793+
794+
async def test_oauth_client_ttl_not_expired(self, mock_endpoints):
795+
"""Test that credentials are not refreshed when oauth_client_ttl has not elapsed."""
796+
config = MCPOAuth2ProviderConfig(
797+
server_url="https://example.com/mcp", # type: ignore
798+
redirect_uri="https://example.com/callback", # type: ignore
799+
enable_dynamic_registration=True,
800+
oauth_client_ttl=100,
801+
)
802+
provider = MCPOAuth2Provider(config)
803+
804+
first_credentials = OAuth2Credentials(client_id="first_client_id", client_secret="secret")
805+
second_credentials = OAuth2Credentials(client_id="second_client_id", client_secret="secret")
806+
807+
with patch.object(provider._discoverer, 'discover') as mock_discover:
808+
mock_discover.return_value = (mock_endpoints, False)
809+
810+
with patch.object(provider._registrar, 'register') as mock_register:
811+
mock_register.return_value = first_credentials
812+
813+
await provider._discover_and_register()
814+
815+
assert provider._cached_credentials.client_id == "first_client_id"
816+
first_cache_time = provider._credentials_cache_time
817+
818+
# Wait well under the TTL
819+
await asyncio.sleep(0.01)
820+
821+
mock_register.return_value = second_credentials
822+
await provider._discover_and_register()
823+
824+
# Credentials should be unchanged — no re-registration occurred
825+
assert provider._cached_credentials.client_id == "first_client_id"
826+
assert provider._credentials_cache_time == first_cache_time
827+
assert mock_register.call_count == 1

packages/nvidia_nat_mcp/tests/client/test_mcp_token_storage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import time
1617
from datetime import UTC
1718
from datetime import datetime
1819
from datetime import timedelta
@@ -338,6 +339,7 @@ async def test_token_storage_lazy_resolution(self, mock_config, sample_auth_resu
338339
token_url="https://auth.example.com/token", # type: ignore
339340
)
340341
provider._cached_credentials = OAuth2Credentials(client_id="test", client_secret="secret")
342+
provider._credentials_cache_time = time.time() # A non-none value to indicate credentials are "cached"
341343

342344
# Trigger authentication which should resolve the object store
343345
with patch('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider'

0 commit comments

Comments
 (0)