Skip to content

Commit 4eee604

Browse files
feat: add opt-in shared OAuth token cache via object store
Add oauth_token_store to FastApiFrontEndConfig (optional, unset by default). When set, the OAuth token obtained during the authorization-code flow is persisted in the specified object store (e.g. Redis) so that a WebSocket reconnecting to a different pod does not repeat authentication. When unset, the existing in-memory dict is used unchanged. Signed-off-by: Patrick Chin <8509935+thepatrickchin@users.noreply.github.com>
1 parent 61671c5 commit 4eee604

File tree

4 files changed

+190
-20
lines changed

4 files changed

+190
-20
lines changed

packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import asyncio
17+
import json
1718
import logging
1819
import secrets
1920
import time
@@ -32,7 +33,10 @@
3233
from nat.data_models.authentication import AuthFlowType
3334
from nat.data_models.authentication import AuthProviderBaseConfig
3435
from nat.data_models.interactive import _HumanPromptOAuthConsent
36+
from nat.data_models.object_store import NoSuchKeyError
3537
from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler
38+
from nat.object_store.interfaces import ObjectStore
39+
from nat.object_store.models import ObjectStoreItem
3640

3741
logger = logging.getLogger(__name__)
3842

@@ -55,15 +59,15 @@ def __init__(self,
5559
web_socket_message_handler: WebSocketMessageHandler,
5660
auth_timeout_seconds: float = 300.0,
5761
return_url: str | None = None,
58-
token_store: dict | None = None,
62+
token_store: dict | ObjectStore | None = None,
5963
session_id: str | None = None):
6064

6165
self._add_flow_cb: Callable[[str, FlowState], Awaitable[None]] = add_flow_cb
6266
self._remove_flow_cb: Callable[[str], Awaitable[None]] = remove_flow_cb
6367
self._web_socket_message_handler: WebSocketMessageHandler = web_socket_message_handler
6468
self._auth_timeout_seconds: float = auth_timeout_seconds
6569
self._return_url: str | None = return_url
66-
self._token_store: dict | None = token_store
70+
self._token_store: dict | ObjectStore | None = token_store
6771
self._session_id: str | None = session_id
6872

6973
async def authenticate(
@@ -137,7 +141,7 @@ async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProvider
137141
raise ValueError("Redirect-based authentication (use_redirect_auth=True) requires a return URL, "
138142
"but none was configured. Pass return_url when constructing the flow handler.")
139143

140-
cached = self._get_cached_token(config)
144+
cached = await self._get_cached_token(config)
141145
if cached is not None:
142146
logger.debug("OAuth token cache hit for client_id=%s", config.client_id)
143147
return cached
@@ -174,7 +178,7 @@ async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProvider
174178
metadata={
175179
"expires_at": token.get("expires_at"), "raw_token": token
176180
})
177-
self._store_token(config, ctx)
181+
await self._store_token(config, ctx)
178182
return ctx
179183

180184
def _token_cache_key(self, config: OAuth2AuthCodeFlowProviderConfig) -> str | None:
@@ -183,24 +187,73 @@ def _token_cache_key(self, config: OAuth2AuthCodeFlowProviderConfig) -> str | No
183187
return None
184188
return f"{self._session_id}:{config.client_id}:{config.token_url}"
185189

