Skip to content

Commit 8052dcf

Browse files
committed
feat(a365): pluggable token extractor and token cache for telemetry
- Add TokenExtractor protocol and registry for A365 auth - Add token cache with expiry for proactive refresh - Update telemetry register and exporter for deferred token resolution Signed-off-by: afourniernv <afournier@nvidia.com>
1 parent 3e5c46f commit 8052dcf

File tree

4 files changed

+150
-499
lines changed

4 files changed

+150
-499
lines changed

packages/nvidia_nat_a365/src/nat/plugins/a365/telemetry/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@
1818

1919
# Import register module to ensure registration happens
2020
from . import register # noqa: F401
21+
from .register import TokenExtractor, register_token_extractor
22+
23+
__all__ = ["TokenExtractor", "register_token_extractor"]

packages/nvidia_nat_a365/src/nat/plugins/a365/telemetry/a365_exporter.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525

2626
from nat.builder.context import ContextState
2727
from nat.plugins.a365.exceptions import A365AuthenticationError, A365SDKError
28+
from nat.plugins.a365.telemetry.register import (
29+
_get_token_extractor,
30+
_raise_no_bearer_token,
31+
)
2832
from nat.plugins.opentelemetry.otel_span import OtelSpan
2933
from nat.plugins.opentelemetry.otel_span_exporter import OtelSpanExporter
3034
from opentelemetry.sdk.trace import Event as OtelEvent
@@ -157,8 +161,12 @@ def __init__(
157161
token_cache=None,
158162
auth_ref=None,
159163
builder=None,
164+
token_extractor=None,
160165
):
161166
"""Initialize the A365 exporter."""
167+
self._token_extractor = (
168+
token_extractor if token_extractor is not None else _get_token_extractor(None)
169+
)
162170
super().__init__(
163171
context_state=context_state,
164172
batch_size=batch_size,
@@ -213,22 +221,9 @@ async def _resolve_auth_once(self) -> None:
213221
auth_result = await auth_provider.authenticate(user_id=user_id)
214222
if not auth_result.credentials:
215223
raise A365AuthenticationError("No credentials available from auth provider")
216-
from nat.data_models.authentication import BearerTokenCred, HeaderCred
217-
from nat.authentication.interfaces import AUTHORIZATION_HEADER
218-
token = None
219-
for cred in auth_result.credentials:
220-
if isinstance(cred, BearerTokenCred):
221-
token = cred.token.get_secret_value()
222-
break
223-
if isinstance(cred, HeaderCred) and cred.name == AUTHORIZATION_HEADER:
224-
hv = cred.value.get_secret_value()
225-
token = hv[7:] if hv.startswith("Bearer ") else hv
226-
break
224+
token = self._token_extractor(auth_result)
227225
if token is None:
228-
raise A365AuthenticationError(
229-
f"No bearer token in credentials. "
230-
f"Types: {[type(c).__name__ for c in auth_result.credentials]}"
231-
)
226+
_raise_no_bearer_token(auth_result)
232227
self._token_cache.update_token(token, auth_result.token_expires_at)
233228
self._auth_provider = auth_provider
234229
except Exception as e:
@@ -262,22 +257,7 @@ async def _refresh_token_if_needed(self) -> None:
262257
logger.warning("Token refresh failed: no credentials available")
263258
return
264259

265-
from nat.data_models.authentication import BearerTokenCred, HeaderCred
266-
from nat.authentication.interfaces import AUTHORIZATION_HEADER
267-
268-
token: str | None = None
269-
for cred in auth_result.credentials:
270-
if isinstance(cred, BearerTokenCred):
271-
token = cred.token.get_secret_value()
272-
break
273-
elif isinstance(cred, HeaderCred) and cred.name == AUTHORIZATION_HEADER:
274-
header_value = cred.value.get_secret_value()
275-
if header_value.startswith("Bearer "):
276-
token = header_value[7:] # Remove "Bearer " prefix
277-
else:
278-
token = header_value
279-
break
280-
260+
token = self._token_extractor(auth_result)
281261
if token is None:
282262
logger.warning(
283263
f"No bearer token found in refreshed credentials. "

packages/nvidia_nat_a365/src/nat/plugins/a365/telemetry/register.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
import logging
1818
from collections.abc import Callable
1919
from datetime import datetime, timedelta, timezone
20+
from typing import TYPE_CHECKING, Protocol
21+
22+
if TYPE_CHECKING:
23+
from nat.data_models.authentication import AuthResult
2024

2125
from pydantic import Field
2226

@@ -25,39 +29,84 @@
2529
from nat.data_models.component_ref import AuthenticationRef
2630
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
2731
from nat.observability.mixin.batch_config_mixin import BatchConfigMixin
28-
from nat.plugins.a365.exceptions import A365AuthenticationError
32+
from nat.plugins.a365.exceptions import A365AuthenticationError, A365ConfigurationError
2933

3034
logger = logging.getLogger(__name__)
3135

36+
# --- Pluggable token extractor (interface + dependency injection) ---
3237

33-
def _extract_token_from_auth_result(auth_result) -> str:
34-
"""Extract bearer token from AuthResult credentials.
38+
_TOKEN_EXTRACTOR_SUPPORTED = (
39+
"BearerTokenCred or HeaderCred(Authorization)"
40+
)
3541

36-
Args:
37-
auth_result: AuthResult from auth provider
3842

39-
Returns:
40-
Bearer token string
43+
class TokenExtractor(Protocol):
44+
"""Callable that extracts a bearer token from NAT's AuthResult.
4145
42-
Raises:
43-
A365AuthenticationError: If no bearer token found in credentials
46+
Used when the default (BearerTokenCred or HeaderCred(Authorization)) does not
47+
match your auth provider's credential shape. Register a custom extractor with
48+
register_token_extractor(name, callable) and set token_extractor=name in config.
49+
"""
50+
51+
def __call__(self, auth_result: "AuthResult") -> str | None: ...
52+
53+
54+
def _default_token_extractor(auth_result: "AuthResult") -> str | None:
55+
"""Default extractor: BearerTokenCred or HeaderCred(Authorization).
56+
57+
Returns the bearer token string, or None if neither credential type is present.
58+
Caller should raise A365AuthenticationError with a clear message when None.
4459
"""
4560
from nat.data_models.authentication import BearerTokenCred, HeaderCred
4661
from nat.authentication.interfaces import AUTHORIZATION_HEADER
4762

4863
for cred in auth_result.credentials:
4964
if isinstance(cred, BearerTokenCred):
5065
return cred.token.get_secret_value()
51-
elif isinstance(cred, HeaderCred) and cred.name == AUTHORIZATION_HEADER:
52-
header_value = cred.value.get_secret_value()
53-
# Strip "Bearer " prefix if present
54-
if header_value.startswith("Bearer "):
55-
return header_value[7:] # Remove "Bearer " prefix
56-
return header_value
66+
if isinstance(cred, HeaderCred) and cred.name == AUTHORIZATION_HEADER:
67+
raw = cred.value.get_secret_value()
68+
return raw[7:] if raw.startswith("Bearer ") else raw
69+
return None
70+
71+
72+
_TOKEN_EXTRACTOR_REGISTRY: dict[str, Callable[["AuthResult"], str | None]] = {
73+
"default": _default_token_extractor,
74+
}
75+
5776

77+
def register_token_extractor(name: str, extractor: Callable[["AuthResult"], str | None]) -> None:
78+
"""Register a custom token extractor for A365 telemetry.
79+
80+
Use when your auth provider returns credentials in a shape the default extractor
81+
does not understand (e.g. a new NAT credential type). Then set
82+
token_extractor=\"name\" in your a365 telemetry exporter config.
83+
84+
Args:
85+
name: Name to use in config (e.g. \"my_provider\").
86+
extractor: Callable (AuthResult) -> str | None. Return the bearer token or None.
87+
"""
88+
_TOKEN_EXTRACTOR_REGISTRY[name] = extractor
89+
90+
91+
def _get_token_extractor(name: str | None) -> Callable[["AuthResult"], str | None]:
92+
if name is None or name == "default":
93+
return _default_token_extractor
94+
if name not in _TOKEN_EXTRACTOR_REGISTRY:
95+
raise A365ConfigurationError(
96+
f"Unknown token_extractor '{name}'. "
97+
f"Registered: {sorted(_TOKEN_EXTRACTOR_REGISTRY.keys())}. "
98+
f"Use register_token_extractor(name, callable) to add custom extractors."
99+
)
100+
return _TOKEN_EXTRACTOR_REGISTRY[name]
101+
102+
103+
def _raise_no_bearer_token(auth_result: "AuthResult") -> None:
104+
"""Raise A365AuthenticationError with a clear message when no token could be extracted."""
105+
found = [type(c).__name__ for c in auth_result.credentials]
58106
raise A365AuthenticationError(
59-
f"No bearer token found in auth provider credentials. "
60-
f"Found credential types: {[type(c).__name__ for c in auth_result.credentials]}"
107+
f"No bearer token from auth provider. "
108+
f"Supported (default): {_TOKEN_EXTRACTOR_SUPPORTED}. "
109+
f"Found credential types: {found}"
61110
)
62111

63112

@@ -163,7 +212,10 @@ async def _create_token_resolver_from_auth_ref(
163212
if not auth_result.credentials:
164213
raise A365AuthenticationError("No credentials available from auth provider")
165214

166-
token = _extract_token_from_auth_result(auth_result)
215+
extractor = _get_token_extractor(None)
216+
token = extractor(auth_result)
217+
if token is None:
218+
_raise_no_bearer_token(auth_result)
167219
expires_at = auth_result.token_expires_at
168220

169221
token_cache = _TokenCache(token, expires_at)
@@ -180,13 +232,23 @@ def token_resolver(agent_id: str, tenant_id: str) -> str | None:
180232

181233

182234
class A365TelemetryExporter(BatchConfigMixin, TelemetryExporterBaseConfig, name="a365"):
183-
"""A telemetry exporter to transmit traces to Microsoft Agent 365 backend."""
235+
"""A telemetry exporter to transmit traces to Microsoft Agent 365 backend.
236+
237+
Auth: the referenced auth provider should return a bearer token via
238+
BearerTokenCred or HeaderCred(Authorization). For other credential shapes,
239+
register a custom token extractor with register_token_extractor(name, callable)
240+
and set token_extractor=name.
241+
"""
184242

185243
agent_id: str = Field(description="The Agent 365 agent ID")
186244
tenant_id: str = Field(description="The Azure tenant ID")
187245
token_resolver: AuthenticationRef = Field(
188246
description="Reference to NAT auth provider for token resolution (e.g., 'a365_auth')"
189247
)
248+
token_extractor: str | None = Field(
249+
default=None,
250+
description="Optional name of a registered token extractor. Default uses BearerTokenCred or HeaderCred(Authorization)."
251+
)
190252
cluster_category: str = Field(
191253
default="prod",
192254
description="Cluster category/environment (e.g., 'prod', 'dev')"
@@ -215,6 +277,8 @@ async def a365_telemetry_exporter(config: A365TelemetryExporter, builder: Builde
215277
"""
216278
from nat.plugins.a365.telemetry.a365_exporter import A365OtelExporter
217279

280+
token_extractor_fn = _get_token_extractor(config.token_extractor)
281+
218282
# Defer auth: do not call get_auth_provider here (not available yet in __aenter__).
219283
token_cache = _TokenCache(None, None)
220284

@@ -236,6 +300,7 @@ def token_resolver(agent_id: str, tenant_id: str) -> str | None:
236300
token_cache=token_cache,
237301
auth_ref=config.token_resolver,
238302
builder=builder,
303+
token_extractor=token_extractor_fn,
239304
cluster_category=config.cluster_category,
240305
use_s2s_endpoint=config.use_s2s_endpoint,
241306
suppress_invoke_agent_input=config.suppress_invoke_agent_input,

0 commit comments

Comments
 (0)