Skip to content

Commit 23af716

Browse files
feat: add pre_authenticate option for authorization providers
- This makes pre-authentication an opt-in feature - When True, authentication for the provider is triggered at WebSocket connection time before any user message is submitted. - When False (default), authentication only occurs when the workflow explicitly requires it. Signed-off-by: Patrick Chin <8509935+thepatrickchin@users.noreply.github.com>
1 parent 9f99cdb commit 23af716

File tree

4 files changed

+19
-3
lines changed

4 files changed

+19
-3
lines changed

examples/front_ends/simple_auth/src/nat_simple_auth/configs/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ authentication:
5656
- email
5757
client_id: ${NAT_OAUTH_CLIENT_ID}
5858
client_secret: ${NAT_OAUTH_CLIENT_SECRET}
59+
pre_authenticate: false
5960
use_pkce: false
6061
use_redirect_auth: false
6162

packages/nvidia_nat_core/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class OAuth2AuthCodeFlowProviderConfig(AuthProviderBaseConfig, name="oauth2_auth
4141
"remains open. When True, the browser navigates to the OAuth login page directly and is "
4242
"redirected back after authentication completes."))
4343

44+
pre_authenticate: bool = Field(
45+
default=False,
46+
description=("When True, authentication is triggered at WebSocket connection time before the user submits "
47+
"their first prompt. When False (default), authentication is deferred until the workflow first "
48+
"requires credentials. Only applies in nat serve mode."))
49+
4450
authorization_kwargs: dict[str, str] | None = Field(description=("Additional keyword arguments for the "
4551
"authorization request."),
4652
default=None)

packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,12 @@ async def authenticate(
7878
async def pre_authenticate(self, auth_providers: dict[str, AuthProviderBaseConfig]) -> None:
7979
"""Run auth for every configured OAuth2 provider before the first user message.
8080
81+
Only providers with pre_authenticate option set in their config are processed.
8182
Returns immediately if tokens are already cached. Otherwise triggers the OAuth
8283
redirect so the user authenticates at page load rather than mid-workflow.
8384
"""
8485
for provider_config in auth_providers.values():
85-
if isinstance(provider_config, OAuth2AuthCodeFlowProviderConfig):
86+
if isinstance(provider_config, OAuth2AuthCodeFlowProviderConfig) and provider_config.pre_authenticate:
8687
await self.authenticate(provider_config, AuthFlowType.OAUTH2_AUTHORIZATION_CODE)
8788

8889
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:

packages/nvidia_nat_core/tests/nat/front_ends/auth_flow_handlers/test_websocket_flow_handler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ async def test_pre_authenticate_skips_non_oauth2_providers(noop_handler):
508508
await noop_handler.pre_authenticate({"my_api_key": api_key_config})
509509

510510

511+
async def test_pre_authenticate_skips_oauth2_provider_flag_false(noop_handler, minimal_oauth_config):
512+
"""pre_authenticate does not trigger auth for OAuth2 providers with pre_authenticate=False (the default)."""
513+
# minimal_oauth_config has pre_authenticate=False (the default); if the guard were absent this would hang
514+
await noop_handler.pre_authenticate({"my_provider": minimal_oauth_config})
515+
516+
511517
async def test_pre_authenticate_uses_cached_token(minimal_oauth_config):
512518
"""pre_authenticate returns immediately without calling create_websocket_message on a cache hit."""
513519

@@ -526,15 +532,17 @@ async def create_websocket_message(self, msg):
526532

527533
ctx = AuthenticatedContext(headers={"Authorization": "Bearer cached-tok"}, metadata={"expires_at": None})
528534
store: dict = {}
535+
# Enable pre_authenticate so the cache lookup is actually reached
536+
active_config = minimal_oauth_config.model_copy(update={"pre_authenticate": True})
529537
handler = WebSocketAuthenticationFlowHandler(
530538
add_flow_cb=_noop_add,
531539
remove_flow_cb=_noop_remove,
532540
web_socket_message_handler=_CountingWSHandler(),
533541
token_store=store,
534542
session_id="sess-1",
535543
)
536-
key = handler._token_cache_key(minimal_oauth_config)
544+
key = handler._token_cache_key(active_config)
537545
store[key] = (ctx, time.time() + 3600)
538546

539-
await handler.pre_authenticate({"my_provider": minimal_oauth_config})
547+
await handler.pre_authenticate({"my_provider": active_config})
540548
assert message_count[0] == 0, "pre_authenticate must not trigger OAuth when token is cached"

0 commit comments

Comments
 (0)