186-
def _get_cached_token(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext | None:
190+
async def _get_cached_token(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext | None:
187191
"""Return a cached, non-expired token for *config*, or None."""
188192
key = self._token_cache_key(config)
189193
if key is None or self._token_store is None:
190194
return None
191-
entry = self._token_store.get(key)
192-
if entry is None:
195+
196+
if isinstance(self._token_store, dict):
197+
entry = self._token_store.get(key)
198+
if entry is None:
199+
return None
200+
ctx, expires_at = entry
201+
if expires_at is not None and time.time() >= expires_at - 60:
202+
del self._token_store[key]
203+
return None
204+
return ctx
205+
206+
# ObjectStore path
207+
try:
208+
item = await self._token_store.get_object(key)
209+
except NoSuchKeyError:
210+
return None
211+
try:
212+
payload = json.loads(item.data)
213+
except Exception:
214+
logger.warning("OAuth token cache entry for key %s is corrupt; evicting", key)
215+
try:
216+
await self._token_store.delete_object(key)
217+
except NoSuchKeyError:
218+
pass
193219
return None
194-
ctx, expires_at = entry
220+
expires_at = payload.get("expires_at")
195221
if expires_at is not None and time.time() >= expires_at - 60:
196-
del self._token_store[key]
222+
try:
223+
await self._token_store.delete_object(key)
224+
except NoSuchKeyError:
225+
pass
197226
return None
198-
return ctx
227+
return AuthenticatedContext(
228+
headers=payload.get("headers"),
229+
query_params=payload.get("query_params"),
230+
cookies=payload.get("cookies"),
231+
body=payload.get("body"),
232+
metadata=payload.get("metadata"),
233+
)
199234

200-
def _store_token(self, config: OAuth2AuthCodeFlowProviderConfig, ctx: AuthenticatedContext) -> None:
235+
async def _store_token(self, config: OAuth2AuthCodeFlowProviderConfig, ctx: AuthenticatedContext) -> None:
201236
"""Cache *ctx* for *config* if caching is available."""
202237
key = self._token_cache_key(config)
203238
if key is None or self._token_store is None:
204239
return
205240
expires_at = ctx.metadata.get("expires_at") if ctx.metadata else None
206-
self._token_store[key] = (ctx, expires_at)
241+
242+
if isinstance(self._token_store, dict):
243+
self._token_store[key] = (ctx, expires_at)
244+
return
245+
246+
# ObjectStore path
247+
payload = {
248+
"headers": dict(ctx.headers) if ctx.headers else None,
249+
"query_params": dict(ctx.query_params) if ctx.query_params else None,
250+
"cookies": dict(ctx.cookies) if ctx.cookies else None,
251+
"body": ctx.body,
252+
"metadata": ctx.metadata,
253+
"expires_at": expires_at,
254+
}
255+
try:
256+
item = ObjectStoreItem(data=json.dumps(payload).encode(), content_type="application/json")
257+
await self._token_store.upsert_object(key, item)
258+
except Exception:
259+
logger.warning("Failed to store OAuth token in object store for key %s", key, exc_info=True)

packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,15 @@ class CrossOriginResourceSharing(BaseModel):
319319
"request to '/static' and files will be served from the object store. The files will be served from the "
320320
"object store at '/static/{file_name}'."))
321321

322+
oauth_token_store: ObjectStoreRef | None = Field(
323+
default=None,
324+
description=(
325+
"Object store reference for sharing the OAuth token cache across processes or replicas. "
326+
"When set, the access token obtained during the OAuth authorization-code flow is persisted "
327+
"in the specified object store (e.g. a Redis store) so that a WebSocket reconnecting to a "
328+
"different pod does not need to repeat authentication. "
329+
"If not set, tokens are cached in process memory, which is sufficient for single-replica deployments."))
330+
322331
disable_legacy_routes: bool = Field(
323332
default=False,
324333
description="Disable the legacy routes for the FastAPI app. If True, the legacy routes are disabled.")

packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from collections.abc import Awaitable
2222
from collections.abc import Callable
2323
from contextlib import asynccontextmanager
24+
from typing import TYPE_CHECKING
25+
26+
if TYPE_CHECKING:
27+
from nat.object_store.interfaces import ObjectStore
2428

2529
from fastapi import FastAPI
2630
from fastapi import Request
@@ -203,8 +207,9 @@ def __init__(self, config: Config):
203207
self._conversation_handlers: dict[str, WebSocketMessageHandler] = {}
204208

205209
# OAuth token cache: maps "{session_id}:{client_id}:{token_url}" to
206-
# (AuthenticatedContext, expires_at_unix_or_None)
207-
self._oauth_token_store: dict[str, tuple] = {}
210+
# (AuthenticatedContext, expires_at_unix_or_None) when using the default in-memory dict.
211+
# Replaced with an ObjectStore-backed instance in add_routes() when oauth_token_store is set in config.
212+
self._oauth_token_store: dict[str, tuple] | ObjectStore = {}
208213

209214
# Track session managers for each route
210215
self._session_managers: list[SessionManager] = []
@@ -331,6 +336,10 @@ async def configure(self, app: FastAPI, builder: WorkflowBuilder):
331336

