|
1 | 1 | """ |
2 | 2 | Authentication utilities for Databricks Apps deployment. |
3 | 3 |
|
4 | | -Uses service principal authentication when running on Databricks Apps, |
5 | | -and falls back to PAT token or CLI authentication for local development. |
| 4 | +On Databricks Apps, uses OBO (On Behalf Of) — each request creates a |
| 5 | +WorkspaceClient with the user's forwarded token so all SDK calls (SQL, |
| 6 | +UC, serving endpoints) execute under the user's identity and permissions. |
| 7 | +
|
| 8 | +Locally, falls back to PAT token or CLI profile (singleton client). |
6 | 9 | """ |
7 | 10 |
|
8 | 11 | import logging |
9 | 12 | import os |
| 13 | +from contextvars import ContextVar |
10 | 14 |
|
11 | 15 | from databricks.sdk import WorkspaceClient |
| 16 | +from databricks.sdk.config import Config |
12 | 17 |
|
13 | 18 | logger = logging.getLogger(__name__) |
14 | 19 |
|
15 | | -# Singleton client — avoids re-reading ~/.databrickscfg on every call |
| 20 | +# Singleton client for local dev (or fallback when no user token is available) |
16 | 21 | _client: WorkspaceClient | None = None |
17 | 22 | _auth_logged = False |
18 | 23 |
|
| 24 | +# Per-request OBO client stored in a context variable |
| 25 | +_obo_client: ContextVar[WorkspaceClient | None] = ContextVar("_obo_client", default=None) |
| 26 | + |
19 | 27 |
|
20 | 28 | def is_running_on_databricks_apps() -> bool: |
21 | 29 | """Check if running on Databricks Apps (vs local development).""" |
22 | 30 | return os.environ.get("DATABRICKS_APP_PORT") is not None |
23 | 31 |
|
24 | 32 |
|
25 | | -def get_workspace_client() -> WorkspaceClient: |
26 | | - """Get a cached Databricks WorkspaceClient with appropriate authentication. |
| 33 | +def set_obo_user_token(token: str) -> None: |
| 34 | + """Set the user's OBO token for the current request context. |
27 | 35 |
|
28 | | - The client is created once and reused for the lifetime of the process. |
29 | | - On Databricks Apps it uses the service principal; locally it uses |
30 | | - PAT token or CLI profile. |
| 36 | + Call this from middleware/dependencies with the user's Authorization |
| 37 | + header value. Creates a per-request WorkspaceClient that authenticates |
| 38 | + as the user. |
| 39 | +
|
| 40 | + We must explicitly set ``auth_type="pat"`` because the Databricks Apps |
| 41 | + environment has DATABRICKS_CLIENT_ID / DATABRICKS_CLIENT_SECRET set, |
| 42 | + and the SDK would otherwise use oauth-m2m instead of the user's token. |
31 | 43 | """ |
| 44 | + host = os.environ.get("DATABRICKS_HOST", "") |
| 45 | + if not host: |
| 46 | + default = _get_default_client() |
| 47 | + host = default.config.host or "" |
| 48 | + |
| 49 | + cfg = Config( |
| 50 | + host=host, |
| 51 | + token=token, |
| 52 | + auth_type="pat", |
| 53 | + # Prevent the SDK from reading env vars that would override the token |
| 54 | + client_id=None, |
| 55 | + client_secret=None, # gitleaks:allow |
| 56 | + ) |
| 57 | + client = WorkspaceClient(config=cfg) |
| 58 | + _obo_client.set(client) |
| 59 | + logger.debug("OBO client set for current request (host=%s, auth=%s)", host, cfg.auth_type) |
| 60 | + |
| 61 | + |
| 62 | +def clear_obo_user_token() -> None: |
| 63 | + """Clear the per-request OBO client after the request completes.""" |
| 64 | + _obo_client.set(None) |
| 65 | + |
| 66 | + |
| 67 | +def _get_default_client() -> WorkspaceClient: |
| 68 | + """Get the default singleton client (SP on Apps, CLI/PAT locally).""" |
32 | 69 | global _client, _auth_logged |
33 | 70 |
|
34 | 71 | if _client is None: |
@@ -61,15 +98,27 @@ def get_workspace_client() -> WorkspaceClient: |
61 | 98 | return _client |
62 | 99 |
|
63 | 100 |
|
| 101 | +def get_workspace_client() -> WorkspaceClient: |
| 102 | + """Get the WorkspaceClient for the current context. |
| 103 | +
|
| 104 | + Returns the OBO (per-user) client if set, otherwise the default |
| 105 | + singleton. This ensures all SDK calls in the request path use the |
| 106 | + user's credentials when running on Databricks Apps. |
| 107 | + """ |
| 108 | + obo = _obo_client.get() |
| 109 | + if obo is not None: |
| 110 | + return obo |
| 111 | + return _get_default_client() |
| 112 | + |
| 113 | + |
64 | 114 | def get_databricks_host() -> str: |
65 | 115 | """Get the Databricks workspace host URL (without trailing slash).""" |
66 | | - client = get_workspace_client() |
| 116 | + client = _get_default_client() |
67 | 117 | host = client.config.host |
68 | 118 | return host.rstrip("/") if host else "" |
69 | 119 |
|
70 | 120 |
|
71 | 121 | def get_llm_api_key() -> str: |
72 | 122 | """Get the API key for LLM serving endpoints.""" |
73 | | - if is_running_on_databricks_apps(): |
74 | | - return get_workspace_client().config.token or "" |
75 | | - return os.environ.get("DATABRICKS_TOKEN", "") |
| 123 | + client = get_workspace_client() |
| 124 | + return client.config.token or os.environ.get("DATABRICKS_TOKEN", "") |
0 commit comments