332337
async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
333338

339+
if self.front_end_config.oauth_token_store:
340+
self._oauth_token_store = await builder.get_object_store_client(self.front_end_config.oauth_token_store)
341+
logger.info("Using object-store-backed OAuth token cache: %s", self.front_end_config.oauth_token_store)
342+
334343
session_manager = await self._create_session_manager(builder)
335344

336345
await add_authorization_route(self, app)

packages/nvidia_nat_core/tests/nat/front_ends/auth_flow_handlers/test_websocket_flow_handler.py

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from nat.data_models.config import Config
3131
from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler
3232
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
33+
from nat.object_store.in_memory_object_store import InMemoryObjectStore
3334
from nat.test.functions import EchoFunctionConfig
3435

3536

@@ -396,11 +397,11 @@ def test_token_cache_key_format(noop_handler, minimal_oauth_config):
396397
assert key == f"sess-1:{minimal_oauth_config.client_id}:{minimal_oauth_config.token_url}"
397398

398399

399-
def test_get_cached_token_miss(noop_handler, minimal_oauth_config):
400+
async def test_get_cached_token_miss(noop_handler, minimal_oauth_config):
400401
"""_get_cached_token returns None when the cache has no entry for the config."""
401402
noop_handler._token_store = {}
402403
noop_handler._session_id = "sess-1"
403-
assert noop_handler._get_cached_token(minimal_oauth_config) is None
404+
assert await noop_handler._get_cached_token(minimal_oauth_config) is None
404405

405406

406407
@pytest.mark.parametrize("expires_at,expect_hit",
@@ -410,37 +411,135 @@ def test_get_cached_token_miss(noop_handler, minimal_oauth_config):
410411
pytest.param(time.time() - 1, False, id="past"),
411412
pytest.param(time.time() + 30, False, id="within_buffer"),
412413
])
413-
def test_get_cached_token_expiry(noop_handler, minimal_oauth_config, expires_at, expect_hit):
414+
async def test_get_cached_token_expiry(noop_handler, minimal_oauth_config, expires_at, expect_hit):
414415
"""_get_cached_token returns the context when valid and evicts it when expired or within the 60s buffer."""
415416
ctx = AuthenticatedContext(headers={"Authorization": "Bearer tok"}, metadata={})
416417
store: dict = {}
417418
noop_handler._token_store = store
418419
noop_handler._session_id = "sess-1"
419420
key = noop_handler._token_cache_key(minimal_oauth_config)
420421
store[key] = (ctx, expires_at)
421-
result = noop_handler._get_cached_token(minimal_oauth_config)
422+
result = await noop_handler._get_cached_token(minimal_oauth_config)
422423
if expect_hit:
423424
assert result is ctx
424425
else:
425426
assert result is None
426427
assert key not in store
427428

428429

429-
def test_store_token_writes_correctly(noop_handler, minimal_oauth_config):
430+
async def test_store_token_writes_correctly(noop_handler, minimal_oauth_config):
430431
"""_store_token writes (ctx, expires_at) to the store under the expected key."""
431432
store: dict = {}
432433
noop_handler._token_store = store
433434
noop_handler._session_id = "sess-1"
434435
expires = 9999999999.0
435436
ctx = AuthenticatedContext(headers={"Authorization": "Bearer tok"}, metadata={"expires_at": expires})
436-
noop_handler._store_token(minimal_oauth_config, ctx)
437+
await noop_handler._store_token(minimal_oauth_config, ctx)
437438
key = noop_handler._token_cache_key(minimal_oauth_config)
438439
assert key in store
439440
stored_ctx, stored_expires = store[key]
440441
assert stored_ctx is ctx
441442
assert stored_expires == expires
442443

443444

445+
# --------------------------------------------------------------------------- #
446+
# ObjectStore-backed token cache tests #
447+
# --------------------------------------------------------------------------- #
448+
449+
450+
async def test_get_cached_token_miss_object_store(noop_handler, minimal_oauth_config):
451+
"""_get_cached_token returns None when the ObjectStore has no entry for the config."""
452+
noop_handler._token_store = InMemoryObjectStore()
453+
noop_handler._session_id = "sess-1"
454+
assert await noop_handler._get_cached_token(minimal_oauth_config) is None
455+
456+
457+
async def test_store_and_retrieve_token_object_store(noop_handler, minimal_oauth_config):
458+
"""_store_token persists to an ObjectStore and _get_cached_token retrieves it."""
459+
noop_handler._token_store = InMemoryObjectStore()
460+
noop_handler._session_id = "sess-1"
461+
expires = time.time() + 3600
462+
ctx = AuthenticatedContext(
463+
headers={"Authorization": "Bearer obj-tok"},
464+
metadata={
465+
"expires_at": expires, "raw_token": {
466+
"access_token": "obj-tok"
467+
}
468+
},
469+
)
470+
await noop_handler._store_token(minimal_oauth_config, ctx)
471+
result = await noop_handler._get_cached_token(minimal_oauth_config)
472+
assert result is not None
473+
assert result.headers["Authorization"] == "Bearer obj-tok"
474+
475+
476+
async def test_get_cached_token_evicts_expired_object_store(noop_handler, minimal_oauth_config):
477+
"""_get_cached_token evicts and returns None for an expired entry in an ObjectStore."""
478+
noop_handler._token_store = InMemoryObjectStore()
479+
noop_handler._session_id = "sess-1"
480+
expired_ctx = AuthenticatedContext(
481+
headers={"Authorization": "Bearer old"},
482+
metadata={"expires_at": time.time() - 1},
483+
)
484+
await noop_handler._store_token(minimal_oauth_config, expired_ctx)
485+
result = await noop_handler._get_cached_token(minimal_oauth_config)
486+
assert result is None
487+
488+
489+
@pytest.mark.usefixtures("set_nat_config_file_env_var")
490+
async def test_second_authenticate_uses_object_store_cache(monkeypatch, mock_server):
491+
"""After a successful flow the token is cached in an ObjectStore; a second call must not trigger OAuth again."""
492+
493+
redirect_port = _free_port()
494+
mock_server.register_client(
495+
client_id="cid",
496+
client_secret="secret",
497+
redirect_base=f"http://localhost:{redirect_port}",
498+
)
499+
500+
cfg_nat = Config(workflow=EchoFunctionConfig())
501+
worker = FastApiFrontEndPluginWorker(cfg_nat)
502+
message_count = [0]
503+
504+
class _DummyWSHandler:
505+
506+
def set_flow_handler(self, _):
507+
return
508+
509+
async def create_websocket_message(self, msg):
510+
message_count[0] += 1
511+
await _complete_oauth_redirect(msg.text, mock_server, worker._outstanding_flows)
512+
513+
object_store = InMemoryObjectStore()
514+
ws_handler = _AuthHandler(
515+
oauth_server=mock_server,
516+
add_flow_cb=worker._add_flow,
517+
remove_flow_cb=worker._remove_flow,
518+
web_socket_message_handler=_DummyWSHandler(),
519+
token_store=object_store,
520+
session_id="test-session",
521+
)
522+
523+
cfg_flow = OAuth2AuthCodeFlowProviderConfig(
524+
client_id="cid",
525+
client_secret="secret",
526+
authorization_url="http://testserver/oauth/authorize",
527+
token_url="http://testserver/oauth/token",
528+
scopes=["read"],
529+
use_pkce=True,
530+
redirect_uri=f"http://localhost:{redirect_port}/auth/redirect",
531+
)
532+
533+
monkeypatch.setattr("click.echo", lambda *_: None, raising=True)
534+
535+
ctx1 = await ws_handler.authenticate(cfg_flow, AuthFlowType.OAUTH2_AUTHORIZATION_CODE)
536+
assert message_count[0] == 1, "OAuth flow should have run exactly once"
537+
538+
ctx2 = await ws_handler.authenticate(cfg_flow, AuthFlowType.OAUTH2_AUTHORIZATION_CODE)
539+
assert message_count[0] == 1, "Second authenticate() must return from object store cache without triggering OAuth"
540+
assert ctx2.headers["Authorization"] == ctx1.headers["Authorization"]
541+
542+
444543
# --------------------------------------------------------------------------- #
445544
# Token-cache integration: second authenticate() returns from cache #
446545
# --------------------------------------------------------------------------- #

0 commit comments

Comments
 (0)